skypilot-nightly 1.0.0.dev20250502__py3-none-any.whl → 1.0.0.dev20251203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. sky/__init__.py +22 -6
  2. sky/adaptors/aws.py +81 -16
  3. sky/adaptors/common.py +25 -2
  4. sky/adaptors/coreweave.py +278 -0
  5. sky/adaptors/do.py +8 -2
  6. sky/adaptors/gcp.py +11 -0
  7. sky/adaptors/hyperbolic.py +8 -0
  8. sky/adaptors/ibm.py +5 -2
  9. sky/adaptors/kubernetes.py +149 -18
  10. sky/adaptors/nebius.py +173 -30
  11. sky/adaptors/primeintellect.py +1 -0
  12. sky/adaptors/runpod.py +68 -0
  13. sky/adaptors/seeweb.py +183 -0
  14. sky/adaptors/shadeform.py +89 -0
  15. sky/admin_policy.py +187 -4
  16. sky/authentication.py +179 -225
  17. sky/backends/__init__.py +4 -2
  18. sky/backends/backend.py +22 -9
  19. sky/backends/backend_utils.py +1323 -397
  20. sky/backends/cloud_vm_ray_backend.py +1749 -1029
  21. sky/backends/docker_utils.py +1 -1
  22. sky/backends/local_docker_backend.py +11 -6
  23. sky/backends/task_codegen.py +633 -0
  24. sky/backends/wheel_utils.py +55 -9
  25. sky/{clouds/service_catalog → catalog}/__init__.py +21 -19
  26. sky/{clouds/service_catalog → catalog}/aws_catalog.py +27 -8
  27. sky/{clouds/service_catalog → catalog}/azure_catalog.py +10 -7
  28. sky/{clouds/service_catalog → catalog}/common.py +90 -49
  29. sky/{clouds/service_catalog → catalog}/cudo_catalog.py +8 -5
  30. sky/{clouds/service_catalog → catalog}/data_fetchers/analyze.py +1 -1
  31. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_aws.py +116 -80
  32. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_cudo.py +38 -38
  33. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_gcp.py +70 -16
  34. sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
  35. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_lambda_cloud.py +1 -0
  36. sky/catalog/data_fetchers/fetch_nebius.py +338 -0
  37. sky/catalog/data_fetchers/fetch_runpod.py +698 -0
  38. sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
  39. sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
  40. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vast.py +1 -1
  41. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vsphere.py +1 -1
  42. sky/{clouds/service_catalog → catalog}/do_catalog.py +5 -2
  43. sky/{clouds/service_catalog → catalog}/fluidstack_catalog.py +6 -3
  44. sky/{clouds/service_catalog → catalog}/gcp_catalog.py +41 -15
  45. sky/catalog/hyperbolic_catalog.py +136 -0
  46. sky/{clouds/service_catalog → catalog}/ibm_catalog.py +9 -6
  47. sky/{clouds/service_catalog → catalog}/kubernetes_catalog.py +36 -24
  48. sky/{clouds/service_catalog → catalog}/lambda_catalog.py +9 -6
  49. sky/{clouds/service_catalog → catalog}/nebius_catalog.py +9 -7
  50. sky/{clouds/service_catalog → catalog}/oci_catalog.py +9 -6
  51. sky/{clouds/service_catalog → catalog}/paperspace_catalog.py +5 -2
  52. sky/catalog/primeintellect_catalog.py +95 -0
  53. sky/{clouds/service_catalog → catalog}/runpod_catalog.py +11 -4
  54. sky/{clouds/service_catalog → catalog}/scp_catalog.py +9 -6
  55. sky/catalog/seeweb_catalog.py +184 -0
  56. sky/catalog/shadeform_catalog.py +165 -0
  57. sky/catalog/ssh_catalog.py +167 -0
  58. sky/{clouds/service_catalog → catalog}/vast_catalog.py +6 -3
  59. sky/{clouds/service_catalog → catalog}/vsphere_catalog.py +5 -2
  60. sky/check.py +533 -185
  61. sky/cli.py +5 -5975
  62. sky/client/{cli.py → cli/command.py} +2591 -1956
  63. sky/client/cli/deprecation_utils.py +99 -0
  64. sky/client/cli/flags.py +359 -0
  65. sky/client/cli/table_utils.py +322 -0
  66. sky/client/cli/utils.py +79 -0
  67. sky/client/common.py +78 -32
  68. sky/client/oauth.py +82 -0
  69. sky/client/sdk.py +1219 -319
  70. sky/client/sdk_async.py +827 -0
  71. sky/client/service_account_auth.py +47 -0
  72. sky/cloud_stores.py +82 -3
  73. sky/clouds/__init__.py +13 -0
  74. sky/clouds/aws.py +564 -164
  75. sky/clouds/azure.py +105 -83
  76. sky/clouds/cloud.py +140 -40
  77. sky/clouds/cudo.py +68 -50
  78. sky/clouds/do.py +66 -48
  79. sky/clouds/fluidstack.py +63 -44
  80. sky/clouds/gcp.py +339 -110
  81. sky/clouds/hyperbolic.py +293 -0
  82. sky/clouds/ibm.py +70 -49
  83. sky/clouds/kubernetes.py +570 -162
  84. sky/clouds/lambda_cloud.py +74 -54
  85. sky/clouds/nebius.py +210 -81
  86. sky/clouds/oci.py +88 -66
  87. sky/clouds/paperspace.py +61 -44
  88. sky/clouds/primeintellect.py +317 -0
  89. sky/clouds/runpod.py +164 -74
  90. sky/clouds/scp.py +89 -86
  91. sky/clouds/seeweb.py +477 -0
  92. sky/clouds/shadeform.py +400 -0
  93. sky/clouds/ssh.py +263 -0
  94. sky/clouds/utils/aws_utils.py +10 -4
  95. sky/clouds/utils/gcp_utils.py +87 -11
  96. sky/clouds/utils/oci_utils.py +38 -14
  97. sky/clouds/utils/scp_utils.py +231 -167
  98. sky/clouds/vast.py +99 -77
  99. sky/clouds/vsphere.py +51 -40
  100. sky/core.py +375 -173
  101. sky/dag.py +15 -0
  102. sky/dashboard/out/404.html +1 -1
  103. sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +1 -0
  104. sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
  105. sky/dashboard/out/_next/static/chunks/1272-1ef0bf0237faccdb.js +1 -0
  106. sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +6 -0
  107. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
  108. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +1 -0
  109. sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
  110. sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +26 -0
  111. sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +1 -0
  112. sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +1 -0
  113. sky/dashboard/out/_next/static/chunks/3800-7b45f9fbb6308557.js +1 -0
  114. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
  115. sky/dashboard/out/_next/static/chunks/3937.210053269f121201.js +1 -0
  116. sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +1 -0
  117. sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
  118. sky/dashboard/out/_next/static/chunks/5739-d67458fcb1386c92.js +8 -0
  119. sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
  120. sky/dashboard/out/_next/static/chunks/616-3d59f75e2ccf9321.js +39 -0
  121. sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
  122. sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +1 -0
  123. sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +1 -0
  124. sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +1 -0
  125. sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
  126. sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +41 -0
  127. sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +1 -0
  128. sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
  129. sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +1 -0
  130. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
  131. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
  132. sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +31 -0
  133. sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
  134. sky/dashboard/out/_next/static/chunks/fd9d1056-86323a29a8f7e46a.js +1 -0
  135. sky/dashboard/out/_next/static/chunks/framework-cf60a09ccd051a10.js +33 -0
  136. sky/dashboard/out/_next/static/chunks/main-app-587214043926b3cc.js +1 -0
  137. sky/dashboard/out/_next/static/chunks/main-f15ccb73239a3bf1.js +1 -0
  138. sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +34 -0
  139. sky/dashboard/out/_next/static/chunks/pages/_error-c66a4e8afc46f17b.js +1 -0
  140. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +16 -0
  141. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-abfcac9c137aa543.js +1 -0
  142. sky/dashboard/out/_next/static/chunks/pages/clusters-ee39056f9851a3ff.js +1 -0
  143. sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
  144. sky/dashboard/out/_next/static/chunks/pages/index-444f1804401f04ea.js +1 -0
  145. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +1 -0
  146. sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +1 -0
  147. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +16 -0
  148. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-9faf940b253e3e06.js +21 -0
  149. sky/dashboard/out/_next/static/chunks/pages/jobs-2072b48b617989c9.js +1 -0
  150. sky/dashboard/out/_next/static/chunks/pages/users-f42674164aa73423.js +1 -0
  151. sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
  152. sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
  153. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-84a40f8c7c627fe4.js +1 -0
  154. sky/dashboard/out/_next/static/chunks/pages/workspaces-531b2f8c4bf89f82.js +1 -0
  155. sky/dashboard/out/_next/static/chunks/webpack-64e05f17bf2cf8ce.js +1 -0
  156. sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
  157. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  158. sky/dashboard/out/clusters/[cluster].html +1 -1
  159. sky/dashboard/out/clusters.html +1 -1
  160. sky/dashboard/out/config.html +1 -0
  161. sky/dashboard/out/index.html +1 -1
  162. sky/dashboard/out/infra/[context].html +1 -0
  163. sky/dashboard/out/infra.html +1 -0
  164. sky/dashboard/out/jobs/[job].html +1 -1
  165. sky/dashboard/out/jobs/pools/[pool].html +1 -0
  166. sky/dashboard/out/jobs.html +1 -1
  167. sky/dashboard/out/users.html +1 -0
  168. sky/dashboard/out/volumes.html +1 -0
  169. sky/dashboard/out/workspace/new.html +1 -0
  170. sky/dashboard/out/workspaces/[name].html +1 -0
  171. sky/dashboard/out/workspaces.html +1 -0
  172. sky/data/data_utils.py +137 -1
  173. sky/data/mounting_utils.py +269 -84
  174. sky/data/storage.py +1460 -1807
  175. sky/data/storage_utils.py +43 -57
  176. sky/exceptions.py +126 -2
  177. sky/execution.py +216 -63
  178. sky/global_user_state.py +2390 -586
  179. sky/jobs/__init__.py +7 -0
  180. sky/jobs/client/sdk.py +300 -58
  181. sky/jobs/client/sdk_async.py +161 -0
  182. sky/jobs/constants.py +15 -8
  183. sky/jobs/controller.py +848 -275
  184. sky/jobs/file_content_utils.py +128 -0
  185. sky/jobs/log_gc.py +193 -0
  186. sky/jobs/recovery_strategy.py +402 -152
  187. sky/jobs/scheduler.py +314 -189
  188. sky/jobs/server/core.py +836 -255
  189. sky/jobs/server/server.py +156 -115
  190. sky/jobs/server/utils.py +136 -0
  191. sky/jobs/state.py +2109 -706
  192. sky/jobs/utils.py +1306 -215
  193. sky/logs/__init__.py +21 -0
  194. sky/logs/agent.py +108 -0
  195. sky/logs/aws.py +243 -0
  196. sky/logs/gcp.py +91 -0
  197. sky/metrics/__init__.py +0 -0
  198. sky/metrics/utils.py +453 -0
  199. sky/models.py +78 -1
  200. sky/optimizer.py +164 -70
  201. sky/provision/__init__.py +90 -4
  202. sky/provision/aws/config.py +147 -26
  203. sky/provision/aws/instance.py +136 -50
  204. sky/provision/azure/instance.py +11 -6
  205. sky/provision/common.py +13 -1
  206. sky/provision/cudo/cudo_machine_type.py +1 -1
  207. sky/provision/cudo/cudo_utils.py +14 -8
  208. sky/provision/cudo/cudo_wrapper.py +72 -71
  209. sky/provision/cudo/instance.py +10 -6
  210. sky/provision/do/instance.py +10 -6
  211. sky/provision/do/utils.py +4 -3
  212. sky/provision/docker_utils.py +140 -33
  213. sky/provision/fluidstack/instance.py +13 -8
  214. sky/provision/gcp/__init__.py +1 -0
  215. sky/provision/gcp/config.py +301 -19
  216. sky/provision/gcp/constants.py +218 -0
  217. sky/provision/gcp/instance.py +36 -8
  218. sky/provision/gcp/instance_utils.py +18 -4
  219. sky/provision/gcp/volume_utils.py +247 -0
  220. sky/provision/hyperbolic/__init__.py +12 -0
  221. sky/provision/hyperbolic/config.py +10 -0
  222. sky/provision/hyperbolic/instance.py +437 -0
  223. sky/provision/hyperbolic/utils.py +373 -0
  224. sky/provision/instance_setup.py +101 -20
  225. sky/provision/kubernetes/__init__.py +5 -0
  226. sky/provision/kubernetes/config.py +9 -52
  227. sky/provision/kubernetes/constants.py +17 -0
  228. sky/provision/kubernetes/instance.py +919 -280
  229. sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
  230. sky/provision/kubernetes/network.py +27 -17
  231. sky/provision/kubernetes/network_utils.py +44 -43
  232. sky/provision/kubernetes/utils.py +1221 -534
  233. sky/provision/kubernetes/volume.py +343 -0
  234. sky/provision/lambda_cloud/instance.py +22 -16
  235. sky/provision/nebius/constants.py +50 -0
  236. sky/provision/nebius/instance.py +19 -6
  237. sky/provision/nebius/utils.py +237 -137
  238. sky/provision/oci/instance.py +10 -5
  239. sky/provision/paperspace/instance.py +10 -7
  240. sky/provision/paperspace/utils.py +1 -1
  241. sky/provision/primeintellect/__init__.py +10 -0
  242. sky/provision/primeintellect/config.py +11 -0
  243. sky/provision/primeintellect/instance.py +454 -0
  244. sky/provision/primeintellect/utils.py +398 -0
  245. sky/provision/provisioner.py +117 -36
  246. sky/provision/runpod/__init__.py +5 -0
  247. sky/provision/runpod/instance.py +27 -6
  248. sky/provision/runpod/utils.py +51 -18
  249. sky/provision/runpod/volume.py +214 -0
  250. sky/provision/scp/__init__.py +15 -0
  251. sky/provision/scp/config.py +93 -0
  252. sky/provision/scp/instance.py +707 -0
  253. sky/provision/seeweb/__init__.py +11 -0
  254. sky/provision/seeweb/config.py +13 -0
  255. sky/provision/seeweb/instance.py +812 -0
  256. sky/provision/shadeform/__init__.py +11 -0
  257. sky/provision/shadeform/config.py +12 -0
  258. sky/provision/shadeform/instance.py +351 -0
  259. sky/provision/shadeform/shadeform_utils.py +83 -0
  260. sky/provision/ssh/__init__.py +18 -0
  261. sky/provision/vast/instance.py +13 -8
  262. sky/provision/vast/utils.py +10 -7
  263. sky/provision/volume.py +164 -0
  264. sky/provision/vsphere/common/ssl_helper.py +1 -1
  265. sky/provision/vsphere/common/vapiconnect.py +2 -1
  266. sky/provision/vsphere/common/vim_utils.py +4 -4
  267. sky/provision/vsphere/instance.py +15 -10
  268. sky/provision/vsphere/vsphere_utils.py +17 -20
  269. sky/py.typed +0 -0
  270. sky/resources.py +845 -119
  271. sky/schemas/__init__.py +0 -0
  272. sky/schemas/api/__init__.py +0 -0
  273. sky/schemas/api/responses.py +227 -0
  274. sky/schemas/db/README +4 -0
  275. sky/schemas/db/env.py +90 -0
  276. sky/schemas/db/global_user_state/001_initial_schema.py +124 -0
  277. sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
  278. sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
  279. sky/schemas/db/global_user_state/004_is_managed.py +34 -0
  280. sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
  281. sky/schemas/db/global_user_state/006_provision_log.py +41 -0
  282. sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
  283. sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
  284. sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
  285. sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
  286. sky/schemas/db/global_user_state/011_is_ephemeral.py +34 -0
  287. sky/schemas/db/kv_cache/001_initial_schema.py +29 -0
  288. sky/schemas/db/script.py.mako +28 -0
  289. sky/schemas/db/serve_state/001_initial_schema.py +67 -0
  290. sky/schemas/db/serve_state/002_yaml_content.py +34 -0
  291. sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
  292. sky/schemas/db/spot_jobs/001_initial_schema.py +97 -0
  293. sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
  294. sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
  295. sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
  296. sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
  297. sky/schemas/db/spot_jobs/006_controller_pid_started_at.py +34 -0
  298. sky/schemas/db/spot_jobs/007_config_file_content.py +34 -0
  299. sky/schemas/generated/__init__.py +0 -0
  300. sky/schemas/generated/autostopv1_pb2.py +36 -0
  301. sky/schemas/generated/autostopv1_pb2.pyi +43 -0
  302. sky/schemas/generated/autostopv1_pb2_grpc.py +146 -0
  303. sky/schemas/generated/jobsv1_pb2.py +86 -0
  304. sky/schemas/generated/jobsv1_pb2.pyi +254 -0
  305. sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
  306. sky/schemas/generated/managed_jobsv1_pb2.py +76 -0
  307. sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
  308. sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
  309. sky/schemas/generated/servev1_pb2.py +58 -0
  310. sky/schemas/generated/servev1_pb2.pyi +115 -0
  311. sky/schemas/generated/servev1_pb2_grpc.py +322 -0
  312. sky/serve/autoscalers.py +357 -5
  313. sky/serve/client/impl.py +310 -0
  314. sky/serve/client/sdk.py +47 -139
  315. sky/serve/client/sdk_async.py +130 -0
  316. sky/serve/constants.py +12 -9
  317. sky/serve/controller.py +68 -17
  318. sky/serve/load_balancer.py +106 -60
  319. sky/serve/load_balancing_policies.py +116 -2
  320. sky/serve/replica_managers.py +434 -249
  321. sky/serve/serve_rpc_utils.py +179 -0
  322. sky/serve/serve_state.py +569 -257
  323. sky/serve/serve_utils.py +775 -265
  324. sky/serve/server/core.py +66 -711
  325. sky/serve/server/impl.py +1093 -0
  326. sky/serve/server/server.py +21 -18
  327. sky/serve/service.py +192 -89
  328. sky/serve/service_spec.py +144 -20
  329. sky/serve/spot_placer.py +3 -0
  330. sky/server/auth/__init__.py +0 -0
  331. sky/server/auth/authn.py +50 -0
  332. sky/server/auth/loopback.py +38 -0
  333. sky/server/auth/oauth2_proxy.py +202 -0
  334. sky/server/common.py +478 -182
  335. sky/server/config.py +85 -23
  336. sky/server/constants.py +44 -6
  337. sky/server/daemons.py +295 -0
  338. sky/server/html/token_page.html +185 -0
  339. sky/server/metrics.py +160 -0
  340. sky/server/middleware_utils.py +166 -0
  341. sky/server/requests/executor.py +558 -138
  342. sky/server/requests/payloads.py +364 -24
  343. sky/server/requests/preconditions.py +21 -17
  344. sky/server/requests/process.py +112 -29
  345. sky/server/requests/request_names.py +121 -0
  346. sky/server/requests/requests.py +822 -226
  347. sky/server/requests/serializers/decoders.py +82 -31
  348. sky/server/requests/serializers/encoders.py +140 -22
  349. sky/server/requests/threads.py +117 -0
  350. sky/server/rest.py +455 -0
  351. sky/server/server.py +1309 -285
  352. sky/server/state.py +20 -0
  353. sky/server/stream_utils.py +327 -61
  354. sky/server/uvicorn.py +217 -3
  355. sky/server/versions.py +270 -0
  356. sky/setup_files/MANIFEST.in +11 -1
  357. sky/setup_files/alembic.ini +160 -0
  358. sky/setup_files/dependencies.py +139 -31
  359. sky/setup_files/setup.py +44 -42
  360. sky/sky_logging.py +114 -7
  361. sky/skylet/attempt_skylet.py +106 -24
  362. sky/skylet/autostop_lib.py +129 -8
  363. sky/skylet/configs.py +29 -20
  364. sky/skylet/constants.py +216 -25
  365. sky/skylet/events.py +101 -21
  366. sky/skylet/job_lib.py +345 -164
  367. sky/skylet/log_lib.py +297 -18
  368. sky/skylet/log_lib.pyi +44 -1
  369. sky/skylet/providers/ibm/node_provider.py +12 -8
  370. sky/skylet/providers/ibm/vpc_provider.py +13 -12
  371. sky/skylet/ray_patches/__init__.py +17 -3
  372. sky/skylet/ray_patches/autoscaler.py.diff +18 -0
  373. sky/skylet/ray_patches/cli.py.diff +19 -0
  374. sky/skylet/ray_patches/command_runner.py.diff +17 -0
  375. sky/skylet/ray_patches/log_monitor.py.diff +20 -0
  376. sky/skylet/ray_patches/resource_demand_scheduler.py.diff +32 -0
  377. sky/skylet/ray_patches/updater.py.diff +18 -0
  378. sky/skylet/ray_patches/worker.py.diff +41 -0
  379. sky/skylet/runtime_utils.py +21 -0
  380. sky/skylet/services.py +568 -0
  381. sky/skylet/skylet.py +72 -4
  382. sky/skylet/subprocess_daemon.py +104 -29
  383. sky/skypilot_config.py +506 -99
  384. sky/ssh_node_pools/__init__.py +1 -0
  385. sky/ssh_node_pools/core.py +135 -0
  386. sky/ssh_node_pools/server.py +233 -0
  387. sky/task.py +685 -163
  388. sky/templates/aws-ray.yml.j2 +11 -3
  389. sky/templates/azure-ray.yml.j2 +2 -1
  390. sky/templates/cudo-ray.yml.j2 +1 -0
  391. sky/templates/do-ray.yml.j2 +2 -1
  392. sky/templates/fluidstack-ray.yml.j2 +1 -0
  393. sky/templates/gcp-ray.yml.j2 +62 -1
  394. sky/templates/hyperbolic-ray.yml.j2 +68 -0
  395. sky/templates/ibm-ray.yml.j2 +2 -1
  396. sky/templates/jobs-controller.yaml.j2 +27 -24
  397. sky/templates/kubernetes-loadbalancer.yml.j2 +2 -0
  398. sky/templates/kubernetes-ray.yml.j2 +611 -50
  399. sky/templates/lambda-ray.yml.j2 +2 -1
  400. sky/templates/nebius-ray.yml.j2 +34 -12
  401. sky/templates/oci-ray.yml.j2 +1 -0
  402. sky/templates/paperspace-ray.yml.j2 +2 -1
  403. sky/templates/primeintellect-ray.yml.j2 +72 -0
  404. sky/templates/runpod-ray.yml.j2 +10 -1
  405. sky/templates/scp-ray.yml.j2 +4 -50
  406. sky/templates/seeweb-ray.yml.j2 +171 -0
  407. sky/templates/shadeform-ray.yml.j2 +73 -0
  408. sky/templates/sky-serve-controller.yaml.j2 +22 -2
  409. sky/templates/vast-ray.yml.j2 +1 -0
  410. sky/templates/vsphere-ray.yml.j2 +1 -0
  411. sky/templates/websocket_proxy.py +212 -37
  412. sky/usage/usage_lib.py +31 -15
  413. sky/users/__init__.py +0 -0
  414. sky/users/model.conf +15 -0
  415. sky/users/permission.py +397 -0
  416. sky/users/rbac.py +121 -0
  417. sky/users/server.py +720 -0
  418. sky/users/token_service.py +218 -0
  419. sky/utils/accelerator_registry.py +35 -5
  420. sky/utils/admin_policy_utils.py +84 -38
  421. sky/utils/annotations.py +38 -5
  422. sky/utils/asyncio_utils.py +78 -0
  423. sky/utils/atomic.py +1 -1
  424. sky/utils/auth_utils.py +153 -0
  425. sky/utils/benchmark_utils.py +60 -0
  426. sky/utils/cli_utils/status_utils.py +159 -86
  427. sky/utils/cluster_utils.py +31 -9
  428. sky/utils/command_runner.py +354 -68
  429. sky/utils/command_runner.pyi +93 -3
  430. sky/utils/common.py +35 -8
  431. sky/utils/common_utils.py +314 -91
  432. sky/utils/config_utils.py +74 -5
  433. sky/utils/context.py +403 -0
  434. sky/utils/context_utils.py +242 -0
  435. sky/utils/controller_utils.py +383 -89
  436. sky/utils/dag_utils.py +31 -12
  437. sky/utils/db/__init__.py +0 -0
  438. sky/utils/db/db_utils.py +485 -0
  439. sky/utils/db/kv_cache.py +149 -0
  440. sky/utils/db/migration_utils.py +137 -0
  441. sky/utils/directory_utils.py +12 -0
  442. sky/utils/env_options.py +13 -0
  443. sky/utils/git.py +567 -0
  444. sky/utils/git_clone.sh +460 -0
  445. sky/utils/infra_utils.py +195 -0
  446. sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
  447. sky/utils/kubernetes/config_map_utils.py +133 -0
  448. sky/utils/kubernetes/create_cluster.sh +15 -29
  449. sky/utils/kubernetes/delete_cluster.sh +10 -7
  450. sky/utils/kubernetes/deploy_ssh_node_pools.py +1177 -0
  451. sky/utils/kubernetes/exec_kubeconfig_converter.py +22 -31
  452. sky/utils/kubernetes/generate_kind_config.py +6 -66
  453. sky/utils/kubernetes/generate_kubeconfig.sh +4 -1
  454. sky/utils/kubernetes/gpu_labeler.py +18 -8
  455. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +2 -1
  456. sky/utils/kubernetes/k8s_gpu_labeler_setup.yaml +16 -16
  457. sky/utils/kubernetes/kubernetes_deploy_utils.py +284 -114
  458. sky/utils/kubernetes/rsync_helper.sh +11 -3
  459. sky/utils/kubernetes/ssh-tunnel.sh +379 -0
  460. sky/utils/kubernetes/ssh_utils.py +221 -0
  461. sky/utils/kubernetes_enums.py +8 -15
  462. sky/utils/lock_events.py +94 -0
  463. sky/utils/locks.py +416 -0
  464. sky/utils/log_utils.py +82 -107
  465. sky/utils/perf_utils.py +22 -0
  466. sky/utils/resource_checker.py +298 -0
  467. sky/utils/resources_utils.py +249 -32
  468. sky/utils/rich_utils.py +217 -39
  469. sky/utils/schemas.py +955 -160
  470. sky/utils/serialize_utils.py +16 -0
  471. sky/utils/status_lib.py +10 -0
  472. sky/utils/subprocess_utils.py +29 -15
  473. sky/utils/tempstore.py +70 -0
  474. sky/utils/thread_utils.py +91 -0
  475. sky/utils/timeline.py +26 -53
  476. sky/utils/ux_utils.py +84 -15
  477. sky/utils/validator.py +11 -1
  478. sky/utils/volume.py +165 -0
  479. sky/utils/yaml_utils.py +111 -0
  480. sky/volumes/__init__.py +13 -0
  481. sky/volumes/client/__init__.py +0 -0
  482. sky/volumes/client/sdk.py +150 -0
  483. sky/volumes/server/__init__.py +0 -0
  484. sky/volumes/server/core.py +270 -0
  485. sky/volumes/server/server.py +124 -0
  486. sky/volumes/volume.py +215 -0
  487. sky/workspaces/__init__.py +0 -0
  488. sky/workspaces/core.py +655 -0
  489. sky/workspaces/server.py +101 -0
  490. sky/workspaces/utils.py +56 -0
  491. sky_templates/README.md +3 -0
  492. sky_templates/__init__.py +3 -0
  493. sky_templates/ray/__init__.py +0 -0
  494. sky_templates/ray/start_cluster +183 -0
  495. sky_templates/ray/stop_cluster +75 -0
  496. skypilot_nightly-1.0.0.dev20251203.dist-info/METADATA +676 -0
  497. skypilot_nightly-1.0.0.dev20251203.dist-info/RECORD +611 -0
  498. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/WHEEL +1 -1
  499. skypilot_nightly-1.0.0.dev20251203.dist-info/top_level.txt +2 -0
  500. sky/benchmark/benchmark_state.py +0 -256
  501. sky/benchmark/benchmark_utils.py +0 -641
  502. sky/clouds/service_catalog/constants.py +0 -7
  503. sky/dashboard/out/_next/static/GWvVBSCS7FmUiVmjaL1a7/_buildManifest.js +0 -1
  504. sky/dashboard/out/_next/static/chunks/236-2db3ee3fba33dd9e.js +0 -6
  505. sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
  506. sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
  507. sky/dashboard/out/_next/static/chunks/678-206dddca808e6d16.js +0 -59
  508. sky/dashboard/out/_next/static/chunks/845-9e60713e0c441abc.js +0 -1
  509. sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
  510. sky/dashboard/out/_next/static/chunks/fd9d1056-2821b0f0cabcd8bd.js +0 -1
  511. sky/dashboard/out/_next/static/chunks/framework-87d061ee6ed71b28.js +0 -33
  512. sky/dashboard/out/_next/static/chunks/main-app-241eb28595532291.js +0 -1
  513. sky/dashboard/out/_next/static/chunks/main-e0e2335212e72357.js +0 -1
  514. sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
  515. sky/dashboard/out/_next/static/chunks/pages/_error-1be831200e60c5c0.js +0 -1
  516. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-6ac338bc2239cb45.js +0 -1
  517. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
  518. sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
  519. sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
  520. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-1c519e1afc523dc9.js +0 -1
  521. sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
  522. sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
  523. sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
  524. sky/jobs/dashboard/dashboard.py +0 -223
  525. sky/jobs/dashboard/static/favicon.ico +0 -0
  526. sky/jobs/dashboard/templates/index.html +0 -831
  527. sky/jobs/server/dashboard_utils.py +0 -69
  528. sky/skylet/providers/scp/__init__.py +0 -2
  529. sky/skylet/providers/scp/config.py +0 -149
  530. sky/skylet/providers/scp/node_provider.py +0 -578
  531. sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
  532. sky/utils/db_utils.py +0 -100
  533. sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
  534. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
  535. skypilot_nightly-1.0.0.dev20250502.dist-info/METADATA +0 -361
  536. skypilot_nightly-1.0.0.dev20250502.dist-info/RECORD +0 -396
  537. skypilot_nightly-1.0.0.dev20250502.dist-info/top_level.txt +0 -1
  538. /sky/{clouds/service_catalog → catalog}/config.py +0 -0
  539. /sky/{benchmark → catalog/data_fetchers}/__init__.py +0 -0
  540. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_azure.py +0 -0
  541. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_fluidstack.py +0 -0
  542. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_ibm.py +0 -0
  543. /sky/{clouds/service_catalog/data_fetchers → client/cli}/__init__.py +0 -0
  544. /sky/dashboard/out/_next/static/{GWvVBSCS7FmUiVmjaL1a7 → 96_E2yl3QAiIJGOYCkSpB}/_ssgManifest.js +0 -0
  545. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/entry_points.txt +0 -0
  546. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,11 @@
