skypilot-nightly 1.0.0.dev20251203__py3-none-any.whl → 1.0.0.dev20251210__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. sky/__init__.py +4 -2
  2. sky/adaptors/aws.py +1 -61
  3. sky/adaptors/slurm.py +478 -0
  4. sky/backends/backend_utils.py +45 -4
  5. sky/backends/cloud_vm_ray_backend.py +32 -33
  6. sky/backends/task_codegen.py +340 -2
  7. sky/catalog/__init__.py +0 -3
  8. sky/catalog/kubernetes_catalog.py +12 -4
  9. sky/catalog/slurm_catalog.py +243 -0
  10. sky/check.py +14 -3
  11. sky/client/cli/command.py +329 -22
  12. sky/client/sdk.py +56 -2
  13. sky/clouds/__init__.py +2 -0
  14. sky/clouds/cloud.py +7 -0
  15. sky/clouds/slurm.py +578 -0
  16. sky/clouds/ssh.py +2 -1
  17. sky/clouds/vast.py +10 -0
  18. sky/core.py +128 -36
  19. sky/dashboard/out/404.html +1 -1
  20. sky/dashboard/out/_next/static/KYAhEFa3FTfq4JyKVgo-s/_buildManifest.js +1 -0
  21. sky/dashboard/out/_next/static/chunks/3294.ddda8c6c6f9f24dc.js +1 -0
  22. sky/dashboard/out/_next/static/chunks/3850-fd5696f3bbbaddae.js +1 -0
  23. sky/dashboard/out/_next/static/chunks/6856-da20c5fd999f319c.js +1 -0
  24. sky/dashboard/out/_next/static/chunks/6990-09cbf02d3cd518c3.js +1 -0
  25. sky/dashboard/out/_next/static/chunks/9353-8369df1cf105221c.js +1 -0
  26. sky/dashboard/out/_next/static/chunks/pages/_app-68b647e26f9d2793.js +34 -0
  27. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-33f525539665fdfd.js +16 -0
  28. sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-abfcac9c137aa543.js → [cluster]-a7565f586ef86467.js} +1 -1
  29. sky/dashboard/out/_next/static/chunks/pages/{clusters-ee39056f9851a3ff.js → clusters-9e5d47818b9bdadd.js} +1 -1
  30. sky/dashboard/out/_next/static/chunks/pages/{config-dfb9bf07b13045f4.js → config-718cdc365de82689.js} +1 -1
  31. sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-c0b5935149902e6f.js → [context]-12c559ec4d81fdbd.js} +1 -1
  32. sky/dashboard/out/_next/static/chunks/pages/{infra-aed0ea19df7cf961.js → infra-d187cd0413d72475.js} +1 -1
  33. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-895847b6cf200b04.js +16 -0
  34. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/{[pool]-9faf940b253e3e06.js → [pool]-8d0f4655400b4eb9.js} +2 -2
  35. sky/dashboard/out/_next/static/chunks/pages/{jobs-2072b48b617989c9.js → jobs-e5a98f17f8513a96.js} +1 -1
  36. sky/dashboard/out/_next/static/chunks/pages/plugins/[...slug]-4f46050ca065d8f8.js +1 -0
  37. sky/dashboard/out/_next/static/chunks/pages/{users-f42674164aa73423.js → users-2f7646eb77785a2c.js} +1 -1
  38. sky/dashboard/out/_next/static/chunks/pages/{volumes-b84b948ff357c43e.js → volumes-ef19d49c6d0e8500.js} +1 -1
  39. sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-84a40f8c7c627fe4.js → [name]-96e0f298308da7e2.js} +1 -1
  40. sky/dashboard/out/_next/static/chunks/pages/{workspaces-531b2f8c4bf89f82.js → workspaces-cb4da3abe08ebf19.js} +1 -1
  41. sky/dashboard/out/_next/static/chunks/{webpack-64e05f17bf2cf8ce.js → webpack-fba3de387ff6bb08.js} +1 -1
  42. sky/dashboard/out/_next/static/css/c5a4cfd2600fc715.css +3 -0
  43. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  44. sky/dashboard/out/clusters/[cluster].html +1 -1
  45. sky/dashboard/out/clusters.html +1 -1
  46. sky/dashboard/out/config.html +1 -1
  47. sky/dashboard/out/index.html +1 -1
  48. sky/dashboard/out/infra/[context].html +1 -1
  49. sky/dashboard/out/infra.html +1 -1
  50. sky/dashboard/out/jobs/[job].html +1 -1
  51. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  52. sky/dashboard/out/jobs.html +1 -1
  53. sky/dashboard/out/plugins/[...slug].html +1 -0
  54. sky/dashboard/out/users.html +1 -1
  55. sky/dashboard/out/volumes.html +1 -1
  56. sky/dashboard/out/workspace/new.html +1 -1
  57. sky/dashboard/out/workspaces/[name].html +1 -1
  58. sky/dashboard/out/workspaces.html +1 -1
  59. sky/data/mounting_utils.py +16 -2
  60. sky/global_user_state.py +3 -3
  61. sky/models.py +2 -0
  62. sky/optimizer.py +6 -5
  63. sky/provision/__init__.py +1 -0
  64. sky/provision/common.py +20 -0
  65. sky/provision/docker_utils.py +15 -2
  66. sky/provision/kubernetes/utils.py +42 -6
  67. sky/provision/provisioner.py +15 -6
  68. sky/provision/slurm/__init__.py +12 -0
  69. sky/provision/slurm/config.py +13 -0
  70. sky/provision/slurm/instance.py +572 -0
  71. sky/provision/slurm/utils.py +583 -0
  72. sky/provision/vast/instance.py +4 -1
  73. sky/provision/vast/utils.py +10 -6
  74. sky/serve/server/impl.py +1 -1
  75. sky/server/constants.py +1 -1
  76. sky/server/plugins.py +222 -0
  77. sky/server/requests/executor.py +5 -2
  78. sky/server/requests/payloads.py +12 -1
  79. sky/server/requests/request_names.py +2 -0
  80. sky/server/requests/requests.py +5 -1
  81. sky/server/requests/serializers/encoders.py +17 -0
  82. sky/server/requests/serializers/return_value_serializers.py +60 -0
  83. sky/server/server.py +78 -8
  84. sky/server/server_utils.py +30 -0
  85. sky/setup_files/dependencies.py +2 -0
  86. sky/skylet/attempt_skylet.py +13 -3
  87. sky/skylet/constants.py +34 -9
  88. sky/skylet/events.py +10 -4
  89. sky/skylet/executor/__init__.py +1 -0
  90. sky/skylet/executor/slurm.py +189 -0
  91. sky/skylet/job_lib.py +2 -1
  92. sky/skylet/log_lib.py +22 -6
  93. sky/skylet/log_lib.pyi +8 -6
  94. sky/skylet/skylet.py +5 -1
  95. sky/skylet/subprocess_daemon.py +2 -1
  96. sky/ssh_node_pools/constants.py +12 -0
  97. sky/ssh_node_pools/core.py +40 -3
  98. sky/ssh_node_pools/deploy/__init__.py +4 -0
  99. sky/{utils/kubernetes/deploy_ssh_node_pools.py → ssh_node_pools/deploy/deploy.py} +279 -504
  100. sky/ssh_node_pools/deploy/tunnel_utils.py +199 -0
  101. sky/ssh_node_pools/deploy/utils.py +173 -0
  102. sky/ssh_node_pools/server.py +11 -13
  103. sky/{utils/kubernetes/ssh_utils.py → ssh_node_pools/utils.py} +9 -6
  104. sky/templates/kubernetes-ray.yml.j2 +8 -0
  105. sky/templates/slurm-ray.yml.j2 +85 -0
  106. sky/templates/vast-ray.yml.j2 +1 -0
  107. sky/users/model.conf +1 -1
  108. sky/users/permission.py +24 -1
  109. sky/users/rbac.py +31 -3
  110. sky/utils/annotations.py +108 -8
  111. sky/utils/command_runner.py +197 -5
  112. sky/utils/command_runner.pyi +27 -4
  113. sky/utils/common_utils.py +18 -3
  114. sky/utils/kubernetes/kubernetes_deploy_utils.py +2 -94
  115. sky/utils/kubernetes/ssh-tunnel.sh +7 -376
  116. sky/utils/schemas.py +31 -0
  117. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/METADATA +48 -36
  118. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/RECORD +125 -107
  119. sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +0 -1
  120. sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +0 -1
  121. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +0 -1
  122. sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +0 -1
  123. sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +0 -1
  124. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +0 -1
  125. sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +0 -34
  126. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +0 -16
  127. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +0 -16
  128. sky/dashboard/out/_next/static/css/0748ce22df867032.css +0 -3
  129. sky/utils/kubernetes/cleanup-tunnel.sh +0 -62
  130. /sky/dashboard/out/_next/static/{96_E2yl3QAiIJGOYCkSpB → KYAhEFa3FTfq4JyKVgo-s}/_ssgManifest.js +0 -0
  131. /sky/dashboard/out/_next/static/chunks/{1141-e6aa9ab418717c59.js → 1141-9c810f01ff4f398a.js} +0 -0
  132. /sky/dashboard/out/_next/static/chunks/{3800-7b45f9fbb6308557.js → 3800-b589397dc09c5b4e.js} +0 -0
  133. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/WHEEL +0 -0
  134. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/entry_points.txt +0 -0
  135. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/licenses/LICENSE +0 -0
  136. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20251210.dist-info}/top_level.txt +0 -0
