skypilot-nightly 1.0.0.dev20251009__py3-none-any.whl → 1.0.0.dev20251107__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of skypilot-nightly might be problematic. Click here for more details.
- sky/__init__.py +6 -2
- sky/adaptors/aws.py +25 -7
- sky/adaptors/coreweave.py +278 -0
- sky/adaptors/kubernetes.py +64 -0
- sky/adaptors/shadeform.py +89 -0
- sky/admin_policy.py +20 -0
- sky/authentication.py +59 -149
- sky/backends/backend_utils.py +104 -63
- sky/backends/cloud_vm_ray_backend.py +84 -39
- sky/catalog/data_fetchers/fetch_runpod.py +698 -0
- sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
- sky/catalog/kubernetes_catalog.py +24 -28
- sky/catalog/runpod_catalog.py +5 -1
- sky/catalog/shadeform_catalog.py +165 -0
- sky/check.py +25 -13
- sky/client/cli/command.py +335 -86
- sky/client/cli/flags.py +4 -2
- sky/client/cli/table_utils.py +17 -9
- sky/client/sdk.py +59 -12
- sky/cloud_stores.py +73 -0
- sky/clouds/__init__.py +2 -0
- sky/clouds/aws.py +71 -16
- sky/clouds/azure.py +12 -5
- sky/clouds/cloud.py +19 -9
- sky/clouds/cudo.py +12 -5
- sky/clouds/do.py +4 -1
- sky/clouds/fluidstack.py +12 -5
- sky/clouds/gcp.py +12 -5
- sky/clouds/hyperbolic.py +12 -5
- sky/clouds/ibm.py +12 -5
- sky/clouds/kubernetes.py +62 -25
- sky/clouds/lambda_cloud.py +12 -5
- sky/clouds/nebius.py +12 -5
- sky/clouds/oci.py +12 -5
- sky/clouds/paperspace.py +4 -1
- sky/clouds/primeintellect.py +4 -1
- sky/clouds/runpod.py +12 -5
- sky/clouds/scp.py +12 -5
- sky/clouds/seeweb.py +4 -1
- sky/clouds/shadeform.py +400 -0
- sky/clouds/ssh.py +4 -2
- sky/clouds/vast.py +12 -5
- sky/clouds/vsphere.py +4 -1
- sky/core.py +12 -11
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
- sky/dashboard/out/_next/static/chunks/{1871-49141c317f3a9020.js → 1871-74503c8e80fd253b.js} +1 -1
- sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
- sky/dashboard/out/_next/static/chunks/2755.fff53c4a3fcae910.js +26 -0
- sky/dashboard/out/_next/static/chunks/3294.72362fa129305b19.js +1 -0
- sky/dashboard/out/_next/static/chunks/{3785.a19328ba41517b8b.js → 3785.ad6adaa2a0fa9768.js} +1 -1
- sky/dashboard/out/_next/static/chunks/{4725.10f7a9a5d3ea8208.js → 4725.a830b5c9e7867c92.js} +1 -1
- sky/dashboard/out/_next/static/chunks/6856-ef8ba11f96d8c4a3.js +1 -0
- sky/dashboard/out/_next/static/chunks/6990-32b6e2d3822301fa.js +1 -0
- sky/dashboard/out/_next/static/chunks/7615-3301e838e5f25772.js +1 -0
- sky/dashboard/out/_next/static/chunks/8969-1e4613c651bf4051.js +1 -0
- sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
- sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
- sky/dashboard/out/_next/static/chunks/9360.7310982cf5a0dc79.js +31 -0
- sky/dashboard/out/_next/static/chunks/pages/{_app-ce361c6959bc2001.js → _app-bde01e4a2beec258.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-c736ead69c2d86ec.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-477555ab7c0b13d8.js → [cluster]-a37d2063af475a1c.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{clusters-2f61f65487f6d8ff.js → clusters-d44859594e6f8064.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-553b8b5cb65e100b.js → [context]-c0b5935149902e6f.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{infra-910a22500c50596f.js → infra-aed0ea19df7cf961.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-5796e8d6aea291a0.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/{[pool]-bc979970c247d8f3.js → [pool]-6edeb7d06032adfc.js} +2 -2
- sky/dashboard/out/_next/static/chunks/pages/{jobs-a35a9dc3c5ccd657.js → jobs-479dde13399cf270.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{users-98d2ed979084162a.js → users-5ab3b907622cf0fe.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{volumes-835d14ba94808f79.js → volumes-b84b948ff357c43e.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-e8688c35c06f0ac5.js → [name]-c5a3eeee1c218af1.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{workspaces-69c80d677d3c2949.js → workspaces-22b23febb3e89ce1.js} +1 -1
- sky/dashboard/out/_next/static/chunks/webpack-2679be77fc08a2f8.js +1 -0
- sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
- sky/dashboard/out/_next/static/zB0ed6ge_W1MDszVHhijS/_buildManifest.js +1 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/volumes.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/data/data_utils.py +92 -1
- sky/data/mounting_utils.py +143 -19
- sky/data/storage.py +168 -11
- sky/exceptions.py +13 -1
- sky/execution.py +13 -0
- sky/global_user_state.py +189 -113
- sky/jobs/client/sdk.py +32 -10
- sky/jobs/client/sdk_async.py +9 -3
- sky/jobs/constants.py +3 -1
- sky/jobs/controller.py +164 -192
- sky/jobs/file_content_utils.py +80 -0
- sky/jobs/log_gc.py +201 -0
- sky/jobs/recovery_strategy.py +59 -82
- sky/jobs/scheduler.py +20 -9
- sky/jobs/server/core.py +105 -23
- sky/jobs/server/server.py +40 -28
- sky/jobs/server/utils.py +32 -11
- sky/jobs/state.py +588 -110
- sky/jobs/utils.py +442 -209
- sky/logs/agent.py +1 -1
- sky/metrics/utils.py +45 -6
- sky/optimizer.py +1 -1
- sky/provision/__init__.py +7 -0
- sky/provision/aws/instance.py +2 -1
- sky/provision/azure/instance.py +2 -1
- sky/provision/common.py +2 -0
- sky/provision/cudo/instance.py +2 -1
- sky/provision/do/instance.py +2 -1
- sky/provision/fluidstack/instance.py +4 -3
- sky/provision/gcp/instance.py +2 -1
- sky/provision/hyperbolic/instance.py +2 -1
- sky/provision/instance_setup.py +10 -2
- sky/provision/kubernetes/constants.py +0 -1
- sky/provision/kubernetes/instance.py +222 -89
- sky/provision/kubernetes/network.py +12 -8
- sky/provision/kubernetes/utils.py +114 -53
- sky/provision/kubernetes/volume.py +5 -4
- sky/provision/lambda_cloud/instance.py +2 -1
- sky/provision/nebius/instance.py +2 -1
- sky/provision/oci/instance.py +2 -1
- sky/provision/paperspace/instance.py +2 -1
- sky/provision/provisioner.py +11 -2
- sky/provision/runpod/instance.py +2 -1
- sky/provision/scp/instance.py +2 -1
- sky/provision/seeweb/instance.py +3 -3
- sky/provision/shadeform/__init__.py +11 -0
- sky/provision/shadeform/config.py +12 -0
- sky/provision/shadeform/instance.py +351 -0
- sky/provision/shadeform/shadeform_utils.py +83 -0
- sky/provision/vast/instance.py +2 -1
- sky/provision/vsphere/instance.py +2 -1
- sky/resources.py +1 -1
- sky/schemas/api/responses.py +9 -5
- sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
- sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
- sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
- sky/schemas/generated/jobsv1_pb2.py +52 -52
- sky/schemas/generated/jobsv1_pb2.pyi +4 -2
- sky/schemas/generated/managed_jobsv1_pb2.py +39 -35
- sky/schemas/generated/managed_jobsv1_pb2.pyi +21 -5
- sky/serve/client/impl.py +11 -3
- sky/serve/replica_managers.py +5 -2
- sky/serve/serve_utils.py +9 -2
- sky/serve/server/impl.py +7 -2
- sky/serve/server/server.py +18 -15
- sky/serve/service.py +2 -2
- sky/server/auth/oauth2_proxy.py +2 -5
- sky/server/common.py +31 -28
- sky/server/constants.py +5 -1
- sky/server/daemons.py +27 -19
- sky/server/requests/executor.py +138 -74
- sky/server/requests/payloads.py +9 -1
- sky/server/requests/preconditions.py +13 -10
- sky/server/requests/request_names.py +120 -0
- sky/server/requests/requests.py +485 -153
- sky/server/requests/serializers/decoders.py +26 -13
- sky/server/requests/serializers/encoders.py +56 -11
- sky/server/requests/threads.py +106 -0
- sky/server/rest.py +70 -18
- sky/server/server.py +283 -104
- sky/server/stream_utils.py +233 -59
- sky/server/uvicorn.py +18 -17
- sky/setup_files/alembic.ini +4 -0
- sky/setup_files/dependencies.py +32 -13
- sky/sky_logging.py +0 -2
- sky/skylet/constants.py +30 -7
- sky/skylet/events.py +7 -0
- sky/skylet/log_lib.py +8 -2
- sky/skylet/log_lib.pyi +1 -1
- sky/skylet/services.py +26 -13
- sky/skylet/subprocess_daemon.py +103 -29
- sky/skypilot_config.py +87 -75
- sky/ssh_node_pools/server.py +9 -8
- sky/task.py +67 -54
- sky/templates/kubernetes-ray.yml.j2 +8 -1
- sky/templates/nebius-ray.yml.j2 +1 -0
- sky/templates/shadeform-ray.yml.j2 +72 -0
- sky/templates/websocket_proxy.py +142 -12
- sky/users/permission.py +8 -1
- sky/utils/admin_policy_utils.py +16 -3
- sky/utils/asyncio_utils.py +78 -0
- sky/utils/auth_utils.py +153 -0
- sky/utils/cli_utils/status_utils.py +8 -2
- sky/utils/command_runner.py +11 -0
- sky/utils/common.py +3 -1
- sky/utils/common_utils.py +7 -4
- sky/utils/context.py +57 -51
- sky/utils/context_utils.py +30 -12
- sky/utils/controller_utils.py +35 -8
- sky/utils/db/db_utils.py +37 -10
- sky/utils/db/migration_utils.py +8 -4
- sky/utils/locks.py +24 -6
- sky/utils/resource_checker.py +4 -1
- sky/utils/resources_utils.py +53 -29
- sky/utils/schemas.py +23 -4
- sky/utils/subprocess_utils.py +17 -4
- sky/volumes/server/server.py +7 -6
- sky/workspaces/server.py +13 -12
- {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/METADATA +306 -55
- {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/RECORD +215 -195
- sky/dashboard/out/_next/static/chunks/1121-d0782b9251f0fcd3.js +0 -1
- sky/dashboard/out/_next/static/chunks/1141-3b40c39626f99c89.js +0 -11
- sky/dashboard/out/_next/static/chunks/2755.97300e1362fe7c98.js +0 -26
- sky/dashboard/out/_next/static/chunks/3015-8d748834fcc60b46.js +0 -1
- sky/dashboard/out/_next/static/chunks/3294.1fafbf42b3bcebff.js +0 -1
- sky/dashboard/out/_next/static/chunks/6135-4b4d5e824b7f9d3c.js +0 -1
- sky/dashboard/out/_next/static/chunks/6856-5fdc9b851a18acdb.js +0 -1
- sky/dashboard/out/_next/static/chunks/6990-f6818c84ed8f1c86.js +0 -1
- sky/dashboard/out/_next/static/chunks/8969-66237729cdf9749e.js +0 -1
- sky/dashboard/out/_next/static/chunks/9025.c12318fb6a1a9093.js +0 -6
- sky/dashboard/out/_next/static/chunks/9360.71e83b2ddc844ec2.js +0 -31
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-8f058b0346db2aff.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-4f7079dcab6ed653.js +0 -16
- sky/dashboard/out/_next/static/chunks/webpack-6a5ddd0184bfa22c.js +0 -1
- sky/dashboard/out/_next/static/css/4614e06482d7309e.css +0 -3
- sky/dashboard/out/_next/static/hIViZcQBkn0HE8SpaSsUU/_buildManifest.js +0 -1
- /sky/dashboard/out/_next/static/{hIViZcQBkn0HE8SpaSsUU → zB0ed6ge_W1MDszVHhijS}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20251009.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/top_level.txt +0 -0
sky/__init__.py
CHANGED
|
@@ -7,7 +7,7 @@ import urllib.request
|
|
|
7
7
|
from sky.utils import directory_utils
|
|
8
8
|
|
|
9
9
|
# Replaced with the current commit when building the wheels.
|
|
10
|
-
_SKYPILOT_COMMIT_SHA = '
|
|
10
|
+
_SKYPILOT_COMMIT_SHA = 'd7530d48bc1a331b5644bd6fbbc51eaebf1432f3'
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def _get_git_commit():
|
|
@@ -37,7 +37,7 @@ def _get_git_commit():
|
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
__commit__ = _get_git_commit()
|
|
40
|
-
__version__ = '1.0.0.
|
|
40
|
+
__version__ = '1.0.0.dev20251107'
|
|
41
41
|
__root_dir__ = directory_utils.get_sky_dir()
|
|
42
42
|
|
|
43
43
|
|
|
@@ -122,6 +122,7 @@ from sky.data import StoreType
|
|
|
122
122
|
from sky.jobs import ManagedJobStatus
|
|
123
123
|
from sky.optimizer import Optimizer
|
|
124
124
|
from sky.resources import Resources
|
|
125
|
+
from sky.server.requests.request_names import AdminPolicyRequestName
|
|
125
126
|
from sky.skylet.job_lib import JobStatus
|
|
126
127
|
from sky.task import Task
|
|
127
128
|
from sky.utils.common import OptimizeTarget
|
|
@@ -150,6 +151,7 @@ Vsphere = clouds.Vsphere
|
|
|
150
151
|
Fluidstack = clouds.Fluidstack
|
|
151
152
|
Nebius = clouds.Nebius
|
|
152
153
|
Hyperbolic = clouds.Hyperbolic
|
|
154
|
+
Shadeform = clouds.Shadeform
|
|
153
155
|
Seeweb = clouds.Seeweb
|
|
154
156
|
|
|
155
157
|
__all__ = [
|
|
@@ -172,6 +174,7 @@ __all__ = [
|
|
|
172
174
|
'Fluidstack',
|
|
173
175
|
'Nebius',
|
|
174
176
|
'Hyperbolic',
|
|
177
|
+
'Shadeform',
|
|
175
178
|
'Seeweb',
|
|
176
179
|
'Optimizer',
|
|
177
180
|
'OptimizeTarget',
|
|
@@ -226,6 +229,7 @@ __all__ = [
|
|
|
226
229
|
'MutatedUserRequest',
|
|
227
230
|
'AdminPolicy',
|
|
228
231
|
'Config',
|
|
232
|
+
'AdminPolicyRequestName',
|
|
229
233
|
# Registry
|
|
230
234
|
'CLOUD_REGISTRY',
|
|
231
235
|
'JOBS_RECOVERY_STRATEGY_REGISTRY',
|
sky/adaptors/aws.py
CHANGED
|
@@ -34,6 +34,7 @@ import time
|
|
|
34
34
|
import typing
|
|
35
35
|
from typing import Callable, Literal, Optional, TypeVar
|
|
36
36
|
|
|
37
|
+
from sky import skypilot_config
|
|
37
38
|
from sky.adaptors import common
|
|
38
39
|
from sky.utils import annotations
|
|
39
40
|
from sky.utils import common_utils
|
|
@@ -119,12 +120,27 @@ def _create_aws_object(creation_fn_or_cls: Callable[[], T],
|
|
|
119
120
|
f'{common_utils.format_exception(e)}.')
|
|
120
121
|
|
|
121
122
|
|
|
123
|
+
def get_workspace_profile() -> Optional[str]:
|
|
124
|
+
"""Get AWS profile name from workspace config."""
|
|
125
|
+
return skypilot_config.get_workspace_cloud('aws').get('profile', None)
|
|
126
|
+
|
|
127
|
+
|
|
122
128
|
# The LRU cache needs to be thread-local to avoid multiple threads sharing the
|
|
123
129
|
# same session object, which is not guaranteed to be thread-safe.
|
|
124
130
|
@_thread_local_lru_cache()
|
|
125
|
-
def session(check_credentials: bool = True):
|
|
126
|
-
"""Create an AWS session.
|
|
127
|
-
|
|
131
|
+
def session(check_credentials: bool = True, profile: Optional[str] = None):
|
|
132
|
+
"""Create an AWS session.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
check_credentials: Whether to check if credentials are available.
|
|
136
|
+
profile: AWS profile name to use. If None, uses default credentials.
|
|
137
|
+
"""
|
|
138
|
+
if profile is not None:
|
|
139
|
+
logger.debug(f'Using AWS profile \'{profile}\'.')
|
|
140
|
+
s = _create_aws_object(
|
|
141
|
+
lambda: boto3.session.Session(profile_name=profile), 'session')
|
|
142
|
+
else:
|
|
143
|
+
s = _create_aws_object(boto3.session.Session, 'session')
|
|
128
144
|
if check_credentials and s.get_credentials() is None:
|
|
129
145
|
# s.get_credentials() can be None if there are actually no credentials,
|
|
130
146
|
# or if we fail to get credentials from IMDS (e.g. due to throttling).
|
|
@@ -180,13 +196,14 @@ def resource(service_name: str, **kwargs):
|
|
|
180
196
|
kwargs['config'] = config
|
|
181
197
|
|
|
182
198
|
check_credentials = kwargs.pop('check_credentials', True)
|
|
199
|
+
profile = get_workspace_profile()
|
|
183
200
|
|
|
184
201
|
# Need to use the client retrieved from the per-thread session to avoid
|
|
185
202
|
# thread-safety issues (Directly creating the client with boto3.resource()
|
|
186
203
|
# is not thread-safe). Reference: https://stackoverflow.com/a/59635814
|
|
187
204
|
return _create_aws_object(
|
|
188
|
-
lambda: session(check_credentials=check_credentials).
|
|
189
|
-
|
|
205
|
+
lambda: session(check_credentials=check_credentials, profile=profile).
|
|
206
|
+
resource(service_name, **kwargs), 'resource')
|
|
190
207
|
|
|
191
208
|
|
|
192
209
|
# New typing overloads can be added as needed.
|
|
@@ -221,14 +238,15 @@ def client(service_name: str, **kwargs):
|
|
|
221
238
|
_assert_kwargs_builtin_type(kwargs)
|
|
222
239
|
|
|
223
240
|
check_credentials = kwargs.pop('check_credentials', True)
|
|
241
|
+
profile = get_workspace_profile()
|
|
224
242
|
|
|
225
243
|
# Need to use the client retrieved from the per-thread session to avoid
|
|
226
244
|
# thread-safety issues (Directly creating the client with boto3.client() is
|
|
227
245
|
# not thread-safe). Reference: https://stackoverflow.com/a/59635814
|
|
228
246
|
|
|
229
247
|
return _create_aws_object(
|
|
230
|
-
lambda: session(check_credentials=check_credentials).
|
|
231
|
-
|
|
248
|
+
lambda: session(check_credentials=check_credentials, profile=profile).
|
|
249
|
+
client(service_name, **kwargs), 'client')
|
|
232
250
|
|
|
233
251
|
|
|
234
252
|
@common.load_lazy_modules(modules=_LAZY_MODULES)
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
"""CoreWeave cloud adaptor."""
|
|
2
|
+
|
|
3
|
+
import configparser
|
|
4
|
+
import contextlib
|
|
5
|
+
import os
|
|
6
|
+
import threading
|
|
7
|
+
from typing import Dict, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
from sky import exceptions
|
|
10
|
+
from sky import sky_logging
|
|
11
|
+
from sky.adaptors import common
|
|
12
|
+
from sky.clouds import cloud
|
|
13
|
+
from sky.utils import annotations
|
|
14
|
+
from sky.utils import ux_utils
|
|
15
|
+
|
|
16
|
+
logger = sky_logging.init_logger(__name__)
|
|
17
|
+
|
|
18
|
+
COREWEAVE_PROFILE_NAME = 'cw'
|
|
19
|
+
COREWEAVE_CREDENTIALS_PATH = '~/.coreweave/cw.credentials'
|
|
20
|
+
COREWEAVE_CONFIG_PATH = '~/.coreweave/cw.config'
|
|
21
|
+
NAME = 'CoreWeave'
|
|
22
|
+
DEFAULT_REGION = 'US-EAST-01A'
|
|
23
|
+
_DEFAULT_ENDPOINT = 'https://cwobject.com'
|
|
24
|
+
_INDENT_PREFIX = ' '
|
|
25
|
+
|
|
26
|
+
_IMPORT_ERROR_MESSAGE = ('Failed to import dependencies for CoreWeave.'
|
|
27
|
+
'Try pip install "skypilot[coreweave]"')
|
|
28
|
+
|
|
29
|
+
boto3 = common.LazyImport('boto3', import_error_message=_IMPORT_ERROR_MESSAGE)
|
|
30
|
+
botocore = common.LazyImport('botocore',
|
|
31
|
+
import_error_message=_IMPORT_ERROR_MESSAGE)
|
|
32
|
+
|
|
33
|
+
_LAZY_MODULES = (boto3, botocore)
|
|
34
|
+
_session_creation_lock = threading.RLock()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@contextlib.contextmanager
|
|
38
|
+
def _load_cw_credentials_env():
|
|
39
|
+
"""Context manager to temporarily change the AWS credentials file path."""
|
|
40
|
+
prev_credentials_path = os.environ.get('AWS_SHARED_CREDENTIALS_FILE')
|
|
41
|
+
prev_config_path = os.environ.get('AWS_CONFIG_FILE')
|
|
42
|
+
os.environ['AWS_SHARED_CREDENTIALS_FILE'] = COREWEAVE_CREDENTIALS_PATH
|
|
43
|
+
os.environ['AWS_CONFIG_FILE'] = COREWEAVE_CONFIG_PATH
|
|
44
|
+
try:
|
|
45
|
+
yield
|
|
46
|
+
finally:
|
|
47
|
+
if prev_credentials_path is None:
|
|
48
|
+
del os.environ['AWS_SHARED_CREDENTIALS_FILE']
|
|
49
|
+
else:
|
|
50
|
+
os.environ['AWS_SHARED_CREDENTIALS_FILE'] = prev_credentials_path
|
|
51
|
+
if prev_config_path is None:
|
|
52
|
+
del os.environ['AWS_CONFIG_FILE']
|
|
53
|
+
else:
|
|
54
|
+
os.environ['AWS_CONFIG_FILE'] = prev_config_path
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_coreweave_credentials(boto3_session):
|
|
58
|
+
"""Gets the CoreWeave credentials from the boto3 session object.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
boto3_session: The boto3 session object.
|
|
62
|
+
Returns:
|
|
63
|
+
botocore.credentials.ReadOnlyCredentials object with the CoreWeave
|
|
64
|
+
credentials.
|
|
65
|
+
"""
|
|
66
|
+
with _load_cw_credentials_env():
|
|
67
|
+
coreweave_credentials = boto3_session.get_credentials()
|
|
68
|
+
if coreweave_credentials is None:
|
|
69
|
+
with ux_utils.print_exception_no_traceback():
|
|
70
|
+
raise ValueError('CoreWeave credentials not found. Run '
|
|
71
|
+
'`sky check` to verify credentials are '
|
|
72
|
+
'correctly set up.')
|
|
73
|
+
return coreweave_credentials.get_frozen_credentials()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@annotations.lru_cache(scope='global')
|
|
77
|
+
def session():
|
|
78
|
+
"""Create an AWS session for CoreWeave."""
|
|
79
|
+
# Creating the session object is not thread-safe for boto3,
|
|
80
|
+
# so we add a reentrant lock to synchronize the session creation.
|
|
81
|
+
# Reference: https://github.com/boto/boto3/issues/1592
|
|
82
|
+
# However, the session object itself is thread-safe, so we are
|
|
83
|
+
# able to use lru_cache() to cache the session object.
|
|
84
|
+
with _session_creation_lock:
|
|
85
|
+
with _load_cw_credentials_env():
|
|
86
|
+
session_ = boto3.session.Session(
|
|
87
|
+
profile_name=COREWEAVE_PROFILE_NAME)
|
|
88
|
+
return session_
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@annotations.lru_cache(scope='global')
|
|
92
|
+
def resource(resource_name: str, **kwargs):
|
|
93
|
+
"""Create a CoreWeave resource.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
resource_name: CoreWeave resource name (e.g., 's3').
|
|
97
|
+
kwargs: Other options.
|
|
98
|
+
"""
|
|
99
|
+
# Need to use the resource retrieved from the per-thread session
|
|
100
|
+
# to avoid thread-safety issues (Directly creating the client
|
|
101
|
+
# with boto3.resource() is not thread-safe).
|
|
102
|
+
# Reference: https://stackoverflow.com/a/59635814
|
|
103
|
+
|
|
104
|
+
session_ = session()
|
|
105
|
+
coreweave_credentials = get_coreweave_credentials(session_)
|
|
106
|
+
endpoint = get_endpoint()
|
|
107
|
+
|
|
108
|
+
return session_.resource(
|
|
109
|
+
resource_name,
|
|
110
|
+
endpoint_url=endpoint,
|
|
111
|
+
aws_access_key_id=coreweave_credentials.access_key,
|
|
112
|
+
aws_secret_access_key=coreweave_credentials.secret_key,
|
|
113
|
+
region_name='auto',
|
|
114
|
+
config=botocore.config.Config(s3={'addressing_style': 'virtual'}),
|
|
115
|
+
**kwargs)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@annotations.lru_cache(scope='global')
|
|
119
|
+
def client(service_name: str):
|
|
120
|
+
"""Create CoreWeave client of a certain service.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
service_name: CoreWeave service name (e.g., 's3').
|
|
124
|
+
"""
|
|
125
|
+
# Need to use the client retrieved from the per-thread session
|
|
126
|
+
# to avoid thread-safety issues (Directly creating the client
|
|
127
|
+
# with boto3.client() is not thread-safe).
|
|
128
|
+
# Reference: https://stackoverflow.com/a/59635814
|
|
129
|
+
|
|
130
|
+
session_ = session()
|
|
131
|
+
coreweave_credentials = get_coreweave_credentials(session_)
|
|
132
|
+
endpoint = get_endpoint()
|
|
133
|
+
|
|
134
|
+
return session_.client(
|
|
135
|
+
service_name,
|
|
136
|
+
endpoint_url=endpoint,
|
|
137
|
+
aws_access_key_id=coreweave_credentials.access_key,
|
|
138
|
+
aws_secret_access_key=coreweave_credentials.secret_key,
|
|
139
|
+
region_name='auto',
|
|
140
|
+
config=botocore.config.Config(s3={'addressing_style': 'virtual'}),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@common.load_lazy_modules(_LAZY_MODULES)
|
|
145
|
+
def botocore_exceptions():
|
|
146
|
+
"""AWS botocore exception."""
|
|
147
|
+
# pylint: disable=import-outside-toplevel
|
|
148
|
+
from botocore import exceptions as boto_exceptions
|
|
149
|
+
return boto_exceptions
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def get_endpoint():
|
|
153
|
+
"""Parse the COREWEAVE_CONFIG_PATH to get the endpoint_url.
|
|
154
|
+
|
|
155
|
+
The config file is an AWS-style config file with format:
|
|
156
|
+
[profile cw]
|
|
157
|
+
endpoint_url = https://cwobject.com
|
|
158
|
+
s3 =
|
|
159
|
+
addressing_style = virtual
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
str: The endpoint URL from the config file, or the default endpoint
|
|
163
|
+
if the file doesn't exist or doesn't contain the endpoint_url.
|
|
164
|
+
"""
|
|
165
|
+
config_path = os.path.expanduser(COREWEAVE_CONFIG_PATH)
|
|
166
|
+
if not os.path.isfile(config_path):
|
|
167
|
+
return _DEFAULT_ENDPOINT
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
config = configparser.ConfigParser()
|
|
171
|
+
config.read(config_path)
|
|
172
|
+
|
|
173
|
+
# Try to get endpoint_url from [profile cw] section
|
|
174
|
+
profile_section = f'profile {COREWEAVE_PROFILE_NAME}'
|
|
175
|
+
if config.has_section(profile_section):
|
|
176
|
+
if config.has_option(profile_section, 'endpoint_url'):
|
|
177
|
+
endpoint = config.get(profile_section, 'endpoint_url')
|
|
178
|
+
return endpoint.strip()
|
|
179
|
+
except (configparser.Error, OSError) as e:
|
|
180
|
+
logger.warning(f'Failed to parse CoreWeave config file: {e}. '
|
|
181
|
+
f'Using default endpoint: {_DEFAULT_ENDPOINT}')
|
|
182
|
+
|
|
183
|
+
return _DEFAULT_ENDPOINT
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def check_credentials(
|
|
187
|
+
cloud_capability: cloud.CloudCapability) -> Tuple[bool, Optional[str]]:
|
|
188
|
+
if cloud_capability == cloud.CloudCapability.STORAGE:
|
|
189
|
+
return check_storage_credentials()
|
|
190
|
+
else:
|
|
191
|
+
raise exceptions.NotSupportedError(
|
|
192
|
+
f'{NAME} does not support {cloud_capability}.')
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def check_storage_credentials() -> Tuple[bool, Optional[str]]:
|
|
196
|
+
"""Checks if the user has access credentials to CoreWeave Object Storage.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
A tuple of a boolean value and a hint message where the bool
|
|
200
|
+
is True when both credentials needed for CoreWeave storage is set.
|
|
201
|
+
It is False when either of those are not set, which would hint with a
|
|
202
|
+
string on unset credential.
|
|
203
|
+
"""
|
|
204
|
+
hints = None
|
|
205
|
+
profile_in_cred = coreweave_profile_in_cred()
|
|
206
|
+
profile_in_config = coreweave_profile_in_config()
|
|
207
|
+
|
|
208
|
+
if not profile_in_cred:
|
|
209
|
+
hints = (f'[{COREWEAVE_PROFILE_NAME}] profile is not set in '
|
|
210
|
+
f'{COREWEAVE_CREDENTIALS_PATH}.')
|
|
211
|
+
if not profile_in_config:
|
|
212
|
+
if hints:
|
|
213
|
+
hints += ' Additionally, '
|
|
214
|
+
else:
|
|
215
|
+
hints = ''
|
|
216
|
+
hints += (f'[{COREWEAVE_PROFILE_NAME}] profile is not set in '
|
|
217
|
+
f'{COREWEAVE_CONFIG_PATH}.')
|
|
218
|
+
|
|
219
|
+
if hints:
|
|
220
|
+
hints += ' Run the following commands:'
|
|
221
|
+
if not profile_in_cred:
|
|
222
|
+
hints += f'\n{_INDENT_PREFIX} $ pip install boto3'
|
|
223
|
+
hints += (f'\n{_INDENT_PREFIX} $ AWS_SHARED_CREDENTIALS_FILE='
|
|
224
|
+
f'{COREWEAVE_CREDENTIALS_PATH} aws configure --profile '
|
|
225
|
+
f'{COREWEAVE_PROFILE_NAME}')
|
|
226
|
+
if not profile_in_config:
|
|
227
|
+
hints += (f'\n{_INDENT_PREFIX} $ AWS_CONFIG_FILE='
|
|
228
|
+
f'{COREWEAVE_CONFIG_PATH} aws configure set endpoint_url'
|
|
229
|
+
f' <ENDPOINT_URL> --profile '
|
|
230
|
+
f'{COREWEAVE_PROFILE_NAME}')
|
|
231
|
+
hints += (f'\n{_INDENT_PREFIX} $ AWS_CONFIG_FILE='
|
|
232
|
+
f'{COREWEAVE_CONFIG_PATH} aws configure set '
|
|
233
|
+
f's3.addressing_style virtual --profile '
|
|
234
|
+
f'{COREWEAVE_PROFILE_NAME}')
|
|
235
|
+
hints += f'\n{_INDENT_PREFIX}For more info: '
|
|
236
|
+
hints += 'https://docs.coreweave.com/docs/products/storage/object-storage/get-started-caios' # pylint: disable=line-too-long
|
|
237
|
+
|
|
238
|
+
return (False, hints) if hints else (True, hints)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def coreweave_profile_in_config() -> bool:
|
|
242
|
+
"""Checks if CoreWeave profile is set in config"""
|
|
243
|
+
conf_path = os.path.expanduser(COREWEAVE_CONFIG_PATH)
|
|
244
|
+
coreweave_profile_exists = False
|
|
245
|
+
if os.path.isfile(conf_path):
|
|
246
|
+
with open(conf_path, 'r', encoding='utf-8') as file:
|
|
247
|
+
for line in file:
|
|
248
|
+
if f'[profile {COREWEAVE_PROFILE_NAME}]' in line:
|
|
249
|
+
coreweave_profile_exists = True
|
|
250
|
+
break
|
|
251
|
+
return coreweave_profile_exists
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def coreweave_profile_in_cred() -> bool:
|
|
255
|
+
"""Checks if CoreWeave profile is set in credentials"""
|
|
256
|
+
cred_path = os.path.expanduser(COREWEAVE_CREDENTIALS_PATH)
|
|
257
|
+
coreweave_profile_exists = False
|
|
258
|
+
if os.path.isfile(cred_path):
|
|
259
|
+
with open(cred_path, 'r', encoding='utf-8') as file:
|
|
260
|
+
for line in file:
|
|
261
|
+
if f'[{COREWEAVE_PROFILE_NAME}]' in line:
|
|
262
|
+
coreweave_profile_exists = True
|
|
263
|
+
break
|
|
264
|
+
return coreweave_profile_exists
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def get_credential_file_mounts() -> Dict[str, str]:
|
|
268
|
+
"""Returns credential file mounts for CoreWeave.
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
Dict[str, str]: A dictionary mapping source paths to destination paths
|
|
272
|
+
for credential files.
|
|
273
|
+
"""
|
|
274
|
+
coreweave_credential_mounts = {
|
|
275
|
+
COREWEAVE_CREDENTIALS_PATH: COREWEAVE_CREDENTIALS_PATH,
|
|
276
|
+
COREWEAVE_CONFIG_PATH: COREWEAVE_CONFIG_PATH
|
|
277
|
+
}
|
|
278
|
+
return coreweave_credential_mounts
|
sky/adaptors/kubernetes.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Kubernetes adaptors"""
|
|
2
|
+
import functools
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
4
5
|
import platform
|
|
@@ -162,8 +163,61 @@ def list_kube_config_contexts():
|
|
|
162
163
|
return kubernetes.config.list_kube_config_contexts(_get_config_file())
|
|
163
164
|
|
|
164
165
|
|
|
166
|
+
class ClientWrapper:
|
|
167
|
+
"""Wrapper around the kubernetes API clients.
|
|
168
|
+
|
|
169
|
+
This is needed because we cache kubernetes.client.ApiClient and other typed
|
|
170
|
+
clients (e.g. kubernetes.client.CoreV1Api) and lru_cache.cache_clear() does
|
|
171
|
+
not call close() on the client to cleanup external resources like
|
|
172
|
+
semaphores. This decorator wraps the client with __del__ to ensure the
|
|
173
|
+
external state of kubernetes clients are properly cleaned up on GC.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(self, client):
|
|
177
|
+
self._client = client
|
|
178
|
+
|
|
179
|
+
def __getattr__(self, name):
|
|
180
|
+
"""Delegate to the underlying client"""
|
|
181
|
+
return getattr(self._client, name)
|
|
182
|
+
|
|
183
|
+
def __del__(self):
|
|
184
|
+
"""Clean up the underlying client"""
|
|
185
|
+
try:
|
|
186
|
+
real_client = None
|
|
187
|
+
if isinstance(self._client, kubernetes.client.ApiClient):
|
|
188
|
+
real_client = self._client
|
|
189
|
+
elif isinstance(self._client, kubernetes.watch.Watch):
|
|
190
|
+
real_client = getattr(self._client, '_api_client', None)
|
|
191
|
+
else:
|
|
192
|
+
# Otherwise, the client is a typed client, the typed client
|
|
193
|
+
# is generated by codegen and all of them should have an
|
|
194
|
+
# 'api_client' attribute referring to the real client.
|
|
195
|
+
real_client = getattr(self._client, 'api_client', None)
|
|
196
|
+
if real_client is not None:
|
|
197
|
+
real_client.close()
|
|
198
|
+
else:
|
|
199
|
+
# logger may already be cleaned up during __del__ at shutdown
|
|
200
|
+
if logger is not None:
|
|
201
|
+
logger.debug(f'No client found for {self._client}')
|
|
202
|
+
except Exception as e: # pylint: disable=broad-except
|
|
203
|
+
if logger is not None:
|
|
204
|
+
logger.debug(f'Error closing Kubernetes client: {e}')
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def wrap_kubernetes_client(func):
|
|
208
|
+
"""Wraps kubernetes API clients for proper cleanup."""
|
|
209
|
+
|
|
210
|
+
@functools.wraps(func)
|
|
211
|
+
def wrapper(*args, **kwargs):
|
|
212
|
+
obj = func(*args, **kwargs)
|
|
213
|
+
return ClientWrapper(obj)
|
|
214
|
+
|
|
215
|
+
return wrapper
|
|
216
|
+
|
|
217
|
+
|
|
165
218
|
@_api_logging_decorator('urllib3', logging.ERROR)
|
|
166
219
|
@annotations.lru_cache(scope='request')
|
|
220
|
+
@wrap_kubernetes_client
|
|
167
221
|
def core_api(context: Optional[str] = None):
|
|
168
222
|
_load_config(context)
|
|
169
223
|
return kubernetes.client.CoreV1Api()
|
|
@@ -171,6 +225,7 @@ def core_api(context: Optional[str] = None):
|
|
|
171
225
|
|
|
172
226
|
@_api_logging_decorator('urllib3', logging.ERROR)
|
|
173
227
|
@annotations.lru_cache(scope='request')
|
|
228
|
+
@wrap_kubernetes_client
|
|
174
229
|
def storage_api(context: Optional[str] = None):
|
|
175
230
|
_load_config(context)
|
|
176
231
|
return kubernetes.client.StorageV1Api()
|
|
@@ -178,6 +233,7 @@ def storage_api(context: Optional[str] = None):
|
|
|
178
233
|
|
|
179
234
|
@_api_logging_decorator('urllib3', logging.ERROR)
|
|
180
235
|
@annotations.lru_cache(scope='request')
|
|
236
|
+
@wrap_kubernetes_client
|
|
181
237
|
def auth_api(context: Optional[str] = None):
|
|
182
238
|
_load_config(context)
|
|
183
239
|
return kubernetes.client.RbacAuthorizationV1Api()
|
|
@@ -185,6 +241,7 @@ def auth_api(context: Optional[str] = None):
|
|
|
185
241
|
|
|
186
242
|
@_api_logging_decorator('urllib3', logging.ERROR)
|
|
187
243
|
@annotations.lru_cache(scope='request')
|
|
244
|
+
@wrap_kubernetes_client
|
|
188
245
|
def networking_api(context: Optional[str] = None):
|
|
189
246
|
_load_config(context)
|
|
190
247
|
return kubernetes.client.NetworkingV1Api()
|
|
@@ -192,6 +249,7 @@ def networking_api(context: Optional[str] = None):
|
|
|
192
249
|
|
|
193
250
|
@_api_logging_decorator('urllib3', logging.ERROR)
|
|
194
251
|
@annotations.lru_cache(scope='request')
|
|
252
|
+
@wrap_kubernetes_client
|
|
195
253
|
def custom_objects_api(context: Optional[str] = None):
|
|
196
254
|
_load_config(context)
|
|
197
255
|
return kubernetes.client.CustomObjectsApi()
|
|
@@ -199,6 +257,7 @@ def custom_objects_api(context: Optional[str] = None):
|
|
|
199
257
|
|
|
200
258
|
@_api_logging_decorator('urllib3', logging.ERROR)
|
|
201
259
|
@annotations.lru_cache(scope='global')
|
|
260
|
+
@wrap_kubernetes_client
|
|
202
261
|
def node_api(context: Optional[str] = None):
|
|
203
262
|
_load_config(context)
|
|
204
263
|
return kubernetes.client.NodeV1Api()
|
|
@@ -206,6 +265,7 @@ def node_api(context: Optional[str] = None):
|
|
|
206
265
|
|
|
207
266
|
@_api_logging_decorator('urllib3', logging.ERROR)
|
|
208
267
|
@annotations.lru_cache(scope='request')
|
|
268
|
+
@wrap_kubernetes_client
|
|
209
269
|
def apps_api(context: Optional[str] = None):
|
|
210
270
|
_load_config(context)
|
|
211
271
|
return kubernetes.client.AppsV1Api()
|
|
@@ -213,6 +273,7 @@ def apps_api(context: Optional[str] = None):
|
|
|
213
273
|
|
|
214
274
|
@_api_logging_decorator('urllib3', logging.ERROR)
|
|
215
275
|
@annotations.lru_cache(scope='request')
|
|
276
|
+
@wrap_kubernetes_client
|
|
216
277
|
def batch_api(context: Optional[str] = None):
|
|
217
278
|
_load_config(context)
|
|
218
279
|
return kubernetes.client.BatchV1Api()
|
|
@@ -220,6 +281,7 @@ def batch_api(context: Optional[str] = None):
|
|
|
220
281
|
|
|
221
282
|
@_api_logging_decorator('urllib3', logging.ERROR)
|
|
222
283
|
@annotations.lru_cache(scope='request')
|
|
284
|
+
@wrap_kubernetes_client
|
|
223
285
|
def api_client(context: Optional[str] = None):
|
|
224
286
|
_load_config(context)
|
|
225
287
|
return kubernetes.client.ApiClient()
|
|
@@ -227,6 +289,7 @@ def api_client(context: Optional[str] = None):
|
|
|
227
289
|
|
|
228
290
|
@_api_logging_decorator('urllib3', logging.ERROR)
|
|
229
291
|
@annotations.lru_cache(scope='request')
|
|
292
|
+
@wrap_kubernetes_client
|
|
230
293
|
def custom_resources_api(context: Optional[str] = None):
|
|
231
294
|
_load_config(context)
|
|
232
295
|
return kubernetes.client.CustomObjectsApi()
|
|
@@ -234,6 +297,7 @@ def custom_resources_api(context: Optional[str] = None):
|
|
|
234
297
|
|
|
235
298
|
@_api_logging_decorator('urllib3', logging.ERROR)
|
|
236
299
|
@annotations.lru_cache(scope='request')
|
|
300
|
+
@wrap_kubernetes_client
|
|
237
301
|
def watch(context: Optional[str] = None):
|
|
238
302
|
_load_config(context)
|
|
239
303
|
return kubernetes.watch.Watch()
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""Shadeform cloud adaptor."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import socket
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import requests
|
|
8
|
+
|
|
9
|
+
from sky import sky_logging
|
|
10
|
+
from sky.provision.shadeform import shadeform_utils
|
|
11
|
+
from sky.utils import common_utils
|
|
12
|
+
|
|
13
|
+
logger = sky_logging.init_logger(__name__)
|
|
14
|
+
|
|
15
|
+
_shadeform_sdk = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def import_package(func):
|
|
19
|
+
|
|
20
|
+
@functools.wraps(func)
|
|
21
|
+
def wrapper(*args, **kwargs):
|
|
22
|
+
global _shadeform_sdk
|
|
23
|
+
if _shadeform_sdk is None:
|
|
24
|
+
try:
|
|
25
|
+
import shadeform as _shadeform # pylint: disable=import-outside-toplevel
|
|
26
|
+
_shadeform_sdk = _shadeform
|
|
27
|
+
except ImportError:
|
|
28
|
+
raise ImportError(
|
|
29
|
+
'Failed to import dependencies for Shadeform. '
|
|
30
|
+
'Try pip install "skypilot[shadeform]"') from None
|
|
31
|
+
return func(*args, **kwargs)
|
|
32
|
+
|
|
33
|
+
return wrapper
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@import_package
|
|
37
|
+
def shadeform():
|
|
38
|
+
"""Return the shadeform package."""
|
|
39
|
+
return _shadeform_sdk
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def list_ssh_keys() -> List[Dict[str, Any]]:
|
|
43
|
+
"""List all SSH keys in Shadeform account."""
|
|
44
|
+
try:
|
|
45
|
+
response = shadeform_utils.get_ssh_keys()
|
|
46
|
+
return response.get('ssh_keys', [])
|
|
47
|
+
except (ValueError, KeyError, requests.exceptions.RequestException) as e:
|
|
48
|
+
logger.warning(f'Failed to list SSH keys from Shadeform: {e}')
|
|
49
|
+
return []
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def add_ssh_key_to_shadeform(public_key: str) -> Optional[str]:
|
|
53
|
+
"""Add SSH key to Shadeform if it doesn't already exist.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
public_key: The SSH public key string.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The name of the key if added successfully, None otherwise.
|
|
60
|
+
"""
|
|
61
|
+
try:
|
|
62
|
+
# Check if key already exists
|
|
63
|
+
existing_keys = list_ssh_keys()
|
|
64
|
+
key_exists = False
|
|
65
|
+
key_id = None
|
|
66
|
+
for key in existing_keys:
|
|
67
|
+
if key.get('public_key') == public_key:
|
|
68
|
+
key_exists = True
|
|
69
|
+
key_id = key.get('id')
|
|
70
|
+
break
|
|
71
|
+
|
|
72
|
+
if key_exists:
|
|
73
|
+
logger.info('SSH key already exists in Shadeform account')
|
|
74
|
+
return key_id
|
|
75
|
+
|
|
76
|
+
# Generate a unique key name
|
|
77
|
+
hostname = socket.gethostname()
|
|
78
|
+
key_name = f'skypilot-{hostname}-{common_utils.get_user_hash()[:8]}'
|
|
79
|
+
|
|
80
|
+
# Add the key
|
|
81
|
+
response = shadeform_utils.add_ssh_key(name=key_name,
|
|
82
|
+
public_key=public_key)
|
|
83
|
+
key_id = response['id']
|
|
84
|
+
logger.info(f'Added SSH key to Shadeform: {key_name, key_id}')
|
|
85
|
+
return key_id
|
|
86
|
+
|
|
87
|
+
except (ValueError, KeyError, requests.exceptions.RequestException) as e:
|
|
88
|
+
logger.warning(f'Failed to add SSH key to Shadeform: {e}')
|
|
89
|
+
return None
|