1
1
  """Kubernetes utilities for SkyPilot."""
2
+ import collections
3
+ import copy
2
4
  import dataclasses
5
+ import datetime
6
+ import enum
3
7
  import functools
8
+ import hashlib
4
9
  import json
5
10
  import math
6
11
  import os
@@ -9,12 +14,14 @@ import shutil
9
14
  import subprocess
10
15
  import time
11
16
  import typing
12
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
13
- from urllib.parse import urlparse
17
+ from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
18
+ Union)
19
+
20
+ import ijson
14
21
 
15
- import sky
16
22
  from sky import clouds
17
23
  from sky import exceptions
24
+ from sky import global_user_state
18
25
  from sky import models
19
26
  from sky import sky_logging
20
27
  from sky import skypilot_config
@@ -34,6 +41,7 @@ from sky.utils import schemas
34
41
  from sky.utils import status_lib
35
42
  from sky.utils import timeline
36
43
  from sky.utils import ux_utils
44
+ from sky.utils import yaml_utils
37
45
 
38
46
  if typing.TYPE_CHECKING:
39
47
  import jinja2
@@ -55,6 +63,81 @@ HIGH_AVAILABILITY_DEPLOYMENT_VOLUME_MOUNT_NAME = 'sky-data'
55
63
  # and store all data that needs to be persisted in future.
56
64
  HIGH_AVAILABILITY_DEPLOYMENT_VOLUME_MOUNT_PATH = '/home/sky'
57
65
 
66
+ IJSON_BUFFER_SIZE = 64 * 1024 # 64KB, default from ijson
67
+
68
+
69
+ class KubernetesHighPerformanceNetworkType(enum.Enum):
70
+ """Enum for different Kubernetes cluster types with high performance
71
+ network configurations.
72
+
73
+ This enum defines cluster types that support optimized networking for
74
+ distributed ML workloads:
75
+ - GCP_TCPX: GKE clusters with GPUDirect-TCPX support
76
+ (A3 High instances: a3-highgpu-8g)
77
+ - GCP_TCPXO: GKE clusters with GPUDirect-TCPXO support
78
+ (A3 Mega instances: a3-megagpu-8g)
79
+ - GCP_GPUDIRECT_RDMA: GKE clusters with GPUDirect-RDMA support
80
+ (A4/A3 Ultra instances)
81
+ - NEBIUS: Nebius clusters with InfiniBand support for high-throughput,
82
+ low-latency networking
83
+ - COREWEAVE: CoreWeave clusters with InfiniBand support.
84
+ - NONE: Standard clusters without specialized networking optimizations
85
+
86
+ The network configurations align with corresponding VM-based
87
+ implementations:
88
+ - GCP settings match
89
+ sky.provision.gcp.constants.GPU_DIRECT_TCPX_SPECIFIC_OPTIONS
90
+ - Nebius settings match the InfiniBand configuration used in Nebius VMs
91
+ """
92
+
93
+ GCP_TCPX = 'gcp_tcpx'
94
+ GCP_TCPXO = 'gcp_tcpxo'
95
+ GCP_GPUDIRECT_RDMA = 'gcp_gpudirect_rdma'
96
+ NEBIUS = 'nebius'
97
+ COREWEAVE = 'coreweave'
98
+ NONE = 'none'
99
+
100
+ def get_network_env_vars(self) -> Dict[str, str]:
101
+ """Get network environment variables for this cluster type."""
102
+ if self == KubernetesHighPerformanceNetworkType.NEBIUS:
103
+ # Nebius cluster with InfiniBand - use InfiniBand optimizations
104
+ return {
105
+ 'NCCL_IB_HCA': 'mlx5',
106
+ 'UCX_NET_DEVICES': ('mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,'
107
+ 'mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1')
108
+ }
109
+ elif self == KubernetesHighPerformanceNetworkType.COREWEAVE:
110
+ return {
111
+ 'NCCL_SOCKET_IFNAME': 'eth0',
112
+ 'NCCL_IB_HCA': 'ibp',
113
+ # Restrict UCX to TCP to avoid unneccsary errors. NCCL doesn't use UCX
114
+ 'UCX_TLS': 'tcp',
115
+ 'UCX_NET_DEVICES': 'eth0',
116
+ }
117
+ else:
118
+ # GCP clusters and generic clusters - environment variables are
119
+ # handled directly in the template
120
+ return {}
121
+
122
+ def supports_high_performance_networking(self) -> bool:
123
+ """Check if this cluster type supports high performance networking."""
124
+ return self is not KubernetesHighPerformanceNetworkType.NONE
125
+
126
+ def supports_gpu_direct(self) -> bool:
127
+ """Check if this cluster type supports GPUDirect networking."""
128
+ return self in (KubernetesHighPerformanceNetworkType.GCP_TCPX,
129
+ KubernetesHighPerformanceNetworkType.GCP_TCPXO,
130
+ KubernetesHighPerformanceNetworkType.GCP_GPUDIRECT_RDMA)
131
+
132
+ def requires_ipc_lock_capability(self) -> bool:
133
+ """Check if this cluster type requires IPC_LOCK capability."""
134
+ return self.supports_high_performance_networking()
135
+
136
+ def requires_tcpxo_daemon(self) -> bool:
137
+ """Check if this cluster type requires TCPXO daemon."""
138
+ return self == KubernetesHighPerformanceNetworkType.GCP_TCPXO
139
+
140
+
58
141
  # TODO(romilb): Move constants to constants.py
59
142
  DEFAULT_NAMESPACE = 'default'
60
143
 