sky/users/rbac.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """RBAC (Role-Based Access Control) functionality for SkyPilot API Server."""
2
2
 
3
3
  import enum
4
- from typing import Dict, List
4
+ from typing import Dict, List, Optional
5
5
 
6
6
  from sky import sky_logging
7
7
  from sky import skypilot_config
@@ -55,8 +55,13 @@ def get_default_role() -> str:
55
55
 
56
56
 
57
57
  def get_role_permissions(
58
+ plugin_rules: Optional[Dict[str, List[Dict[str, str]]]] = None
58
59
  ) -> Dict[str, Dict[str, Dict[str, List[Dict[str, str]]]]]:
59
- """Get all role permissions from config.
60
+ """Get all role permissions from config and plugins.
61
+
62
+ Args:
63
+ plugin_rules: Optional dictionary of plugin RBAC rules to merge.
64
+ Format: {'user': [{'path': '...', 'method': '...'}]}
60
65
 
61
66
  Returns:
62
67
  Dictionary containing all roles and their permissions configuration.
@@ -91,9 +96,32 @@ def get_role_permissions(
91
96
  if 'user' not in config_permissions:
92
97
  config_permissions['user'] = {
93
98
  'permissions': {
94
- 'blocklist': _DEFAULT_USER_BLOCKLIST
99
+ 'blocklist': _DEFAULT_USER_BLOCKLIST.copy()
95
100
  }
96
101
  }
102
+
103
+ # Merge plugin rules into the appropriate roles
104
+ if plugin_rules:
105
+ for role, rules in plugin_rules.items():
106
+ if role not in supported_roles:
107
+ logger.warning(f'Plugin specified invalid role: {role}')
108
+ continue
109
+ if role not in config_permissions:
110
+ config_permissions[role] = {'permissions': {'blocklist': []}}
111
+ if 'permissions' not in config_permissions[role]:
112
+ config_permissions[role]['permissions'] = {'blocklist': []}
113
+ if 'blocklist' not in config_permissions[role]['permissions']:
114
+ config_permissions[role]['permissions']['blocklist'] = []
115
+
116
+ # Merge plugin rules, avoiding duplicates
117
+ existing_rules = config_permissions[role]['permissions'][
118
+ 'blocklist']
119
+ for rule in rules:
120
+ if rule not in existing_rules:
121
+ existing_rules.append(rule)
122
+ logger.debug(f'Added plugin RBAC rule for {role}: '
123
+ f'{rule["method"]} {rule["path"]}')
124
+
97
125
  return config_permissions
98
126
 
99
127
 
sky/utils/annotations.py CHANGED
@@ -1,14 +1,20 @@
1
1
  """Annotations for public APIs."""
2
2
 
3
3
  import functools
4
- from typing import Callable, Literal, TypeVar
4
+ import threading
5
+ import time
6
+ from typing import Callable, List, Literal, TypeVar
7
+ import weakref
5
8
 
6
9
  import cachetools
7
10
  from typing_extensions import ParamSpec
8
11
 
9
12
  # Whether the current process is a SkyPilot API server process.
10
13
  is_on_api_server = True
11
- _FUNCTIONS_NEED_RELOAD_CACHE = []
14
+ _FUNCTIONS_NEED_RELOAD_CACHE_LOCK = threading.Lock()
15
+ # Caches can be thread-local, use weakref to avoid blocking the GC when the
16
+ # thread is destroyed.
17
+ _FUNCTIONS_NEED_RELOAD_CACHE: List[weakref.ReferenceType] = []
12
18
 
13
19
  T = TypeVar('T')
14
20
  P = ParamSpec('P')
@@ -30,6 +36,94 @@ def client_api(func: Callable[P, T]) -> Callable[P, T]:
30
36
  return wrapper
31
37
 
32
38
 
39
+ def _register_functions_need_reload_cache(func: Callable) -> Callable:
40
+ """Register a cachefunction that needs to be reloaded for a new request.
41
+
42
+ The function will be registered as a weak reference to avoid blocking GC.
43
+ """
44
+ assert hasattr(func, 'cache_clear'), f'{func.__name__} is not cacheable'
45
+ wrapped_fn = func
46
+ try:
47
+ func_ref = weakref.ref(func)
48
+ except TypeError:
49
+ # The function might be not weakrefable (e.g. functools.lru_cache),
50
+ # wrap it in this case.
51
+ @functools.wraps(func)
52
+ def wrapper(*args, **kwargs):
53
+ return func(*args, **kwargs)
54
+
55
+ wrapper.cache_clear = func.cache_clear # type: ignore[attr-defined]
56
+ func_ref = weakref.ref(wrapper)
57
+ wrapped_fn = wrapper
58
+ with _FUNCTIONS_NEED_RELOAD_CACHE_LOCK:
59
+ _FUNCTIONS_NEED_RELOAD_CACHE.append(func_ref)
60
+ return wrapped_fn
61
+
62
+
63
+ class ThreadLocalTTLCache(threading.local):
64
+ """Thread-local storage for _thread_local_lru_cache decorator."""
65
+
66
+ def __init__(self, func, maxsize: int, ttl: int):
67
+ super().__init__()
68
+ self.func = func
69
+ self.maxsize = maxsize
70
+ self.ttl = ttl
71
+
72
+ def get_cache(self):
73
+ if not hasattr(self, 'cache'):
74
+ self.cache = ttl_cache(scope='request',
75
+ maxsize=self.maxsize,
76
+ ttl=self.ttl,
77
+ timer=time.time)(self.func)
78
+ return self.cache
79
+
80
+ def __del__(self):
81
+ if hasattr(self, 'cache'):
82
+ self.cache.cache_clear()
83
+ self.cache = None
84
+
85
+
86
+ def thread_local_ttl_cache(maxsize=32, ttl=60 * 55):
87
+ """Thread-local TTL cache decorator.
88
+
89
+ Args:
90
+ maxsize: Maximum size of the cache.
91
+ ttl: Time to live for the cache in seconds.
92
+ Default is 55 minutes, a bit less than 1 hour
93
+ default lifetime of an STS token.
94
+ """
95
+
96
+ def decorator(func):
97
+ # Create thread-local storage for the LRU cache
98
+ local_cache = ThreadLocalTTLCache(func, maxsize, ttl)
99
+
100
+ # We can't apply the lru_cache here, because this runs at import time
101
+ # so we will always have the main thread's cache.
102
+
103
+ @functools.wraps(func)
104
+ def wrapper(*args, **kwargs):
105
+ # We are within the actual function call, which may be on a thread,
106
+ # so local_cache.cache will return the correct thread-local cache,
107
+ # which we can now apply and immediately call.
108
+ return local_cache.get_cache()(*args, **kwargs)
109
+
110
+ def cache_info():
111
+ # Note that this will only give the cache info for the current
112
+ # thread's cache.
113
+ return local_cache.get_cache().cache_info()
114
+
115
+ def cache_clear():
116
+ # Note that this will only clear the cache for the current thread.
117
+ local_cache.get_cache().cache_clear()
118
+
119
+ wrapper.cache_info = cache_info # type: ignore[attr-defined]
120
+ wrapper.cache_clear = cache_clear # type: ignore[attr-defined]
121
+
122
+ return wrapper
123
+
124
+ return decorator
125
+
126
+
33
127
  def lru_cache(scope: Literal['global', 'request'], *lru_cache_args,
34
128
  **lru_cache_kwargs) -> Callable:
35
129
  """LRU cache decorator for functions.
@@ -51,8 +145,7 @@ def lru_cache(scope: Literal['global', 'request'], *lru_cache_args,
51
145
  else:
52
146
  cached_func = functools.lru_cache(*lru_cache_args,
53
147
  **lru_cache_kwargs)(func)
54
- _FUNCTIONS_NEED_RELOAD_CACHE.append(cached_func)
55
- return cached_func
148
+ return _register_functions_need_reload_cache(cached_func)
56
149
 
57
150
  return decorator
58
151
 
@@ -72,13 +165,20 @@ def ttl_cache(scope: Literal['global', 'request'], *ttl_cache_args,
72
165
  else:
73
166
  cached_func = cachetools.cached(
74
167
  cachetools.TTLCache(*ttl_cache_args, **ttl_cache_kwargs))(func)
75
- _FUNCTIONS_NEED_RELOAD_CACHE.append(cached_func)
76
- return cached_func
168
+ return _register_functions_need_reload_cache(cached_func)
77
169
 
78
170
  return decorator
79
171
 
80
172
 
81
173
  def clear_request_level_cache():
82
174
  """Clear the request-level cache."""
83
- for func in _FUNCTIONS_NEED_RELOAD_CACHE:
84
- func.cache_clear()
175
+ alive_entries = []
176
+ with _FUNCTIONS_NEED_RELOAD_CACHE_LOCK:
177
+ for entry in _FUNCTIONS_NEED_RELOAD_CACHE:
178
+ func = entry()
179
+ if func is None:
180
+ # Has been GC'ed, drop the reference.
181
+ continue
182
+ func.cache_clear()
183
+ alive_entries.append(entry)
184
+ _FUNCTIONS_NEED_RELOAD_CACHE[:] = alive_entries
@@ -63,6 +63,22 @@ def _ssh_control_path(ssh_control_filename: Optional[str]) -> Optional[str]:
63
63
  return path
64
64
 
65
65
 
66
+ def _is_skypilot_managed_key(key_path: str) -> bool:
67
+ """Check if SSH key follows SkyPilot's managed key format.
68
+
69
+ SkyPilot-managed keys follow the pattern: ~/.sky/clients/<hash>/ssh/sky-key
70
+ External keys (like ~/.ssh/id_rsa) do not follow this pattern.
71
+
72
+ Args:
73
+ key_path: Path to the SSH private key.
74
+
75
+ Returns:
76
+ True if the key follows SkyPilot's managed format, False otherwise.
77
+ """
78
+ parts = os.path.normpath(key_path).split(os.path.sep)
79
+ return len(parts) >= 2 and parts[-1] == 'sky-key' and parts[-2] == 'ssh'
80
+
81
+
66
82
  # Disable sudo for root user. This is useful when the command is running in a
67
83
  # docker container, i.e. image_id is a docker image.
68
84
  ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD = (
@@ -603,7 +619,7 @@ class SSHCommandRunner(CommandRunner):
603
619
  self,
604
620
  node: Tuple[str, int],
605
621
  ssh_user: str,
606
- ssh_private_key: str,
622
+ ssh_private_key: Optional[str],
607
623
  ssh_control_name: Optional[str] = '__default__',
608
624
  ssh_proxy_command: Optional[str] = None,
609
625
  docker_user: Optional[str] = None,
@@ -613,7 +629,7 @@ class SSHCommandRunner(CommandRunner):
613
629
  """Initialize SSHCommandRunner.
614
630
 
615
631
  Example Usage:
616
- runner = SSHCommandRunner(ip, ssh_user, ssh_private_key)
632
+ runner = SSHCommandRunner((ip, port), ssh_user, ssh_private_key)
617
633
  runner.run('ls -l', mode=SshMode.NON_INTERACTIVE)
618
634
  runner.rsync(source, target, up=True)
619
635
 
@@ -650,8 +666,17 @@ class SSHCommandRunner(CommandRunner):
650
666
  self.disable_control_master = (
651
667
  disable_control_master or
652
668
  control_master_utils.should_disable_control_master())
653
- # ensure the ssh key files are created from the database
654
- auth_utils.create_ssh_key_files_from_db(ssh_private_key)
669
+ # Ensure SSH key is available. For SkyPilot-managed keys, create from
670
+ # database. For external keys (e.g., Slurm clusters), verify existence.
671
+ if ssh_private_key is not None and _is_skypilot_managed_key(
672
+ ssh_private_key):
673
+ auth_utils.create_ssh_key_files_from_db(ssh_private_key)
674
+ elif ssh_private_key is not None:
675
+ # Externally managed key - just verify it exists
676
+ expanded_key_path = os.path.expanduser(ssh_private_key)
677
+ if not os.path.exists(expanded_key_path):
678
+ raise FileNotFoundError(
679
+ f'SSH private key not found: {expanded_key_path}')
655
680
  if docker_user is not None:
656
681
  assert port is None or port == 22, (
657
682
  f'port must be None or 22 for docker_user, got {port}.')
@@ -867,6 +892,7 @@ class SSHCommandRunner(CommandRunner):
867
892
  log_path: str = os.devnull,
868
893
  stream_logs: bool = True,
869
894
  max_retry: int = 1,
895
+ get_remote_home_dir: Callable[[], str] = lambda: '~',
870
896
  ) -> None:
