skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. sky/__init__.py +64 -32
  2. sky/adaptors/aws.py +23 -6
  3. sky/adaptors/azure.py +432 -15
  4. sky/adaptors/cloudflare.py +5 -5
  5. sky/adaptors/common.py +19 -9
  6. sky/adaptors/do.py +20 -0
  7. sky/adaptors/gcp.py +3 -2
  8. sky/adaptors/kubernetes.py +122 -88
  9. sky/adaptors/nebius.py +100 -0
  10. sky/adaptors/oci.py +39 -1
  11. sky/adaptors/vast.py +29 -0
  12. sky/admin_policy.py +101 -0
  13. sky/authentication.py +117 -98
  14. sky/backends/backend.py +52 -20
  15. sky/backends/backend_utils.py +669 -557
  16. sky/backends/cloud_vm_ray_backend.py +1099 -808
  17. sky/backends/local_docker_backend.py +14 -8
  18. sky/backends/wheel_utils.py +38 -20
  19. sky/benchmark/benchmark_utils.py +22 -23
  20. sky/check.py +76 -27
  21. sky/cli.py +1586 -1139
  22. sky/client/__init__.py +1 -0
  23. sky/client/cli.py +5683 -0
  24. sky/client/common.py +345 -0
  25. sky/client/sdk.py +1765 -0
  26. sky/cloud_stores.py +283 -19
  27. sky/clouds/__init__.py +7 -2
  28. sky/clouds/aws.py +303 -112
  29. sky/clouds/azure.py +185 -179
  30. sky/clouds/cloud.py +115 -37
  31. sky/clouds/cudo.py +29 -22
  32. sky/clouds/do.py +313 -0
  33. sky/clouds/fluidstack.py +44 -54
  34. sky/clouds/gcp.py +206 -65
  35. sky/clouds/ibm.py +26 -21
  36. sky/clouds/kubernetes.py +345 -91
  37. sky/clouds/lambda_cloud.py +40 -29
  38. sky/clouds/nebius.py +297 -0
  39. sky/clouds/oci.py +129 -90
  40. sky/clouds/paperspace.py +22 -18
  41. sky/clouds/runpod.py +53 -34
  42. sky/clouds/scp.py +28 -24
  43. sky/clouds/service_catalog/__init__.py +19 -13
  44. sky/clouds/service_catalog/aws_catalog.py +29 -12
  45. sky/clouds/service_catalog/azure_catalog.py +33 -6
  46. sky/clouds/service_catalog/common.py +95 -75
  47. sky/clouds/service_catalog/constants.py +3 -3
  48. sky/clouds/service_catalog/cudo_catalog.py +13 -3
  49. sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
  50. sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
  51. sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
  52. sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
  53. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
  54. sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
  55. sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
  56. sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
  57. sky/clouds/service_catalog/do_catalog.py +111 -0
  58. sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
  59. sky/clouds/service_catalog/gcp_catalog.py +16 -2
  60. sky/clouds/service_catalog/ibm_catalog.py +2 -2
  61. sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
  62. sky/clouds/service_catalog/lambda_catalog.py +8 -3
  63. sky/clouds/service_catalog/nebius_catalog.py +116 -0
  64. sky/clouds/service_catalog/oci_catalog.py +31 -4
  65. sky/clouds/service_catalog/paperspace_catalog.py +2 -2
  66. sky/clouds/service_catalog/runpod_catalog.py +2 -2
  67. sky/clouds/service_catalog/scp_catalog.py +2 -2
  68. sky/clouds/service_catalog/vast_catalog.py +104 -0
  69. sky/clouds/service_catalog/vsphere_catalog.py +2 -2
  70. sky/clouds/utils/aws_utils.py +65 -0
  71. sky/clouds/utils/azure_utils.py +91 -0
  72. sky/clouds/utils/gcp_utils.py +5 -9
  73. sky/clouds/utils/oci_utils.py +47 -5
  74. sky/clouds/utils/scp_utils.py +4 -3
  75. sky/clouds/vast.py +280 -0
  76. sky/clouds/vsphere.py +22 -18
  77. sky/core.py +361 -107
  78. sky/dag.py +41 -28
  79. sky/data/data_transfer.py +37 -0
  80. sky/data/data_utils.py +211 -32
  81. sky/data/mounting_utils.py +182 -30
  82. sky/data/storage.py +2118 -270
  83. sky/data/storage_utils.py +126 -5
  84. sky/exceptions.py +179 -8
  85. sky/execution.py +158 -85
  86. sky/global_user_state.py +150 -34
  87. sky/jobs/__init__.py +12 -10
  88. sky/jobs/client/__init__.py +0 -0
  89. sky/jobs/client/sdk.py +302 -0
  90. sky/jobs/constants.py +49 -11
  91. sky/jobs/controller.py +161 -99
  92. sky/jobs/dashboard/dashboard.py +171 -25
  93. sky/jobs/dashboard/templates/index.html +572 -60
  94. sky/jobs/recovery_strategy.py +157 -156
  95. sky/jobs/scheduler.py +307 -0
  96. sky/jobs/server/__init__.py +1 -0
  97. sky/jobs/server/core.py +598 -0
  98. sky/jobs/server/dashboard_utils.py +69 -0
  99. sky/jobs/server/server.py +190 -0
  100. sky/jobs/state.py +627 -122
  101. sky/jobs/utils.py +615 -206
  102. sky/models.py +27 -0
  103. sky/optimizer.py +142 -83
  104. sky/provision/__init__.py +20 -5
  105. sky/provision/aws/config.py +124 -42
  106. sky/provision/aws/instance.py +130 -53
  107. sky/provision/azure/__init__.py +7 -0
  108. sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
  109. sky/provision/azure/config.py +220 -0
  110. sky/provision/azure/instance.py +1012 -37
  111. sky/provision/common.py +31 -3
  112. sky/provision/constants.py +25 -0
  113. sky/provision/cudo/__init__.py +2 -1
  114. sky/provision/cudo/cudo_utils.py +112 -0
  115. sky/provision/cudo/cudo_wrapper.py +37 -16
  116. sky/provision/cudo/instance.py +28 -12
  117. sky/provision/do/__init__.py +11 -0
  118. sky/provision/do/config.py +14 -0
  119. sky/provision/do/constants.py +10 -0
  120. sky/provision/do/instance.py +287 -0
  121. sky/provision/do/utils.py +301 -0
  122. sky/provision/docker_utils.py +82 -46
  123. sky/provision/fluidstack/fluidstack_utils.py +57 -125
  124. sky/provision/fluidstack/instance.py +15 -43
  125. sky/provision/gcp/config.py +19 -9
  126. sky/provision/gcp/constants.py +7 -1
  127. sky/provision/gcp/instance.py +55 -34
  128. sky/provision/gcp/instance_utils.py +339 -80
  129. sky/provision/gcp/mig_utils.py +210 -0
  130. sky/provision/instance_setup.py +172 -133
  131. sky/provision/kubernetes/__init__.py +1 -0
  132. sky/provision/kubernetes/config.py +104 -90
  133. sky/provision/kubernetes/constants.py +8 -0
  134. sky/provision/kubernetes/instance.py +680 -325
  135. sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
  136. sky/provision/kubernetes/network.py +54 -20
  137. sky/provision/kubernetes/network_utils.py +70 -21
  138. sky/provision/kubernetes/utils.py +1370 -251
  139. sky/provision/lambda_cloud/__init__.py +11 -0
  140. sky/provision/lambda_cloud/config.py +10 -0
  141. sky/provision/lambda_cloud/instance.py +265 -0
  142. sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
  143. sky/provision/logging.py +1 -1
  144. sky/provision/nebius/__init__.py +11 -0
  145. sky/provision/nebius/config.py +11 -0
  146. sky/provision/nebius/instance.py +285 -0
  147. sky/provision/nebius/utils.py +318 -0
  148. sky/provision/oci/__init__.py +15 -0
  149. sky/provision/oci/config.py +51 -0
  150. sky/provision/oci/instance.py +436 -0
  151. sky/provision/oci/query_utils.py +681 -0
  152. sky/provision/paperspace/constants.py +6 -0
  153. sky/provision/paperspace/instance.py +4 -3
  154. sky/provision/paperspace/utils.py +2 -0
  155. sky/provision/provisioner.py +207 -130
  156. sky/provision/runpod/__init__.py +1 -0
  157. sky/provision/runpod/api/__init__.py +3 -0
  158. sky/provision/runpod/api/commands.py +119 -0
  159. sky/provision/runpod/api/pods.py +142 -0
  160. sky/provision/runpod/instance.py +64 -8
  161. sky/provision/runpod/utils.py +239 -23
  162. sky/provision/vast/__init__.py +10 -0
  163. sky/provision/vast/config.py +11 -0
  164. sky/provision/vast/instance.py +247 -0
  165. sky/provision/vast/utils.py +162 -0
  166. sky/provision/vsphere/common/vim_utils.py +1 -1
  167. sky/provision/vsphere/instance.py +8 -18
  168. sky/provision/vsphere/vsphere_utils.py +1 -1
  169. sky/resources.py +247 -102
  170. sky/serve/__init__.py +9 -9
  171. sky/serve/autoscalers.py +361 -299
  172. sky/serve/client/__init__.py +0 -0
  173. sky/serve/client/sdk.py +366 -0
  174. sky/serve/constants.py +12 -3
  175. sky/serve/controller.py +106 -36
  176. sky/serve/load_balancer.py +63 -12
  177. sky/serve/load_balancing_policies.py +84 -2
  178. sky/serve/replica_managers.py +42 -34
  179. sky/serve/serve_state.py +62 -32
  180. sky/serve/serve_utils.py +271 -160
  181. sky/serve/server/__init__.py +0 -0
  182. sky/serve/{core.py → server/core.py} +271 -90
  183. sky/serve/server/server.py +112 -0
  184. sky/serve/service.py +52 -16
  185. sky/serve/service_spec.py +95 -32
  186. sky/server/__init__.py +1 -0
  187. sky/server/common.py +430 -0
  188. sky/server/constants.py +21 -0
  189. sky/server/html/log.html +174 -0
  190. sky/server/requests/__init__.py +0 -0
  191. sky/server/requests/executor.py +472 -0
  192. sky/server/requests/payloads.py +487 -0
  193. sky/server/requests/queues/__init__.py +0 -0
  194. sky/server/requests/queues/mp_queue.py +76 -0
  195. sky/server/requests/requests.py +567 -0
  196. sky/server/requests/serializers/__init__.py +0 -0
  197. sky/server/requests/serializers/decoders.py +192 -0
  198. sky/server/requests/serializers/encoders.py +166 -0
  199. sky/server/server.py +1106 -0
  200. sky/server/stream_utils.py +141 -0
  201. sky/setup_files/MANIFEST.in +2 -5
  202. sky/setup_files/dependencies.py +159 -0
  203. sky/setup_files/setup.py +14 -125
  204. sky/sky_logging.py +59 -14
  205. sky/skylet/autostop_lib.py +2 -2
  206. sky/skylet/constants.py +183 -50
  207. sky/skylet/events.py +22 -10
  208. sky/skylet/job_lib.py +403 -258
  209. sky/skylet/log_lib.py +111 -71
  210. sky/skylet/log_lib.pyi +6 -0
  211. sky/skylet/providers/command_runner.py +6 -8
  212. sky/skylet/providers/ibm/node_provider.py +2 -2
  213. sky/skylet/providers/scp/config.py +11 -3
  214. sky/skylet/providers/scp/node_provider.py +8 -8
  215. sky/skylet/skylet.py +3 -1
  216. sky/skylet/subprocess_daemon.py +69 -17
  217. sky/skypilot_config.py +119 -57
  218. sky/task.py +205 -64
  219. sky/templates/aws-ray.yml.j2 +37 -7
  220. sky/templates/azure-ray.yml.j2 +27 -82
  221. sky/templates/cudo-ray.yml.j2 +7 -3
  222. sky/templates/do-ray.yml.j2 +98 -0
  223. sky/templates/fluidstack-ray.yml.j2 +7 -4
  224. sky/templates/gcp-ray.yml.j2 +26 -6
  225. sky/templates/ibm-ray.yml.j2 +3 -2
  226. sky/templates/jobs-controller.yaml.j2 +46 -11
  227. sky/templates/kubernetes-ingress.yml.j2 +7 -0
  228. sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
  229. sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
  230. sky/templates/kubernetes-ray.yml.j2 +292 -25
  231. sky/templates/lambda-ray.yml.j2 +30 -40
  232. sky/templates/nebius-ray.yml.j2 +79 -0
  233. sky/templates/oci-ray.yml.j2 +18 -57
  234. sky/templates/paperspace-ray.yml.j2 +10 -6
  235. sky/templates/runpod-ray.yml.j2 +26 -4
  236. sky/templates/scp-ray.yml.j2 +3 -2
  237. sky/templates/sky-serve-controller.yaml.j2 +12 -1
  238. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  239. sky/templates/vast-ray.yml.j2 +70 -0
  240. sky/templates/vsphere-ray.yml.j2 +8 -3
  241. sky/templates/websocket_proxy.py +64 -0
  242. sky/usage/constants.py +10 -1
  243. sky/usage/usage_lib.py +130 -37
  244. sky/utils/accelerator_registry.py +35 -51
  245. sky/utils/admin_policy_utils.py +147 -0
  246. sky/utils/annotations.py +51 -0
  247. sky/utils/cli_utils/status_utils.py +81 -23
  248. sky/utils/cluster_utils.py +356 -0
  249. sky/utils/command_runner.py +452 -89
  250. sky/utils/command_runner.pyi +77 -3
  251. sky/utils/common.py +54 -0
  252. sky/utils/common_utils.py +319 -108
  253. sky/utils/config_utils.py +204 -0
  254. sky/utils/control_master_utils.py +48 -0
  255. sky/utils/controller_utils.py +548 -266
  256. sky/utils/dag_utils.py +93 -32
  257. sky/utils/db_utils.py +18 -4
  258. sky/utils/env_options.py +29 -7
  259. sky/utils/kubernetes/create_cluster.sh +8 -60
  260. sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
  261. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  262. sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
  263. sky/utils/kubernetes/gpu_labeler.py +4 -4
  264. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
  265. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  266. sky/utils/kubernetes/rsync_helper.sh +24 -0
  267. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
  268. sky/utils/log_utils.py +240 -33
  269. sky/utils/message_utils.py +81 -0
  270. sky/utils/registry.py +127 -0
  271. sky/utils/resources_utils.py +94 -22
  272. sky/utils/rich_utils.py +247 -18
  273. sky/utils/schemas.py +284 -64
  274. sky/{status_lib.py → utils/status_lib.py} +12 -7
  275. sky/utils/subprocess_utils.py +212 -46
  276. sky/utils/timeline.py +12 -7
  277. sky/utils/ux_utils.py +168 -15
  278. skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
  279. skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
  280. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
  281. sky/clouds/cloud_registry.py +0 -31
  282. sky/jobs/core.py +0 -330
  283. sky/skylet/providers/azure/__init__.py +0 -2
  284. sky/skylet/providers/azure/azure-vm-template.json +0 -301
  285. sky/skylet/providers/azure/config.py +0 -170
  286. sky/skylet/providers/azure/node_provider.py +0 -466
  287. sky/skylet/providers/lambda_cloud/__init__.py +0 -2
  288. sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
  289. sky/skylet/providers/oci/__init__.py +0 -2
  290. sky/skylet/providers/oci/node_provider.py +0 -488
  291. sky/skylet/providers/oci/query_helper.py +0 -383
  292. sky/skylet/providers/oci/utils.py +0 -21
  293. sky/utils/cluster_yaml_utils.py +0 -24
  294. sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
  295. skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
  296. skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
  297. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
  298. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
  299. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,14 @@
1
1
  """Kubernetes utilities for SkyPilot."""
2
+ import dataclasses
3
+ import functools
2
4
  import json
3
5
  import math
4
6
  import os
5
7
  import re
8
+ import shutil
6
9
  import subprocess
10
+ import time
11
+ import typing
7
12
  from typing import Any, Dict, List, Optional, Set, Tuple, Union
8
13
  from urllib.parse import urlparse
9
14
 
@@ -11,17 +16,30 @@ import jinja2
11
16
  import yaml
12
17
 
13
18
  import sky