@@ -72,12 +155,14 @@ MEMORY_SIZE_UNITS = {
72
155
  # The resource keys used by Kubernetes to track NVIDIA GPUs and Google TPUs on
73
156
  # nodes. These keys are typically used in the node's status.allocatable
74
157
  # or status.capacity fields to indicate the available resources on the node.
75
- GPU_RESOURCE_KEY = 'nvidia.com/gpu'
158
+ SUPPORTED_GPU_RESOURCE_KEYS = {'amd': 'amd.com/gpu', 'nvidia': 'nvidia.com/gpu'}
76
159
  TPU_RESOURCE_KEY = 'google.com/tpu'
77
160
 
78
161
  NO_ACCELERATOR_HELP_MESSAGE = (
79
162
  'If your cluster contains GPUs or TPUs, make sure '
80
- f'{GPU_RESOURCE_KEY} or {TPU_RESOURCE_KEY} resource is available '
163
+ f'one of {SUPPORTED_GPU_RESOURCE_KEYS["amd"]}, '
164
+ f'{SUPPORTED_GPU_RESOURCE_KEYS["nvidia"]} or '
165
+ f'{TPU_RESOURCE_KEY} resource is available '
81
166
  'on the nodes and the node labels for identifying GPUs/TPUs '
82
167
  '(e.g., skypilot.co/accelerator) are setup correctly. ')
83
168
 
@@ -131,6 +216,64 @@ DEFAULT_MAX_RETRIES = 3
131
216
  DEFAULT_RETRY_INTERVAL_SECONDS = 1
132
217
 
133
218
 
219
+ def normalize_tpu_accelerator_name(accelerator: str) -> Tuple[str, int]:
220
+ """Normalize TPU names to the k8s-compatible name and extract count."""
221
+ # Examples:
222
+ # 'tpu-v6e-8' -> ('tpu-v6e-slice', 8)
223
+ # 'tpu-v5litepod-4' -> ('tpu-v5-lite-podslice', 4)
224
+
225
+ gcp_to_k8s_patterns = [
226
+ (r'^tpu-v6e-(\d+)$', 'tpu-v6e-slice'),
227
+ (r'^tpu-v5p-(\d+)$', 'tpu-v5p-slice'),
228
+ (r'^tpu-v5litepod-(\d+)$', 'tpu-v5-lite-podslice'),
229
+ (r'^tpu-v5lite-(\d+)$', 'tpu-v5-lite-device'),
230
+ (r'^tpu-v4-(\d+)$', 'tpu-v4-podslice'),
231
+ ]
232
+
233
+ for pattern, replacement in gcp_to_k8s_patterns:
234
+ match = re.match(pattern, accelerator)
235
+ if match:
236
+ count = int(match.group(1))
237
+ return replacement, count
238
+
239
+ # Default fallback
240
+ return accelerator, 1
241
+
242
+
243
+ def _is_cloudflare_403_error(exception: Exception) -> bool:
244
+ """Check if an exception is a transient CloudFlare 403 error.
245
+
246
+ CloudFlare proxy 403 errors with CF-specific headers are transient and
247
+ should be retried, unlike real RBAC 403 errors.
248
+
249
+ Args:
250
+ exception: The exception to check
251
+
252
+ Returns:
253
+ True if this is a CloudFlare 403 error that should be retried
254
+ """
255
+ if not isinstance(exception, kubernetes.api_exception()):
256
+ return False
257
+
258
+ # Only check for 403 errors
259
+ if exception.status != 403:
260
+ return False
261
+
262
+ # Check for CloudFlare-specific headers
263
+ headers = exception.headers if hasattr(exception, 'headers') else {}
264
+ if not headers:
265
+ return False
266
+
267
+ # CloudFlare errors have CF-RAY header and/or Server: cloudflare
268
+ for k, v in headers.items():
269
+ if 'cf-ray' in k.lower():
270
+ return True
271
+ if 'server' in k.lower() and 'cloudflare' in str(v).lower():
272
+ return True
273
+
274
+ return False
275
+
276
+
134
277
  def _retry_on_error(max_retries=DEFAULT_MAX_RETRIES,
135
278
  retry_interval=DEFAULT_RETRY_INTERVAL_SECONDS,
136
279
  resource_type: Optional[str] = None):
@@ -165,19 +308,25 @@ def _retry_on_error(max_retries=DEFAULT_MAX_RETRIES,
165
308
  kubernetes.api_exception(),
166
309
  kubernetes.config_exception()) as e:
167
310
  last_exception = e
311
+
312
+ # Check if this is a CloudFlare transient 403 error
313
+ is_cloudflare_403 = _is_cloudflare_403_error(e)
314
+
168
315
  # Don't retry on permanent errors like 401 (Unauthorized)
169
- # or 403 (Forbidden)
316
+ # or 403 (Forbidden), unless it's a CloudFlare transient 403
170
317
  if (isinstance(e, kubernetes.api_exception()) and
171
- e.status in (401, 403)):
318
+ e.status in (401, 403) and not is_cloudflare_403):
172
319
  # Raise KubeAPIUnreachableError exception so that the
173
320
  # optimizer/provisioner can failover to other clouds.
174
321
  raise exceptions.KubeAPIUnreachableError(
175
322
  f'Kubernetes API error: {str(e)}') from e
176
323
  if attempt < max_retries - 1:
177
324
  sleep_time = backoff.current_backoff()
178
- logger.debug(f'Kubernetes API call {func.__name__} '
179
- f'failed with {str(e)}. Retrying in '
180
- f'{sleep_time:.1f}s...')
325
+ error_type = 'CloudFlare 403' if is_cloudflare_403 else 'error'
326
+ logger.debug(
327
+ f'Kubernetes API call {func.__name__} '
328
+ f'failed with {error_type} {str(e)}. Retrying in '
329
+ f'{sleep_time:.1f}s...')
181
330
  time.sleep(sleep_time)
182
331
  continue
183
332
 
@@ -287,8 +436,13 @@ def get_gke_accelerator_name(accelerator: str) -> str:
287
436
  # A100-80GB, L4, H100-80GB and H100-MEGA-80GB
288
437
  # have a different name pattern.
289
438
  return 'nvidia-{}'.format(accelerator.lower())
439
+ elif accelerator == 'H200':
440
+ # H200s on GCP use this label format
441
+ return 'nvidia-h200-141gb'
290
442
  elif accelerator.startswith('tpu-'):
291
443
  return accelerator
444
+ elif accelerator.startswith('amd-'):
445
+ return accelerator
292
446
  else:
293
447
  return 'nvidia-tesla-{}'.format(accelerator.lower())
294
448
 
@@ -342,6 +496,9 @@ class CoreWeaveLabelFormatter(GPULabelFormatter):
342
496
 
343
497
  LABEL_KEY = 'gpu.nvidia.com/class'
344
498
 
499
+ # TODO (kyuds): fill in more label values for different accelerators.
500
+ ACC_VALUE_MAPPINGS = {'H100_NVLINK_80GB': 'H100'}
501
+
345
502
  @classmethod
346
503
  def get_label_key(cls, accelerator: Optional[str] = None) -> str:
347
504
  return cls.LABEL_KEY
@@ -360,7 +517,8 @@ class CoreWeaveLabelFormatter(GPULabelFormatter):
360
517
 
361
518
  @classmethod
362
519
  def get_accelerator_from_label_value(cls, value: str) -> str:
363
- return value
520
+ # return original label value if not found in mappings.
521
+ return cls.ACC_VALUE_MAPPINGS.get(value, value)
364
522
 
365
523
 
366
524
  class GKELabelFormatter(GPULabelFormatter):
@@ -425,6 +583,10 @@ class GKELabelFormatter(GPULabelFormatter):
425
583
 
426
584
  e.g. tpu-v5-lite-podslice:8 -> '2x4'
427
585
  """
586
+ # If the TPU type is in the GKE_TPU_ACCELERATOR_TO_GENERATION, it means
587
+ # that it has been normalized before, no need to normalize again.
588
+ if acc_type not in GKE_TPU_ACCELERATOR_TO_GENERATION:
589
+ acc_type, acc_count = normalize_tpu_accelerator_name(acc_type)
428
590
  count_to_topology = cls.GKE_TPU_TOPOLOGIES.get(acc_type,
429
591
  {}).get(acc_count, None)
430
592
  if count_to_topology is None:
@@ -452,13 +614,26 @@ class GKELabelFormatter(GPULabelFormatter):
452
614
  # we map H100 ---> H100-80GB and keep H100-MEGA-80GB
453
615
  # to distinguish between a3-high and a3-mega instances
454
616
  return 'H100'
617
+ elif acc == 'H200-141GB':
618
+ return 'H200'
455
619
  return acc
456
620
  elif is_tpu_on_gke(value):
457
621
  return value
622
+ elif value == '':
623
+ # heterogenous cluster may have empty labels for cpu nodes.
624
+ return ''
458
625
  else:
459
626
  raise ValueError(
460
627
  f'Invalid accelerator name in GKE cluster: {value}')
461
628
 
629
+ @classmethod
630
+ def validate_label_value(cls, value: str) -> Tuple[bool, str]:
631
+ try:
632
+ _ = cls.get_accelerator_from_label_value(value)
633
+ return True, ''
634
+ except ValueError as e:
635
+ return False, str(e)
636
+
462
637
 
463
638
  class GFDLabelFormatter(GPULabelFormatter):
464
639
  """GPU Feature Discovery label formatter
@@ -563,17 +738,37 @@ def detect_gpu_label_formatter(
563
738
  for label, value in node.metadata.labels.items():
564
739
  node_labels[node.metadata.name].append((label, value))
565
740
 
566
- label_formatter = None
567
-
741
+ invalid_label_values: List[Tuple[str, str, str, str]] = []
568
742
  # Check if the node labels contain any of the GPU label prefixes
569
743
  for lf in LABEL_FORMATTER_REGISTRY:
744
+ skip = False
570
745
  for _, label_list in node_labels.items():
571
- for label, _ in label_list:
746
+ for label, value in label_list:
572
747
  if lf.match_label_key(label):
573
- label_formatter = lf()
574
- return label_formatter, node_labels
748
+ # Skip empty label values
749
+ if not value or value.strip() == '':
750
+ continue
751
+ valid, reason = lf.validate_label_value(value)
752
+ if valid:
753
+ return lf(), node_labels
754
+ else:
755
+ invalid_label_values.append(
756
+ (label, lf.__name__, value, reason))
757
+ skip = True
758
+ break
759
+ if skip:
760
+ break
761
+ if skip:
762
+ continue
575
763
 
576
- return label_formatter, node_labels
764
+ for label, lf_name, value, reason in invalid_label_values:
765
+ logger.warning(f'GPU label {label} matched for label '
766
+ f'formatter {lf_name}, '
767
+ f'but has invalid value {value}. '
768
+ f'Reason: {reason}. '
769
+ 'Skipping...')
770
+
771
+ return None, node_labels
577
772
 
578
773
 
579
774
  class Autoscaler:
@@ -703,6 +898,74 @@ class GKEAutoscaler(Autoscaler):
703
898
  return True
704
899
  return False
705
900
 
901
+ @classmethod
902
+ @annotations.lru_cache(scope='request', maxsize=10)
903
+ def get_available_machine_types(cls, context: str) -> List[str]:
904
+ """Returns the list of machine types that are available in the cluster.
905
+ """
906
+ # Assume context naming convention of
907
+ # gke_PROJECT-ID_LOCATION_CLUSTER-NAME
908
+ valid, project_id, location, cluster_name = cls._validate_context_name(
909
+ context)
910
+ if not valid:
911
+ # Context name is not in the format of
912
+ # gke_PROJECT-ID_LOCATION_CLUSTER-NAME.
913
+ # Cannot determine if the context can autoscale.
914
+ # Return empty list.
915
+ logger.debug(f'Context {context} is not in the format of '
916
+ f'gke_PROJECT-ID_LOCATION_CLUSTER-NAME. '
917
+ 'Returning empty machine type list.')
918
+ return []
919
+ try:
920
+ logger.debug(
921
+ f'Attempting to get information about cluster {cluster_name}')
922
+ container_service = gcp.build('container',
923
+ 'v1',
924
+ credentials=None,
925
+ cache_discovery=False)
926
+ cluster = container_service.projects().locations().clusters().get(
927
+ name=f'projects/{project_id}'
928
+ f'/locations/{location}'
929
+ f'/clusters/{cluster_name}').execute()
930
+ except ImportError:
931
+ # If the gcp module is not installed, return empty list.
932
+ # Remind the user once per day to install the gcp module for better
933
+ # pod scheduling with GKE autoscaler.
934
+ if time.time() - cls._pip_install_gcp_hint_last_sent > 60 * 60 * 24:
935
+ logger.info(
936
+ 'Could not fetch autoscaler information from GKE. '
937
+ 'Run pip install "skypilot[gcp]" for more intelligent pod '
938
+ 'scheduling with GKE autoscaler.')
939
+ cls._pip_install_gcp_hint_last_sent = time.time()
940
+ return []
941
+ except gcp.http_error_exception() as e:
942
+ # Cluster information is not available.
943
+ # Return empty list.
944
+ logger.debug(f'{e.message}', exc_info=True)
945
+ return []
946
+
947
+ machine_types = []
948
+ # Get the list of machine types that are available in the cluster.
949
+ node_pools = cluster.get('nodePools', [])
950
+ for node_pool in node_pools:
951
+ name = node_pool.get('name', '')
952
+ logger.debug(f'Checking if node pool {name} '
953
+ 'has autoscaling enabled.')
954
+ autoscaling_enabled = (node_pool.get('autoscaling',
955
+ {}).get('enabled', False))
956
+ if autoscaling_enabled:
957
+ logger.debug(f'Node pool {name} has autoscaling enabled.')
958
+ try:
959
+ machine_type = node_pool.get('config',
960
+ {}).get('machineType', '')
961
+ if machine_type:
962
+ machine_types.append(machine_type)
963
+ except KeyError:
964
+ logger.debug(f'Encountered KeyError while checking machine '
965
+ f'type of node pool {name}.')
966
+ continue
967
+ return machine_types
968
+
706
969
  @classmethod
707
970
  def _validate_context_name(cls, context: str) -> Tuple[bool, str, str, str]:
708
971
  """Validates the context name is in the format of
@@ -752,6 +1015,8 @@ class GKEAutoscaler(Autoscaler):
752
1015
  f'checking {node_pool_name} for TPU {requested_acc_type}:'
753
1016
  f'{requested_acc_count}')
754
1017
  if 'resourceLabels' in node_config:
1018
+ requested_acc_type, requested_acc_count = normalize_tpu_accelerator_name(
1019
+ requested_acc_type)
755
1020
  accelerator_exists = cls._node_pool_has_tpu_capacity(
756
1021
  node_config['resourceLabels'], machine_type,
757
1022
  requested_acc_type, requested_acc_count)
@@ -801,12 +1066,16 @@ class GKEAutoscaler(Autoscaler):
801
1066
  to fit the instance type.
802
1067
  """
803
1068
  for accelerator in node_pool_accelerators:
1069
+ raw_value = accelerator['acceleratorType']
804
1070
  node_accelerator_type = (
805
- GKELabelFormatter.get_accelerator_from_label_value(
806
- accelerator['acceleratorType']))
1071
+ GKELabelFormatter.get_accelerator_from_label_value(raw_value))
1072
+ # handle heterogenous nodes.
1073
+ if not node_accelerator_type:
1074
+ continue
807
1075
  node_accelerator_count = accelerator['acceleratorCount']
808
- if node_accelerator_type == requested_gpu_type and int(
809
- node_accelerator_count) >= requested_gpu_count:
1076
+ viable_names = [node_accelerator_type.lower(), raw_value.lower()]
1077
+ if (requested_gpu_type.lower() in viable_names and
1078
+ int(node_accelerator_count) >= requested_gpu_count):
810
1079
  return True
811
1080
  return False
812
1081
 
@@ -869,6 +1138,14 @@ class KarpenterAutoscaler(Autoscaler):
869
1138
  can_query_backend: bool = False
870
1139
 
871
1140
 
1141
+ class CoreweaveAutoscaler(Autoscaler):
1142
+ """CoreWeave autoscaler
1143
+ """
1144
+
1145
+ label_formatter: Any = CoreWeaveLabelFormatter
1146
+ can_query_backend: bool = False
1147
+
1148
+
872
1149
  class GenericAutoscaler(Autoscaler):
873
1150
  """Generic autoscaler
874
1151
  """
@@ -881,6 +1158,7 @@ class GenericAutoscaler(Autoscaler):
881
1158
  AUTOSCALER_TYPE_TO_AUTOSCALER = {
882
1159
  kubernetes_enums.KubernetesAutoscalerType.GKE: GKEAutoscaler,
883
1160
  kubernetes_enums.KubernetesAutoscalerType.KARPENTER: KarpenterAutoscaler,
1161
+ kubernetes_enums.KubernetesAutoscalerType.COREWEAVE: CoreweaveAutoscaler,
884
1162
  kubernetes_enums.KubernetesAutoscalerType.GENERIC: GenericAutoscaler,
885
1163
  }
886
1164
 
@@ -894,10 +1172,10 @@ def detect_accelerator_resource(
894
1172
  context: Optional[str]) -> Tuple[bool, Set[str]]:
895
1173
  """Checks if the Kubernetes cluster has GPU/TPU resource.
896
1174
 
897
- Two types of accelerator resources are available which are each checked
898
- with nvidia.com/gpu and google.com/tpu. If nvidia.com/gpu resource is
1175
+ Three types of accelerator resources are available which are each checked
1176
+ with amd.com/gpu, nvidia.com/gpu and google.com/tpu. If amd.com/gpu or nvidia.com/gpu resource is
899
1177
  missing, that typically means that the Kubernetes cluster does not have
900
- GPUs or the nvidia GPU operator and/or device drivers are not installed.
1178
+ GPUs or the amd/nvidia GPU operator and/or device drivers are not installed.
901
1179
 
902
1180
  Returns:
903
1181
  bool: True if the cluster has GPU_RESOURCE_KEY or TPU_RESOURCE_KEY
@@ -908,15 +1186,57 @@ def detect_accelerator_resource(
908
1186
  nodes = get_kubernetes_nodes(context=context)
909
1187
  for node in nodes:
910
1188
  cluster_resources.update(node.status.allocatable.keys())
911
- has_accelerator = (get_gpu_resource_key() in cluster_resources or
1189
+ has_accelerator = (get_gpu_resource_key(context) in cluster_resources or
912
1190
  TPU_RESOURCE_KEY in cluster_resources)
913
1191
 
914
1192
  return has_accelerator, cluster_resources
915
1193
 
916
1194
 
1195
+ @dataclasses.dataclass
1196
+ class V1ObjectMeta:
1197
+ name: str
1198
+ labels: Dict[str, str]
1199
+ namespace: str = '' # Used for pods, not nodes
1200
+
1201
+
1202
+ @dataclasses.dataclass
1203
+ class V1NodeAddress:
1204
+ type: str
1205
+ address: str
1206
+
1207
+
1208
+ @dataclasses.dataclass
1209
+ class V1NodeStatus:
1210
+ allocatable: Dict[str, str]
1211
+ capacity: Dict[str, str]
1212
+ addresses: List[V1NodeAddress]
1213
+
1214
+
1215
+ @dataclasses.dataclass
1216
+ class V1Node:
1217
+ metadata: V1ObjectMeta
1218
+ status: V1NodeStatus
1219
+
1220
+ @classmethod
1221
+ def from_dict(cls, data: dict) -> 'V1Node':
1222
+ """Create V1Node from a dictionary."""
1223
+ return cls(metadata=V1ObjectMeta(
1224
+ name=data['metadata']['name'],
1225
+ labels=data['metadata'].get('labels', {}),
1226
+ ),
1227
+ status=V1NodeStatus(
1228
+ allocatable=data['status']['allocatable'],
1229
+ capacity=data['status']['capacity'],
1230
+ addresses=[
1231
+ V1NodeAddress(type=addr['type'],
1232
+ address=addr['address'])
1233
+ for addr in data['status'].get('addresses', [])
1234
+ ]))
1235
+
1236
+
917
1237
  @annotations.lru_cache(scope='request', maxsize=10)
918
1238
  @_retry_on_error(resource_type='node')
919
- def get_kubernetes_nodes(*, context: Optional[str] = None) -> List[Any]:
1239
+ def get_kubernetes_nodes(*, context: Optional[str] = None) -> List[V1Node]:
920
1240
  """Gets the kubernetes nodes in the context.
