skypilot-nightly 1.0.0.dev20251203__py3-none-any.whl → 1.0.0.dev20251210__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +4 -2
- sky/adaptors/aws.py +1 -61
- sky/adaptors/slurm.py +478 -0
- sky/backends/backend_utils.py +45 -4
- sky/backends/cloud_vm_ray_backend.py +32 -33
- sky/backends/task_codegen.py +340 -2
- sky/catalog/__init__.py +0 -3
- sky/catalog/kubernetes_catalog.py +12 -4
- sky/catalog/slurm_catalog.py +243 -0
- sky/check.py +14 -3
- sky/client/cli/command.py +329 -22
- sky/client/sdk.py +56 -2
- sky/clouds/__init__.py +2 -0
- sky/clouds/cloud.py +7 -0
- sky/clouds/slurm.py +578 -0
- sky/clouds/ssh.py +2 -1
- sky/clouds/vast.py +10 -0
- sky/core.py +128 -36
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/KYAhEFa3FTfq4JyKVgo-s/_buildManifest.js +1 -0
- sky/dashboard/out/_next/static/chunks/3294.ddda8c6c6f9f24dc.js +1 -0
- sky/dashboard/out/_next/static/chunks/3850-fd5696f3bbbaddae.js +1 -0
- sky/dashboard/out/_next/static/chunks/6856-da20c5fd999f319c.js +1 -0
- sky/dashboard/out/_next/static/chunks/6990-09cbf02d3cd518c3.js +1 -0
- sky/dashboard/out/_next/static/chunks/9353-8369df1cf105221c.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/_app-68b647e26f9d2793.js +34 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-33f525539665fdfd.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-abfcac9c137aa543.js → [cluster]-a7565f586ef86467.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{clusters-ee39056f9851a3ff.js → clusters-9e5d47818b9bdadd.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{config-dfb9bf07b13045f4.js → config-718cdc365de82689.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-c0b5935149902e6f.js → [context]-12c559ec4d81fdbd.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{infra-aed0ea19df7cf961.js → infra-d187cd0413d72475.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-895847b6cf200b04.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/{[pool]-9faf940b253e3e06.js → [pool]-8d0f4655400b4eb9.js} +2 -2
- sky/dashboard/out/_next/static/chunks/pages/{jobs-2072b48b617989c9.js → jobs-e5a98f17f8513a96.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/plugins/[...slug]-4f46050ca065d8f8.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/{users-f42674164aa73423.js → users-2f7646eb77785a2c.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{volumes-b84b948ff357c43e.js → volumes-ef19d49c6d0e8500.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-84a40f8c7c627fe4.js → [name]-96e0f298308da7e2.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{workspaces-531b2f8c4bf89f82.js → workspaces-cb4da3abe08ebf19.js} +1 -1
- sky/dashboard/out/_next/static/chunks/{webpack-64e05f17bf2cf8ce.js → webpack-fba3de387ff6bb08.js} +1 -1
- sky/dashboard/out/_next/static/css/c5a4cfd2600fc715.css +3 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/plugins/[...slug].html +1 -0
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/volumes.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/data/mounting_utils.py +16 -2
- sky/global_user_state.py +3 -3
- sky/models.py +2 -0
- sky/optimizer.py +6 -5
- sky/provision/__init__.py +1 -0
- sky/provision/common.py +20 -0
- sky/provision/docker_utils.py +15 -2
- sky/provision/kubernetes/utils.py +42 -6
- sky/provision/provisioner.py +15 -6
- sky/provision/slurm/__init__.py +12 -0
- sky/provision/slurm/config.py +13 -0
- sky/provision/slurm/instance.py +572 -0
- sky/provision/slurm/utils.py +583 -0
- sky/provision/vast/instance.py +4 -1
- sky/provision/vast/utils.py +10 -6
- sky/serve/server/impl.py +1 -1
- sky/server/constants.py +1 -1
- sky/server/plugins.py +222 -0
- sky/server/requests/executor.py +5 -2
- sky/server/requests/payloads.py +12 -1
- sky/server/requests/request_names.py +2 -0
- sky/server/requests/requests.py +5 -1
- sky/server/requests/serializers/encoders.py +17 -0
- sky/server/requests/serializers/return_value_serializers.py +60 -0
- sky/server/server.py +78 -8
- sky/server/server_utils.py +30 -0
- sky/setup_files/dependencies.py +2 -0
- sky/skylet/attempt_skylet.py +13 -3
- sky/skylet/constants.py +34 -9
- sky/skylet/events.py +10 -4
- sky/skylet/executor/__init__.py +1 -0
- sky/skylet/executor/slurm.py +189 -0
- sky/skylet/job_lib.py +2 -1
- sky/skylet/log_lib.py +22 -6
- sky/skylet/log_lib.pyi +8 -6
- sky/skylet/skylet.py +5 -1
- sky/skylet/subprocess_daemon.py +2 -1
- sky/ssh_node_pools/constants.py +12 -0
- sky/ssh_node_pools/core.py +40 -3
- sky/ssh_node_pools/deploy/__init__.py +4 -0
- sky/{utils/kubernetes/deploy_ssh_node_pools.py → ssh_node_pools/deploy/deploy.py} +279 -504
- sky/ssh_node_pools/deploy/tunnel_utils.py +199 -0
- sky/ssh_node_pools/deploy/utils.py +173 -0
- sky/ssh_node_pools/server.py +11 -13
- sky/{utils/kubernetes/ssh_utils.py → ssh_node_pools/utils.py} +9 -6
- sky/templates/kubernetes-ray.yml.j2 +8 -0
- sky/templates/slurm-ray.yml.j2 +85 -0
- sky/templates/vast-ray.yml.j2 +1 -0
- sky/users/model.conf +1 -1
- sky/users/permission.py +24 -1
- sky/users/rbac.py +31 -3
- sky/utils/annotations.py +108 -8
- sky/utils/command_runner.py +197 -5
- sky/utils/command_runner.pyi +27 -4
- sky/utils/common_utils.py +18 -3
- sky/utils/kubernetes/kubernetes_deploy_utils.py +2 -94
- sky/utils/kubernetes/ssh-tunnel.sh +7 -376
- sky/utils/schemas.py +31 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/METADATA +48 -36
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/RECORD +125 -107
- sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +0 -1
- sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +0 -1
- sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +0 -1
- sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +0 -1
- sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +0 -1
- sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +0 -34
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +0 -16
- sky/dashboard/out/_next/static/css/0748ce22df867032.css +0 -3
- sky/utils/kubernetes/cleanup-tunnel.sh +0 -62
- /sky/dashboard/out/_next/static/{96_E2yl3QAiIJGOYCkSpB → KYAhEFa3FTfq4JyKVgo-s}/_ssgManifest.js +0 -0
- /sky/dashboard/out/_next/static/chunks/{1141-e6aa9ab418717c59.js → 1141-9c810f01ff4f398a.js} +0 -0
- /sky/dashboard/out/_next/static/chunks/{3800-7b45f9fbb6308557.js → 3800-b589397dc09c5b4e.js} +0 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/top_level.txt +0 -0
sky/users/rbac.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""RBAC (Role-Based Access Control) functionality for SkyPilot API Server."""
|
|
2
2
|
|
|
3
3
|
import enum
|
|
4
|
-
from typing import Dict, List
|
|
4
|
+
from typing import Dict, List, Optional
|
|
5
5
|
|
|
6
6
|
from sky import sky_logging
|
|
7
7
|
from sky import skypilot_config
|
|
@@ -55,8 +55,13 @@ def get_default_role() -> str:
|
|
|
55
55
|
|
|
56
56
|
|
|
57
57
|
def get_role_permissions(
|
|
58
|
+
plugin_rules: Optional[Dict[str, List[Dict[str, str]]]] = None
|
|
58
59
|
) -> Dict[str, Dict[str, Dict[str, List[Dict[str, str]]]]]:
|
|
59
|
-
"""Get all role permissions from config.
|
|
60
|
+
"""Get all role permissions from config and plugins.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
plugin_rules: Optional dictionary of plugin RBAC rules to merge.
|
|
64
|
+
Format: {'user': [{'path': '...', 'method': '...'}]}
|
|
60
65
|
|
|
61
66
|
Returns:
|
|
62
67
|
Dictionary containing all roles and their permissions configuration.
|
|
@@ -91,9 +96,32 @@ def get_role_permissions(
|
|
|
91
96
|
if 'user' not in config_permissions:
|
|
92
97
|
config_permissions['user'] = {
|
|
93
98
|
'permissions': {
|
|
94
|
-
'blocklist': _DEFAULT_USER_BLOCKLIST
|
|
99
|
+
'blocklist': _DEFAULT_USER_BLOCKLIST.copy()
|
|
95
100
|
}
|
|
96
101
|
}
|
|
102
|
+
|
|
103
|
+
# Merge plugin rules into the appropriate roles
|
|
104
|
+
if plugin_rules:
|
|
105
|
+
for role, rules in plugin_rules.items():
|
|
106
|
+
if role not in supported_roles:
|
|
107
|
+
logger.warning(f'Plugin specified invalid role: {role}')
|
|
108
|
+
continue
|
|
109
|
+
if role not in config_permissions:
|
|
110
|
+
config_permissions[role] = {'permissions': {'blocklist': []}}
|
|
111
|
+
if 'permissions' not in config_permissions[role]:
|
|
112
|
+
config_permissions[role]['permissions'] = {'blocklist': []}
|
|
113
|
+
if 'blocklist' not in config_permissions[role]['permissions']:
|
|
114
|
+
config_permissions[role]['permissions']['blocklist'] = []
|
|
115
|
+
|
|
116
|
+
# Merge plugin rules, avoiding duplicates
|
|
117
|
+
existing_rules = config_permissions[role]['permissions'][
|
|
118
|
+
'blocklist']
|
|
119
|
+
for rule in rules:
|
|
120
|
+
if rule not in existing_rules:
|
|
121
|
+
existing_rules.append(rule)
|
|
122
|
+
logger.debug(f'Added plugin RBAC rule for {role}: '
|
|
123
|
+
f'{rule["method"]} {rule["path"]}')
|
|
124
|
+
|
|
97
125
|
return config_permissions
|
|
98
126
|
|
|
99
127
|
|
sky/utils/annotations.py
CHANGED
|
@@ -1,14 +1,20 @@
|
|
|
1
1
|
"""Annotations for public APIs."""
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
|
-
|
|
4
|
+
import threading
|
|
5
|
+
import time
|
|
6
|
+
from typing import Callable, List, Literal, TypeVar
|
|
7
|
+
import weakref
|
|
5
8
|
|
|
6
9
|
import cachetools
|
|
7
10
|
from typing_extensions import ParamSpec
|
|
8
11
|
|
|
9
12
|
# Whether the current process is a SkyPilot API server process.
|
|
10
13
|
is_on_api_server = True
|
|
11
|
-
|
|
14
|
+
_FUNCTIONS_NEED_RELOAD_CACHE_LOCK = threading.Lock()
|
|
15
|
+
# Caches can be thread-local, use weakref to avoid blocking the GC when the
|
|
16
|
+
# thread is destroyed.
|
|
17
|
+
_FUNCTIONS_NEED_RELOAD_CACHE: List[weakref.ReferenceType] = []
|
|
12
18
|
|
|
13
19
|
T = TypeVar('T')
|
|
14
20
|
P = ParamSpec('P')
|
|
@@ -30,6 +36,94 @@ def client_api(func: Callable[P, T]) -> Callable[P, T]:
|
|
|
30
36
|
return wrapper
|
|
31
37
|
|
|
32
38
|
|
|
39
|
+
def _register_functions_need_reload_cache(func: Callable) -> Callable:
|
|
40
|
+
"""Register a cachefunction that needs to be reloaded for a new request.
|
|
41
|
+
|
|
42
|
+
The function will be registered as a weak reference to avoid blocking GC.
|
|
43
|
+
"""
|
|
44
|
+
assert hasattr(func, 'cache_clear'), f'{func.__name__} is not cacheable'
|
|
45
|
+
wrapped_fn = func
|
|
46
|
+
try:
|
|
47
|
+
func_ref = weakref.ref(func)
|
|
48
|
+
except TypeError:
|
|
49
|
+
# The function might be not weakrefable (e.g. functools.lru_cache),
|
|
50
|
+
# wrap it in this case.
|
|
51
|
+
@functools.wraps(func)
|
|
52
|
+
def wrapper(*args, **kwargs):
|
|
53
|
+
return func(*args, **kwargs)
|
|
54
|
+
|
|
55
|
+
wrapper.cache_clear = func.cache_clear # type: ignore[attr-defined]
|
|
56
|
+
func_ref = weakref.ref(wrapper)
|
|
57
|
+
wrapped_fn = wrapper
|
|
58
|
+
with _FUNCTIONS_NEED_RELOAD_CACHE_LOCK:
|
|
59
|
+
_FUNCTIONS_NEED_RELOAD_CACHE.append(func_ref)
|
|
60
|
+
return wrapped_fn
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ThreadLocalTTLCache(threading.local):
|
|
64
|
+
"""Thread-local storage for _thread_local_lru_cache decorator."""
|
|
65
|
+
|
|
66
|
+
def __init__(self, func, maxsize: int, ttl: int):
|
|
67
|
+
super().__init__()
|
|
68
|
+
self.func = func
|
|
69
|
+
self.maxsize = maxsize
|
|
70
|
+
self.ttl = ttl
|
|
71
|
+
|
|
72
|
+
def get_cache(self):
|
|
73
|
+
if not hasattr(self, 'cache'):
|
|
74
|
+
self.cache = ttl_cache(scope='request',
|
|
75
|
+
maxsize=self.maxsize,
|
|
76
|
+
ttl=self.ttl,
|
|
77
|
+
timer=time.time)(self.func)
|
|
78
|
+
return self.cache
|
|
79
|
+
|
|
80
|
+
def __del__(self):
|
|
81
|
+
if hasattr(self, 'cache'):
|
|
82
|
+
self.cache.cache_clear()
|
|
83
|
+
self.cache = None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def thread_local_ttl_cache(maxsize=32, ttl=60 * 55):
|
|
87
|
+
"""Thread-local TTL cache decorator.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
maxsize: Maximum size of the cache.
|
|
91
|
+
ttl: Time to live for the cache in seconds.
|
|
92
|
+
Default is 55 minutes, a bit less than 1 hour
|
|
93
|
+
default lifetime of an STS token.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def decorator(func):
|
|
97
|
+
# Create thread-local storage for the LRU cache
|
|
98
|
+
local_cache = ThreadLocalTTLCache(func, maxsize, ttl)
|
|
99
|
+
|
|
100
|
+
# We can't apply the lru_cache here, because this runs at import time
|
|
101
|
+
# so we will always have the main thread's cache.
|
|
102
|
+
|
|
103
|
+
@functools.wraps(func)
|
|
104
|
+
def wrapper(*args, **kwargs):
|
|
105
|
+
# We are within the actual function call, which may be on a thread,
|
|
106
|
+
# so local_cache.cache will return the correct thread-local cache,
|
|
107
|
+
# which we can now apply and immediately call.
|
|
108
|
+
return local_cache.get_cache()(*args, **kwargs)
|
|
109
|
+
|
|
110
|
+
def cache_info():
|
|
111
|
+
# Note that this will only give the cache info for the current
|
|
112
|
+
# thread's cache.
|
|
113
|
+
return local_cache.get_cache().cache_info()
|
|
114
|
+
|
|
115
|
+
def cache_clear():
|
|
116
|
+
# Note that this will only clear the cache for the current thread.
|
|
117
|
+
local_cache.get_cache().cache_clear()
|
|
118
|
+
|
|
119
|
+
wrapper.cache_info = cache_info # type: ignore[attr-defined]
|
|
120
|
+
wrapper.cache_clear = cache_clear # type: ignore[attr-defined]
|
|
121
|
+
|
|
122
|
+
return wrapper
|
|
123
|
+
|
|
124
|
+
return decorator
|
|
125
|
+
|
|
126
|
+
|
|
33
127
|
def lru_cache(scope: Literal['global', 'request'], *lru_cache_args,
|
|
34
128
|
**lru_cache_kwargs) -> Callable:
|
|
35
129
|
"""LRU cache decorator for functions.
|
|
@@ -51,8 +145,7 @@ def lru_cache(scope: Literal['global', 'request'], *lru_cache_args,
|
|
|
51
145
|
else:
|
|
52
146
|
cached_func = functools.lru_cache(*lru_cache_args,
|
|
53
147
|
**lru_cache_kwargs)(func)
|
|
54
|
-
|
|
55
|
-
return cached_func
|
|
148
|
+
return _register_functions_need_reload_cache(cached_func)
|
|
56
149
|
|
|
57
150
|
return decorator
|
|
58
151
|
|
|
@@ -72,13 +165,20 @@ def ttl_cache(scope: Literal['global', 'request'], *ttl_cache_args,
|
|
|
72
165
|
else:
|
|
73
166
|
cached_func = cachetools.cached(
|
|
74
167
|
cachetools.TTLCache(*ttl_cache_args, **ttl_cache_kwargs))(func)
|
|
75
|
-
|
|
76
|
-
return cached_func
|
|
168
|
+
return _register_functions_need_reload_cache(cached_func)
|
|
77
169
|
|
|
78
170
|
return decorator
|
|
79
171
|
|
|
80
172
|
|
|
81
173
|
def clear_request_level_cache():
|
|
82
174
|
"""Clear the request-level cache."""
|
|
83
|
-
|
|
84
|
-
|
|
175
|
+
alive_entries = []
|
|
176
|
+
with _FUNCTIONS_NEED_RELOAD_CACHE_LOCK:
|
|
177
|
+
for entry in _FUNCTIONS_NEED_RELOAD_CACHE:
|
|
178
|
+
func = entry()
|
|
179
|
+
if func is None:
|
|
180
|
+
# Has been GC'ed, drop the reference.
|
|
181
|
+
continue
|
|
182
|
+
func.cache_clear()
|
|
183
|
+
alive_entries.append(entry)
|
|
184
|
+
_FUNCTIONS_NEED_RELOAD_CACHE[:] = alive_entries
|
sky/utils/command_runner.py
CHANGED
|
@@ -63,6 +63,22 @@ def _ssh_control_path(ssh_control_filename: Optional[str]) -> Optional[str]:
|
|
|
63
63
|
return path
|
|
64
64
|
|
|
65
65
|
|
|
66
|
+
def _is_skypilot_managed_key(key_path: str) -> bool:
|
|
67
|
+
"""Check if SSH key follows SkyPilot's managed key format.
|
|
68
|
+
|
|
69
|
+
SkyPilot-managed keys follow the pattern: ~/.sky/clients/<hash>/ssh/sky-key
|
|
70
|
+
External keys (like ~/.ssh/id_rsa) do not follow this pattern.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
key_path: Path to the SSH private key.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
True if the key follows SkyPilot's managed format, False otherwise.
|
|
77
|
+
"""
|
|
78
|
+
parts = os.path.normpath(key_path).split(os.path.sep)
|
|
79
|
+
return len(parts) >= 2 and parts[-1] == 'sky-key' and parts[-2] == 'ssh'
|
|
80
|
+
|
|
81
|
+
|
|
66
82
|
# Disable sudo for root user. This is useful when the command is running in a
|
|
67
83
|
# docker container, i.e. image_id is a docker image.
|
|
68
84
|
ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD = (
|
|
@@ -603,7 +619,7 @@ class SSHCommandRunner(CommandRunner):
|
|
|
603
619
|
self,
|
|
604
620
|
node: Tuple[str, int],
|
|
605
621
|
ssh_user: str,
|
|
606
|
-
ssh_private_key: str,
|
|
622
|
+
ssh_private_key: Optional[str],
|
|
607
623
|
ssh_control_name: Optional[str] = '__default__',
|
|
608
624
|
ssh_proxy_command: Optional[str] = None,
|
|
609
625
|
docker_user: Optional[str] = None,
|
|
@@ -613,7 +629,7 @@ class SSHCommandRunner(CommandRunner):
|
|
|
613
629
|
"""Initialize SSHCommandRunner.
|
|
614
630
|
|
|
615
631
|
Example Usage:
|
|
616
|
-
runner = SSHCommandRunner(ip, ssh_user, ssh_private_key)
|
|
632
|
+
runner = SSHCommandRunner((ip, port), ssh_user, ssh_private_key)
|
|
617
633
|
runner.run('ls -l', mode=SshMode.NON_INTERACTIVE)
|
|
618
634
|
runner.rsync(source, target, up=True)
|
|
619
635
|
|
|
@@ -650,8 +666,17 @@ class SSHCommandRunner(CommandRunner):
|
|
|
650
666
|
self.disable_control_master = (
|
|
651
667
|
disable_control_master or
|
|
652
668
|
control_master_utils.should_disable_control_master())
|
|
653
|
-
#
|
|
654
|
-
|
|
669
|
+
# Ensure SSH key is available. For SkyPilot-managed keys, create from
|
|
670
|
+
# database. For external keys (e.g., Slurm clusters), verify existence.
|
|
671
|
+
if ssh_private_key is not None and _is_skypilot_managed_key(
|
|
672
|
+
ssh_private_key):
|
|
673
|
+
auth_utils.create_ssh_key_files_from_db(ssh_private_key)
|
|
674
|
+
elif ssh_private_key is not None:
|
|
675
|
+
# Externally managed key - just verify it exists
|
|
676
|
+
expanded_key_path = os.path.expanduser(ssh_private_key)
|
|
677
|
+
if not os.path.exists(expanded_key_path):
|
|
678
|
+
raise FileNotFoundError(
|
|
679
|
+
f'SSH private key not found: {expanded_key_path}')
|
|
655
680
|
if docker_user is not None:
|
|
656
681
|
assert port is None or port == 22, (
|
|
657
682
|
f'port must be None or 22 for docker_user, got {port}.')
|
|
@@ -867,6 +892,7 @@ class SSHCommandRunner(CommandRunner):
|
|
|
867
892
|
log_path: str = os.devnull,
|
|
868
893
|
stream_logs: bool = True,
|
|
869
894
|
max_retry: int = 1,
|
|
895
|
+
get_remote_home_dir: Callable[[], str] = lambda: '~',
|
|
870
896
|
) -> None:
|
|
871
897
|
"""Uses 'rsync' to sync 'source' to 'target'.
|
|
872
898
|
|
|
@@ -879,6 +905,8 @@ class SSHCommandRunner(CommandRunner):
|
|
|
879
905
|
stream_logs: Stream logs to the stdout/stderr.
|
|
880
906
|
max_retry: The maximum number of retries for the rsync command.
|
|
881
907
|
This value should be non-negative.
|
|
908
|
+
get_remote_home_dir: A callable that returns the remote home
|
|
909
|
+
directory. Defaults to '~'.
|
|
882
910
|
|
|
883
911
|
Raises:
|
|
884
912
|
exceptions.CommandError: rsync command failed.
|
|
@@ -903,7 +931,8 @@ class SSHCommandRunner(CommandRunner):
|
|
|
903
931
|
rsh_option=rsh_option,
|
|
904
932
|
log_path=log_path,
|
|
905
933
|
stream_logs=stream_logs,
|
|
906
|
-
max_retry=max_retry
|
|
934
|
+
max_retry=max_retry,
|
|
935
|
+
get_remote_home_dir=get_remote_home_dir)
|
|
907
936
|
|
|
908
937
|
|
|
909
938
|
class KubernetesCommandRunner(CommandRunner):
|
|
@@ -1247,3 +1276,166 @@ class LocalProcessCommandRunner(CommandRunner):
|
|
|
1247
1276
|
log_path=log_path,
|
|
1248
1277
|
stream_logs=stream_logs,
|
|
1249
1278
|
max_retry=max_retry)
|
|
1279
|
+
|
|
1280
|
+
|
|
1281
|
+
class SlurmCommandRunner(SSHCommandRunner):
|
|
1282
|
+
"""Runner for Slurm commands.
|
|
1283
|
+
|
|
1284
|
+
SlurmCommandRunner sends commands over an SSH connection through the Slurm
|
|
1285
|
+
controller, to the virtual instances.
|
|
1286
|
+
"""
|
|
1287
|
+
|
|
1288
|
+
def __init__(
|
|
1289
|
+
self,
|
|
1290
|
+
node: Tuple[str, int],
|
|
1291
|
+
ssh_user: str,
|
|
1292
|
+
ssh_private_key: Optional[str],
|
|
1293
|
+
*,
|
|
1294
|
+
sky_dir: str,
|
|
1295
|
+
skypilot_runtime_dir: str,
|
|
1296
|
+
job_id: str,
|
|
1297
|
+
slurm_node: str,
|
|
1298
|
+
**kwargs,
|
|
1299
|
+
):
|
|
1300
|
+
"""Initialize SlurmCommandRunner.
|
|
1301
|
+
|
|
1302
|
+
Example Usage:
|
|
1303
|
+
runner = SlurmCommandRunner(
|
|
1304
|
+
(ip, port),
|
|
1305
|
+
ssh_user,
|
|
1306
|
+
ssh_private_key,
|
|
1307
|
+
sky_dir=sky_dir,
|
|
1308
|
+
skypilot_runtime_dir=skypilot_runtime_dir,
|
|
1309
|
+
job_id=job_id,
|
|
1310
|
+
slurm_node=slurm_node)
|
|
1311
|
+
runner.run('ls -l', mode=SshMode.NON_INTERACTIVE)
|
|
1312
|
+
runner.rsync(source, target, up=True)
|
|
1313
|
+
|
|
1314
|
+
Args:
|
|
1315
|
+
node: (ip, port) The IP address and port of the remote machine
|
|
1316
|
+
(login node).
|
|
1317
|
+
ssh_user: SSH username.
|
|
1318
|
+
ssh_private_key: Path to SSH private key.
|
|
1319
|
+
sky_dir: The private directory for the SkyPilot cluster on the
|
|
1320
|
+
Slurm cluster.
|
|
1321
|
+
skypilot_runtime_dir: The directory for the SkyPilot runtime
|
|
1322
|
+
on the Slurm cluster.
|
|
1323
|
+
job_id: The Slurm job ID for this instance.
|
|
1324
|
+
slurm_node: The Slurm node hostname for this instance
|
|
1325
|
+
(compute node).
|
|
1326
|
+
**kwargs: Additional arguments forwarded to SSHCommandRunner
|
|
1327
|
+
(e.g., ssh_proxy_command).
|
|
1328
|
+
"""
|
|
1329
|
+
super().__init__(node, ssh_user, ssh_private_key, **kwargs)
|
|
1330
|
+
self.sky_dir = sky_dir
|
|
1331
|
+
self.skypilot_runtime_dir = skypilot_runtime_dir
|
|
1332
|
+
self.job_id = job_id
|
|
1333
|
+
self.slurm_node = slurm_node
|
|
1334
|
+
|
|
1335
|
+
# Build a chained ProxyCommand that goes through the login node to reach
|
|
1336
|
+
# the compute node where the job is running.
|
|
1337
|
+
|
|
1338
|
+
# First, build SSH options to reach the login node, using the user's
|
|
1339
|
+
# existing proxy command if provided.
|
|
1340
|
+
proxy_ssh_options = ' '.join(
|
|
1341
|
+
ssh_options_list(self.ssh_private_key,
|
|
1342
|
+
None,
|
|
1343
|
+
ssh_proxy_command=self._ssh_proxy_command,
|
|
1344
|
+
port=self.port,
|
|
1345
|
+
disable_control_master=True))
|
|
1346
|
+
login_node_proxy_command = (f'ssh {proxy_ssh_options} '
|
|
1347
|
+
f'-W %h:%p {self.ssh_user}@{self.ip}')
|
|
1348
|
+
|
|
1349
|
+
# Update the proxy command to be the login node proxy, which will
|
|
1350
|
+
# be used by super().run() to reach the compute node.
|
|
1351
|
+
self._ssh_proxy_command = login_node_proxy_command
|
|
1352
|
+
# Update self.ip to target the compute node.
|
|
1353
|
+
self.ip = slurm_node
|
|
1354
|
+
# Assume the compute node's SSH port is 22.
|
|
1355
|
+
# TODO(kevin): Make this configurable if needed.
|
|
1356
|
+
self.port = 22
|
|
1357
|
+
|
|
1358
|
+
def rsync(
|
|
1359
|
+
self,
|
|
1360
|
+
source: str,
|
|
1361
|
+
target: str,
|
|
1362
|
+
*,
|
|
1363
|
+
up: bool,
|
|
1364
|
+
log_path: str = os.devnull,
|
|
1365
|
+
stream_logs: bool = True,
|
|
1366
|
+
max_retry: int = 1,
|
|
1367
|
+
) -> None:
|
|
1368
|
+
"""Rsyncs files directly to the Slurm compute node,
|
|
1369
|
+
by proxying through the Slurm login node.
|
|
1370
|
+
|
|
1371
|
+
For Slurm, files need to be accessible by compute nodes where jobs
|
|
1372
|
+
execute via srun. This means either it has to be on the compute node's
|
|
1373
|
+
local filesystem, or on a shared filesystem.
|
|
1374
|
+
"""
|
|
1375
|
+
# TODO(kevin): We can probably optimize this to skip the proxying
|
|
1376
|
+
# if the target dir is in a shared filesystem, since it will
|
|
1377
|
+
# be accessible by the compute node.
|
|
1378
|
+
|
|
1379
|
+
# Build SSH options for rsync using the ProxyCommand set up in __init__
|
|
1380
|
+
# to reach the compute node through the login node.
|
|
1381
|
+
ssh_options = ' '.join(
|
|
1382
|
+
ssh_options_list(
|
|
1383
|
+
# Assume nothing and rely on default SSH behavior when -i is
|
|
1384
|
+
# not specified.
|
|
1385
|
+
None,
|
|
1386
|
+
None,
|
|
1387
|
+
ssh_proxy_command=self._ssh_proxy_command,
|
|
1388
|
+
disable_control_master=True))
|
|
1389
|
+
rsh_option = f'ssh {ssh_options}'
|
|
1390
|
+
|
|
1391
|
+
self._rsync(
|
|
1392
|
+
source,
|
|
1393
|
+
target,
|
|
1394
|
+
# Compute node
|
|
1395
|
+
node_destination=f'{self.ssh_user}@{self.slurm_node}',
|
|
1396
|
+
up=up,
|
|
1397
|
+
rsh_option=rsh_option,
|
|
1398
|
+
log_path=log_path,
|
|
1399
|
+
stream_logs=stream_logs,
|
|
1400
|
+
max_retry=max_retry,
|
|
1401
|
+
get_remote_home_dir=lambda: self.sky_dir)
|
|
1402
|
+
|
|
1403
|
+
@timeline.event
|
|
1404
|
+
@context_utils.cancellation_guard
|
|
1405
|
+
def run(self, cmd: Union[str, List[str]],
|
|
1406
|
+
**kwargs) -> Union[int, Tuple[int, str, str]]:
|
|
1407
|
+
"""Run Slurm-supported user commands over an SSH connection.
|
|
1408
|
+
|
|
1409
|
+
Args:
|
|
1410
|
+
cmd: The Slurm-supported user command to run.
|
|
1411
|
+
|
|
1412
|
+
Returns:
|
|
1413
|
+
returncode
|
|
1414
|
+
or
|
|
1415
|
+
A tuple of (returncode, stdout, stderr).
|
|
1416
|
+
"""
|
|
1417
|
+
# Override $HOME so that each SkyPilot cluster's state is isolated
|
|
1418
|
+
# from one another. We rely on the assumption that ~ is exclusively
|
|
1419
|
+
# used by a cluster, and in Slurm that is not the case, as $HOME
|
|
1420
|
+
# could be part of a shared filesystem.
|
|
1421
|
+
# And similarly for SKY_RUNTIME_DIR. See constants.\
|
|
1422
|
+
# SKY_RUNTIME_DIR_ENV_VAR_KEY for more details.
|
|
1423
|
+
#
|
|
1424
|
+
# SSH directly to the compute node instead of using srun.
|
|
1425
|
+
# This avoids Slurm's proctrack/cgroup which kills all processes
|
|
1426
|
+
# when the job step ends (including child processes launched as
|
|
1427
|
+
# a separate process group), breaking background process spawning
|
|
1428
|
+
# (e.g., JobScheduler._run_job which uses launch_new_process_tree).
|
|
1429
|
+
# Note: proctrack/cgroup is enabled by default on Nebius'
|
|
1430
|
+
# Managed Soperator.
|
|
1431
|
+
cmd = (
|
|
1432
|
+
f'export {constants.SKY_RUNTIME_DIR_ENV_VAR_KEY}='
|
|
1433
|
+
f'"{self.skypilot_runtime_dir}" && '
|
|
1434
|
+
# Set the uv cache directory to /tmp/uv_cache_$(id -u) to speed up
|
|
1435
|
+
# package installation while avoiding permission conflicts when
|
|
1436
|
+
# multiple users share the same host. Otherwise it defaults to
|
|
1437
|
+
# ~/.cache/uv.
|
|
1438
|
+
f'export UV_CACHE_DIR=/tmp/uv_cache_$(id -u) && '
|
|
1439
|
+
f'cd {self.sky_dir} && export HOME=$(pwd) && {cmd}')
|
|
1440
|
+
|
|
1441
|
+
return super().run(cmd, **kwargs)
|
sky/utils/command_runner.pyi
CHANGED
|
@@ -6,7 +6,7 @@ determine the return type based on the value of require_outputs.
|
|
|
6
6
|
"""
|
|
7
7
|
import enum
|
|
8
8
|
import typing
|
|
9
|
-
from typing import Any, Iterable, List, Optional, Tuple, Union
|
|
9
|
+
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
|
|
10
10
|
|
|
11
11
|
from typing_extensions import Literal
|
|
12
12
|
|
|
@@ -130,7 +130,7 @@ class SSHCommandRunner(CommandRunner):
|
|
|
130
130
|
ip: str
|
|
131
131
|
port: int
|
|
132
132
|
ssh_user: str
|
|
133
|
-
ssh_private_key: str
|
|
133
|
+
ssh_private_key: Optional[str]
|
|
134
134
|
ssh_control_name: Optional[str]
|
|
135
135
|
docker_user: str
|
|
136
136
|
disable_control_master: Optional[bool]
|
|
@@ -140,7 +140,7 @@ class SSHCommandRunner(CommandRunner):
|
|
|
140
140
|
self,
|
|
141
141
|
node: Tuple[str, int],
|
|
142
142
|
ssh_user: str,
|
|
143
|
-
ssh_private_key: str,
|
|
143
|
+
ssh_private_key: Optional[str],
|
|
144
144
|
ssh_control_name: Optional[str] = ...,
|
|
145
145
|
ssh_proxy_command: Optional[str] = ...,
|
|
146
146
|
docker_user: Optional[str] = ...,
|
|
@@ -216,7 +216,8 @@ class SSHCommandRunner(CommandRunner):
|
|
|
216
216
|
up: bool,
|
|
217
217
|
log_path: str = ...,
|
|
218
218
|
stream_logs: bool = ...,
|
|
219
|
-
max_retry: int =
|
|
219
|
+
max_retry: int = ...,
|
|
220
|
+
get_remote_home_dir: Callable[[], str] = ...) -> None:
|
|
220
221
|
...
|
|
221
222
|
|
|
222
223
|
def port_forward_command(
|
|
@@ -306,6 +307,28 @@ class KubernetesCommandRunner(CommandRunner):
|
|
|
306
307
|
...
|
|
307
308
|
|
|
308
309
|
|
|
310
|
+
class SlurmCommandRunner(SSHCommandRunner):
|
|
311
|
+
"""Runner for Slurm commands."""
|
|
312
|
+
sky_dir: str
|
|
313
|
+
skypilot_runtime_dir: str
|
|
314
|
+
job_id: str
|
|
315
|
+
slurm_node: str
|
|
316
|
+
|
|
317
|
+
def __init__(
|
|
318
|
+
self,
|
|
319
|
+
node: Tuple[str, int],
|
|
320
|
+
ssh_user: str,
|
|
321
|
+
ssh_private_key: Optional[str],
|
|
322
|
+
*,
|
|
323
|
+
sky_dir: str,
|
|
324
|
+
skypilot_runtime_dir: str,
|
|
325
|
+
job_id: str,
|
|
326
|
+
slurm_node: str,
|
|
327
|
+
**kwargs,
|
|
328
|
+
) -> None:
|
|
329
|
+
...
|
|
330
|
+
|
|
331
|
+
|
|
309
332
|
class LocalProcessCommandRunner(CommandRunner):
|
|
310
333
|
|
|
311
334
|
def __init__(self) -> None:
|
sky/utils/common_utils.py
CHANGED
|
@@ -300,6 +300,7 @@ _current_user: Optional['models.User'] = None
|
|
|
300
300
|
_current_request_id: Optional[str] = None
|
|
301
301
|
|
|
302
302
|
|
|
303
|
+
# TODO(aylei,hailong): request context should be contextual
|
|
303
304
|
def set_request_context(client_entrypoint: Optional[str],
|
|
304
305
|
client_command: Optional[str],
|
|
305
306
|
using_remote_api_server: bool,
|
|
@@ -341,19 +342,32 @@ def get_current_command() -> str:
|
|
|
341
342
|
|
|
342
343
|
|
|
343
344
|
def get_current_user() -> 'models.User':
|
|
344
|
-
"""Returns the current
|
|
345
|
+
"""Returns the user in current server session."""
|
|
345
346
|
if _current_user is not None:
|
|
346
347
|
return _current_user
|
|
347
348
|
return models.User.get_current_user()
|
|
348
349
|
|
|
349
350
|
|
|
350
351
|
def get_current_user_name() -> str:
|
|
351
|
-
"""Returns the current
|
|
352
|
+
"""Returns the user name in current server session."""
|
|
352
353
|
name = get_current_user().name
|
|
353
354
|
assert name is not None
|
|
354
355
|
return name
|
|
355
356
|
|
|
356
357
|
|
|
358
|
+
def get_local_user_name() -> str:
|
|
359
|
+
"""Returns the user name in local environment.
|
|
360
|
+
|
|
361
|
+
This is for backward compatibility where anonymous access is implicitly
|
|
362
|
+
allowed when no authentication method at server-side is configured and
|
|
363
|
+
the username from client environment variable will be used to identify the
|
|
364
|
+
user.
|
|
365
|
+
"""
|
|
366
|
+
name = os.getenv(constants.USER_ENV_VAR, getpass.getuser())
|
|
367
|
+
assert name is not None
|
|
368
|
+
return name
|
|
369
|
+
|
|
370
|
+
|
|
357
371
|
def set_current_user(user: 'models.User'):
|
|
358
372
|
"""Sets the current user."""
|
|
359
373
|
global _current_user
|
|
@@ -724,7 +738,8 @@ def find_free_port(start_port: int) -> int:
|
|
|
724
738
|
try:
|
|
725
739
|
s.bind(('', port))
|
|
726
740
|
return port
|
|
727
|
-
except OSError:
|
|
741
|
+
except OSError as e:
|
|
742
|
+
logger.debug(f'Error binding port {port}: {e}')
|
|
728
743
|
pass
|
|
729
744
|
raise OSError('No free ports available.')
|
|
730
745
|
|