skypilot-nightly 1.0.0.dev20250215__py3-none-any.whl → 1.0.0.dev20250217__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sky/__init__.py +48 -22
  2. sky/adaptors/aws.py +2 -1
  3. sky/adaptors/azure.py +4 -4
  4. sky/adaptors/cloudflare.py +4 -4
  5. sky/adaptors/kubernetes.py +8 -8
  6. sky/authentication.py +42 -45
  7. sky/backends/backend.py +2 -2
  8. sky/backends/backend_utils.py +108 -221
  9. sky/backends/cloud_vm_ray_backend.py +283 -282
  10. sky/benchmark/benchmark_utils.py +6 -2
  11. sky/check.py +40 -28
  12. sky/cli.py +1213 -1116
  13. sky/client/__init__.py +1 -0
  14. sky/client/cli.py +5644 -0
  15. sky/client/common.py +345 -0
  16. sky/client/sdk.py +1757 -0
  17. sky/cloud_stores.py +12 -6
  18. sky/clouds/__init__.py +0 -2
  19. sky/clouds/aws.py +20 -13
  20. sky/clouds/azure.py +5 -3
  21. sky/clouds/cloud.py +1 -1
  22. sky/clouds/cudo.py +2 -1
  23. sky/clouds/do.py +2 -1
  24. sky/clouds/fluidstack.py +3 -2
  25. sky/clouds/gcp.py +10 -8
  26. sky/clouds/ibm.py +8 -7
  27. sky/clouds/kubernetes.py +7 -6
  28. sky/clouds/lambda_cloud.py +8 -7
  29. sky/clouds/oci.py +4 -3
  30. sky/clouds/paperspace.py +2 -1
  31. sky/clouds/runpod.py +2 -1
  32. sky/clouds/scp.py +8 -7
  33. sky/clouds/service_catalog/__init__.py +3 -3
  34. sky/clouds/service_catalog/aws_catalog.py +7 -1
  35. sky/clouds/service_catalog/common.py +4 -2
  36. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +2 -2
  37. sky/clouds/utils/oci_utils.py +1 -1
  38. sky/clouds/vast.py +2 -1
  39. sky/clouds/vsphere.py +2 -1
  40. sky/core.py +263 -99
  41. sky/dag.py +4 -0
  42. sky/data/mounting_utils.py +2 -1
  43. sky/data/storage.py +97 -35
  44. sky/data/storage_utils.py +69 -9
  45. sky/exceptions.py +138 -5
  46. sky/execution.py +47 -50
  47. sky/global_user_state.py +105 -22
  48. sky/jobs/__init__.py +12 -14
  49. sky/jobs/client/__init__.py +0 -0
  50. sky/jobs/client/sdk.py +296 -0
  51. sky/jobs/constants.py +30 -1
  52. sky/jobs/controller.py +12 -6
  53. sky/jobs/dashboard/dashboard.py +2 -6
  54. sky/jobs/recovery_strategy.py +22 -29
  55. sky/jobs/server/__init__.py +1 -0
  56. sky/jobs/{core.py → server/core.py} +101 -34
  57. sky/jobs/server/dashboard_utils.py +64 -0
  58. sky/jobs/server/server.py +182 -0
  59. sky/jobs/utils.py +32 -23
  60. sky/models.py +27 -0
  61. sky/optimizer.py +9 -11
  62. sky/provision/__init__.py +6 -3
  63. sky/provision/aws/config.py +2 -2
  64. sky/provision/aws/instance.py +1 -1
  65. sky/provision/azure/instance.py +1 -1
  66. sky/provision/cudo/instance.py +1 -1
  67. sky/provision/do/instance.py +1 -1
  68. sky/provision/do/utils.py +0 -5
  69. sky/provision/fluidstack/fluidstack_utils.py +4 -3
  70. sky/provision/fluidstack/instance.py +4 -2
  71. sky/provision/gcp/instance.py +1 -1
  72. sky/provision/instance_setup.py +2 -2
  73. sky/provision/kubernetes/constants.py +8 -0
  74. sky/provision/kubernetes/instance.py +1 -1
  75. sky/provision/kubernetes/utils.py +67 -76
  76. sky/provision/lambda_cloud/instance.py +3 -15
  77. sky/provision/logging.py +1 -1
  78. sky/provision/oci/instance.py +7 -4
  79. sky/provision/paperspace/instance.py +1 -1
  80. sky/provision/provisioner.py +3 -2
  81. sky/provision/runpod/instance.py +1 -1
  82. sky/provision/vast/instance.py +1 -1
  83. sky/provision/vast/utils.py +2 -1
  84. sky/provision/vsphere/instance.py +2 -11
  85. sky/resources.py +55 -40
  86. sky/serve/__init__.py +6 -10
  87. sky/serve/client/__init__.py +0 -0
  88. sky/serve/client/sdk.py +366 -0
  89. sky/serve/constants.py +3 -0
  90. sky/serve/replica_managers.py +10 -10
  91. sky/serve/serve_utils.py +56 -36
  92. sky/serve/server/__init__.py +0 -0
  93. sky/serve/{core.py → server/core.py} +37 -17
  94. sky/serve/server/server.py +117 -0
  95. sky/serve/service.py +8 -1
  96. sky/server/__init__.py +1 -0
  97. sky/server/common.py +441 -0
  98. sky/server/constants.py +21 -0
  99. sky/server/html/log.html +174 -0
  100. sky/server/requests/__init__.py +0 -0
  101. sky/server/requests/executor.py +462 -0
  102. sky/server/requests/payloads.py +481 -0
  103. sky/server/requests/queues/__init__.py +0 -0
  104. sky/server/requests/queues/mp_queue.py +76 -0
  105. sky/server/requests/requests.py +567 -0
  106. sky/server/requests/serializers/__init__.py +0 -0
  107. sky/server/requests/serializers/decoders.py +192 -0
  108. sky/server/requests/serializers/encoders.py +166 -0
  109. sky/server/server.py +1095 -0
  110. sky/server/stream_utils.py +144 -0
  111. sky/setup_files/MANIFEST.in +1 -0
  112. sky/setup_files/dependencies.py +12 -4
  113. sky/setup_files/setup.py +1 -1
  114. sky/sky_logging.py +9 -13
  115. sky/skylet/autostop_lib.py +2 -2
  116. sky/skylet/constants.py +46 -12
  117. sky/skylet/events.py +5 -6
  118. sky/skylet/job_lib.py +78 -66
  119. sky/skylet/log_lib.py +17 -11
  120. sky/skypilot_config.py +79 -94
  121. sky/task.py +119 -73
  122. sky/templates/aws-ray.yml.j2 +4 -4
  123. sky/templates/azure-ray.yml.j2 +3 -2
  124. sky/templates/cudo-ray.yml.j2 +3 -2
  125. sky/templates/fluidstack-ray.yml.j2 +3 -2
  126. sky/templates/gcp-ray.yml.j2 +3 -2
  127. sky/templates/ibm-ray.yml.j2 +3 -2
  128. sky/templates/jobs-controller.yaml.j2 +1 -12
  129. sky/templates/kubernetes-ray.yml.j2 +3 -2
  130. sky/templates/lambda-ray.yml.j2 +3 -2
  131. sky/templates/oci-ray.yml.j2 +3 -2
  132. sky/templates/paperspace-ray.yml.j2 +3 -2
  133. sky/templates/runpod-ray.yml.j2 +3 -2
  134. sky/templates/scp-ray.yml.j2 +3 -2
  135. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  136. sky/templates/vsphere-ray.yml.j2 +4 -2
  137. sky/templates/websocket_proxy.py +64 -0
  138. sky/usage/constants.py +8 -0
  139. sky/usage/usage_lib.py +45 -11
  140. sky/utils/accelerator_registry.py +33 -53
  141. sky/utils/admin_policy_utils.py +2 -1
  142. sky/utils/annotations.py +51 -0
  143. sky/utils/cli_utils/status_utils.py +33 -3
  144. sky/utils/cluster_utils.py +356 -0
  145. sky/utils/command_runner.py +69 -14
  146. sky/utils/common.py +74 -0
  147. sky/utils/common_utils.py +133 -93
  148. sky/utils/config_utils.py +204 -0
  149. sky/utils/control_master_utils.py +2 -3
  150. sky/utils/controller_utils.py +133 -147
  151. sky/utils/dag_utils.py +72 -24
  152. sky/utils/kubernetes/deploy_remote_cluster.sh +2 -2
  153. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  154. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  155. sky/utils/log_utils.py +83 -23
  156. sky/utils/message_utils.py +81 -0
  157. sky/utils/registry.py +127 -0
  158. sky/utils/resources_utils.py +2 -2
  159. sky/utils/rich_utils.py +213 -34
  160. sky/utils/schemas.py +19 -2
  161. sky/{status_lib.py → utils/status_lib.py} +12 -7
  162. sky/utils/subprocess_utils.py +51 -35
  163. sky/utils/timeline.py +7 -2
  164. sky/utils/ux_utils.py +95 -25
  165. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/METADATA +8 -3
  166. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/RECORD +170 -132
  167. sky/clouds/cloud_registry.py +0 -76
  168. sky/utils/cluster_yaml_utils.py +0 -24
  169. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/LICENSE +0 -0
  170. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/WHEEL +0 -0
  171. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/entry_points.txt +0 -0
  172. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/top_level.txt +0 -0
