skypilot-nightly 1.0.0.dev20251203__py3-none-any.whl → 1.0.0.dev20260112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sky/__init__.py +6 -2
  2. sky/adaptors/aws.py +1 -61
  3. sky/adaptors/slurm.py +565 -0
  4. sky/backends/backend_utils.py +95 -12
  5. sky/backends/cloud_vm_ray_backend.py +224 -65
  6. sky/backends/task_codegen.py +380 -4
  7. sky/catalog/__init__.py +0 -3
  8. sky/catalog/data_fetchers/fetch_gcp.py +9 -1
  9. sky/catalog/data_fetchers/fetch_nebius.py +1 -1
  10. sky/catalog/data_fetchers/fetch_vast.py +4 -2
  11. sky/catalog/kubernetes_catalog.py +12 -4
  12. sky/catalog/seeweb_catalog.py +30 -15
  13. sky/catalog/shadeform_catalog.py +5 -2
  14. sky/catalog/slurm_catalog.py +236 -0
  15. sky/catalog/vast_catalog.py +30 -6
  16. sky/check.py +25 -11
  17. sky/client/cli/command.py +391 -32
  18. sky/client/interactive_utils.py +190 -0
  19. sky/client/sdk.py +64 -2
  20. sky/client/sdk_async.py +9 -0
  21. sky/clouds/__init__.py +2 -0
  22. sky/clouds/aws.py +60 -2
  23. sky/clouds/azure.py +2 -0
  24. sky/clouds/cloud.py +7 -0
  25. sky/clouds/kubernetes.py +2 -0
  26. sky/clouds/runpod.py +38 -7
  27. sky/clouds/slurm.py +610 -0
  28. sky/clouds/ssh.py +3 -2
  29. sky/clouds/vast.py +39 -16
  30. sky/core.py +197 -37
  31. sky/dashboard/out/404.html +1 -1
  32. sky/dashboard/out/_next/static/3nu-b8raeKRNABZ2d4GAG/_buildManifest.js +1 -0
  33. sky/dashboard/out/_next/static/chunks/1871-0565f8975a7dcd10.js +6 -0
  34. sky/dashboard/out/_next/static/chunks/2109-55a1546d793574a7.js +11 -0
  35. sky/dashboard/out/_next/static/chunks/2521-099b07cd9e4745bf.js +26 -0
  36. sky/dashboard/out/_next/static/chunks/2755.a636e04a928a700e.js +31 -0
  37. sky/dashboard/out/_next/static/chunks/3495.05eab4862217c1a5.js +6 -0
  38. sky/dashboard/out/_next/static/chunks/3785.cfc5dcc9434fd98c.js +1 -0
  39. sky/dashboard/out/_next/static/chunks/3850-fd5696f3bbbaddae.js +1 -0
  40. sky/dashboard/out/_next/static/chunks/3981.645d01bf9c8cad0c.js +21 -0
  41. sky/dashboard/out/_next/static/chunks/4083-0115d67c1fb57d6c.js +21 -0
  42. sky/dashboard/out/_next/static/chunks/{8640.5b9475a2d18c5416.js → 429.a58e9ba9742309ed.js} +2 -2
  43. sky/dashboard/out/_next/static/chunks/4555.8e221537181b5dc1.js +6 -0
  44. sky/dashboard/out/_next/static/chunks/4725.937865b81fdaaebb.js +6 -0
  45. sky/dashboard/out/_next/static/chunks/6082-edabd8f6092300ce.js +25 -0
  46. sky/dashboard/out/_next/static/chunks/6989-49cb7dca83a7a62d.js +1 -0
  47. sky/dashboard/out/_next/static/chunks/6990-630bd2a2257275f8.js +1 -0
  48. sky/dashboard/out/_next/static/chunks/7248-a99800d4db8edabd.js +1 -0
  49. sky/dashboard/out/_next/static/chunks/754-cfc5d4ad1b843d29.js +18 -0
  50. sky/dashboard/out/_next/static/chunks/8050-dd8aa107b17dce00.js +16 -0
  51. sky/dashboard/out/_next/static/chunks/8056-d4ae1e0cb81e7368.js +1 -0
  52. sky/dashboard/out/_next/static/chunks/8555.011023e296c127b3.js +6 -0
  53. sky/dashboard/out/_next/static/chunks/8821-93c25df904a8362b.js +1 -0
  54. sky/dashboard/out/_next/static/chunks/8969-0662594b69432ade.js +1 -0
  55. sky/dashboard/out/_next/static/chunks/9025.f15c91c97d124a5f.js +6 -0
  56. sky/dashboard/out/_next/static/chunks/9353-7ad6bd01858556f1.js +1 -0
  57. sky/dashboard/out/_next/static/chunks/pages/_app-5a86569acad99764.js +34 -0
  58. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-8297476714acb4ac.js +6 -0
  59. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-337c3ba1085f1210.js +1 -0
  60. sky/dashboard/out/_next/static/chunks/pages/{clusters-ee39056f9851a3ff.js → clusters-57632ff3684a8b5c.js} +1 -1
  61. sky/dashboard/out/_next/static/chunks/pages/{config-dfb9bf07b13045f4.js → config-718cdc365de82689.js} +1 -1
  62. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-5fd3a453c079c2ea.js +1 -0
  63. sky/dashboard/out/_next/static/chunks/pages/infra-9f85c02c9c6cae9e.js +1 -0
  64. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-90f16972cbecf354.js +1 -0
  65. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-2dd42fc37aad427a.js +16 -0
  66. sky/dashboard/out/_next/static/chunks/pages/jobs-ed806aeace26b972.js +1 -0
  67. sky/dashboard/out/_next/static/chunks/pages/plugins/[...slug]-449a9f5a3bb20fb3.js +1 -0
  68. sky/dashboard/out/_next/static/chunks/pages/users-bec34706b36f3524.js +1 -0
  69. sky/dashboard/out/_next/static/chunks/pages/{volumes-b84b948ff357c43e.js → volumes-a83ba9b38dff7ea9.js} +1 -1
  70. sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-84a40f8c7c627fe4.js → [name]-c781e9c3e52ef9fc.js} +1 -1
  71. sky/dashboard/out/_next/static/chunks/pages/workspaces-91e0942f47310aae.js +1 -0
  72. sky/dashboard/out/_next/static/chunks/webpack-cfe59cf684ee13b9.js +1 -0
  73. sky/dashboard/out/_next/static/css/b0dbca28f027cc19.css +3 -0
  74. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  75. sky/dashboard/out/clusters/[cluster].html +1 -1
  76. sky/dashboard/out/clusters.html +1 -1
  77. sky/dashboard/out/config.html +1 -1
  78. sky/dashboard/out/index.html +1 -1
  79. sky/dashboard/out/infra/[context].html +1 -1
  80. sky/dashboard/out/infra.html +1 -1
  81. sky/dashboard/out/jobs/[job].html +1 -1
  82. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  83. sky/dashboard/out/jobs.html +1 -1
  84. sky/dashboard/out/plugins/[...slug].html +1 -0
  85. sky/dashboard/out/users.html +1 -1
  86. sky/dashboard/out/volumes.html +1 -1
  87. sky/dashboard/out/workspace/new.html +1 -1
  88. sky/dashboard/out/workspaces/[name].html +1 -1
  89. sky/dashboard/out/workspaces.html +1 -1
  90. sky/data/data_utils.py +26 -12
  91. sky/data/mounting_utils.py +44 -5
  92. sky/global_user_state.py +111 -19
  93. sky/jobs/client/sdk.py +8 -3
  94. sky/jobs/controller.py +191 -31
  95. sky/jobs/recovery_strategy.py +109 -11
  96. sky/jobs/server/core.py +81 -4
  97. sky/jobs/server/server.py +14 -0
  98. sky/jobs/state.py +417 -19
  99. sky/jobs/utils.py +73 -80
  100. sky/models.py +11 -0
  101. sky/optimizer.py +8 -6
  102. sky/provision/__init__.py +12 -9
  103. sky/provision/common.py +20 -0
  104. sky/provision/docker_utils.py +15 -2
  105. sky/provision/kubernetes/utils.py +163 -20
  106. sky/provision/kubernetes/volume.py +52 -17
  107. sky/provision/provisioner.py +17 -7
  108. sky/provision/runpod/instance.py +3 -1
  109. sky/provision/runpod/utils.py +13 -1
  110. sky/provision/runpod/volume.py +25 -9
  111. sky/provision/slurm/__init__.py +12 -0
  112. sky/provision/slurm/config.py +13 -0
  113. sky/provision/slurm/instance.py +618 -0
  114. sky/provision/slurm/utils.py +689 -0
  115. sky/provision/vast/instance.py +4 -1
  116. sky/provision/vast/utils.py +11 -6
  117. sky/resources.py +135 -13
  118. sky/schemas/api/responses.py +4 -0
  119. sky/schemas/db/global_user_state/010_save_ssh_key.py +1 -1
  120. sky/schemas/db/spot_jobs/008_add_full_resources.py +34 -0
  121. sky/schemas/db/spot_jobs/009_job_events.py +32 -0
  122. sky/schemas/db/spot_jobs/010_job_events_timestamp_with_timezone.py +43 -0
  123. sky/schemas/db/spot_jobs/011_add_links.py +34 -0
  124. sky/schemas/generated/jobsv1_pb2.py +9 -5
  125. sky/schemas/generated/jobsv1_pb2.pyi +12 -0
  126. sky/schemas/generated/jobsv1_pb2_grpc.py +44 -0
  127. sky/schemas/generated/managed_jobsv1_pb2.py +32 -28
  128. sky/schemas/generated/managed_jobsv1_pb2.pyi +11 -2
  129. sky/serve/serve_utils.py +232 -40
  130. sky/serve/server/impl.py +1 -1
  131. sky/server/common.py +17 -0
  132. sky/server/constants.py +1 -1
  133. sky/server/metrics.py +6 -3
  134. sky/server/plugins.py +238 -0
  135. sky/server/requests/executor.py +5 -2
  136. sky/server/requests/payloads.py +30 -1
  137. sky/server/requests/request_names.py +4 -0
  138. sky/server/requests/requests.py +33 -11
  139. sky/server/requests/serializers/encoders.py +22 -0
  140. sky/server/requests/serializers/return_value_serializers.py +70 -0
  141. sky/server/server.py +506 -109
  142. sky/server/server_utils.py +30 -0
  143. sky/server/uvicorn.py +5 -0
  144. sky/setup_files/MANIFEST.in +1 -0
  145. sky/setup_files/dependencies.py +22 -9
  146. sky/sky_logging.py +2 -1
  147. sky/skylet/attempt_skylet.py +13 -3
  148. sky/skylet/constants.py +55 -13
  149. sky/skylet/events.py +10 -4
  150. sky/skylet/executor/__init__.py +1 -0
  151. sky/skylet/executor/slurm.py +187 -0
  152. sky/skylet/job_lib.py +91 -5
  153. sky/skylet/log_lib.py +22 -6
  154. sky/skylet/log_lib.pyi +8 -6
  155. sky/skylet/services.py +18 -3
  156. sky/skylet/skylet.py +5 -1
  157. sky/skylet/subprocess_daemon.py +2 -1
  158. sky/ssh_node_pools/constants.py +12 -0
  159. sky/ssh_node_pools/core.py +40 -3
  160. sky/ssh_node_pools/deploy/__init__.py +4 -0
  161. sky/{utils/kubernetes/deploy_ssh_node_pools.py → ssh_node_pools/deploy/deploy.py} +279 -504
  162. sky/ssh_node_pools/deploy/tunnel/ssh-tunnel.sh +379 -0
  163. sky/ssh_node_pools/deploy/tunnel_utils.py +199 -0
  164. sky/ssh_node_pools/deploy/utils.py +173 -0
  165. sky/ssh_node_pools/server.py +11 -13
  166. sky/{utils/kubernetes/ssh_utils.py → ssh_node_pools/utils.py} +9 -6
  167. sky/templates/kubernetes-ray.yml.j2 +12 -6
  168. sky/templates/slurm-ray.yml.j2 +115 -0
  169. sky/templates/vast-ray.yml.j2 +1 -0
  170. sky/templates/websocket_proxy.py +18 -41
  171. sky/users/model.conf +1 -1
  172. sky/users/permission.py +85 -52
  173. sky/users/rbac.py +31 -3
  174. sky/utils/annotations.py +108 -8
  175. sky/utils/auth_utils.py +42 -0
  176. sky/utils/cli_utils/status_utils.py +19 -5
  177. sky/utils/cluster_utils.py +10 -3
  178. sky/utils/command_runner.py +389 -35
  179. sky/utils/command_runner.pyi +43 -4
  180. sky/utils/common_utils.py +47 -31
  181. sky/utils/context.py +32 -0
  182. sky/utils/db/db_utils.py +36 -6
  183. sky/utils/db/migration_utils.py +41 -21
  184. sky/utils/infra_utils.py +5 -1
  185. sky/utils/instance_links.py +139 -0
  186. sky/utils/interactive_utils.py +49 -0
  187. sky/utils/kubernetes/generate_kubeconfig.sh +42 -33
  188. sky/utils/kubernetes/kubernetes_deploy_utils.py +2 -94
  189. sky/utils/kubernetes/rsync_helper.sh +5 -1
  190. sky/utils/kubernetes/ssh-tunnel.sh +7 -376
  191. sky/utils/plugin_extensions/__init__.py +14 -0
  192. sky/utils/plugin_extensions/external_failure_source.py +176 -0
  193. sky/utils/resources_utils.py +10 -8
  194. sky/utils/rich_utils.py +9 -11
  195. sky/utils/schemas.py +93 -19
  196. sky/utils/status_lib.py +7 -0
  197. sky/utils/subprocess_utils.py +17 -0
  198. sky/volumes/client/sdk.py +6 -3
  199. sky/volumes/server/core.py +65 -27
  200. sky_templates/ray/start_cluster +8 -4
  201. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/METADATA +67 -59
  202. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/RECORD +208 -180
  203. sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +0 -1
  204. sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +0 -11
  205. sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +0 -6
  206. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +0 -1
  207. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +0 -1
  208. sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +0 -15
  209. sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +0 -26
  210. sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +0 -1
  211. sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +0 -1
  212. sky/dashboard/out/_next/static/chunks/3800-7b45f9fbb6308557.js +0 -1
  213. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +0 -1
  214. sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +0 -1
  215. sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +0 -15
  216. sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +0 -13
  217. sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +0 -1
  218. sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +0 -1
  219. sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +0 -1
  220. sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +0 -30
  221. sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +0 -41
  222. sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +0 -1
  223. sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +0 -1
  224. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +0 -6
  225. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +0 -1
  226. sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +0 -31
  227. sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +0 -30
  228. sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +0 -34
  229. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +0 -16
  230. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-abfcac9c137aa543.js +0 -1
  231. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +0 -1
  232. sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +0 -1
  233. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +0 -16
  234. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-9faf940b253e3e06.js +0 -21
  235. sky/dashboard/out/_next/static/chunks/pages/jobs-2072b48b617989c9.js +0 -1
  236. sky/dashboard/out/_next/static/chunks/pages/users-f42674164aa73423.js +0 -1
  237. sky/dashboard/out/_next/static/chunks/pages/workspaces-531b2f8c4bf89f82.js +0 -1
  238. sky/dashboard/out/_next/static/chunks/webpack-64e05f17bf2cf8ce.js +0 -1
  239. sky/dashboard/out/_next/static/css/0748ce22df867032.css +0 -3
  240. /sky/dashboard/out/_next/static/{96_E2yl3QAiIJGOYCkSpB → 3nu-b8raeKRNABZ2d4GAG}/_ssgManifest.js +0 -0
  241. /sky/{utils/kubernetes → ssh_node_pools/deploy/tunnel}/cleanup-tunnel.sh +0 -0
  242. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/WHEEL +0 -0
  243. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/entry_points.txt +0 -0
  244. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/licenses/LICENSE +0 -0
  245. {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/top_level.txt +0 -0
sky/server/server.py CHANGED
@@ -15,12 +15,16 @@ import pathlib
15
15
  import posixpath
16
16
  import re
17
17
  import resource
18
+ import shlex
18
19
  import shutil
20
+ import socket
19
21
  import struct
20
22
  import sys
21
23
  import threading
22
24
  import traceback
23
- from typing import Dict, List, Literal, Optional, Set, Tuple
25
+ import typing
26
+ from typing import (Any, Awaitable, Callable, Dict, List, Literal, Optional,
27
+ Set, Tuple, Type)
24
28
  import uuid
25
29
  import zipfile
26
30
 
@@ -43,11 +47,13 @@ from sky import global_user_state
43
47
  from sky import models
44
48
  from sky import sky_logging
45
49
  from sky.data import storage_utils
50
+ from sky.jobs import state as managed_job_state
46
51
  from sky.jobs import utils as managed_job_utils
47
52
  from sky.jobs.server import server as jobs_rest
48
53
  from sky.metrics import utils as metrics_utils
49
54
  from sky.provision import metadata_utils
50
55
  from sky.provision.kubernetes import utils as kubernetes_utils
56
+ from sky.provision.slurm import utils as slurm_utils
51
57
  from sky.schemas.api import responses
52
58
  from sky.serve.server import server as serve_rest
53
59
  from sky.server import common
@@ -56,6 +62,8 @@ from sky.server import constants as server_constants
56
62
  from sky.server import daemons
57
63
  from sky.server import metrics
58
64
  from sky.server import middleware_utils
65
+ from sky.server import plugins
66
+ from sky.server import server_utils
59
67
  from sky.server import state
60
68
  from sky.server import stream_utils
61
69
  from sky.server import versions
@@ -73,6 +81,7 @@ from sky.usage import usage_lib
73
81
  from sky.users import permission
74
82
  from sky.users import server as users_rest
75
83
  from sky.utils import admin_policy_utils
84
+ from sky.utils import command_runner
76
85
  from sky.utils import common as common_lib
77
86
  from sky.utils import common_utils
78
87
  from sky.utils import context
@@ -80,6 +89,7 @@ from sky.utils import context_utils
80
89
  from sky.utils import controller_utils
81
90
  from sky.utils import dag_utils
82
91
  from sky.utils import env_options
92
+ from sky.utils import interactive_utils
83
93
  from sky.utils import perf_utils
84
94
  from sky.utils import status_lib
85
95
  from sky.utils import subprocess_utils
@@ -88,6 +98,9 @@ from sky.utils.db import db_utils
88
98
  from sky.volumes.server import server as volumes_rest
89
99
  from sky.workspaces import server as workspaces_rest
90
100
 
101
+ if typing.TYPE_CHECKING:
102
+ from sky import backends
103
+
91
104
  # pylint: disable=ungrouped-imports
92
105
  if sys.version_info >= (3, 10):
93
106
  from typing import ParamSpec
@@ -205,6 +218,10 @@ class BasicAuthMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
205
218
  """Middleware to handle HTTP Basic Auth."""
206
219
 
207
220
  async def dispatch(self, request: fastapi.Request, call_next):
221
+ # If a previous middleware already authenticated the user, pass through
222
+ if request.state.auth_user is not None:
223
+ return await call_next(request)
224
+
208
225
  if managed_job_utils.is_consolidation_mode(
209
226
  ) and loopback.is_loopback_request(request):
210
227
  return await call_next(request)
@@ -272,6 +289,10 @@ class BearerTokenMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
272
289
  X-Skypilot-Auth-Mode header. The auth proxy should either validate the
273
290
  auth or set the header X-Skypilot-Auth-Mode: token.
274
291
  """
292
+ # If a previous middleware already authenticated the user, pass through
293
+ if request.state.auth_user is not None:
294
+ return await call_next(request)
295
+
275
296
  has_skypilot_auth_header = (
276
297
  request.headers.get('X-Skypilot-Auth-Mode') == 'token')
277
298
  auth_header = request.headers.get('authorization')
@@ -470,7 +491,8 @@ async def schedule_on_boot_check_async():
470
491
  await executor.schedule_request_async(
471
492
  request_id='skypilot-server-on-boot-check',
472
493
  request_name=request_names.RequestName.CHECK,
473
- request_body=payloads.CheckBody(),
494
+ request_body=server_utils.build_body_at_server(
495
+ request=None, body_type=payloads.CheckBody),
474
496
  func=sky_check.check,
475
497
  schedule_type=requests_lib.ScheduleType.SHORT,
476
498
  is_skypilot_system=True,
@@ -493,7 +515,8 @@ async def lifespan(app: fastapi.FastAPI): # pylint: disable=redefined-outer-nam
493
515
  await executor.schedule_request_async(
494
516
  request_id=event.id,
495
517
  request_name=event.name,
496
- request_body=payloads.RequestBody(),
518
+ request_body=server_utils.build_body_at_server(
519
+ request=None, body_type=payloads.RequestBody),
497
520
  func=event.run_event,
498
521
  schedule_type=requests_lib.ScheduleType.SHORT,
499
522
  is_skypilot_system=True,
@@ -652,6 +675,17 @@ app.add_middleware(BearerTokenMiddleware)
652
675
  # middleware above.
653
676
  app.add_middleware(InitializeRequestAuthUserMiddleware)
654
677
  app.add_middleware(RequestIDMiddleware)
678
+
679
+ # Load plugins after all the middlewares are added, to keep the core
680
+ # middleware stack intact if a plugin adds new middlewares.
681
+ # Note: server.py will be imported twice in server process, once as
682
+ # the top-level entrypoint module and once imported by uvicorn, we only
683
+ # load the plugin when imported by uvicorn for server process.
684
+ # TODO(aylei): move uvicorn app out of the top-level module to avoid
685
+ # duplicate app initialization.
686
+ if __name__ == 'sky.server.server':
687
+ plugins.load_plugins(plugins.ExtensionContext(app=app))
688
+
655
689
  app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
656
690
  app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
657
691
  app.include_router(users_rest.router, prefix='/users', tags=['users'])
@@ -746,8 +780,11 @@ async def enabled_clouds(request: fastapi.Request,
746
780
  await executor.schedule_request_async(
747
781
  request_id=request.state.request_id,
748
782
  request_name=request_names.RequestName.ENABLED_CLOUDS,
749
- request_body=payloads.EnabledCloudsBody(workspace=workspace,
750
- expand=expand),
783
+ request_body=server_utils.build_body_at_server(
784
+ request=request,
785
+ body_type=payloads.EnabledCloudsBody,
786
+ workspace=workspace,
787
+ expand=expand),
751
788
  func=core.enabled_clouds,
752
789
  schedule_type=requests_lib.ScheduleType.SHORT,
753
790
  )
@@ -784,6 +821,36 @@ async def kubernetes_node_info(
784
821
  )
785
822
 
786
823
 
824
+ @app.post('/slurm_gpu_availability')
825
+ async def slurm_gpu_availability(
826
+ request: fastapi.Request,
827
+ slurm_gpu_availability_body: payloads.SlurmGpuAvailabilityRequestBody
828
+ ) -> None:
829
+ """Gets real-time Slurm GPU availability."""
830
+ await executor.schedule_request_async(
831
+ request_id=request.state.request_id,
832
+ request_name=request_names.RequestName.REALTIME_SLURM_GPU_AVAILABILITY,
833
+ request_body=slurm_gpu_availability_body,
834
+ func=core.realtime_slurm_gpu_availability,
835
+ schedule_type=requests_lib.ScheduleType.SHORT,
836
+ )
837
+
838
+
839
+ # Keep the GET method for backwards compatibility
840
+ @app.api_route('/slurm_node_info', methods=['GET', 'POST'])
841
+ async def slurm_node_info(
842
+ request: fastapi.Request,
843
+ slurm_node_info_body: payloads.SlurmNodeInfoRequestBody) -> None:
844
+ """Gets detailed information for each node in the Slurm cluster."""
845
+ await executor.schedule_request_async(
846
+ request_id=request.state.request_id,
847
+ request_name=request_names.RequestName.SLURM_NODE_INFO,
848
+ request_body=slurm_node_info_body,
849
+ func=slurm_utils.slurm_node_info,
850
+ schedule_type=requests_lib.ScheduleType.SHORT,
851
+ )
852
+
853
+
787
854
  @app.get('/status_kubernetes')
788
855
  async def status_kubernetes(request: fastapi.Request) -> None:
789
856
  """[Experimental] Get all SkyPilot resources (including from other '
@@ -791,7 +858,8 @@ async def status_kubernetes(request: fastapi.Request) -> None:
791
858
  await executor.schedule_request_async(
792
859
  request_id=request.state.request_id,
793
860
  request_name=request_names.RequestName.STATUS_KUBERNETES,
794
- request_body=payloads.RequestBody(),
861
+ request_body=server_utils.build_body_at_server(
862
+ request=request, body_type=payloads.RequestBody),
795
863
  func=core.status_kubernetes,
796
864
  schedule_type=requests_lib.ScheduleType.SHORT,
797
865
  )
@@ -1454,13 +1522,29 @@ async def cost_report(request: fastapi.Request,
1454
1522
  )
1455
1523
 
1456
1524
 
1525
+ @app.post('/cluster_events')
1526
+ async def cluster_events(
1527
+ request: fastapi.Request,
1528
+ cluster_events_body: payloads.ClusterEventsBody) -> None:
1529
+ """Gets events for a cluster."""
1530
+ await executor.schedule_request_async(
1531
+ request_id=request.state.request_id,
1532
+ request_name=request_names.RequestName.CLUSTER_EVENTS,
1533
+ request_body=cluster_events_body,
1534
+ func=core.get_cluster_events,
1535
+ schedule_type=requests_lib.ScheduleType.SHORT,
1536
+ request_cluster_name=cluster_events_body.cluster_name or '',
1537
+ )
1538
+
1539
+
1457
1540
  @app.get('/storage/ls')
1458
1541
  async def storage_ls(request: fastapi.Request) -> None:
1459
1542
  """Gets the storages."""
1460
1543
  await executor.schedule_request_async(
1461
1544
  request_id=request.state.request_id,
1462
1545
  request_name=request_names.RequestName.STORAGE_LS,
1463
- request_body=payloads.RequestBody(),
1546
+ request_body=server_utils.build_body_at_server(
1547
+ request=request, body_type=payloads.RequestBody),
1464
1548
  func=core.storage_ls,
1465
1549
  schedule_type=requests_lib.ScheduleType.SHORT,
1466
1550
  )
@@ -1752,6 +1836,22 @@ async def api_status(
1752
1836
  return encoded_request_tasks
1753
1837
 
1754
1838
 
1839
+ @app.get('/api/plugins', response_class=fastapi_responses.ORJSONResponse)
1840
+ async def list_plugins() -> Dict[str, List[Dict[str, Any]]]:
1841
+ """Return metadata about loaded backend plugins."""
1842
+ plugin_infos = []
1843
+ for plugin_info in plugins.get_plugins():
1844
+ info = {
1845
+ 'js_extension_path': plugin_info.js_extension_path,
1846
+ }
1847
+ for attr in ('name', 'version', 'commit'):
1848
+ value = getattr(plugin_info, attr, None)
1849
+ if value is not None:
1850
+ info[attr] = value
1851
+ plugin_infos.append(info)
1852
+ return {'plugins': plugin_infos}
1853
+
1854
+
1755
1855
  @app.get(
1756
1856
  '/api/health',
1757
1857
  # response_model_exclude_unset omits unset fields
@@ -1823,12 +1923,149 @@ async def health(request: fastapi.Request) -> responses.APIHealthResponse:
1823
1923
  )
1824
1924
 
1825
1925
 
1826
- class KubernetesSSHMessageType(IntEnum):
1926
+ class SSHMessageType(IntEnum):
1827
1927
  REGULAR_DATA = 0
1828
1928
  PINGPONG = 1
1829
1929
  LATENCY_MEASUREMENT = 2
1830
1930
 
1831
1931
 
1932
+ async def _get_cluster_and_validate(
1933
+ cluster_name: str,
1934
+ cloud_type: Type[clouds.Cloud],
1935
+ ) -> 'backends.CloudVmRayResourceHandle':
1936
+ """Fetch cluster status and validate it's UP and correct cloud type."""
1937
+ # Run core.status in another thread to avoid blocking the event loop.
1938
+ # TODO(aylei): core.status() will be called with server user, which has
1939
+ # permission to all workspaces, this will break workspace isolation.
1940
+ # It is ok for now, as users with limited access will not get the ssh config
1941
+ # for the clusters in non-accessible workspaces.
1942
+ with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
1943
+ cluster_records = await context_utils.to_thread_with_executor(
1944
+ thread_pool_executor, core.status, cluster_name, all_users=True)
1945
+ cluster_record = cluster_records[0]
1946
+ if cluster_record['status'] != status_lib.ClusterStatus.UP:
1947
+ raise fastapi.HTTPException(
1948
+ status_code=400, detail=f'Cluster {cluster_name} is not running')
1949
+
1950
+ handle: Optional['backends.CloudVmRayResourceHandle'] = cluster_record[
1951
+ 'handle']
1952
+ assert handle is not None, 'Cluster handle is None'
1953
+ if not isinstance(handle.launched_resources.cloud, cloud_type):
1954
+ raise fastapi.HTTPException(
1955
+ status_code=400,
1956
+ detail=f'Cluster {cluster_name} is not a {str(cloud_type())} '
1957
+ 'cluster. Use ssh to connect to the cluster instead.')
1958
+
1959
+ return handle
1960
+
1961
+
1962
+ async def _run_websocket_proxy(
1963
+ websocket: fastapi.WebSocket,
1964
+ read_from_backend: Callable[[], Awaitable[bytes]],
1965
+ write_to_backend: Callable[[bytes], Awaitable[None]],
1966
+ close_backend: Callable[[], Awaitable[None]],
1967
+ timestamps_supported: bool,
1968
+ ) -> bool:
1969
+ """Run bidirectional WebSocket-to-backend proxy.
1970
+
1971
+ Args:
1972
+ websocket: FastAPI WebSocket connection
1973
+ read_from_backend: Async callable to read bytes from backend
1974
+ write_to_backend: Async callable to write bytes to backend
1975
+ close_backend: Async callable to close backend connection
1976
+ timestamps_supported: Whether to use message type framing
1977
+
1978
+ Returns:
1979
+ True if SSH failed, False otherwise
1980
+ """
1981
+ ssh_failed = False
1982
+ websocket_closed = False
1983
+
1984
+ async def websocket_to_backend():
1985
+ try:
1986
+ async for message in websocket.iter_bytes():
1987
+ if timestamps_supported:
1988
+ type_size = struct.calcsize('!B')
1989
+ message_type = struct.unpack('!B', message[:type_size])[0]
1990
+ if message_type == SSHMessageType.REGULAR_DATA:
1991
+ # Regular data - strip type byte and forward to backend
1992
+ message = message[type_size:]
1993
+ elif message_type == SSHMessageType.PINGPONG:
1994
+ # PING message - respond with PONG
1995
+ ping_id_size = struct.calcsize('!I')
1996
+ if len(message) != type_size + ping_id_size:
1997
+ raise ValueError(
1998
+ f'Invalid PING message length: {len(message)}')
1999
+ # Return the same PING message for latency measurement
2000
+ await websocket.send_bytes(message)
2001
+ continue
2002
+ elif message_type == SSHMessageType.LATENCY_MEASUREMENT:
2003
+ # Latency measurement from client
2004
+ latency_size = struct.calcsize('!Q')
2005
+ if len(message) != type_size + latency_size:
2006
+ raise ValueError('Invalid latency measurement '
2007
+ f'message length: {len(message)}')
2008
+ avg_latency_ms = struct.unpack(
2009
+ '!Q',
2010
+ message[type_size:type_size + latency_size])[0]
2011
+ latency_seconds = avg_latency_ms / 1000
2012
+ metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels( # pylint: disable=line-too-long
2013
+ pid=os.getpid()).observe(latency_seconds)
2014
+ continue
2015
+ else:
2016
+ raise ValueError(
2017
+ f'Unknown message type: {message_type}')
2018
+
2019
+ try:
2020
+ await write_to_backend(message)
2021
+ except Exception as e: # pylint: disable=broad-except
2022
+ # Typically we will not reach here, if the conn to backend
2023
+ # is disconnected, backend_to_websocket will exit first.
2024
+ # But just in case.
2025
+ logger.error(f'Failed to write to backend through '
2026
+ f'connection: {e}')
2027
+ nonlocal ssh_failed
2028
+ ssh_failed = True
2029
+ break
2030
+ except fastapi.WebSocketDisconnect:
2031
+ pass
2032
+ nonlocal websocket_closed
2033
+ websocket_closed = True
2034
+ await close_backend()
2035
+
2036
+ async def backend_to_websocket():
2037
+ try:
2038
+ while True:
2039
+ data = await read_from_backend()
2040
+ if not data:
2041
+ if not websocket_closed:
2042
+ logger.warning(
2043
+ 'SSH connection to backend is disconnected '
2044
+ 'before websocket connection is closed')
2045
+ nonlocal ssh_failed
2046
+ ssh_failed = True
2047
+ break
2048
+ if timestamps_supported:
2049
+ # Prepend message type byte (0 = regular data)
2050
+ message_type_bytes = struct.pack(
2051
+ '!B', SSHMessageType.REGULAR_DATA.value)
2052
+ data = message_type_bytes + data
2053
+ await websocket.send_bytes(data)
2054
+ except Exception: # pylint: disable=broad-except
2055
+ pass
2056
+ try:
2057
+ await websocket.close()
2058
+ except Exception: # pylint: disable=broad-except
2059
+ # The websocket might have been closed by the client
2060
+ pass
2061
+
2062
+ await asyncio.gather(websocket_to_backend(),
2063
+ backend_to_websocket(),
2064
+ return_exceptions=True)
2065
+
2066
+ return ssh_failed
2067
+
2068
+
1832
2069
  @app.websocket('/kubernetes-pod-ssh-proxy')
1833
2070
  async def kubernetes_pod_ssh_proxy(
1834
2071
  websocket: fastapi.WebSocket,
@@ -1842,22 +2079,7 @@ async def kubernetes_pod_ssh_proxy(
1842
2079
  logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
1843
2080
  client_version = {client_version}')
1844
2081
 
1845
- # Run core.status in another thread to avoid blocking the event loop.
1846
- with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
1847
- cluster_records = await context_utils.to_thread_with_executor(
1848
- thread_pool_executor, core.status, cluster_name, all_users=True)
1849
- cluster_record = cluster_records[0]
1850
- if cluster_record['status'] != status_lib.ClusterStatus.UP:
1851
- raise fastapi.HTTPException(
1852
- status_code=400, detail=f'Cluster {cluster_name} is not running')
1853
-
1854
- handle = cluster_record['handle']
1855
- assert handle is not None, 'Cluster handle is None'
1856
- if not isinstance(handle.launched_resources.cloud, clouds.Kubernetes):
1857
- raise fastapi.HTTPException(
1858
- status_code=400,
1859
- detail=f'Cluster {cluster_name} is not a Kubernetes cluster'
1860
- 'Use ssh to connect to the cluster instead.')
2082
+ handle = await _get_cluster_and_validate(cluster_name, clouds.Kubernetes)
1861
2083
 
1862
2084
  kubectl_cmd = handle.get_command_runners()[0].port_forward_command(
1863
2085
  port_forward=[(None, 22)])
@@ -1887,96 +2109,25 @@ async def kubernetes_pod_ssh_proxy(
1887
2109
  conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
1888
2110
  pid=os.getpid())
1889
2111
  ssh_failed = False
1890
- websocket_closed = False
1891
2112
  try:
1892
2113
  conn_gauge.inc()
1893
2114
  # Connect to the local port
1894
2115
  reader, writer = await asyncio.open_connection('127.0.0.1', local_port)
1895
2116
 
1896
- async def websocket_to_ssh():
1897
- try:
1898
- async for message in websocket.iter_bytes():
1899
- if timestamps_supported:
1900
- type_size = struct.calcsize('!B')
1901
- message_type = struct.unpack('!B',
1902
- message[:type_size])[0]
1903
- if (message_type ==
1904
- KubernetesSSHMessageType.REGULAR_DATA):
1905
- # Regular data - strip type byte and forward to SSH
1906
- message = message[type_size:]
1907
- elif message_type == KubernetesSSHMessageType.PINGPONG:
1908
- # PING message - respond with PONG (type 1)
1909
- ping_id_size = struct.calcsize('!I')
1910
- if len(message) != type_size + ping_id_size:
1911
- raise ValueError('Invalid PING message '
1912
- f'length: {len(message)}')
1913
- # Return the same PING message, so that the client
1914
- # can measure the latency.
1915
- await websocket.send_bytes(message)
1916
- continue
1917
- elif (message_type ==
1918
- KubernetesSSHMessageType.LATENCY_MEASUREMENT):
1919
- # Latency measurement from client
1920
- latency_size = struct.calcsize('!Q')
1921
- if len(message) != type_size + latency_size:
1922
- raise ValueError(
1923
- 'Invalid latency measurement '
1924
- f'message length: {len(message)}')
1925
- avg_latency_ms = struct.unpack(
1926
- '!Q',
1927
- message[type_size:type_size + latency_size])[0]
1928
- latency_seconds = avg_latency_ms / 1000
1929
- metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels(pid=os.getpid()).observe(latency_seconds) # pylint: disable=line-too-long
1930
- continue
1931
- else:
1932
- # Unknown message type.
1933
- raise ValueError(
1934
- f'Unknown message type: {message_type}')
1935
- writer.write(message)
1936
- try:
1937
- await writer.drain()
1938
- except Exception as e: # pylint: disable=broad-except
1939
- # Typically we will not reach here, if the ssh to pod
1940
- # is disconnected, ssh_to_websocket will exit first.
1941
- # But just in case.
1942
- logger.error('Failed to write to pod through '
1943
- f'port-forward connection: {e}')
1944
- nonlocal ssh_failed
1945
- ssh_failed = True
1946
- break
1947
- except fastapi.WebSocketDisconnect:
1948
- pass
1949
- nonlocal websocket_closed
1950
- websocket_closed = True
1951
- writer.close()
2117
+ async def write_and_drain(data: bytes) -> None:
2118
+ writer.write(data)
2119
+ await writer.drain()
1952
2120
 
1953
- async def ssh_to_websocket():
1954
- try:
1955
- while True:
1956
- data = await reader.read(1024)
1957
- if not data:
1958
- if not websocket_closed:
1959
- logger.warning('SSH connection to pod is '
1960
- 'disconnected before websocket '
1961
- 'connection is closed')
1962
- nonlocal ssh_failed
1963
- ssh_failed = True
1964
- break
1965
- if timestamps_supported:
1966
- # Prepend message type byte (0 = regular data)
1967
- message_type_bytes = struct.pack(
1968
- '!B', KubernetesSSHMessageType.REGULAR_DATA.value)
1969
- data = message_type_bytes + data
1970
- await websocket.send_bytes(data)
1971
- except Exception: # pylint: disable=broad-except
1972
- pass
1973
- try:
1974
- await websocket.close()
1975
- except Exception: # pylint: disable=broad-except
1976
- # The websocket might has been closed by the client.
1977
- pass
2121
+ async def close_writer() -> None:
2122
+ writer.close()
1978
2123
 
1979
- await asyncio.gather(websocket_to_ssh(), ssh_to_websocket())
2124
+ ssh_failed = await _run_websocket_proxy(
2125
+ websocket,
2126
+ read_from_backend=lambda: reader.read(1024),
2127
+ write_to_backend=write_and_drain,
2128
+ close_backend=close_writer,
2129
+ timestamps_supported=timestamps_supported,
2130
+ )
1980
2131
  finally:
1981
2132
  conn_gauge.dec()
1982
2133
  reason = ''
@@ -1990,7 +2141,7 @@ async def kubernetes_pod_ssh_proxy(
1990
2141
  f'output: {str(stdout)}')
1991
2142
  reason = 'KubectlPortForwardExit'
1992
2143
  metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
1993
- pid=os.getpid(), reason='KubectlPortForwardExit').inc()
2144
+ pid=os.getpid(), reason=reason).inc()
1994
2145
  else:
1995
2146
  if ssh_failed:
1996
2147
  reason = 'SSHToPodDisconnected'
@@ -2000,6 +2151,235 @@ async def kubernetes_pod_ssh_proxy(
2000
2151
  pid=os.getpid(), reason=reason).inc()
2001
2152
 
2002
2153
 
2154
+ @app.websocket('/slurm-job-ssh-proxy')
2155
+ async def slurm_job_ssh_proxy(websocket: fastapi.WebSocket,
2156
+ cluster_name: str,
2157
+ worker: int = 0,
2158
+ client_version: Optional[int] = None) -> None:
2159
+ """Proxies SSH to the Slurm job via sshd inside srun."""
2160
+ await websocket.accept()
2161
+ logger.info(f'WebSocket connection accepted for cluster: '
2162
+ f'{cluster_name}')
2163
+
2164
+ timestamps_supported = client_version is not None and client_version > 21
2165
+ logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
2166
+ client_version = {client_version}')
2167
+
2168
+ handle = await _get_cluster_and_validate(cluster_name, clouds.Slurm)
2169
+
2170
+ assert handle.cached_cluster_info is not None, 'Cached cluster info is None'
2171
+ provider_config = handle.cached_cluster_info.provider_config
2172
+ assert provider_config is not None, 'Provider config is None'
2173
+ login_node_ssh_config = provider_config['ssh']
2174
+ login_node_host = login_node_ssh_config['hostname']
2175
+ login_node_port = int(login_node_ssh_config['port'])
2176
+ login_node_user = login_node_ssh_config['user']
2177
+ login_node_key = login_node_ssh_config['private_key']
2178
+ login_node_proxy_command = login_node_ssh_config.get('proxycommand', None)
2179
+ login_node_proxy_jump = login_node_ssh_config.get('proxyjump', None)
2180
+
2181
+ login_node_runner = command_runner.SSHCommandRunner(
2182
+ (login_node_host, login_node_port),
2183
+ login_node_user,
2184
+ login_node_key,
2185
+ ssh_proxy_command=login_node_proxy_command,
2186
+ ssh_proxy_jump=login_node_proxy_jump,
2187
+ )
2188
+
2189
+ ssh_cmd = login_node_runner.ssh_base_command(
2190
+ ssh_mode=command_runner.SshMode.NON_INTERACTIVE,
2191
+ port_forward=None,
2192
+ connect_timeout=None)
2193
+
2194
+ # There can only be one InstanceInfo per instance_id.
2195
+ head_instance = handle.cached_cluster_info.get_head_instance()
2196
+ assert head_instance is not None, 'Head instance is None'
2197
+ job_id = head_instance.tags['job_id']
2198
+
2199
+ # Instances are ordered: head first, then workers
2200
+ instances = handle.cached_cluster_info.instances
2201
+ node_hostnames = [inst[0].tags['node'] for inst in instances.values()]
2202
+ if worker >= len(node_hostnames):
2203
+ raise fastapi.HTTPException(
2204
+ status_code=400,
2205
+ detail=f'Worker index {worker} out of range. '
2206
+ f'Cluster has {len(node_hostnames)} nodes.')
2207
+ target_node = node_hostnames[worker]
2208
+
2209
+ # Run sshd inside the Slurm job "container" via srun, such that it inherits
2210
+ # the resource constraints of the Slurm job.
2211
+ ssh_cmd += [
2212
+ shlex.quote(
2213
+ slurm_utils.srun_sshd_command(job_id, target_node, login_node_user))
2214
+ ]
2215
+
2216
+ proc = await asyncio.create_subprocess_shell(
2217
+ ' '.join(ssh_cmd),
2218
+ stdin=asyncio.subprocess.PIPE,
2219
+ stdout=asyncio.subprocess.PIPE,
2220
+ stderr=asyncio.subprocess.PIPE, # Capture stderr separately for logging
2221
+ )
2222
+ assert proc.stdin is not None
2223
+ assert proc.stdout is not None
2224
+ assert proc.stderr is not None
2225
+
2226
+ stdin = proc.stdin
2227
+ stdout = proc.stdout
2228
+ stderr = proc.stderr
2229
+
2230
+ async def log_stderr():
2231
+ while True:
2232
+ line = await stderr.readline()
2233
+ if not line:
2234
+ break
2235
+ logger.debug(f'srun stderr: {line.decode().rstrip()}')
2236
+
2237
+ stderr_task = None
2238
+ if env_options.Options.SHOW_DEBUG_INFO.get():
2239
+ stderr_task = asyncio.create_task(log_stderr())
2240
+ conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
2241
+ pid=os.getpid())
2242
+ ssh_failed = False
2243
+ try:
2244
+ conn_gauge.inc()
2245
+
2246
+ async def write_and_drain(data: bytes) -> None:
2247
+ stdin.write(data)
2248
+ await stdin.drain()
2249
+
2250
+ async def close_stdin() -> None:
2251
+ stdin.close()
2252
+
2253
+ ssh_failed = await _run_websocket_proxy(
2254
+ websocket,
2255
+ read_from_backend=lambda: stdout.read(4096),
2256
+ write_to_backend=write_and_drain,
2257
+ close_backend=close_stdin,
2258
+ timestamps_supported=timestamps_supported,
2259
+ )
2260
+
2261
+ finally:
2262
+ conn_gauge.dec()
2263
+ reason = ''
2264
+ try:
2265
+ logger.info('Terminating srun process')
2266
+ proc.terminate()
2267
+ except ProcessLookupError:
2268
+ stdout_data = await stdout.read()
2269
+ logger.error('srun process was terminated before the '
2270
+ 'ssh websocket connection was closed. Remaining '
2271
+ f'output: {str(stdout_data)}')
2272
+ reason = 'SrunProcessExit'
2273
+ metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
2274
+ pid=os.getpid(), reason=reason).inc()
2275
+ else:
2276
+ if ssh_failed:
2277
+ reason = 'SSHToSlurmJobDisconnected'
2278
+ else:
2279
+ reason = 'ClientClosed'
2280
+
2281
+ metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
2282
+ pid=os.getpid(), reason=reason).inc()
2283
+
2284
+ # Cancel the stderr logging task if it's still running
2285
+ if stderr_task is not None and not stderr_task.done():
2286
+ stderr_task.cancel()
2287
+ try:
2288
+ await stderr_task
2289
+ except asyncio.CancelledError:
2290
+ pass
2291
+
2292
+
2293
+ @app.websocket('/ssh-interactive-auth')
2294
+ async def ssh_interactive_auth(websocket: fastapi.WebSocket,
2295
+ session_id: str) -> None:
2296
+ """Proxies PTY for SSH interactive authentication via websocket.
2297
+
2298
+ This endpoint receives a PTY file descriptor from a worker process
2299
+ and bridges it bidirectionally with a websocket connection, allowing
2300
+ the client to handle interactive SSH authentication (e.g., 2FA).
2301
+
2302
+ Detects auth completion by monitoring terminal echo state and data flow.
2303
+ """
2304
+ await websocket.accept()
2305
+ logger.info(f'WebSocket connection accepted for SSH auth session: '
2306
+ f'{session_id}')
2307
+
2308
+ loop = asyncio.get_running_loop()
2309
+
2310
+ # Connect to worker process to receive PTY file descriptor
2311
+ fd_socket_path = interactive_utils.get_pty_socket_path(session_id)
2312
+ fd_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
2313
+ master_fd = -1
2314
+ try:
2315
+ # Connect to worker's FD-passing socket
2316
+ await loop.sock_connect(fd_sock, fd_socket_path)
2317
+ master_fd = await loop.run_in_executor(None, interactive_utils.recv_fd,
2318
+ fd_sock)
2319
+ logger.debug(f'Received PTY master fd {master_fd} for session '
2320
+ f'{session_id}')
2321
+
2322
+ # Bridge PTY ↔ websocket bidirectionally
2323
+ async def websocket_to_pty():
2324
+ """Forward websocket messages to PTY."""
2325
+ try:
2326
+ async for message in websocket.iter_bytes():
2327
+ await loop.run_in_executor(None, os.write, master_fd,
2328
+ message)
2329
+ except fastapi.WebSocketDisconnect:
2330
+ logger.debug(f'WebSocket disconnected for session {session_id}')
2331
+ except asyncio.CancelledError:
2332
+ pass
2333
+ except Exception as e: # pylint: disable=broad-except
2334
+ logger.error(f'Error in websocket_to_pty: {e}')
2335
+
2336
+ async def pty_to_websocket():
2337
+ """Forward PTY output to websocket and detect auth completion.
2338
+
2339
+ Detects auth completion by monitoring terminal echo state.
2340
+ Echo is disabled during password prompts and enabled after
2341
+ successful authentication. Auth is considered complete when
2342
+ echo has been enabled for a sustained period (1s).
2343
+ """
2344
+ try:
2345
+ while True:
2346
+ try:
2347
+ data = await loop.run_in_executor(
2348
+ None, os.read, master_fd, 4096)
2349
+ except OSError as e:
2350
+ logger.error(f'PTY read error (likely closed): {e}')
2351
+ break
2352
+
2353
+ if not data:
2354
+ break
2355
+
2356
+ await websocket.send_bytes(data)
2357
+ except asyncio.CancelledError:
2358
+ pass
2359
+ except Exception as e: # pylint: disable=broad-except
2360
+ logger.error(f'Error in pty_to_websocket: {e}')
2361
+ finally:
2362
+ try:
2363
+ await websocket.close()
2364
+ except Exception: # pylint: disable=broad-except
2365
+ pass
2366
+
2367
+ await asyncio.gather(websocket_to_pty(), pty_to_websocket())
2368
+
2369
+ except Exception as e: # pylint: disable=broad-except
2370
+ logger.error(f'Error in SSH interactive auth websocket: {e}')
2371
+ raise
2372
+ finally:
2373
+ # Clean up
2374
+ if master_fd >= 0:
2375
+ try:
2376
+ os.close(master_fd)
2377
+ except OSError:
2378
+ pass
2379
+ fd_sock.close()
2380
+ logger.debug(f'SSH interactive auth session {session_id} completed')
2381
+
2382
+
2003
2383
  @app.get('/all_contexts')