19
+ from sky import clouds
14
20
  from sky import exceptions
21
+ from sky import models
15
22
  from sky import sky_logging
16
23
  from sky import skypilot_config
17
24
  from sky.adaptors import kubernetes
25
+ from sky.provision import constants as provision_constants
26
+ from sky.provision.kubernetes import constants as kubernetes_constants
18
27
  from sky.provision.kubernetes import network_utils
28
+ from sky.skylet import constants
29
+ from sky.utils import annotations
19
30
  from sky.utils import common_utils
31
+ from sky.utils import config_utils
20
32
  from sky.utils import env_options
21
33
  from sky.utils import kubernetes_enums
22
34
  from sky.utils import schemas
35
+ from sky.utils import status_lib
36
+ from sky.utils import timeline
23
37
  from sky.utils import ux_utils
24
38
 
39
+ if typing.TYPE_CHECKING:
40
+ from sky import backends
41
+ from sky import resources as resources_lib
42
+
25
43
  # TODO(romilb): Move constants to constants.py
26
44
  DEFAULT_NAMESPACE = 'default'
27
45
 
@@ -35,10 +53,18 @@ MEMORY_SIZE_UNITS = {
35
53
  'T': 2**40,
36
54
  'P': 2**50,
37
55
  }
38
- NO_GPU_HELP_MESSAGE = ('If your cluster contains GPUs, make sure '
39
- 'nvidia.com/gpu resource is available on the nodes and '
40
- 'the node labels for identifying GPUs '
41
- '(e.g., skypilot.co/accelerator) are setup correctly. ')
56
+
57
+ # The resource keys used by Kubernetes to track NVIDIA GPUs and Google TPUs on
58
+ # nodes. These keys are typically used in the node's status.allocatable
59
+ # or status.capacity fields to indicate the available resources on the node.
60
+ GPU_RESOURCE_KEY = 'nvidia.com/gpu'
61
+ TPU_RESOURCE_KEY = 'google.com/tpu'
62
+
63
+ NO_ACCELERATOR_HELP_MESSAGE = (
64
+ 'If your cluster contains GPUs or TPUs, make sure '
65
+ f'{GPU_RESOURCE_KEY} or {TPU_RESOURCE_KEY} resource is available '
66
+ 'on the nodes and the node labels for identifying GPUs/TPUs '
67
+ '(e.g., skypilot.co/accelerator) are setup correctly. ')
42
68
 
