skypilot-nightly 1.0.0.dev20240926__py3-none-any.whl → 1.0.0.dev20240927__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -79,13 +79,14 @@ def _open_ports_using_ingress(
79
79
  )
80
80
 
81
81
  # Prepare service names, ports, for template rendering
82
- service_details = [(f'{cluster_name_on_cloud}--skypilot-svc--{port}', port,
83
- _PATH_PREFIX.format(
84
- cluster_name_on_cloud=cluster_name_on_cloud,
85
- port=port,
86
- namespace=kubernetes_utils.
87
- get_current_kube_config_context_namespace()).rstrip(
88
- '/').lstrip('/')) for port in ports]
82
+ service_details = [
83
+ (f'{cluster_name_on_cloud}--skypilot-svc--{port}', port,
84
+ _PATH_PREFIX.format(
85
+ cluster_name_on_cloud=cluster_name_on_cloud,
86
+ port=port,
87
+ namespace=kubernetes_utils.get_kube_config_context_namespace(
88
+ context)).rstrip('/').lstrip('/')) for port in ports
89
+ ]
89
90
 
90
91
  # Generate ingress and services specs
91
92
  # We batch ingress rule creation because each rule triggers a hot reload of
@@ -171,7 +172,8 @@ def _cleanup_ports_for_ingress(
171
172
  for port in ports:
172
173
  service_name = f'{cluster_name_on_cloud}--skypilot-svc--{port}'
173
174
  network_utils.delete_namespaced_service(
174
- namespace=provider_config.get('namespace', 'default'),
175
+ namespace=provider_config.get('namespace',
176
+ kubernetes_utils.DEFAULT_NAMESPACE),
175
177
  service_name=service_name,
176
178
  )
177
179
 
@@ -208,11 +210,13 @@ def query_ports(
208
210
  return _query_ports_for_ingress(
209
211
  cluster_name_on_cloud=cluster_name_on_cloud,
210
212
  ports=ports,
213
+ provider_config=provider_config,
211
214
  )
212
215
  elif port_mode == kubernetes_enums.KubernetesPortMode.PODIP:
213
216
  return _query_ports_for_podip(
214
217
  cluster_name_on_cloud=cluster_name_on_cloud,
215
218
  ports=ports,
219
+ provider_config=provider_config,
216
220
  )
217
221
  else:
218
222
  return {}
@@ -231,8 +235,14 @@ def _query_ports_for_loadbalancer(
231
235
  result: Dict[int, List[common.Endpoint]] = {}
232
236
  service_name = _LOADBALANCER_SERVICE_NAME.format(
233
237
  cluster_name_on_cloud=cluster_name_on_cloud)
238
+ context = provider_config.get(
239
+ 'context', kubernetes_utils.get_current_kube_config_context_name())
240
+ namespace = provider_config.get(
241
+ 'namespace',
242
+ kubernetes_utils.get_kube_config_context_namespace(context))
234
243
  external_ip = network_utils.get_loadbalancer_ip(
235
- namespace=provider_config.get('namespace', 'default'),
244
+ context=context,
245
+ namespace=namespace,
236
246
  service_name=service_name,
237
247
  # Timeout is set so that we can retry the query when the
238
248
  # cluster is firstly created and the load balancer is not ready yet.
@@ -251,19 +261,24 @@ def _query_ports_for_loadbalancer(
251
261
  def _query_ports_for_ingress(
252
262
  cluster_name_on_cloud: str,
253
263
  ports: List[int],
264
+ provider_config: Dict[str, Any],
254
265
  ) -> Dict[int, List[common.Endpoint]]:
255
- ingress_details = network_utils.get_ingress_external_ip_and_ports()
266
+ context = provider_config.get(
267
+ 'context', kubernetes_utils.get_current_kube_config_context_name())
268
+ ingress_details = network_utils.get_ingress_external_ip_and_ports(context)
256
269
  external_ip, external_ports = ingress_details
257
270
  if external_ip is None:
258
271
  return {}
259
272
 
273
+ namespace = provider_config.get(
274
+ 'namespace',
275
+ kubernetes_utils.get_kube_config_context_namespace(context))
260
276
  result: Dict[int, List[common.Endpoint]] = {}
261
277
  for port in ports:
262
278
  path_prefix = _PATH_PREFIX.format(
263
279
  cluster_name_on_cloud=cluster_name_on_cloud,
264
280
  port=port,
265
- namespace=kubernetes_utils.
266
- get_current_kube_config_context_namespace())
281
+ namespace=namespace)
267
282
 
268
283
  http_port, https_port = external_ports \
269
284
  if external_ports is not None else (None, None)
@@ -282,10 +297,15 @@ def _query_ports_for_ingress(
282
297
  def _query_ports_for_podip(
283
298
  cluster_name_on_cloud: str,
284
299
  ports: List[int],
300
+ provider_config: Dict[str, Any],
285
301
  ) -> Dict[int, List[common.Endpoint]]:
286
- namespace = kubernetes_utils.get_current_kube_config_context_namespace()
302
+ context = provider_config.get(
303
+ 'context', kubernetes_utils.get_current_kube_config_context_name())
304
+ namespace = provider_config.get(
305
+ 'namespace',
306
+ kubernetes_utils.get_kube_config_context_namespace(context))
287
307
  pod_name = kubernetes_utils.get_head_pod_name(cluster_name_on_cloud)
288
- pod_ip = network_utils.get_pod_ip(namespace, pod_name)
308
+ pod_ip = network_utils.get_pod_ip(context, namespace, pod_name)
289
309
 
290
310
  result: Dict[int, List[common.Endpoint]] = {}
291
311
  if pod_ip is None:
@@ -220,10 +220,11 @@ def ingress_controller_exists(context: str,
220
220
 
221
221
 
222
222
  def get_ingress_external_ip_and_ports(
223
+ context: str,
223
224
  namespace: str = 'ingress-nginx'
224
225
  ) -> Tuple[Optional[str], Optional[Tuple[int, int]]]:
225
226
  """Returns external ip and ports for the ingress controller."""
226
- core_api = kubernetes.core_api()
227
+ core_api = kubernetes.core_api(context)
227
228
  ingress_services = [
228
229
  item for item in core_api.list_namespaced_service(
229
230
  namespace, _request_timeout=kubernetes.API_TIMEOUT).items
@@ -257,11 +258,12 @@ def get_ingress_external_ip_and_ports(
257
258
  return external_ip, None
258
259
 
259
260
 
260
- def get_loadbalancer_ip(namespace: str,
261
+ def get_loadbalancer_ip(context: str,
262
+ namespace: str,
261
263
  service_name: str,
262
264
  timeout: int = 0) -> Optional[str]:
263
265
  """Returns the IP address of the load balancer."""
264
- core_api = kubernetes.core_api()
266
+ core_api = kubernetes.core_api(context)
265
267
 
266
268
  ip = None
267
269
 
@@ -282,9 +284,9 @@ def get_loadbalancer_ip(namespace: str,
282
284
  return ip
283
285
 
284
286
 
285
- def get_pod_ip(namespace: str, pod_name: str) -> Optional[str]:
287
+ def get_pod_ip(context: str, namespace: str, pod_name: str) -> Optional[str]:
286
288
  """Returns the IP address of the pod."""
287
- core_api = kubernetes.core_api()
289
+ core_api = kubernetes.core_api(context)
288
290
  pod = core_api.read_namespaced_pod(pod_name,
289
291
  namespace,
290
292
  _request_timeout=kubernetes.API_TIMEOUT)
@@ -1,5 +1,6 @@
1
1
  """Kubernetes utilities for SkyPilot."""
2
2
  import dataclasses
3
+ import functools
3
4
  import json
4
5
  import math
5
6
  import os
@@ -307,7 +308,9 @@ AUTOSCALER_TO_LABEL_FORMATTER = {
307
308
  }
308
309
 
309
310
 
311
+ @functools.lru_cache()
310
312
  def detect_gpu_label_formatter(
313
+ context: str
311
314
  ) -> Tuple[Optional[GPULabelFormatter], Dict[str, List[Tuple[str, str]]]]:
312
315
  """Detects the GPU label formatter for the Kubernetes cluster
313
316
 
@@ -318,7 +321,7 @@ def detect_gpu_label_formatter(
318
321
  """
319
322
  # Get all labels across all nodes
320
323
  node_labels: Dict[str, List[Tuple[str, str]]] = {}
321
- nodes = get_kubernetes_nodes()
324
+ nodes = get_kubernetes_nodes(context)
322
325
  for node in nodes:
323
326
  node_labels[node.metadata.name] = []
324
327
  for label, value in node.metadata.labels.items():
@@ -338,7 +341,8 @@ def detect_gpu_label_formatter(
338
341
  return label_formatter, node_labels
339
342
 
340
343
 
341
- def detect_gpu_resource() -> Tuple[bool, Set[str]]:
344
+ @functools.lru_cache(maxsize=10)
345
+ def detect_gpu_resource(context: str) -> Tuple[bool, Set[str]]:
342
346
  """Checks if the Kubernetes cluster has nvidia.com/gpu resource.
343
347
 
344
348
  If nvidia.com/gpu resource is missing, that typically means that the
@@ -350,7 +354,7 @@ def detect_gpu_resource() -> Tuple[bool, Set[str]]:
350
354
  """
351
355
  # Get the set of resources across all nodes
352
356
  cluster_resources: Set[str] = set()
353
- nodes = get_kubernetes_nodes()
357
+ nodes = get_kubernetes_nodes(context)
354
358
  for node in nodes:
355
359
  cluster_resources.update(node.status.allocatable.keys())
356
360
  has_gpu = 'nvidia.com/gpu' in cluster_resources
@@ -358,12 +362,17 @@ def detect_gpu_resource() -> Tuple[bool, Set[str]]:
358
362
  return has_gpu, cluster_resources
359
363
 
360
364
 
361
- def get_kubernetes_nodes() -> List[Any]:
362
- # TODO(romilb): Calling kube API can take between 10-100ms depending on
363
- # the control plane. Consider caching calls to this function (using
364
- # kubecontext hash as key).
365
+ @functools.lru_cache(maxsize=10)
366
+ def get_kubernetes_nodes(context: Optional[str] = None) -> List[Any]:
367
+ """Gets the kubernetes nodes in the context.
368
+
369
+ If context is None, gets the nodes in the current context.
370
+ """
371
+ if context is None:
372
+ context = get_current_kube_config_context_name()
373
+
365
374
  try:
366
- nodes = kubernetes.core_api().list_node(
375
+ nodes = kubernetes.core_api(context).list_node(
367
376
  _request_timeout=kubernetes.API_TIMEOUT).items
368
377
  except kubernetes.max_retry_error():
369
378
  raise exceptions.ResourcesUnavailableError(
@@ -373,15 +382,18 @@ def get_kubernetes_nodes() -> List[Any]:
373
382
  return nodes
374
383
 
375
384
 
376
- def get_kubernetes_pods() -> List[Any]:
377
- """Gets the kubernetes pods in the current namespace and current context.
385
+ def get_all_pods_in_kubernetes_cluster(
386
+ context: Optional[str] = None) -> List[Any]:
387
+ """Gets pods in all namespaces in kubernetes cluster indicated by context.
378
388
 
379
389
  Used for computing cluster resource usage.
380
390
  """
391
+ if context is None:
392
+ context = get_current_kube_config_context_name()
393
+
381
394
  try:
382
- ns = get_current_kube_config_context_namespace()
383
- pods = kubernetes.core_api().list_namespaced_pod(
384
- ns, _request_timeout=kubernetes.API_TIMEOUT).items
395
+ pods = kubernetes.core_api(context).list_pod_for_all_namespaces(
396
+ _request_timeout=kubernetes.API_TIMEOUT).items
385
397
  except kubernetes.max_retry_error():
386
398
  raise exceptions.ResourcesUnavailableError(
387
399
  'Timed out when trying to get pod info from Kubernetes cluster. '
@@ -390,7 +402,8 @@ def get_kubernetes_pods() -> List[Any]:
390
402
  return pods
391
403
 
392
404
 
393
- def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
405
+ def check_instance_fits(context: str,
406
+ instance: str) -> Tuple[bool, Optional[str]]:
394
407
  """Checks if the instance fits on the Kubernetes cluster.
395
408
 
396
409
  If the instance has GPU requirements, checks if the GPU type is
@@ -405,6 +418,9 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
405
418
  Optional[str]: Error message if the instance does not fit.
406
419
  """
407
420
 
421
+ # TODO(zhwu): this should check the node for specific context, instead
422
+ # of the default context to make failover fully functional.
423
+
408
424
  def check_cpu_mem_fits(candidate_instance_type: 'KubernetesInstanceType',
409
425
  node_list: List[Any]) -> Tuple[bool, Optional[str]]:
410
426
  """Checks if the instance fits on the cluster based on CPU and memory.
@@ -431,7 +447,7 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
431
447
  'Maximum resources found on a single node: '
432
448
  f'{max_cpu} CPUs, {common_utils.format_float(max_mem)}G Memory')
433
449
 
434
- nodes = get_kubernetes_nodes()
450
+ nodes = get_kubernetes_nodes(context)
435
451
  k8s_instance_type = KubernetesInstanceType.\
436
452
  from_instance_type(instance)
437
453
  acc_type = k8s_instance_type.accelerator_type
@@ -439,7 +455,8 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
439
455
  # If GPUs are requested, check if GPU type is available, and if so,
440
456
  # check if CPU and memory requirements on the specific node are met.
441
457
  try:
442
- gpu_label_key, gpu_label_val = get_gpu_label_key_value(acc_type)
458
+ gpu_label_key, gpu_label_val = get_gpu_label_key_value(
459
+ context, acc_type)
443
460
  except exceptions.ResourcesUnavailableError as e:
444
461
  # If GPU not found, return empty list and error message.
445
462
  return False, str(e)
@@ -471,7 +488,9 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
471
488
  return fits, reason
472
489
 
473
490
 
474
- def get_gpu_label_key_value(acc_type: str, check_mode=False) -> Tuple[str, str]:
491
+ def get_gpu_label_key_value(context: str,
492
+ acc_type: str,
493
+ check_mode=False) -> Tuple[str, str]:
475
494
  """Returns the label key and value for the given GPU type.
476
495
 
477
496
  Args:
@@ -512,11 +531,11 @@ def get_gpu_label_key_value(acc_type: str, check_mode=False) -> Tuple[str, str]:
512
531
  f' {autoscaler_type}')
513
532
  return formatter.get_label_key(), formatter.get_label_value(acc_type)
514
533
 
515
- has_gpus, cluster_resources = detect_gpu_resource()
534
+ has_gpus, cluster_resources = detect_gpu_resource(context)
516
535
  if has_gpus:
517
536
  # Check if the cluster has GPU labels setup correctly
518
537
  label_formatter, node_labels = \
519
- detect_gpu_label_formatter()
538
+ detect_gpu_label_formatter(context)
520
539
  if label_formatter is None:
521
540
  # If none of the GPU labels from LABEL_FORMATTER_REGISTRY are
522
541
  # detected, raise error
@@ -632,7 +651,7 @@ def get_external_ip(network_mode: Optional[
632
651
  return parsed_url.hostname
633
652
 
634
653
 
635
- def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
654
+ def check_credentials(context: str, timeout: int = kubernetes.API_TIMEOUT) -> \
636
655
  Tuple[bool, Optional[str]]:
637
656
  """Check if the credentials in kubeconfig file are valid
638
657
 
@@ -644,10 +663,9 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
644
663
  str: Error message if credentials are invalid, None otherwise
645
664
  """
646
665
  try:
647
- ns = get_current_kube_config_context_namespace()
648
- context = get_current_kube_config_context_name()
666
+ namespace = get_kube_config_context_namespace(context)
649
667
  kubernetes.core_api(context).list_namespaced_pod(
650
- ns, _request_timeout=timeout)
668
+ namespace, _request_timeout=timeout)
651
669
  except ImportError:
652
670
  # TODO(romilb): Update these error strs to also include link to docs
653
671
  # when docs are ready.
@@ -676,7 +694,7 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
676
694
  # We now do softer checks to check if exec based auth is used and to
677
695
  # see if the cluster is GPU-enabled.
678
696
 
679
- _, exec_msg = is_kubeconfig_exec_auth()
697
+ _, exec_msg = is_kubeconfig_exec_auth(context)
680
698
 
681
699
  # We now check if GPUs are available and labels are set correctly on the
682
700
  # cluster, and if not we return hints that may help debug any issues.
@@ -685,7 +703,7 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
685
703
  # provider if their cluster GPUs are not setup correctly.
686
704
  gpu_msg = ''
687
705
  try:
688
- _, _ = get_gpu_label_key_value(acc_type='', check_mode=True)
706
+ _, _ = get_gpu_label_key_value(context, acc_type='', check_mode=True)
689
707
  except exceptions.ResourcesUnavailableError as e:
690
708
  # If GPUs are not available, we return cluster as enabled (since it can
691
709
  # be a CPU-only cluster) but we also return the exception message which
@@ -701,7 +719,8 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
701
719
  return True, None
702
720
 
703
721
 
704
- def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
722
+ def is_kubeconfig_exec_auth(
723
+ context: Optional[str] = None) -> Tuple[bool, Optional[str]]:
705
724
  """Checks if the kubeconfig file uses exec-based authentication
706
725
 
707
726
  Exec-based auth is commonly used for authenticating with cloud hosted
@@ -735,8 +754,16 @@ def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
735
754
  return False, None
736
755
 
737
756
  # Get active context and user from kubeconfig using k8s api
738
- _, current_context = k8s.config.list_kube_config_contexts()
739
- target_username = current_context['context']['user']
757
+ all_contexts, current_context = k8s.config.list_kube_config_contexts()
758
+ context_obj = current_context
759
+ if context is not None:
760
+ for c in all_contexts:
761
+ if c['name'] == context:
762
+ context_obj = c
763
+ break
764
+ else:
765
+ raise ValueError(f'Kubernetes context {context!r} not found.')
766
+ target_username = context_obj['context']['user']
740
767
 
741
768
  # K8s api does not provide a mechanism to get the user details from the
742
769
  # context. We need to load the kubeconfig file and parse it to get the
@@ -759,7 +786,7 @@ def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
759
786
  schemas.get_default_remote_identity('kubernetes'))
760
787
  if ('exec' in user_details.get('user', {}) and remote_identity
761
788
  == schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value):
762
- ctx_name = current_context['name']
789
+ ctx_name = context_obj['name']
763
790
  exec_msg = ('exec-based authentication is used for '
764
791
  f'Kubernetes context {ctx_name!r}.'
765
792
  ' This may cause issues with autodown or when running '
@@ -775,6 +802,7 @@ def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
775
802
  return False, None
776
803
 
777
804
 
805
+ @functools.lru_cache()
778
806
  def get_current_kube_config_context_name() -> Optional[str]:
779
807
  """Get the current kubernetes context from the kubeconfig file
780
808
 
@@ -789,7 +817,27 @@ def get_current_kube_config_context_name() -> Optional[str]:
789
817
  return None
790
818
 
791
819
 
792
- def get_current_kube_config_context_namespace() -> str:
820
+ def get_all_kube_config_context_names() -> Optional[List[str]]:
821
+ """Get all kubernetes context names from the kubeconfig file.
822
+
823
+ We should not cache the result of this function as the admin policy may
824
+ update the contexts.
825
+
826
+ Returns:
827
+ List[str] | None: The list of kubernetes context names if it exists,
828
+ None otherwise
829
+ """
830
+ k8s = kubernetes.kubernetes
831
+ try:
832
+ all_contexts, _ = k8s.config.list_kube_config_contexts()
833
+ return [context['name'] for context in all_contexts]
834
+ except k8s.config.config_exception.ConfigException:
835
+ return None
836
+
837
+
838
+ @functools.lru_cache()
839
+ def get_kube_config_context_namespace(
840
+ context_name: Optional[str] = None) -> str:
793
841
  """Get the current kubernetes context namespace from the kubeconfig file
794
842
 
795
843
  Returns:
@@ -804,9 +852,17 @@ def get_current_kube_config_context_namespace() -> str:
804
852
  return f.read().strip()
805
853
  # If not in-cluster, get the namespace from kubeconfig
806
854
  try:
807
- _, current_context = k8s.config.list_kube_config_contexts()
808
- if 'namespace' in current_context['context']:
809
- return current_context['context']['namespace']
855
+ contexts, current_context = k8s.config.list_kube_config_contexts()
856
+ if context_name is None:
857
+ context = current_context
858
+ else:
859
+ context = next((c for c in contexts if c['name'] == context_name),
860
+ None)
861
+ if context is None:
862
+ return DEFAULT_NAMESPACE
863
+
864
+ if 'namespace' in context['context']:
865
+ return context['context']['namespace']
810
866
  else:
811
867
  return DEFAULT_NAMESPACE
812
868
  except k8s.config.config_exception.ConfigException:
@@ -987,11 +1043,12 @@ def construct_ssh_jump_command(
987
1043
 
988
1044
 
989
1045
  def get_ssh_proxy_command(
990
- k8s_ssh_target: str,
991
- network_mode: kubernetes_enums.KubernetesNetworkingMode,
992
- private_key_path: Optional[str] = None,
993
- namespace: Optional[str] = None,
994
- context: Optional[str] = None) -> str:
1046
+ k8s_ssh_target: str,
1047
+ network_mode: kubernetes_enums.KubernetesNetworkingMode,
1048
+ private_key_path: str,
1049
+ context: str,
1050
+ namespace: str,
1051
+ ) -> str:
995
1052
  """Generates the SSH proxy command to connect to the pod.
996
1053
 
997
1054
  Uses a jump pod if the network mode is NODEPORT, and direct port-forwarding
@@ -1048,8 +1105,6 @@ def get_ssh_proxy_command(
1048
1105
  private_key_path, ssh_jump_ip, ssh_jump_port=ssh_jump_port)
1049
1106
  else:
1050
1107
  ssh_jump_proxy_command_path = create_proxy_command_script()
1051
- current_context = get_current_kube_config_context_name()
1052
- current_namespace = get_current_kube_config_context_namespace()
1053
1108
  ssh_jump_proxy_command = construct_ssh_jump_command(
1054
1109
  private_key_path,
1055
1110
  ssh_jump_ip,
@@ -1059,8 +1114,8 @@ def get_ssh_proxy_command(
1059
1114
  # We embed both the current context and namespace to the SSH proxy
1060
1115
  # command to make sure SSH still works when the current
1061
1116
  # context/namespace is changed by the user.
1062
- current_kube_context=current_context,
1063
- current_kube_namespace=current_namespace)
1117
+ current_kube_context=context,
1118
+ current_kube_namespace=namespace)
1064
1119
  return ssh_jump_proxy_command
1065
1120
 
1066
1121
 
@@ -1647,7 +1702,8 @@ SPOT_LABEL_MAP = {
1647
1702
  }
1648
1703
 
1649
1704
 
1650
- def get_spot_label() -> Tuple[Optional[str], Optional[str]]:
1705
+ def get_spot_label(
1706
+ context: Optional[str] = None) -> Tuple[Optional[str], Optional[str]]:
1651
1707
  """Get the spot label key and value for using spot instances, if supported.
1652
1708
 
1653
1709
  Checks if the underlying cluster supports spot instances by checking nodes
@@ -1661,7 +1717,7 @@ def get_spot_label() -> Tuple[Optional[str], Optional[str]]:
1661
1717
  """
1662
1718
  # Check if the cluster supports spot instances by checking nodes for known
1663
1719
  # spot label keys and values
1664
- for node in get_kubernetes_nodes():
1720
+ for node in get_kubernetes_nodes(context):
1665
1721
  for _, (key, value) in SPOT_LABEL_MAP.items():
1666
1722
  if key in node.metadata.labels and node.metadata.labels[
1667
1723
  key] == value:
@@ -1706,7 +1762,8 @@ class KubernetesNodeInfo:
1706
1762
  free: Dict[str, int]
1707
1763
 
1708
1764
 
1709
- def get_kubernetes_node_info() -> Dict[str, KubernetesNodeInfo]:
1765
+ def get_kubernetes_node_info(
1766
+ context: Optional[str] = None) -> Dict[str, KubernetesNodeInfo]:
1710
1767
  """Gets the resource information for all the nodes in the cluster.
1711
1768
 
1712
1769
  Currently only GPU resources are supported. The function returns the total
@@ -1717,11 +1774,11 @@ def get_kubernetes_node_info() -> Dict[str, KubernetesNodeInfo]:
1717
1774
  Dict[str, KubernetesNodeInfo]: Dictionary containing the node name as
1718
1775
  key and the KubernetesNodeInfo object as value
1719
1776
  """
1720
- nodes = get_kubernetes_nodes()
1777
+ nodes = get_kubernetes_nodes(context)
1721
1778
  # Get the pods to get the real-time resource usage
1722
- pods = get_kubernetes_pods()
1779
+ pods = get_all_pods_in_kubernetes_cluster(context)
1723
1780
 
1724
- label_formatter, _ = detect_gpu_label_formatter()
1781
+ label_formatter, _ = detect_gpu_label_formatter(context)
1725
1782
  if not label_formatter:
1726
1783
  label_key = None
1727
1784
  else:
@@ -1773,8 +1830,9 @@ def to_label_selector(tags):
1773
1830
 
1774
1831
 
1775
1832
  def get_namespace_from_config(provider_config: Dict[str, Any]) -> str:
1833
+ context = get_context_from_config(provider_config)
1776
1834
  return provider_config.get('namespace',
1777
- get_current_kube_config_context_namespace())
1835
+ get_kube_config_context_namespace(context))
1778
1836
 
1779
1837
 
1780
1838
  def filter_pods(namespace: str,
@@ -1802,8 +1860,10 @@ def filter_pods(namespace: str,
1802
1860
  return {pod.metadata.name: pod for pod in pods}
1803
1861
 
1804
1862
 
1805
- def _remove_pod_annotation(pod: Any, annotation_key: str,
1806
- namespace: str) -> None:
1863
+ def _remove_pod_annotation(pod: Any,
1864
+ annotation_key: str,
1865
+ namespace: str,
1866
+ context: Optional[str] = None) -> None:
1807
1867
  """Removes specified Annotations from a Kubernetes pod."""
1808
1868
  try:
1809
1869
  # Remove the specified annotation
@@ -1811,7 +1871,7 @@ def _remove_pod_annotation(pod: Any, annotation_key: str,
1811
1871
  if annotation_key in pod.metadata.annotations:
1812
1872
  # Patch the pod with the updated metadata.
1813
1873
  body = {'metadata': {'annotations': {annotation_key: None}}}
1814
- kubernetes.core_api().patch_namespaced_pod(
1874
+ kubernetes.core_api(context).patch_namespaced_pod(
1815
1875
  name=pod.metadata.name,
1816
1876
  namespace=namespace,
1817
1877
  body=body,
@@ -1830,13 +1890,15 @@ def _remove_pod_annotation(pod: Any, annotation_key: str,
1830
1890
  raise
1831
1891
 
1832
1892
 
1833
- def _add_pod_annotation(pod: Any, annotation: Dict[str, str],
1834
- namespace: str) -> None:
1893
+ def _add_pod_annotation(pod: Any,
1894
+ annotation: Dict[str, str],
1895
+ namespace: str,
1896
+ context: Optional[str] = None) -> None:
1835
1897
  """Adds specified Annotations on a Kubernetes pod."""
1836
1898
  try:
1837
1899
  # Patch the pod with the updated metadata
1838
1900
  body = {'metadata': {'annotations': annotation}}
1839
- kubernetes.core_api().patch_namespaced_pod(
1901
+ kubernetes.core_api(context).patch_namespaced_pod(
1840
1902
  name=pod.metadata.name,
1841
1903
  namespace=namespace,
1842
1904
  body=body,
@@ -1877,10 +1939,12 @@ def set_autodown_annotations(handle: 'backends.CloudVmRayResourceHandle',
1877
1939
  autodown_annotation = {AUTODOWN_ANNOTATION_KEY: 'true'}
1878
1940
  _add_pod_annotation(pod=pod,
1879
1941
  annotation=idle_minutes_to_autostop_annotation,
1880
- namespace=namespace)
1942
+ namespace=namespace,
1943
+ context=context)
1881
1944
  _add_pod_annotation(pod=pod,
1882
1945
  annotation=autodown_annotation,
1883
- namespace=namespace)
1946
+ namespace=namespace,
1947
+ context=context)
1884
1948
 
1885
1949
  # If idle_minutes_to_autostop is negative, it indicates a request to
1886
1950
  # cancel autostop using the --cancel flag with the `sky autostop`
@@ -1890,10 +1954,12 @@ def set_autodown_annotations(handle: 'backends.CloudVmRayResourceHandle',
1890
1954
  _remove_pod_annotation(
1891
1955
  pod=pod,
1892
1956
  annotation_key=IDLE_MINUTES_TO_AUTOSTOP_ANNOTATION_KEY,
1893
- namespace=namespace)
1957
+ namespace=namespace,
1958
+ context=context)
1894
1959
  _remove_pod_annotation(pod=pod,
1895
1960
  annotation_key=AUTODOWN_ANNOTATION_KEY,
1896
- namespace=namespace)
1961
+ namespace=namespace,
1962
+ context=context)
1897
1963
 
1898
1964
 
1899
1965
  def get_context_from_config(provider_config: Dict[str, Any]) -> str:
@@ -259,6 +259,8 @@ def _ssh_probe_command(ip: str,
259
259
  '-o',
260
260
  'IdentitiesOnly=yes',
261
261
  '-o',
262
+ 'AddKeysToAgent=yes',
263
+ '-o',
262
264
  'ExitOnForwardFailure=yes',
263
265
  '-o',
264
266
  'ServerAliveInterval=5',
@@ -18,7 +18,7 @@ provider:
18
18
 
19
19
  region: kubernetes
20
20
 
21
- # The namespace to create the Ray cluster in.
21
+
22
22
  namespace: {{k8s_namespace}}
23
23
 
24
24
  # The kubecontext used to connect to the Kubernetes cluster.
@@ -85,6 +85,10 @@ def ssh_options_list(
85
85
  'LogLevel': 'ERROR',
86
86
  # Try fewer extraneous key pairs.
87
87
  'IdentitiesOnly': 'yes',
88
+ # Add the current private key used for this SSH connection to the
89
+ # SSH agent, so that forward agent parameter will then make SSH
90
+ # agent forward it.
91
+ 'AddKeysToAgent': 'yes',
88
92
  # Abort if port forwarding fails (instead of just printing to
89
93
  # stderr).
90
94
  'ExitOnForwardFailure': 'yes',
sky/utils/schemas.py CHANGED
@@ -775,6 +775,12 @@ def get_config_schema():
775
775
  'required': [],
776
776
  'additionalProperties': False,
777
777
  'properties': {
778
+ 'allowed_contexts': {
779
+ 'type': 'array',
780
+ 'items': {
781
+ 'type': 'string',
782
+ },
783
+ },
778
784
  'networking': {
779
785
  'type': 'string',
780
786
  'case_insensitive_enum': [