2004
2384
  async def all_contexts(request: fastapi.Request) -> None:
2005
2385
  """Gets all Kubernetes and SSH node pool contexts."""
@@ -2007,7 +2387,8 @@ async def all_contexts(request: fastapi.Request) -> None:
2007
2387
  await executor.schedule_request_async(
2008
2388
  request_id=request.state.request_id,
2009
2389
  request_name=request_names.RequestName.ALL_CONTEXTS,
2010
- request_body=payloads.RequestBody(),
2390
+ request_body=server_utils.build_body_at_server(
2391
+ request=request, body_type=payloads.RequestBody),
2011
2392
  func=core.get_all_contexts,
2012
2393
  schedule_type=requests_lib.ScheduleType.SHORT,
2013
2394
  )
@@ -2057,6 +2438,14 @@ async def serve_dashboard(full_path: str):
2057
2438
  if os.path.isfile(file_path):
2058
2439
  return fastapi.responses.FileResponse(file_path)
2059
2440
 
2441
+ # Serve plugin catch-all page for any /plugins/* paths so client-side
2442
+ # routing can bootstrap correctly.
2443
+ if full_path == 'plugins' or full_path.startswith('plugins/'):
2444
+ plugin_catchall = os.path.join(server_constants.DASHBOARD_DIR,
2445
+ 'plugins', '[...slug].html')
2446
+ if os.path.isfile(plugin_catchall):
2447
+ return fastapi.responses.FileResponse(plugin_catchall)
2448
+
2060
2449
  # Serve index.html for client-side routing
