skypilot-nightly 1.0.0.dev20240925__py3-none-any.whl → 1.0.0.dev20240927__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,13 @@
1
1
  """Kubernetes utilities for SkyPilot."""
2
2
  import dataclasses
3
+ import functools
3
4
  import json
4
5
  import math
5
6
  import os
6
7
  import re
7
8
  import shutil
8
9
  import subprocess
10
+ import typing
9
11
  from typing import Any, Dict, List, Optional, Set, Tuple, Union
10
12
  from urllib.parse import urlparse
11
13
 
@@ -17,6 +19,7 @@ from sky import exceptions
17
19
  from sky import sky_logging
18
20
  from sky import skypilot_config
19
21
  from sky.adaptors import kubernetes
22
+ from sky.provision import constants as provision_constants
20
23
  from sky.provision.kubernetes import network_utils
21
24
  from sky.skylet import constants
22
25
  from sky.utils import common_utils
@@ -25,6 +28,9 @@ from sky.utils import kubernetes_enums
25
28
  from sky.utils import schemas
26
29
  from sky.utils import ux_utils
27
30
 
31
+ if typing.TYPE_CHECKING:
32
+ from sky import backends
33
+
28
34
  # TODO(romilb): Move constants to constants.py
29
35
  DEFAULT_NAMESPACE = 'default'
30
36
 
@@ -64,6 +70,16 @@ PORT_FORWARD_PROXY_CMD_VERSION = 2
64
70
  PORT_FORWARD_PROXY_CMD_PATH = ('~/.sky/kubernetes-port-forward-proxy-command-'
65
71
  f'v{PORT_FORWARD_PROXY_CMD_VERSION}.sh')
66
72
 
73
+ POD_STATUSES = {
74
+ 'Pending', 'Running', 'Succeeded', 'Failed', 'Unknown', 'Terminating'
75
+ }
76
+ AUTODOWN_ANNOTATION_KEY = 'skypilot.co/autodown'
77
+ IDLE_MINUTES_TO_AUTOSTOP_ANNOTATION_KEY = (
78
+ 'skypilot.co/idle_minutes_to_autostop')
79
+ ANNOTATIONS_POD_NOT_FOUND_ERROR_MSG = ('Pod {pod_name} not found in namespace '
80
+ '{namespace} while trying to {action} '
81
+ 'an annotation {annotation}.')
82
+
67
83
  logger = sky_logging.init_logger(__name__)
68
84
 
69
85
 
@@ -292,7 +308,9 @@ AUTOSCALER_TO_LABEL_FORMATTER = {
292
308
  }
293
309
 
294
310
 
311
+ @functools.lru_cache()
295
312
  def detect_gpu_label_formatter(
313
+ context: str
296
314
  ) -> Tuple[Optional[GPULabelFormatter], Dict[str, List[Tuple[str, str]]]]:
297
315
  """Detects the GPU label formatter for the Kubernetes cluster
298
316
 
@@ -303,7 +321,7 @@ def detect_gpu_label_formatter(
303
321
  """
304
322
  # Get all labels across all nodes
305
323
  node_labels: Dict[str, List[Tuple[str, str]]] = {}
306
- nodes = get_kubernetes_nodes()
324
+ nodes = get_kubernetes_nodes(context)
307
325
  for node in nodes:
308
326
  node_labels[node.metadata.name] = []
309
327
  for label, value in node.metadata.labels.items():