921
1241
 
922
1242
  If context is None, gets the nodes in the current context.
@@ -924,25 +1244,113 @@ def get_kubernetes_nodes(*, context: Optional[str] = None) -> List[Any]:
924
1244
  if context is None:
925
1245
  context = get_current_kube_config_context_name()
926
1246
 
927
- nodes = kubernetes.core_api(context).list_node(
928
- _request_timeout=kubernetes.API_TIMEOUT).items
1247
+ # Return raw urllib3.HTTPResponse object so that we can parse the json
1248
+ # more efficiently.
1249
+ response = kubernetes.core_api(context).list_node(
1250
+ _request_timeout=kubernetes.API_TIMEOUT, _preload_content=False)
1251
+ try:
1252
+ nodes = [
1253
+ V1Node.from_dict(item_dict) for item_dict in ijson.items(
1254
+ response, 'items.item', buf_size=IJSON_BUFFER_SIZE)
1255
+ ]
1256
+ finally:
1257
+ response.release_conn()
1258
+
929
1259
  return nodes
930
1260
 
931
1261
 
932
- @_retry_on_error(resource_type='pod')
933
- def get_all_pods_in_kubernetes_cluster(*,
934
- context: Optional[str] = None
935
- ) -> List[Any]:
936
- """Gets pods in all namespaces in kubernetes cluster indicated by context.
1262
+ @dataclasses.dataclass
1263
+ class V1PodStatus:
1264
+ phase: str
1265
+
937
1266
 
938
- Used for computing cluster resource usage.
1267
+ @dataclasses.dataclass
1268
+ class V1ResourceRequirements:
1269
+ requests: Optional[Dict[str, str]]
1270
+
1271
+
1272
+ @dataclasses.dataclass
1273
+ class V1Container:
1274
+ resources: V1ResourceRequirements
1275
+
1276
+
1277
+ @dataclasses.dataclass
1278
+ class V1PodSpec:
1279
+ containers: List[V1Container]
1280
+ node_name: Optional[str]
1281
+
1282
+
1283
+ @dataclasses.dataclass
1284
+ class V1Pod:
1285
+ metadata: V1ObjectMeta
1286
+ status: V1PodStatus
1287
+ spec: V1PodSpec
1288
+
1289
+ @classmethod
1290
+ def from_dict(cls, data: dict) -> 'V1Pod':
1291
+ """Create V1Pod from a dictionary."""
1292
+ return cls(metadata=V1ObjectMeta(
1293
+ name=data['metadata']['name'],
1294
+ labels=data['metadata'].get('labels', {}),
1295
+ namespace=data['metadata'].get('namespace'),
1296
+ ),
1297
+ status=V1PodStatus(phase=data['status'].get('phase'),),
1298
+ spec=V1PodSpec(
1299
+ node_name=data['spec'].get('nodeName'),
1300
+ containers=[
1301
+ V1Container(resources=V1ResourceRequirements(
1302
+ requests=container.get('resources', {}).get(
1303
+ 'requests') or None))
1304
+ for container in data['spec'].get('containers', [])
1305
+ ]))
1306
+
1307
+
1308
+ @_retry_on_error(resource_type='pod')
1309
+ def get_allocated_gpu_qty_by_node(
1310
+ *,
1311
+ context: Optional[str] = None,
1312
+ ) -> Dict[str, int]:
1313
+ """Gets allocated GPU quantity by each node by fetching pods in
1314
+ all namespaces in kubernetes cluster indicated by context.
939
1315
  """
940
1316
  if context is None:
941
1317
  context = get_current_kube_config_context_name()
1318
+ non_included_pod_statuses = POD_STATUSES.copy()
1319
+ status_filters = ['Running', 'Pending']
1320
+ if status_filters is not None:
1321
+ non_included_pod_statuses -= set(status_filters)
1322
+ field_selector = ','.join(
1323
+ [f'status.phase!={status}' for status in non_included_pod_statuses])
942
1324
 
943
- pods = kubernetes.core_api(context).list_pod_for_all_namespaces(
944
- _request_timeout=kubernetes.API_TIMEOUT).items
945
- return pods
1325
+ # Return raw urllib3.HTTPResponse object so that we can parse the json
1326
+ # more efficiently.
1327
+ response = kubernetes.core_api(context).list_pod_for_all_namespaces(
1328
+ _request_timeout=kubernetes.API_TIMEOUT,
1329
+ _preload_content=False,
1330
+ field_selector=field_selector)
1331
+ try:
1332
+ allocated_qty_by_node: Dict[str, int] = collections.defaultdict(int)
1333
+ for item_dict in ijson.items(response,
1334
+ 'items.item',
1335
+ buf_size=IJSON_BUFFER_SIZE):
1336
+ pod = V1Pod.from_dict(item_dict)
1337
+ if should_exclude_pod_from_gpu_allocation(pod):
1338
+ logger.debug(
1339
+ f'Excluding pod {pod.metadata.name} from GPU count '
1340
+ f'calculations on node {pod.spec.node_name}')
1341
+ continue
1342
+ # Iterate over all the containers in the pod and sum the
1343
+ # GPU requests
1344
+ pod_allocated_qty = 0
1345
+ for container in pod.spec.containers:
1346
+ if container.resources.requests:
1347
+ pod_allocated_qty += get_node_accelerator_count(
1348
+ context, container.resources.requests)
1349
+ if pod_allocated_qty > 0 and pod.spec.node_name:
1350
+ allocated_qty_by_node[pod.spec.node_name] += pod_allocated_qty
1351
+ return allocated_qty_by_node
1352
+ finally:
1353
+ response.release_conn()
946
1354
 
947
1355
 
948
1356
  def check_instance_fits(context: Optional[str],
@@ -980,14 +1388,18 @@ def check_instance_fits(context: Optional[str],
980
1388
  if node_cpus > max_cpu:
981
1389
  max_cpu = node_cpus
982
1390
  max_mem = node_memory_gb
983
- if (node_cpus >= candidate_instance_type.cpus and
984
- node_memory_gb >= candidate_instance_type.memory):
1391
+ # We don't consider nodes that have exactly the same amount of
1392
+ # CPU or memory as the candidate instance type.
1393
+ # This is to account for the fact that each node always has some
1394
+ # amount kube-system pods running on it and consuming resources.
1395
+ if (node_cpus > candidate_instance_type.cpus and
1396
+ node_memory_gb > candidate_instance_type.memory):
985
1397
  return True, None
986
1398
  return False, (
987
1399
  'Maximum resources found on a single node: '
988
1400
  f'{max_cpu} CPUs, {common_utils.format_float(max_mem)}G Memory')
989
1401
 
990
- def check_tpu_fits(candidate_instance_type: 'KubernetesInstanceType',
1402
+ def check_tpu_fits(acc_type: str, acc_count: int,
991
1403
  node_list: List[Any]) -> Tuple[bool, Optional[str]]:
992
1404
  """Checks if the instance fits on the cluster based on requested TPU.
993
1405
 