43
69
  KUBERNETES_AUTOSCALER_NOTE = (
44
70
  'Note: Kubernetes cluster autoscaling is enabled. '
@@ -53,8 +79,106 @@ ENDPOINTS_DEBUG_MESSAGE = ('Additionally, make sure your {endpoint_type} '
53
79
 
54
80
  KIND_CONTEXT_NAME = 'kind-skypilot' # Context name used by sky local up
55
81
 
82
+ # Port-forward proxy command constants
83
+ PORT_FORWARD_PROXY_CMD_TEMPLATE = 'kubernetes-port-forward-proxy-command.sh'
84
+ # We add a version suffix to the port-forward proxy command to ensure backward
85
+ # compatibility and avoid overwriting the older version.
86
+ PORT_FORWARD_PROXY_CMD_VERSION = 2
87
+ PORT_FORWARD_PROXY_CMD_PATH = ('~/.sky/kubernetes-port-forward-proxy-command-'
88
+ f'v{PORT_FORWARD_PROXY_CMD_VERSION}.sh')
89
+
90
+ # Mapping used to get generation for TPU accelerator name.
91
+ # https://cloud.google.com/kubernetes-engine/docs/how-to/tpus#run
92
+ GKE_TPU_ACCELERATOR_TO_GENERATION = {
93
+ 'tpu-v4-podslice': 'v4',
94
+ # Only Single-host v5e TPU configurations are allowed.
95
+ 'tpu-v5-lite-device': 'v5e',
96
+ # Multi-host compatible v5e TPU configurations allowed.
97
+ 'tpu-v5-lite-podslice': 'v5e',
98
+ 'tpu-v5p-slice': 'v5p',
99
+ }
100
+
101
+ POD_STATUSES = {
102
+ 'Pending', 'Running', 'Succeeded', 'Failed', 'Unknown', 'Terminating'
103
+ }
104
+ AUTODOWN_ANNOTATION_KEY = 'skypilot.co/autodown'
105
+ IDLE_MINUTES_TO_AUTOSTOP_ANNOTATION_KEY = (
106
+ 'skypilot.co/idle_minutes_to_autostop')
107
+ ANNOTATIONS_POD_NOT_FOUND_ERROR_MSG = ('Pod {pod_name} not found in namespace '
108
+ '{namespace} while trying to {action} '
109
+ 'an annotation {annotation}.')
110
+
56
111
  logger = sky_logging.init_logger(__name__)
57
112
 
113
+ # Default retry settings for Kubernetes API calls
114
+ DEFAULT_MAX_RETRIES = 3
115
+ DEFAULT_RETRY_INTERVAL_SECONDS = 1
116
+
117
+
118
+ def _retry_on_error(max_retries=DEFAULT_MAX_RETRIES,
119
+ retry_interval=DEFAULT_RETRY_INTERVAL_SECONDS,
120
+ resource_type: Optional[str] = None):
121
+ """Decorator to retry Kubernetes API calls on transient failures.
122
+
123
+ Args:
124
+ max_retries: Maximum number of retry attempts
125
+ retry_interval: Initial seconds to wait between retries
126
+ resource_type: Type of resource being accessed (e.g. 'node', 'pod').
127
+ Used to provide more specific error messages.
128
+ """
129
+
130
+ def decorator(func):
131
+
132
+ @functools.wraps(func)
133
+ def wrapper(*args, **kwargs):
134
+ last_exception = None
135
+ backoff = common_utils.Backoff(initial_backoff=retry_interval,
136
+ max_backoff_factor=3)
137
+
138
+ for attempt in range(max_retries):
139
+ try:
140
+ return func(*args, **kwargs)
141
+ except (kubernetes.max_retry_error(),
142
+ kubernetes.api_exception(),
143
+ kubernetes.config_exception()) as e:
144
+ last_exception = e
145
+ # Don't retry on permanent errors like 401 (Unauthorized)
146
+ # or 403 (Forbidden)
147
+ if (isinstance(e, kubernetes.api_exception()) and
148
+ e.status in (401, 403)):
149
+ raise
150
+ if attempt < max_retries - 1:
151
+ sleep_time = backoff.current_backoff()
152
+ logger.debug(f'Kubernetes API call {func.__name__} '
153
+ f'failed with {str(e)}. Retrying in '
154
+ f'{sleep_time:.1f}s...')
155
+ time.sleep(sleep_time)
156
+ continue
157
+
158
+ # Format error message based on the type of exception
159
+ resource_msg = f' when trying to get {resource_type} info' \
160
+ if resource_type else ''
161
+ debug_cmd = f' To debug, run: kubectl get {resource_type}s' \
162
+ if resource_type else ''
163
+
164
+ if isinstance(last_exception, kubernetes.max_retry_error()):
165
+ error_msg = f'Timed out{resource_msg} from Kubernetes cluster.'
166
+ elif isinstance(last_exception, kubernetes.api_exception()):
167
+ error_msg = (f'Kubernetes API error{resource_msg}: '
168
+ f'{str(last_exception)}')
169
+ else:
170
+ error_msg = (f'Kubernetes configuration error{resource_msg}: '
171
+ f'{str(last_exception)}')
172
+
173
+ raise exceptions.ResourcesUnavailableError(
174
+ f'{error_msg}'
175
+ f' Please check if the cluster is healthy and retry.'
176
+ f'{debug_cmd}') from last_exception
177
+
178
+ return wrapper
179
+
180
+ return decorator
181
+
58
182
 
59
183
  class GPULabelFormatter:
60
184
  """Base class to define a GPU label formatter for a Kubernetes cluster
@@ -65,15 +189,41 @@ class GPULabelFormatter:
65
189
  """
66
190
 
67
191
  @classmethod
68
- def get_label_key(cls) -> str:
192
+ def get_tpu_topology_label_key(cls) -> str:
193
+ """Returns the label for TPU topology used by the Kubernetes cluster.
194
+
195
+ Only implemented by formatters that support TPUs.
196
+ """
197
+ raise NotImplementedError
198
+
199
+ @classmethod
200
+ def get_tpu_topology_label_value(cls, acc_type: str, acc_count: int) -> str:
201
+ """Returns the TPU topology value for the given TPU type and count.
202
+
203
+ Only implemented by formatters that support TPUs.
204
+ """
205
+ raise NotImplementedError
206
+
207
+ @classmethod
208
+ def get_label_key(cls, accelerator: Optional[str] = None) -> str:
69
209
  """Returns the label key for GPU type used by the Kubernetes cluster"""
70
210
  raise NotImplementedError
71
211
 
212
+ @classmethod
213
+ def get_label_keys(cls) -> List[str]:
214
+ """Returns a list of label keys for GPU used by Kubernetes cluster."""
215
+ raise NotImplementedError
216
+
72
217
  @classmethod
73
218
  def get_label_value(cls, accelerator: str) -> str:
74
219
  """Given a GPU type, returns the label value to be used"""
75
220
  raise NotImplementedError
76
221
 
222
+ @classmethod
223
+ def match_label_key(cls, label_key: str) -> bool:
224
+ """Checks if the given label key matches the formatter's label keys"""
225
+ raise NotImplementedError
226
+
77
227
  @classmethod
78
228
  def get_accelerator_from_label_value(cls, value: str) -> str:
79
229
  """Given a label value, returns the GPU type"""
@@ -95,14 +245,21 @@ class GPULabelFormatter:
95
245
 
96
246
 
97
247
  def get_gke_accelerator_name(accelerator: str) -> str:
98
- """Returns the accelerator name for GKE clusters
248
+ """Returns the accelerator name for GKE clusters.
99
249
 
100
250
  Uses the format - nvidia-tesla-<accelerator>.
101
- A100-80GB, H100-80GB and L4 are an exception. They use nvidia-<accelerator>.
251
+ A100-80GB, H100-80GB, L4 are an exception. They use nvidia-<accelerator>.
252
+ TPU types are an exception as well keeping the given name.
102
253
  """
103
- if accelerator in ('A100-80GB', 'L4', 'H100-80GB'):
104
- # A100-80GB, L4 and H100-80GB have a different name pattern.
254
+ if accelerator == 'H100':
255
+ # H100 is named as H100-80GB in GKE.
256
+ accelerator = 'H100-80GB'
257
+ if accelerator in ('A100-80GB', 'L4', 'H100-80GB', 'H100-MEGA-80GB'):
258
+ # A100-80GB, L4, H100-80GB and H100-MEGA-80GB
259
+ # have a different name pattern.
105
260
  return 'nvidia-{}'.format(accelerator.lower())
261
+ elif accelerator.startswith('tpu-'):
262
+ return accelerator
106
263
  else:
107
264
  return 'nvidia-tesla-{}'.format(accelerator.lower())
108
265
 
@@ -117,15 +274,23 @@ class SkyPilotLabelFormatter(GPULabelFormatter):
117
274
  LABEL_KEY = 'skypilot.co/accelerator'
118
275
 
119
276
  @classmethod
120
- def get_label_key(cls) -> str:
277
+ def get_label_key(cls, accelerator: Optional[str] = None) -> str:
121
278
  return cls.LABEL_KEY
122
279
 
280
+ @classmethod
281
+ def get_label_keys(cls) -> List[str]:
282
+ return [cls.LABEL_KEY]
283
+
123
284
  @classmethod
124
285
  def get_label_value(cls, accelerator: str) -> str:
125
286
  # For SkyPilot formatter, we use the accelerator str directly.
126
287
  # See sky.utils.kubernetes.gpu_labeler.
127
288
  return accelerator.lower()
128
289
 
290
+ @classmethod
291
+ def match_label_key(cls, label_key: str) -> bool:
292
+ return label_key == cls.LABEL_KEY
293
+
129
294
  @classmethod
130
295
  def get_accelerator_from_label_value(cls, value: str) -> str:
131
296
  return value.upper()
@@ -149,13 +314,21 @@ class CoreWeaveLabelFormatter(GPULabelFormatter):
149
314
  LABEL_KEY = 'gpu.nvidia.com/class'
150
315
 
151
316
  @classmethod
152
- def get_label_key(cls) -> str:
317
+ def get_label_key(cls, accelerator: Optional[str] = None) -> str:
153
318
  return cls.LABEL_KEY
154
319
 
320
+ @classmethod
321
+ def get_label_keys(cls) -> List[str]:
322
+ return [cls.LABEL_KEY]
323
+
155
324
  @classmethod
156
325
  def get_label_value(cls, accelerator: str) -> str:
157
326
  return accelerator.upper()
158
327
 
328
+ @classmethod
329
+ def match_label_key(cls, label_key: str) -> bool:
330
+ return label_key == cls.LABEL_KEY
331
+
159
332
  @classmethod
160
333
  def get_accelerator_from_label_value(cls, value: str) -> str:
161
334
  return value
@@ -167,12 +340,67 @@ class GKELabelFormatter(GPULabelFormatter):
167
340
  GKE nodes by default are populated with `cloud.google.com/gke-accelerator`
168
341
  label, which is used to identify the GPU type.
169
342
  """
343
+ GPU_LABEL_KEY = 'cloud.google.com/gke-accelerator'
344
+ TPU_LABEL_KEY = 'cloud.google.com/gke-tpu-accelerator'
345
+ ACCELERATOR_COUNT_LABEL_KEY = 'cloud.google.com/gke-accelerator-count'
346
+ TPU_TOPOLOGY_LABEL_KEY = 'cloud.google.com/gke-tpu-topology'
347
+
348
+ # Mapping from TPU type to {count: topologies}. Used to determine topology
349
+ # label to use in an autoscaling environment. For list of topologies, see:
350
+ # tpu v5e: https://cloud.google.com/tpu/docs/tpus-in-gke
351
+ # tpu v5p: https://cloud.google.com/tpu/docs/v5p
352
+ # TODO(romilb): Add support for TPU v4 and v6.
353
+ GKE_TPU_TOPOLOGIES = {
354
+ 'tpu-v5-lite-podslice': {
355
+ 1: '1x1',
356
+ 4: '2x2',
357
+ 8: '2x4'
358
+ },
359
+ 'tpu-v5-lite-device': {
360
+ 1: '1x1',
361
+ 4: '2x2',
362
+ 8: '2x4'
363
+ },
364
+ 'tpu-v5p-slice': {
365
+ 4: '2x2x1'
366
+ },
367
+ }
170
368
 
171
- LABEL_KEY = 'cloud.google.com/gke-accelerator'
369
+ @classmethod
370
+ def get_label_key(cls, accelerator: Optional[str] = None) -> str:
371
+ if accelerator is not None and accelerator.startswith('tpu-'):
372
+ return cls.TPU_LABEL_KEY
373
+ return cls.GPU_LABEL_KEY
172
374
 
173
375
  @classmethod
174
- def get_label_key(cls) -> str:
175
- return cls.LABEL_KEY
376
+ def get_label_keys(cls) -> List[str]:
377
+ return [cls.GPU_LABEL_KEY, cls.TPU_LABEL_KEY]
378
+
379
+ @classmethod
380
+ def match_label_key(cls, label_key: str) -> bool:
381
+ return label_key in cls.get_label_keys()
382
+
383
+ @classmethod
384
+ def get_tpu_topology_label_key(cls) -> str:
385
+ return cls.TPU_TOPOLOGY_LABEL_KEY
386
+
387
+ @classmethod
388
+ def get_tpu_topology_label_value(cls, acc_type: str, acc_count: int) -> str:
389
+ """Returns the TPU topology label value for the given TPU count.
390
+
391
+ e.g. tpu-v5-lite-podslice:8 -> '2x4'
392
+ """
393
+ count_to_topology = cls.GKE_TPU_TOPOLOGIES.get(acc_type,
394
+ {}).get(acc_count, None)
395
+ if count_to_topology is None:
396
+ supported_tpus = {
397
+ tpu: list(topologies.values())
398
+ for tpu, topologies in cls.GKE_TPU_TOPOLOGIES.items()
399
+ }
400
+ raise ValueError(
401
+ f'No TPU topology found for {acc_type} with count {acc_count}. '
402
+ f'Supported TPU types and counts: {supported_tpus}')
403
+ return count_to_topology
176
404
 
177
405
  @classmethod
178
406
  def get_label_value(cls, accelerator: str) -> str:
@@ -183,12 +411,85 @@ class GKELabelFormatter(GPULabelFormatter):
183
411
  if value.startswith('nvidia-tesla-'):
184
412
  return value.replace('nvidia-tesla-', '').upper()
185
413
  elif value.startswith('nvidia-'):
186
- return value.replace('nvidia-', '').upper()
414
+ acc = value.replace('nvidia-', '').upper()
415
+ if acc == 'H100-80GB':
416
+ # H100 can be either H100-80GB or H100-MEGA-80GB in GKE
417
+ # we map H100 ---> H100-80GB and keep H100-MEGA-80GB
418
+ # to distinguish between a3-high and a3-mega instances
419
+ return 'H100'
420
+ return acc
421
+ elif is_tpu_on_gke(value):
422
+ return value
187
423
  else:
188
424
  raise ValueError(
189
425
  f'Invalid accelerator name in GKE cluster: {value}')
190
426
 
191
427
 
428
+ class GFDLabelFormatter(GPULabelFormatter):
429
+ """GPU Feature Discovery label formatter
430
+
431
+ NVIDIA GPUs nodes are labeled by GPU feature discovery
432
+ e.g. nvidia.com/gpu.product=NVIDIA-H100-80GB-HBM3
433
+ https://github.com/NVIDIA/gpu-feature-discovery
434
+
435
+ GPU feature discovery is included as part of the
436
+ NVIDIA GPU Operator:
437
+ https://docs.nvidia.com/datacenter/cloud-native/gpu-operator/latest/overview.html
438
+
439
+ This LabelFormatter can't be used in autoscaling clusters since accelerators
440
+ may map to multiple label, so we're not implementing `get_label_value`
441
+ """
442
+
443
+ LABEL_KEY = 'nvidia.com/gpu.product'
444
+
445
+ @classmethod
446
+ def get_label_key(cls, accelerator: Optional[str] = None) -> str:
447
+ return cls.LABEL_KEY
448
+
449
+ @classmethod
450
+ def get_label_keys(cls) -> List[str]:
451
+ return [cls.LABEL_KEY]
452
+
453
+ @classmethod
454
+ def get_label_value(cls, accelerator: str) -> str:
455
+ """An accelerator can map to many Nvidia GFD labels
456
+ (e.g., A100-80GB-PCIE vs. A100-SXM4-80GB).
457
+ As a result, we do not support get_label_value for GFDLabelFormatter."""
458
+ raise NotImplementedError
459
+
460
+ @classmethod
461
+ def match_label_key(cls, label_key: str) -> bool:
462
+ return label_key == cls.LABEL_KEY
463
+
464
+ @classmethod
465
+ def get_accelerator_from_label_value(cls, value: str) -> str:
466
+ """Searches against a canonical list of NVIDIA GPUs and pattern
467
+ matches the canonical GPU name against the GFD label.
468
+ """
469
+ canonical_gpu_names = [
470
+ 'A100-80GB', 'A100', 'A10G', 'H100', 'K80', 'M60', 'T4g', 'T4',
471
+ 'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L40', 'L4'
472
+ ]
473
+ for canonical_name in canonical_gpu_names:
474
+ # A100-80G accelerator is A100-SXM-80GB or A100-PCIE-80GB
475
+ if canonical_name == 'A100-80GB' and re.search(
476
+ r'A100.*-80GB', value):
477
+ return canonical_name
478
+ # Use word boundary matching to prevent substring matches
479
+ elif re.search(rf'\b{re.escape(canonical_name)}\b', value):
480
+ return canonical_name
481
+
482
+ # If we didn't find a canonical name:
483
+ # 1. remove 'NVIDIA-' (e.g., 'NVIDIA-RTX-A6000' -> 'RTX-A6000')
484
+ # 2. remove 'GEFORCE-' (e.g., 'NVIDIA-GEFORCE-RTX-3070' -> 'RTX-3070')
485
+ # 3. remove 'RTX-' (e.g. 'RTX-6000' -> 'RTX6000')
486
+ # Same logic, but uppercased, as the Skypilot labeler job found in
487
+ # sky/utils/kubernetes/k8s_gpu_labeler_setup.yaml
488
+ return value.upper().replace('NVIDIA-',
489
+ '').replace('GEFORCE-',
490
+ '').replace('RTX-', 'RTX')
491
+
492
+
192
493
  class KarpenterLabelFormatter(SkyPilotLabelFormatter):
193
494
  """Karpeneter label formatter
194
495
  Karpenter uses the label `karpenter.k8s.aws/instance-gpu-name` to identify
@@ -203,8 +504,8 @@ class KarpenterLabelFormatter(SkyPilotLabelFormatter):
203
504
  # it will be used to determine the priority of the label formats when
204
505
  # auto-detecting the GPU label type.
205
506
  LABEL_FORMATTER_REGISTRY = [
206
- SkyPilotLabelFormatter, CoreWeaveLabelFormatter, GKELabelFormatter,
207
- KarpenterLabelFormatter
507
+ SkyPilotLabelFormatter, GKELabelFormatter, KarpenterLabelFormatter,
508
+ GFDLabelFormatter, CoreWeaveLabelFormatter
208
509
  ]
209
510
 
210
511
  # Mapping of autoscaler type to label formatter
@@ -215,7 +516,9 @@ AUTOSCALER_TO_LABEL_FORMATTER = {
215
516
  }
216
517
 
217
518
 
519
+ @annotations.lru_cache(scope='request')
218
520
  def detect_gpu_label_formatter(
521
+ context: Optional[str]
219
522
  ) -> Tuple[Optional[GPULabelFormatter], Dict[str, List[Tuple[str, str]]]]:
220
523
  """Detects the GPU label formatter for the Kubernetes cluster
221
524
 
@@ -226,7 +529,7 @@ def detect_gpu_label_formatter(
226
529
  """
227
530
  # Get all labels across all nodes
228
531
  node_labels: Dict[str, List[Tuple[str, str]]] = {}
229
- nodes = get_kubernetes_nodes()
532
+ nodes = get_kubernetes_nodes(context)
230
533
  for node in nodes:
231
534
  node_labels[node.metadata.name] = []
232
535
  for label, value in node.metadata.labels.items():
@@ -236,63 +539,72 @@ def detect_gpu_label_formatter(
236
539
 
237
540
  # Check if the node labels contain any of the GPU label prefixes
238
541
  for lf in LABEL_FORMATTER_REGISTRY:
239
- label_key = lf.get_label_key()
240
542
  for _, label_list in node_labels.items():
241
543
  for label, _ in label_list:
242
- if label.startswith(label_key):
544
+ if lf.match_label_key(label):
243
545
  label_formatter = lf()
244
546
  return label_formatter, node_labels
245
547
 
246
548
  return label_formatter, node_labels
247
549
 
248
550
 
249
- def detect_gpu_resource() -> Tuple[bool, Set[str]]:
250
- """Checks if the Kubernetes cluster has nvidia.com/gpu resource.
551
+ @annotations.lru_cache(scope='request', maxsize=10)
552
+ def detect_accelerator_resource(
553
+ context: Optional[str]) -> Tuple[bool, Set[str]]:
554
+ """Checks if the Kubernetes cluster has GPU/TPU resource.
251
555
 
252
- If nvidia.com/gpu resource is missing, that typically means that the
253
- Kubernetes cluster does not have GPUs or the nvidia GPU operator and/or
254
- device drivers are not installed.
556
+ Two types of accelerator resources are available which are each checked
557
+ with nvidia.com/gpu and google.com/tpu. If nvidia.com/gpu resource is
558
+ missing, that typically means that the Kubernetes cluster does not have
559
+ GPUs or the nvidia GPU operator and/or device drivers are not installed.
255
560
 
256
561
  Returns:
257
- bool: True if the cluster has nvidia.com/gpu resource, False otherwise.
562
+ bool: True if the cluster has GPU_RESOURCE_KEY or TPU_RESOURCE_KEY
563
+ resource, False otherwise.
258
564
  """
259
565
  # Get the set of resources across all nodes
260
566
  cluster_resources: Set[str] = set()
261
- nodes = get_kubernetes_nodes()
567
+ nodes = get_kubernetes_nodes(context)
262
568
  for node in nodes:
263
569
  cluster_resources.update(node.status.allocatable.keys())
264
- has_gpu = 'nvidia.com/gpu' in cluster_resources
570
+ has_accelerator = (get_gpu_resource_key() in cluster_resources or
571
+ TPU_RESOURCE_KEY in cluster_resources)
265
572
 
266
- return has_gpu, cluster_resources
573
+ return has_accelerator, cluster_resources
267
574
 
268
575
 
269
- def get_kubernetes_nodes() -> List[Any]:
270
- # TODO(romilb): Calling kube API can take between 10-100ms depending on
271
- # the control plane. Consider caching calls to this function (using
272
- # kubecontext hash as key).
273
- try:
274
- nodes = kubernetes.core_api().list_node(
275
- _request_timeout=kubernetes.API_TIMEOUT).items
276
- except kubernetes.max_retry_error():
277
- raise exceptions.ResourcesUnavailableError(
278
- 'Timed out when trying to get node info from Kubernetes cluster. '
279
- 'Please check if the cluster is healthy and retry.') from None
576
+ @annotations.lru_cache(scope='request', maxsize=10)
577
+ @_retry_on_error(resource_type='node')
578
+ def get_kubernetes_nodes(context: Optional[str] = None) -> List[Any]:
579
+ """Gets the kubernetes nodes in the context.
580
+
581
+ If context is None, gets the nodes in the current context.
582
+ """
583
+ if context is None:
584
+ context = get_current_kube_config_context_name()
585
+
586
+ nodes = kubernetes.core_api(context).list_node(
587
+ _request_timeout=kubernetes.API_TIMEOUT).items
280
588
  return nodes
281
589
 
282
590
 
283
- def get_kubernetes_pods() -> List[Any]:
284
- try:
285
- ns = get_current_kube_config_context_namespace()
286
- pods = kubernetes.core_api().list_namespaced_pod(
287
- ns, _request_timeout=kubernetes.API_TIMEOUT).items
288
- except kubernetes.max_retry_error():
289
- raise exceptions.ResourcesUnavailableError(
290
- 'Timed out when trying to get pod info from Kubernetes cluster. '
291
- 'Please check if the cluster is healthy and retry.') from None
591
+ @_retry_on_error(resource_type='pod')
592
+ def get_all_pods_in_kubernetes_cluster(
593
+ context: Optional[str] = None) -> List[Any]:
594
+ """Gets pods in all namespaces in kubernetes cluster indicated by context.
595
+
596
+ Used for computing cluster resource usage.
597
+ """
598
+ if context is None:
599
+ context = get_current_kube_config_context_name()
600
+
601
+ pods = kubernetes.core_api(context).list_pod_for_all_namespaces(
602
+ _request_timeout=kubernetes.API_TIMEOUT).items
292
603
  return pods
293
604
 
294
605
 
295
- def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
606
+ def check_instance_fits(context: Optional[str],
607
+ instance: str) -> Tuple[bool, Optional[str]]:
296
608
  """Checks if the instance fits on the Kubernetes cluster.
297
609
 
298
610
  If the instance has GPU requirements, checks if the GPU type is
@@ -307,6 +619,9 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
307
619
  Optional[str]: Error message if the instance does not fit.
308
620
  """
309
621
 
622
+ # TODO(zhwu): this should check the node for specific context, instead
623
+ # of the default context to make failover fully functional.
624
+
310
625
  def check_cpu_mem_fits(candidate_instance_type: 'KubernetesInstanceType',
311
626
  node_list: List[Any]) -> Tuple[bool, Optional[str]]:
312
627
  """Checks if the instance fits on the cluster based on CPU and memory.
@@ -333,15 +648,53 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
333
648
  'Maximum resources found on a single node: '
334
649
  f'{max_cpu} CPUs, {common_utils.format_float(max_mem)}G Memory')
335
650
 
336
- nodes = get_kubernetes_nodes()
651
+ def check_tpu_fits(candidate_instance_type: 'KubernetesInstanceType',
652
+ node_list: List[Any]) -> Tuple[bool, Optional[str]]:
653
+ """Checks if the instance fits on the cluster based on requested TPU.
654
+
655
+ It checks if the TPU type and count on each node match the required
656
+ number of TPU chips for the instance. In the case of multi-host TPU
657
+ podslice, the function ensures that the number of TPU chips on a single
658
+ node (node_tpu_chip_count) and the total TPU chips across the entire
659
+ podslice (topology_chip_count) are correctly handled.
660
+ """
661
+ acc_type = candidate_instance_type.accelerator_type
662
+ acc_count = candidate_instance_type.accelerator_count
663
+ tpu_list_in_cluster = []
664
+ for node in node_list:
665
+ if acc_type == node.metadata.labels[
666
+ GKELabelFormatter.TPU_LABEL_KEY]:
667
+ # TODO(Doyoung): Update the logic when adding support for
668
+ # multi-host TPUs.
669
+ if is_multi_host_tpu(node.metadata.labels):
670
+ continue
671
+ node_tpu_chip_count = int(node.metadata.labels[
672
+ GKELabelFormatter.ACCELERATOR_COUNT_LABEL_KEY])
673
+ tpu_type = f'{acc_type}:{node_tpu_chip_count}'
674
+ tpu_list_in_cluster.append(tpu_type)
675
+ if node_tpu_chip_count == acc_count:
676
+ return True, None
677
+ tpu_list_in_cluster_str = ','.join(tpu_list_in_cluster)
678
+ # TODO(Doyoung): Update the error message raised with the multi-host
679
+ # TPU support.
680
+ return False, ('Requested TPU type was not found in the cluster. TPU '
681
+ 'types found in the cluster: '
682
+ f'{tpu_list_in_cluster_str}. Note that multi-host TPU '
683
+ 'podslices are currently not unsupported.')
684
+
685
+ nodes = get_kubernetes_nodes(context)
337
686
  k8s_instance_type = KubernetesInstanceType.\
338
687
  from_instance_type(instance)
339
688
  acc_type = k8s_instance_type.accelerator_type
689
+ acc_count = k8s_instance_type.accelerator_count
340
690
  if acc_type is not None:
341
- # If GPUs are requested, check if GPU type is available, and if so,
342
- # check if CPU and memory requirements on the specific node are met.
691
+ # If GPU/TPUs are requested, check if GPU/TPU type is available, and
692
+ # if so, check if CPU and memory requirements on the specific node are
693
+ # met.
694
+ assert acc_count is not None, (acc_type, acc_count)
343
695
  try:
344
- gpu_label_key, gpu_label_val = get_gpu_label_key_value(acc_type)
696
+ gpu_label_key, gpu_label_val, _, _ = (
697
+ get_accelerator_label_key_value(context, acc_type, acc_count))
345
698
  except exceptions.ResourcesUnavailableError as e:
346
699
  # If GPU not found, return empty list and error message.
347
700
  return False, str(e)
@@ -350,14 +703,26 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
350
703
  node for node in nodes if gpu_label_key in node.metadata.labels and
351
704
  node.metadata.labels[gpu_label_key] == gpu_label_val
352
705
  ]
353
- assert len(gpu_nodes) > 0, 'GPU nodes not found'
706
+ assert gpu_nodes, 'GPU nodes not found'
707
+ if is_tpu_on_gke(acc_type):
708
+ # If requested accelerator is a TPU type, check if the cluster
709
+ # has sufficient TPU resource to meet the requirement.
710
+ fits, reason = check_tpu_fits(k8s_instance_type, gpu_nodes)
711
+ if reason is not None:
712
+ return fits, reason
713
+
354
714
  candidate_nodes = gpu_nodes
355
- not_fit_reason_prefix = (f'GPU nodes with {acc_type} do not have '
356
- 'enough CPU and/or memory. ')
715
+ not_fit_reason_prefix = (
716
+ f'GPU nodes with {acc_type} do not have '
717
+ f'enough CPU (> {k8s_instance_type.cpus} CPUs) and/or '
718
+ f'memory (> {k8s_instance_type.memory} G). ')
357
719
  else:
358
720
  candidate_nodes = nodes
359
- not_fit_reason_prefix = 'No nodes found with enough CPU and/or memory. '
360
- # Check if CPU and memory requirements are met on at least one
721
+ not_fit_reason_prefix = (f'No nodes found with enough '
722
+ f'CPU (> {k8s_instance_type.cpus} CPUs) '
723
+ 'and/or memory '
724
+ f'(> {k8s_instance_type.memory} G). ')
725
+ # Check if CPU and memory requirements are met on at least one
361
726
  # candidate node.
362
727
  fits, reason = check_cpu_mem_fits(k8s_instance_type, candidate_nodes)
363
728
  if not fits:
@@ -368,23 +733,33 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
368
733
  return fits, reason
369
734
 
370
735
 
371
- def get_gpu_label_key_value(acc_type: str, check_mode=False) -> Tuple[str, str]:
372
- """Returns the label key and value for the given GPU type.
736
+ def get_accelerator_label_key_value(
737
+ context: Optional[str],
738
+ acc_type: str,
739
+ acc_count: int,
740
+ check_mode=False
741
+ ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
742
+ """Returns the label key and value for the given GPU/TPU type.
373
743
 
374
744
  Args:
375
- acc_type: The GPU type required by the task.
376
- check_mode: If True, only checks if the cluster has GPU resources and
377
- labels are setup on the cluster. acc_type is ignore does not return
378
- the label key and value. Useful for checking if GPUs are configured
379
- correctly on the cluster without explicitly requesting a acc_type.
745
+ acc_type: The GPU/TPU type required by the task.
746
+ acc_count: Number of GPU/TPUs required by the task.
747
+ check_mode: If True, only checks if the cluster has GPU/TPU resources
748
+ and labels are setup on the cluster. acc_type is ignore does not
749
+ return the label key and value. Useful for checking if GPUs are
750
+ configured correctly on the cluster without explicitly requesting
751
+ a acc_type.
380
752
  Returns:
381
- A tuple of the label key and value. Returns empty strings if check_mode
382
- is True.
753
+ A tuple of the accelerator label key, value, topology label key, and
754
+ topology value. The topology label key and value are populated only if
755
+ the requested accelerator type is TPU. Returns None if check_mode is
756
+ True.
383
757
  Raises:
384
758
  ResourcesUnavailableError: Can be raised from the following conditions:
385
- - The cluster does not have GPU resources (nvidia.com/gpu)
386
- - The cluster does not have GPU labels setup correctly
387
- - The cluster doesn't have any nodes with acc_type GPU
759
+ - The cluster does not have GPU/TPU resources
760
+ (nvidia.com/gpu, google.com/tpu)
761
+ - The cluster does not have GPU/TPU labels setup correctly
762
+ - The cluster doesn't have any nodes with acc_type GPU/TPU
388
763
  """
389
764
  # Check if the cluster has GPU resources
390
765
  # TODO(romilb): This assumes the accelerator is a nvidia GPU. We
@@ -403,23 +778,33 @@ def get_gpu_label_key_value(acc_type: str, check_mode=False) -> Tuple[str, str]:
403
778
  # If check mode is enabled and autoscaler is set, we can return
404
779
  # early since we assume the cluster autoscaler will handle GPU
405
780
  # node provisioning.
406
- return '', ''
781
+ return None, None, None, None
407
782
  formatter = AUTOSCALER_TO_LABEL_FORMATTER.get(autoscaler_type)
408
783
  assert formatter is not None, ('Unsupported autoscaler type:'
409
784
  f' {autoscaler_type}')
410
- return formatter.get_label_key(), formatter.get_label_value(acc_type)
411
-
412
- has_gpus, cluster_resources = detect_gpu_resource()
785
+ tpu_topology_label_key = None
786
+ tpu_topology_label_value = None
787
+ if is_tpu_on_gke(acc_type):
788
+ assert formatter == GKELabelFormatter, formatter
789
+ tpu_topology_label_key = formatter.get_tpu_topology_label_key()
790
+ tpu_topology_label_value = formatter.get_tpu_topology_label_value(
791
+ acc_type, acc_count)
792
+ return formatter.get_label_key(acc_type), formatter.get_label_value(
793
+ acc_type), tpu_topology_label_key, tpu_topology_label_value
794
+
795
+ has_gpus, cluster_resources = detect_accelerator_resource(context)
413
796
  if has_gpus:
414
797
  # Check if the cluster has GPU labels setup correctly
415
798
  label_formatter, node_labels = \
416
- detect_gpu_label_formatter()
799
+ detect_gpu_label_formatter(context)
417
800
  if label_formatter is None:
418
801
  # If none of the GPU labels from LABEL_FORMATTER_REGISTRY are
419
802
  # detected, raise error
420
803
  with ux_utils.print_exception_no_traceback():
421
- supported_formats = ', '.join(
422
- [f.get_label_key() for f in LABEL_FORMATTER_REGISTRY])
804
+ supported_formats = ', '.join([
805
+ key for f in LABEL_FORMATTER_REGISTRY
806
+ for key in f.get_label_keys()
807
+ ])
423
808
  suffix = ''
424
809
  if env_options.Options.SHOW_DEBUG_INFO.get():
425
810
  suffix = f' Found node labels: {node_labels}'
@@ -430,12 +815,12 @@ def get_gpu_label_key_value(acc_type: str, check_mode=False) -> Tuple[str, str]:
430
815
  f'{supported_formats}. Please refer to '
431
816
  'the documentation on how to set up node labels.'
432
817
  f'{suffix}')
433
- if label_formatter is not None:
818
+ else:
434
819
  # Validate the label value on all nodes labels to ensure they are
435
820
  # correctly setup and will behave as expected.
436
821
  for node_name, label_list in node_labels.items():
437
822
  for label, value in label_list:
438
- if label == label_formatter.get_label_key():
823
+ if label_formatter.match_label_key(label):
439
824
  is_valid, reason = label_formatter.validate_label_value(
440
825
  value)
441
826
  if not is_valid:
@@ -445,9 +830,7 @@ def get_gpu_label_key_value(acc_type: str, check_mode=False) -> Tuple[str, str]:
445
830
  if check_mode:
446
831
  # If check mode is enabled and we reached so far, we can
447
832
  # conclude that the cluster is setup correctly and return.
448
- return '', ''
449
- k8s_acc_label_key = label_formatter.get_label_key()
450
- k8s_acc_label_value = label_formatter.get_label_value(acc_type)
833
+ return None, None, None, None
451
834
  # Search in node_labels to see if any node has the requested
452
835
  # GPU type.
453
836
  # Note - this only checks if the label is available on a
@@ -455,12 +838,43 @@ def get_gpu_label_key_value(acc_type: str, check_mode=False) -> Tuple[str, str]:
455
838
  # quantity is available since that is dynamic and can change
456
839
  # during scheduling.
457
840
  for node_name, label_list in node_labels.items():
841
+ node_metadata_labels = dict(label_list)
842
+ # TODO(Doyoung): Update the logic when adding support for
843
+ # multi-host TPUs.
844
+ if is_multi_host_tpu(node_metadata_labels):
845
+ continue
458
846
  for label, value in label_list:
459
- if (label == k8s_acc_label_key and
460
- value == k8s_acc_label_value):
461
- # If a node is found, we can break out of the loop
462
- # and proceed to deploy.
463
- return k8s_acc_label_key, k8s_acc_label_value
847
+ if (label_formatter.match_label_key(label) and
848
+ label_formatter.get_accelerator_from_label_value(
849
+ value) == acc_type):
850
+ if is_tpu_on_gke(acc_type):
851
+ assert isinstance(label_formatter,
852
+ GKELabelFormatter)
853
+ if node_metadata_labels.get(
854
+ label_formatter.TPU_LABEL_KEY) == acc_type:
855
+ topology_label_key = (
856
+ label_formatter.get_tpu_topology_label_key(
857
+ ))
858
+ # Instead of using get_tpu_topology_label_value,
859
+ # we use the node's label value to determine the
860
+ # topology. This is to make sure the node's
861
+ # available topology matches our request.
862
+ topology_value = node_metadata_labels.get(
863
+ topology_label_key)
864
+ assert topology_value is not None
865
+ tpu_topology_chip_count = reduce_tpu_topology(
866
+ topology_value)
867
+ # For single-host TPUs, there aren't multiple
868
+ # different topologies that maps to identical
869
+ # number of TPU chips.
870
+ if tpu_topology_chip_count == acc_count:
871
+ return (label, value, topology_label_key,
872
+ topology_value)
873
+ else:
874
+ continue
875
+ else:
876
+ return label, value, None, None
877
+
464
878
  # If no node is found with the requested acc_type, raise error
465
879
  with ux_utils.print_exception_no_traceback():
466
880
  suffix = ''
@@ -468,15 +882,19 @@ def get_gpu_label_key_value(acc_type: str, check_mode=False) -> Tuple[str, str]:
468
882
  all_labels = []
469
883
  for node_name, label_list in node_labels.items():
470
884
  all_labels.extend(label_list)
471
- gpus_available = set(
472
- v for k, v in all_labels if k == k8s_acc_label_key)
473
- suffix = f' Available GPUs on the cluster: {gpus_available}'
885
+ acc_available = set(v for k, v in all_labels
886
+ if label_formatter.match_label_key(k))
887
+ suffix = (' Available GPU/TPUs on the cluster: '
888
+ f'{acc_available}')
889
+ # TODO(Doyoung): Update the error message raised with the
890
+ # multi-host TPU support.
474
891
  raise exceptions.ResourcesUnavailableError(
475
892
  'Could not find any node in the Kubernetes cluster '
476
- f'with {acc_type} GPU. Please ensure at least '
477
- f'one node in the cluster has {acc_type} GPU and node '
478
- 'labels are setup correctly. '
479
- f'Please refer to the documentation for more. {suffix}')
893
+ f'with {acc_type}. Please ensure at least one node in the '
894
+ f'cluster has {acc_type} and node labels are setup '
895
+ 'correctly. Please refer to the documentration for more. '
896
+ f'{suffix}. Note that multi-host TPU podslices are '
897
+ 'currently not unsupported.')
480
898
  else:
481
899
  # If GPU resources are not detected, raise error
482
900
  with ux_utils.print_exception_no_traceback():
@@ -485,55 +903,62 @@ def get_gpu_label_key_value(acc_type: str, check_mode=False) -> Tuple[str, str]:
485
903
  suffix = (' Available resources on the cluster: '
486
904
  f'{cluster_resources}')
487
905
  raise exceptions.ResourcesUnavailableError(
488
- 'Could not detect GPU resources (`nvidia.com/gpu`) in '
489
- 'Kubernetes cluster. If this cluster contains GPUs, please '
490
- 'ensure GPU drivers are installed on the node. Check if the '
491
- 'GPUs are setup correctly by running `kubectl describe nodes` '
492
- 'and looking for the nvidia.com/gpu resource. '
493
- 'Please refer to the documentation on how '
494
- f'to set up GPUs.{suffix}')
495
-
496
-
497
- def get_head_ssh_port(cluster_name: str, namespace: str) -> int:
906
+ f'Could not detect GPU/TPU resources ({GPU_RESOURCE_KEY!r} or '
907
+ f'{TPU_RESOURCE_KEY!r}) in Kubernetes cluster. If this cluster'
908
+ ' contains GPUs, please ensure GPU drivers are installed on '
909
+ 'the node. Check if the GPUs are setup correctly by running '
910
+ '`kubectl describe nodes` and looking for the '
911
+ f'{GPU_RESOURCE_KEY!r} or {TPU_RESOURCE_KEY!r} resource. '
912
+ 'Please refer to the documentation on how to set up GPUs.'
913
+ f'{suffix}')
914
+ assert False, 'This should not be reached'
915
+
916
+
917
+ def get_head_ssh_port(cluster_name: str, namespace: str,
918
+ context: Optional[str]) -> int:
498
919
  svc_name = f'{cluster_name}-head-ssh'
499
- return get_port(svc_name, namespace)
920
+ return get_port(svc_name, namespace, context)
500
921
 
501
922
 
502
- def get_port(svc_name: str, namespace: str) -> int:
923
+ def get_port(svc_name: str, namespace: str, context: Optional[str]) -> int:
503
924
  """Gets the nodeport of the specified service.
504
925
 
505
926
  Args:
506
927
  svc_name (str): Name of the kubernetes service. Note that this may be
507
928
  different from the cluster name.
508
929
  namespace (str): Kubernetes namespace to look for the service in.
930
+ context (str): Kubernetes context to use.
509
931
  """
510
- head_service = kubernetes.core_api().read_namespaced_service(
932
+ head_service = kubernetes.core_api(context).read_namespaced_service(
511
933
  svc_name, namespace)
512
934
  return head_service.spec.ports[0].node_port
513
935
 
514
936
 
515
- def get_external_ip(
516
- network_mode: Optional[kubernetes_enums.KubernetesNetworkingMode]):
937
+ def get_external_ip(network_mode: Optional[
938
+ kubernetes_enums.KubernetesNetworkingMode], context: Optional[str]) -> str:
517
939
  if network_mode == kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD:
518
940
  return '127.0.0.1'
519
941
  # Return the IP address of the first node with an external IP
520
- nodes = kubernetes.core_api().list_node().items
942
+ nodes = kubernetes.core_api(context).list_node().items
521
943
  for node in nodes:
522
944
  if node.status.addresses:
523
945
  for address in node.status.addresses:
524
946
  if address.type == 'ExternalIP':
525
947
  return address.address
526
948
  # If no external IP is found, use the API server IP
527
- api_host = kubernetes.core_api().api_client.configuration.host
949
+ api_host = kubernetes.core_api(context).api_client.configuration.host
528
950
  parsed_url = urlparse(api_host)
529
951
  return parsed_url.hostname
530
952
 
531
953
 
532
- def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
954
+ def check_credentials(context: Optional[str],
955
+ timeout: int = kubernetes.API_TIMEOUT) -> \
533
956
  Tuple[bool, Optional[str]]:
534
957
  """Check if the credentials in kubeconfig file are valid
535
958
 
536
959
  Args:
960
+ context (Optional[str]): The Kubernetes context to use. If none, uses
961
+ in-cluster auth to check credentials, if available.
537
962
  timeout (int): Timeout in seconds for the test API call
538
963
 
539
964
  Returns:
@@ -541,8 +966,9 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
541
966
  str: Error message if credentials are invalid, None otherwise
542
967
  """
543
968
  try:
544
- ns = get_current_kube_config_context_namespace()
545
- kubernetes.core_api().list_namespaced_pod(ns, _request_timeout=timeout)
969
+ namespace = get_kube_config_context_namespace(context)
970
+ kubernetes.core_api(context).list_namespaced_pod(
971
+ namespace, _request_timeout=timeout)
546
972
  except ImportError:
547
973
  # TODO(romilb): Update these error strs to also include link to docs
548
974
  # when docs are ready.
@@ -571,7 +997,7 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
571
997
  # We now do softer checks to check if exec based auth is used and to
572
998
  # see if the cluster is GPU-enabled.
573
999
 
574
- _, exec_msg = is_kubeconfig_exec_auth()
1000
+ _, exec_msg = is_kubeconfig_exec_auth(context)
575
1001
 
576
1002
  # We now check if GPUs are available and labels are set correctly on the
577
1003
  # cluster, and if not we return hints that may help debug any issues.
@@ -580,7 +1006,10 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
580
1006
  # provider if their cluster GPUs are not setup correctly.
581
1007
  gpu_msg = ''
582
1008
  try:
583
- _, _ = get_gpu_label_key_value(acc_type='', check_mode=True)
1009
+ get_accelerator_label_key_value(context,
1010
+ acc_type='',
1011
+ acc_count=0,
1012
+ check_mode=True)
584
1013
  except exceptions.ResourcesUnavailableError as e:
585
1014
  # If GPUs are not available, we return cluster as enabled (since it can
586
1015
  # be a CPU-only cluster) but we also return the exception message which
@@ -596,7 +1025,54 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
596
1025
  return True, None
597
1026
 
598
1027
 
599
- def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
1028
+ def check_pod_config(pod_config: dict) \
1029
+ -> Tuple[bool, Optional[str]]:
1030
+ """Check if the pod_config is a valid pod config
1031
+
1032
+ Using deserialize api to check the pod_config is valid or not.
1033
+
1034
+ Returns:
1035
+ bool: True if pod_config is valid.
1036
+ str: Error message about why the pod_config is invalid, None otherwise.
1037
+ """
1038
+ errors = []
1039
+ # This api_client won't be used to send any requests, so there is no need to
1040
+ # load kubeconfig
1041
+ api_client = kubernetes.kubernetes.client.ApiClient()
1042
+
1043
+ # Used for kubernetes api_client deserialize function, the function will use
1044
+ # data attr, the detail ref:
1045
+ # https://github.com/kubernetes-client/python/blob/master/kubernetes/client/api_client.py#L244
1046
+ class InnerResponse():
1047
+
1048
+ def __init__(self, data: dict):
1049
+ self.data = json.dumps(data)
1050
+
1051
+ try:
1052
+ # Validate metadata if present
1053
+ if 'metadata' in pod_config:
1054
+ try:
1055
+ value = InnerResponse(pod_config['metadata'])
1056
+ api_client.deserialize(
1057
+ value, kubernetes.kubernetes.client.V1ObjectMeta)
1058
+ except ValueError as e:
1059
+ errors.append(f'Invalid metadata: {str(e)}')
1060
+ # Validate spec if present
1061
+ if 'spec' in pod_config:
1062
+ try:
1063
+ value = InnerResponse(pod_config['spec'])
1064
+ api_client.deserialize(value,
1065
+ kubernetes.kubernetes.client.V1PodSpec)
1066
+ except ValueError as e:
1067
+ errors.append(f'Invalid spec: {str(e)}')
1068
+ return len(errors) == 0, '.'.join(errors)
1069
+ except Exception as e: # pylint: disable=broad-except
1070
+ errors.append(f'Validation error: {str(e)}')
1071
+ return False, '.'.join(errors)
1072
+
1073
+
1074
+ def is_kubeconfig_exec_auth(
1075
+ context: Optional[str] = None) -> Tuple[bool, Optional[str]]:
600
1076
  """Checks if the kubeconfig file uses exec-based authentication
601
1077
 
602
1078
  Exec-based auth is commonly used for authenticating with cloud hosted
@@ -623,6 +1099,9 @@ def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
623
1099
  str: Error message if exec-based authentication is used, None otherwise
624
1100
  """
625
1101
  k8s = kubernetes.kubernetes
1102
+ if context == kubernetes.in_cluster_context_name():
1103
+ # If in-cluster config is used, exec-based auth is not used.
1104
+ return False, None
626
1105
  try:
627
1106
  k8s.config.load_kube_config()
628
1107
  except kubernetes.config_exception():
@@ -630,8 +1109,16 @@ def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
630
1109
  return False, None
631
1110
 
632
1111
  # Get active context and user from kubeconfig using k8s api
633
- _, current_context = k8s.config.list_kube_config_contexts()
634
- target_username = current_context['context']['user']
1112
+ all_contexts, current_context = k8s.config.list_kube_config_contexts()
1113
+ context_obj = current_context
1114
+ if context is not None:
1115
+ for c in all_contexts:
1116
+ if c['name'] == context:
1117
+ context_obj = c
1118
+ break
1119
+ else:
1120
+ raise ValueError(f'Kubernetes context {context!r} not found.')
1121
+ target_username = context_obj['context']['user']
635
1122
 
636
1123
  # K8s api does not provide a mechanism to get the user details from the
637
1124
  # context. We need to load the kubeconfig file and parse it to get the
@@ -654,7 +1141,7 @@ def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
654
1141
  schemas.get_default_remote_identity('kubernetes'))
655
1142
  if ('exec' in user_details.get('user', {}) and remote_identity
656
1143
  == schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value):
657
- ctx_name = current_context['name']
1144
+ ctx_name = context_obj['name']
658
1145
  exec_msg = ('exec-based authentication is used for '
659
1146
  f'Kubernetes context {ctx_name!r}.'
660
1147
  ' This may cause issues with autodown or when running '
@@ -664,12 +1151,13 @@ def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
664
1151
  '~/.sky/config.yaml:\n'
665
1152
  ' kubernetes:\n'
666
1153
  ' remote_identity: SERVICE_ACCOUNT\n'
667
- ' More: https://skypilot.readthedocs.io/en/latest/'
1154
+ ' More: https://docs.skypilot.co/en/latest/'
668
1155
  'reference/config.html')
669
1156
  return True, exec_msg
670
1157
  return False, None
671
1158
 
672
1159
 
1160
+ @annotations.lru_cache(scope='request')
673
1161
  def get_current_kube_config_context_name() -> Optional[str]:
674
1162
  """Get the current kubernetes context from the kubeconfig file
675
1163
 
@@ -684,18 +1172,90 @@ def get_current_kube_config_context_name() -> Optional[str]:
684
1172
  return None
685
1173
 
686
1174
 
687
- def get_current_kube_config_context_namespace() -> str:
1175
+ def is_incluster_config_available() -> bool:
1176
+ """Check if in-cluster auth is available.
1177
+
1178
+ Note: We cannot use load_incluster_config() to check if in-cluster config
1179
+ is available because it will load the in-cluster config (if available)
1180
+ and modify the current global kubernetes config. We simply check if the
1181
+ service account token file exists to determine if in-cluster config may
1182
+ be available.
1183
+ """
1184
+ return os.path.exists('/var/run/secrets/kubernetes.io/serviceaccount/token')
1185
+
1186
+
1187
+ def get_all_kube_context_names() -> List[str]:
1188
+ """Get all kubernetes context names available in the environment.
1189
+
1190
+ Fetches context names from the kubeconfig file and in-cluster auth, if any.
1191
+
1192
+ If running in-cluster and IN_CLUSTER_CONTEXT_NAME_ENV_VAR is not set,
1193
+ returns the default in-cluster kubernetes context name.
1194
+
1195
+ We should not cache the result of this function as the admin policy may
1196
+ update the contexts.
1197
+
1198
+ Returns:
1199
+ List[Optional[str]]: The list of kubernetes context names if
1200
+ available, an empty list otherwise.
1201
+ """
1202
+ k8s = kubernetes.kubernetes
1203
+ context_names = []
1204
+ try:
1205
+ all_contexts, _ = k8s.config.list_kube_config_contexts()
1206
+ # all_contexts will always have at least one context. If kubeconfig
1207
+ # does not have any contexts defined, it will raise ConfigException.
1208
+ context_names = [context['name'] for context in all_contexts]
1209
+ except k8s.config.config_exception.ConfigException:
1210
+ # If no config found, continue
1211
+ pass
1212
+ if is_incluster_config_available():
1213
+ context_names.append(kubernetes.in_cluster_context_name())
1214
+ return context_names
1215
+
1216
+
1217
+ @annotations.lru_cache(scope='request')
1218
+ def get_kube_config_context_namespace(
1219
+ context_name: Optional[str] = None) -> str:
688
1220
  """Get the current kubernetes context namespace from the kubeconfig file
689
1221
 
690
1222
  Returns:
691
- str | None: The current kubernetes context namespace if it exists, else
1223
+ str: The current kubernetes context namespace if it exists, else
692
1224
  the default namespace.
693
1225
  """
694
1226
  k8s = kubernetes.kubernetes
1227
+ ns_path = '/var/run/secrets/kubernetes.io/serviceaccount/namespace'
1228
+ # If using in-cluster context, first check for the environment variable,
1229
+ # then fall back to the service account namespace file. Uses the same logic
1230
+ # as adaptors.kubernetes._load_config() to stay consistent with in-cluster
1231
+ # config loading.
1232
+ if (context_name == kubernetes.in_cluster_context_name() or
1233
+ context_name is None):
1234
+ # First check for environment variable. We allow the env var to take
1235
+ # effect only when using in-cluster auth because the recommended way to
1236
+ # set the namespace when using kubeconfig is to change the namespace
1237
+ # configured in the context.
1238
+ env_namespace = os.getenv(
1239
+ kubernetes_constants.KUBERNETES_IN_CLUSTER_NAMESPACE_ENV_VAR)
1240
+ if env_namespace:
1241
+ return env_namespace
1242
+ # Fall back to service account namespace file
1243
+ if os.path.exists(ns_path):
1244
+ with open(ns_path, encoding='utf-8') as f:
1245
+ return f.read().strip()
1246
+ # If not in-cluster, get the namespace from kubeconfig
695
1247
  try:
696
- _, current_context = k8s.config.list_kube_config_contexts()
697
- if 'namespace' in current_context['context']:
698
- return current_context['context']['namespace']
1248
+ contexts, current_context = k8s.config.list_kube_config_contexts()
1249
+ if context_name is None:
1250
+ context = current_context
1251
+ else:
1252
+ context = next((c for c in contexts if c['name'] == context_name),
1253
+ None)
1254
+ if context is None:
1255
+ return DEFAULT_NAMESPACE
1256
+
1257
+ if 'namespace' in context['context']:
1258
+ return context['context']['namespace']
699
1259
  else:
700
1260
  return DEFAULT_NAMESPACE
701
1261
  except k8s.config.config_exception.ConfigException:
@@ -742,13 +1302,13 @@ class KubernetesInstanceType:
742
1302
  - Accelerators
743
1303
  The name format is "{n}CPU--{k}GB" where n is the number of vCPUs and
744
1304
  k is the amount of memory in GB. Accelerators can be specified by
745
- appending "--{a}{type}" where a is the number of accelerators and
746
- type is the accelerator type.
1305
+ appending "--{type}:{a}" where type is the accelerator type and a
1306
+ is the number of accelerators.
747
1307
  CPU and memory can be specified as floats. Accelerator count must be int.
748
1308
  Examples:
749
1309
  - 4CPU--16GB
750
1310
  - 0.5CPU--1.5GB
751
- - 4CPU--16GB--1V100
1311
+ - 4CPU--16GB--V100:1
752
1312
  """
753
1313
 
754
1314
  def __init__(self,
@@ -769,13 +1329,18 @@ class KubernetesInstanceType:
769
1329
  name = (f'{common_utils.format_float(self.cpus)}CPU--'
770
1330
  f'{common_utils.format_float(self.memory)}GB')
771
1331
  if self.accelerator_count:
772
- name += f'--{self.accelerator_count}{self.accelerator_type}'
1332
+ # Replace spaces with underscores in accelerator type to make it a
1333
+ # valid logical instance type name.
1334
+ assert self.accelerator_type is not None, self.accelerator_count
1335
+ acc_name = self.accelerator_type.replace(' ', '_')
1336
+ name += f'--{acc_name}:{self.accelerator_count}'
773
1337
  return name
774
1338
 
775
1339
  @staticmethod
776
1340
  def is_valid_instance_type(name: str) -> bool:
777
1341
  """Returns whether the given name is a valid instance type."""
778
- pattern = re.compile(r'^(\d+(\.\d+)?CPU--\d+(\.\d+)?GB)(--\d+\S+)?$')
1342
+ pattern = re.compile(
1343
+ r'^(\d+(\.\d+)?CPU--\d+(\.\d+)?GB)(--[\w\d-]+:\d+)?$')
779
1344
  return bool(pattern.match(name))
780
1345
 
781
1346
  @classmethod
@@ -790,7 +1355,7 @@ class KubernetesInstanceType:
790
1355
  accelerator_type | str: Type of accelerator
791
1356
  """
792
1357
  pattern = re.compile(
793
- r'^(?P<cpus>\d+(\.\d+)?)CPU--(?P<memory>\d+(\.\d+)?)GB(?:--(?P<accelerator_count>\d+)(?P<accelerator_type>\S+))?$' # pylint: disable=line-too-long
1358
+ r'^(?P<cpus>\d+(\.\d+)?)CPU--(?P<memory>\d+(\.\d+)?)GB(?:--(?P<accelerator_type>[\w\d-]+):(?P<accelerator_count>\d+))?$' # pylint: disable=line-too-long
794
1359
  )
795
1360
  match = pattern.match(name)
796
1361
  if match:
@@ -800,7 +1365,9 @@ class KubernetesInstanceType:
800
1365
  accelerator_type = match.group('accelerator_type')
801
1366
  if accelerator_count:
802
1367
  accelerator_count = int(accelerator_count)
803
- accelerator_type = str(accelerator_type)
1368
+ # This is to revert the accelerator types with spaces back to
1369
+ # the original format.
1370
+ accelerator_type = str(accelerator_type).replace('_', ' ')
804
1371
  else:
805
1372
  accelerator_count = None
806
1373
  accelerator_type = None
@@ -834,7 +1401,7 @@ class KubernetesInstanceType:
834
1401
  # Round up accelerator_count if it is not an int.
835
1402
  accelerator_count = math.ceil(accelerator_count)
836
1403
  if accelerator_count > 0:
837
- name += f'--{accelerator_count}{accelerator_type}'
1404
+ name += f'--{accelerator_type}:{accelerator_count}'
838
1405
  return cls(cpus=cpus,
839
1406
  memory=memory,
840
1407
  accelerator_count=accelerator_count,
@@ -844,30 +1411,49 @@ class KubernetesInstanceType:
844
1411
  return self.name
845
1412
 
846
1413
 
847
- def construct_ssh_jump_command(private_key_path: str,
848
- ssh_jump_ip: str,
849
- ssh_jump_port: Optional[int] = None,
850
- proxy_cmd_path: Optional[str] = None) -> str:
1414
+ def construct_ssh_jump_command(
1415
+ private_key_path: str,
1416
+ ssh_jump_ip: str,
1417
+ ssh_jump_port: Optional[int] = None,
1418
+ ssh_jump_user: str = 'sky',
1419
+ proxy_cmd_path: Optional[str] = None,
1420
+ proxy_cmd_target_pod: Optional[str] = None,
1421
+ current_kube_context: Optional[str] = None,
1422
+ current_kube_namespace: Optional[str] = None) -> str:
851
1423
  ssh_jump_proxy_command = (f'ssh -tt -i {private_key_path} '
852
1424
  '-o StrictHostKeyChecking=no '
853
1425
  '-o UserKnownHostsFile=/dev/null '
854
1426
  f'-o IdentitiesOnly=yes '
855
- f'-W %h:%p sky@{ssh_jump_ip}')
1427
+ r'-W \[%h\]:%p '
1428
+ f'{ssh_jump_user}@{ssh_jump_ip}')
856
1429
  if ssh_jump_port is not None:
857
1430
  ssh_jump_proxy_command += f' -p {ssh_jump_port} '
858
1431
  if proxy_cmd_path is not None:
859
1432
  proxy_cmd_path = os.path.expanduser(proxy_cmd_path)
860
1433
  # adding execution permission to the proxy command script
861
1434
  os.chmod(proxy_cmd_path, os.stat(proxy_cmd_path).st_mode | 0o111)
862
- ssh_jump_proxy_command += f' -o ProxyCommand=\'{proxy_cmd_path}\' '
1435
+ kube_context_flag = f'-c {current_kube_context} ' if (
1436
+ current_kube_context is not None) else ''
1437
+ kube_namespace_flag = f'-n {current_kube_namespace} ' if (
1438
+ current_kube_namespace is not None) else ''
1439
+ ssh_jump_proxy_command += (f' -o ProxyCommand=\'{proxy_cmd_path} '
1440
+ f'{kube_context_flag}'
1441
+ f'{kube_namespace_flag}'
1442
+ f'{proxy_cmd_target_pod}\'')
863
1443
  return ssh_jump_proxy_command
864
1444
 
865
1445
 
866
1446
  def get_ssh_proxy_command(
867
- private_key_path: str, ssh_jump_name: str,
868
- network_mode: kubernetes_enums.KubernetesNetworkingMode, namespace: str,
869
- port_fwd_proxy_cmd_path: str, port_fwd_proxy_cmd_template: str) -> str:
870
- """Generates the SSH proxy command to connect through the SSH jump pod.
1447
+ k8s_ssh_target: str,
1448
+ network_mode: kubernetes_enums.KubernetesNetworkingMode,
1449
+ private_key_path: str,
1450
+ context: Optional[str],
1451
+ namespace: str,
1452
+ ) -> str:
1453
+ """Generates the SSH proxy command to connect to the pod.
1454
+
1455
+ Uses a jump pod if the network mode is NODEPORT, and direct port-forwarding
1456
+ if the network mode is PORTFORWARD.
871
1457
 
872
1458
  By default, establishing an SSH connection creates a communication
873
1459
  channel to a remote node by setting up a TCP connection. When a
@@ -883,58 +1469,87 @@ def get_ssh_proxy_command(
883
1469
 
884
1470
  With the NodePort networking mode, a NodePort service is launched. This
885
1471
  service opens an external port on the node which redirects to the desired
886
- port within the pod. When establishing an SSH session in this mode, the
1472
+ port to a SSH jump pod. When establishing an SSH session in this mode, the
887
1473
  ProxyCommand makes use of this external port to create a communication
888
1474
  channel directly to port 22, which is the default port ssh server listens
889
1475
  on, of the jump pod.
890
1476
 
891
1477
  With Port-forward mode, instead of directly exposing an external port,
892
1478
  'kubectl port-forward' sets up a tunnel between a local port
893
- (127.0.0.1:23100) and port 22 of the jump pod. Then we establish a TCP
1479
+ (127.0.0.1:23100) and port 22 of the provisioned pod. Then we establish TCP
894
1480
  connection to the local end of this tunnel, 127.0.0.1:23100, using 'socat'.
895
- This is setup in the inner ProxyCommand of the nested ProxyCommand, and the
896
- rest is the same as NodePort approach, which the outer ProxyCommand
897
- establishes a communication channel between 127.0.0.1:23100 and port 22 on
898
- the jump pod. Consequently, any stdin provided on the local machine is
899
- forwarded through this tunnel to the application (SSH server) listening in
900
- the pod. Similarly, any output from the application in the pod is tunneled
901
- back and displayed in the terminal on the local machine.
1481
+ All of this is done in a ProxyCommand script. Any stdin provided on the
1482
+ local machine is forwarded through this tunnel to the application
1483
+ (SSH server) listening in the pod. Similarly, any output from the
1484
+ application in the pod is tunneled back and displayed in the terminal on
1485
+ the local machine.
902
1486
 
903
1487
  Args:
904
- private_key_path: str; Path to the private key to use for SSH.
905
- This key must be authorized to access the SSH jump pod.
906
- ssh_jump_name: str; Name of the SSH jump service to use
1488
+ k8s_ssh_target: str; The Kubernetes object that will be used as the
1489
+ target for SSH. If network_mode is NODEPORT, this is the name of the
1490
+ service. If network_mode is PORTFORWARD, this is the pod name.
907
1491
  network_mode: KubernetesNetworkingMode; networking mode for ssh
908
1492
  session. It is either 'NODEPORT' or 'PORTFORWARD'
909
- namespace: Kubernetes namespace to use
910
- port_fwd_proxy_cmd_path: str; path to the script used as Proxycommand
911
- with 'kubectl port-forward'
912
- port_fwd_proxy_cmd_template: str; template used to create
913
- 'kubectl port-forward' Proxycommand
1493
+ private_key_path: str; Path to the private key to use for SSH.
1494
+ This key must be authorized to access the SSH jump pod.
1495
+ Required for NODEPORT networking mode.
1496
+ namespace: Kubernetes namespace to use.
1497
+ Required for NODEPORT networking mode.
914
1498
  """
915
1499
  # Fetch IP to connect to for the jump svc
916
- ssh_jump_ip = get_external_ip(network_mode)
1500
+ ssh_jump_ip = get_external_ip(network_mode, context)
1501
+ assert private_key_path is not None, 'Private key path must be provided'
917
1502
  if network_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT:
918
- ssh_jump_port = get_port(ssh_jump_name, namespace)
1503
+ assert namespace is not None, 'Namespace must be provided for NodePort'
1504
+ ssh_jump_port = get_port(k8s_ssh_target, namespace, context)
919
1505
  ssh_jump_proxy_command = construct_ssh_jump_command(
920
1506
  private_key_path, ssh_jump_ip, ssh_jump_port=ssh_jump_port)
921
- # Setting kubectl port-forward/socat to establish ssh session using
922
- # ClusterIP service to disallow any ports opened
923
1507
  else:
924
- vars_to_fill = {
925
- 'ssh_jump_name': ssh_jump_name,
926
- }
927
- common_utils.fill_template(port_fwd_proxy_cmd_template,
928
- vars_to_fill,
929
- output_path=port_fwd_proxy_cmd_path)
1508
+ ssh_jump_proxy_command_path = create_proxy_command_script()
930
1509
  ssh_jump_proxy_command = construct_ssh_jump_command(
931
1510
  private_key_path,
932
1511
  ssh_jump_ip,
933
- proxy_cmd_path=port_fwd_proxy_cmd_path)
1512
+ ssh_jump_user=constants.SKY_SSH_USER_PLACEHOLDER,
1513
+ proxy_cmd_path=ssh_jump_proxy_command_path,
1514
+ proxy_cmd_target_pod=k8s_ssh_target,
1515
+ # We embed both the current context and namespace to the SSH proxy
1516
+ # command to make sure SSH still works when the current
1517
+ # context/namespace is changed by the user.
1518
+ current_kube_context=context,
1519
+ current_kube_namespace=namespace)
934
1520
  return ssh_jump_proxy_command