sky/__init__.py CHANGED
@@ -5,7 +5,7 @@ from typing import Optional
5
5
  import urllib.request
6
6
 
7
7
  # Replaced with the current commit when building the wheels.
8
- _SKYPILOT_COMMIT_SHA = '354bbdf3a1d031b350011bc76570cf8c009ecc4a'
8
+ _SKYPILOT_COMMIT_SHA = '7775d44c4c91d1474982fbdb3b1e03cfbbf2385e'
9
9
 
10
10
 
11
11
  def _get_git_commit():
@@ -35,7 +35,7 @@ def _get_git_commit():
35
35
 
36
36
 
37
37
  __commit__ = _get_git_commit()
38
- __version__ = '1.0.0.dev20250215'
38
+ __version__ = '1.0.0.dev20250217'
39
39
  __root_dir__ = os.path.dirname(os.path.abspath(__file__))
40
40
 
41
41
 
@@ -75,6 +75,7 @@ def _set_http_proxy_env_vars() -> None:
75
75
 
76
76
 
77
77
  _set_http_proxy_env_vars()
78
+
78
79
  # ----------------------------------------------------------------- #
79
80
 
80
81
  # Keep this order to avoid cyclic imports
@@ -85,34 +86,46 @@ from sky import clouds
85
86
  from sky.admin_policy import AdminPolicy
