skypilot-nightly 1.0.0.dev20251203__py3-none-any.whl → 1.0.0.dev20260112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sky/__init__.py +6 -2
  2. sky/adaptors/aws.py +1 -61
  3. sky/adaptors/slurm.py +565 -0
  4. sky/backends/backend_utils.py +95 -12
  5. sky/backends/cloud_vm_ray_backend.py +224 -65
  6. sky/backends/task_codegen.py +380 -4
  7. sky/catalog/__init__.py +0 -3
  8. sky/catalog/data_fetchers/fetch_gcp.py +9 -1
  9. sky/catalog/data_fetchers/fetch_nebius.py +1 -1
  10. sky/catalog/data_fetchers/fetch_vast.py +4 -2
  11. sky/catalog/kubernetes_catalog.py +12 -4
  12. sky/catalog/seeweb_catalog.py +30 -15
  13. sky/catalog/shadeform_catalog.py +5 -2
  14. sky/catalog/slurm_catalog.py +236 -0
  15. sky/catalog/vast_catalog.py +30 -6
  16. sky/check.py +25 -11
  17. sky/client/cli/command.py +391 -32
  18. sky/client/interactive_utils.py +190 -0
  19. sky/client/sdk.py +64 -2
  20. sky/client/sdk_async.py +9 -0
  21. sky/clouds/__init__.py +2 -0
  22. sky/clouds/aws.py +60 -2
  23. sky/clouds/azure.py +2 -0
  24. sky/clouds/cloud.py +7 -0
  25. sky/clouds/kubernetes.py +2 -0
  26. sky/clouds/runpod.py +38 -7
  27. sky/clouds/slurm.py +610 -0
  28. sky/clouds/ssh.py +3 -2
  29. sky/clouds/vast.py +39 -16
  30. sky/core.py +197 -37
  31. sky/dashboard/out/404.html +1 -1
  32. sky/dashboard/out/_next/static/3nu-b8raeKRNABZ2d4GAG/_buildManifest.js +1 -0
  33. sky/dashboard/out/_next/static/chunks/1871-0565f8975a7dcd10.js +6 -0
  34. sky/dashboard/out/_next/static/chunks/2109-55a1546d793574a7.js +11 -0
  35. sky/dashboard/out/_next/static/chunks/2521-099b07cd9e4745bf.js +26 -0
  36. sky/dashboard/out/_next/static/chunks/2755.a636e04a928a700e.js +31 -0
  37. sky/dashboard/out/_next/static/chunks/3495.05eab4862217c1a5.js +6 -0
  38. sky/dashboard/out/_next/static/chunks/3785.cfc5dcc9434fd98c.js +1 -0
  39. sky/dashboard/out/_next/static/chunks/3850-fd5696f3bbbaddae.js +1 -0
  40. sky/dashboard/out/_next/static/chunks/3981.645d01bf9c8cad0c.js +21 -0
  41. sky/dashboard/out/_next/static/chunks/4083-0115d67c1fb57d6c.js +21 -0
  42. sky/dashboard/out/_next/static/chunks/{8640.5b9475a2d18c5416.js → 429.a58e9ba9742309ed.js} +2 -2
  43. sky/dashboard/out/_next/static/chunks/4555.8e221537181b5dc1.js +6 -0
  44. sky/dashboard/out/_next/static/chunks/4725.937865b81fdaaebb.js +6 -0
  45. sky/dashboard/out/_next/static/chunks/6082-edabd8f6092300ce.js +25 -0
  46. sky/dashboard/out/_next/static/chunks/6989-49cb7dca83a7a62d.js +1 -0
  47. sky/dashboard/out/_next/static/chunks/6990-630bd2a2257275f8.js +1 -0
  48. sky/dashboard/out/_next/static/chunks/7248-a99800d4db8edabd.js +1 -0
  49. sky/dashboard/out/_next/static/chunks/754-cfc5d4ad1b843d29.js +18 -0
  50. sky/dashboard/out/_next/static/chunks/8050-dd8aa107b17dce00.js +16 -0
  51. sky/dashboard/out/_next/static/chunks/8056-d4ae1e0cb81e7368.js +1 -0
  52. sky/dashboard/out/_next/static/chunks/8555.011023e296c127b3.js +6 -0
  53. sky/dashboard/out/_next/static/chunks/8821-93c25df904a8362b.js +1 -0
  54. sky/dashboard/out/_next/static/chunks/8969-0662594b69432ade.js +1 -0
  55. sky/dashboard/out/_next/static/chunks/9025.f15c91c97d124a5f.js +6 -0
  56. sky/dashboard/out/_next/static/chunks/9353-7ad6bd01858556f1.js +1 -0
  57. sky/dashboard/out/_next/static/chunks/pages/_app-5a86569acad99764.js +34 -0
  58. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-8297476714acb4ac.js +6 -0
  59. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-337c3ba1085f1210.js +1 -0
  60. sky/dashboard/out/_next/static/chunks/pages/{clusters-ee39056f9851a3ff.js → clusters-57632ff3684a8b5c.js} +1 -1
  61. sky/dashboard/out/_next/static/chunks/pages/{config-dfb9bf07b13045f4.js → config-718cdc365de82689.js} +1 -1
  62. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-5fd3a453c079c2ea.js +1 -0
  63. sky/dashboard/out/_next/static/chunks/pages/infra-9f85c02c9c6cae9e.js +1 -0
  64. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-90f16972cbecf354.js +1 -0
  65. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-2dd42fc37aad427a.js +16 -0
  66. sky/dashboard/out/_next/static/chunks/pages/jobs-ed806aeace26b972.js +1 -0
  67. sky/dashboard/out/_next/static/chunks/pages/plugins/[...slug]-449a9f5a3bb20fb3.js +1 -0
  68. sky/dashboard/out/_next/static/chunks/pages/users-bec34706b36f3524.js +1 -0
  69. sky/dashboard/out/_next/static/chunks/pages/{volumes-b84b948ff357c43e.js → volumes-a83ba9b38dff7ea9.js} +1 -1
  70. sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-84a40f8c7c627fe4.js → [name]-c781e9c3e52ef9fc.js} +1 -1
  71. sky/dashboard/out/_next/static/chunks/pages/workspaces-91e0942f47310aae.js +1 -0
  72. sky/dashboard/out/_next/static/chunks/webpack-cfe59cf684ee13b9.js +1 -0
  73. sky/dashboard/out/_next/static/css/b0dbca28f027cc19.css +3 -0
  74. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  75. sky/dashboard/out/clusters/[cluster].html +1 -1
  76. sky/dashboard/out/clusters.html +1 -1
  77. sky/dashboard/out/config.html +1 -1
  78. sky/dashboard/out/index.html +1 -1
  79. sky/dashboard/out/infra/[context].html +1 -1
  80. sky/dashboard/out/infra.html +1 -1
  81. sky/dashboard/out/jobs/[job].html +1 -1
  82. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  83. sky/dashboard/out/jobs.html +1 -1
  84. sky/dashboard/out/plugins/[...slug].html +1 -0
  85. sky/dashboard/out/users.html +1 -1
  86. sky/dashboard/out/volumes.html +1 -1
  87. sky/dashboard/out/workspace/new.html +1 -1
  88. sky/dashboard/out/workspaces/[name].html +1 -1
  89. sky/dashboard/out/workspaces.html +1 -1
  90. sky/data/data_utils.py +26 -12
  91. sky/data/mounting_utils.py +44 -5
  92. sky/global_user_state.py +111 -19
  93. sky/jobs/client/sdk.py +8 -3
  94. sky/jobs/controller.py +191 -31
  95. sky/jobs/recovery_strategy.py +109 -11
  96. sky/jobs/server/core.py +81 -4
  97. sky/jobs/server/server.py +14 -0
  98. sky/jobs/state.py +417 -19
  99. sky/jobs/utils.py +73 -80
  100. sky/models.py +11 -0
  101. sky/optimizer.py +8 -6
  102. sky/provision/__init__.py +12 -9
  103. sky/provision/common.py +20 -0
  104. sky/provision/docker_utils.py +15 -2
  105. sky/provision/kubernetes/utils.py +163 -20
  106. sky/provision/kubernetes/volume.py +52 -17
  107. sky/provision/provisioner.py +17 -7
  108. sky/provision/runpod/instance.py +3 -1
  109. sky/provision/runpod/utils.py +13 -1
  110. sky/provision/runpod/volume.py +25 -9
  111. sky/provision/slurm/__init__.py +12 -0
  112. sky/provision/slurm/config.py +13 -0
  113. sky/provision/slurm/instance.py +618 -0
  114. sky/provision/slurm/utils.py +689 -0
  115. sky/provision/vast/instance.py +4 -1
  116. sky/provision/vast/utils.py +11 -6
  117. sky/resources.py +135 -13
  118. sky/schemas/api/responses.py +4 -0
  119. sky/schemas/db/global_user_state/010_save_ssh_key.py +1 -1
  120. sky/schemas/db/spot_jobs/008_add_full_resources.py +34 -0
  121. sky/schemas/db/spot_jobs/009_job_events.py +32 -0
  122. sky/schemas/db/spot_jobs/010_job_events_timestamp_with_timezone.py +43 -0
  123. sky/schemas/db/spot_jobs/011_add_links.py +34 -0
  124. sky/schemas/generated/jobsv1_pb2.py +9 -5
  125. sky/schemas/generated/jobsv1_pb2.pyi +12 -0
  126. sky/schemas/generated/jobsv1_pb2_grpc.py +44 -0
  127. sky/schemas/generated/managed_jobsv1_pb2.py +32 -28
  128. sky/schemas/generated/managed_jobsv1_pb2.pyi +11 -2
  129. sky/serve/serve_utils.py +232 -40
  130. sky/serve/server/impl.py +1 -1
  131. sky/server/common.py +17 -0
  132. sky/server/constants.py +1 -1
  133. sky/server/metrics.py +6 -3
  134. sky/server/plugins.py +238 -0
  135. sky/server/requests/executor.py +5 -2
  136. sky/server/requests/payloads.py +30 -1
  137. sky/server/requests/request_names.py +4 -0
  138. sky/server/requests/requests.py +33 -11
  139. sky/server/requests/serializers/encoders.py +22 -0
  140. sky/server/requests/serializers/return_value_serializers.py +70 -0
  141. sky/server/server.py +506 -109
  142. sky/server/server_utils.py +30 -0
  143. sky/server/uvicorn.py +5 -0
  144. sky/setup_files/MANIFEST.in +1 -0
  145. sky/setup_files/dependencies.py +22 -9
  146. sky/sky_logging.py +2 -1
  147. sky/skylet/attempt_skylet.py +13 -3
  148. sky/skylet/constants.py +55 -13
  149. sky/skylet/events.py +10 -4
  150. sky/skylet/executor/__init__.py +1 -0
  151. sky/skylet/executor/slurm.py +187 -0
  152. sky/skylet/job_lib.py +91 -5
  153. sky/skylet/log_lib.py +22 -6
  154. sky/skylet/log_lib.pyi +8 -6
  155. sky/skylet/services.py +18 -3
  156. sky/skylet/skylet.py +5 -1
  157. sky/skylet/subprocess_daemon.py +2 -1
  158. sky/ssh_node_pools/constants.py +12 -0
  159. sky/ssh_node_pools/core.py +40 -3
  160. sky/ssh_node_pools/deploy/__init__.py +4 -0
  161. sky/{utils/kubernetes/deploy_ssh_node_pools.py → ssh_node_pools/deploy/deploy.py} +279 -504
  162. sky/ssh_node_pools/deploy/tunnel/ssh-tunnel.sh +379 -0
  163. sky/ssh_node_pools/deploy/tunnel_utils.py +199 -0
  164. sky/ssh_node_pools/deploy/utils.py +173 -0
  165. sky/ssh_node_pools/server.py +11 -13
  166. sky/{utils/kubernetes/ssh_utils.py → ssh_node_pools/utils.py} +9 -6
  167. sky/templates/kubernetes-ray.yml.j2 +12 -6
  168. sky/templates/slurm-ray.yml.j2 +115 -0
  169. sky/templates/vast-ray.yml.j2 +1 -0
  170. sky/templates/websocket_proxy.py +18 -41
  171. sky/users/model.conf +1 -1
  172. sky/users/permission.py +85 -52
  173. sky/users/rbac.py +31 -3
  174. sky/utils/annotations.py +108 -8
  175. sky/utils/auth_utils.py +42 -0
  176. sky/utils/cli_utils/status_utils.py +19 -5
  177. sky/utils/cluster_utils.py +10 -3
  178. sky/utils/command_runner.py +389 -35
  179. sky/utils/command_runner.pyi +43 -4
  180. sky/utils/common_utils.py +47 -31
  181. sky/utils/context.py +32 -0
  182. sky/utils/db/db_utils.py +36 -6
  183. sky/utils/db/migration_utils.py +41 -21
  184. sky/utils/infra_utils.py +5 -1
  185. sky/utils/instance_links.py +139 -0
  186. sky/utils/interactive_utils.py +49 -0
  187. sky/utils/kubernetes/generate_kubeconfig.sh +42 -33
  188. sky/utils/kubernetes/kubernetes_deploy_utils.py +2 -94
  189. sky/utils/kubernetes/rsync_helper.sh +5 -1
  190. sky/utils/kubernetes/ssh-tunnel.sh +7 -376
  191. sky/utils/plugin_extensions/__init__.py +14 -0
  192. sky/utils/plugin_extensions/external_failure_source.py +176 -0
  193. sky/utils/resources_utils.py +10 -8
  194. sky/utils/rich_utils.py +9 -11
  195. sky/utils/schemas.py +93 -19
  196. sky/utils/status_lib.py +7 -0
  197. sky/utils/subprocess_utils.py +17 -0
  198. sky/volumes/client/sdk.py +6 -3
  199. sky/volumes/server/core.py +65 -27
  200. sky_templates/ray/start_cluster +8 -4
  201. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/METADATA +67 -59
  202. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/RECORD +208 -180
  203. sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +0 -1
  204. sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +0 -11
  205. sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +0 -6
  206. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +0 -1
  207. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +0 -1
  208. sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +0 -15
  209. sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +0 -26
  210. sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +0 -1
  211. sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +0 -1
  212. sky/dashboard/out/_next/static/chunks/3800-7b45f9fbb6308557.js +0 -1
  213. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +0 -1
  214. sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +0 -1
  215. sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +0 -15
  216. sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +0 -13
  217. sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +0 -1
  218. sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +0 -1
  219. sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +0 -1
  220. sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +0 -30
  221. sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +0 -41
  222. sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +0 -1
  223. sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +0 -1
  224. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +0 -6
  225. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +0 -1
  226. sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +0 -31
  227. sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +0 -30
  228. sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +0 -34
  229. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +0 -16
  230. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-abfcac9c137aa543.js +0 -1
  231. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +0 -1
  232. sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +0 -1
  233. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +0 -16
  234. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-9faf940b253e3e06.js +0 -21
  235. sky/dashboard/out/_next/static/chunks/pages/jobs-2072b48b617989c9.js +0 -1
  236. sky/dashboard/out/_next/static/chunks/pages/users-f42674164aa73423.js +0 -1
  237. sky/dashboard/out/_next/static/chunks/pages/workspaces-531b2f8c4bf89f82.js +0 -1
  238. sky/dashboard/out/_next/static/chunks/webpack-64e05f17bf2cf8ce.js +0 -1
  239. sky/dashboard/out/_next/static/css/0748ce22df867032.css +0 -3
  240. /sky/dashboard/out/_next/static/{96_E2yl3QAiIJGOYCkSpB → 3nu-b8raeKRNABZ2d4GAG}/_ssgManifest.js +0 -0
  241. /sky/{utils/kubernetes → ssh_node_pools/deploy/tunnel}/cleanup-tunnel.sh +0 -0
  242. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/WHEEL +0 -0
  243. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/entry_points.txt +0 -0
  244. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/licenses/LICENSE +0 -0
  245. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/top_level.txt +0 -0