935
1521
 
936
1522
 
1523
+ def create_proxy_command_script() -> str:
1524
+ """Creates a ProxyCommand script that uses kubectl port-forward to setup
1525
+ a tunnel between a local port and the SSH server in the pod.
1526
+
1527
+ Returns:
1528
+ str: Path to the ProxyCommand script.
1529
+ """
1530
+ port_fwd_proxy_cmd_path = os.path.expanduser(PORT_FORWARD_PROXY_CMD_PATH)
1531
+ os.makedirs(os.path.dirname(port_fwd_proxy_cmd_path),
1532
+ exist_ok=True,
1533
+ mode=0o700)
1534
+
1535
+ root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
1536
+ template_path = os.path.join(root_dir, 'templates',
1537
+ PORT_FORWARD_PROXY_CMD_TEMPLATE)
1538
+ # Copy the template to the proxy command path. We create a copy to allow
1539
+ # different users sharing the same SkyPilot installation to have their own
1540
+ # proxy command scripts.
1541
+ shutil.copy(template_path, port_fwd_proxy_cmd_path)
1542
+ # Set the permissions to 700 to ensure only the owner can read, write,
1543
+ # and execute the file.
1544
+ os.chmod(port_fwd_proxy_cmd_path, 0o700)
1545
+ # Return the path to the proxy command script without expanding the user
1546
+ # home directory to be compatible when a SSH is called from a client in
1547
+ # client-server mode.
1548
+ return PORT_FORWARD_PROXY_CMD_PATH
1549
+
1550
+
937
1551
  def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str,