86
87
  from sky.admin_policy import MutatedUserRequest
87
88
  from sky.admin_policy import UserRequest
89
+ from sky.client.sdk import api_cancel
90
+ from sky.client.sdk import api_info
91
+ from sky.client.sdk import api_server_logs
92
+ from sky.client.sdk import api_start
93
+ from sky.client.sdk import api_status
94
+ from sky.client.sdk import api_stop
95
+ from sky.client.sdk import autostop
96
+ from sky.client.sdk import cancel
97
+ from sky.client.sdk import cost_report
98
+ from sky.client.sdk import down
99
+ from sky.client.sdk import download_logs
100
+ from sky.client.sdk import exec # pylint: disable=redefined-builtin
101
+ from sky.client.sdk import get
102
+ from sky.client.sdk import job_status
103
+ from sky.client.sdk import launch
104
+ from sky.client.sdk import optimize
105
+ from sky.client.sdk import queue
106
+ from sky.client.sdk import start
107
+ from sky.client.sdk import status
108
+ from sky.client.sdk import stop
109
+ from sky.client.sdk import storage_delete
110
+ from sky.client.sdk import storage_ls
111
+ from sky.client.sdk import stream_and_get
112
+ from sky.client.sdk import tail_logs
88
113
  from sky.clouds.service_catalog import list_accelerators
89
- from sky.core import autostop
90
- from sky.core import cancel
91
- from sky.core import cost_report
92
- from sky.core import down
93
- from sky.core import download_logs
94
- from sky.core import job_status
95
- from sky.core import queue
96
- from sky.core import start
97
- from sky.core import status
98
- from sky.core import stop
99
- from sky.core import storage_delete
100
- from sky.core import storage_ls
101
- from sky.core import tail_logs
102
114
  from sky.dag import Dag
103
115
  from sky.data import Storage
104
116
  from sky.data import StorageMode
105
117
  from sky.data import StoreType
106
- from sky.execution import exec # pylint: disable=redefined-builtin
107
- from sky.execution import launch
108
118
  from sky.jobs import ManagedJobStatus
109
119
  from sky.optimizer import Optimizer