871
897
  """Uses 'rsync' to sync 'source' to 'target'.
872
898
 
@@ -879,6 +905,8 @@ class SSHCommandRunner(CommandRunner):
879
905
  stream_logs: Stream logs to the stdout/stderr.
880
906
  max_retry: The maximum number of retries for the rsync command.
881
907
  This value should be non-negative.
908
+ get_remote_home_dir: A callable that returns the remote home
909
+ directory. Defaults to '~'.
882
910
 
883
911
  Raises:
884
912
  exceptions.CommandError: rsync command failed.
@@ -903,7 +931,8 @@ class SSHCommandRunner(CommandRunner):
903
931
  rsh_option=rsh_option,
904
932
  log_path=log_path,
905
933
  stream_logs=stream_logs,
906
- max_retry=max_retry)
934
+ max_retry=max_retry,
935
+ get_remote_home_dir=get_remote_home_dir)
907
936
 
908
937
 
909
938
  class KubernetesCommandRunner(CommandRunner):
@@ -1247,3 +1276,166 @@ class LocalProcessCommandRunner(CommandRunner):
1247
1276
  log_path=log_path,
1248
1277
  stream_logs=stream_logs,
1249
1278
  max_retry=max_retry)
1279
+
1280
+
1281
+ class SlurmCommandRunner(SSHCommandRunner):
1282
+ """Runner for Slurm commands.
1283
+
1284
+ SlurmCommandRunner sends commands over an SSH connection through the Slurm
1285
+ controller, to the virtual instances.
1286
+ """
1287
+
1288
+ def __init__(
1289
+ self,
1290
+ node: Tuple[str, int],
1291
+ ssh_user: str,
1292
+ ssh_private_key: Optional[str],
1293
+ *,
1294
+ sky_dir: str,
1295
+ skypilot_runtime_dir: str,
1296
+ job_id: str,
1297
+ slurm_node: str,
1298
+ **kwargs,
1299
+ ):
1300
+ """Initialize SlurmCommandRunner.
1301
+
1302
+ Example Usage:
1303
+ runner = SlurmCommandRunner(
1304
+ (ip, port),
1305
+ ssh_user,
1306
+ ssh_private_key,
1307
+ sky_dir=sky_dir,
1308
+ skypilot_runtime_dir=skypilot_runtime_dir,
1309
+ job_id=job_id,
1310
+ slurm_node=slurm_node)
1311
+ runner.run('ls -l', mode=SshMode.NON_INTERACTIVE)
1312
+ runner.rsync(source, target, up=True)
1313
+
1314
+ Args:
1315
+ node: (ip, port) The IP address and port of the remote machine
1316
+ (login node).
1317
+ ssh_user: SSH username.
1318
+ ssh_private_key: Path to SSH private key.
1319
+ sky_dir: The private directory for the SkyPilot cluster on the
1320
+ Slurm cluster.
1321
+ skypilot_runtime_dir: The directory for the SkyPilot runtime
1322
+ on the Slurm cluster.
1323
+ job_id: The Slurm job ID for this instance.
1324
+ slurm_node: The Slurm node hostname for this instance
1325
+ (compute node).
1326
+ **kwargs: Additional arguments forwarded to SSHCommandRunner
1327
+ (e.g., ssh_proxy_command).
1328
+ """
1329
+ super().__init__(node, ssh_user, ssh_private_key, **kwargs)
1330
+ self.sky_dir = sky_dir
1331
+ self.skypilot_runtime_dir = skypilot_runtime_dir
1332
+ self.job_id = job_id
1333
+ self.slurm_node = slurm_node
1334
+
1335
+ # Build a chained ProxyCommand that goes through the login node to reach
1336
+ # the compute node where the job is running.
1337
+
1338
+ # First, build SSH options to reach the login node, using the user's
1339
+ # existing proxy command if provided.
1340
+ proxy_ssh_options = ' '.join(
1341
+ ssh_options_list(self.ssh_private_key,
1342
+ None,
1343
+ ssh_proxy_command=self._ssh_proxy_command,
1344
+ port=self.port,
1345
+ disable_control_master=True))
1346
+ login_node_proxy_command = (f'ssh {proxy_ssh_options} '
1347
+ f'-W %h:%p {self.ssh_user}@{self.ip}')
1348
+
1349
+ # Update the proxy command to be the login node proxy, which will
1350
+ # be used by super().run() to reach the compute node.
1351
+ self._ssh_proxy_command = login_node_proxy_command
1352
+ # Update self.ip to target the compute node.
1353
+ self.ip = slurm_node
1354
+ # Assume the compute node's SSH port is 22.
1355
+ # TODO(kevin): Make this configurable if needed.
1356
+ self.port = 22
1357
+
1358
+ def rsync(
1359
+ self,
1360
+ source: str,
1361
+ target: str,
1362
+ *,
1363
+ up: bool,
1364
+ log_path: str = os.devnull,
1365
+ stream_logs: bool = True,
1366
+ max_retry: int = 1,
1367
+ ) -> None:
1368
+ """Rsyncs files directly to the Slurm compute node,
1369
+ by proxying through the Slurm login node.
1370
+
1371
+ For Slurm, files need to be accessible by compute nodes where jobs
1372
+ execute via srun. This means either it has to be on the compute node's
1373
+ local filesystem, or on a shared filesystem.
1374
+ """
1375
+ # TODO(kevin): We can probably optimize this to skip the proxying
1376
+ # if the target dir is in a shared filesystem, since it will
1377
+ # be accessible by the compute node.
1378
+
1379
+ # Build SSH options for rsync using the ProxyCommand set up in __init__
1380
+ # to reach the compute node through the login node.
1381
+ ssh_options = ' '.join(
1382
+ ssh_options_list(
1383
+ # Assume nothing and rely on default SSH behavior when -i is
1384
+ # not specified.
1385
+ None,
1386
+ None,
1387
+ ssh_proxy_command=self._ssh_proxy_command,
1388
+ disable_control_master=True))
1389
+ rsh_option = f'ssh {ssh_options}'
1390
+
1391
+ self._rsync(
1392
+ source,
1393
+ target,
1394
+ # Compute node
1395
+ node_destination=f'{self.ssh_user}@{self.slurm_node}',
1396
+ up=up,
1397
+ rsh_option=rsh_option,
1398
+ log_path=log_path,
1399
+ stream_logs=stream_logs,
1400
+ max_retry=max_retry,
1401
+ get_remote_home_dir=lambda: self.sky_dir)
1402
+
1403
+ @timeline.event
1404
+ @context_utils.cancellation_guard
1405
+ def run(self, cmd: Union[str, List[str]],
1406
+ **kwargs) -> Union[int, Tuple[int, str, str]]:
1407
+ """Run Slurm-supported user commands over an SSH connection.
1408
+
1409
+ Args:
1410
+ cmd: The Slurm-supported user command to run.
1411
+
1412
+ Returns:
1413
+ returncode
1414
+ or
1415
+ A tuple of (returncode, stdout, stderr).
1416
+ """
1417
+ # Override $HOME so that each SkyPilot cluster's state is isolated
1418
+ # from one another. We rely on the assumption that ~ is exclusively
1419
+ # used by a cluster, and in Slurm that is not the case, as $HOME
1420
+ # could be part of a shared filesystem.
1421
+ # And similarly for SKY_RUNTIME_DIR. See constants.\
1422
+ # SKY_RUNTIME_DIR_ENV_VAR_KEY for more details.
1423
+ #
1424
+ # SSH directly to the compute node instead of using srun.
1425
+ # This avoids Slurm's proctrack/cgroup which kills all processes
1426
+ # when the job step ends (including child processes launched as
1427
+ # a separate process group), breaking background process spawning
1428
+ # (e.g., JobScheduler._run_job which uses launch_new_process_tree).
1429
+ # Note: proctrack/cgroup is enabled by default on Nebius'
1430
+ # Managed Soperator.
1431
+ cmd = (
1432
+ f'export {constants.SKY_RUNTIME_DIR_ENV_VAR_KEY}='
1433
+ f'"{self.skypilot_runtime_dir}" && '
1434
+ # Set the uv cache directory to /tmp/uv_cache_$(id -u) to speed up
1435
+ # package installation while avoiding permission conflicts when
1436
+ # multiple users share the same host. Otherwise it defaults to
1437
+ # ~/.cache/uv.
1438
+ f'export UV_CACHE_DIR=/tmp/uv_cache_$(id -u) && '
1439
+ f'cd {self.sky_dir} && export HOME=$(pwd) && {cmd}')
1440
+
1441
+ return super().run(cmd, **kwargs)
@@ -6,7 +6,7 @@ determine the return type based on the value of require_outputs.
6
6
  """