1552
+ context: Optional[str],
938
1553
  service_type: kubernetes_enums.KubernetesServiceType):
939
1554
  """Sets up Kubernetes service resource to access for SSH jump pod.
940
1555
 
@@ -956,13 +1571,14 @@ def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str,
956
1571
 
957
1572
  # Create service
958
1573
  try:
959
- kubernetes.core_api().create_namespaced_service(namespace,
960
- content['service_spec'])
1574
+ kubernetes.core_api(context).create_namespaced_service(
1575
+ namespace, content['service_spec'])
961
1576
  except kubernetes.api_exception() as e:
962
1577
  # SSH Jump Pod service already exists.
963
1578
  if e.status == 409:
964
- ssh_jump_service = kubernetes.core_api().read_namespaced_service(
965
- name=ssh_jump_name, namespace=namespace)
1579
+ ssh_jump_service = kubernetes.core_api(
1580
+ context).read_namespaced_service(name=ssh_jump_name,
1581
+ namespace=namespace)
966
1582
  curr_svc_type = ssh_jump_service.spec.type
967
1583
  if service_type.value == curr_svc_type:
968
1584
  # If the currently existing SSH Jump service's type is identical
@@ -974,9 +1590,9 @@ def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str,
974
1590
  # If a different type of service type for SSH Jump pod compared
975
1591
  # to user's configuration for networking mode exists, we remove
976
1592
  # existing servie to create a new one following user's config
977
- kubernetes.core_api().delete_namespaced_service(
1593
+ kubernetes.core_api(context).delete_namespaced_service(
978
1594
  name=ssh_jump_name, namespace=namespace)
979
- kubernetes.core_api().create_namespaced_service(
1595
+ kubernetes.core_api(context).create_namespaced_service(
980
1596
  namespace, content['service_spec'])
981
1597
  port_forward_mode = (
982
1598
  kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD.value)
@@ -1005,7 +1621,8 @@ def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str,
1005
1621
 
1006
1622
 
1007
1623
  def setup_ssh_jump_pod(ssh_jump_name: str, ssh_jump_image: str,
1008
- ssh_key_secret: str, namespace: str):
1624
+ ssh_key_secret: str, namespace: str,
1625
+ context: Optional[str]):
1009
1626
  """Sets up Kubernetes RBAC and pod for SSH jump host.