@@ -323,7 +341,8 @@ def detect_gpu_label_formatter(
323
341
  return label_formatter, node_labels
324
342
 
325
343
 
326
- def detect_gpu_resource() -> Tuple[bool, Set[str]]:
344
+ @functools.lru_cache(maxsize=10)
345
+ def detect_gpu_resource(context: str) -> Tuple[bool, Set[str]]:
327
346
  """Checks if the Kubernetes cluster has nvidia.com/gpu resource.
328
347
 
329
348
  If nvidia.com/gpu resource is missing, that typically means that the
@@ -335,7 +354,7 @@ def detect_gpu_resource() -> Tuple[bool, Set[str]]:
335
354
  """
336
355
  # Get the set of resources across all nodes
337
356
  cluster_resources: Set[str] = set()
338
- nodes = get_kubernetes_nodes()
357
+ nodes = get_kubernetes_nodes(context)
339
358
  for node in nodes:
340
359
  cluster_resources.update(node.status.allocatable.keys())
341
360
  has_gpu = 'nvidia.com/gpu' in cluster_resources
@@ -343,12 +362,17 @@ def detect_gpu_resource() -> Tuple[bool, Set[str]]:
343
362
  return has_gpu, cluster_resources
344
363
 
345
364
 
346
- def get_kubernetes_nodes() -> List[Any]:
347
- # TODO(romilb): Calling kube API can take between 10-100ms depending on
348
- # the control plane. Consider caching calls to this function (using
349
- # kubecontext hash as key).
365
+ @functools.lru_cache(maxsize=10)
366
+ def get_kubernetes_nodes(context: Optional[str] = None) -> List[Any]:
367
+ """Gets the kubernetes nodes in the context.
368
+
369
+ If context is None, gets the nodes in the current context.
370
+ """
371
+ if context is None:
372
+ context = get_current_kube_config_context_name()
373
+
350
374
  try:
351
- nodes = kubernetes.core_api().list_node(
375
+ nodes = kubernetes.core_api(context).list_node(
352
376
  _request_timeout=kubernetes.API_TIMEOUT).items
353
377
  except kubernetes.max_retry_error():
354
378
  raise exceptions.ResourcesUnavailableError(
@@ -358,15 +382,18 @@ def get_kubernetes_nodes() -> List[Any]:
358
382
  return nodes
359
383
 
360
384
 
361
- def get_kubernetes_pods() -> List[Any]:
362
- """Gets the kubernetes pods in the current namespace and current context.
385
+ def get_all_pods_in_kubernetes_cluster(
386
+ context: Optional[str] = None) -> List[Any]:
387
+ """Gets pods in all namespaces in kubernetes cluster indicated by context.
363
388
 
364
389
  Used for computing cluster resource usage.
365
390
  """
391
+ if context is None:
392
+ context = get_current_kube_config_context_name()
393
+
366
394
  try:
367
- ns = get_current_kube_config_context_namespace()
368
- pods = kubernetes.core_api().list_namespaced_pod(
369
- ns, _request_timeout=kubernetes.API_TIMEOUT).items
395
+ pods = kubernetes.core_api(context).list_pod_for_all_namespaces(
396
+ _request_timeout=kubernetes.API_TIMEOUT).items
370
397
  except kubernetes.max_retry_error():
371
398
  raise exceptions.ResourcesUnavailableError(
372
399
  'Timed out when trying to get pod info from Kubernetes cluster. '
@@ -375,7 +402,8 @@ def get_kubernetes_pods() -> List[Any]:
375
402
  return pods
376
403
 
377
404
 
378
- def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
405
+ def check_instance_fits(context: str,
406
+ instance: str) -> Tuple[bool, Optional[str]]:
379
407
  """Checks if the instance fits on the Kubernetes cluster.
380
408
 
381
409
  If the instance has GPU requirements, checks if the GPU type is
@@ -390,6 +418,9 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
390
418
  Optional[str]: Error message if the instance does not fit.
391
419
  """
392
420
 
421
+ # TODO(zhwu): this should check the node for specific context, instead
422
+ # of the default context to make failover fully functional.
423
+
393
424
  def check_cpu_mem_fits(candidate_instance_type: 'KubernetesInstanceType',
394
425
  node_list: List[Any]) -> Tuple[bool, Optional[str]]:
395
426
  """Checks if the instance fits on the cluster based on CPU and memory.
@@ -416,7 +447,7 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
416
447
  'Maximum resources found on a single node: '
417
448
  f'{max_cpu} CPUs, {common_utils.format_float(max_mem)}G Memory')
418
449
 
419
- nodes = get_kubernetes_nodes()
450
+ nodes = get_kubernetes_nodes(context)
420
451
  k8s_instance_type = KubernetesInstanceType.\
421
452
  from_instance_type(instance)
422
453
  acc_type = k8s_instance_type.accelerator_type
@@ -424,7 +455,8 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
424
455
  # If GPUs are requested, check if GPU type is available, and if so,
425
456
  # check if CPU and memory requirements on the specific node are met.
426
457
  try:
427
- gpu_label_key, gpu_label_val = get_gpu_label_key_value(acc_type)
458
+ gpu_label_key, gpu_label_val = get_gpu_label_key_value(
459
+ context, acc_type)
428
460
  except exceptions.ResourcesUnavailableError as e:
429
461
  # If GPU not found, return empty list and error message.
430
462
  return False, str(e)
@@ -456,7 +488,9 @@ def check_instance_fits(instance: str) -> Tuple[bool, Optional[str]]:
456
488
  return fits, reason
457
489
 
458
490
 
459
- def get_gpu_label_key_value(acc_type: str, check_mode=False) -> Tuple[str, str]:
491
+ def get_gpu_label_key_value(context: str,
492
+ acc_type: str,
493
+ check_mode=False) -> Tuple[str, str]:
460
494
  """Returns the label key and value for the given GPU type.
461
495
 
462
496
  Args:
@@ -497,11 +531,11 @@ def get_gpu_label_key_value(acc_type: str, check_mode=False) -> Tuple[str, str]:
497
531
  f' {autoscaler_type}')
498
532
  return formatter.get_label_key(), formatter.get_label_value(acc_type)
499
533
 
500
- has_gpus, cluster_resources = detect_gpu_resource()
534
+ has_gpus, cluster_resources = detect_gpu_resource(context)
501
535
  if has_gpus:
502
536
  # Check if the cluster has GPU labels setup correctly
503
537
  label_formatter, node_labels = \
504
- detect_gpu_label_formatter()
538
+ detect_gpu_label_formatter(context)
505
539
  if label_formatter is None:
506
540
  # If none of the GPU labels from LABEL_FORMATTER_REGISTRY are
507
541
  # detected, raise error
@@ -617,7 +651,7 @@ def get_external_ip(network_mode: Optional[
617
651
  return parsed_url.hostname
618
652
 
619
653
 
620
- def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
654
+ def check_credentials(context: str, timeout: int = kubernetes.API_TIMEOUT) -> \
621
655
  Tuple[bool, Optional[str]]:
622
656
  """Check if the credentials in kubeconfig file are valid
623
657
 
@@ -629,10 +663,9 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
629
663
  str: Error message if credentials are invalid, None otherwise
630
664
  """
631
665
  try:
632
- ns = get_current_kube_config_context_namespace()
633
- context = get_current_kube_config_context_name()
666
+ namespace = get_kube_config_context_namespace(context)
634
667
  kubernetes.core_api(context).list_namespaced_pod(
635
- ns, _request_timeout=timeout)
668
+ namespace, _request_timeout=timeout)
636
669
  except ImportError:
637
670
  # TODO(romilb): Update these error strs to also include link to docs
638
671
  # when docs are ready.
@@ -661,7 +694,7 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
661
694
  # We now do softer checks to check if exec based auth is used and to
662
695
  # see if the cluster is GPU-enabled.
663
696
 
664
- _, exec_msg = is_kubeconfig_exec_auth()
697
+ _, exec_msg = is_kubeconfig_exec_auth(context)
665
698
 
666
699
  # We now check if GPUs are available and labels are set correctly on the
667
700
  # cluster, and if not we return hints that may help debug any issues.
@@ -670,7 +703,7 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
670
703
  # provider if their cluster GPUs are not setup correctly.
671
704
  gpu_msg = ''
672
705
  try:
673
- _, _ = get_gpu_label_key_value(acc_type='', check_mode=True)
706
+ _, _ = get_gpu_label_key_value(context, acc_type='', check_mode=True)
674
707
  except exceptions.ResourcesUnavailableError as e:
675
708
  # If GPUs are not available, we return cluster as enabled (since it can
676
709
  # be a CPU-only cluster) but we also return the exception message which
@@ -686,7 +719,8 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \
686
719
  return True, None
687
720
 
688
721
 
689
- def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
722
+ def is_kubeconfig_exec_auth(
723
+ context: Optional[str] = None) -> Tuple[bool, Optional[str]]:
690
724
  """Checks if the kubeconfig file uses exec-based authentication
691
725
 
692
726
  Exec-based auth is commonly used for authenticating with cloud hosted
@@ -720,8 +754,16 @@ def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
720
754
  return False, None
721
755
 
722
756
  # Get active context and user from kubeconfig using k8s api
723
- _, current_context = k8s.config.list_kube_config_contexts()
724
- target_username = current_context['context']['user']
757
+ all_contexts, current_context = k8s.config.list_kube_config_contexts()
758
+ context_obj = current_context
759
+ if context is not None:
760
+ for c in all_contexts:
761
+ if c['name'] == context:
762
+ context_obj = c
763
+ break
764
+ else:
765
+ raise ValueError(f'Kubernetes context {context!r} not found.')
766
+ target_username = context_obj['context']['user']
725
767
 
726
768
  # K8s api does not provide a mechanism to get the user details from the
727
769
  # context. We need to load the kubeconfig file and parse it to get the
@@ -744,7 +786,7 @@ def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
744
786
  schemas.get_default_remote_identity('kubernetes'))
745
787
  if ('exec' in user_details.get('user', {}) and remote_identity
746
788
  == schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value):
747
- ctx_name = current_context['name']
789
+ ctx_name = context_obj['name']
748
790
  exec_msg = ('exec-based authentication is used for '
749
791
  f'Kubernetes context {ctx_name!r}.'
750
792
  ' This may cause issues with autodown or when running '
@@ -760,6 +802,7 @@ def is_kubeconfig_exec_auth() -> Tuple[bool, Optional[str]]:
760
802
  return False, None
761
803
 
762
804
 
805
+ @functools.lru_cache()
763
806
  def get_current_kube_config_context_name() -> Optional[str]:
764
807
  """Get the current kubernetes context from the kubeconfig file
765
808
 
@@ -774,7 +817,27 @@ def get_current_kube_config_context_name() -> Optional[str]:
774
817
  return None
775
818
 
776
819
 
777
- def get_current_kube_config_context_namespace() -> str:
820
+ def get_all_kube_config_context_names() -> Optional[List[str]]:
821
+ """Get all kubernetes context names from the kubeconfig file.
822
+
823
+ We should not cache the result of this function as the admin policy may
824
+ update the contexts.
825
+
826
+ Returns:
827
+ List[str] | None: The list of kubernetes context names if it exists,
828
+ None otherwise
829
+ """
830
+ k8s = kubernetes.kubernetes
831
+ try:
832
+ all_contexts, _ = k8s.config.list_kube_config_contexts()
833
+ return [context['name'] for context in all_contexts]
834
+ except k8s.config.config_exception.ConfigException:
835
+ return None
836
+
837
+
838
+ @functools.lru_cache()
839
+ def get_kube_config_context_namespace(
840
+ context_name: Optional[str] = None) -> str:
778
841
  """Get the current kubernetes context namespace from the kubeconfig file
779
842
 
780
843
  Returns:
@@ -789,9 +852,17 @@ def get_current_kube_config_context_namespace() -> str:
789
852
  return f.read().strip()
790
853
  # If not in-cluster, get the namespace from kubeconfig
791
854
  try:
792
- _, current_context = k8s.config.list_kube_config_contexts()
793
- if 'namespace' in current_context['context']:
794
- return current_context['context']['namespace']
855
+ contexts, current_context = k8s.config.list_kube_config_contexts()
856
+ if context_name is None:
857
+ context = current_context
858
+ else:
859
+ context = next((c for c in contexts if c['name'] == context_name),
860
+ None)
861
+ if context is None:
862
+ return DEFAULT_NAMESPACE
863
+
864
+ if 'namespace' in context['context']:
865
+ return context['context']['namespace']
795
866
  else:
796
867
  return DEFAULT_NAMESPACE
797
868
  except k8s.config.config_exception.ConfigException:
@@ -972,11 +1043,12 @@ def construct_ssh_jump_command(
972
1043
 
973
1044
 
974
1045
  def get_ssh_proxy_command(
975
- k8s_ssh_target: str,
976
- network_mode: kubernetes_enums.KubernetesNetworkingMode,
977
- private_key_path: Optional[str] = None,
978
- namespace: Optional[str] = None,
979
- context: Optional[str] = None) -> str:
1046
+ k8s_ssh_target: str,
1047
+ network_mode: kubernetes_enums.KubernetesNetworkingMode,
1048
+ private_key_path: str,
1049
+ context: str,
1050
+ namespace: str,
1051
+ ) -> str:
980
1052
  """Generates the SSH proxy command to connect to the pod.
981
1053
 
982
1054
  Uses a jump pod if the network mode is NODEPORT, and direct port-forwarding
@@ -1033,8 +1105,6 @@ def get_ssh_proxy_command(
1033
1105
  private_key_path, ssh_jump_ip, ssh_jump_port=ssh_jump_port)
1034
1106
  else:
1035
1107
  ssh_jump_proxy_command_path = create_proxy_command_script()
1036
- current_context = get_current_kube_config_context_name()
1037
- current_namespace = get_current_kube_config_context_namespace()
1038
1108
  ssh_jump_proxy_command = construct_ssh_jump_command(
1039
1109
  private_key_path,
1040
1110
  ssh_jump_ip,
@@ -1044,8 +1114,8 @@ def get_ssh_proxy_command(
1044
1114
  # We embed both the current context and namespace to the SSH proxy
1045
1115
  # command to make sure SSH still works when the current
1046
1116
  # context/namespace is changed by the user.
1047
- current_kube_context=current_context,
1048
- current_kube_namespace=current_namespace)
1117
+ current_kube_context=context,
1118
+ current_kube_namespace=namespace)
1049
1119
  return ssh_jump_proxy_command
1050
1120
 
1051
1121
 
@@ -1632,7 +1702,8 @@ SPOT_LABEL_MAP = {
1632
1702
  }
1633
1703
 
1634
1704
 
1635
- def get_spot_label() -> Tuple[Optional[str], Optional[str]]:
1705
+ def get_spot_label(
1706
+ context: Optional[str] = None) -> Tuple[Optional[str], Optional[str]]:
1636
1707
  """Get the spot label key and value for using spot instances, if supported.
1637
1708
 
1638
1709
  Checks if the underlying cluster supports spot instances by checking nodes
@@ -1646,7 +1717,7 @@ def get_spot_label() -> Tuple[Optional[str], Optional[str]]:
1646
1717
  """
1647
1718
  # Check if the cluster supports spot instances by checking nodes for known
1648
1719
  # spot label keys and values
1649
- for node in get_kubernetes_nodes():
1720
+ for node in get_kubernetes_nodes(context):
1650
1721
  for _, (key, value) in SPOT_LABEL_MAP.items():
1651
1722
  if key in node.metadata.labels and node.metadata.labels[
1652
1723
  key] == value:
@@ -1691,7 +1762,8 @@ class KubernetesNodeInfo:
1691
1762
  free: Dict[str, int]
1692
1763
 
1693
1764
 
1694
- def get_kubernetes_node_info() -> Dict[str, KubernetesNodeInfo]:
1765
+ def get_kubernetes_node_info(
1766
+ context: Optional[str] = None) -> Dict[str, KubernetesNodeInfo]:
1695
1767
  """Gets the resource information for all the nodes in the cluster.
1696
1768
 
1697
1769
  Currently only GPU resources are supported. The function returns the total
@@ -1702,11 +1774,11 @@ def get_kubernetes_node_info() -> Dict[str, KubernetesNodeInfo]:
1702
1774
  Dict[str, KubernetesNodeInfo]: Dictionary containing the node name as
1703
1775
  key and the KubernetesNodeInfo object as value
1704
1776
  """
1705
- nodes = get_kubernetes_nodes()
1777
+ nodes = get_kubernetes_nodes(context)
1706
1778
  # Get the pods to get the real-time resource usage
1707
- pods = get_kubernetes_pods()
1779
+ pods = get_all_pods_in_kubernetes_cluster(context)
1708
1780
 
1709
- label_formatter, _ = detect_gpu_label_formatter()
1781
+ label_formatter, _ = detect_gpu_label_formatter(context)
1710
1782
  if not label_formatter:
1711
1783
  label_key = None
1712
1784
  else:
@@ -1748,9 +1820,146 @@ def get_kubernetes_node_info() -> Dict[str, KubernetesNodeInfo]:
1748
1820
  return node_info_dict
1749
1821
 
1750
1822
 
1823
+ def to_label_selector(tags):
1824
+ label_selector = ''
1825
+ for k, v in tags.items():
1826
+ if label_selector != '':
1827
+ label_selector += ','
1828
+ label_selector += '{}={}'.format(k, v)
1829
+ return label_selector
1830
+
1831
+
1751
1832
  def get_namespace_from_config(provider_config: Dict[str, Any]) -> str:
1833
+ context = get_context_from_config(provider_config)
1752
1834
  return provider_config.get('namespace',
1753
- get_current_kube_config_context_namespace())
1835
+ get_kube_config_context_namespace(context))
1836
+
1837
+
1838
+ def filter_pods(namespace: str,
1839
+ context: str,
1840
+ tag_filters: Dict[str, str],
1841
+ status_filters: Optional[List[str]] = None) -> Dict[str, Any]:
1842
+ """Filters pods by tags and status."""
1843
+ non_included_pod_statuses = POD_STATUSES.copy()
1844
+
1845
+ field_selector = ''
1846
+ if status_filters is not None:
1847
+ non_included_pod_statuses -= set(status_filters)
1848
+ field_selector = ','.join(
1849
+ [f'status.phase!={status}' for status in non_included_pod_statuses])
1850
+
1851
+ label_selector = to_label_selector(tag_filters)
1852
+ pod_list = kubernetes.core_api(context).list_namespaced_pod(
1853
+ namespace, field_selector=field_selector, label_selector=label_selector)
1854
+
1855
+ # Don't return pods marked for deletion,
1856
+ # i.e. pods with non-null metadata.DeletionTimestamp.
1857
+ pods = [
1858
+ pod for pod in pod_list.items if pod.metadata.deletion_timestamp is None
1859
+ ]
1860
+ return {pod.metadata.name: pod for pod in pods}
1861
+
1862
+
1863
+ def _remove_pod_annotation(pod: Any,
1864
+ annotation_key: str,
1865
+ namespace: str,
1866
+ context: Optional[str] = None) -> None:
1867
+ """Removes specified Annotations from a Kubernetes pod."""
1868
+ try:
1869
+ # Remove the specified annotation
1870
+ if pod.metadata.annotations:
1871
+ if annotation_key in pod.metadata.annotations:
1872
+ # Patch the pod with the updated metadata.
1873
+ body = {'metadata': {'annotations': {annotation_key: None}}}
1874
+ kubernetes.core_api(context).patch_namespaced_pod(
1875
+ name=pod.metadata.name,
1876
+ namespace=namespace,
1877
+ body=body,
1878
+ _request_timeout=kubernetes.API_TIMEOUT)
1879
+
1880
+ except kubernetes.api_exception() as e:
1881
+ if e.status == 404:
1882
+ logger.warning(
1883
+ ANNOTATIONS_POD_NOT_FOUND_ERROR_MSG.format(
1884
+ pod_name=pod.metadata.name,
1885
+ namespace=namespace,
1886
+ action='remove',
1887
+ annotation=annotation_key))
1888
+ else:
1889
+ with ux_utils.print_exception_no_traceback():
1890
+ raise
1891
+
1892
+
1893
+ def _add_pod_annotation(pod: Any,
1894
+ annotation: Dict[str, str],
1895
+ namespace: str,
1896
+ context: Optional[str] = None) -> None:
1897
+ """Adds specified Annotations on a Kubernetes pod."""
1898
+ try:
1899
+ # Patch the pod with the updated metadata
1900
+ body = {'metadata': {'annotations': annotation}}
1901
+ kubernetes.core_api(context).patch_namespaced_pod(
1902
+ name=pod.metadata.name,
1903
+ namespace=namespace,
1904
+ body=body,
1905
+ _request_timeout=kubernetes.API_TIMEOUT)
1906
+
1907
+ except kubernetes.api_exception() as e:
1908
+ if e.status == 404:
1909
+ logger.warning(
1910
+ ANNOTATIONS_POD_NOT_FOUND_ERROR_MSG.format(
1911
+ pod_name=pod.metadata.name,
1912
+ namespace=namespace,
1913
+ action='add',
1914
+ annotation=annotation))
1915
+ else:
1916
+ with ux_utils.print_exception_no_traceback():
1917
+ raise
1918
+
1919
+
1920
+ def set_autodown_annotations(handle: 'backends.CloudVmRayResourceHandle',
1921
+ idle_minutes_to_autostop: Optional[int],
1922
+ down: bool = False) -> None:
1923
+ """Adds or removes Annotations of autodown on Kubernetes pods."""
1924
+ tags = {
1925
+ provision_constants.TAG_RAY_CLUSTER_NAME: handle.cluster_name_on_cloud,
1926
+ }
1927
+ ray_config = common_utils.read_yaml(handle.cluster_yaml)
1928
+ provider_config = ray_config['provider']
1929
+ namespace = get_namespace_from_config(provider_config)
1930
+ context = get_context_from_config(provider_config)
1931
+ running_pods = filter_pods(namespace, context, tags)
1932
+
1933
+ for _, pod in running_pods.items():
1934
+ if down:
1935
+ idle_minutes_to_autostop_annotation = {
1936
+ IDLE_MINUTES_TO_AUTOSTOP_ANNOTATION_KEY:
1937
+ str(idle_minutes_to_autostop)
1938
+ }
1939
+ autodown_annotation = {AUTODOWN_ANNOTATION_KEY: 'true'}
1940
+ _add_pod_annotation(pod=pod,
1941
+ annotation=idle_minutes_to_autostop_annotation,
1942
+ namespace=namespace,
1943
+ context=context)
1944
+ _add_pod_annotation(pod=pod,
1945
+ annotation=autodown_annotation,
1946
+ namespace=namespace,
1947
+ context=context)
1948
+
1949
+ # If idle_minutes_to_autostop is negative, it indicates a request to
1950
+ # cancel autostop using the --cancel flag with the `sky autostop`
1951
+ # command.
1952
+ elif (idle_minutes_to_autostop is not None and
1953
+ idle_minutes_to_autostop < 0):
1954
+ _remove_pod_annotation(
1955
+ pod=pod,
1956
+ annotation_key=IDLE_MINUTES_TO_AUTOSTOP_ANNOTATION_KEY,
1957
+ namespace=namespace,
1958
+ context=context)
1959
+ _remove_pod_annotation(pod=pod,
1960
+ annotation_key=AUTODOWN_ANNOTATION_KEY,
1961
+ namespace=namespace,
1962
+ context=context)
1754
1963
 
1755
1964
 
1756
1965
  def get_context_from_config(provider_config: Dict[str, Any]) -> str:
@@ -19,6 +19,12 @@ INSTANCE_TO_TEMPLATEID = {
19
19
  'V100-32Gx2': 'twnlo3zj',
20
20
  'V100-32G': 'twnlo3zj',
21
21
  'V100': 'twnlo3zj',
22
+ 'GPU+': 'twnlo3zj',
23
+ 'P4000': 'twnlo3zj',
24
+ 'P4000x2': 'twnlo3zj',
25
+ 'A4000': 'twnlo3zj',
26
+ 'A4000x2': 'twnlo3zj',
27
+ 'A4000x4': 'twnlo3zj',
22
28
  **CPU_INSTANCES_TEMPLATEID
23
29
  }
24
30
  NVLINK_INSTANCES = {
@@ -259,6 +259,8 @@ def _ssh_probe_command(ip: str,
259
259
  '-o',
260
260
  'IdentitiesOnly=yes',
261
261
  '-o',
262
+ 'AddKeysToAgent=yes',
263
+ '-o',
262
264
  'ExitOnForwardFailure=yes',
263
265
  '-o',
264
266
  'ServerAliveInterval=5',
@@ -18,7 +18,7 @@ provider:
18
18
 
19
19
  region: kubernetes
20
20
 
21
- # The namespace to create the Ray cluster in.
21
+
22
22
  namespace: {{k8s_namespace}}
23
23
 
24
24
  # The kubecontext used to connect to the Kubernetes cluster.
@@ -85,6 +85,10 @@ def ssh_options_list(
85
85
  'LogLevel': 'ERROR',
86
86
  # Try fewer extraneous key pairs.
87
87
  'IdentitiesOnly': 'yes',
88
+ # Add the current private key used for this SSH connection to the
89
+ # SSH agent, so that forward agent parameter will then make SSH
90
+ # agent forward it.
91
+ 'AddKeysToAgent': 'yes',
88
92
  # Abort if port forwarding fails (instead of just printing to
89
93
  # stderr).
90
94
  'ExitOnForwardFailure': 'yes',
sky/utils/schemas.py CHANGED
@@ -775,6 +775,12 @@ def get_config_schema():
775
775
  'required': [],
776
776
  'additionalProperties': False,
777
777
  'properties': {
778
+ 'allowed_contexts': {
779
+ 'type': 'array',
780
+ 'items': {
781
+ 'type': 'string',
782
+ },
783
+ },
778
784
  'networking': {
779
785
  'type': 'string',
780
786
  'case_insensitive_enum': [