2061
2450
  # e.g. /clusters, /jobs
2062
2451
  index_path = os.path.join(server_constants.DASHBOARD_DIR, 'index.html')
@@ -2161,6 +2550,9 @@ if __name__ == '__main__':
2161
2550
  # Restore the server user hash
2162
2551
  logger.info('Initializing server user hash')
2163
2552
  _init_or_restore_server_user_hash()
2553
+ logger.info('Initializing permission service')
2554
+ permission.permission_service.initialize()
2555
+ logger.info('Permission service initialized')
2164
2556
 
2165
2557
  max_db_connections = global_user_state.get_max_db_connections()
2166
2558
  logger.info(f'Max db connections: {max_db_connections}')
@@ -2197,6 +2589,9 @@ if __name__ == '__main__':
2197
2589
  global_tasks.append(
2198
2590
  background.create_task(
2199
2591
  global_user_state.cluster_event_retention_daemon()))
2592
+ global_tasks.append(
2593
+ background.create_task(
2594
+ managed_job_state.job_event_retention_daemon()))
2200
2595
  threading.Thread(target=background.run_forever, daemon=True).start()
2201
2596
 
2202
2597
  queue_server, workers = executor.start(config)
@@ -2220,6 +2615,8 @@ if __name__ == '__main__':
2220
2615
 
2221
2616
  for gt in global_tasks:
2222
2617
  gt.cancel()
2618
+ for plugin in plugins.get_plugins():
2619
+ plugin.shutdown()
2223
2620
  subprocess_utils.run_in_parallel(lambda worker: worker.cancel(),
2224
2621
  workers,
2225
2622
  num_threads=len(workers))