110
- from sky.optimizer import OptimizeTarget
111
120
  from sky.resources import Resources
112
121
  from sky.skylet.job_lib import JobStatus
113
- from sky.skypilot_config import Config
114
- from sky.status_lib import ClusterStatus
115
122
  from sky.task import Task
123
+ from sky.utils.common import OptimizeTarget
124
+ from sky.utils.common import StatusRefreshMode
125
+ from sky.utils.config_utils import Config
126
+ from sky.utils.registry import CLOUD_REGISTRY
127
+ from sky.utils.registry import JOBS_RECOVERY_STRATEGY_REGISTRY
128
+ from sky.utils.status_lib import ClusterStatus
116
129
 
117
130
  # Aliases.
118
131
  IBM = clouds.IBM
@@ -130,7 +143,6 @@ RunPod = clouds.RunPod
130
143
  Vast = clouds.Vast
131
144
  Vsphere = clouds.Vsphere
132
145
  Fluidstack = clouds.Fluidstack
133
- optimize = Optimizer.optimize
134
146
 
135
147
  __all__ = [
136
148
  '__version__',
@@ -161,11 +173,13 @@ __all__ = [
161
173
  'ClusterStatus',
162
174
  'JobStatus',
163
175
  'ManagedJobStatus',
176
+ 'StatusRefreshMode',
164
177
  # APIs
165
178
  'Dag',
166
179
  'Task',
167
180
  'Resources',
168
- # execution APIs
181
+ # core APIs
182
+ 'optimize',
169
183
  'launch',
170
184
  'exec',
171
185
  # core APIs
@@ -184,9 +198,21 @@ __all__ = [
184
198
  # core APIs Storage Management
185
199
  'storage_ls',
186
200
  'storage_delete',
201
+ # API server APIs
202
+ 'get',
203
+ 'stream_and_get',
204
+ 'api_status',
205
+ 'api_cancel',
206
+ 'api_info',
207
+ 'api_start',
208
+ 'api_stop',
209
+ 'api_server_logs',
187
210
  # Admin Policy
188
211
  'UserRequest',
189
212
  'MutatedUserRequest',
190
213
  'AdminPolicy',
191
214
  'Config',
215
+ # Registry
216
+ 'CLOUD_REGISTRY',
217
+ 'JOBS_RECOVERY_STRATEGY_REGISTRY',
192
218
  ]
sky/adaptors/aws.py CHANGED
@@ -35,6 +35,7 @@ import time
35
35
  from typing import Any, Callable
36
36
 
37
37
  from sky.adaptors import common
38
+ from sky.utils import annotations
38
39
  from sky.utils import common_utils
39
40
 
40
41
  _IMPORT_ERROR_MESSAGE = ('Failed to import dependencies for AWS. '
@@ -59,7 +60,7 @@ class _ThreadLocalLRUCache(threading.local):
59
60
 
60
61
  def __init__(self, maxsize=32):
61
62
  super().__init__()
62
- self.cache = functools.lru_cache(maxsize=maxsize)
63
+ self.cache = annotations.lru_cache(scope='global', maxsize=maxsize)
63
64
 
64
65
 
65
66
  def _thread_local_lru_cache(maxsize=32):
sky/adaptors/azure.py CHANGED
@@ -3,7 +3,6 @@
3
3
  # pylint: disable=import-outside-toplevel
4
4
  import asyncio
5
5
  import datetime
6
- import functools
7
6
  import logging
8
7
  import threading
9
8
  import time
@@ -14,6 +13,7 @@ from sky import exceptions as sky_exceptions
14
13
  from sky import sky_logging
15
14
  from sky.adaptors import common
16
15
  from sky.skylet import constants
16
+ from sky.utils import annotations
17
17
  from sky.utils import common_utils
18
18
  from sky.utils import ux_utils
19
19
 
@@ -33,7 +33,7 @@ _MAX_RETRY_FOR_GET_SUBSCRIPTION_ID = 5
33
33
 
34
34
 
35
35
  @common.load_lazy_modules(modules=_LAZY_MODULES)
36
- @functools.lru_cache()
36
+ @annotations.lru_cache(scope='global', maxsize=1)
37
37
  def get_subscription_id() -> str:
38
38
  """Get the default subscription id."""
39
39
  from azure.common import credentials
@@ -69,7 +69,7 @@ def exceptions():
69
69
  return azure_exceptions
70
70
 
71
71
 
72
- @functools.lru_cache()
72
+ @annotations.lru_cache(scope='global')
73
73
  @common.load_lazy_modules(modules=_LAZY_MODULES)
74
74
  def azure_mgmt_models(name: str):
75
75
  if name == 'compute':
@@ -83,7 +83,7 @@ def azure_mgmt_models(name: str):
83
83
  # We should keep the order of the decorators having 'lru_cache' followed
84
84
  # by 'load_lazy_modules' as we need to make sure a caller can call
85
85
  # 'get_client.cache_clear', which is a function provided by 'lru_cache'
86
- @functools.lru_cache()
86
+ @annotations.lru_cache(scope='global')
87
87
  @common.load_lazy_modules(modules=_LAZY_MODULES)
88
88
  def get_client(name: str,
89
89
  subscription_id: Optional[str] = None,
@@ -2,12 +2,12 @@
2
2
  # pylint: disable=import-outside-toplevel
3
3
 
4
4
  import contextlib
5
- import functools
6
5
  import os
7
6
  import threading
8
7
  from typing import Dict, Optional, Tuple
9
8
 
10
9
  from sky.adaptors import common
10
+ from sky.utils import annotations
11
11
  from sky.utils import ux_utils
12
12
 
13
13
  _IMPORT_ERROR_MESSAGE = ('Failed to import dependencies for Cloudflare.'
@@ -62,7 +62,7 @@ def get_r2_credentials(boto3_session):
62
62
  # lru_cache() is thread-safe and it will return the same session object
63
63
  # for different threads.
64
64
  # Reference: https://docs.python.org/3/library/functools.html#functools.lru_cache # pylint: disable=line-too-long
65
- @functools.lru_cache()
65
+ @annotations.lru_cache(scope='global')
66
66
  def session():
67
67
  """Create an AWS session."""
68
68
  # Creating the session object is not thread-safe for boto3,
@@ -76,7 +76,7 @@ def session():
76
76
  return session_
77
77
 
78
78
 
79
- @functools.lru_cache()
79
+ @annotations.lru_cache(scope='global')
80
80
  def resource(resource_name: str, **kwargs):
81
81
  """Create a Cloudflare resource.
82
82
 
@@ -102,7 +102,7 @@ def resource(resource_name: str, **kwargs):
102
102
  **kwargs)
103
103
 
104
104
 
105
- @functools.lru_cache()
105
+ @annotations.lru_cache(scope='global')
106
106
  def client(service_name: str, region):
107
107
  """Create an CLOUDFLARE client of a certain service.
108
108
 
@@ -1,11 +1,11 @@
1
1
  """Kubernetes adaptors"""
2
- import functools
3
2
  import logging
4
3
  import os
5
4
  from typing import Any, Callable, Optional, Set
6
5
 
7
6
  from sky.adaptors import common
8
7
  from sky.sky_logging import set_logging_level
8
+ from sky.utils import annotations
9
9
  from sky.utils import env_options
10
10
  from sky.utils import ux_utils
11
11
 
@@ -106,49 +106,49 @@ def _load_config(context: Optional[str] = None):
106
106
 
107
107
 
108
108
  @_api_logging_decorator('urllib3', logging.ERROR)
109
- @functools.lru_cache()
109
+ @annotations.lru_cache(scope='request')
110
110
  def core_api(context: Optional[str] = None):
111
111
  _load_config(context)
112
112
  return kubernetes.client.CoreV1Api()
113
113
 
114
114
 
115
115
  @_api_logging_decorator('urllib3', logging.ERROR)
116
- @functools.lru_cache()
116
+ @annotations.lru_cache(scope='request')
117
117
  def auth_api(context: Optional[str] = None):
118
118
  _load_config(context)
119
119
  return kubernetes.client.RbacAuthorizationV1Api()
120
120
 
121
121
 
122
122
  @_api_logging_decorator('urllib3', logging.ERROR)
123
- @functools.lru_cache()
123
+ @annotations.lru_cache(scope='request')
124
124
  def networking_api(context: Optional[str] = None):
125
125
  _load_config(context)
126
126
  return kubernetes.client.NetworkingV1Api()
127
127
 
128
128
 
129
129
  @_api_logging_decorator('urllib3', logging.ERROR)
130
- @functools.lru_cache()
130
+ @annotations.lru_cache(scope='request')
131
131
  def custom_objects_api(context: Optional[str] = None):
132
132
  _load_config(context)
133
133
  return kubernetes.client.CustomObjectsApi()
134
134
 
135
135
 
136
136
  @_api_logging_decorator('urllib3', logging.ERROR)
137
- @functools.lru_cache()
137
+ @annotations.lru_cache(scope='global')
138
138
  def node_api(context: Optional[str] = None):
139
139
  _load_config(context)
140
140
  return kubernetes.client.NodeV1Api()
141
141
 
142
142
 
143
143
  @_api_logging_decorator('urllib3', logging.ERROR)
144
- @functools.lru_cache()
144
+ @annotations.lru_cache(scope='request')
145
145
  def apps_api(context: Optional[str] = None):
146
146
  _load_config(context)
147
147
  return kubernetes.client.AppsV1Api()
148
148
 
149
149
 
150
150
  @_api_logging_decorator('urllib3', logging.ERROR)
151
- @functools.lru_cache()
151
+ @annotations.lru_cache(scope='request')
152
152
  def api_client(context: Optional[str] = None):
153
153
  _load_config(context)
154
154
  return kubernetes.client.ApiClient()
sky/authentication.py CHANGED
@@ -12,12 +12,11 @@ in ray yaml config as input,
12
12
  2. Setup the `authorized_keys` on the remote VM with the public key content,
13
13
  by cloud-init or directly using cloud provider's API.
14
14
 
15
- The local machine's public key should not be uploaded to the
16
- `~/.ssh/sky-key.pub` on the remote VM, because it will cause private/public
17
- key pair mismatch when the user tries to launch new VM from that remote VM
18
- using SkyPilot, e.g., the node is used as a jobs controller. (Lambda cloud
19
- is an exception, due to the limitation of the cloud provider. See the
20
- comments in setup_lambda_authentication)
15
+ The local machine's public key should not be uploaded to the remote VM, because
16
+ it will cause private/public key pair mismatch when the user tries to launch new
17
+ VM from that remote VM using SkyPilot, e.g., the node is used as a jobs
18
+ controller. (Lambda cloud is an exception, due to the limitation of the cloud
19
+ provider. See the comments in setup_lambda_authentication)
21
20
  """
22
21
  import copy
23
22
  import functools
@@ -48,6 +47,7 @@ from sky.provision.fluidstack import fluidstack_utils
48
47
  from sky.provision.kubernetes import utils as kubernetes_utils
49
48
  from sky.provision.lambda_cloud import lambda_utils
50
49
  from sky.utils import common_utils
50
+ from sky.utils import config_utils
51
51
  from sky.utils import kubernetes_enums
52
52
  from sky.utils import subprocess_utils
53
53
  from sky.utils import ux_utils
@@ -61,9 +61,24 @@ logger = sky_logging.init_logger(__name__)
61
61
 
62
62
  MAX_TRIALS = 64
63
63
  # TODO(zhwu): Support user specified key pair.
64
- PRIVATE_SSH_KEY_PATH = '~/.ssh/sky-key'
65
- PUBLIC_SSH_KEY_PATH = '~/.ssh/sky-key.pub'
66
- _SSH_KEY_GENERATION_LOCK = '~/.sky/generated/ssh/.__internal-sky-key.lock'
64
+ # We intentionally not have the ssh key pair to be stored in
65
+ # ~/.sky/api_server/clients, i.e. sky.server.common.API_SERVER_CLIENT_DIR,
66
+ # because ssh key pair need to persist across API server restarts, while
67
+ # the former dir is empheral.
68
+ _SSH_KEY_PATH_PREFIX = '~/.sky/clients/{user_hash}/ssh'
69
+
70
+
71
+ def get_ssh_key_and_lock_path() -> Tuple[str, str, str]:
72
+ user_hash = common_utils.get_user_hash()
73
+ user_ssh_key_prefix = _SSH_KEY_PATH_PREFIX.format(user_hash=user_hash)
74
+
75
+ os.makedirs(os.path.expanduser(user_ssh_key_prefix),
76
+ exist_ok=True,
77
+ mode=0o700)
78
+ private_key_path = os.path.join(user_ssh_key_prefix, 'sky-key')
79
+ public_key_path = os.path.join(user_ssh_key_prefix, 'sky-key.pub')
80
+ lock_path = os.path.join(user_ssh_key_prefix, '.__internal-sky-key.lock')
81
+ return private_key_path, public_key_path, lock_path
67
82
 
68
83
 
69
84
  def _generate_rsa_key_pair() -> Tuple[str, str]:
@@ -106,16 +121,17 @@ def _save_key_pair(private_key_path: str, public_key_path: str,
106
121
 
107
122
  def get_or_generate_keys() -> Tuple[str, str]:
108
123
  """Returns the aboslute private and public key paths."""
109
- private_key_path = os.path.expanduser(PRIVATE_SSH_KEY_PATH)
110
- public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH)
124
+ private_key_path, public_key_path, lock_path = get_ssh_key_and_lock_path()
125
+ private_key_path = os.path.expanduser(private_key_path)
126
+ public_key_path = os.path.expanduser(public_key_path)
127
+ lock_path = os.path.expanduser(lock_path)
111
128
 
112
- key_file_lock = os.path.expanduser(_SSH_KEY_GENERATION_LOCK)
113
- lock_dir = os.path.dirname(key_file_lock)
129
+ lock_dir = os.path.dirname(lock_path)
114
130
  # We should have the folder ~/.sky/generated/ssh to have 0o700 permission,
115
131
  # as the ssh configs will be written to this folder as well in
116
132
  # backend_utils.SSHConfigHelper
117
133
  os.makedirs(lock_dir, exist_ok=True, mode=0o700)
118
- with filelock.FileLock(key_file_lock, timeout=10):
134
+ with filelock.FileLock(lock_path, timeout=10):
119
135
  if not os.path.exists(private_key_path):
120
136
  public_key, private_key = _generate_rsa_key_pair()
121
137
  _save_key_pair(private_key_path, public_key_path, private_key,
@@ -276,7 +292,7 @@ def setup_lambda_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
276
292
 
277
293
  # Ensure ssh key is registered with Lambda Cloud
278
294
  lambda_client = lambda_utils.LambdaCloudClient()
279
- public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH)
295
+ _, public_key_path = get_or_generate_keys()
280
296
  with open(public_key_path, 'r', encoding='utf-8') as f:
281
297
  public_key = f.read().strip()
282
298
  prefix = f'sky-key-{common_utils.get_user_hash()}'
@@ -284,26 +300,16 @@ def setup_lambda_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
284
300
  if not exists:
285
301
  lambda_client.register_ssh_key(name, public_key)
286
302
 
287
- # Need to use ~ relative path because Ray uses the same
288
- # path for finding the public key path on both local and head node.
289
- config['auth']['ssh_public_key'] = PUBLIC_SSH_KEY_PATH
290
-
291
- # TODO(zhwu): we need to avoid uploading the public ssh key to the
292
- # nodes, as that will cause problem when the node is used as jobs
293
- # controller, i.e., the public and private key on the node may
294
- # not match.
295
- file_mounts = config['file_mounts']
296
- file_mounts[PUBLIC_SSH_KEY_PATH] = PUBLIC_SSH_KEY_PATH
297
- config['file_mounts'] = file_mounts
298
-
303
+ config['auth']['remote_key_name'] = name
299
304
  return config
300
305
 
301
306
 
302
- def setup_ibm_authentication(config):
307
+ def setup_ibm_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
303
308
  """ registers keys if they do not exist in sky folder
304
309
  and updates config file.
305
310
  keys default location: '~/.ssh/sky-key' and '~/.ssh/sky-key.pub'
306
311
  """
312
+ private_key_path, _ = get_or_generate_keys()
307
313
 
308
314
  def _get_unique_key_name():
309
315
  suffix_len = 10
@@ -343,17 +349,11 @@ def setup_ibm_authentication(config):
343
349
  else:
344
350
  raise Exception('Failed to register a key') from e
345
351
 
346
- config['auth']['ssh_private_key'] = PRIVATE_SSH_KEY_PATH
352
+ config['auth']['ssh_private_key'] = private_key_path
347
353
 
348
354
  for node_type in config['available_node_types']:
349
355
  config['available_node_types'][node_type]['node_config'][
350
356
  'key_id'] = vpc_key_id
351
-
352
- # Add public key path to file mounts
353
- file_mounts = config['file_mounts']
354
- file_mounts[PUBLIC_SSH_KEY_PATH] = PUBLIC_SSH_KEY_PATH
355
- config['file_mounts'] = file_mounts
356
-
357
357
  return config
358
358
 
359
359
 
@@ -373,10 +373,9 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
373
373
  with ux_utils.print_exception_no_traceback():
374
374
  raise ValueError(str(e) + ' Please check: ~/.sky/config.yaml.') \
375
375
  from None
376
- get_or_generate_keys()
376
+ _, public_key_path = get_or_generate_keys()
377
377
 
378
378
  # Add the user's public key to the SkyPilot cluster.
379
- public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH)
380
379
  secret_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME
381
380
  secret_field_name = clouds.Kubernetes().ssh_key_secret_field_name
382
381
  context = config['provider'].get(
@@ -386,9 +385,7 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
386
385
  # with in-cluster configuration. We need to set the context to None
387
386
  # to use the mounted service account.
388
387
  context = None
389
- namespace = config['provider'].get(
390
- 'namespace',
391
- kubernetes_utils.get_kube_config_context_namespace(context))
388
+ namespace = kubernetes_utils.get_namespace_from_config(config['provider'])
392
389
  k8s = kubernetes.kubernetes
393
390
  with open(public_key_path, 'r', encoding='utf-8') as f:
394
391
  public_key = f.read()
@@ -404,7 +401,7 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
404
401
  }
405
402
  custom_metadata = skypilot_config.get_nested(
406
403
  ('kubernetes', 'custom_metadata'), {})
407
- kubernetes_utils.merge_dicts(custom_metadata, secret_metadata)
404
+ config_utils.merge_k8s_configs(secret_metadata, custom_metadata)
408
405
 
409
406
  secret = k8s.client.V1Secret(
410
407
  metadata=k8s.client.V1ObjectMeta(**secret_metadata),
@@ -468,6 +465,7 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
468
465
  # This should never happen because we check for this in from_str above.
469
466
  raise ValueError(f'Unsupported networking mode: {network_mode_str}')
470
467
  config['auth']['ssh_proxy_command'] = ssh_proxy_cmd
468
+ config['auth']['ssh_private_key'] = private_key_path
471
469
 
472
470
  return config
473
471
 
@@ -499,19 +497,18 @@ def setup_vast_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
499
497
  if not any(x['public_key'] == public_key for x in current_key_list):
500
498
  vast.vast().create_ssh_key(ssh_key=public_key)
501
499
 
502
- config['auth']['ssh_public_key'] = PUBLIC_SSH_KEY_PATH
500
+ config['auth']['ssh_public_key'] = public_key_path
503
501
  return configure_ssh_info(config)
504
502
 
505
503
 
506
504
  def setup_fluidstack_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
507
505
 
508
- get_or_generate_keys()
506
+ _, public_key_path = get_or_generate_keys()
509
507
 
510
508
  client = fluidstack_utils.FluidstackClient()
511
- public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH)
512
509
  public_key = None
513
510
  with open(public_key_path, 'r', encoding='utf-8') as f:
514
511
  public_key = f.read()
515
512
  client.get_or_add_ssh_key(public_key)
516
- config['auth']['ssh_public_key'] = PUBLIC_SSH_KEY_PATH
513
+ config['auth']['ssh_public_key'] = public_key_path
517
514
  return configure_ssh_info(config)
sky/backends/backend.py CHANGED
@@ -2,8 +2,8 @@
2
2
  import typing
3
3
  from typing import Dict, Generic, Optional
4
4
 
5
- import sky
6
5
  from sky.usage import usage_lib
6
+ from sky.utils import cluster_utils
7
7
  from sky.utils import rich_utils
8
8
  from sky.utils import timeline
9
9
  from sky.utils import ux_utils
@@ -77,7 +77,7 @@ class Backend(Generic[_ResourceHandleType]):
77
77
  dryrun is True.
78
78
  """
79
79
  if cluster_name is None:
80
- cluster_name = sky.backends.backend_utils.generate_cluster_name()
80
+ cluster_name = cluster_utils.generate_cluster_name()
81
81
  usage_lib.record_cluster_name_for_current_operation(cluster_name)
82
82
  usage_lib.messages.usage.update_actual_task(task)
83
83
  with rich_utils.safe_status(ux_utils.spinner_message('Launching')):