sky/__init__.py CHANGED
@@ -7,7 +7,7 @@ import urllib.request
7
7
  from sky.utils import directory_utils
8
8
 
9
9
  # Replaced with the current commit when building the wheels.
10
- _SKYPILOT_COMMIT_SHA = '3ff39aba6d4752d5c3b09e3fa7d778cefea39370'
10
+ _SKYPILOT_COMMIT_SHA = '5f4cd3b33375c055093474b95f219d26018b7343'
11
11
 
12
12
 
13
13
  def _get_git_commit():
@@ -37,7 +37,7 @@ def _get_git_commit():
37
37
 
38
38
 
39
39
  __commit__ = _get_git_commit()
40
- __version__ = '1.0.0.dev20251203'
40
+ __version__ = '1.0.0.dev20260112'
41
41
  __root_dir__ = directory_utils.get_sky_dir()
42
42
 
43
43
 
@@ -140,8 +140,10 @@ Cudo = clouds.Cudo
140
140
  GCP = clouds.GCP
141
141
  Lambda = clouds.Lambda
142
142
  SCP = clouds.SCP
143
+ Slurm = clouds.Slurm
143
144
  Kubernetes = clouds.Kubernetes
144
145
  K8s = Kubernetes
146
+ SSH = clouds.SSH
145
147
  OCI = clouds.OCI