7
7
  import enum
8
8
  import typing
9
- from typing import Any, Iterable, List, Optional, Tuple, Union
9
+ from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
10
10
 
11
11
  from typing_extensions import Literal
12
12
 
@@ -130,7 +130,7 @@ class SSHCommandRunner(CommandRunner):
130
130
  ip: str
131
131
  port: int
132
132
  ssh_user: str
133
- ssh_private_key: str
133
+ ssh_private_key: Optional[str]
134
134
  ssh_control_name: Optional[str]
135
135
  docker_user: str
136
136
  disable_control_master: Optional[bool]
@@ -140,7 +140,7 @@ class SSHCommandRunner(CommandRunner):
140
140
  self,
141
141
  node: Tuple[str, int],
142
142
  ssh_user: str,
143
- ssh_private_key: str,
143
+ ssh_private_key: Optional[str],
144
144
  ssh_control_name: Optional[str] = ...,
145
145
  ssh_proxy_command: Optional[str] = ...,
146
146
  docker_user: Optional[str] = ...,
@@ -216,7 +216,8 @@ class SSHCommandRunner(CommandRunner):
216
216
  up: bool,
217
217
  log_path: str = ...,
218
218
  stream_logs: bool = ...,
219
- max_retry: int = ...) -> None:
219
+ max_retry: int = ...,
220
+ get_remote_home_dir: Callable[[], str] = ...) -> None:
220
221
  ...