1010
1627
 
1011
1628
  Our Kubernetes implementation uses a SSH jump pod to reach SkyPilot clusters
@@ -1034,7 +1651,7 @@ def setup_ssh_jump_pod(ssh_jump_name: str, ssh_jump_image: str,
1034
1651
 
1035
1652
  # ServiceAccount
1036
1653
  try:
1037
- kubernetes.core_api().create_namespaced_service_account(
1654
+ kubernetes.core_api(context).create_namespaced_service_account(
1038
1655
  namespace, content['service_account'])
1039
1656
  except kubernetes.api_exception() as e:
1040
1657
  if e.status == 409:
@@ -1047,7 +1664,8 @@ def setup_ssh_jump_pod(ssh_jump_name: str, ssh_jump_image: str,
1047
1664
  logger.info('Created SSH Jump ServiceAccount.')
1048
1665
  # Role
1049
1666
  try:
1050
- kubernetes.auth_api().create_namespaced_role(namespace, content['role'])
1667
+ kubernetes.auth_api(context).create_namespaced_role(
1668
+ namespace, content['role'])
1051
1669
  except kubernetes.api_exception() as e:
1052
1670
  if e.status == 409:
1053
1671
  logger.info(
@@ -1058,7 +1676,7 @@ def setup_ssh_jump_pod(ssh_jump_name: str, ssh_jump_image: str,
1058
1676
  logger.info('Created SSH Jump Role.')
1059
1677
  # RoleBinding
1060
1678
  try:
1061
- kubernetes.auth_api().create_namespaced_role_binding(
1679
+ kubernetes.auth_api(context).create_namespaced_role_binding(
1062
1680
  namespace, content['role_binding'])
1063
1681
  except kubernetes.api_exception() as e:
1064
1682
  if e.status == 409:
@@ -1071,8 +1689,8 @@ def setup_ssh_jump_pod(ssh_jump_name: str, ssh_jump_image: str,
1071
1689
  logger.info('Created SSH Jump RoleBinding.')
1072
1690
  # Pod
1073
1691
  try:
1074
- kubernetes.core_api().create_namespaced_pod(namespace,
1075
- content['pod_spec'])
1692
+ kubernetes.core_api(context).create_namespaced_pod(
1693
+ namespace, content['pod_spec'])
1076
1694
  except kubernetes.api_exception() as e:
1077
1695
  if e.status == 409:
1078
1696
  logger.info(
@@ -1084,7 +1702,8 @@ def setup_ssh_jump_pod(ssh_jump_name: str, ssh_jump_image: str,
1084
1702
  logger.info(f'Created SSH Jump Host {ssh_jump_name}.')
1085
1703
 
1086
1704
 
1087
- def clean_zombie_ssh_jump_pod(namespace: str, node_id: str):
1705
+ def clean_zombie_ssh_jump_pod(namespace: str, context: Optional[str],
1706
+ node_id: str):
1088
1707
  """Analyzes SSH jump pod and removes if it is in a bad state
1089
1708
 
1090
1709
  Prevents the existence of a dangling SSH jump pod. This could happen
@@ -1100,11 +1719,12 @@ def clean_zombie_ssh_jump_pod(namespace: str, node_id: str):
1100
1719
  def find(l, predicate):
1101
1720
  """Utility function to find element in given list"""
1102
1721
  results = [x for x in l if predicate(x)]
1103
- return results[0] if len(results) > 0 else None
1722
+ return results[0] if results else None
1104
1723
 
1105
1724
  # Get the SSH jump pod name from the head pod
1106
1725
  try:
1107
- pod = kubernetes.core_api().read_namespaced_pod(node_id, namespace)
1726
+ pod = kubernetes.core_api(context).read_namespaced_pod(
1727
+ node_id, namespace)
1108
1728
  except kubernetes.api_exception() as e:
1109
1729
  if e.status == 404:
1110
1730
  logger.warning(f'Failed to get pod {node_id},'
@@ -1113,7 +1733,7 @@ def clean_zombie_ssh_jump_pod(namespace: str, node_id: str):
1113
1733
  else:
1114
1734
  ssh_jump_name = pod.metadata.labels.get('skypilot-ssh-jump')
1115
1735
  try:
1116
- ssh_jump_pod = kubernetes.core_api().read_namespaced_pod(
1736
+ ssh_jump_pod = kubernetes.core_api(context).read_namespaced_pod(
1117
1737
  ssh_jump_name, namespace)
1118
1738
  cont_ready_cond = find(ssh_jump_pod.status.conditions,
1119
1739
  lambda c: c.type == 'ContainersReady')
@@ -1124,9 +1744,9 @@ def clean_zombie_ssh_jump_pod(namespace: str, node_id: str):
1124
1744
  # ssh jump pod, lets remove it and the service. Otherwise, main
1125
1745
  # container is ready and its lifecycle management script takes
1126
1746
  # care of the cleaning.
1127
- kubernetes.core_api().delete_namespaced_pod(ssh_jump_name,
1128
- namespace)
1129
- kubernetes.core_api().delete_namespaced_service(
1747
+ kubernetes.core_api(context).delete_namespaced_pod(
1748
+ ssh_jump_name, namespace)
1749
+ kubernetes.core_api(context).delete_namespaced_service(
1130
1750
  ssh_jump_name, namespace)
1131
1751
  except kubernetes.api_exception() as e:
1132
1752
  # We keep the warning in debug to avoid polluting the `sky launch`
@@ -1138,7 +1758,7 @@ def clean_zombie_ssh_jump_pod(namespace: str, node_id: str):
1138
1758
  # We encountered an issue while checking ssh jump pod. To be on
1139
1759
  # the safe side, lets remove its service so the port is freed
1140
1760
  try:
1141
- kubernetes.core_api().delete_namespaced_service(
1761
+ kubernetes.core_api(context).delete_namespaced_service(
1142
1762
  ssh_jump_name, namespace)
1143
1763
  except kubernetes.api_exception():
1144
1764
  pass
@@ -1245,50 +1865,12 @@ def get_endpoint_debug_message() -> str:
1245
1865
  debug_cmd=debug_cmd)
1246
1866
 
1247
1867
 
1248
- def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]):
1249
- """Merge two dictionaries into the destination dictionary.
1250
-
1251
- Updates nested dictionaries instead of replacing them.
1252
- If a list is encountered, it will be appended to the destination list.
1253
-
1254
- An exception is when the key is 'containers', in which case the
1255
- first container in the list will be fetched and merge_dict will be
1256
- called on it with the first container in the destination list.
1257
- """
1258
- for key, value in source.items():
1259
- if isinstance(value, dict) and key in destination:
1260
- merge_dicts(value, destination[key])
1261
- elif isinstance(value, list) and key in destination:
1262
- assert isinstance(destination[key], list), \
1263
- f'Expected {key} to be a list, found {destination[key]}'
1264
- if key == 'containers':
1265
- # If the key is 'containers', we take the first and only
1266
- # container in the list and merge it.
1267
- assert len(value) == 1, \
1268
- f'Expected only one container, found {value}'
1269
- merge_dicts(value[0], destination[key][0])
1270
- elif key in ['volumes', 'volumeMounts']:
1271
- # If the key is 'volumes' or 'volumeMounts', we search for
1272
- # item with the same name and merge it.
1273
- for new_volume in value:
1274
- new_volume_name = new_volume.get('name')
1275
- if new_volume_name is not None:
1276
- destination_volume = next(
1277
- (v for v in destination[key]
1278
- if v.get('name') == new_volume_name), None)
1279
- if destination_volume is not None:
1280
- merge_dicts(new_volume, destination_volume)
1281
- else:
1282
- destination[key].append(new_volume)
1283
- else:
1284
- destination[key].extend(value)
1285
- else:
1286
- destination[key] = value
1287
-
1288
-
1289
- def combine_pod_config_fields(cluster_yaml_path: str) -> None:
1290
- """Adds or updates fields in the YAML with fields from the ~/.sky/config's
1291
- kubernetes.pod_spec dict.
1868
+ def combine_pod_config_fields(
1869
+ cluster_yaml_path: str,
1870
+ cluster_config_overrides: Dict[str, Any],
1871
+ ) -> None:
1872
+ """Adds or updates fields in the YAML with fields from the
1873
+ ~/.sky/config.yaml's kubernetes.pod_spec dict.
1292
1874
  This can be used to add fields to the YAML that are not supported by
1293
1875
  SkyPilot yet, or require simple configuration (e.g., adding an
1294
1876
  imagePullSecrets field).
@@ -1328,13 +1910,19 @@ def combine_pod_config_fields(cluster_yaml_path: str) -> None:
1328
1910
  with open(cluster_yaml_path, 'r', encoding='utf-8') as f:
1329
1911
  yaml_content = f.read()
1330
1912
  yaml_obj = yaml.safe_load(yaml_content)
1913
+ # We don't use override_configs in `skypilot_config.get_nested`, as merging
1914
+ # the pod config requires special handling.
1331
1915
  kubernetes_config = skypilot_config.get_nested(('kubernetes', 'pod_config'),
1332
- {})
1916
+ default_value={},
1917
+ override_configs={})
1918
+ override_pod_config = (cluster_config_overrides.get('kubernetes', {}).get(
1919
+ 'pod_config', {}))
1920
+ config_utils.merge_k8s_configs(kubernetes_config, override_pod_config)
1333
1921
 
1334
1922
  # Merge the kubernetes config into the YAML for both head and worker nodes.
1335
- merge_dicts(
1336
- kubernetes_config,
1337
- yaml_obj['available_node_types']['ray_head_default']['node_config'])
1923
+ config_utils.merge_k8s_configs(
1924
+ yaml_obj['available_node_types']['ray_head_default']['node_config'],
1925
+ kubernetes_config)
1338
1926
 
1339
1927
  # Write the updated YAML back to the file
1340
1928
  common_utils.dump_yaml(cluster_yaml_path, yaml_obj)
@@ -1342,7 +1930,7 @@ def combine_pod_config_fields(cluster_yaml_path: str) -> None:
1342
1930
 
1343
1931
  def combine_metadata_fields(cluster_yaml_path: str) -> None:
1344
1932
  """Updates the metadata for all Kubernetes objects created by SkyPilot with
1345
- fields from the ~/.sky/config's kubernetes.custom_metadata dict.
1933
+ fields from the ~/.sky/config.yaml's kubernetes.custom_metadata dict.
1346
1934
 
1347
1935
  Obeys the same add or update semantics as combine_pod_config_fields().
1348
1936
  """
@@ -1368,7 +1956,7 @@ def combine_metadata_fields(cluster_yaml_path: str) -> None:
1368
1956
  ]
1369
1957
 
1370
1958
  for destination in combination_destinations:
1371
- merge_dicts(custom_metadata, destination)
1959
+ config_utils.merge_k8s_configs(destination, custom_metadata)
1372
1960
 
1373
1961
  # Write the updated YAML back to the file
1374
1962
  common_utils.dump_yaml(cluster_yaml_path, yaml_obj)
@@ -1381,13 +1969,13 @@ def merge_custom_metadata(original_metadata: Dict[str, Any]) -> None:
1381
1969
  """
1382
1970
  custom_metadata = skypilot_config.get_nested(
1383
1971
  ('kubernetes', 'custom_metadata'), {})
1384
- merge_dicts(custom_metadata, original_metadata)
1972
+ config_utils.merge_k8s_configs(original_metadata, custom_metadata)
1385
1973
 
1386
1974
 
1387
- def check_nvidia_runtime_class() -> bool:
1975
+ def check_nvidia_runtime_class(context: Optional[str] = None) -> bool:
1388
1976
  """Checks if the 'nvidia' RuntimeClass exists in the cluster"""
1389
1977
  # Fetch the list of available RuntimeClasses
1390
- runtime_classes = kubernetes.node_api().list_runtime_class()
1978
+ runtime_classes = kubernetes.node_api(context).list_runtime_class()
1391
1979
 
1392
1980
  # Check if 'nvidia' RuntimeClass exists
1393
1981
  nvidia_exists = any(
@@ -1395,7 +1983,8 @@ def check_nvidia_runtime_class() -> bool:
1395
1983
  return nvidia_exists
1396
1984
 
1397
1985
 
1398
- def check_secret_exists(secret_name: str, namespace: str) -> bool:
1986
+ def check_secret_exists(secret_name: str, namespace: str,
1987
+ context: Optional[str]) -> bool:
1399
1988
  """Checks if a secret exists in a namespace
1400
1989
 
1401
1990
  Args:
@@ -1404,7 +1993,7 @@ def check_secret_exists(secret_name: str, namespace: str) -> bool:
1404
1993
  """
1405
1994
 
1406
1995
  try:
1407
- kubernetes.core_api().read_namespaced_secret(
1996
+ kubernetes.core_api(context).read_namespaced_secret(
1408
1997
  secret_name, namespace, _request_timeout=kubernetes.API_TIMEOUT)
1409
1998
  except kubernetes.api_exception() as e:
1410
1999
  if e.status == 404:
@@ -1414,20 +2003,29 @@ def check_secret_exists(secret_name: str, namespace: str) -> bool:
1414
2003
  return True
1415
2004
 
1416
2005
 
1417
- def create_namespace(namespace: str) -> None:
2006
+ def create_namespace(namespace: str, context: Optional[str]) -> None:
1418
2007
  """Creates a namespace in the cluster.
1419
2008
 
1420
2009
  If the namespace already exists, logs a message and does nothing.
1421
2010
 
1422
2011
  Args:
1423
2012
  namespace: Name of the namespace to create
2013
+ context: Name of the context to use. Can be none to use default context.
1424
2014
  """
1425
2015
  kubernetes_client = kubernetes.kubernetes.client
2016
+ try:
2017
+ kubernetes.core_api(context).read_namespace(namespace)
2018
+ except kubernetes.api_exception() as e:
2019
+ if e.status != 404:
2020
+ raise
2021
+ else:
2022
+ return
2023
+
1426
2024
  ns_metadata = dict(name=namespace, labels={'parent': 'skypilot'})
1427
2025
  merge_custom_metadata(ns_metadata)
1428
2026
  namespace_obj = kubernetes_client.V1Namespace(metadata=ns_metadata)
1429
2027
  try:
1430
- kubernetes.core_api().create_namespace(namespace_obj)
2028
+ kubernetes.core_api(context).create_namespace(namespace_obj)
1431
2029
  except kubernetes.api_exception() as e:
1432
2030
  if e.status == 409:
1433
2031
  logger.info(f'Namespace {namespace} already exists in the cluster.')
@@ -1453,7 +2051,7 @@ def get_head_pod_name(cluster_name_on_cloud: str):
1453
2051
  def get_autoscaler_type(
1454
2052
  ) -> Optional[kubernetes_enums.KubernetesAutoscalerType]:
1455
2053
  """Returns the autoscaler type by reading from config"""
1456
- autoscaler_type = skypilot_config.get_nested(['kubernetes', 'autoscaler'],
2054
+ autoscaler_type = skypilot_config.get_nested(('kubernetes', 'autoscaler'),
1457
2055
  None)
1458
2056
  if autoscaler_type is not None:
1459
2057
  autoscaler_type = kubernetes_enums.KubernetesAutoscalerType(
@@ -1461,6 +2059,45 @@ def get_autoscaler_type(
1461
2059
  return autoscaler_type
1462
2060
 
1463
2061
 
2062
+ # Mapping of known spot label keys and values for different cluster types
2063
+ # Add new cluster types here if they support spot instances along with the
2064
+ # corresponding spot label key and value.
2065
+ SPOT_LABEL_MAP = {
2066
+ kubernetes_enums.KubernetesAutoscalerType.GKE.value:
2067
+ ('cloud.google.com/gke-spot', 'true')
2068
+ }
2069
+
2070
+
2071
+ def get_spot_label(
2072
+ context: Optional[str] = None) -> Tuple[Optional[str], Optional[str]]:
2073
+ """Get the spot label key and value for using spot instances, if supported.
2074
+
2075
+ Checks if the underlying cluster supports spot instances by checking nodes
2076
+ for known spot label keys and values. If found, returns the spot label key
2077
+ and value. If not, checks if autoscaler is configured and returns
2078
+ appropriate labels. If neither are found, returns None.
2079
+
2080
+ Returns:
2081
+ Tuple[str, str]: Tuple containing the spot label key and value. Returns
2082
+ None if spot instances are not supported.
2083
+ """
2084
+ # Check if the cluster supports spot instances by checking nodes for known
2085
+ # spot label keys and values
2086
+ for node in get_kubernetes_nodes(context):
2087
+ for _, (key, value) in SPOT_LABEL_MAP.items():
2088
+ if key in node.metadata.labels and node.metadata.labels[
2089
+ key] == value:
2090
+ return key, value
2091
+
2092
+ # Check if autoscaler is configured. Allow spot instances if autoscaler type
2093
+ # is known to support spot instances.
2094
+ autoscaler_type = get_autoscaler_type()
2095
+ if autoscaler_type == kubernetes_enums.KubernetesAutoscalerType.GKE:
2096
+ return SPOT_LABEL_MAP[autoscaler_type.value]
2097
+
2098
+ return None, None
2099
+
2100
+
1464
2101
  def dict_to_k8s_object(object_dict: Dict[str, Any], object_type: 'str') -> Any:
1465
2102
  """Converts a dictionary to a Kubernetes object.
1466
2103
 
@@ -1479,3 +2116,485 @@ def dict_to_k8s_object(object_dict: Dict[str, Any], object_type: 'str') -> Any:
1479
2116
 
1480
2117
  fake_kube_response = FakeKubeResponse(object_dict)
1481
2118
  return kubernetes.api_client().deserialize(fake_kube_response, object_type)
2119
+
2120
+
2121
+ def get_kubernetes_node_info(
2122
+ context: Optional[str] = None) -> Dict[str, models.KubernetesNodeInfo]:
2123
+ """Gets the resource information for all the nodes in the cluster.
2124
+
2125
+ Currently only GPU resources are supported. The function returns the total
2126
+ number of GPUs available on the node and the number of free GPUs on the
2127
+ node.
2128
+
2129
+ If the user does not have sufficient permissions to list pods in all
2130
+ namespaces, the function will return free GPUs as -1.
2131
+
2132
+ Returns:
2133
+ Dict[str, KubernetesNodeInfo]: Dictionary containing the node name as
2134
+ key and the KubernetesNodeInfo object as value
2135
+ """
2136
+ nodes = get_kubernetes_nodes(context)
2137
+ # Get the pods to get the real-time resource usage
2138
+ try:
2139
+ pods = get_all_pods_in_kubernetes_cluster(context)
2140
+ except kubernetes.api_exception() as e:
2141
+ if e.status == 403:
2142
+ pods = None
2143
+ else:
2144
+ raise
2145
+
2146
+ lf, _ = detect_gpu_label_formatter(context)
2147
+ if not lf:
2148
+ label_keys = []
2149
+ else:
2150
+ label_keys = lf.get_label_keys()
2151
+
2152
+ node_info_dict: Dict[str, models.KubernetesNodeInfo] = {}
2153
+
2154
+ for node in nodes:
2155
+ accelerator_name = None
2156
+ # Determine the accelerator name from the node labels and pick the
2157
+ # first one found. We assume that the node has only one accelerator type
2158
+ # (e.g., either GPU or TPU).
2159
+ for label_key in label_keys:
2160
+ if lf is not None and label_key in node.metadata.labels:
2161
+ accelerator_name = lf.get_accelerator_from_label_value(
2162
+ node.metadata.labels.get(label_key))
2163
+ break
2164
+
2165
+ allocated_qty = 0
2166
+ accelerator_count = get_node_accelerator_count(node.status.allocatable)
2167
+
2168
+ if pods is None:
2169
+ accelerators_available = -1
2170
+
2171
+ else:
2172
+ for pod in pods:
2173
+ # Get all the pods running on the node
2174
+ if (pod.spec.node_name == node.metadata.name and
2175
+ pod.status.phase in ['Running', 'Pending']):
2176
+ # Iterate over all the containers in the pod and sum the
2177
+ # GPU requests
2178
+ for container in pod.spec.containers:
2179
+ if container.resources.requests:
2180
+ allocated_qty += get_node_accelerator_count(
2181
+ container.resources.requests)
2182
+
2183
+ accelerators_available = accelerator_count - allocated_qty
2184
+
2185
+ # Exclude multi-host TPUs from being processed.
2186
+ # TODO(Doyoung): Remove the logic when adding support for
2187
+ # multi-host TPUs.
2188
+ if is_multi_host_tpu(node.metadata.labels):
2189
+ continue
2190
+
2191
+ node_info_dict[node.metadata.name] = models.KubernetesNodeInfo(
2192
+ name=node.metadata.name,
2193
+ accelerator_type=accelerator_name,
2194
+ total={'accelerator_count': int(accelerator_count)},
2195
+ free={'accelerators_available': int(accelerators_available)})
2196
+
2197
+ return node_info_dict
2198
+
2199
+
2200
+ def to_label_selector(tags):
2201
+ label_selector = ''
2202
+ for k, v in tags.items():
2203
+ if label_selector != '':
2204
+ label_selector += ','
2205
+ label_selector += '{}={}'.format(k, v)
2206
+ return label_selector
2207
+
2208
+
2209
+ def get_namespace_from_config(provider_config: Dict[str, Any]) -> str:
2210
+ context = get_context_from_config(provider_config)
2211
+ return provider_config.get('namespace',
2212
+ get_kube_config_context_namespace(context))
2213
+
2214
+
2215
+ @timeline.event
2216
+ def filter_pods(namespace: str,
2217
+ context: Optional[str],
2218
+ tag_filters: Dict[str, str],
2219
+ status_filters: Optional[List[str]] = None) -> Dict[str, Any]:
2220
+ """Filters pods by tags and status."""
2221
+ non_included_pod_statuses = POD_STATUSES.copy()
2222
+
2223
+ field_selector = ''
2224
+ if status_filters is not None:
2225
+ non_included_pod_statuses -= set(status_filters)
2226
+ field_selector = ','.join(
2227
+ [f'status.phase!={status}' for status in non_included_pod_statuses])
2228
+
2229
+ label_selector = to_label_selector(tag_filters)
2230
+ pod_list = kubernetes.core_api(context).list_namespaced_pod(
2231
+ namespace, field_selector=field_selector, label_selector=label_selector)
2232
+
2233
+ # Don't return pods marked for deletion,
2234
+ # i.e. pods with non-null metadata.DeletionTimestamp.
2235
+ pods = [
2236
+ pod for pod in pod_list.items if pod.metadata.deletion_timestamp is None
2237
+ ]
2238
+ return {pod.metadata.name: pod for pod in pods}
2239
+
2240
+
2241
+ def _remove_pod_annotation(pod: Any,
2242
+ annotation_key: str,
2243
+ namespace: str,
2244
+ context: Optional[str] = None) -> None:
2245
+ """Removes specified Annotations from a Kubernetes pod."""
2246
+ try:
2247
+ # Remove the specified annotation
2248
+ if pod.metadata.annotations:
2249
+ if annotation_key in pod.metadata.annotations:
2250
+ # Patch the pod with the updated metadata.
2251
+ body = {'metadata': {'annotations': {annotation_key: None}}}
2252
+ kubernetes.core_api(context).patch_namespaced_pod(
2253
+ name=pod.metadata.name,
2254
+ namespace=namespace,
2255
+ body=body,
2256
+ _request_timeout=kubernetes.API_TIMEOUT)
2257
+
2258
+ except kubernetes.api_exception() as e:
2259
+ if e.status == 404:
2260
+ logger.warning(
2261
+ ANNOTATIONS_POD_NOT_FOUND_ERROR_MSG.format(
2262
+ pod_name=pod.metadata.name,
2263
+ namespace=namespace,
2264
+ action='remove',
2265
+ annotation=annotation_key))
2266
+ else:
2267
+ with ux_utils.print_exception_no_traceback():
2268
+ raise
2269
+
2270
+
2271
+ def _add_pod_annotation(pod: Any,
2272
+ annotation: Dict[str, str],
2273
+ namespace: str,
2274
+ context: Optional[str] = None) -> None:
2275
+ """Adds specified Annotations on a Kubernetes pod."""
2276
+ try:
2277
+ # Patch the pod with the updated metadata
2278
+ body = {'metadata': {'annotations': annotation}}
2279
+ kubernetes.core_api(context).patch_namespaced_pod(
2280
+ name=pod.metadata.name,
2281
+ namespace=namespace,
2282
+ body=body,
2283
+ _request_timeout=kubernetes.API_TIMEOUT)
2284
+
2285
+ except kubernetes.api_exception() as e:
2286
+ if e.status == 404:
2287
+ logger.warning(
2288
+ ANNOTATIONS_POD_NOT_FOUND_ERROR_MSG.format(
2289
+ pod_name=pod.metadata.name,
2290
+ namespace=namespace,
2291
+ action='add',
2292
+ annotation=annotation))
2293
+ else:
2294
+ with ux_utils.print_exception_no_traceback():
2295
+ raise
2296
+
2297
+
2298
+ def set_autodown_annotations(handle: 'backends.CloudVmRayResourceHandle',
2299
+ idle_minutes_to_autostop: Optional[int],
2300
+ down: bool = False) -> None:
2301
+ """Adds or removes Annotations of autodown on Kubernetes pods."""
2302
+ tags = {
2303
+ provision_constants.TAG_RAY_CLUSTER_NAME: handle.cluster_name_on_cloud,
2304
+ }
2305
+ ray_config = common_utils.read_yaml(handle.cluster_yaml)
2306
+ provider_config = ray_config['provider']
2307
+ namespace = get_namespace_from_config(provider_config)
2308
+ context = get_context_from_config(provider_config)
2309
+ running_pods = filter_pods(namespace, context, tags)
2310
+
2311
+ for _, pod in running_pods.items():
2312
+ if down:
2313
+ idle_minutes_to_autostop_annotation = {
2314
+ IDLE_MINUTES_TO_AUTOSTOP_ANNOTATION_KEY:
2315
+ str(idle_minutes_to_autostop)
2316
+ }
2317
+ autodown_annotation = {AUTODOWN_ANNOTATION_KEY: 'true'}
2318
+ _add_pod_annotation(pod=pod,
2319
+ annotation=idle_minutes_to_autostop_annotation,
2320
+ namespace=namespace,
2321
+ context=context)
2322
+ _add_pod_annotation(pod=pod,
2323
+ annotation=autodown_annotation,
2324
+ namespace=namespace,
2325
+ context=context)
2326
+
2327
+ # If idle_minutes_to_autostop is negative, it indicates a request to
2328
+ # cancel autostop using the --cancel flag with the `sky autostop`
2329
+ # command.
2330
+ elif (idle_minutes_to_autostop is not None and
2331
+ idle_minutes_to_autostop < 0):
2332
+ _remove_pod_annotation(
2333
+ pod=pod,
2334
+ annotation_key=IDLE_MINUTES_TO_AUTOSTOP_ANNOTATION_KEY,
2335
+ namespace=namespace,
2336
+ context=context)
2337
+ _remove_pod_annotation(pod=pod,
2338
+ annotation_key=AUTODOWN_ANNOTATION_KEY,
2339
+ namespace=namespace,
2340
+ context=context)
2341
+
2342
+
2343
+ def get_context_from_config(provider_config: Dict[str, Any]) -> Optional[str]:
2344
+ context = provider_config.get('context',
2345
+ get_current_kube_config_context_name())
2346
+ if context == kubernetes.in_cluster_context_name():
2347
+ # If the context (also used as the region) is in-cluster, we need to
2348
+ # we need to use in-cluster auth by setting the context to None.
2349
+ context = None
2350
+ return context
2351
+
2352
+
2353
+ def get_skypilot_pods(context: Optional[str] = None) -> List[Any]:
2354
+ """Gets all SkyPilot pods in the Kubernetes cluster.
2355
+
2356
+ Args:
2357
+ context: Kubernetes context to use. If None, uses the current context.
2358
+
2359
+ Returns:
2360
+ A list of Kubernetes pod objects.
2361
+ """
2362
+ if context is None:
2363
+ context = get_current_kube_config_context_name()
2364
+
2365
+ try:
2366
+ pods = kubernetes.core_api(context).list_pod_for_all_namespaces(
2367
+ label_selector='skypilot-cluster',
2368
+ _request_timeout=kubernetes.API_TIMEOUT).items
2369
+ except kubernetes.max_retry_error():
2370
+ raise exceptions.ResourcesUnavailableError(
2371
+ 'Timed out trying to get SkyPilot pods from Kubernetes cluster. '
2372
+ 'Please check if the cluster is healthy and retry. To debug, run: '
2373
+ 'kubectl get pods --selector=skypilot-cluster --all-namespaces'
2374
+ ) from None
2375
+ return pods
2376
+
2377
+
2378
+ def is_tpu_on_gke(accelerator: str) -> bool:
2379
+ """Determines if the given accelerator is a TPU supported on GKE."""
2380
+ return accelerator in GKE_TPU_ACCELERATOR_TO_GENERATION
2381
+
2382
+
2383
+ def get_node_accelerator_count(attribute_dict: dict) -> int:
2384
+ """Retrieves the count of accelerators from a node's resource dictionary.
2385
+
2386
+ This method checks the node's allocatable resources or the accelerators
2387
+ already deployed on the node, using pod objects that describe resource
2388
+ requests.
2389
+
2390
+ Args:
2391
+ attribute_dict: Containing resource information from a node, such as
2392
+ allocatable or requested resources.
2393
+
2394
+ Returns:
2395
+ Number of accelerators allocated or available from the node. If no
2396
+ resource is found, it returns 0.
2397
+ """
2398
+ gpu_resource_name = get_gpu_resource_key()
2399
+ assert not (gpu_resource_name in attribute_dict and
2400
+ TPU_RESOURCE_KEY in attribute_dict)
2401
+ if gpu_resource_name in attribute_dict:
2402
+ return int(attribute_dict[gpu_resource_name])
2403
+ elif TPU_RESOURCE_KEY in attribute_dict:
2404
+ return int(attribute_dict[TPU_RESOURCE_KEY])
2405
+ return 0
2406
+
2407
+
2408
+ def reduce_tpu_topology(topology: str) -> int:
2409
+ """Computes the number of TPU chips from its topology string."""
2410
+ chip_dimensions = [int(chip_count) for chip_count in topology.split('x')]
2411
+ # tpu_topology_chip_count represents the total number of TPU chips in the
2412
+ # entire podslice, whether it is a single-host or multi-host TPU podslice.
2413
+ tpu_topology_chip_count = functools.reduce(lambda x, y: x * y,
2414
+ chip_dimensions)
2415
+ return tpu_topology_chip_count
2416
+
2417
+
2418
+ def is_multi_host_tpu(node_metadata_labels: dict) -> bool:
2419
+ """Determines whether the given node is a multi-host TPU configuration."""
2420
+ if GKELabelFormatter.TPU_LABEL_KEY in node_metadata_labels:
2421
+ assert GKELabelFormatter.TPU_TOPOLOGY_LABEL_KEY in node_metadata_labels
2422
+ topology_value = (
2423
+ node_metadata_labels[GKELabelFormatter.TPU_TOPOLOGY_LABEL_KEY])
2424
+ accelerator_count_label_key = (
2425
+ GKELabelFormatter.ACCELERATOR_COUNT_LABEL_KEY)
2426
+ assert accelerator_count_label_key in node_metadata_labels
2427
+ # node_tpu_chip_count represents the number of TPU chips
2428
+ # available in this node. If the node is part of a node pool
2429
+ # forming a multi-host TPU podslice, it only reflects the
2430
+ # number of TPU chips in this individual node, not the entire
2431
+ # multi-host TPU podslice.
2432
+ node_tpu_chip_count = int(
2433
+ node_metadata_labels[accelerator_count_label_key])
2434
+ topology_chip_count = reduce_tpu_topology(topology_value)
2435
+ # For multi-host TPU podslices, topology_chip_count and
2436
+ # node_tpu_chip_count will differ, as topology_chip_count
2437
+ # reflects the total across all hosts, while
2438
+ # node_tpu_chip_count reflects only the chips in a single node.
2439
+ if node_tpu_chip_count != topology_chip_count:
2440
+ return True
2441
+ return False
2442
+
2443
+
2444
+ def multi_host_tpu_exists_in_cluster(context: Optional[str] = None) -> bool:
2445
+ """Checks if there exists a multi-host TPU within the cluster."""
2446
+ nodes = get_kubernetes_nodes(context)
2447
+ for node in nodes:
2448
+ if is_multi_host_tpu(node.metadata.labels):
2449
+ return True
2450
+ return False
2451
+
2452
+
2453
+ @dataclasses.dataclass
2454
+ class KubernetesSkyPilotClusterInfo:
2455
+ cluster_name_on_cloud: str
2456
+ cluster_name: str
2457
+ user: str
2458
+ status: status_lib.ClusterStatus
2459
+ pods: List[Any]
2460
+ launched_at: float
2461
+ resources: 'resources_lib.Resources'
2462
+ resources_str: str
2463
+
2464
+
2465
+ @dataclasses.dataclass
2466
+ class KubernetesSkyPilotClusterInfoPayload:
2467
+ """SkyPilot Cluster on Kubernetes payload."""
2468
+ cluster_name_on_cloud: str
2469
+ cluster_name: str
2470
+ user: str
2471
+ status: status_lib.ClusterStatus
2472
+ resources_str: str
2473
+ launched_at: float
2474
+
2475
+ @classmethod
2476
+ def from_cluster(
2477
+ cls, cluster: KubernetesSkyPilotClusterInfo
2478
+ ) -> 'KubernetesSkyPilotClusterInfoPayload':
2479
+ resources_str = f'{len(cluster.pods)}x {cluster.resources}'
2480
+ return cls(
2481
+ cluster_name_on_cloud=cluster.cluster_name_on_cloud,
2482
+ cluster_name=cluster.cluster_name,
2483
+ user=cluster.user,
2484
+ status=cluster.status,
2485
+ resources_str=resources_str,
2486
+ launched_at=cluster.launched_at,
2487
+ )
2488
+
2489
+
2490
+ def process_skypilot_pods(
2491
+ pods: List[Any],
2492
+ context: Optional[str] = None
2493
+ ) -> Tuple[List[KubernetesSkyPilotClusterInfo],
2494
+ List[KubernetesSkyPilotClusterInfo],
2495
+ List[KubernetesSkyPilotClusterInfo]]:
2496
+ """Process SkyPilot pods on k8s to extract cluster and controller info.
2497
+
2498
+ Args:
2499
+ pods: List of Kubernetes pod objects.
2500
+ context: Kubernetes context name, used to detect GPU label formatter.
2501
+
2502
+ Returns:
2503
+ A tuple containing:
2504
+ - List of KubernetesSkyPilotClusterInfo with all cluster info.
2505
+ - List of KubernetesSkyPilotClusterInfo with job controller info.
2506
+ - List of KubernetesSkyPilotClusterInfo with serve controller info.
2507
+ """
2508
+ # pylint: disable=import-outside-toplevel
2509
+ from sky import resources as resources_lib
2510
+ clusters: Dict[str, KubernetesSkyPilotClusterInfo] = {}
2511
+ jobs_controllers: List[KubernetesSkyPilotClusterInfo] = []
2512
+ serve_controllers: List[KubernetesSkyPilotClusterInfo] = []
2513
+
2514
+ for pod in pods:
2515
+ cluster_name_on_cloud = pod.metadata.labels.get('skypilot-cluster')
2516
+ cluster_name = cluster_name_on_cloud.rsplit(
2517
+ '-', 1
2518
+ )[0] # Remove the user hash to get cluster name (e.g., mycluster-2ea4)
2519
+ if cluster_name_on_cloud not in clusters:
2520
+ # Parse the start time for the cluster
2521
+ start_time = pod.status.start_time
2522
+ if start_time is not None:
2523
+ start_time = pod.status.start_time.timestamp()
2524
+
2525
+ # Parse resources
2526
+ cpu_request = parse_cpu_or_gpu_resource(
2527
+ pod.spec.containers[0].resources.requests.get('cpu', '0'))
2528
+ memory_request = parse_memory_resource(
2529
+ pod.spec.containers[0].resources.requests.get('memory', '0'),
2530
+ unit='G')
2531
+ gpu_count = parse_cpu_or_gpu_resource(
2532
+ pod.spec.containers[0].resources.requests.get(
2533
+ 'nvidia.com/gpu', '0'))
2534
+ gpu_name = None
2535
+ if gpu_count > 0:
2536
+ label_formatter, _ = (detect_gpu_label_formatter(context))
2537
+ assert label_formatter is not None, (
2538
+ 'GPU label formatter cannot be None if there are pods '
2539
+ f'requesting GPUs: {pod.metadata.name}')
2540
+ gpu_label = label_formatter.get_label_key()
2541
+ # Get GPU name from pod node selector
2542
+ if pod.spec.node_selector is not None:
2543
+ gpu_name = label_formatter.get_accelerator_from_label_value(
2544
+ pod.spec.node_selector.get(gpu_label))
2545
+
2546
+ resources = resources_lib.Resources(
2547
+ cloud=clouds.Kubernetes(),
2548
+ cpus=int(cpu_request),
2549
+ memory=int(memory_request),
2550
+ accelerators=(f'{gpu_name}:{gpu_count}'
2551
+ if gpu_count > 0 else None))
2552
+ if pod.status.phase == 'Pending':
2553
+ # If pod is pending, do not show it in the status
2554
+ continue
2555
+
2556
+ cluster_info = KubernetesSkyPilotClusterInfo(
2557
+ cluster_name_on_cloud=cluster_name_on_cloud,
2558
+ cluster_name=cluster_name,
2559
+ user=pod.metadata.labels.get('skypilot-user'),
2560
+ status=status_lib.ClusterStatus.UP,
2561
+ pods=[],
2562
+ launched_at=start_time,
2563
+ resources=resources,
2564
+ resources_str='')
2565
+ clusters[cluster_name_on_cloud] = cluster_info
2566
+ # Check if cluster name is name of a controller
2567
+ # Can't use controller_utils.Controllers.from_name(cluster_name)
2568
+ # because hash is different across users
2569
+ if 'sky-jobs-controller' in cluster_name_on_cloud:
2570
+ jobs_controllers.append(cluster_info)
2571
+ elif 'sky-serve-controller' in cluster_name_on_cloud:
2572
+ serve_controllers.append(cluster_info)
2573
+ else:
2574
+ # Update start_time if this pod started earlier
2575
+ pod_start_time = pod.status.start_time
2576
+ if pod_start_time is not None:
2577
+ pod_start_time = pod_start_time.timestamp()
2578
+ if pod_start_time < clusters[cluster_name_on_cloud].launched_at:
2579
+ clusters[cluster_name_on_cloud].launched_at = pod_start_time
2580
+ clusters[cluster_name_on_cloud].pods.append(pod)
2581
+ # Update resources_str in clusters:
2582
+ for cluster in clusters.values():
2583
+ num_pods = len(cluster.pods)
2584
+ cluster.resources_str = f'{num_pods}x {cluster.resources}'
2585
+ return list(clusters.values()), jobs_controllers, serve_controllers
2586
+
2587
+
2588
+ def get_gpu_resource_key():
2589
+ """Get the GPU resource name to use in kubernetes.
2590
+ The function first checks for an environment variable.
2591
+ If defined, it uses its value; otherwise, it returns the default value.
2592
+ Args:
2593
+ name (str): Default GPU resource name, default is "nvidia.com/gpu".
2594
+ Returns:
2595
+ str: The selected GPU resource name.
2596
+ """
2597
+ # Retrieve GPU resource name from environment variable, if set.
2598
+ # Else use default.
2599
+ # E.g., can be nvidia.com/gpu-h100, amd.com/gpu etc.
2600
+ return os.getenv('CUSTOM_GPU_RESOURCE_KEY', default=GPU_RESOURCE_KEY)