146
148
  Paperspace = clouds.Paperspace
147
149
  PrimeIntellect = clouds.PrimeIntellect
@@ -163,6 +165,7 @@ __all__ = [
163
165
  'IBM',
164
166
  'Kubernetes',
165
167
  'K8s',
168
+ 'SSH',
166
169
  'Lambda',
167
170
  'OCI',
168
171
  'Paperspace',
@@ -170,6 +173,7 @@ __all__ = [
170
173
  'RunPod',
171
174
  'Vast',
172
175
  'SCP',
176
+ 'Slurm',
173
177
  'Vsphere',
174
178
  'Fluidstack',
175
179
  'Nebius',
sky/adaptors/aws.py CHANGED
@@ -28,7 +28,6 @@ This is informed by the following boto3 docs:
28
28
 
29
29
  # pylint: disable=import-outside-toplevel
30
30
 
31
- import functools
32
31
  import logging
33
32
  import threading
34
33
  import time
@@ -69,65 +68,6 @@ version = 1
69
68
  _MAX_ATTEMPT_FOR_CREATION = 5
70
69
 
71
70
 
72
- class _ThreadLocalTTLCache(threading.local):
73
- """Thread-local storage for _thread_local_lru_cache decorator."""
74
-
75
- def __init__(self, func, maxsize: int, ttl: int):
76
- super().__init__()
77
- self.func = func
78
- self.maxsize = maxsize
79
- self.ttl = ttl
80
-
81
- def get_cache(self):
82
- if not hasattr(self, 'cache'):
83
- self.cache = annotations.ttl_cache(scope='request',
84
- maxsize=self.maxsize,
85
- ttl=self.ttl,
86
- timer=time.time)(self.func)
87
- return self.cache
88
-
89
-
90
- def _thread_local_ttl_cache(maxsize=32, ttl=60 * 55):
91
- """Thread-local TTL cache decorator.
92
-
93
- Args:
94
- maxsize: Maximum size of the cache.
95
- ttl: Time to live for the cache in seconds.
96
- Default is 55 minutes, a bit less than 1 hour
97
- default lifetime of an STS token.
98
- """
99
-
100
- def decorator(func):
101
- # Create thread-local storage for the LRU cache
102
- local_cache = _ThreadLocalTTLCache(func, maxsize, ttl)
103
-
104
- # We can't apply the lru_cache here, because this runs at import time
105
- # so we will always have the main thread's cache.
106
-
107
- @functools.wraps(func)
108
- def wrapper(*args, **kwargs):
109
- # We are within the actual function call, which may be on a thread,
110
- # so local_cache.cache will return the correct thread-local cache,
111
- # which we can now apply and immediately call.
112
- return local_cache.get_cache()(*args, **kwargs)
113
-
114
- def cache_info():
115
- # Note that this will only give the cache info for the current
116
- # thread's cache.
117
- return local_cache.get_cache().cache_info()
118
-
119
- def cache_clear():
120
- # Note that this will only clear the cache for the current thread.
121
- local_cache.get_cache().cache_clear()
122
-
123
- wrapper.cache_info = cache_info # type: ignore[attr-defined]
124
- wrapper.cache_clear = cache_clear # type: ignore[attr-defined]
125
-
126
- return wrapper
127
-
128
- return decorator
129
-
130
-
131
71
  def _assert_kwargs_builtin_type(kwargs):
132
72
  assert all(isinstance(v, (int, float, str)) for v in kwargs.values()), (
133
73
  f'kwargs should not contain none built-in types: {kwargs}')
@@ -174,7 +114,7 @@ def get_workspace_profile() -> Optional[str]:
174
114
 
175
115
  # The TTL cache needs to be thread-local to avoid multiple threads sharing the
176
116
  # same session object, which is not guaranteed to be thread-safe.
177
- @_thread_local_ttl_cache()
117
+ @annotations.thread_local_ttl_cache()
178
118
  def session(check_credentials: bool = True, profile: Optional[str] = None):
179
119
  """Create an AWS session.
180
120
 
sky/adaptors/slurm.py ADDED
@@ -0,0 +1,565 @@
1
+ """Slurm adaptor for SkyPilot."""
2
+
3
+ import ipaddress
4
+ import logging
5
+ import re
6
+ import socket
7
+ import time
8
+ from typing import Dict, List, NamedTuple, Optional, Tuple
9
+
10
+ from sky.adaptors import common
11
+ from sky.utils import command_runner
12
+ from sky.utils import common_utils
13
+ from sky.utils import subprocess_utils
14
+ from sky.utils import timeline
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # ASCII Unit Separator (\x1f) to handle values with spaces
19
+ # and other special characters.
20
+ SEP = r'\x1f'
21
+
22
+ # Regex pattern to extract partition names from scontrol output
23
+ # Matches PartitionName=<name> and captures until the next field
24
+ _PARTITION_NAME_REGEX = re.compile(r'PartitionName=(.+?)(?:\s+\w+=|$)')
25
+
26
+ # Default timeout for waiting for job nodes to be allocated, in seconds.
27
+ _SLURM_DEFAULT_PROVISION_TIMEOUT = 10
28
+
29
+ _IMPORT_ERROR_MESSAGE = ('Failed to import dependencies for Slurm. '
30
+ 'Try running: pip install "skypilot[slurm]"')
31
+ hostlist = common.LazyImport('hostlist',
32
+ import_error_message=_IMPORT_ERROR_MESSAGE)
33
+
34
+
35
+ class SlurmPartition(NamedTuple):
36
+ """Information about the Slurm partitions."""
37
+ name: str
38
+ is_default: bool
39
+
40
+
41
+ # TODO(kevin): Add more API types for other client functions.
42
+ class NodeInfo(NamedTuple):
43
+ """Information about a Slurm node from sinfo."""
44
+ node: str
45
+ state: str
46
+ gres: str
47
+ cpus: int
48
+ memory_gb: float
49
+ # The default partition contains a '*' at the end of the name.
50
+ # It is the caller's responsibility to strip the '*' if needed.
51
+ partition: str
52
+
53
+
54
+ class SlurmClient:
55
+ """Client for Slurm control plane operations."""
56
+
57
+ def __init__(
58
+ self,
59
+ ssh_host: Optional[str] = None,
60
+ ssh_port: Optional[int] = None,
61
+ ssh_user: Optional[str] = None,
62
+ ssh_key: Optional[str] = None,
63
+ ssh_proxy_command: Optional[str] = None,
64
+ ssh_proxy_jump: Optional[str] = None,
65
+ is_inside_slurm_cluster: bool = False,
66
+ ):
67
+ """Initialize SlurmClient.
68
+
69
+ Args:
70
+ ssh_host: Hostname of the Slurm controller.
71
+ ssh_port: SSH port on the controller.
72
+ ssh_user: SSH username.
73
+ ssh_key: Path to SSH private key, or None for keyless SSH.
74
+ ssh_proxy_command: Optional SSH proxy command.
75
+ ssh_proxy_jump: Optional SSH proxy jump destination.
76
+ is_inside_slurm_cluster: If True, uses local execution mode (for
77
+ when running on the Slurm cluster itself). Defaults to False.
78
+ """
79
+ self.ssh_host = ssh_host
80
+ self.ssh_port = ssh_port
81
+ self.ssh_user = ssh_user
82
+ self.ssh_key = ssh_key
83
+ self.ssh_proxy_command = ssh_proxy_command
84
+ self.ssh_proxy_jump = ssh_proxy_jump
85
+
86
+ self._runner: command_runner.CommandRunner
87
+
88
+ if is_inside_slurm_cluster:
89
+ # Local execution mode - for running on the Slurm cluster itself
90
+ # (e.g., autodown from skylet).
91
+ self._runner = command_runner.LocalProcessCommandRunner()
92
+ else:
93
+ # Remote execution via SSH
94
+ assert ssh_host is not None
95
+ assert ssh_port is not None
96
+ assert ssh_user is not None
97
+ self._runner = command_runner.SSHCommandRunner(
98
+ (ssh_host, ssh_port),
99
+ ssh_user,
100
+ ssh_key,
101
+ ssh_proxy_command=ssh_proxy_command,
102
+ ssh_proxy_jump=ssh_proxy_jump,
103
+ enable_interactive_auth=True,
104
+ )
105
+
106
+ def _run_slurm_cmd(self, cmd: str) -> Tuple[int, str, str]:
107
+ return self._runner.run(cmd,
108
+ require_outputs=True,
109
+ separate_stderr=True,
110
+ stream_logs=False)
111
+
112
+ def query_jobs(
113
+ self,
114
+ job_name: Optional[str] = None,
115
+ state_filters: Optional[List[str]] = None,
116
+ ) -> List[str]:
117
+ """Query Slurm jobs by state and optional name.
118
+
119
+ Args:
120
+ job_name: Optional job name to filter by.
121
+ state_filters: List of job states to filter by
122
+ (e.g., ['running', 'pending']). If None, returns all jobs.
123
+
124
+ Returns:
125
+ List of job IDs matching the filters.
126
+ """
127
+ cmd = 'squeue --me -h -o "%i"'
128
+ if state_filters is not None:
129
+ state_filters_str = ','.join(state_filters)
130
+ cmd += f' --states {state_filters_str}'
131
+ if job_name is not None:
132
+ cmd += f' --name {job_name}'
133
+
134
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
135
+ subprocess_utils.handle_returncode(rc,
136
+ cmd,
137
+ 'Failed to query Slurm jobs.',
138
+ stderr=f'{stdout}\n{stderr}')
139
+
140
+ job_ids = stdout.strip().splitlines()
141
+ return job_ids
142
+
143
+ def cancel_jobs_by_name(self,
144
+ job_name: str,
145
+ signal: Optional[str] = None,
146
+ full: bool = False) -> None:
147
+ """Cancel Slurm job(s) by name.
148
+
149
+ Args:
150
+ job_name: Name of the job(s) to cancel.
151
+ signal: Optional signal to send to the job(s).
152
+ full: If True, signals the batch script and its children processes.
153
+ By default, signals other than SIGKILL are not sent to the
154
+ batch step (the shell script).
155
+ """
156
+ cmd = f'scancel --name {job_name}'
157
+ if signal is not None:
158
+ cmd += f' --signal {signal}'
159
+ if full:
160
+ cmd += ' --full'
161
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
162
+ subprocess_utils.handle_returncode(rc,
163
+ cmd,
164
+ f'Failed to cancel job {job_name}.',
165
+ stderr=f'{stdout}\n{stderr}')
166
+ logger.debug(f'Successfully cancelled job {job_name}: {stdout}')
167
+
168
+ def info(self) -> str:
169
+ """Get Slurm cluster information.
170
+
171
+ This is useful for checking if the cluster is accessible and
172
+ retrieving node information.
173
+
174
+ Returns:
175
+ The stdout output from sinfo.
176
+ """
177
+ cmd = 'sinfo'
178
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
179
+ subprocess_utils.handle_returncode(
180
+ rc,
181
+ cmd,
182
+ 'Failed to get Slurm cluster information.',
183
+ stderr=f'{stdout}\n{stderr}')
184
+ return stdout
185
+
186
+ def info_nodes(self) -> List[NodeInfo]:
187
+ """Get Slurm node information.
188
+
189
+ Returns node names, states, GRES (generic resources like GPUs),
190
+ CPUs, memory (MB), and partitions.
191
+ """
192
+ cmd = (f'sinfo -h --Node -o '
193
+ f'"%N{SEP}%t{SEP}%G{SEP}%c{SEP}%m{SEP}%P"')
194
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
195
+ subprocess_utils.handle_returncode(
196
+ rc,
197
+ cmd,
198
+ 'Failed to get Slurm node information.',
199
+ stderr=f'{stdout}\n{stderr}')
200
+
201
+ nodes = []
202
+ for line in stdout.splitlines():
203
+ parts = line.split(SEP)
204
+ if len(parts) != 6:
205
+ raise RuntimeError(
206
+ f'Unexpected output format from sinfo: {line!r}')
207
+ try:
208
+ node_info = NodeInfo(node=parts[0],
209
+ state=parts[1],
210
+ gres=parts[2],
211
+ cpus=int(parts[3]),
212
+ memory_gb=int(parts[4]) / 1024.0,
213
+ partition=parts[5])
214
+ nodes.append(node_info)
215
+ except ValueError as e:
216
+ raise RuntimeError(
217
+ f'Failed to parse node info from line: {line!r}. '
218
+ f'Error: {e}') from e
219
+
220
+ return nodes
221
+
222
+ def node_details(self, node_name: str) -> Dict[str, str]:
223
+ """Get detailed Slurm node information.
224
+
225
+ Returns:
226
+ A dictionary of node attributes.
227
+ """
228
+
229
+ def _parse_scontrol_node_output(output: str) -> Dict[str, str]:
230
+ """Parses the key=value output of 'scontrol show node'."""
231
+ node_info = {}
232
+ # Split by space, handling values that might have spaces
233
+ # if quoted. This is simplified; scontrol can be complex.
234
+ parts = output.split()
235
+ for part in parts:
236
+ if '=' in part:
237
+ key, value = part.split('=', 1)
238
+ # Simple quote removal, might need refinement
239
+ value = value.strip('\'"')
240
+ node_info[key] = value
241
+ return node_info
242
+
243
+ cmd = f'scontrol show node {node_name}'
244
+ rc, node_details, stderr = self._run_slurm_cmd(cmd)
245
+ subprocess_utils.handle_returncode(
246
+ rc,
247
+ cmd,
248
+ f'Failed to get detailed node information for {node_name}.',
249
+ stderr=f'{node_details}\n{stderr}')
250
+ node_info = _parse_scontrol_node_output(node_details)
251
+ return node_info
252
+
253
+ def get_jobs_gres(self, node_name: str) -> List[str]:
254
+ """Get the list of jobs GRES for a given node name.
255
+
256
+ Returns:
257
+ A list of GRES specs (e.g., 'gres/gpu:h100:4')
258
+ for jobs on the node.
259
+ """
260
+ cmd = f'squeue -h --nodelist {node_name} -o "%b"'
261
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
262
+ subprocess_utils.handle_returncode(
263
+ rc,
264
+ cmd,
265
+ f'Failed to get jobs for node {node_name}.',
266
+ stderr=f'{stdout}\n{stderr}')
267
+ return stdout.splitlines()
268
+
269
+ def get_all_jobs_gres(self) -> Dict[str, List[str]]:
270
+ """Get GRES allocation for all running jobs, grouped by node.
271
+
272
+ Returns:
273
+ Dict mapping node_name -> list of GRES strings for jobs on that
274
+ node.
275
+ """
276
+ cmd = f'squeue -h --states=running,completing -o "%N{SEP}%b"'
277
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
278
+ subprocess_utils.handle_returncode(rc,
279
+ cmd,
280
+ 'Failed to get all jobs GRES.',
281
+ stderr=f'{stdout}\n{stderr}')
282
+
283
+ nodes_to_gres: Dict[str, List[str]] = {}
284
+ for line in stdout.splitlines():
285
+ line = line.strip()
286
+ if not line:
287
+ continue
288
+ parts = line.split(SEP)
289
+ if len(parts) != 2:
290
+ # We should never reach here, but just in case.
291
+ continue
292
+ nodelist_str, gres_str = parts
293
+ if not gres_str or gres_str == 'N/A':
294
+ continue
295
+
296
+ for node in hostlist.expand_hostlist(nodelist_str):
297
+ nodes_to_gres.setdefault(node, []).append(gres_str)
298
+
299
+ return nodes_to_gres
300
+
301
+ def get_job_state(self, job_id: str) -> Optional[str]:
302
+ """Get the state of a Slurm job.
303
+
304
+ Args:
305
+ job_id: The Slurm job ID.
306
+
307
+ Returns:
308
+ The job state (e.g., 'PENDING', 'RUNNING', 'COMPLETED', etc.),
309
+ or None if the job is not found.
310
+ """
311
+ # Use --only-job-state since we only need the job state.
312
+ # This reduces the work required by slurmctld.
313
+ cmd = f'squeue -h --only-job-state --jobs {job_id} -o "%T"'
314
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
315
+ subprocess_utils.handle_returncode(
316
+ rc,
317
+ cmd,
318
+ f'Failed to get job state for job {job_id}.',
319
+ stderr=f'{stdout}\n{stderr}')
320
+
321
+ state = stdout.strip()
322
+ return state if state else None
323
+
324
+ def get_jobs_state_by_name(self, job_name: str) -> List[str]:
325
+ """Get the states of all Slurm jobs by name.
326
+ """
327
+ cmd = f'squeue -h --name {job_name} -o "%T"'
328
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
329
+ subprocess_utils.handle_returncode(
330
+ rc,
331
+ cmd,
332
+ f'Failed to get job state for job {job_name}.',
333
+ stderr=f'{stdout}\n{stderr}')
334
+
335
+ states = stdout.splitlines()
336
+ return states
337
+
338
+ @timeline.event
339
+ def get_job_reason(self, job_id: str) -> Optional[str]:
340
+ """Get the reason a job is in its current state
341
+
342
+ Args:
343
+ job_id: The Slurm job ID.
344
+ """
345
+ # Without --states all, squeue omits terminated jobs.
346
+ cmd = f'squeue -h --jobs {job_id} --states all -o "%r"'
347
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
348
+ subprocess_utils.handle_returncode(
349
+ rc,
350
+ cmd,
351
+ f'Failed to get job reason for job {job_id}.',
352
+ stderr=f'{stdout}\n{stderr}')
353
+
354
+ output = stdout.strip()
355
+ if not output:
356
+ return None
357
+
358
+ return output if output != 'None' else None
359
+
360
+ @timeline.event
361
+ def wait_for_job_nodes(self, job_id: str, timeout: int) -> None:
362
+ """Wait for a Slurm job to have nodes allocated.
363
+
364
+ Args:
365
+ job_id: The Slurm job ID.
366
+ timeout: Maximum time to wait in seconds.
367
+ """
368
+ start_time = time.time()
369
+ last_state = None
370
+
371
+ while time.time() - start_time < timeout:
372
+ state = self.get_job_state(job_id)
373
+
374
+ if state != last_state:
375
+ logger.debug(f'Job {job_id} state: {state}')
376
+ last_state = state
377
+
378
+ if state is None:
379
+ raise RuntimeError(f'Job {job_id} not found. It may have been '
380
+ 'cancelled or failed.')
381
+
382
+ if state in ('COMPLETED', 'CANCELLED', 'FAILED', 'TIMEOUT'):
383
+ raise RuntimeError(
384
+ f'Job {job_id} terminated with state {state} '
385
+ 'before nodes were allocated.')
386
+ # TODO(kevin): Log reason for pending.
387
+
388
+ # Check if nodes are allocated by trying to get node list
389
+ cmd = f'squeue -h --jobs {job_id} -o "%N"'
390
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
391
+
392
+ if rc == 0 and stdout.strip():
393
+ # Nodes are allocated
394
+ logger.debug(
395
+ f'Job {job_id} has nodes allocated: {stdout.strip()}')
396
+ return
397
+ elif rc != 0:
398
+ logger.debug(f'Failed to get nodes for job {job_id}: '
399
+ f'{stdout}\n{stderr}')
400
+
401
+ # Wait before checking again
402
+ time.sleep(2)
403
+
404
+ raise TimeoutError(f'Job {job_id} did not get nodes allocated within '
405
+ f'{timeout} seconds. Last state: {last_state}')
406
+
407
+ @timeline.event
408
+ def get_job_nodes(
409
+ self,
410
+ job_id: str,
411
+ wait: bool = True,
412
+ timeout: Optional[int] = None) -> Tuple[List[str], List[str]]:
413
+ """Get the list of nodes and their IPs for a given job ID.
414
+
415
+ The ordering is guaranteed to be stable for the lifetime of the job.
416
+
417
+ Args:
418
+ job_id: The Slurm job ID.
419
+ wait: If True, wait for nodes to be allocated before returning.
420
+ timeout: Maximum time to wait in seconds. Only used when wait=True.
421
+
422
+ Returns:
423
+ A tuple of (nodes, node_ips) where nodes is a list of node names
424
+ and node_ips is a list of corresponding IP addresses.
425
+ """
426
+ # Wait for nodes to be allocated if requested
427
+ if wait:
428
+ if timeout is None:
429
+ timeout = _SLURM_DEFAULT_PROVISION_TIMEOUT
430
+ self.wait_for_job_nodes(job_id, timeout=timeout)
431
+
432
+ cmd = (
433
+ f'squeue -h --jobs {job_id} -o "%N" | tr \',\' \'\\n\' | '
434
+ f'while read node; do '
435
+ # TODO(kevin): Use json output for more robust parsing.
436
+ f'node_addr=$(scontrol show node=$node | grep NodeAddr= | '
437
+ f'awk -F= \'{{print $2}}\' | awk \'{{print $1}}\'); '
438
+ f'echo "$node $node_addr"; '
439
+ f'done')
440
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
441
+ subprocess_utils.handle_returncode(
442
+ rc,
443
+ cmd,
444
+ f'Failed to get nodes for job {job_id}.',
445
+ stderr=f'{stdout}\n{stderr}')
446
+ logger.debug(f'Successfully got nodes for job {job_id}: {stdout}')
447
+
448
+ node_info = {}
449
+ for line in stdout.strip().splitlines():
450
+ line = line.strip()
451
+ if line:
452
+ parts = line.split()
453
+ if len(parts) >= 2:
454
+ node_name = parts[0]
455
+ node_addr = parts[1]
456
+ # Resolve hostname to IP if node_addr is not already
457
+ # an IP address.
458
+ try:
459
+ ipaddress.ip_address(node_addr)
460
+ # Already an IP address
461
+ node_ip = node_addr
462
+ except ValueError:
463
+ # It's a hostname, resolve it to an IP
464
+ try:
465
+ node_ip = socket.gethostbyname(node_addr)
466
+ except socket.gaierror as e:
467
+ raise RuntimeError(
468
+ f'Failed to resolve hostname {node_addr} to IP '
469
+ f'for node {node_name}: '
470
+ f'{common_utils.format_exception(e)}') from e
471
+
472
+ node_info[node_name] = node_ip
473
+
474
+ nodes = list(node_info.keys())
475
+ node_ips = [node_info[node] for node in nodes]
476
+ if not nodes:
477
+ raise RuntimeError(
478
+ f'No nodes found for job {job_id}. '
479
+ f'The job may have terminated or the output was empty.')
480
+ assert (len(nodes) == len(node_ips)
481
+ ), f'Number of nodes and IPs do not match: {nodes} != {node_ips}'
482
+
483
+ return nodes, node_ips
484
+
485
+ def submit_job(
486
+ self,
487
+ partition: str,
488
+ job_name: str,
489
+ script_path: str,
490
+ ) -> str:
491
+ """Submit a Slurm job script.
492
+
493
+ Args:
494
+ partition: Slurm partition to submit to.
495
+ job_name: Name to give the job.
496
+ script_path: Remote path where the script will be stored.
497
+
498
+ Returns:
499
+ The job ID of the submitted job.
500
+ """
501
+ cmd = f'sbatch --partition={partition} {script_path}'
502
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
503
+ subprocess_utils.handle_returncode(rc,
504
+ cmd,
505
+ 'Failed to submit Slurm job.',
506
+ stderr=f'{stdout}\n{stderr}')
507
+
508
+ # Parse job ID from sbatch output (format: "Submitted batch job 12345")
509
+ job_id_match = re.search(r'Submitted batch job (\d+)', stdout)
510
+ if not job_id_match:
511
+ raise RuntimeError(
512
+ f'Failed to parse job ID from sbatch output: {stdout}')
513
+
514
+ job_id = job_id_match.group(1).strip()
515
+ logger.debug(f'Successfully submitted Slurm job {job_id} with name '
516
+ f'{job_name}: {stdout}')
517
+
518
+ return job_id
519
+
520
+ def get_partitions_info(self) -> List[SlurmPartition]:
521
+ """Get the partitions information for the Slurm cluster.
522
+
523
+ Returns:
524
+ List of SlurmPartition objects.
525
+ """
526
+ cmd = 'scontrol show partitions -o'
527
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
528
+ subprocess_utils.handle_returncode(rc,
529
+ cmd,
530
+ 'Failed to get Slurm partitions.',
531
+ stderr=f'{stdout}\n{stderr}')
532
+
533
+ partitions = []
534
+ for line in stdout.strip().splitlines():
535
+ is_default = False
536
+ match = _PARTITION_NAME_REGEX.search(line)
537
+ if 'Default=YES' in line:
538
+ is_default = True
539
+ if match:
540
+ partition = match.group(1).strip()
541
+ if partition:
542
+ partitions.append(
543
+ SlurmPartition(name=partition, is_default=is_default))
544
+ return partitions
545
+
546
+ def get_default_partition(self) -> Optional[str]:
547
+ """Get the default partition name for the Slurm cluster.
548
+
549
+ Returns:
550
+ The default partition name, or None if it cannot be determined.
551
+ """
552
+ partitions = self.get_partitions_info()
553
+ for partition in partitions:
554
+ if partition.is_default:
555
+ return partition.name
556
+ return None
557
+
558
+ def get_partitions(self) -> List[str]:
559
+ """Get unique partition names in the Slurm cluster.
560
+
561
+ Returns:
562
+ List of partition names. The default partition will not have a '*'
563
+ at the end of the name.
564
+ """
565
+ return [partition.name for partition in self.get_partitions_info()]