skypilot-nightly 1.0.0.dev20251210__py3-none-any.whl → 1.0.0.dev20260112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. sky/__init__.py +4 -2
  2. sky/adaptors/slurm.py +159 -72
  3. sky/backends/backend_utils.py +52 -10
  4. sky/backends/cloud_vm_ray_backend.py +192 -32
  5. sky/backends/task_codegen.py +40 -2
  6. sky/catalog/data_fetchers/fetch_gcp.py +9 -1
  7. sky/catalog/data_fetchers/fetch_nebius.py +1 -1
  8. sky/catalog/data_fetchers/fetch_vast.py +4 -2
  9. sky/catalog/seeweb_catalog.py +30 -15
  10. sky/catalog/shadeform_catalog.py +5 -2
  11. sky/catalog/slurm_catalog.py +0 -7
  12. sky/catalog/vast_catalog.py +30 -6
  13. sky/check.py +11 -8
  14. sky/client/cli/command.py +106 -54
  15. sky/client/interactive_utils.py +190 -0
  16. sky/client/sdk.py +8 -0
  17. sky/client/sdk_async.py +9 -0
  18. sky/clouds/aws.py +60 -2
  19. sky/clouds/azure.py +2 -0
  20. sky/clouds/kubernetes.py +2 -0
  21. sky/clouds/runpod.py +38 -7
  22. sky/clouds/slurm.py +44 -12
  23. sky/clouds/ssh.py +1 -1
  24. sky/clouds/vast.py +30 -17
  25. sky/core.py +69 -1
  26. sky/dashboard/out/404.html +1 -1
  27. sky/dashboard/out/_next/static/3nu-b8raeKRNABZ2d4GAG/_buildManifest.js +1 -0
  28. sky/dashboard/out/_next/static/chunks/1871-0565f8975a7dcd10.js +6 -0
  29. sky/dashboard/out/_next/static/chunks/2109-55a1546d793574a7.js +11 -0
  30. sky/dashboard/out/_next/static/chunks/2521-099b07cd9e4745bf.js +26 -0
  31. sky/dashboard/out/_next/static/chunks/2755.a636e04a928a700e.js +31 -0
  32. sky/dashboard/out/_next/static/chunks/3495.05eab4862217c1a5.js +6 -0
  33. sky/dashboard/out/_next/static/chunks/3785.cfc5dcc9434fd98c.js +1 -0
  34. sky/dashboard/out/_next/static/chunks/3981.645d01bf9c8cad0c.js +21 -0
  35. sky/dashboard/out/_next/static/chunks/4083-0115d67c1fb57d6c.js +21 -0
  36. sky/dashboard/out/_next/static/chunks/{8640.5b9475a2d18c5416.js → 429.a58e9ba9742309ed.js} +2 -2
  37. sky/dashboard/out/_next/static/chunks/4555.8e221537181b5dc1.js +6 -0
  38. sky/dashboard/out/_next/static/chunks/4725.937865b81fdaaebb.js +6 -0
  39. sky/dashboard/out/_next/static/chunks/6082-edabd8f6092300ce.js +25 -0
  40. sky/dashboard/out/_next/static/chunks/6989-49cb7dca83a7a62d.js +1 -0
  41. sky/dashboard/out/_next/static/chunks/6990-630bd2a2257275f8.js +1 -0
  42. sky/dashboard/out/_next/static/chunks/7248-a99800d4db8edabd.js +1 -0
  43. sky/dashboard/out/_next/static/chunks/754-cfc5d4ad1b843d29.js +18 -0
  44. sky/dashboard/out/_next/static/chunks/8050-dd8aa107b17dce00.js +16 -0
  45. sky/dashboard/out/_next/static/chunks/8056-d4ae1e0cb81e7368.js +1 -0
  46. sky/dashboard/out/_next/static/chunks/8555.011023e296c127b3.js +6 -0
  47. sky/dashboard/out/_next/static/chunks/8821-93c25df904a8362b.js +1 -0
  48. sky/dashboard/out/_next/static/chunks/8969-0662594b69432ade.js +1 -0
  49. sky/dashboard/out/_next/static/chunks/9025.f15c91c97d124a5f.js +6 -0
  50. sky/dashboard/out/_next/static/chunks/{9353-8369df1cf105221c.js → 9353-7ad6bd01858556f1.js} +1 -1
  51. sky/dashboard/out/_next/static/chunks/pages/_app-5a86569acad99764.js +34 -0
  52. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-8297476714acb4ac.js +6 -0
  53. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-337c3ba1085f1210.js +1 -0
  54. sky/dashboard/out/_next/static/chunks/pages/{clusters-9e5d47818b9bdadd.js → clusters-57632ff3684a8b5c.js} +1 -1
  55. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-5fd3a453c079c2ea.js +1 -0
  56. sky/dashboard/out/_next/static/chunks/pages/infra-9f85c02c9c6cae9e.js +1 -0
  57. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-90f16972cbecf354.js +1 -0
  58. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-2dd42fc37aad427a.js +16 -0
  59. sky/dashboard/out/_next/static/chunks/pages/jobs-ed806aeace26b972.js +1 -0
  60. sky/dashboard/out/_next/static/chunks/pages/users-bec34706b36f3524.js +1 -0
  61. sky/dashboard/out/_next/static/chunks/pages/{volumes-ef19d49c6d0e8500.js → volumes-a83ba9b38dff7ea9.js} +1 -1
  62. sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-96e0f298308da7e2.js → [name]-c781e9c3e52ef9fc.js} +1 -1
  63. sky/dashboard/out/_next/static/chunks/pages/workspaces-91e0942f47310aae.js +1 -0
  64. sky/dashboard/out/_next/static/chunks/webpack-cfe59cf684ee13b9.js +1 -0
  65. sky/dashboard/out/_next/static/css/b0dbca28f027cc19.css +3 -0
  66. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  67. sky/dashboard/out/clusters/[cluster].html +1 -1
  68. sky/dashboard/out/clusters.html +1 -1
  69. sky/dashboard/out/config.html +1 -1
  70. sky/dashboard/out/index.html +1 -1
  71. sky/dashboard/out/infra/[context].html +1 -1
  72. sky/dashboard/out/infra.html +1 -1
  73. sky/dashboard/out/jobs/[job].html +1 -1
  74. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  75. sky/dashboard/out/jobs.html +1 -1
  76. sky/dashboard/out/plugins/[...slug].html +1 -1
  77. sky/dashboard/out/users.html +1 -1
  78. sky/dashboard/out/volumes.html +1 -1
  79. sky/dashboard/out/workspace/new.html +1 -1
  80. sky/dashboard/out/workspaces/[name].html +1 -1
  81. sky/dashboard/out/workspaces.html +1 -1
  82. sky/data/data_utils.py +26 -12
  83. sky/data/mounting_utils.py +29 -4
  84. sky/global_user_state.py +108 -16
  85. sky/jobs/client/sdk.py +8 -3
  86. sky/jobs/controller.py +191 -31
  87. sky/jobs/recovery_strategy.py +109 -11
  88. sky/jobs/server/core.py +81 -4
  89. sky/jobs/server/server.py +14 -0
  90. sky/jobs/state.py +417 -19
  91. sky/jobs/utils.py +73 -80
  92. sky/models.py +9 -0
  93. sky/optimizer.py +2 -1
  94. sky/provision/__init__.py +11 -9
  95. sky/provision/kubernetes/utils.py +122 -15
  96. sky/provision/kubernetes/volume.py +52 -17
  97. sky/provision/provisioner.py +2 -1
  98. sky/provision/runpod/instance.py +3 -1
  99. sky/provision/runpod/utils.py +13 -1
  100. sky/provision/runpod/volume.py +25 -9
  101. sky/provision/slurm/instance.py +75 -29
  102. sky/provision/slurm/utils.py +213 -107
  103. sky/provision/vast/utils.py +1 -0
  104. sky/resources.py +135 -13
  105. sky/schemas/api/responses.py +4 -0
  106. sky/schemas/db/global_user_state/010_save_ssh_key.py +1 -1
  107. sky/schemas/db/spot_jobs/008_add_full_resources.py +34 -0
  108. sky/schemas/db/spot_jobs/009_job_events.py +32 -0
  109. sky/schemas/db/spot_jobs/010_job_events_timestamp_with_timezone.py +43 -0
  110. sky/schemas/db/spot_jobs/011_add_links.py +34 -0
  111. sky/schemas/generated/jobsv1_pb2.py +9 -5
  112. sky/schemas/generated/jobsv1_pb2.pyi +12 -0
  113. sky/schemas/generated/jobsv1_pb2_grpc.py +44 -0
  114. sky/schemas/generated/managed_jobsv1_pb2.py +32 -28
  115. sky/schemas/generated/managed_jobsv1_pb2.pyi +11 -2
  116. sky/serve/serve_utils.py +232 -40
  117. sky/server/common.py +17 -0
  118. sky/server/constants.py +1 -1
  119. sky/server/metrics.py +6 -3
  120. sky/server/plugins.py +16 -0
  121. sky/server/requests/payloads.py +18 -0
  122. sky/server/requests/request_names.py +2 -0
  123. sky/server/requests/requests.py +28 -10
  124. sky/server/requests/serializers/encoders.py +5 -0
  125. sky/server/requests/serializers/return_value_serializers.py +14 -4
  126. sky/server/server.py +434 -107
  127. sky/server/uvicorn.py +5 -0
  128. sky/setup_files/MANIFEST.in +1 -0
  129. sky/setup_files/dependencies.py +21 -10
  130. sky/sky_logging.py +2 -1
  131. sky/skylet/constants.py +22 -5
  132. sky/skylet/executor/slurm.py +4 -6
  133. sky/skylet/job_lib.py +89 -4
  134. sky/skylet/services.py +18 -3
  135. sky/ssh_node_pools/deploy/tunnel/cleanup-tunnel.sh +62 -0
  136. sky/ssh_node_pools/deploy/tunnel/ssh-tunnel.sh +379 -0
  137. sky/templates/kubernetes-ray.yml.j2 +4 -6
  138. sky/templates/slurm-ray.yml.j2 +32 -2
  139. sky/templates/websocket_proxy.py +18 -41
  140. sky/users/permission.py +61 -51
  141. sky/utils/auth_utils.py +42 -0
  142. sky/utils/cli_utils/status_utils.py +19 -5
  143. sky/utils/cluster_utils.py +10 -3
  144. sky/utils/command_runner.py +256 -94
  145. sky/utils/command_runner.pyi +16 -0
  146. sky/utils/common_utils.py +30 -29
  147. sky/utils/context.py +32 -0
  148. sky/utils/db/db_utils.py +36 -6
  149. sky/utils/db/migration_utils.py +41 -21
  150. sky/utils/infra_utils.py +5 -1
  151. sky/utils/instance_links.py +139 -0
  152. sky/utils/interactive_utils.py +49 -0
  153. sky/utils/kubernetes/generate_kubeconfig.sh +42 -33
  154. sky/utils/kubernetes/rsync_helper.sh +5 -1
  155. sky/utils/plugin_extensions/__init__.py +14 -0
  156. sky/utils/plugin_extensions/external_failure_source.py +176 -0
  157. sky/utils/resources_utils.py +10 -8
  158. sky/utils/rich_utils.py +9 -11
  159. sky/utils/schemas.py +63 -20
  160. sky/utils/status_lib.py +7 -0
  161. sky/utils/subprocess_utils.py +17 -0
  162. sky/volumes/client/sdk.py +6 -3
  163. sky/volumes/server/core.py +65 -27
  164. sky_templates/ray/start_cluster +8 -4
  165. {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/METADATA +53 -57
  166. {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/RECORD +172 -162
  167. sky/dashboard/out/_next/static/KYAhEFa3FTfq4JyKVgo-s/_buildManifest.js +0 -1
  168. sky/dashboard/out/_next/static/chunks/1141-9c810f01ff4f398a.js +0 -11
  169. sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +0 -6
  170. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +0 -1
  171. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +0 -1
  172. sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +0 -15
  173. sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +0 -26
  174. sky/dashboard/out/_next/static/chunks/3294.ddda8c6c6f9f24dc.js +0 -1
  175. sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +0 -1
  176. sky/dashboard/out/_next/static/chunks/3800-b589397dc09c5b4e.js +0 -1
  177. sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +0 -1
  178. sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +0 -15
  179. sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +0 -13
  180. sky/dashboard/out/_next/static/chunks/6856-da20c5fd999f319c.js +0 -1
  181. sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +0 -1
  182. sky/dashboard/out/_next/static/chunks/6990-09cbf02d3cd518c3.js +0 -1
  183. sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +0 -30
  184. sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +0 -41
  185. sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +0 -1
  186. sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +0 -1
  187. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +0 -6
  188. sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +0 -31
  189. sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +0 -30
  190. sky/dashboard/out/_next/static/chunks/pages/_app-68b647e26f9d2793.js +0 -34
  191. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-33f525539665fdfd.js +0 -16
  192. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-a7565f586ef86467.js +0 -1
  193. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-12c559ec4d81fdbd.js +0 -1
  194. sky/dashboard/out/_next/static/chunks/pages/infra-d187cd0413d72475.js +0 -1
  195. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-895847b6cf200b04.js +0 -16
  196. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-8d0f4655400b4eb9.js +0 -21
  197. sky/dashboard/out/_next/static/chunks/pages/jobs-e5a98f17f8513a96.js +0 -1
  198. sky/dashboard/out/_next/static/chunks/pages/users-2f7646eb77785a2c.js +0 -1
  199. sky/dashboard/out/_next/static/chunks/pages/workspaces-cb4da3abe08ebf19.js +0 -1
  200. sky/dashboard/out/_next/static/chunks/webpack-fba3de387ff6bb08.js +0 -1
  201. sky/dashboard/out/_next/static/css/c5a4cfd2600fc715.css +0 -3
  202. /sky/dashboard/out/_next/static/{KYAhEFa3FTfq4JyKVgo-s → 3nu-b8raeKRNABZ2d4GAG}/_ssgManifest.js +0 -0
  203. /sky/dashboard/out/_next/static/chunks/pages/plugins/{[...slug]-4f46050ca065d8f8.js → [...slug]-449a9f5a3bb20fb3.js} +0 -0
  204. {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/WHEEL +0 -0
  205. {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/entry_points.txt +0 -0
  206. {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/licenses/LICENSE +0 -0
  207. {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/top_level.txt +0 -0
sky/server/server.py CHANGED
@@ -15,12 +15,16 @@ import pathlib
15
15
  import posixpath
16
16
  import re
17
17
  import resource
18
+ import shlex
18
19
  import shutil
20
+ import socket
19
21
  import struct
20
22
  import sys
21
23
  import threading
22
24
  import traceback
23
- from typing import Any, Dict, List, Literal, Optional, Set, Tuple
25
+ import typing
26
+ from typing import (Any, Awaitable, Callable, Dict, List, Literal, Optional,
27
+ Set, Tuple, Type)
24
28
  import uuid
25
29
  import zipfile
26
30
 
@@ -43,6 +47,7 @@ from sky import global_user_state
43
47
  from sky import models
44
48
  from sky import sky_logging
45
49
  from sky.data import storage_utils
50
+ from sky.jobs import state as managed_job_state
46
51
  from sky.jobs import utils as managed_job_utils
47
52
  from sky.jobs.server import server as jobs_rest
48
53
  from sky.metrics import utils as metrics_utils
@@ -76,6 +81,7 @@ from sky.usage import usage_lib
76
81
  from sky.users import permission
77
82
  from sky.users import server as users_rest
78
83
  from sky.utils import admin_policy_utils
84
+ from sky.utils import command_runner
79
85
  from sky.utils import common as common_lib
80
86
  from sky.utils import common_utils
81
87
  from sky.utils import context
@@ -83,6 +89,7 @@ from sky.utils import context_utils
83
89
  from sky.utils import controller_utils
84
90
  from sky.utils import dag_utils
85
91
  from sky.utils import env_options
92
+ from sky.utils import interactive_utils
86
93
  from sky.utils import perf_utils
87
94
  from sky.utils import status_lib
88
95
  from sky.utils import subprocess_utils
@@ -91,6 +98,9 @@ from sky.utils.db import db_utils
91
98
  from sky.volumes.server import server as volumes_rest
92
99
  from sky.workspaces import server as workspaces_rest
93
100
 
101
+ if typing.TYPE_CHECKING:
102
+ from sky import backends
103
+
94
104
  # pylint: disable=ungrouped-imports
95
105
  if sys.version_info >= (3, 10):
96
106
  from typing import ParamSpec
@@ -208,6 +218,10 @@ class BasicAuthMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
208
218
  """Middleware to handle HTTP Basic Auth."""
209
219
 
210
220
  async def dispatch(self, request: fastapi.Request, call_next):
221
+ # If a previous middleware already authenticated the user, pass through
222
+ if request.state.auth_user is not None:
223
+ return await call_next(request)
224
+
211
225
  if managed_job_utils.is_consolidation_mode(
212
226
  ) and loopback.is_loopback_request(request):
213
227
  return await call_next(request)
@@ -275,6 +289,10 @@ class BearerTokenMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
275
289
  X-Skypilot-Auth-Mode header. The auth proxy should either validate the
276
290
  auth or set the header X-Skypilot-Auth-Mode: token.
277
291
  """
292
+ # If a previous middleware already authenticated the user, pass through
293
+ if request.state.auth_user is not None:
294
+ return await call_next(request)
295
+
278
296
  has_skypilot_auth_header = (
279
297
  request.headers.get('X-Skypilot-Auth-Mode') == 'token')
280
298
  auth_header = request.headers.get('authorization')
@@ -818,7 +836,8 @@ async def slurm_gpu_availability(
818
836
  )
819
837
 
820
838
 
821
- @app.get('/slurm_node_info')
839
+ # Keep the GET method for backwards compatibility
840
+ @app.api_route('/slurm_node_info', methods=['GET', 'POST'])
822
841
  async def slurm_node_info(
823
842
  request: fastapi.Request,
824
843
  slurm_node_info_body: payloads.SlurmNodeInfoRequestBody) -> None:
@@ -1503,6 +1522,21 @@ async def cost_report(request: fastapi.Request,
1503
1522
  )
1504
1523
 
1505
1524
 
1525
+ @app.post('/cluster_events')
1526
+ async def cluster_events(
1527
+ request: fastapi.Request,
1528
+ cluster_events_body: payloads.ClusterEventsBody) -> None:
1529
+ """Gets events for a cluster."""
1530
+ await executor.schedule_request_async(
1531
+ request_id=request.state.request_id,
1532
+ request_name=request_names.RequestName.CLUSTER_EVENTS,
1533
+ request_body=cluster_events_body,
1534
+ func=core.get_cluster_events,
1535
+ schedule_type=requests_lib.ScheduleType.SHORT,
1536
+ request_cluster_name=cluster_events_body.cluster_name or '',
1537
+ )
1538
+
1539
+
1506
1540
  @app.get('/storage/ls')
1507
1541
  async def storage_ls(request: fastapi.Request) -> None:
1508
1542
  """Gets the storages."""
@@ -1805,10 +1839,17 @@ async def api_status(
1805
1839
  @app.get('/api/plugins', response_class=fastapi_responses.ORJSONResponse)
1806
1840
  async def list_plugins() -> Dict[str, List[Dict[str, Any]]]:
1807
1841
  """Return metadata about loaded backend plugins."""
1808
- plugin_info = [{
1809
- 'js_extension_path': plugin.js_extension_path,
1810
- } for plugin in plugins.get_plugins()]
1811
- return {'plugins': plugin_info}
1842
+ plugin_infos = []
1843
+ for plugin_info in plugins.get_plugins():
1844
+ info = {
1845
+ 'js_extension_path': plugin_info.js_extension_path,
1846
+ }
1847
+ for attr in ('name', 'version', 'commit'):
1848
+ value = getattr(plugin_info, attr, None)
1849
+ if value is not None:
1850
+ info[attr] = value
1851
+ plugin_infos.append(info)
1852
+ return {'plugins': plugin_infos}
1812
1853
 
1813
1854
 
1814
1855
  @app.get(
@@ -1882,12 +1923,149 @@ async def health(request: fastapi.Request) -> responses.APIHealthResponse:
1882
1923
  )
1883
1924
 
1884
1925
 
1885
- class KubernetesSSHMessageType(IntEnum):
1926
+ class SSHMessageType(IntEnum):
1886
1927
  REGULAR_DATA = 0
1887
1928
  PINGPONG = 1
1888
1929
  LATENCY_MEASUREMENT = 2
1889
1930
 
1890
1931
 
1932
+ async def _get_cluster_and_validate(
1933
+ cluster_name: str,
1934
+ cloud_type: Type[clouds.Cloud],
1935
+ ) -> 'backends.CloudVmRayResourceHandle':
1936
+ """Fetch cluster status and validate it's UP and correct cloud type."""
1937
+ # Run core.status in another thread to avoid blocking the event loop.
1938
+ # TODO(aylei): core.status() will be called with server user, which has
1939
+ # permission to all workspaces, this will break workspace isolation.
1940
+ # It is ok for now, as users with limited access will not get the ssh config
1941
+ # for the clusters in non-accessible workspaces.
1942
+ with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
1943
+ cluster_records = await context_utils.to_thread_with_executor(
1944
+ thread_pool_executor, core.status, cluster_name, all_users=True)
1945
+ cluster_record = cluster_records[0]
1946
+ if cluster_record['status'] != status_lib.ClusterStatus.UP:
1947
+ raise fastapi.HTTPException(
1948
+ status_code=400, detail=f'Cluster {cluster_name} is not running')
1949
+
1950
+ handle: Optional['backends.CloudVmRayResourceHandle'] = cluster_record[
1951
+ 'handle']
1952
+ assert handle is not None, 'Cluster handle is None'
1953
+ if not isinstance(handle.launched_resources.cloud, cloud_type):
1954
+ raise fastapi.HTTPException(
1955
+ status_code=400,
1956
+ detail=f'Cluster {cluster_name} is not a {str(cloud_type())} '
1957
+ 'cluster. Use ssh to connect to the cluster instead.')
1958
+
1959
+ return handle
1960
+
1961
+
1962
+ async def _run_websocket_proxy(
1963
+ websocket: fastapi.WebSocket,
1964
+ read_from_backend: Callable[[], Awaitable[bytes]],
1965
+ write_to_backend: Callable[[bytes], Awaitable[None]],
1966
+ close_backend: Callable[[], Awaitable[None]],
1967
+ timestamps_supported: bool,
1968
+ ) -> bool:
1969
+ """Run bidirectional WebSocket-to-backend proxy.
1970
+
1971
+ Args:
1972
+ websocket: FastAPI WebSocket connection
1973
+ read_from_backend: Async callable to read bytes from backend
1974
+ write_to_backend: Async callable to write bytes to backend
1975
+ close_backend: Async callable to close backend connection
1976
+ timestamps_supported: Whether to use message type framing
1977
+
1978
+ Returns:
1979
+ True if SSH failed, False otherwise
1980
+ """
1981
+ ssh_failed = False
1982
+ websocket_closed = False
1983
+
1984
+ async def websocket_to_backend():
1985
+ try:
1986
+ async for message in websocket.iter_bytes():
1987
+ if timestamps_supported:
1988
+ type_size = struct.calcsize('!B')
1989
+ message_type = struct.unpack('!B', message[:type_size])[0]
1990
+ if message_type == SSHMessageType.REGULAR_DATA:
1991
+ # Regular data - strip type byte and forward to backend
1992
+ message = message[type_size:]
1993
+ elif message_type == SSHMessageType.PINGPONG:
1994
+ # PING message - respond with PONG
1995
+ ping_id_size = struct.calcsize('!I')
1996
+ if len(message) != type_size + ping_id_size:
1997
+ raise ValueError(
1998
+ f'Invalid PING message length: {len(message)}')
1999
+ # Return the same PING message for latency measurement
2000
+ await websocket.send_bytes(message)
2001
+ continue
2002
+ elif message_type == SSHMessageType.LATENCY_MEASUREMENT:
2003
+ # Latency measurement from client
2004
+ latency_size = struct.calcsize('!Q')
2005
+ if len(message) != type_size + latency_size:
2006
+ raise ValueError('Invalid latency measurement '
2007
+ f'message length: {len(message)}')
2008
+ avg_latency_ms = struct.unpack(
2009
+ '!Q',
2010
+ message[type_size:type_size + latency_size])[0]
2011
+ latency_seconds = avg_latency_ms / 1000
2012
+ metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels( # pylint: disable=line-too-long
2013
+ pid=os.getpid()).observe(latency_seconds)
2014
+ continue
2015
+ else:
2016
+ raise ValueError(
2017
+ f'Unknown message type: {message_type}')
2018
+
2019
+ try:
2020
+ await write_to_backend(message)
2021
+ except Exception as e: # pylint: disable=broad-except
2022
+ # Typically we will not reach here, if the conn to backend
2023
+ # is disconnected, backend_to_websocket will exit first.
2024
+ # But just in case.
2025
+ logger.error(f'Failed to write to backend through '
2026
+ f'connection: {e}')
2027
+ nonlocal ssh_failed
2028
+ ssh_failed = True
2029
+ break
2030
+ except fastapi.WebSocketDisconnect:
2031
+ pass
2032
+ nonlocal websocket_closed
2033
+ websocket_closed = True
2034
+ await close_backend()
2035
+
2036
+ async def backend_to_websocket():
2037
+ try:
2038
+ while True:
2039
+ data = await read_from_backend()
2040
+ if not data:
2041
+ if not websocket_closed:
2042
+ logger.warning(
2043
+ 'SSH connection to backend is disconnected '
2044
+ 'before websocket connection is closed')
2045
+ nonlocal ssh_failed
2046
+ ssh_failed = True
2047
+ break
2048
+ if timestamps_supported:
2049
+ # Prepend message type byte (0 = regular data)
2050
+ message_type_bytes = struct.pack(
2051
+ '!B', SSHMessageType.REGULAR_DATA.value)
2052
+ data = message_type_bytes + data
2053
+ await websocket.send_bytes(data)
2054
+ except Exception: # pylint: disable=broad-except
2055
+ pass
2056
+ try:
2057
+ await websocket.close()
2058
+ except Exception: # pylint: disable=broad-except
2059
+ # The websocket might have been closed by the client
2060
+ pass
2061
+
2062
+ await asyncio.gather(websocket_to_backend(),
2063
+ backend_to_websocket(),
2064
+ return_exceptions=True)
2065
+
2066
+ return ssh_failed
2067
+
2068
+
1891
2069
  @app.websocket('/kubernetes-pod-ssh-proxy')
1892
2070
  async def kubernetes_pod_ssh_proxy(
1893
2071
  websocket: fastapi.WebSocket,
@@ -1901,22 +2079,7 @@ async def kubernetes_pod_ssh_proxy(
1901
2079
  logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
1902
2080
  client_version = {client_version}')
1903
2081
 
1904
- # Run core.status in another thread to avoid blocking the event loop.
1905
- with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
1906
- cluster_records = await context_utils.to_thread_with_executor(
1907
- thread_pool_executor, core.status, cluster_name, all_users=True)
1908
- cluster_record = cluster_records[0]
1909
- if cluster_record['status'] != status_lib.ClusterStatus.UP:
1910
- raise fastapi.HTTPException(
1911
- status_code=400, detail=f'Cluster {cluster_name} is not running')
1912
-
1913
- handle = cluster_record['handle']
1914
- assert handle is not None, 'Cluster handle is None'
1915
- if not isinstance(handle.launched_resources.cloud, clouds.Kubernetes):
1916
- raise fastapi.HTTPException(
1917
- status_code=400,
1918
- detail=f'Cluster {cluster_name} is not a Kubernetes cluster'
1919
- 'Use ssh to connect to the cluster instead.')
2082
+ handle = await _get_cluster_and_validate(cluster_name, clouds.Kubernetes)
1920
2083
 
1921
2084
  kubectl_cmd = handle.get_command_runners()[0].port_forward_command(
1922
2085
  port_forward=[(None, 22)])
@@ -1946,96 +2109,25 @@ async def kubernetes_pod_ssh_proxy(
1946
2109
  conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
1947
2110
  pid=os.getpid())
1948
2111
  ssh_failed = False
1949
- websocket_closed = False
1950
2112
  try:
1951
2113
  conn_gauge.inc()
1952
2114
  # Connect to the local port
1953
2115
  reader, writer = await asyncio.open_connection('127.0.0.1', local_port)
1954
2116
 
1955
- async def websocket_to_ssh():
1956
- try:
1957
- async for message in websocket.iter_bytes():
1958
- if timestamps_supported:
1959
- type_size = struct.calcsize('!B')
1960
- message_type = struct.unpack('!B',
1961
- message[:type_size])[0]
1962
- if (message_type ==
1963
- KubernetesSSHMessageType.REGULAR_DATA):
1964
- # Regular data - strip type byte and forward to SSH
1965
- message = message[type_size:]
1966
- elif message_type == KubernetesSSHMessageType.PINGPONG:
1967
- # PING message - respond with PONG (type 1)
1968
- ping_id_size = struct.calcsize('!I')
1969
- if len(message) != type_size + ping_id_size:
1970
- raise ValueError('Invalid PING message '
1971
- f'length: {len(message)}')
1972
- # Return the same PING message, so that the client
1973
- # can measure the latency.
1974
- await websocket.send_bytes(message)
1975
- continue
1976
- elif (message_type ==
1977
- KubernetesSSHMessageType.LATENCY_MEASUREMENT):
1978
- # Latency measurement from client
1979
- latency_size = struct.calcsize('!Q')
1980
- if len(message) != type_size + latency_size:
1981
- raise ValueError(
1982
- 'Invalid latency measurement '
1983
- f'message length: {len(message)}')
1984
- avg_latency_ms = struct.unpack(
1985
- '!Q',
1986
- message[type_size:type_size + latency_size])[0]
1987
- latency_seconds = avg_latency_ms / 1000
1988
- metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels(pid=os.getpid()).observe(latency_seconds) # pylint: disable=line-too-long
1989
- continue
1990
- else:
1991
- # Unknown message type.
1992
- raise ValueError(
1993
- f'Unknown message type: {message_type}')
1994
- writer.write(message)
1995
- try:
1996
- await writer.drain()
1997
- except Exception as e: # pylint: disable=broad-except
1998
- # Typically we will not reach here, if the ssh to pod
1999
- # is disconnected, ssh_to_websocket will exit first.
2000
- # But just in case.
2001
- logger.error('Failed to write to pod through '
2002
- f'port-forward connection: {e}')
2003
- nonlocal ssh_failed
2004
- ssh_failed = True
2005
- break
2006
- except fastapi.WebSocketDisconnect:
2007
- pass
2008
- nonlocal websocket_closed
2009
- websocket_closed = True
2010
- writer.close()
2117
+ async def write_and_drain(data: bytes) -> None:
2118
+ writer.write(data)
2119
+ await writer.drain()
2011
2120
 
2012
- async def ssh_to_websocket():
2013
- try:
2014
- while True:
2015
- data = await reader.read(1024)
2016
- if not data:
2017
- if not websocket_closed:
2018
- logger.warning('SSH connection to pod is '
2019
- 'disconnected before websocket '
2020
- 'connection is closed')
2021
- nonlocal ssh_failed
2022
- ssh_failed = True
2023
- break
2024
- if timestamps_supported:
2025
- # Prepend message type byte (0 = regular data)
2026
- message_type_bytes = struct.pack(
2027
- '!B', KubernetesSSHMessageType.REGULAR_DATA.value)
2028
- data = message_type_bytes + data
2029
- await websocket.send_bytes(data)
2030
- except Exception: # pylint: disable=broad-except
2031
- pass
2032
- try:
2033
- await websocket.close()
2034
- except Exception: # pylint: disable=broad-except
2035
- # The websocket might has been closed by the client.
2036
- pass
2121
+ async def close_writer() -> None:
2122
+ writer.close()
2037
2123
 
2038
- await asyncio.gather(websocket_to_ssh(), ssh_to_websocket())
2124
+ ssh_failed = await _run_websocket_proxy(
2125
+ websocket,
2126
+ read_from_backend=lambda: reader.read(1024),
2127
+ write_to_backend=write_and_drain,
2128
+ close_backend=close_writer,
2129
+ timestamps_supported=timestamps_supported,
2130
+ )
2039
2131
  finally:
2040
2132
  conn_gauge.dec()
2041
2133
  reason = ''
@@ -2049,7 +2141,7 @@ async def kubernetes_pod_ssh_proxy(
2049
2141
  f'output: {str(stdout)}')
2050
2142
  reason = 'KubectlPortForwardExit'
2051
2143
  metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
2052
- pid=os.getpid(), reason='KubectlPortForwardExit').inc()
2144
+ pid=os.getpid(), reason=reason).inc()
2053
2145
  else:
2054
2146
  if ssh_failed:
2055
2147
  reason = 'SSHToPodDisconnected'
@@ -2059,6 +2151,235 @@ async def kubernetes_pod_ssh_proxy(
2059
2151
  pid=os.getpid(), reason=reason).inc()
2060
2152
 
2061
2153
 
2154
+ @app.websocket('/slurm-job-ssh-proxy')
2155
+ async def slurm_job_ssh_proxy(websocket: fastapi.WebSocket,
2156
+ cluster_name: str,
2157
+ worker: int = 0,
2158
+ client_version: Optional[int] = None) -> None:
2159
+ """Proxies SSH to the Slurm job via sshd inside srun."""
2160
+ await websocket.accept()
2161
+ logger.info(f'WebSocket connection accepted for cluster: '
2162
+ f'{cluster_name}')
2163
+
2164
+ timestamps_supported = client_version is not None and client_version > 21
2165
+ logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
2166
+ client_version = {client_version}')
2167
+
2168
+ handle = await _get_cluster_and_validate(cluster_name, clouds.Slurm)
2169
+
2170
+ assert handle.cached_cluster_info is not None, 'Cached cluster info is None'
2171
+ provider_config = handle.cached_cluster_info.provider_config
2172
+ assert provider_config is not None, 'Provider config is None'
2173
+ login_node_ssh_config = provider_config['ssh']
2174
+ login_node_host = login_node_ssh_config['hostname']
2175
+ login_node_port = int(login_node_ssh_config['port'])
2176
+ login_node_user = login_node_ssh_config['user']
2177
+ login_node_key = login_node_ssh_config['private_key']
2178
+ login_node_proxy_command = login_node_ssh_config.get('proxycommand', None)
2179
+ login_node_proxy_jump = login_node_ssh_config.get('proxyjump', None)
2180
+
2181
+ login_node_runner = command_runner.SSHCommandRunner(
2182
+ (login_node_host, login_node_port),
2183
+ login_node_user,
2184
+ login_node_key,
2185
+ ssh_proxy_command=login_node_proxy_command,
2186
+ ssh_proxy_jump=login_node_proxy_jump,
2187
+ )
2188
+
2189
+ ssh_cmd = login_node_runner.ssh_base_command(
2190
+ ssh_mode=command_runner.SshMode.NON_INTERACTIVE,
2191
+ port_forward=None,
2192
+ connect_timeout=None)
2193
+
2194
+ # There can only be one InstanceInfo per instance_id.
2195
+ head_instance = handle.cached_cluster_info.get_head_instance()
2196
+ assert head_instance is not None, 'Head instance is None'
2197
+ job_id = head_instance.tags['job_id']
2198
+
2199
+ # Instances are ordered: head first, then workers
2200
+ instances = handle.cached_cluster_info.instances
2201
+ node_hostnames = [inst[0].tags['node'] for inst in instances.values()]
2202
+ if worker >= len(node_hostnames):
2203
+ raise fastapi.HTTPException(
2204
+ status_code=400,
2205
+ detail=f'Worker index {worker} out of range. '
2206
+ f'Cluster has {len(node_hostnames)} nodes.')
2207
+ target_node = node_hostnames[worker]
2208
+
2209
+ # Run sshd inside the Slurm job "container" via srun, such that it inherits
2210
+ # the resource constraints of the Slurm job.
2211
+ ssh_cmd += [
2212
+ shlex.quote(
2213
+ slurm_utils.srun_sshd_command(job_id, target_node, login_node_user))
2214
+ ]
2215
+
2216
+ proc = await asyncio.create_subprocess_shell(
2217
+ ' '.join(ssh_cmd),
2218
+ stdin=asyncio.subprocess.PIPE,
2219
+ stdout=asyncio.subprocess.PIPE,
2220
+ stderr=asyncio.subprocess.PIPE, # Capture stderr separately for logging
2221
+ )
2222
+ assert proc.stdin is not None
2223
+ assert proc.stdout is not None
2224
+ assert proc.stderr is not None
2225
+
2226
+ stdin = proc.stdin
2227
+ stdout = proc.stdout
2228
+ stderr = proc.stderr
2229
+
2230
+ async def log_stderr():
2231
+ while True:
2232
+ line = await stderr.readline()
2233
+ if not line:
2234
+ break
2235
+ logger.debug(f'srun stderr: {line.decode().rstrip()}')
2236
+
2237
+ stderr_task = None
2238
+ if env_options.Options.SHOW_DEBUG_INFO.get():
2239
+ stderr_task = asyncio.create_task(log_stderr())
2240
+ conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
2241
+ pid=os.getpid())
2242
+ ssh_failed = False
2243
+ try:
2244
+ conn_gauge.inc()
2245
+
2246
+ async def write_and_drain(data: bytes) -> None:
2247
+ stdin.write(data)
2248
+ await stdin.drain()
2249
+
2250
+ async def close_stdin() -> None:
2251
+ stdin.close()
2252
+
2253
+ ssh_failed = await _run_websocket_proxy(
2254
+ websocket,
2255
+ read_from_backend=lambda: stdout.read(4096),
2256
+ write_to_backend=write_and_drain,
2257
+ close_backend=close_stdin,
2258
+ timestamps_supported=timestamps_supported,
2259
+ )
2260
+
2261
+ finally:
2262
+ conn_gauge.dec()
2263
+ reason = ''
2264
+ try:
2265
+ logger.info('Terminating srun process')
2266
+ proc.terminate()
2267
+ except ProcessLookupError:
2268
+ stdout_data = await stdout.read()
2269
+ logger.error('srun process was terminated before the '
2270
+ 'ssh websocket connection was closed. Remaining '
2271
+ f'output: {str(stdout_data)}')
2272
+ reason = 'SrunProcessExit'
2273
+ metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
2274
+ pid=os.getpid(), reason=reason).inc()
2275
+ else:
2276
+ if ssh_failed:
2277
+ reason = 'SSHToSlurmJobDisconnected'
2278
+ else:
2279
+ reason = 'ClientClosed'
2280
+
2281
+ metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
2282
+ pid=os.getpid(), reason=reason).inc()
2283
+
2284
+ # Cancel the stderr logging task if it's still running
2285
+ if stderr_task is not None and not stderr_task.done():
2286
+ stderr_task.cancel()
2287
+ try:
2288
+ await stderr_task
2289
+ except asyncio.CancelledError:
2290
+ pass
2291
+
2292
+
2293
+ @app.websocket('/ssh-interactive-auth')
2294
+ async def ssh_interactive_auth(websocket: fastapi.WebSocket,
2295
+ session_id: str) -> None:
2296
+ """Proxies PTY for SSH interactive authentication via websocket.
2297
+
2298
+ This endpoint receives a PTY file descriptor from a worker process
2299
+ and bridges it bidirectionally with a websocket connection, allowing
2300
+ the client to handle interactive SSH authentication (e.g., 2FA).
2301
+
2302
+ Detects auth completion by monitoring terminal echo state and data flow.
2303
+ """
2304
+ await websocket.accept()
2305
+ logger.info(f'WebSocket connection accepted for SSH auth session: '
2306
+ f'{session_id}')
2307
+
2308
+ loop = asyncio.get_running_loop()
2309
+
2310
+ # Connect to worker process to receive PTY file descriptor
2311
+ fd_socket_path = interactive_utils.get_pty_socket_path(session_id)
2312
+ fd_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
2313
+ master_fd = -1
2314
+ try:
2315
+ # Connect to worker's FD-passing socket
2316
+ await loop.sock_connect(fd_sock, fd_socket_path)
2317
+ master_fd = await loop.run_in_executor(None, interactive_utils.recv_fd,
2318
+ fd_sock)
2319
+ logger.debug(f'Received PTY master fd {master_fd} for session '
2320
+ f'{session_id}')
2321
+
2322
+ # Bridge PTY ↔ websocket bidirectionally
2323
+ async def websocket_to_pty():
2324
+ """Forward websocket messages to PTY."""
2325
+ try:
2326
+ async for message in websocket.iter_bytes():
2327
+ await loop.run_in_executor(None, os.write, master_fd,
2328
+ message)
2329
+ except fastapi.WebSocketDisconnect:
2330
+ logger.debug(f'WebSocket disconnected for session {session_id}')
2331
+ except asyncio.CancelledError:
2332
+ pass
2333
+ except Exception as e: # pylint: disable=broad-except
2334
+ logger.error(f'Error in websocket_to_pty: {e}')
2335
+
2336
+ async def pty_to_websocket():
2337
+ """Forward PTY output to websocket and detect auth completion.
2338
+
2339
+ Detects auth completion by monitoring terminal echo state.
2340
+ Echo is disabled during password prompts and enabled after
2341
+ successful authentication. Auth is considered complete when
2342
+ echo has been enabled for a sustained period (1s).
2343
+ """
2344
+ try:
2345
+ while True:
2346
+ try:
2347
+ data = await loop.run_in_executor(
2348
+ None, os.read, master_fd, 4096)
2349
+ except OSError as e:
2350
+ logger.error(f'PTY read error (likely closed): {e}')
2351
+ break
2352
+
2353
+ if not data:
2354
+ break
2355
+
2356
+ await websocket.send_bytes(data)
2357
+ except asyncio.CancelledError:
2358
+ pass
2359
+ except Exception as e: # pylint: disable=broad-except
2360
+ logger.error(f'Error in pty_to_websocket: {e}')
2361
+ finally:
2362
+ try:
2363
+ await websocket.close()
2364
+ except Exception: # pylint: disable=broad-except
2365
+ pass
2366
+
2367
+ await asyncio.gather(websocket_to_pty(), pty_to_websocket())
2368
+
2369
+ except Exception as e: # pylint: disable=broad-except
2370
+ logger.error(f'Error in SSH interactive auth websocket: {e}')
2371
+ raise
2372
+ finally:
2373
+ # Clean up
2374
+ if master_fd >= 0:
2375
+ try:
2376
+ os.close(master_fd)
2377
+ except OSError:
2378
+ pass
2379
+ fd_sock.close()
2380
+ logger.debug(f'SSH interactive auth session {session_id} completed')
2381
+
2382
+
2062
2383
  @app.get('/all_contexts')
2063
2384
  async def all_contexts(request: fastapi.Request) -> None:
2064
2385
  """Gets all Kubernetes and SSH node pool contexts."""
@@ -2229,6 +2550,9 @@ if __name__ == '__main__':
2229
2550
  # Restore the server user hash
2230
2551
  logger.info('Initializing server user hash')
2231
2552
  _init_or_restore_server_user_hash()
2553
+ logger.info('Initializing permission service')
2554
+ permission.permission_service.initialize()
2555
+ logger.info('Permission service initialized')
2232
2556
 
2233
2557
  max_db_connections = global_user_state.get_max_db_connections()
2234
2558
  logger.info(f'Max db connections: {max_db_connections}')
@@ -2265,6 +2589,9 @@ if __name__ == '__main__':
2265
2589
  global_tasks.append(
2266
2590
  background.create_task(
2267
2591
  global_user_state.cluster_event_retention_daemon()))
2592
+ global_tasks.append(
2593
+ background.create_task(
2594
+ managed_job_state.job_event_retention_daemon()))
2268
2595
  threading.Thread(target=background.run_forever, daemon=True).start()
2269
2596
 
2270
2597
  queue_server, workers = executor.start(config)
sky/server/uvicorn.py CHANGED
@@ -20,6 +20,7 @@ from uvicorn.supervisors import multiprocess
20
20
  from sky import sky_logging
21
21
  from sky.server import daemons
22
22
  from sky.server import metrics as metrics_lib
23
+ from sky.server import plugins
23
24
  from sky.server import state
24
25
  from sky.server.requests import requests as requests_lib
25
26
  from sky.skylet import constants
@@ -237,6 +238,10 @@ def run(config: uvicorn.Config, max_db_connections: Optional[int] = None):
237
238
  server = Server(config=config, max_db_connections=max_db_connections)
238
239
  try:
239
240
  if config.workers is not None and config.workers > 1:
241
+ # When workers > 1, uvicorn does not run server app in the main
242
+ # process. In this case, plugins are not loaded at this point, so
243
+ # load plugins here without uvicorn app.
244
+ plugins.load_plugins(plugins.ExtensionContext())
240
245
  sock = config.bind_socket()
241
246
  SlowStartMultiprocess(config, target=server.run,
242
247
  sockets=[sock]).run()