221
222
 
222
223
  def port_forward_command(
@@ -306,6 +307,28 @@ class KubernetesCommandRunner(CommandRunner):
306
307
  ...
307
308
 
308
309
 
310
+ class SlurmCommandRunner(SSHCommandRunner):
311
+ """Runner for Slurm commands."""
312
+ sky_dir: str
313
+ skypilot_runtime_dir: str
314
+ job_id: str
315
+ slurm_node: str
316
+
317
+ def __init__(
318
+ self,
319
+ node: Tuple[str, int],
320
+ ssh_user: str,
321
+ ssh_private_key: Optional[str],
322
+ *,
323
+ sky_dir: str,
324
+ skypilot_runtime_dir: str,
325
+ job_id: str,
326
+ slurm_node: str,
327
+ **kwargs,
328
+ ) -> None:
329
+ ...
330
+
331
+
309
332
  class LocalProcessCommandRunner(CommandRunner):
310
333
 
311
334
  def __init__(self) -> None:
sky/utils/common_utils.py CHANGED
@@ -300,6 +300,7 @@ _current_user: Optional['models.User'] = None
300
300
  _current_request_id: Optional[str] = None
301
301
 
302
302
 
303
+ # TODO(aylei,hailong): request context should be contextual
303
304
  def set_request_context(client_entrypoint: Optional[str],
304
305
  client_command: Optional[str],
305
306
  using_remote_api_server: bool,
@@ -341,19 +342,32 @@ def get_current_command() -> str:
341
342
 
342
343
 
343
344
  def get_current_user() -> 'models.User':
344
- """Returns the current user."""
345
+ """Returns the user in current server session."""
345
346
  if _current_user is not None:
346
347
  return _current_user
347
348
  return models.User.get_current_user()
348
349
 
349
350
 
350
351
  def get_current_user_name() -> str:
351
- """Returns the current user name."""
352
+ """Returns the user name in current server session."""
352
353
  name = get_current_user().name
353
354
  assert name is not None
354
355
  return name
355
356
 
356
357
 
358
+ def get_local_user_name() -> str:
359
+ """Returns the user name in local environment.
360
+
361
+ This is for backward compatibility where anonymous access is implicitly
362
+ allowed when no authentication method at server-side is configured and
363
+ the username from client environment variable will be used to identify the
364
+ user.
365
+ """
366
+ name = os.getenv(constants.USER_ENV_VAR, getpass.getuser())
367
+ assert name is not None
368
+ return name
369
+
370
+
357
371
  def set_current_user(user: 'models.User'):
358
372
  """Sets the current user."""
359
373
  global _current_user
@@ -724,7 +738,8 @@ def find_free_port(start_port: int) -> int:
724
738
  try:
725
739
  s.bind(('', port))
726
740
  return port
727
- except OSError:
741
+ except OSError as e:
742
+ logger.debug(f'Error binding port {port}: {e}')
728
743
  pass
729
744
  raise OSError('No free ports available.')
730
745