@@ -997,8 +1409,6 @@ def check_instance_fits(context: Optional[str],
997
1409
  node (node_tpu_chip_count) and the total TPU chips across the entire
998
1410
  podslice (topology_chip_count) are correctly handled.
999
1411
  """
1000
- acc_type = candidate_instance_type.accelerator_type
1001
- acc_count = candidate_instance_type.accelerator_count
1002
1412
  tpu_list_in_cluster = []
1003
1413
  for node in node_list:
1004
1414
  if acc_type == node.metadata.labels[
@@ -1049,14 +1459,15 @@ def check_instance_fits(context: Optional[str],
1049
1459
  if is_tpu_on_gke(acc_type):
1050
1460
  # If requested accelerator is a TPU type, check if the cluster
1051
1461
  # has sufficient TPU resource to meet the requirement.
1052
- fits, reason = check_tpu_fits(k8s_instance_type, gpu_nodes)
1462
+ acc_type, acc_count = normalize_tpu_accelerator_name(acc_type)
1463
+ fits, reason = check_tpu_fits(acc_type, acc_count, gpu_nodes)
1053
1464
  if reason is not None:
1054
1465
  return fits, reason
1055
1466
  else:
1056
1467
  # Check if any of the GPU nodes have sufficient number of GPUs.
1057
1468
  gpu_nodes = [
1058
- node for node in gpu_nodes if
1059
- get_node_accelerator_count(node.status.allocatable) >= acc_count
1469
+ node for node in gpu_nodes if get_node_accelerator_count(
1470
+ context, node.status.allocatable) >= acc_count
1060
1471
  ]
1061
1472
  if not gpu_nodes:
1062
1473
  return False, (
@@ -1118,14 +1529,14 @@ def get_accelerator_label_key_values(
1118
1529
  Raises:
1119
1530
  ResourcesUnavailableError: Can be raised from the following conditions:
1120
1531
  - The cluster does not have GPU/TPU resources
1121
- (nvidia.com/gpu, google.com/tpu)
1532
+ (amd.com/gpu, nvidia.com/gpu, google.com/tpu)
1122
1533
  - The cluster has GPU/TPU resources, but no node in the cluster has
1123
1534
  an accelerator label.
1124
1535
  - The cluster has a node with an invalid accelerator label value.
1125
1536
  - The cluster doesn't have any nodes with acc_type GPU/TPU
1126
1537
  """
1127
1538
  # Check if the cluster has GPU resources
1128
- # TODO(romilb): This assumes the accelerator is a nvidia GPU. We
1539
+ # TODO(romilb): This assumes the accelerator is a amd/nvidia GPU. We
1129
1540
  # need to support TPUs and other accelerators as well.
1130
1541
  # TODO(romilb): Currently, we broadly disable all GPU checks if autoscaling
1131
1542
  # is configured in config.yaml since the cluster may be scaling up from
@@ -1133,7 +1544,16 @@ def get_accelerator_label_key_values(
1133
1544
  # support pollingthe clusters for autoscaling information, such as the
1134
1545
  # node pools configured etc.
1135
1546
 
1136
- autoscaler_type = get_autoscaler_type()
1547
+ is_ssh_node_pool = context.startswith('ssh-') if context else False
1548
+ cloud_name = 'SSH Node Pool' if is_ssh_node_pool else 'Kubernetes cluster'
1549
+ context_display_name = common_utils.removeprefix(
1550
+ context, 'ssh-') if (context and is_ssh_node_pool) else context
1551
+
1552
+ autoscaler_type = skypilot_config.get_effective_region_config(
1553
+ cloud='kubernetes',
1554
+ region=context,
1555
+ keys=('autoscaler',),
1556
+ default_value=None)
1137
1557
  if autoscaler_type is not None:
1138
1558
  # If autoscaler is set in config.yaml, override the label key and value
1139
1559
  # to the autoscaler's format and bypass the GPU checks.
@@ -1142,7 +1562,8 @@ def get_accelerator_label_key_values(
1142
1562
  # early since we assume the cluster autoscaler will handle GPU
1143
1563
  # node provisioning.
1144
1564
  return None, None, None, None
1145
- autoscaler = AUTOSCALER_TYPE_TO_AUTOSCALER.get(autoscaler_type)
1565
+ autoscaler = AUTOSCALER_TYPE_TO_AUTOSCALER.get(
1566
+ kubernetes_enums.KubernetesAutoscalerType(autoscaler_type))
1146
1567
  assert autoscaler is not None, ('Unsupported autoscaler type:'
1147
1568
  f' {autoscaler_type}')
1148
1569
  formatter = autoscaler.label_formatter
@@ -1172,13 +1593,17 @@ def get_accelerator_label_key_values(
1172
1593
  suffix = ''
1173
1594
  if env_options.Options.SHOW_DEBUG_INFO.get():
1174
1595
  suffix = f' Found node labels: {node_labels}'
1175
- raise exceptions.ResourcesUnavailableError(
1176
- 'Could not detect GPU labels in Kubernetes cluster. '
1177
- 'If this cluster has GPUs, please ensure GPU nodes have '
1178
- 'node labels of either of these formats: '
1179
- f'{supported_formats}. Please refer to '
1180
- 'the documentation on how to set up node labels.'
1181
- f'{suffix}')
1596
+ msg = (f'Could not detect GPU labels in {cloud_name}.')
1597
+ if not is_ssh_node_pool:
1598
+ msg += (' Run `sky check ssh` to debug.')
1599
+ else:
1600
+ msg += (
1601
+ ' If this cluster has GPUs, please ensure GPU nodes have '
1602
+ 'node labels of either of these formats: '
1603
+ f'{supported_formats}. Please refer to '
1604
+ 'the documentation on how to set up node labels.')
1605
+ msg += f'{suffix}'
1606
+ raise exceptions.ResourcesUnavailableError(msg)
1182
1607
  else:
1183
1608
  # Validate the label value on all nodes labels to ensure they are
1184
1609
  # correctly setup and will behave as expected.
@@ -1189,7 +1614,7 @@ def get_accelerator_label_key_values(
1189
1614
  value)
1190
1615
  if not is_valid:
1191
1616
  raise exceptions.ResourcesUnavailableError(
1192
- f'Node {node_name!r} in Kubernetes cluster has '
1617
+ f'Node {node_name!r} in {cloud_name} has '
1193
1618
  f'invalid GPU label: {label}={value}. {reason}')
1194
1619
  if check_mode:
1195
1620
  # If check mode is enabled and we reached so far, we can
@@ -1208,9 +1633,13 @@ def get_accelerator_label_key_values(
1208
1633
  if is_multi_host_tpu(node_metadata_labels):
1209
1634
  continue
1210
1635
  for label, value in label_list:
1211
- if (label_formatter.match_label_key(label) and
1212
- label_formatter.get_accelerator_from_label_value(
1213
- value).lower() == acc_type.lower()):
1636
+ if label_formatter.match_label_key(label):
1637
+ # match either canonicalized name or raw name
1638
+ accelerator = (label_formatter.
1639
+ get_accelerator_from_label_value(value))
1640
+ viable = [value.lower(), accelerator.lower()]
1641
+ if acc_type.lower() not in viable:
1642
+ continue
1214
1643
  if is_tpu_on_gke(acc_type):
1215
1644
  assert isinstance(label_formatter,
1216
1645
  GKELabelFormatter)
@@ -1253,10 +1682,10 @@ def get_accelerator_label_key_values(
1253
1682
  # TODO(Doyoung): Update the error message raised with the
1254
1683
  # multi-host TPU support.
1255
1684
  raise exceptions.ResourcesUnavailableError(
1256
- 'Could not find any node in the Kubernetes cluster '
1685
+ f'Could not find any node in the {cloud_name} '
1257
1686
  f'with {acc_type}. Please ensure at least one node in the '
1258
1687
  f'cluster has {acc_type} and node labels are setup '
1259
- 'correctly. Please refer to the documentration for more. '
1688
+ 'correctly. Please refer to the documentation for more. '
1260
1689
  f'{suffix}. Note that multi-host TPU podslices are '
1261
1690
  'currently not unsupported.')
1262
1691
  else:
@@ -1266,15 +1695,27 @@ def get_accelerator_label_key_values(
1266
1695
  if env_options.Options.SHOW_DEBUG_INFO.get():
1267
1696
  suffix = (' Available resources on the cluster: '
1268
1697
  f'{cluster_resources}')
1269
- raise exceptions.ResourcesUnavailableError(
1270
- f'Could not detect GPU/TPU resources ({GPU_RESOURCE_KEY!r} or '
1271
- f'{TPU_RESOURCE_KEY!r}) in Kubernetes cluster. If this cluster'
1272
- ' contains GPUs, please ensure GPU drivers are installed on '
1273
- 'the node. Check if the GPUs are setup correctly by running '
1274
- '`kubectl describe nodes` and looking for the '
1275
- f'{GPU_RESOURCE_KEY!r} or {TPU_RESOURCE_KEY!r} resource. '
1276
- 'Please refer to the documentation on how to set up GPUs.'
1277
- f'{suffix}')
1698
+ if is_ssh_node_pool:
1699
+ msg = (
1700
+ f'Could not detect GPUs in SSH Node Pool '
1701
+ f'\'{context_display_name}\'. If this cluster contains '
1702
+ 'GPUs, please ensure GPU drivers are installed on the node '
1703
+ 'and re-run '
1704
+ f'`sky ssh up --infra {context_display_name}`. {suffix}')
1705
+ else:
1706
+ msg = (
1707
+ f'Could not detect GPU/TPU resources ({SUPPORTED_GPU_RESOURCE_KEYS["amd"]!r}, '
1708
+ f'{SUPPORTED_GPU_RESOURCE_KEYS["nvidia"]!r} or '
1709
+ f'{TPU_RESOURCE_KEY!r}) in Kubernetes cluster. If this cluster'
1710
+ ' contains GPUs, please ensure GPU drivers are installed on '
1711
+ 'the node. Check if the GPUs are setup correctly by running '
1712
+ '`kubectl describe nodes` and looking for the '
1713
+ f'{SUPPORTED_GPU_RESOURCE_KEYS["amd"]!r}, '
1714
+ f'{SUPPORTED_GPU_RESOURCE_KEYS["nvidia"]!r} or '
1715
+ f'{TPU_RESOURCE_KEY!r} resource. '
1716
+ 'Please refer to the documentation on how to set up GPUs.'
1717
+ f'{suffix}')
1718
+ raise exceptions.ResourcesUnavailableError(msg)
1278
1719
  assert False, 'This should not be reached'
1279
1720
 
1280
1721
 
@@ -1298,23 +1739,6 @@ def get_port(svc_name: str, namespace: str, context: Optional[str]) -> int:
1298
1739
  return head_service.spec.ports[0].node_port
1299
1740
 
1300
1741
 
1301
- def get_external_ip(network_mode: Optional[
1302
- kubernetes_enums.KubernetesNetworkingMode], context: Optional[str]) -> str:
1303
- if network_mode == kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD:
1304
- return '127.0.0.1'
1305
- # Return the IP address of the first node with an external IP
1306
- nodes = kubernetes.core_api(context).list_node().items
1307
- for node in nodes:
1308
- if node.status.addresses:
1309
- for address in node.status.addresses:
1310
- if address.type == 'ExternalIP':
1311
- return address.address
1312
- # If no external IP is found, use the API server IP
1313
- api_host = kubernetes.core_api(context).api_client.configuration.host
1314
- parsed_url = urlparse(api_host)
1315
- return parsed_url.hostname
1316
-
1317
-
1318
1742
  def check_credentials(context: Optional[str],
1319
1743
  timeout: int = kubernetes.API_TIMEOUT,
1320
1744
  run_optional_checks: bool = False) -> \
@@ -1333,7 +1757,10 @@ def check_credentials(context: Optional[str],
1333
1757
  try:
1334
1758
  namespace = get_kube_config_context_namespace(context)
1335
1759
  kubernetes.core_api(context).list_namespaced_pod(
1336
- namespace, _request_timeout=timeout)
1760
+ namespace, limit=1, _request_timeout=timeout)
1761
+ # This call is "free" because this function is a cached call,
1762
+ # and it will not be called again in this function.
1763
+ get_kubernetes_nodes(context=context)
1337
1764
  except ImportError:
1338
1765
  # TODO(romilb): Update these error strs to also include link to docs
1339
1766
  # when docs are ready.
@@ -1361,7 +1788,7 @@ def check_credentials(context: Optional[str],
1361
1788
  # Check if $KUBECONFIG envvar consists of multiple paths. We run this before
1362
1789
  # optional checks.
1363
1790
  try:
1364
- _ = _get_kubeconfig_path()
1791
+ _ = get_kubeconfig_paths()
1365
1792
  except ValueError as e:
1366
1793
  return False, f'{common_utils.format_exception(e, use_bracket=True)}'
1367
1794
 
@@ -1419,50 +1846,197 @@ def check_credentials(context: Optional[str],
1419
1846
  return True, None
1420
1847
 
1421
1848
 
1849
+ class PodValidator:
1850
+ """Validates Kubernetes pod configs against the OpenAPI spec.
1851
+
1852
+ Adapted from kubernetes.client.ApiClient:
1853
+ https://github.com/kubernetes-client/python/blob/0c56ef1c8c4b50087bc7b803f6af896fb973309e/kubernetes/client/api_client.py#L33
1854
+
1855
+ We needed to adapt it because the original implementation ignores
1856
+ unknown fields, whereas we want to raise an error so that users
1857
+ are aware of the issue.
1858
+ """
1859
+ PRIMITIVE_TYPES = (int, float, bool, str)
1860
+ NATIVE_TYPES_MAPPING = {
1861
+ 'int': int,
1862
+ 'float': float,
1863
+ 'str': str,
1864
+ 'bool': bool,
1865
+ 'date': datetime.date,
1866
+ 'datetime': datetime.datetime,
1867
+ 'object': object,
1868
+ }
1869
+
1870
+ @classmethod
1871
+ def validate(cls, data):
1872
+ return cls.__validate(data, kubernetes.models.V1Pod)
1873
+
1874
+ @classmethod
1875
+ def __validate(cls, data, klass):
1876
+ """Deserializes dict, list, str into an object.
1877
+
1878
+ :param data: dict, list or str.
1879
+ :param klass: class literal, or string of class name.
1880
+
1881
+ :return: object.
1882
+ """
1883
+ if data is None:
1884
+ return None
1885
+
1886
+ if isinstance(klass, str):
1887
+ if klass.startswith('list['):
1888
+ match = re.match(r'list\[(.*)\]', klass)
1889
+ if match is None:
1890
+ raise ValueError(f'Invalid list type format: {klass}')
1891
+ sub_kls = match.group(1)
1892
+ return [cls.__validate(sub_data, sub_kls) for sub_data in data]
1893
+
1894
+ if klass.startswith('dict('):
1895
+ match = re.match(r'dict\(([^,]*), (.*)\)', klass)
1896
+ if match is None:
1897
+ raise ValueError(f'Invalid dict type format: {klass}')
1898
+ sub_kls = match.group(2)
1899
+ return {k: cls.__validate(v, sub_kls) for k, v in data.items()}
1900
+
1901
+ # convert str to class
1902
+ if klass in cls.NATIVE_TYPES_MAPPING:
1903
+ klass = cls.NATIVE_TYPES_MAPPING[klass]
1904
+ else:
1905
+ klass = getattr(kubernetes.models, klass)
1906
+
1907
+ if klass in cls.PRIMITIVE_TYPES:
1908
+ return cls.__validate_primitive(data, klass)
1909
+ elif klass == object:
1910
+ return cls.__validate_object(data)
1911
+ elif klass == datetime.date:
1912
+ return cls.__validate_date(data)
1913
+ elif klass == datetime.datetime:
1914
+ return cls.__validate_datetime(data)
1915
+ else:
1916
+ return cls.__validate_model(data, klass)
1917
+
1918
+ @classmethod
1919
+ def __validate_primitive(cls, data, klass):
1920
+ """Deserializes string to primitive type.
1921
+
1922
+ :param data: str.
1923
+ :param klass: class literal.
1924
+
1925
+ :return: int, long, float, str, bool.
1926
+ """
1927
+ try:
1928
+ return klass(data)
1929
+ except UnicodeEncodeError:
1930
+ return str(data)
1931
+ except TypeError:
1932
+ return data
1933
+
1934
+ @classmethod
1935
+ def __validate_object(cls, value):
1936
+ """Return an original value.
1937
+
1938
+ :return: object.
1939
+ """
1940
+ return value
1941
+
1942
+ @classmethod
1943
+ def __validate_date(cls, string):
1944
+ """Deserializes string to date.
1945
+
1946
+ :param string: str.
1947
+ :return: date.
1948
+ """
1949
+ try:
1950
+ return kubernetes.dateutil_parser.parse(string).date()
1951
+ except ValueError as exc:
1952
+ raise ValueError(
1953
+ f'Failed to parse `{string}` as date object') from exc
1954
+
1955
+ @classmethod
1956
+ def __validate_datetime(cls, string):
1957
+ """Deserializes string to datetime.
1958
+
1959
+ The string should be in iso8601 datetime format.
1960
+
1961
+ :param string: str.
1962
+ :return: datetime.
1963
+ """
1964
+ try:
1965
+ return kubernetes.dateutil_parser.parse(string)
1966
+ except ValueError as exc:
1967
+ raise ValueError(
1968
+ f'Failed to parse `{string}` as datetime object') from exc
1969
+
1970
+ @classmethod
1971
+ def __validate_model(cls, data, klass):
1972
+ """Deserializes list or dict to model.
1973
+
1974
+ :param data: dict, list.
1975
+ :param klass: class literal.
1976
+ :return: model object.
1977
+ """
1978
+
1979
+ if not klass.openapi_types and not hasattr(klass,
1980
+ 'get_real_child_model'):
1981
+ return data
1982
+
1983
+ kwargs = {}
1984
+ try:
1985
+ if (data is not None and klass.openapi_types is not None and
1986
+ isinstance(data, (list, dict))):
1987
+ # attribute_map is a dict that maps field names in snake_case
1988
+ # to camelCase.
1989
+ reverse_attribute_map = {
1990
+ v: k for k, v in klass.attribute_map.items()
1991
+ }
1992
+ for k, v in data.items():
1993
+ field_name = reverse_attribute_map.get(k, None)
1994
+ if field_name is None:
1995
+ raise ValueError(
1996
+ f'Unknown field `{k}`. Please ensure '
1997
+ 'pod_config follows the Kubernetes '
1998
+ 'Pod schema: '
1999
+ 'https://github.com/kubernetes/kubernetes/blob/master/api/openapi-spec/v3/api__v1_openapi.json'
2000
+ )
2001
+ kwargs[field_name] = cls.__validate(
2002
+ v, klass.openapi_types[field_name])
2003
+ except exceptions.KubernetesValidationError as e:
2004
+ raise exceptions.KubernetesValidationError([k] + e.path,
2005
+ str(e)) from e
2006
+ except Exception as e:
2007
+ raise exceptions.KubernetesValidationError([k], str(e)) from e
2008
+
2009
+ instance = klass(**kwargs)
2010
+
2011
+ if hasattr(instance, 'get_real_child_model'):
2012
+ klass_name = instance.get_real_child_model(data)
2013
+ if klass_name:
2014
+ instance = cls.__validate(data, klass_name)
2015
+ return instance
2016
+
1422
2017
  def check_pod_config(pod_config: dict) \
1423
2018
  -> Tuple[bool, Optional[str]]:
1424
- """Check if the pod_config is a valid pod config
2019
+ """Check if the pod_config is a valid pod config.
1425
2020
 
1426
- Using deserialize api to check the pod_config is valid or not.
2021
+ Uses the deserialize API from the kubernetes client library.
2022
+
2023
+ This is a client-side validation, meant to catch common errors like
2024
+ unknown/misspelled fields, and missing required fields.
2025
+
2026
+ The full validation however is done later on by the Kubernetes API server
2027
+ when the pod creation request is sent.
1427
2028
 
1428
2029
  Returns:
1429
2030
  bool: True if pod_config is valid.
1430
2031
  str: Error message about why the pod_config is invalid, None otherwise.
1431
2032
  """
1432
- errors = []
1433
- # This api_client won't be used to send any requests, so there is no need to
1434
- # load kubeconfig
1435
- api_client = kubernetes.kubernetes.client.ApiClient()
1436
-
1437
- # Used for kubernetes api_client deserialize function, the function will use
1438
- # data attr, the detail ref:
1439
- # https://github.com/kubernetes-client/python/blob/master/kubernetes/client/api_client.py#L244
1440
- class InnerResponse():
1441
-
1442
- def __init__(self, data: dict):
1443
- self.data = json.dumps(data)
1444
-
1445
2033
  try:
1446
- # Validate metadata if present
1447
- if 'metadata' in pod_config:
1448
- try:
1449
- value = InnerResponse(pod_config['metadata'])
1450
- api_client.deserialize(
1451
- value, kubernetes.kubernetes.client.V1ObjectMeta)
1452
- except ValueError as e:
1453
- errors.append(f'Invalid metadata: {str(e)}')
1454
- # Validate spec if present
1455
- if 'spec' in pod_config:
1456
- try:
1457
- value = InnerResponse(pod_config['spec'])
1458
- api_client.deserialize(value,
1459
- kubernetes.kubernetes.client.V1PodSpec)
1460
- except ValueError as e:
1461
- errors.append(f'Invalid spec: {str(e)}')
1462
- return len(errors) == 0, '.'.join(errors)
2034
+ PodValidator.validate(pod_config)
2035
+ except exceptions.KubernetesValidationError as e:
2036
+ return False, f'Validation error in {".".join(e.path)}: {str(e)}'
1463
2037
  except Exception as e: # pylint: disable=broad-except
1464
- errors.append(f'Validation error: {str(e)}')
1465
- return False, '.'.join(errors)
2038
+ return False, f'Unexpected error: {str(e)}'
2039
+ return True, None
1466
2040
 
1467
2041
 
1468
2042
  def is_kubeconfig_exec_auth(
@@ -1503,7 +2077,7 @@ def is_kubeconfig_exec_auth(
1503
2077
  return False, None
1504
2078
 
1505
2079
  # Get active context and user from kubeconfig using k8s api
1506
- all_contexts, current_context = k8s.config.list_kube_config_contexts()
2080
+ all_contexts, current_context = kubernetes.list_kube_config_contexts()
1507
2081
  context_obj = current_context
1508
2082
  if context is not None:
1509
2083
  for c in all_contexts:
@@ -1514,33 +2088,31 @@ def is_kubeconfig_exec_auth(
1514
2088
  raise ValueError(f'Kubernetes context {context!r} not found.')
1515
2089
  target_username = context_obj['context']['user']
1516
2090
 
1517
- # K8s api does not provide a mechanism to get the user details from the
1518
- # context. We need to load the kubeconfig file and parse it to get the
1519
- # user details.
1520
- kubeconfig_path = _get_kubeconfig_path()
1521
-
1522
- # Load the kubeconfig file as a dictionary
1523
- with open(kubeconfig_path, 'r', encoding='utf-8') as f:
1524
- kubeconfig = yaml.safe_load(f)
2091
+ # Load the kubeconfig for the context
2092
+ kubeconfig_text = _get_kubeconfig_text_for_context(context)
2093
+ kubeconfig = yaml_utils.safe_load(kubeconfig_text)
1525
2094
 
2095
+ # Get the user details
1526
2096
  user_details = kubeconfig['users']
1527
2097
 
1528
2098
  # Find user matching the target username
1529
2099
  user_details = next(
1530
2100
  user for user in user_details if user['name'] == target_username)
1531
2101
 
1532
- remote_identity = skypilot_config.get_nested(
1533
- ('kubernetes', 'remote_identity'),
1534
- schemas.get_default_remote_identity('kubernetes'))
2102
+ remote_identity = skypilot_config.get_effective_region_config(
2103
+ cloud='kubernetes',
2104
+ region=context,
2105
+ keys=('remote_identity',),
2106
+ default_value=schemas.get_default_remote_identity('kubernetes'))
1535
2107
  if ('exec' in user_details.get('user', {}) and remote_identity
1536
2108
  == schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value):
1537
2109
  ctx_name = context_obj['name']
1538
2110
  exec_msg = ('exec-based authentication is used for '
1539
- f'Kubernetes context {ctx_name!r}.'
1540
- ' This may cause issues with autodown or when running '
1541
- 'Managed Jobs or SkyServe controller on Kubernetes. '
1542
- 'To fix, configure SkyPilot to create a service account '
1543
- 'for running pods by setting the following in '
2111
+ f'Kubernetes context {ctx_name!r}. '
2112
+ 'Make sure that the corresponding cloud provider is '
2113
+ 'also enabled through `sky check` (e.g.: GCP for GKE). '
2114
+ 'Alternatively, configure SkyPilot to create a service '
2115
+ 'account for running pods by setting the following in '
1544
2116
  '~/.sky/config.yaml:\n'
1545
2117
  ' kubernetes:\n'
1546
2118
  ' remote_identity: SERVICE_ACCOUNT\n'
@@ -1550,6 +2122,33 @@ def is_kubeconfig_exec_auth(
1550
2122
  return False, None
1551
2123
 
1552
2124
 
2125
+ def _get_kubeconfig_text_for_context(context: Optional[str] = None) -> str:
2126
+ """Get the kubeconfig text for the given context.
2127
+
2128
+ The kubeconfig might be multiple files, this function use kubectl to
2129
+ handle merging automatically.
2130
+ """
2131
+ command = 'kubectl config view --minify'
2132
+ if context is not None:
2133
+ command += f' --context={context}'
2134
+
2135
+ # Ensure subprocess inherits the current environment properly
2136
+ # This fixes the issue where kubectl can't find ~/.kube/config in API server context
2137
+ env = os.environ.copy()
2138
+
2139
+ proc = subprocess.run(command,
2140
+ shell=True,
2141
+ check=False,
2142
+ env=env,
2143
+ stdout=subprocess.PIPE,
2144
+ stderr=subprocess.PIPE)
2145
+ if proc.returncode != 0:
2146
+ raise RuntimeError(
2147
+ f'Failed to get kubeconfig text for context {context}: {proc.stderr.decode("utf-8")}'
2148
+ )
2149
+ return proc.stdout.decode('utf-8')
2150
+
2151
+
1553
2152
  @annotations.lru_cache(scope='request')
1554
2153
  def get_current_kube_config_context_name() -> Optional[str]:
1555
2154
  """Get the current kubernetes context from the kubeconfig file
@@ -1559,7 +2158,7 @@ def get_current_kube_config_context_name() -> Optional[str]:
1559
2158
  """
1560
2159
  k8s = kubernetes.kubernetes
1561
2160
  try:
1562
- _, current_context = k8s.config.list_kube_config_contexts()
2161
+ _, current_context = kubernetes.list_kube_config_contexts()
1563
2162
  return current_context['name']
1564
2163
  except k8s.config.config_exception.ConfigException:
1565
2164
  return None
@@ -1595,7 +2194,7 @@ def get_all_kube_context_names() -> List[str]:
1595
2194
  k8s = kubernetes.kubernetes
1596
2195
  context_names = []
1597
2196
  try:
1598
- all_contexts, _ = k8s.config.list_kube_config_contexts()
2197
+ all_contexts, _ = kubernetes.list_kube_config_contexts()
1599
2198
  # all_contexts will always have at least one context. If kubeconfig
1600
2199
  # does not have any contexts defined, it will raise ConfigException.
1601
2200
  context_names = [context['name'] for context in all_contexts]
@@ -1638,7 +2237,7 @@ def get_kube_config_context_namespace(
1638
2237
  return f.read().strip()
1639
2238
  # If not in-cluster, get the namespace from kubeconfig
1640
2239
  try:
1641
- contexts, current_context = k8s.config.list_kube_config_contexts()
2240
+ contexts, current_context = kubernetes.list_kube_config_contexts()
1642
2241
  if context_name is None:
1643
2242
  context = current_context
1644
2243
  else:
@@ -1655,6 +2254,15 @@ def get_kube_config_context_namespace(
1655
2254
  return DEFAULT_NAMESPACE
1656
2255
 
1657
2256
 
2257
+ def parse_cpu_or_gpu_resource_to_float(resource_str: str) -> float:
2258
+ if not resource_str:
2259
+ return 0.0
2260
+ if resource_str[-1] == 'm':
2261
+ return float(resource_str[:-1]) / 1000
2262
+ else:
2263
+ return float(resource_str)
2264
+
2265
+
1658
2266
  def parse_cpu_or_gpu_resource(resource_qty_str: str) -> Union[int, float]:
1659
2267
  resource_str = str(resource_qty_str)
1660
2268
  if resource_str[-1] == 'm':
@@ -1758,9 +2366,7 @@ class KubernetesInstanceType:
1758
2366
  accelerator_type = match.group('accelerator_type')
1759
2367
  if accelerator_count:
1760
2368
  accelerator_count = int(accelerator_count)
1761
- # This is to revert the accelerator types with spaces back to
1762
- # the original format.
1763
- accelerator_type = str(accelerator_type).replace('_', ' ')
2369
+ accelerator_type = str(accelerator_type)
1764
2370
  else:
1765
2371
  accelerator_count = None
1766
2372
  accelerator_type = None
@@ -1837,16 +2443,14 @@ def construct_ssh_jump_command(
1837
2443
 
1838
2444
 
1839
2445
  def get_ssh_proxy_command(
1840
- k8s_ssh_target: str,
1841
- network_mode: kubernetes_enums.KubernetesNetworkingMode,
2446
+ pod_name: str,
1842
2447
  private_key_path: str,
1843
2448
  context: Optional[str],
1844
2449
  namespace: str,
1845
2450
  ) -> str:
1846
2451
  """Generates the SSH proxy command to connect to the pod.
1847
2452
 
1848
- Uses a jump pod if the network mode is NODEPORT, and direct port-forwarding
1849
- if the network mode is PORTFORWARD.
2453
+ Uses a direct port-forwarding.
1850
2454
 
1851
2455
  By default, establishing an SSH connection creates a communication
1852
2456
  channel to a remote node by setting up a TCP connection. When a
@@ -1857,17 +2461,8 @@ def get_ssh_proxy_command(
1857
2461
  Pods within a Kubernetes cluster have internal IP addresses that are
1858
2462
  typically not accessible from outside the cluster. Since the default TCP
1859
2463
  connection of SSH won't allow access to these pods, we employ a
1860
- ProxyCommand to establish the required communication channel. We offer this
1861
- in two different networking options: NodePort/port-forward.
1862
-
1863
- With the NodePort networking mode, a NodePort service is launched. This
1864
- service opens an external port on the node which redirects to the desired
1865
- port to a SSH jump pod. When establishing an SSH session in this mode, the
1866
- ProxyCommand makes use of this external port to create a communication
1867
- channel directly to port 22, which is the default port ssh server listens
1868
- on, of the jump pod.
2464
+ ProxyCommand to establish the required communication channel.
1869
2465
 
1870
- With Port-forward mode, instead of directly exposing an external port,
1871
2466
  'kubectl port-forward' sets up a tunnel between a local port
1872
2467
  (127.0.0.1:23100) and port 22 of the provisioned pod. Then we establish TCP
1873
2468
  connection to the local end of this tunnel, 127.0.0.1:23100, using 'socat'.
@@ -1878,38 +2473,26 @@ def get_ssh_proxy_command(
1878
2473
  the local machine.
1879
2474
 
1880
2475
  Args:
1881
- k8s_ssh_target: str; The Kubernetes object that will be used as the
1882
- target for SSH. If network_mode is NODEPORT, this is the name of the
1883
- service. If network_mode is PORTFORWARD, this is the pod name.
1884
- network_mode: KubernetesNetworkingMode; networking mode for ssh
1885
- session. It is either 'NODEPORT' or 'PORTFORWARD'
2476
+ pod_name: str; The Kubernetes pod name that will be used as the
2477
+ target for SSH.
1886
2478
  private_key_path: str; Path to the private key to use for SSH.
1887
2479
  This key must be authorized to access the SSH jump pod.
1888
- Required for NODEPORT networking mode.
1889
2480
  namespace: Kubernetes namespace to use.
1890
- Required for NODEPORT networking mode.
1891
2481
  """
1892
- # Fetch IP to connect to for the jump svc
1893
- ssh_jump_ip = get_external_ip(network_mode, context)
2482
+ ssh_jump_ip = '127.0.0.1' # Local end of the port-forward tunnel
1894
2483
  assert private_key_path is not None, 'Private key path must be provided'
1895
- if network_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT:
1896
- assert namespace is not None, 'Namespace must be provided for NodePort'
1897
- ssh_jump_port = get_port(k8s_ssh_target, namespace, context)
1898
- ssh_jump_proxy_command = construct_ssh_jump_command(
1899
- private_key_path, ssh_jump_ip, ssh_jump_port=ssh_jump_port)
1900
- else:
1901
- ssh_jump_proxy_command_path = create_proxy_command_script()
1902
- ssh_jump_proxy_command = construct_ssh_jump_command(
1903
- private_key_path,
1904
- ssh_jump_ip,
1905
- ssh_jump_user=constants.SKY_SSH_USER_PLACEHOLDER,
1906
- proxy_cmd_path=ssh_jump_proxy_command_path,
1907
- proxy_cmd_target_pod=k8s_ssh_target,
1908
- # We embed both the current context and namespace to the SSH proxy
1909
- # command to make sure SSH still works when the current
1910
- # context/namespace is changed by the user.
1911
- current_kube_context=context,
1912
- current_kube_namespace=namespace)
2484
+ ssh_jump_proxy_command_path = create_proxy_command_script()
2485
+ ssh_jump_proxy_command = construct_ssh_jump_command(
2486
+ private_key_path,
2487
+ ssh_jump_ip,
2488
+ ssh_jump_user=constants.SKY_SSH_USER_PLACEHOLDER,
2489
+ proxy_cmd_path=ssh_jump_proxy_command_path,
2490
+ proxy_cmd_target_pod=pod_name,
2491
+ # We embed both the current context and namespace to the SSH proxy
2492
+ # command to make sure SSH still works when the current
2493
+ # context/namespace is changed by the user.
2494
+ current_kube_context=context,
2495
+ current_kube_namespace=namespace)
1913
2496
  return ssh_jump_proxy_command
1914
2497
 
1915
2498
 
@@ -1941,240 +2524,6 @@ def create_proxy_command_script() -> str:
1941
2524
  return PORT_FORWARD_PROXY_CMD_PATH
1942
2525
 
1943
2526
 
1944
- def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str,
1945
- context: Optional[str],
1946
- service_type: kubernetes_enums.KubernetesServiceType):
1947
- """Sets up Kubernetes service resource to access for SSH jump pod.
1948
-
1949
- This method acts as a necessary complement to be run along with
1950
- setup_ssh_jump_pod(...) method. This service ensures the pod is accessible.
1951
-
1952
- Args:
1953
- ssh_jump_name: Name to use for the SSH jump service
1954
- namespace: Namespace to create the SSH jump service in
1955
- service_type: Networking configuration on either to use NodePort
1956
- or ClusterIP service to ssh in
1957
- """
1958
- # Fill in template - ssh_key_secret and ssh_jump_image are not required for
1959
- # the service spec, so we pass in empty strs.
1960
- content = fill_ssh_jump_template('', '', ssh_jump_name, service_type.value)
1961
-
1962
- # Add custom metadata from config
1963
- merge_custom_metadata(content['service_spec']['metadata'])
1964
-
1965
- # Create service
1966
- try:
1967
- kubernetes.core_api(context).create_namespaced_service(
1968
- namespace, content['service_spec'])
1969
- except kubernetes.api_exception() as e:
1970
- # SSH Jump Pod service already exists.
1971
- if e.status == 409:
1972
- ssh_jump_service = kubernetes.core_api(
1973
- context).read_namespaced_service(name=ssh_jump_name,
1974
- namespace=namespace)
1975
- curr_svc_type = ssh_jump_service.spec.type
1976
- if service_type.value == curr_svc_type:
1977
- # If the currently existing SSH Jump service's type is identical
1978
- # to user's configuration for networking mode
1979
- logger.debug(
1980
- f'SSH Jump Service {ssh_jump_name} already exists in the '
1981
- 'cluster, using it.')
1982
- else:
1983
- # If a different type of service type for SSH Jump pod compared
1984
- # to user's configuration for networking mode exists, we remove
1985
- # existing servie to create a new one following user's config
1986
- kubernetes.core_api(context).delete_namespaced_service(
1987
- name=ssh_jump_name, namespace=namespace)
1988
- kubernetes.core_api(context).create_namespaced_service(
1989
- namespace, content['service_spec'])
1990
- port_forward_mode = (
1991
- kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD.value)
1992
- nodeport_mode = (
1993
- kubernetes_enums.KubernetesNetworkingMode.NODEPORT.value)
1994
- clusterip_svc = (
1995
- kubernetes_enums.KubernetesServiceType.CLUSTERIP.value)
1996
- nodeport_svc = (
1997
- kubernetes_enums.KubernetesServiceType.NODEPORT.value)
1998
- curr_network_mode = port_forward_mode \
1999
- if curr_svc_type == clusterip_svc else nodeport_mode
2000
- new_network_mode = nodeport_mode \
2001
- if curr_svc_type == clusterip_svc else port_forward_mode
2002
- new_svc_type = nodeport_svc \
2003
- if curr_svc_type == clusterip_svc else clusterip_svc
2004
- logger.info(
2005
- f'Switching the networking mode from '
2006
- f'\'{curr_network_mode}\' to \'{new_network_mode}\' '
2007
- f'following networking configuration. Deleting existing '
2008
- f'\'{curr_svc_type}\' service and recreating as '
2009
- f'\'{new_svc_type}\' service.')
2010
- else:
2011
- raise
2012
- else:
2013
- logger.info(f'Created SSH Jump Service {ssh_jump_name}.')
2014
-
2015
-
2016
- def setup_ssh_jump_pod(ssh_jump_name: str, ssh_jump_image: str,
2017
- ssh_key_secret: str, namespace: str,
2018
- context: Optional[str]):
2019
- """Sets up Kubernetes RBAC and pod for SSH jump host.
2020
-
2021
- Our Kubernetes implementation uses a SSH jump pod to reach SkyPilot clusters
2022
- running inside a cluster. This function sets up the resources needed for
2023
- the SSH jump pod. This includes a service account which grants the jump pod
2024
- permission to watch for other SkyPilot pods and terminate itself if there
2025
- are no SkyPilot pods running.
2026
-
2027
- setup_ssh_jump_service must also be run to ensure that the SSH jump pod is
2028
- reachable.
2029
-
2030
- Args:
2031
- ssh_jump_image: Container image to use for the SSH jump pod
2032
- ssh_jump_name: Name to use for the SSH jump pod
2033
- ssh_key_secret: Secret name for the SSH key stored in the cluster
2034
- namespace: Namespace to create the SSH jump pod in
2035
- """
2036
- # Fill in template - service is created separately so service_type is not
2037
- # required, so we pass in empty str.
2038
- content = fill_ssh_jump_template(ssh_key_secret, ssh_jump_image,
2039
- ssh_jump_name, '')
2040
-
2041
- # Add custom metadata to all objects
2042
- for object_type in content.keys():
2043
- merge_custom_metadata(content[object_type]['metadata'])
2044
-
2045
- # ServiceAccount
2046
- try:
2047
- kubernetes.core_api(context).create_namespaced_service_account(
2048
- namespace, content['service_account'])
2049
- except kubernetes.api_exception() as e:
2050
- if e.status == 409:
2051
- logger.info(
2052
- 'SSH Jump ServiceAccount already exists in the cluster, using '
2053
- 'it.')
2054
- else:
2055
- raise
2056
- else:
2057
- logger.info('Created SSH Jump ServiceAccount.')
2058
- # Role
2059
- try:
2060
- kubernetes.auth_api(context).create_namespaced_role(
2061
- namespace, content['role'])
2062
- except kubernetes.api_exception() as e:
2063
- if e.status == 409:
2064
- logger.info(
2065
- 'SSH Jump Role already exists in the cluster, using it.')
2066
- else:
2067
- raise
2068
- else:
2069
- logger.info('Created SSH Jump Role.')
2070
- # RoleBinding
2071
- try:
2072
- kubernetes.auth_api(context).create_namespaced_role_binding(
2073
- namespace, content['role_binding'])
2074
- except kubernetes.api_exception() as e:
2075
- if e.status == 409:
2076
- logger.info(
2077
- 'SSH Jump RoleBinding already exists in the cluster, using '
2078
- 'it.')
2079
- else:
2080
- raise
2081
- else:
2082
- logger.info('Created SSH Jump RoleBinding.')
2083
- # Pod
2084
- try:
2085
- kubernetes.core_api(context).create_namespaced_pod(
2086
- namespace, content['pod_spec'])
2087
- except kubernetes.api_exception() as e:
2088
- if e.status == 409:
2089
- logger.info(
2090
- f'SSH Jump Host {ssh_jump_name} already exists in the cluster, '
2091
- 'using it.')
2092
- else:
2093
- raise
2094
- else:
2095
- logger.info(f'Created SSH Jump Host {ssh_jump_name}.')
2096
-
2097
-
2098
- def clean_zombie_ssh_jump_pod(namespace: str, context: Optional[str],
2099
- node_id: str):
2100
- """Analyzes SSH jump pod and removes if it is in a bad state
2101
-
2102
- Prevents the existence of a dangling SSH jump pod. This could happen
2103
- in case the pod main container did not start properly (or failed). In that
2104
- case, jump pod lifecycle manager will not function properly to
2105
- remove the pod and service automatically, and must be done manually.
2106
-
2107
- Args:
2108
- namespace: Namespace to remove the SSH jump pod and service from
2109
- node_id: Name of head pod
2110
- """
2111
-
2112
- def find(l, predicate):
2113
- """Utility function to find element in given list"""
2114
- results = [x for x in l if predicate(x)]
2115
- return results[0] if results else None
2116
-
2117
- # Get the SSH jump pod name from the head pod
2118
- try:
2119
- pod = kubernetes.core_api(context).read_namespaced_pod(
2120
- node_id, namespace)
2121
- except kubernetes.api_exception() as e:
2122
- if e.status == 404:
2123
- logger.warning(f'Failed to get pod {node_id},'
2124
- ' but the pod was not found (404).')
2125
- raise
2126
- else:
2127
- ssh_jump_name = pod.metadata.labels.get('skypilot-ssh-jump')
2128
- try:
2129
- ssh_jump_pod = kubernetes.core_api(context).read_namespaced_pod(
2130
- ssh_jump_name, namespace)
2131
- cont_ready_cond = find(ssh_jump_pod.status.conditions,
2132
- lambda c: c.type == 'ContainersReady')
2133
- if (cont_ready_cond and cont_ready_cond.status
2134
- == 'False') or ssh_jump_pod.status.phase == 'Pending':
2135
- # Either the main container is not ready or the pod failed
2136
- # to schedule. To be on the safe side and prevent a dangling
2137
- # ssh jump pod, lets remove it and the service. Otherwise, main
2138
- # container is ready and its lifecycle management script takes
2139
- # care of the cleaning.
2140
- kubernetes.core_api(context).delete_namespaced_pod(
2141
- ssh_jump_name, namespace)
2142
- kubernetes.core_api(context).delete_namespaced_service(
2143
- ssh_jump_name, namespace)
2144
- except kubernetes.api_exception() as e:
2145
- # We keep the warning in debug to avoid polluting the `sky launch`
2146
- # output.
2147
- logger.debug(f'Tried to check ssh jump pod {ssh_jump_name},'
2148
- f' but got error {e}\n. Consider running `kubectl '
2149
- f'delete pod {ssh_jump_name} -n {namespace}` to manually '
2150
- 'remove the pod if it has crashed.')
2151
- # We encountered an issue while checking ssh jump pod. To be on
2152
- # the safe side, lets remove its service so the port is freed
2153
- try:
2154
- kubernetes.core_api(context).delete_namespaced_service(
2155
- ssh_jump_name, namespace)
2156
- except kubernetes.api_exception():
2157
- pass
2158
-
2159
-
2160
- def fill_ssh_jump_template(ssh_key_secret: str, ssh_jump_image: str,
2161
- ssh_jump_name: str, service_type: str) -> Dict:
2162
- template_path = os.path.join(sky.__root_dir__, 'templates',
2163
- 'kubernetes-ssh-jump.yml.j2')
2164
- if not os.path.exists(template_path):
2165
- raise FileNotFoundError(
2166
- 'Template "kubernetes-ssh-jump.j2" does not exist.')
2167
- with open(template_path, 'r', encoding='utf-8') as fin:
2168
- template = fin.read()
2169
- j2_template = jinja2.Template(template)
2170
- cont = j2_template.render(name=ssh_jump_name,
2171
- image=ssh_jump_image,
2172
- secret=ssh_key_secret,
2173
- service_type=service_type)
2174
- content = yaml.safe_load(cont)
2175
- return content
2176
-
2177
-
2178
2527
  def check_port_forward_mode_dependencies(
2179
2528
  raise_error: bool = True) -> Optional[List[str]]:
2180
2529
  """Checks if 'socat' and 'nc' are installed
@@ -2252,7 +2601,7 @@ def check_port_forward_mode_dependencies(
2252
2601
  return None
2253
2602
 
2254
2603
 
2255
- def get_endpoint_debug_message() -> str:
2604
+ def get_endpoint_debug_message(context: Optional[str] = None) -> str:
2256
2605
  """ Returns a string message for user to debug Kubernetes port opening
2257
2606
 
2258
2607
  Polls the configured ports mode on Kubernetes to produce an
@@ -2260,7 +2609,7 @@ def get_endpoint_debug_message() -> str:
2260
2609
 
2261
2610
  Also checks if the
2262
2611
  """
2263
- port_mode = network_utils.get_port_mode()
2612
+ port_mode = network_utils.get_port_mode(None, context)
2264
2613
  if port_mode == kubernetes_enums.KubernetesPortMode.INGRESS:
2265
2614
  endpoint_type = 'Ingress'
2266
2615
  debug_cmd = 'kubectl describe ingress && kubectl describe ingressclass'
@@ -2275,9 +2624,11 @@ def get_endpoint_debug_message() -> str:
2275
2624
 
2276
2625
 
2277
2626
  def combine_pod_config_fields(
2278
- cluster_yaml_path: str,
2627
+ cluster_yaml_obj: Dict[str, Any],
2279
2628
  cluster_config_overrides: Dict[str, Any],
2280
- ) -> None:
2629
+ cloud: Optional[clouds.Cloud] = None,
2630
+ context: Optional[str] = None,
2631
+ ) -> Dict[str, Any]:
2281
2632
  """Adds or updates fields in the YAML with fields from the
2282
2633
  ~/.sky/config.yaml's kubernetes.pod_spec dict.
2283
2634
  This can be used to add fields to the YAML that are not supported by
@@ -2316,72 +2667,138 @@ def combine_pod_config_fields(
2316
2667
  - name: my-secret
2317
2668
  ```
2318
2669
  """
2319
- with open(cluster_yaml_path, 'r', encoding='utf-8') as f:
2320
- yaml_content = f.read()
2321
- yaml_obj = yaml.safe_load(yaml_content)
2322
- # We don't use override_configs in `skypilot_config.get_nested`, as merging
2670
+ merged_cluster_yaml_obj = copy.deepcopy(cluster_yaml_obj)
2671
+ # We don't use override_configs in `get_effective_region_config`, as merging
2323
2672
  # the pod config requires special handling.
2324
- kubernetes_config = skypilot_config.get_nested(('kubernetes', 'pod_config'),
2325
- default_value={},
2326
- override_configs={})
2327
- override_pod_config = (cluster_config_overrides.get('kubernetes', {}).get(
2328
- 'pod_config', {}))
2673
+ cloud_str = 'ssh' if isinstance(cloud, clouds.SSH) else 'kubernetes'
2674
+ context_str = context
2675
+ if isinstance(cloud, clouds.SSH) and context is not None:
2676
+ assert context.startswith('ssh-'), 'SSH context must start with "ssh-"'
2677
+ context_str = context[len('ssh-'):]
2678
+ kubernetes_config = skypilot_config.get_effective_region_config(
2679
+ cloud=cloud_str,
2680
+ region=context_str,
2681
+ keys=('pod_config',),
2682
+ default_value={})
2683
+ override_pod_config = config_utils.get_cloud_config_value_from_dict(
2684
+ dict_config=cluster_config_overrides,
2685
+ cloud=cloud_str,
2686
+ region=context_str,
2687
+ keys=('pod_config',),
2688
+ default_value={})
2329
2689
  config_utils.merge_k8s_configs(kubernetes_config, override_pod_config)
2330
2690
 
2331
2691
  # Merge the kubernetes config into the YAML for both head and worker nodes.
2332
2692
  config_utils.merge_k8s_configs(
2333
- yaml_obj['available_node_types']['ray_head_default']['node_config'],
2334
- kubernetes_config)
2335
-
2336
- # Write the updated YAML back to the file
2337
- common_utils.dump_yaml(cluster_yaml_path, yaml_obj)
2693
+ merged_cluster_yaml_obj['available_node_types']['ray_head_default']
2694
+ ['node_config'], kubernetes_config)
2695
+ return merged_cluster_yaml_obj
2338
2696
 
2339
2697
 
2340
- def combine_metadata_fields(cluster_yaml_path: str) -> None:
2698
+ def combine_metadata_fields(cluster_yaml_obj: Dict[str, Any],
2699
+ cluster_config_overrides: Dict[str, Any],
2700
+ context: Optional[str] = None) -> Dict[str, Any]:
2341
2701
  """Updates the metadata for all Kubernetes objects created by SkyPilot with
2342
2702
  fields from the ~/.sky/config.yaml's kubernetes.custom_metadata dict.
2343
2703
 
2344
2704
  Obeys the same add or update semantics as combine_pod_config_fields().
2345
2705
  """
2346
-
2347
- with open(cluster_yaml_path, 'r', encoding='utf-8') as f:
2348
- yaml_content = f.read()
2349
- yaml_obj = yaml.safe_load(yaml_content)
2350
- custom_metadata = skypilot_config.get_nested(
2351
- ('kubernetes', 'custom_metadata'), {})
2706
+ merged_cluster_yaml_obj = copy.deepcopy(cluster_yaml_obj)
2707
+ context, cloud_str = get_cleaned_context_and_cloud_str(context)
2708
+
2709
+ # Get custom_metadata from global config
2710
+ custom_metadata = skypilot_config.get_effective_region_config(
2711
+ cloud=cloud_str,
2712
+ region=context,
2713
+ keys=('custom_metadata',),
2714
+ default_value={})
2715
+
2716
+ # Get custom_metadata from task-level config overrides
2717
+ override_custom_metadata = config_utils.get_cloud_config_value_from_dict(
2718
+ dict_config=cluster_config_overrides,
2719
+ cloud=cloud_str,
2720
+ region=context,
2721
+ keys=('custom_metadata',),
2722
+ default_value={})
2723
+
2724
+ # Merge task-level overrides with global config
2725
+ config_utils.merge_k8s_configs(custom_metadata, override_custom_metadata)
2352
2726
 
2353
2727
  # List of objects in the cluster YAML to be updated
2354
2728
  combination_destinations = [
2355
2729
  # Service accounts
2356
- yaml_obj['provider']['autoscaler_service_account']['metadata'],
2357
- yaml_obj['provider']['autoscaler_role']['metadata'],
2358
- yaml_obj['provider']['autoscaler_role_binding']['metadata'],
2359
- yaml_obj['provider']['autoscaler_service_account']['metadata'],
2360
- # Pod spec
2361
- yaml_obj['available_node_types']['ray_head_default']['node_config']
2730
+ merged_cluster_yaml_obj['provider']['autoscaler_service_account']
2731
+ ['metadata'],
2732
+ merged_cluster_yaml_obj['provider']['autoscaler_role']['metadata'],
2733
+ merged_cluster_yaml_obj['provider']['autoscaler_role_binding']
2362
2734
  ['metadata'],
2735
+ merged_cluster_yaml_obj['provider']['autoscaler_service_account']
2736
+ ['metadata'],
2737
+ # Pod spec
2738
+ merged_cluster_yaml_obj['available_node_types']['ray_head_default']
2739
+ ['node_config']['metadata'],
2363
2740
  # Services for pods
2364
- *[svc['metadata'] for svc in yaml_obj['provider']['services']]
2741
+ *[
2742
+ svc['metadata']
2743
+ for svc in merged_cluster_yaml_obj['provider']['services']
2744
+ ]
2365
2745
  ]
2366
2746
 
2367
2747
  for destination in combination_destinations:
2368
2748
  config_utils.merge_k8s_configs(destination, custom_metadata)
2369
2749
 
2370
- # Write the updated YAML back to the file
2371
- common_utils.dump_yaml(cluster_yaml_path, yaml_obj)
2750
+ return merged_cluster_yaml_obj
2751
+
2372
2752
 
2753
+ def combine_pod_config_fields_and_metadata(
2754
+ cluster_yaml_obj: Dict[str, Any],
2755
+ cluster_config_overrides: Dict[str, Any],
2756
+ cloud: Optional[clouds.Cloud] = None,
2757
+ context: Optional[str] = None) -> Dict[str, Any]:
2758
+ """Combines pod config fields and metadata fields"""
2759
+ combined_yaml_obj = combine_pod_config_fields(cluster_yaml_obj,
2760
+ cluster_config_overrides,
2761
+ cloud, context)
2762
+ combined_yaml_obj = combine_metadata_fields(combined_yaml_obj,
2763
+ cluster_config_overrides,
2764
+ context)
2765
+ return combined_yaml_obj
2373
2766
 
2374
- def merge_custom_metadata(original_metadata: Dict[str, Any]) -> None:
2767
+
2768
+ def merge_custom_metadata(
2769
+ original_metadata: Dict[str, Any],
2770
+ context: Optional[str] = None,
2771
+ cluster_config_overrides: Optional[Dict[str, Any]] = None) -> None:
2375
2772
  """Merges original metadata with custom_metadata from config
2376
2773
 
2377
2774
  Merge is done in-place, so return is not required
2378
2775
  """
2379
- custom_metadata = skypilot_config.get_nested(
2380
- ('kubernetes', 'custom_metadata'), {})
2776
+ context, cloud_str = get_cleaned_context_and_cloud_str(context)
2777
+
2778
+ # Get custom_metadata from global config
2779
+ custom_metadata = skypilot_config.get_effective_region_config(
2780
+ cloud=cloud_str,
2781
+ region=context,
2782
+ keys=('custom_metadata',),
2783
+ default_value={})
2784
+
2785
+ # Get custom_metadata from task-level config overrides if available
2786
+ if cluster_config_overrides is not None:
2787
+ override_custom_metadata = config_utils.get_cloud_config_value_from_dict(
2788
+ dict_config=cluster_config_overrides,
2789
+ cloud=cloud_str,
2790
+ region=context,
2791
+ keys=('custom_metadata',),
2792
+ default_value={})
2793
+ # Merge task-level overrides with global config
2794
+ config_utils.merge_k8s_configs(custom_metadata,
2795
+ override_custom_metadata)
2796
+
2381
2797
  config_utils.merge_k8s_configs(original_metadata, custom_metadata)
2382
2798
 
2383
2799
 
2384
- def check_nvidia_runtime_class(context: Optional[str] = None) -> bool:
2800
+ @_retry_on_error(resource_type='runtimeclass')
2801
+ def check_nvidia_runtime_class(*, context: Optional[str] = None) -> bool:
2385
2802
  """Checks if the 'nvidia' RuntimeClass exists in the cluster"""
2386
2803
  # Fetch the list of available RuntimeClasses
2387
2804
  runtime_classes = kubernetes.node_api(context).list_runtime_class()
@@ -2431,7 +2848,7 @@ def create_namespace(namespace: str, context: Optional[str]) -> None:
2431
2848
  return
2432
2849
 
2433
2850
  ns_metadata = dict(name=namespace, labels={'parent': 'skypilot'})
2434
- merge_custom_metadata(ns_metadata)
2851
+ merge_custom_metadata(ns_metadata, context)
2435
2852
  namespace_obj = kubernetes_client.V1Namespace(metadata=ns_metadata)
2436
2853
  try:
2437
2854
  kubernetes.core_api(context).create_namespace(namespace_obj)
@@ -2457,15 +2874,14 @@ def get_head_pod_name(cluster_name_on_cloud: str):
2457
2874
  return f'{cluster_name_on_cloud}-head'
2458
2875
 
2459
2876
 
2460
- def get_autoscaler_type(
2461
- ) -> Optional[kubernetes_enums.KubernetesAutoscalerType]:
2462
- """Returns the autoscaler type by reading from config"""
2463
- autoscaler_type = skypilot_config.get_nested(('kubernetes', 'autoscaler'),
2464
- None)
2465
- if autoscaler_type is not None:
2466
- autoscaler_type = kubernetes_enums.KubernetesAutoscalerType(
2467
- autoscaler_type)
2468
- return autoscaler_type
2877
+ def get_custom_config_k8s_contexts() -> List[str]:
2878
+ """Returns the list of context names from the config"""
2879
+ contexts = skypilot_config.get_effective_region_config(
2880
+ cloud='kubernetes',
2881
+ region=None,
2882
+ keys=('context_configs',),
2883
+ default_value={})
2884
+ return [*contexts] or []
2469
2885
 
2470
2886
 
2471
2887
  # Mapping of known spot label keys and values for different cluster types
@@ -2477,6 +2893,21 @@ SPOT_LABEL_MAP = {
2477
2893
  }
2478
2894
 
2479
2895
 
2896
+ def get_autoscaler_type(
2897
+ context: Optional[str] = None
2898
+ ) -> Optional[kubernetes_enums.KubernetesAutoscalerType]:
2899
+ """Returns the autoscaler type by reading from config"""
2900
+ autoscaler_type = skypilot_config.get_effective_region_config(
2901
+ cloud='kubernetes',
2902
+ region=context,
2903
+ keys=('autoscaler',),
2904
+ default_value=None)
2905
+ if autoscaler_type is not None:
2906
+ autoscaler_type = kubernetes_enums.KubernetesAutoscalerType(
2907
+ autoscaler_type)
2908
+ return autoscaler_type
2909
+
2910
+
2480
2911
  def get_spot_label(
2481
2912
  context: Optional[str] = None) -> Tuple[Optional[str], Optional[str]]:
2482
2913
  """Get the spot label key and value for using spot instances, if supported.
@@ -2500,7 +2931,7 @@ def get_spot_label(
2500
2931
 
2501
2932
  # Check if autoscaler is configured. Allow spot instances if autoscaler type
2502
2933
  # is known to support spot instances.
2503
- autoscaler_type = get_autoscaler_type()
2934
+ autoscaler_type = get_autoscaler_type(context=context)
2504
2935
  if autoscaler_type == kubernetes_enums.KubernetesAutoscalerType.GKE:
2505
2936
  return SPOT_LABEL_MAP[autoscaler_type.value]
2506
2937
 
@@ -2542,7 +2973,7 @@ def get_unlabeled_accelerator_nodes(context: Optional[str] = None) -> List[Any]:
2542
2973
  nodes = get_kubernetes_nodes(context=context)
2543
2974
  nodes_with_accelerator = []
2544
2975
  for node in nodes:
2545
- if get_gpu_resource_key() in node.status.capacity:
2976
+ if get_gpu_resource_key(context) in node.status.capacity:
2546
2977
  nodes_with_accelerator.append(node)
2547
2978
 
2548
2979
  label_formatter, _ = detect_gpu_label_formatter(context)
@@ -2586,14 +3017,6 @@ def get_kubernetes_node_info(
2586
3017
  information.
2587
3018
  """
2588
3019
  nodes = get_kubernetes_nodes(context=context)
2589
- # Get the pods to get the real-time resource usage
2590
- try:
2591
- pods = get_all_pods_in_kubernetes_cluster(context=context)
2592
- except kubernetes.api_exception() as e:
2593
- if e.status == 403:
2594
- pods = None
2595
- else:
2596
- raise
2597
3020
 
2598
3021
  lf, _ = detect_gpu_label_formatter(context)
2599
3022
  if not lf:
@@ -2601,6 +3024,29 @@ def get_kubernetes_node_info(
2601
3024
  else:
2602
3025
  label_keys = lf.get_label_keys()
2603
3026
 
3027
+ # Check if all nodes have no accelerators to avoid fetching pods
3028
+ has_accelerator_nodes = False
3029
+ for node in nodes:
3030
+ accelerator_count = get_node_accelerator_count(context,
3031
+ node.status.allocatable)
3032
+ if accelerator_count > 0:
3033
+ has_accelerator_nodes = True
3034
+ break
3035
+
3036
+ # Get the allocated GPU quantity by each node
3037
+ allocated_qty_by_node: Dict[str, int] = collections.defaultdict(int)
3038
+ error_on_get_allocated_gpu_qty_by_node = False
3039
+ if has_accelerator_nodes:
3040
+ try:
3041
+ allocated_qty_by_node = get_allocated_gpu_qty_by_node(
3042
+ context=context)
3043
+ except kubernetes.api_exception() as e:
3044
+ if e.status == 403:
3045
+ error_on_get_allocated_gpu_qty_by_node = True
3046
+ pass
3047
+ else:
3048
+ raise
3049
+
2604
3050
  node_info_dict: Dict[str, models.KubernetesNodeInfo] = {}
2605
3051
  has_multi_host_tpu = False
2606
3052
 
@@ -2615,24 +3061,36 @@ def get_kubernetes_node_info(
2615
3061
  node.metadata.labels.get(label_key))
2616
3062
  break
2617
3063
 
2618
- allocated_qty = 0
2619
- accelerator_count = get_node_accelerator_count(node.status.allocatable)
3064
+ # Extract IP address from node addresses (prefer external, fallback to internal)
3065
+ node_ip = None
3066
+ if node.status.addresses:
3067
+ # First try to find external IP
3068
+ for address in node.status.addresses:
3069
+ if address.type == 'ExternalIP':
3070
+ node_ip = address.address
3071
+ break
3072
+ # If no external IP, try to find internal IP
3073
+ if node_ip is None:
3074
+ for address in node.status.addresses:
3075
+ if address.type == 'InternalIP':
3076
+ node_ip = address.address
3077
+ break
3078
+
3079
+ accelerator_count = get_node_accelerator_count(context,
3080
+ node.status.allocatable)
3081
+ if accelerator_count == 0:
3082
+ node_info_dict[node.metadata.name] = models.KubernetesNodeInfo(
3083
+ name=node.metadata.name,
3084
+ accelerator_type=accelerator_name,
3085
+ total={'accelerator_count': 0},
3086
+ free={'accelerators_available': 0},
3087
+ ip_address=node_ip)
3088
+ continue
2620
3089
 
2621
- if pods is None:
3090
+ if not has_accelerator_nodes or error_on_get_allocated_gpu_qty_by_node:
2622
3091
  accelerators_available = -1
2623
-
2624
3092
  else:
2625
- for pod in pods:
2626
- # Get all the pods running on the node
2627
- if (pod.spec.node_name == node.metadata.name and
2628
- pod.status.phase in ['Running', 'Pending']):
2629
- # Iterate over all the containers in the pod and sum the
2630
- # GPU requests
2631
- for container in pod.spec.containers:
2632
- if container.resources.requests:
2633
- allocated_qty += get_node_accelerator_count(
2634
- container.resources.requests)
2635
-
3093
+ allocated_qty = allocated_qty_by_node[node.metadata.name]
2636
3094
  accelerators_available = accelerator_count - allocated_qty
2637
3095
 
2638
3096
  # Exclude multi-host TPUs from being processed.
@@ -2646,7 +3104,8 @@ def get_kubernetes_node_info(
2646
3104
  name=node.metadata.name,
2647
3105
  accelerator_type=accelerator_name,
2648
3106
  total={'accelerator_count': int(accelerator_count)},
2649
- free={'accelerators_available': int(accelerators_available)})
3107
+ free={'accelerators_available': int(accelerators_available)},
3108
+ ip_address=node_ip)
2650
3109
  hint = ''
2651
3110
  if has_multi_host_tpu:
2652
3111
  hint = ('(Note: Multi-host TPUs are detected and excluded from the '
@@ -2678,7 +3137,11 @@ def filter_pods(namespace: str,
2678
3137
  context: Optional[str],
2679
3138
  tag_filters: Dict[str, str],
2680
3139
  status_filters: Optional[List[str]] = None) -> Dict[str, Any]:
2681
- """Filters pods by tags and status."""
3140
+ """Filters pods by tags and status.
3141
+
3142
+ Returned dict is sorted by name, with workers sorted by their numeric suffix.
3143
+ This ensures consistent ordering for SSH configuration and other operations.
3144
+ """
2682
3145
  non_included_pod_statuses = POD_STATUSES.copy()
2683
3146
 
2684
3147
  field_selector = ''
@@ -2696,7 +3159,32 @@ def filter_pods(namespace: str,
2696
3159
  pods = [
2697
3160
  pod for pod in pod_list.items if pod.metadata.deletion_timestamp is None
2698
3161
  ]
2699
- return {pod.metadata.name: pod for pod in pods}
3162
+
3163
+ # Sort pods by name, with workers sorted by their numeric suffix.
3164
+ # This ensures consistent ordering (e.g., cluster-head, cluster-worker1,
3165
+ # cluster-worker2, cluster-worker3, ...) even when Kubernetes API
3166
+ # returns them in arbitrary order. This works even if there were
3167
+ # somehow pod names other than head/worker ones, and those end up at
3168
+ # the end of the list.
3169
+ def get_pod_sort_key(
3170
+ pod: V1Pod
3171
+ ) -> Union[Tuple[Literal[0], str], Tuple[Literal[1], int], Tuple[Literal[2],
3172
+ str]]:
3173
+ name = pod.metadata.name
3174
+ name_suffix = name.split('-')[-1]
3175
+ if name_suffix == 'head':
3176
+ return (0, name)
3177
+ elif name_suffix.startswith('worker'):
3178
+ try:
3179
+ return (1, int(name_suffix.split('worker')[-1]))
3180
+ except (ValueError, IndexError):
3181
+ return (2, name)
3182
+ else:
3183
+ return (2, name)
3184
+
3185
+ sorted_pods = sorted(pods, key=get_pod_sort_key)
3186
+
3187
+ return {pod.metadata.name: pod for pod in sorted_pods}
2700
3188
 
2701
3189
 
2702
3190
  def _remove_pod_annotation(pod: Any,
@@ -2763,7 +3251,7 @@ def set_autodown_annotations(handle: 'backends.CloudVmRayResourceHandle',
2763
3251
  tags = {
2764
3252
  provision_constants.TAG_RAY_CLUSTER_NAME: handle.cluster_name_on_cloud,
2765
3253
  }
2766
- ray_config = common_utils.read_yaml(handle.cluster_yaml)
3254
+ ray_config = global_user_state.get_cluster_yaml_dict(handle.cluster_yaml)
2767
3255
  provider_config = ray_config['provider']
2768
3256
  namespace = get_namespace_from_config(provider_config)
2769
3257
  context = get_context_from_config(provider_config)
@@ -2805,8 +3293,8 @@ def get_context_from_config(provider_config: Dict[str, Any]) -> Optional[str]:
2805
3293
  context = provider_config.get('context',
2806
3294
  get_current_kube_config_context_name())
2807
3295
  if context == kubernetes.in_cluster_context_name():
2808
- # If the context (also used as the region) is in-cluster, we need to
2809
- # we need to use in-cluster auth by setting the context to None.
3296
+ # If the context (also used as the region) is in-cluster, we need
3297
+ # to use in-cluster auth by setting the context to None.
2810
3298
  context = None
2811
3299
  return context
2812
3300
 
@@ -2825,23 +3313,27 @@ def get_skypilot_pods(context: Optional[str] = None) -> List[Any]:
2825
3313
 
2826
3314
  try:
2827
3315
  pods = kubernetes.core_api(context).list_pod_for_all_namespaces(
2828
- label_selector='skypilot-cluster',
3316
+ label_selector=provision_constants.TAG_SKYPILOT_CLUSTER_NAME,
2829
3317
  _request_timeout=kubernetes.API_TIMEOUT).items
2830
3318
  except kubernetes.max_retry_error():
2831
3319
  raise exceptions.ResourcesUnavailableError(
2832
3320
  'Timed out trying to get SkyPilot pods from Kubernetes cluster. '
2833
3321
  'Please check if the cluster is healthy and retry. To debug, run: '
2834
- 'kubectl get pods --selector=skypilot-cluster --all-namespaces'
3322
+ 'kubectl get pods --selector=skypilot-cluster-name --all-namespaces'
2835
3323
  ) from None
2836
3324
  return pods
2837
3325
 
2838
3326
 
2839
- def is_tpu_on_gke(accelerator: str) -> bool:
3327
+ def is_tpu_on_gke(accelerator: str, normalize: bool = True) -> bool:
2840
3328
  """Determines if the given accelerator is a TPU supported on GKE."""
3329
+ if normalize:
3330
+ normalized, _ = normalize_tpu_accelerator_name(accelerator)
3331
+ return normalized in GKE_TPU_ACCELERATOR_TO_GENERATION
2841
3332
  return accelerator in GKE_TPU_ACCELERATOR_TO_GENERATION
2842
3333
 
2843
3334
 
2844
- def get_node_accelerator_count(attribute_dict: dict) -> int:
3335
+ def get_node_accelerator_count(context: Optional[str],
3336
+ attribute_dict: dict) -> int:
2845
3337
  """Retrieves the count of accelerators from a node's resource dictionary.
2846
3338
 
2847
3339
  This method checks the node's allocatable resources or the accelerators
@@ -2856,7 +3348,7 @@ def get_node_accelerator_count(attribute_dict: dict) -> int:
2856
3348
  Number of accelerators allocated or available from the node. If no
2857
3349
  resource is found, it returns 0.
2858
3350
  """
2859
- gpu_resource_name = get_gpu_resource_key()
3351
+ gpu_resource_name = get_gpu_resource_key(context)
2860
3352
  assert not (gpu_resource_name in attribute_dict and
2861
3353
  TPU_RESOURCE_KEY in attribute_dict)
2862
3354
  if gpu_resource_name in attribute_dict:
@@ -2964,7 +3456,8 @@ def process_skypilot_pods(
2964
3456
  serve_controllers: List[KubernetesSkyPilotClusterInfo] = []
2965
3457
 
2966
3458
  for pod in pods:
2967
- cluster_name_on_cloud = pod.metadata.labels.get('skypilot-cluster')
3459
+ cluster_name_on_cloud = pod.metadata.labels.get(
3460
+ provision_constants.TAG_SKYPILOT_CLUSTER_NAME)
2968
3461
  cluster_name = cluster_name_on_cloud.rsplit(
2969
3462
  '-', 1
2970
3463
  )[0] # Remove the user hash to get cluster name (e.g., mycluster-2ea4)
@@ -2982,7 +3475,7 @@ def process_skypilot_pods(
2982
3475
  unit='G')
2983
3476
  gpu_count = parse_cpu_or_gpu_resource(
2984
3477
  pod.spec.containers[0].resources.requests.get(
2985
- 'nvidia.com/gpu', '0'))
3478
+ get_gpu_resource_key(context), '0'))
2986
3479
  gpu_name = None
2987
3480
  if gpu_count > 0:
2988
3481
  label_formatter, _ = (detect_gpu_label_formatter(context))
@@ -2991,9 +3484,20 @@ def process_skypilot_pods(
2991
3484
  f'requesting GPUs: {pod.metadata.name}')
2992
3485
  gpu_label = label_formatter.get_label_key()
2993
3486
  # Get GPU name from pod node selector
2994
- if pod.spec.node_selector is not None:
2995
- gpu_name = label_formatter.get_accelerator_from_label_value(
2996
- pod.spec.node_selector.get(gpu_label))
3487
+ node_selector_terms = (
3488
+ pod.spec.affinity.node_affinity.
3489
+ required_during_scheduling_ignored_during_execution.
3490
+ node_selector_terms)
3491
+ if node_selector_terms is not None:
3492
+ expressions = []
3493
+ for term in node_selector_terms:
3494
+ if term.match_expressions:
3495
+ expressions.extend(term.match_expressions)
3496
+ for expression in expressions:
3497
+ if expression.key == gpu_label and expression.operator == 'In':
3498
+ gpu_name = label_formatter.get_accelerator_from_label_value(
3499
+ expression.values[0])
3500
+ break
2997
3501
 
2998
3502
  resources = resources_lib.Resources(
2999
3503
  cloud=clouds.Kubernetes(),
@@ -3037,33 +3541,216 @@ def process_skypilot_pods(
3037
3541
  return list(clusters.values()), jobs_controllers, serve_controllers
3038
3542
 
3039
3543
 
3040
- def get_gpu_resource_key():
3041
- """Get the GPU resource name to use in kubernetes.
3042
- The function first checks for an environment variable.
3043
- If defined, it uses its value; otherwise, it returns the default value.
3044
- Args:
3045
- name (str): Default GPU resource name, default is "nvidia.com/gpu".
3544
+ def _gpu_resource_key_helper(context: Optional[str]) -> str:
3545
+ """Helper function to get the GPU resource key."""
3546
+ gpu_resource_key = SUPPORTED_GPU_RESOURCE_KEYS['nvidia']
3547
+ try:
3548
+ nodes = kubernetes.core_api(context).list_node().items
3549
+ for gpu_key in SUPPORTED_GPU_RESOURCE_KEYS.values():
3550
+ if any(gpu_key in node.status.capacity for node in nodes):
3551
+ return gpu_key
3552
+ except Exception as e: # pylint: disable=broad-except
3553
+ logger.warning(f'Failed to load kube config or query nodes: {e}. '
3554
+ 'Falling back to default GPU resource key.')
3555
+ return gpu_resource_key
3556
+
3557
+
3558
+ @annotations.lru_cache(scope='request')
3559
+ def get_gpu_resource_key(context: Optional[str] = None) -> str:
3560
+ """Get the GPU resource name to use in Kubernetes.
3561
+
3562
+ The function auto-detects the GPU resource key by querying the Kubernetes node API.
3563
+ If detection fails, it falls back to a default value.
3564
+ An environment variable can override the detected or default value.
3565
+
3046
3566
  Returns:
3047
3567
  str: The selected GPU resource name.
3048
3568
  """
3049
- # Retrieve GPU resource name from environment variable, if set.
3050
- # Else use default.
3051
- # E.g., can be nvidia.com/gpu-h100, amd.com/gpu etc.
3052
- return os.getenv('CUSTOM_GPU_RESOURCE_KEY', default=GPU_RESOURCE_KEY)
3569
+ gpu_resource_key = _gpu_resource_key_helper(context)
3570
+ return os.getenv('CUSTOM_GPU_RESOURCE_KEY', default=gpu_resource_key)
3053
3571
 
3054
3572
 
3055
- def _get_kubeconfig_path() -> str:
3056
- """Get the path to the kubeconfig file.
3573
+ def get_kubeconfig_paths() -> List[str]:
3574
+ """Get the path to the kubeconfig files.
3057
3575
  Parses `KUBECONFIG` env var if present, else uses the default path.
3058
- Currently, specifying multiple KUBECONFIG paths in the envvar is not
3059
- allowed, hence will raise a ValueError.
3060
3576
  """
3061
- kubeconfig_path = os.path.expanduser(
3062
- os.getenv(
3063
- 'KUBECONFIG', kubernetes.kubernetes.config.kube_config.
3064
- KUBE_CONFIG_DEFAULT_LOCATION))
3065
- if len(kubeconfig_path.split(os.pathsep)) > 1:
3066
- raise ValueError('SkyPilot currently only supports one '
3067
- 'config file path with $KUBECONFIG. Current '
3068
- f'path(s) are {kubeconfig_path}.')
3069
- return kubeconfig_path
3577
+ # We should always use the latest KUBECONFIG environment variable to
3578
+ # make sure env var overrides get respected.
3579
+ paths = os.getenv('KUBECONFIG', kubernetes.DEFAULT_KUBECONFIG_PATH)
3580
+ expanded = []
3581
+ for path in paths.split(kubernetes.ENV_KUBECONFIG_PATH_SEPARATOR):
3582
+ expanded.append(os.path.expanduser(path))
3583
+ return expanded
3584
+
3585
+
3586
+ def format_kubeconfig_exec_auth(config: Any,
3587
+ output_path: str,
3588
+ inject_wrapper: bool = True) -> bool:
3589
+ """Reformat the kubeconfig so that exec-based authentication can be used
3590
+ with SkyPilot. Will create a new kubeconfig file under <output_path>
3591
+ regardless of whether a change has been made.
3592
+
3593
+ kubectl internally strips all environment variables except for system
3594
+ defaults. If `inject_wrapper` is true, a wrapper executable is applied
3595
+ to inject the relevant PATH information before exec-auth is executed.
3596
+
3597
+ Contents of sky-kube-exec-wrapper:
3598
+
3599
+ #!/bin/bash
3600
+ export PATH="$HOME/skypilot-runtime/bin:$HOME/google-cloud-sdk:$PATH"
3601
+ exec "$@"
3602
+
3603
+ refer to `skylet/constants.py` for more information.
3604
+
3605
+ Args:
3606
+ config (dict): kubeconfig parsed by yaml.safe_load
3607
+ output_path (str): Path where the potentially modified kubeconfig file
3608
+ will be saved
3609
+ inject_wrapper (bool): Whether to inject the wrapper script
3610
+ Returns: whether config was updated, for logging purposes
3611
+ """
3612
+ updated = False
3613
+ for user in config.get('users', []):
3614
+ exec_info = user.get('user', {}).get('exec', {})
3615
+ current_command = exec_info.get('command', '')
3616
+
3617
+ if current_command:
3618
+ # Strip the path and keep only the executable name
3619
+ executable = os.path.basename(current_command)
3620
+ if executable == kubernetes_constants.SKY_K8S_EXEC_AUTH_WRAPPER:
3621
+ # we don't want this happening recursively.
3622
+ continue
3623
+
3624
+ if inject_wrapper:
3625
+ exec_info[
3626
+ 'command'] = kubernetes_constants.SKY_K8S_EXEC_AUTH_WRAPPER
3627
+ if exec_info.get('args') is None:
3628
+ exec_info['args'] = []
3629
+ exec_info['args'].insert(0, executable)
3630
+ updated = True
3631
+ elif executable != current_command:
3632
+ exec_info['command'] = executable
3633
+ updated = True
3634
+
3635
+ # Handle Nebius kubeconfigs: change --profile to 'sky'
3636
+ if executable == 'nebius':
3637
+ args = exec_info.get('args', [])
3638
+ if args and '--profile' in args:
3639
+ try:
3640
+ profile_index = args.index('--profile')
3641
+ if profile_index + 1 < len(args):
3642
+ old_profile = args[profile_index + 1]
3643
+ if old_profile != 'sky':
3644
+ args[profile_index + 1] = 'sky'
3645
+ updated = True
3646
+ except ValueError:
3647
+ pass
3648
+
3649
+ os.makedirs(os.path.dirname(os.path.expanduser(output_path)), exist_ok=True)
3650
+ with open(output_path, 'w', encoding='utf-8') as file:
3651
+ yaml.safe_dump(config, file)
3652
+
3653
+ return updated
3654
+
3655
+
3656
+ def format_kubeconfig_exec_auth_with_cache(kubeconfig_path: str) -> str:
3657
+ """Reformat the kubeconfig file or retrieve it from cache if it has already
3658
+ been formatted before. Store it in the cache directory if necessary.
3659
+
3660
+ Having a cache for this is good if users spawn an extreme number of jobs
3661
+ concurrently.
3662
+
3663
+ Args:
3664
+ kubeconfig_path (str): kubeconfig path
3665
+ Returns: updated kubeconfig path
3666
+ """
3667
+ # TODO(kyuds): GC cache files
3668
+ with open(kubeconfig_path, 'r', encoding='utf-8') as file:
3669
+ config = yaml_utils.safe_load(file)
3670
+ normalized = yaml.dump(config, sort_keys=True)
3671
+ hashed = hashlib.sha1(normalized.encode('utf-8')).hexdigest()
3672
+ path = os.path.expanduser(
3673
+ f'{kubernetes_constants.SKY_K8S_EXEC_AUTH_KUBECONFIG_CACHE}/{hashed}.yaml'
3674
+ )
3675
+
3676
+ # If we have already converted the same kubeconfig before, just return.
3677
+ if os.path.isfile(path):
3678
+ return path
3679
+
3680
+ try:
3681
+ format_kubeconfig_exec_auth(config, path)
3682
+ return path
3683
+ except Exception as e: # pylint: disable=broad-except
3684
+ # There may be problems with kubeconfig, but the user is not actually
3685
+ # using Kubernetes (or SSH Node Pools)
3686
+ logger.warning(
3687
+ f'Failed to format kubeconfig at {kubeconfig_path}. '
3688
+ 'Please check if the kubeconfig is valid. This may cause '
3689
+ 'problems when Kubernetes infra is used. '
3690
+ f'Reason: {common_utils.format_exception(e)}')
3691
+ return kubeconfig_path
3692
+
3693
+
3694
+ def delete_k8s_resource_with_retry(delete_func: Callable, resource_type: str,
3695
+ resource_name: str) -> None:
3696
+ """Helper to delete Kubernetes resources with 404 handling and retries.
3697
+
3698
+ Args:
3699
+ delete_func: Function to call to delete the resource
3700
+ resource_type: Type of resource being deleted (e.g. 'service'),
3701
+ used in logging
3702
+ resource_name: Name of the resource being deleted, used in logging
3703
+ """
3704
+ max_retries = 3
3705
+ retry_delay = 5 # seconds
3706
+
3707
+ for attempt in range(max_retries):
3708
+ try:
3709
+ delete_func()
3710
+ return
3711
+ except kubernetes.api_exception() as e:
3712
+ if e.status == 404:
3713
+ logger.warning(
3714
+ f'terminate_instances: Tried to delete {resource_type} '
3715
+ f'{resource_name}, but the {resource_type} was not '
3716
+ 'found (404).')
3717
+ return
3718
+ elif attempt < max_retries - 1:
3719
+ logger.warning(f'terminate_instances: Failed to delete '
3720
+ f'{resource_type} {resource_name} (attempt '
3721
+ f'{attempt + 1}/{max_retries}). Error: {e}. '
3722
+ f'Retrying in {retry_delay} seconds...')
3723
+ time.sleep(retry_delay)
3724
+ else:
3725
+ raise
3726
+
3727
+
3728
+ def should_exclude_pod_from_gpu_allocation(pod) -> bool:
3729
+ """Check if a pod should be excluded from GPU count calculations.
3730
+
3731
+ Some cloud providers run low priority test/verification pods that request
3732
+ GPUs but should not count against real GPU availability since they are
3733
+ designed to be evicted when higher priority workloads need resources.
3734
+
3735
+ Args:
3736
+ pod: Kubernetes pod object
3737
+
3738
+ Returns:
3739
+ bool: True if the pod should be excluded from GPU count calculations.
3740
+ """
3741
+ # CoreWeave HPC verification pods - identified by namespace
3742
+ if (hasattr(pod.metadata, 'namespace') and
3743
+ pod.metadata.namespace == 'cw-hpc-verification'):
3744
+ return True
3745
+
3746
+ return False
3747
+
3748
+
3749
+ def get_cleaned_context_and_cloud_str(
3750
+ context: Optional[str]) -> Tuple[Optional[str], str]:
3751
+ """Return the cleaned context and relevant cloud string from a context."""
3752
+ cloud_str = 'kubernetes'
3753
+ if context is not None and context.startswith('ssh-'):
3754
+ cloud_str = 'ssh'
3755
+ context = context[len('ssh-'):]
3756
+ return context, cloud_str