skypilot-nightly 1.0.0.dev20251203__py3-none-any.whl → 1.0.0.dev20260112__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +6 -2
- sky/adaptors/aws.py +1 -61
- sky/adaptors/slurm.py +565 -0
- sky/backends/backend_utils.py +95 -12
- sky/backends/cloud_vm_ray_backend.py +224 -65
- sky/backends/task_codegen.py +380 -4
- sky/catalog/__init__.py +0 -3
- sky/catalog/data_fetchers/fetch_gcp.py +9 -1
- sky/catalog/data_fetchers/fetch_nebius.py +1 -1
- sky/catalog/data_fetchers/fetch_vast.py +4 -2
- sky/catalog/kubernetes_catalog.py +12 -4
- sky/catalog/seeweb_catalog.py +30 -15
- sky/catalog/shadeform_catalog.py +5 -2
- sky/catalog/slurm_catalog.py +236 -0
- sky/catalog/vast_catalog.py +30 -6
- sky/check.py +25 -11
- sky/client/cli/command.py +391 -32
- sky/client/interactive_utils.py +190 -0
- sky/client/sdk.py +64 -2
- sky/client/sdk_async.py +9 -0
- sky/clouds/__init__.py +2 -0
- sky/clouds/aws.py +60 -2
- sky/clouds/azure.py +2 -0
- sky/clouds/cloud.py +7 -0
- sky/clouds/kubernetes.py +2 -0
- sky/clouds/runpod.py +38 -7
- sky/clouds/slurm.py +610 -0
- sky/clouds/ssh.py +3 -2
- sky/clouds/vast.py +39 -16
- sky/core.py +197 -37
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/3nu-b8raeKRNABZ2d4GAG/_buildManifest.js +1 -0
- sky/dashboard/out/_next/static/chunks/1871-0565f8975a7dcd10.js +6 -0
- sky/dashboard/out/_next/static/chunks/2109-55a1546d793574a7.js +11 -0
- sky/dashboard/out/_next/static/chunks/2521-099b07cd9e4745bf.js +26 -0
- sky/dashboard/out/_next/static/chunks/2755.a636e04a928a700e.js +31 -0
- sky/dashboard/out/_next/static/chunks/3495.05eab4862217c1a5.js +6 -0
- sky/dashboard/out/_next/static/chunks/3785.cfc5dcc9434fd98c.js +1 -0
- sky/dashboard/out/_next/static/chunks/3850-fd5696f3bbbaddae.js +1 -0
- sky/dashboard/out/_next/static/chunks/3981.645d01bf9c8cad0c.js +21 -0
- sky/dashboard/out/_next/static/chunks/4083-0115d67c1fb57d6c.js +21 -0
- sky/dashboard/out/_next/static/chunks/{8640.5b9475a2d18c5416.js → 429.a58e9ba9742309ed.js} +2 -2
- sky/dashboard/out/_next/static/chunks/4555.8e221537181b5dc1.js +6 -0
- sky/dashboard/out/_next/static/chunks/4725.937865b81fdaaebb.js +6 -0
- sky/dashboard/out/_next/static/chunks/6082-edabd8f6092300ce.js +25 -0
- sky/dashboard/out/_next/static/chunks/6989-49cb7dca83a7a62d.js +1 -0
- sky/dashboard/out/_next/static/chunks/6990-630bd2a2257275f8.js +1 -0
- sky/dashboard/out/_next/static/chunks/7248-a99800d4db8edabd.js +1 -0
- sky/dashboard/out/_next/static/chunks/754-cfc5d4ad1b843d29.js +18 -0
- sky/dashboard/out/_next/static/chunks/8050-dd8aa107b17dce00.js +16 -0
- sky/dashboard/out/_next/static/chunks/8056-d4ae1e0cb81e7368.js +1 -0
- sky/dashboard/out/_next/static/chunks/8555.011023e296c127b3.js +6 -0
- sky/dashboard/out/_next/static/chunks/8821-93c25df904a8362b.js +1 -0
- sky/dashboard/out/_next/static/chunks/8969-0662594b69432ade.js +1 -0
- sky/dashboard/out/_next/static/chunks/9025.f15c91c97d124a5f.js +6 -0
- sky/dashboard/out/_next/static/chunks/9353-7ad6bd01858556f1.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/_app-5a86569acad99764.js +34 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-8297476714acb4ac.js +6 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-337c3ba1085f1210.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/{clusters-ee39056f9851a3ff.js → clusters-57632ff3684a8b5c.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{config-dfb9bf07b13045f4.js → config-718cdc365de82689.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/infra/[context]-5fd3a453c079c2ea.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/infra-9f85c02c9c6cae9e.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-90f16972cbecf354.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-2dd42fc37aad427a.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs-ed806aeace26b972.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/plugins/[...slug]-449a9f5a3bb20fb3.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/users-bec34706b36f3524.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/{volumes-b84b948ff357c43e.js → volumes-a83ba9b38dff7ea9.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-84a40f8c7c627fe4.js → [name]-c781e9c3e52ef9fc.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces-91e0942f47310aae.js +1 -0
- sky/dashboard/out/_next/static/chunks/webpack-cfe59cf684ee13b9.js +1 -0
- sky/dashboard/out/_next/static/css/b0dbca28f027cc19.css +3 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/plugins/[...slug].html +1 -0
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/volumes.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/data/data_utils.py +26 -12
- sky/data/mounting_utils.py +44 -5
- sky/global_user_state.py +111 -19
- sky/jobs/client/sdk.py +8 -3
- sky/jobs/controller.py +191 -31
- sky/jobs/recovery_strategy.py +109 -11
- sky/jobs/server/core.py +81 -4
- sky/jobs/server/server.py +14 -0
- sky/jobs/state.py +417 -19
- sky/jobs/utils.py +73 -80
- sky/models.py +11 -0
- sky/optimizer.py +8 -6
- sky/provision/__init__.py +12 -9
- sky/provision/common.py +20 -0
- sky/provision/docker_utils.py +15 -2
- sky/provision/kubernetes/utils.py +163 -20
- sky/provision/kubernetes/volume.py +52 -17
- sky/provision/provisioner.py +17 -7
- sky/provision/runpod/instance.py +3 -1
- sky/provision/runpod/utils.py +13 -1
- sky/provision/runpod/volume.py +25 -9
- sky/provision/slurm/__init__.py +12 -0
- sky/provision/slurm/config.py +13 -0
- sky/provision/slurm/instance.py +618 -0
- sky/provision/slurm/utils.py +689 -0
- sky/provision/vast/instance.py +4 -1
- sky/provision/vast/utils.py +11 -6
- sky/resources.py +135 -13
- sky/schemas/api/responses.py +4 -0
- sky/schemas/db/global_user_state/010_save_ssh_key.py +1 -1
- sky/schemas/db/spot_jobs/008_add_full_resources.py +34 -0
- sky/schemas/db/spot_jobs/009_job_events.py +32 -0
- sky/schemas/db/spot_jobs/010_job_events_timestamp_with_timezone.py +43 -0
- sky/schemas/db/spot_jobs/011_add_links.py +34 -0
- sky/schemas/generated/jobsv1_pb2.py +9 -5
- sky/schemas/generated/jobsv1_pb2.pyi +12 -0
- sky/schemas/generated/jobsv1_pb2_grpc.py +44 -0
- sky/schemas/generated/managed_jobsv1_pb2.py +32 -28
- sky/schemas/generated/managed_jobsv1_pb2.pyi +11 -2
- sky/serve/serve_utils.py +232 -40
- sky/serve/server/impl.py +1 -1
- sky/server/common.py +17 -0
- sky/server/constants.py +1 -1
- sky/server/metrics.py +6 -3
- sky/server/plugins.py +238 -0
- sky/server/requests/executor.py +5 -2
- sky/server/requests/payloads.py +30 -1
- sky/server/requests/request_names.py +4 -0
- sky/server/requests/requests.py +33 -11
- sky/server/requests/serializers/encoders.py +22 -0
- sky/server/requests/serializers/return_value_serializers.py +70 -0
- sky/server/server.py +506 -109
- sky/server/server_utils.py +30 -0
- sky/server/uvicorn.py +5 -0
- sky/setup_files/MANIFEST.in +1 -0
- sky/setup_files/dependencies.py +22 -9
- sky/sky_logging.py +2 -1
- sky/skylet/attempt_skylet.py +13 -3
- sky/skylet/constants.py +55 -13
- sky/skylet/events.py +10 -4
- sky/skylet/executor/__init__.py +1 -0
- sky/skylet/executor/slurm.py +187 -0
- sky/skylet/job_lib.py +91 -5
- sky/skylet/log_lib.py +22 -6
- sky/skylet/log_lib.pyi +8 -6
- sky/skylet/services.py +18 -3
- sky/skylet/skylet.py +5 -1
- sky/skylet/subprocess_daemon.py +2 -1
- sky/ssh_node_pools/constants.py +12 -0
- sky/ssh_node_pools/core.py +40 -3
- sky/ssh_node_pools/deploy/__init__.py +4 -0
- sky/{utils/kubernetes/deploy_ssh_node_pools.py → ssh_node_pools/deploy/deploy.py} +279 -504
- sky/ssh_node_pools/deploy/tunnel/ssh-tunnel.sh +379 -0
- sky/ssh_node_pools/deploy/tunnel_utils.py +199 -0
- sky/ssh_node_pools/deploy/utils.py +173 -0
- sky/ssh_node_pools/server.py +11 -13
- sky/{utils/kubernetes/ssh_utils.py → ssh_node_pools/utils.py} +9 -6
- sky/templates/kubernetes-ray.yml.j2 +12 -6
- sky/templates/slurm-ray.yml.j2 +115 -0
- sky/templates/vast-ray.yml.j2 +1 -0
- sky/templates/websocket_proxy.py +18 -41
- sky/users/model.conf +1 -1
- sky/users/permission.py +85 -52
- sky/users/rbac.py +31 -3
- sky/utils/annotations.py +108 -8
- sky/utils/auth_utils.py +42 -0
- sky/utils/cli_utils/status_utils.py +19 -5
- sky/utils/cluster_utils.py +10 -3
- sky/utils/command_runner.py +389 -35
- sky/utils/command_runner.pyi +43 -4
- sky/utils/common_utils.py +47 -31
- sky/utils/context.py +32 -0
- sky/utils/db/db_utils.py +36 -6
- sky/utils/db/migration_utils.py +41 -21
- sky/utils/infra_utils.py +5 -1
- sky/utils/instance_links.py +139 -0
- sky/utils/interactive_utils.py +49 -0
- sky/utils/kubernetes/generate_kubeconfig.sh +42 -33
- sky/utils/kubernetes/kubernetes_deploy_utils.py +2 -94
- sky/utils/kubernetes/rsync_helper.sh +5 -1
- sky/utils/kubernetes/ssh-tunnel.sh +7 -376
- sky/utils/plugin_extensions/__init__.py +14 -0
- sky/utils/plugin_extensions/external_failure_source.py +176 -0
- sky/utils/resources_utils.py +10 -8
- sky/utils/rich_utils.py +9 -11
- sky/utils/schemas.py +93 -19
- sky/utils/status_lib.py +7 -0
- sky/utils/subprocess_utils.py +17 -0
- sky/volumes/client/sdk.py +6 -3
- sky/volumes/server/core.py +65 -27
- sky_templates/ray/start_cluster +8 -4
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/METADATA +67 -59
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/RECORD +208 -180
- sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +0 -1
- sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +0 -11
- sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +0 -6
- sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +0 -1
- sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +0 -1
- sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +0 -15
- sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +0 -26
- sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +0 -1
- sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +0 -1
- sky/dashboard/out/_next/static/chunks/3800-7b45f9fbb6308557.js +0 -1
- sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +0 -1
- sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +0 -1
- sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +0 -15
- sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +0 -13
- sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +0 -1
- sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +0 -1
- sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +0 -1
- sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +0 -30
- sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +0 -41
- sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +0 -1
- sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +0 -1
- sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +0 -6
- sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +0 -1
- sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +0 -31
- sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +0 -30
- sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +0 -34
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-abfcac9c137aa543.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-9faf940b253e3e06.js +0 -21
- sky/dashboard/out/_next/static/chunks/pages/jobs-2072b48b617989c9.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/users-f42674164aa73423.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces-531b2f8c4bf89f82.js +0 -1
- sky/dashboard/out/_next/static/chunks/webpack-64e05f17bf2cf8ce.js +0 -1
- sky/dashboard/out/_next/static/css/0748ce22df867032.css +0 -3
- /sky/dashboard/out/_next/static/{96_E2yl3QAiIJGOYCkSpB → 3nu-b8raeKRNABZ2d4GAG}/_ssgManifest.js +0 -0
- /sky/{utils/kubernetes → ssh_node_pools/deploy/tunnel}/cleanup-tunnel.sh +0 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/top_level.txt +0 -0
sky/server/server.py
CHANGED
|
@@ -15,12 +15,16 @@ import pathlib
|
|
|
15
15
|
import posixpath
|
|
16
16
|
import re
|
|
17
17
|
import resource
|
|
18
|
+
import shlex
|
|
18
19
|
import shutil
|
|
20
|
+
import socket
|
|
19
21
|
import struct
|
|
20
22
|
import sys
|
|
21
23
|
import threading
|
|
22
24
|
import traceback
|
|
23
|
-
|
|
25
|
+
import typing
|
|
26
|
+
from typing import (Any, Awaitable, Callable, Dict, List, Literal, Optional,
|
|
27
|
+
Set, Tuple, Type)
|
|
24
28
|
import uuid
|
|
25
29
|
import zipfile
|
|
26
30
|
|
|
@@ -43,11 +47,13 @@ from sky import global_user_state
|
|
|
43
47
|
from sky import models
|
|
44
48
|
from sky import sky_logging
|
|
45
49
|
from sky.data import storage_utils
|
|
50
|
+
from sky.jobs import state as managed_job_state
|
|
46
51
|
from sky.jobs import utils as managed_job_utils
|
|
47
52
|
from sky.jobs.server import server as jobs_rest
|
|
48
53
|
from sky.metrics import utils as metrics_utils
|
|
49
54
|
from sky.provision import metadata_utils
|
|
50
55
|
from sky.provision.kubernetes import utils as kubernetes_utils
|
|
56
|
+
from sky.provision.slurm import utils as slurm_utils
|
|
51
57
|
from sky.schemas.api import responses
|
|
52
58
|
from sky.serve.server import server as serve_rest
|
|
53
59
|
from sky.server import common
|
|
@@ -56,6 +62,8 @@ from sky.server import constants as server_constants
|
|
|
56
62
|
from sky.server import daemons
|
|
57
63
|
from sky.server import metrics
|
|
58
64
|
from sky.server import middleware_utils
|
|
65
|
+
from sky.server import plugins
|
|
66
|
+
from sky.server import server_utils
|
|
59
67
|
from sky.server import state
|
|
60
68
|
from sky.server import stream_utils
|
|
61
69
|
from sky.server import versions
|
|
@@ -73,6 +81,7 @@ from sky.usage import usage_lib
|
|
|
73
81
|
from sky.users import permission
|
|
74
82
|
from sky.users import server as users_rest
|
|
75
83
|
from sky.utils import admin_policy_utils
|
|
84
|
+
from sky.utils import command_runner
|
|
76
85
|
from sky.utils import common as common_lib
|
|
77
86
|
from sky.utils import common_utils
|
|
78
87
|
from sky.utils import context
|
|
@@ -80,6 +89,7 @@ from sky.utils import context_utils
|
|
|
80
89
|
from sky.utils import controller_utils
|
|
81
90
|
from sky.utils import dag_utils
|
|
82
91
|
from sky.utils import env_options
|
|
92
|
+
from sky.utils import interactive_utils
|
|
83
93
|
from sky.utils import perf_utils
|
|
84
94
|
from sky.utils import status_lib
|
|
85
95
|
from sky.utils import subprocess_utils
|
|
@@ -88,6 +98,9 @@ from sky.utils.db import db_utils
|
|
|
88
98
|
from sky.volumes.server import server as volumes_rest
|
|
89
99
|
from sky.workspaces import server as workspaces_rest
|
|
90
100
|
|
|
101
|
+
if typing.TYPE_CHECKING:
|
|
102
|
+
from sky import backends
|
|
103
|
+
|
|
91
104
|
# pylint: disable=ungrouped-imports
|
|
92
105
|
if sys.version_info >= (3, 10):
|
|
93
106
|
from typing import ParamSpec
|
|
@@ -205,6 +218,10 @@ class BasicAuthMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
|
205
218
|
"""Middleware to handle HTTP Basic Auth."""
|
|
206
219
|
|
|
207
220
|
async def dispatch(self, request: fastapi.Request, call_next):
|
|
221
|
+
# If a previous middleware already authenticated the user, pass through
|
|
222
|
+
if request.state.auth_user is not None:
|
|
223
|
+
return await call_next(request)
|
|
224
|
+
|
|
208
225
|
if managed_job_utils.is_consolidation_mode(
|
|
209
226
|
) and loopback.is_loopback_request(request):
|
|
210
227
|
return await call_next(request)
|
|
@@ -272,6 +289,10 @@ class BearerTokenMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
|
272
289
|
X-Skypilot-Auth-Mode header. The auth proxy should either validate the
|
|
273
290
|
auth or set the header X-Skypilot-Auth-Mode: token.
|
|
274
291
|
"""
|
|
292
|
+
# If a previous middleware already authenticated the user, pass through
|
|
293
|
+
if request.state.auth_user is not None:
|
|
294
|
+
return await call_next(request)
|
|
295
|
+
|
|
275
296
|
has_skypilot_auth_header = (
|
|
276
297
|
request.headers.get('X-Skypilot-Auth-Mode') == 'token')
|
|
277
298
|
auth_header = request.headers.get('authorization')
|
|
@@ -470,7 +491,8 @@ async def schedule_on_boot_check_async():
|
|
|
470
491
|
await executor.schedule_request_async(
|
|
471
492
|
request_id='skypilot-server-on-boot-check',
|
|
472
493
|
request_name=request_names.RequestName.CHECK,
|
|
473
|
-
request_body=
|
|
494
|
+
request_body=server_utils.build_body_at_server(
|
|
495
|
+
request=None, body_type=payloads.CheckBody),
|
|
474
496
|
func=sky_check.check,
|
|
475
497
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
476
498
|
is_skypilot_system=True,
|
|
@@ -493,7 +515,8 @@ async def lifespan(app: fastapi.FastAPI): # pylint: disable=redefined-outer-nam
|
|
|
493
515
|
await executor.schedule_request_async(
|
|
494
516
|
request_id=event.id,
|
|
495
517
|
request_name=event.name,
|
|
496
|
-
request_body=
|
|
518
|
+
request_body=server_utils.build_body_at_server(
|
|
519
|
+
request=None, body_type=payloads.RequestBody),
|
|
497
520
|
func=event.run_event,
|
|
498
521
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
499
522
|
is_skypilot_system=True,
|
|
@@ -652,6 +675,17 @@ app.add_middleware(BearerTokenMiddleware)
|
|
|
652
675
|
# middleware above.
|
|
653
676
|
app.add_middleware(InitializeRequestAuthUserMiddleware)
|
|
654
677
|
app.add_middleware(RequestIDMiddleware)
|
|
678
|
+
|
|
679
|
+
# Load plugins after all the middlewares are added, to keep the core
|
|
680
|
+
# middleware stack intact if a plugin adds new middlewares.
|
|
681
|
+
# Note: server.py will be imported twice in server process, once as
|
|
682
|
+
# the top-level entrypoint module and once imported by uvicorn, we only
|
|
683
|
+
# load the plugin when imported by uvicorn for server process.
|
|
684
|
+
# TODO(aylei): move uvicorn app out of the top-level module to avoid
|
|
685
|
+
# duplicate app initialization.
|
|
686
|
+
if __name__ == 'sky.server.server':
|
|
687
|
+
plugins.load_plugins(plugins.ExtensionContext(app=app))
|
|
688
|
+
|
|
655
689
|
app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
|
|
656
690
|
app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
|
|
657
691
|
app.include_router(users_rest.router, prefix='/users', tags=['users'])
|
|
@@ -746,8 +780,11 @@ async def enabled_clouds(request: fastapi.Request,
|
|
|
746
780
|
await executor.schedule_request_async(
|
|
747
781
|
request_id=request.state.request_id,
|
|
748
782
|
request_name=request_names.RequestName.ENABLED_CLOUDS,
|
|
749
|
-
request_body=
|
|
750
|
-
|
|
783
|
+
request_body=server_utils.build_body_at_server(
|
|
784
|
+
request=request,
|
|
785
|
+
body_type=payloads.EnabledCloudsBody,
|
|
786
|
+
workspace=workspace,
|
|
787
|
+
expand=expand),
|
|
751
788
|
func=core.enabled_clouds,
|
|
752
789
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
753
790
|
)
|
|
@@ -784,6 +821,36 @@ async def kubernetes_node_info(
|
|
|
784
821
|
)
|
|
785
822
|
|
|
786
823
|
|
|
824
|
+
@app.post('/slurm_gpu_availability')
|
|
825
|
+
async def slurm_gpu_availability(
|
|
826
|
+
request: fastapi.Request,
|
|
827
|
+
slurm_gpu_availability_body: payloads.SlurmGpuAvailabilityRequestBody
|
|
828
|
+
) -> None:
|
|
829
|
+
"""Gets real-time Slurm GPU availability."""
|
|
830
|
+
await executor.schedule_request_async(
|
|
831
|
+
request_id=request.state.request_id,
|
|
832
|
+
request_name=request_names.RequestName.REALTIME_SLURM_GPU_AVAILABILITY,
|
|
833
|
+
request_body=slurm_gpu_availability_body,
|
|
834
|
+
func=core.realtime_slurm_gpu_availability,
|
|
835
|
+
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
# Keep the GET method for backwards compatibility
|
|
840
|
+
@app.api_route('/slurm_node_info', methods=['GET', 'POST'])
|
|
841
|
+
async def slurm_node_info(
|
|
842
|
+
request: fastapi.Request,
|
|
843
|
+
slurm_node_info_body: payloads.SlurmNodeInfoRequestBody) -> None:
|
|
844
|
+
"""Gets detailed information for each node in the Slurm cluster."""
|
|
845
|
+
await executor.schedule_request_async(
|
|
846
|
+
request_id=request.state.request_id,
|
|
847
|
+
request_name=request_names.RequestName.SLURM_NODE_INFO,
|
|
848
|
+
request_body=slurm_node_info_body,
|
|
849
|
+
func=slurm_utils.slurm_node_info,
|
|
850
|
+
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
|
|
787
854
|
@app.get('/status_kubernetes')
|
|
788
855
|
async def status_kubernetes(request: fastapi.Request) -> None:
|
|
789
856
|
"""[Experimental] Get all SkyPilot resources (including from other '
|
|
@@ -791,7 +858,8 @@ async def status_kubernetes(request: fastapi.Request) -> None:
|
|
|
791
858
|
await executor.schedule_request_async(
|
|
792
859
|
request_id=request.state.request_id,
|
|
793
860
|
request_name=request_names.RequestName.STATUS_KUBERNETES,
|
|
794
|
-
request_body=
|
|
861
|
+
request_body=server_utils.build_body_at_server(
|
|
862
|
+
request=request, body_type=payloads.RequestBody),
|
|
795
863
|
func=core.status_kubernetes,
|
|
796
864
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
797
865
|
)
|
|
@@ -1454,13 +1522,29 @@ async def cost_report(request: fastapi.Request,
|
|
|
1454
1522
|
)
|
|
1455
1523
|
|
|
1456
1524
|
|
|
1525
|
+
@app.post('/cluster_events')
|
|
1526
|
+
async def cluster_events(
|
|
1527
|
+
request: fastapi.Request,
|
|
1528
|
+
cluster_events_body: payloads.ClusterEventsBody) -> None:
|
|
1529
|
+
"""Gets events for a cluster."""
|
|
1530
|
+
await executor.schedule_request_async(
|
|
1531
|
+
request_id=request.state.request_id,
|
|
1532
|
+
request_name=request_names.RequestName.CLUSTER_EVENTS,
|
|
1533
|
+
request_body=cluster_events_body,
|
|
1534
|
+
func=core.get_cluster_events,
|
|
1535
|
+
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
1536
|
+
request_cluster_name=cluster_events_body.cluster_name or '',
|
|
1537
|
+
)
|
|
1538
|
+
|
|
1539
|
+
|
|
1457
1540
|
@app.get('/storage/ls')
|
|
1458
1541
|
async def storage_ls(request: fastapi.Request) -> None:
|
|
1459
1542
|
"""Gets the storages."""
|
|
1460
1543
|
await executor.schedule_request_async(
|
|
1461
1544
|
request_id=request.state.request_id,
|
|
1462
1545
|
request_name=request_names.RequestName.STORAGE_LS,
|
|
1463
|
-
request_body=
|
|
1546
|
+
request_body=server_utils.build_body_at_server(
|
|
1547
|
+
request=request, body_type=payloads.RequestBody),
|
|
1464
1548
|
func=core.storage_ls,
|
|
1465
1549
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
1466
1550
|
)
|
|
@@ -1752,6 +1836,22 @@ async def api_status(
|
|
|
1752
1836
|
return encoded_request_tasks
|
|
1753
1837
|
|
|
1754
1838
|
|
|
1839
|
+
@app.get('/api/plugins', response_class=fastapi_responses.ORJSONResponse)
|
|
1840
|
+
async def list_plugins() -> Dict[str, List[Dict[str, Any]]]:
|
|
1841
|
+
"""Return metadata about loaded backend plugins."""
|
|
1842
|
+
plugin_infos = []
|
|
1843
|
+
for plugin_info in plugins.get_plugins():
|
|
1844
|
+
info = {
|
|
1845
|
+
'js_extension_path': plugin_info.js_extension_path,
|
|
1846
|
+
}
|
|
1847
|
+
for attr in ('name', 'version', 'commit'):
|
|
1848
|
+
value = getattr(plugin_info, attr, None)
|
|
1849
|
+
if value is not None:
|
|
1850
|
+
info[attr] = value
|
|
1851
|
+
plugin_infos.append(info)
|
|
1852
|
+
return {'plugins': plugin_infos}
|
|
1853
|
+
|
|
1854
|
+
|
|
1755
1855
|
@app.get(
|
|
1756
1856
|
'/api/health',
|
|
1757
1857
|
# response_model_exclude_unset omits unset fields
|
|
@@ -1823,12 +1923,149 @@ async def health(request: fastapi.Request) -> responses.APIHealthResponse:
|
|
|
1823
1923
|
)
|
|
1824
1924
|
|
|
1825
1925
|
|
|
1826
|
-
class
|
|
1926
|
+
class SSHMessageType(IntEnum):
|
|
1827
1927
|
REGULAR_DATA = 0
|
|
1828
1928
|
PINGPONG = 1
|
|
1829
1929
|
LATENCY_MEASUREMENT = 2
|
|
1830
1930
|
|
|
1831
1931
|
|
|
1932
|
+
async def _get_cluster_and_validate(
|
|
1933
|
+
cluster_name: str,
|
|
1934
|
+
cloud_type: Type[clouds.Cloud],
|
|
1935
|
+
) -> 'backends.CloudVmRayResourceHandle':
|
|
1936
|
+
"""Fetch cluster status and validate it's UP and correct cloud type."""
|
|
1937
|
+
# Run core.status in another thread to avoid blocking the event loop.
|
|
1938
|
+
# TODO(aylei): core.status() will be called with server user, which has
|
|
1939
|
+
# permission to all workspaces, this will break workspace isolation.
|
|
1940
|
+
# It is ok for now, as users with limited access will not get the ssh config
|
|
1941
|
+
# for the clusters in non-accessible workspaces.
|
|
1942
|
+
with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
|
1943
|
+
cluster_records = await context_utils.to_thread_with_executor(
|
|
1944
|
+
thread_pool_executor, core.status, cluster_name, all_users=True)
|
|
1945
|
+
cluster_record = cluster_records[0]
|
|
1946
|
+
if cluster_record['status'] != status_lib.ClusterStatus.UP:
|
|
1947
|
+
raise fastapi.HTTPException(
|
|
1948
|
+
status_code=400, detail=f'Cluster {cluster_name} is not running')
|
|
1949
|
+
|
|
1950
|
+
handle: Optional['backends.CloudVmRayResourceHandle'] = cluster_record[
|
|
1951
|
+
'handle']
|
|
1952
|
+
assert handle is not None, 'Cluster handle is None'
|
|
1953
|
+
if not isinstance(handle.launched_resources.cloud, cloud_type):
|
|
1954
|
+
raise fastapi.HTTPException(
|
|
1955
|
+
status_code=400,
|
|
1956
|
+
detail=f'Cluster {cluster_name} is not a {str(cloud_type())} '
|
|
1957
|
+
'cluster. Use ssh to connect to the cluster instead.')
|
|
1958
|
+
|
|
1959
|
+
return handle
|
|
1960
|
+
|
|
1961
|
+
|
|
1962
|
+
async def _run_websocket_proxy(
|
|
1963
|
+
websocket: fastapi.WebSocket,
|
|
1964
|
+
read_from_backend: Callable[[], Awaitable[bytes]],
|
|
1965
|
+
write_to_backend: Callable[[bytes], Awaitable[None]],
|
|
1966
|
+
close_backend: Callable[[], Awaitable[None]],
|
|
1967
|
+
timestamps_supported: bool,
|
|
1968
|
+
) -> bool:
|
|
1969
|
+
"""Run bidirectional WebSocket-to-backend proxy.
|
|
1970
|
+
|
|
1971
|
+
Args:
|
|
1972
|
+
websocket: FastAPI WebSocket connection
|
|
1973
|
+
read_from_backend: Async callable to read bytes from backend
|
|
1974
|
+
write_to_backend: Async callable to write bytes to backend
|
|
1975
|
+
close_backend: Async callable to close backend connection
|
|
1976
|
+
timestamps_supported: Whether to use message type framing
|
|
1977
|
+
|
|
1978
|
+
Returns:
|
|
1979
|
+
True if SSH failed, False otherwise
|
|
1980
|
+
"""
|
|
1981
|
+
ssh_failed = False
|
|
1982
|
+
websocket_closed = False
|
|
1983
|
+
|
|
1984
|
+
async def websocket_to_backend():
|
|
1985
|
+
try:
|
|
1986
|
+
async for message in websocket.iter_bytes():
|
|
1987
|
+
if timestamps_supported:
|
|
1988
|
+
type_size = struct.calcsize('!B')
|
|
1989
|
+
message_type = struct.unpack('!B', message[:type_size])[0]
|
|
1990
|
+
if message_type == SSHMessageType.REGULAR_DATA:
|
|
1991
|
+
# Regular data - strip type byte and forward to backend
|
|
1992
|
+
message = message[type_size:]
|
|
1993
|
+
elif message_type == SSHMessageType.PINGPONG:
|
|
1994
|
+
# PING message - respond with PONG
|
|
1995
|
+
ping_id_size = struct.calcsize('!I')
|
|
1996
|
+
if len(message) != type_size + ping_id_size:
|
|
1997
|
+
raise ValueError(
|
|
1998
|
+
f'Invalid PING message length: {len(message)}')
|
|
1999
|
+
# Return the same PING message for latency measurement
|
|
2000
|
+
await websocket.send_bytes(message)
|
|
2001
|
+
continue
|
|
2002
|
+
elif message_type == SSHMessageType.LATENCY_MEASUREMENT:
|
|
2003
|
+
# Latency measurement from client
|
|
2004
|
+
latency_size = struct.calcsize('!Q')
|
|
2005
|
+
if len(message) != type_size + latency_size:
|
|
2006
|
+
raise ValueError('Invalid latency measurement '
|
|
2007
|
+
f'message length: {len(message)}')
|
|
2008
|
+
avg_latency_ms = struct.unpack(
|
|
2009
|
+
'!Q',
|
|
2010
|
+
message[type_size:type_size + latency_size])[0]
|
|
2011
|
+
latency_seconds = avg_latency_ms / 1000
|
|
2012
|
+
metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels( # pylint: disable=line-too-long
|
|
2013
|
+
pid=os.getpid()).observe(latency_seconds)
|
|
2014
|
+
continue
|
|
2015
|
+
else:
|
|
2016
|
+
raise ValueError(
|
|
2017
|
+
f'Unknown message type: {message_type}')
|
|
2018
|
+
|
|
2019
|
+
try:
|
|
2020
|
+
await write_to_backend(message)
|
|
2021
|
+
except Exception as e: # pylint: disable=broad-except
|
|
2022
|
+
# Typically we will not reach here, if the conn to backend
|
|
2023
|
+
# is disconnected, backend_to_websocket will exit first.
|
|
2024
|
+
# But just in case.
|
|
2025
|
+
logger.error(f'Failed to write to backend through '
|
|
2026
|
+
f'connection: {e}')
|
|
2027
|
+
nonlocal ssh_failed
|
|
2028
|
+
ssh_failed = True
|
|
2029
|
+
break
|
|
2030
|
+
except fastapi.WebSocketDisconnect:
|
|
2031
|
+
pass
|
|
2032
|
+
nonlocal websocket_closed
|
|
2033
|
+
websocket_closed = True
|
|
2034
|
+
await close_backend()
|
|
2035
|
+
|
|
2036
|
+
async def backend_to_websocket():
|
|
2037
|
+
try:
|
|
2038
|
+
while True:
|
|
2039
|
+
data = await read_from_backend()
|
|
2040
|
+
if not data:
|
|
2041
|
+
if not websocket_closed:
|
|
2042
|
+
logger.warning(
|
|
2043
|
+
'SSH connection to backend is disconnected '
|
|
2044
|
+
'before websocket connection is closed')
|
|
2045
|
+
nonlocal ssh_failed
|
|
2046
|
+
ssh_failed = True
|
|
2047
|
+
break
|
|
2048
|
+
if timestamps_supported:
|
|
2049
|
+
# Prepend message type byte (0 = regular data)
|
|
2050
|
+
message_type_bytes = struct.pack(
|
|
2051
|
+
'!B', SSHMessageType.REGULAR_DATA.value)
|
|
2052
|
+
data = message_type_bytes + data
|
|
2053
|
+
await websocket.send_bytes(data)
|
|
2054
|
+
except Exception: # pylint: disable=broad-except
|
|
2055
|
+
pass
|
|
2056
|
+
try:
|
|
2057
|
+
await websocket.close()
|
|
2058
|
+
except Exception: # pylint: disable=broad-except
|
|
2059
|
+
# The websocket might have been closed by the client
|
|
2060
|
+
pass
|
|
2061
|
+
|
|
2062
|
+
await asyncio.gather(websocket_to_backend(),
|
|
2063
|
+
backend_to_websocket(),
|
|
2064
|
+
return_exceptions=True)
|
|
2065
|
+
|
|
2066
|
+
return ssh_failed
|
|
2067
|
+
|
|
2068
|
+
|
|
1832
2069
|
@app.websocket('/kubernetes-pod-ssh-proxy')
|
|
1833
2070
|
async def kubernetes_pod_ssh_proxy(
|
|
1834
2071
|
websocket: fastapi.WebSocket,
|
|
@@ -1842,22 +2079,7 @@ async def kubernetes_pod_ssh_proxy(
|
|
|
1842
2079
|
logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
|
|
1843
2080
|
client_version = {client_version}')
|
|
1844
2081
|
|
|
1845
|
-
|
|
1846
|
-
with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
|
1847
|
-
cluster_records = await context_utils.to_thread_with_executor(
|
|
1848
|
-
thread_pool_executor, core.status, cluster_name, all_users=True)
|
|
1849
|
-
cluster_record = cluster_records[0]
|
|
1850
|
-
if cluster_record['status'] != status_lib.ClusterStatus.UP:
|
|
1851
|
-
raise fastapi.HTTPException(
|
|
1852
|
-
status_code=400, detail=f'Cluster {cluster_name} is not running')
|
|
1853
|
-
|
|
1854
|
-
handle = cluster_record['handle']
|
|
1855
|
-
assert handle is not None, 'Cluster handle is None'
|
|
1856
|
-
if not isinstance(handle.launched_resources.cloud, clouds.Kubernetes):
|
|
1857
|
-
raise fastapi.HTTPException(
|
|
1858
|
-
status_code=400,
|
|
1859
|
-
detail=f'Cluster {cluster_name} is not a Kubernetes cluster'
|
|
1860
|
-
'Use ssh to connect to the cluster instead.')
|
|
2082
|
+
handle = await _get_cluster_and_validate(cluster_name, clouds.Kubernetes)
|
|
1861
2083
|
|
|
1862
2084
|
kubectl_cmd = handle.get_command_runners()[0].port_forward_command(
|
|
1863
2085
|
port_forward=[(None, 22)])
|
|
@@ -1887,96 +2109,25 @@ async def kubernetes_pod_ssh_proxy(
|
|
|
1887
2109
|
conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
|
|
1888
2110
|
pid=os.getpid())
|
|
1889
2111
|
ssh_failed = False
|
|
1890
|
-
websocket_closed = False
|
|
1891
2112
|
try:
|
|
1892
2113
|
conn_gauge.inc()
|
|
1893
2114
|
# Connect to the local port
|
|
1894
2115
|
reader, writer = await asyncio.open_connection('127.0.0.1', local_port)
|
|
1895
2116
|
|
|
1896
|
-
async def
|
|
1897
|
-
|
|
1898
|
-
|
|
1899
|
-
if timestamps_supported:
|
|
1900
|
-
type_size = struct.calcsize('!B')
|
|
1901
|
-
message_type = struct.unpack('!B',
|
|
1902
|
-
message[:type_size])[0]
|
|
1903
|
-
if (message_type ==
|
|
1904
|
-
KubernetesSSHMessageType.REGULAR_DATA):
|
|
1905
|
-
# Regular data - strip type byte and forward to SSH
|
|
1906
|
-
message = message[type_size:]
|
|
1907
|
-
elif message_type == KubernetesSSHMessageType.PINGPONG:
|
|
1908
|
-
# PING message - respond with PONG (type 1)
|
|
1909
|
-
ping_id_size = struct.calcsize('!I')
|
|
1910
|
-
if len(message) != type_size + ping_id_size:
|
|
1911
|
-
raise ValueError('Invalid PING message '
|
|
1912
|
-
f'length: {len(message)}')
|
|
1913
|
-
# Return the same PING message, so that the client
|
|
1914
|
-
# can measure the latency.
|
|
1915
|
-
await websocket.send_bytes(message)
|
|
1916
|
-
continue
|
|
1917
|
-
elif (message_type ==
|
|
1918
|
-
KubernetesSSHMessageType.LATENCY_MEASUREMENT):
|
|
1919
|
-
# Latency measurement from client
|
|
1920
|
-
latency_size = struct.calcsize('!Q')
|
|
1921
|
-
if len(message) != type_size + latency_size:
|
|
1922
|
-
raise ValueError(
|
|
1923
|
-
'Invalid latency measurement '
|
|
1924
|
-
f'message length: {len(message)}')
|
|
1925
|
-
avg_latency_ms = struct.unpack(
|
|
1926
|
-
'!Q',
|
|
1927
|
-
message[type_size:type_size + latency_size])[0]
|
|
1928
|
-
latency_seconds = avg_latency_ms / 1000
|
|
1929
|
-
metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels(pid=os.getpid()).observe(latency_seconds) # pylint: disable=line-too-long
|
|
1930
|
-
continue
|
|
1931
|
-
else:
|
|
1932
|
-
# Unknown message type.
|
|
1933
|
-
raise ValueError(
|
|
1934
|
-
f'Unknown message type: {message_type}')
|
|
1935
|
-
writer.write(message)
|
|
1936
|
-
try:
|
|
1937
|
-
await writer.drain()
|
|
1938
|
-
except Exception as e: # pylint: disable=broad-except
|
|
1939
|
-
# Typically we will not reach here, if the ssh to pod
|
|
1940
|
-
# is disconnected, ssh_to_websocket will exit first.
|
|
1941
|
-
# But just in case.
|
|
1942
|
-
logger.error('Failed to write to pod through '
|
|
1943
|
-
f'port-forward connection: {e}')
|
|
1944
|
-
nonlocal ssh_failed
|
|
1945
|
-
ssh_failed = True
|
|
1946
|
-
break
|
|
1947
|
-
except fastapi.WebSocketDisconnect:
|
|
1948
|
-
pass
|
|
1949
|
-
nonlocal websocket_closed
|
|
1950
|
-
websocket_closed = True
|
|
1951
|
-
writer.close()
|
|
2117
|
+
async def write_and_drain(data: bytes) -> None:
|
|
2118
|
+
writer.write(data)
|
|
2119
|
+
await writer.drain()
|
|
1952
2120
|
|
|
1953
|
-
async def
|
|
1954
|
-
|
|
1955
|
-
while True:
|
|
1956
|
-
data = await reader.read(1024)
|
|
1957
|
-
if not data:
|
|
1958
|
-
if not websocket_closed:
|
|
1959
|
-
logger.warning('SSH connection to pod is '
|
|
1960
|
-
'disconnected before websocket '
|
|
1961
|
-
'connection is closed')
|
|
1962
|
-
nonlocal ssh_failed
|
|
1963
|
-
ssh_failed = True
|
|
1964
|
-
break
|
|
1965
|
-
if timestamps_supported:
|
|
1966
|
-
# Prepend message type byte (0 = regular data)
|
|
1967
|
-
message_type_bytes = struct.pack(
|
|
1968
|
-
'!B', KubernetesSSHMessageType.REGULAR_DATA.value)
|
|
1969
|
-
data = message_type_bytes + data
|
|
1970
|
-
await websocket.send_bytes(data)
|
|
1971
|
-
except Exception: # pylint: disable=broad-except
|
|
1972
|
-
pass
|
|
1973
|
-
try:
|
|
1974
|
-
await websocket.close()
|
|
1975
|
-
except Exception: # pylint: disable=broad-except
|
|
1976
|
-
# The websocket might has been closed by the client.
|
|
1977
|
-
pass
|
|
2121
|
+
async def close_writer() -> None:
|
|
2122
|
+
writer.close()
|
|
1978
2123
|
|
|
1979
|
-
await
|
|
2124
|
+
ssh_failed = await _run_websocket_proxy(
|
|
2125
|
+
websocket,
|
|
2126
|
+
read_from_backend=lambda: reader.read(1024),
|
|
2127
|
+
write_to_backend=write_and_drain,
|
|
2128
|
+
close_backend=close_writer,
|
|
2129
|
+
timestamps_supported=timestamps_supported,
|
|
2130
|
+
)
|
|
1980
2131
|
finally:
|
|
1981
2132
|
conn_gauge.dec()
|
|
1982
2133
|
reason = ''
|
|
@@ -1990,7 +2141,7 @@ async def kubernetes_pod_ssh_proxy(
|
|
|
1990
2141
|
f'output: {str(stdout)}')
|
|
1991
2142
|
reason = 'KubectlPortForwardExit'
|
|
1992
2143
|
metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
|
|
1993
|
-
pid=os.getpid(), reason=
|
|
2144
|
+
pid=os.getpid(), reason=reason).inc()
|
|
1994
2145
|
else:
|
|
1995
2146
|
if ssh_failed:
|
|
1996
2147
|
reason = 'SSHToPodDisconnected'
|
|
@@ -2000,6 +2151,235 @@ async def kubernetes_pod_ssh_proxy(
|
|
|
2000
2151
|
pid=os.getpid(), reason=reason).inc()
|
|
2001
2152
|
|
|
2002
2153
|
|
|
2154
|
+
@app.websocket('/slurm-job-ssh-proxy')
|
|
2155
|
+
async def slurm_job_ssh_proxy(websocket: fastapi.WebSocket,
|
|
2156
|
+
cluster_name: str,
|
|
2157
|
+
worker: int = 0,
|
|
2158
|
+
client_version: Optional[int] = None) -> None:
|
|
2159
|
+
"""Proxies SSH to the Slurm job via sshd inside srun."""
|
|
2160
|
+
await websocket.accept()
|
|
2161
|
+
logger.info(f'WebSocket connection accepted for cluster: '
|
|
2162
|
+
f'{cluster_name}')
|
|
2163
|
+
|
|
2164
|
+
timestamps_supported = client_version is not None and client_version > 21
|
|
2165
|
+
logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
|
|
2166
|
+
client_version = {client_version}')
|
|
2167
|
+
|
|
2168
|
+
handle = await _get_cluster_and_validate(cluster_name, clouds.Slurm)
|
|
2169
|
+
|
|
2170
|
+
assert handle.cached_cluster_info is not None, 'Cached cluster info is None'
|
|
2171
|
+
provider_config = handle.cached_cluster_info.provider_config
|
|
2172
|
+
assert provider_config is not None, 'Provider config is None'
|
|
2173
|
+
login_node_ssh_config = provider_config['ssh']
|
|
2174
|
+
login_node_host = login_node_ssh_config['hostname']
|
|
2175
|
+
login_node_port = int(login_node_ssh_config['port'])
|
|
2176
|
+
login_node_user = login_node_ssh_config['user']
|
|
2177
|
+
login_node_key = login_node_ssh_config['private_key']
|
|
2178
|
+
login_node_proxy_command = login_node_ssh_config.get('proxycommand', None)
|
|
2179
|
+
login_node_proxy_jump = login_node_ssh_config.get('proxyjump', None)
|
|
2180
|
+
|
|
2181
|
+
login_node_runner = command_runner.SSHCommandRunner(
|
|
2182
|
+
(login_node_host, login_node_port),
|
|
2183
|
+
login_node_user,
|
|
2184
|
+
login_node_key,
|
|
2185
|
+
ssh_proxy_command=login_node_proxy_command,
|
|
2186
|
+
ssh_proxy_jump=login_node_proxy_jump,
|
|
2187
|
+
)
|
|
2188
|
+
|
|
2189
|
+
ssh_cmd = login_node_runner.ssh_base_command(
|
|
2190
|
+
ssh_mode=command_runner.SshMode.NON_INTERACTIVE,
|
|
2191
|
+
port_forward=None,
|
|
2192
|
+
connect_timeout=None)
|
|
2193
|
+
|
|
2194
|
+
# There can only be one InstanceInfo per instance_id.
|
|
2195
|
+
head_instance = handle.cached_cluster_info.get_head_instance()
|
|
2196
|
+
assert head_instance is not None, 'Head instance is None'
|
|
2197
|
+
job_id = head_instance.tags['job_id']
|
|
2198
|
+
|
|
2199
|
+
# Instances are ordered: head first, then workers
|
|
2200
|
+
instances = handle.cached_cluster_info.instances
|
|
2201
|
+
node_hostnames = [inst[0].tags['node'] for inst in instances.values()]
|
|
2202
|
+
if worker >= len(node_hostnames):
|
|
2203
|
+
raise fastapi.HTTPException(
|
|
2204
|
+
status_code=400,
|
|
2205
|
+
detail=f'Worker index {worker} out of range. '
|
|
2206
|
+
f'Cluster has {len(node_hostnames)} nodes.')
|
|
2207
|
+
target_node = node_hostnames[worker]
|
|
2208
|
+
|
|
2209
|
+
# Run sshd inside the Slurm job "container" via srun, such that it inherits
|
|
2210
|
+
# the resource constraints of the Slurm job.
|
|
2211
|
+
ssh_cmd += [
|
|
2212
|
+
shlex.quote(
|
|
2213
|
+
slurm_utils.srun_sshd_command(job_id, target_node, login_node_user))
|
|
2214
|
+
]
|
|
2215
|
+
|
|
2216
|
+
proc = await asyncio.create_subprocess_shell(
|
|
2217
|
+
' '.join(ssh_cmd),
|
|
2218
|
+
stdin=asyncio.subprocess.PIPE,
|
|
2219
|
+
stdout=asyncio.subprocess.PIPE,
|
|
2220
|
+
stderr=asyncio.subprocess.PIPE, # Capture stderr separately for logging
|
|
2221
|
+
)
|
|
2222
|
+
assert proc.stdin is not None
|
|
2223
|
+
assert proc.stdout is not None
|
|
2224
|
+
assert proc.stderr is not None
|
|
2225
|
+
|
|
2226
|
+
stdin = proc.stdin
|
|
2227
|
+
stdout = proc.stdout
|
|
2228
|
+
stderr = proc.stderr
|
|
2229
|
+
|
|
2230
|
+
async def log_stderr():
|
|
2231
|
+
while True:
|
|
2232
|
+
line = await stderr.readline()
|
|
2233
|
+
if not line:
|
|
2234
|
+
break
|
|
2235
|
+
logger.debug(f'srun stderr: {line.decode().rstrip()}')
|
|
2236
|
+
|
|
2237
|
+
stderr_task = None
|
|
2238
|
+
if env_options.Options.SHOW_DEBUG_INFO.get():
|
|
2239
|
+
stderr_task = asyncio.create_task(log_stderr())
|
|
2240
|
+
conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
|
|
2241
|
+
pid=os.getpid())
|
|
2242
|
+
ssh_failed = False
|
|
2243
|
+
try:
|
|
2244
|
+
conn_gauge.inc()
|
|
2245
|
+
|
|
2246
|
+
async def write_and_drain(data: bytes) -> None:
|
|
2247
|
+
stdin.write(data)
|
|
2248
|
+
await stdin.drain()
|
|
2249
|
+
|
|
2250
|
+
async def close_stdin() -> None:
|
|
2251
|
+
stdin.close()
|
|
2252
|
+
|
|
2253
|
+
ssh_failed = await _run_websocket_proxy(
|
|
2254
|
+
websocket,
|
|
2255
|
+
read_from_backend=lambda: stdout.read(4096),
|
|
2256
|
+
write_to_backend=write_and_drain,
|
|
2257
|
+
close_backend=close_stdin,
|
|
2258
|
+
timestamps_supported=timestamps_supported,
|
|
2259
|
+
)
|
|
2260
|
+
|
|
2261
|
+
finally:
|
|
2262
|
+
conn_gauge.dec()
|
|
2263
|
+
reason = ''
|
|
2264
|
+
try:
|
|
2265
|
+
logger.info('Terminating srun process')
|
|
2266
|
+
proc.terminate()
|
|
2267
|
+
except ProcessLookupError:
|
|
2268
|
+
stdout_data = await stdout.read()
|
|
2269
|
+
logger.error('srun process was terminated before the '
|
|
2270
|
+
'ssh websocket connection was closed. Remaining '
|
|
2271
|
+
f'output: {str(stdout_data)}')
|
|
2272
|
+
reason = 'SrunProcessExit'
|
|
2273
|
+
metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
|
|
2274
|
+
pid=os.getpid(), reason=reason).inc()
|
|
2275
|
+
else:
|
|
2276
|
+
if ssh_failed:
|
|
2277
|
+
reason = 'SSHToSlurmJobDisconnected'
|
|
2278
|
+
else:
|
|
2279
|
+
reason = 'ClientClosed'
|
|
2280
|
+
|
|
2281
|
+
metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
|
|
2282
|
+
pid=os.getpid(), reason=reason).inc()
|
|
2283
|
+
|
|
2284
|
+
# Cancel the stderr logging task if it's still running
|
|
2285
|
+
if stderr_task is not None and not stderr_task.done():
|
|
2286
|
+
stderr_task.cancel()
|
|
2287
|
+
try:
|
|
2288
|
+
await stderr_task
|
|
2289
|
+
except asyncio.CancelledError:
|
|
2290
|
+
pass
|
|
2291
|
+
|
|
2292
|
+
|
|
2293
|
+
@app.websocket('/ssh-interactive-auth')
|
|
2294
|
+
async def ssh_interactive_auth(websocket: fastapi.WebSocket,
|
|
2295
|
+
session_id: str) -> None:
|
|
2296
|
+
"""Proxies PTY for SSH interactive authentication via websocket.
|
|
2297
|
+
|
|
2298
|
+
This endpoint receives a PTY file descriptor from a worker process
|
|
2299
|
+
and bridges it bidirectionally with a websocket connection, allowing
|
|
2300
|
+
the client to handle interactive SSH authentication (e.g., 2FA).
|
|
2301
|
+
|
|
2302
|
+
Detects auth completion by monitoring terminal echo state and data flow.
|
|
2303
|
+
"""
|
|
2304
|
+
await websocket.accept()
|
|
2305
|
+
logger.info(f'WebSocket connection accepted for SSH auth session: '
|
|
2306
|
+
f'{session_id}')
|
|
2307
|
+
|
|
2308
|
+
loop = asyncio.get_running_loop()
|
|
2309
|
+
|
|
2310
|
+
# Connect to worker process to receive PTY file descriptor
|
|
2311
|
+
fd_socket_path = interactive_utils.get_pty_socket_path(session_id)
|
|
2312
|
+
fd_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
2313
|
+
master_fd = -1
|
|
2314
|
+
try:
|
|
2315
|
+
# Connect to worker's FD-passing socket
|
|
2316
|
+
await loop.sock_connect(fd_sock, fd_socket_path)
|
|
2317
|
+
master_fd = await loop.run_in_executor(None, interactive_utils.recv_fd,
|
|
2318
|
+
fd_sock)
|
|
2319
|
+
logger.debug(f'Received PTY master fd {master_fd} for session '
|
|
2320
|
+
f'{session_id}')
|
|
2321
|
+
|
|
2322
|
+
# Bridge PTY ↔ websocket bidirectionally
|
|
2323
|
+
async def websocket_to_pty():
|
|
2324
|
+
"""Forward websocket messages to PTY."""
|
|
2325
|
+
try:
|
|
2326
|
+
async for message in websocket.iter_bytes():
|
|
2327
|
+
await loop.run_in_executor(None, os.write, master_fd,
|
|
2328
|
+
message)
|
|
2329
|
+
except fastapi.WebSocketDisconnect:
|
|
2330
|
+
logger.debug(f'WebSocket disconnected for session {session_id}')
|
|
2331
|
+
except asyncio.CancelledError:
|
|
2332
|
+
pass
|
|
2333
|
+
except Exception as e: # pylint: disable=broad-except
|
|
2334
|
+
logger.error(f'Error in websocket_to_pty: {e}')
|
|
2335
|
+
|
|
2336
|
+
async def pty_to_websocket():
|
|
2337
|
+
"""Forward PTY output to websocket and detect auth completion.
|
|
2338
|
+
|
|
2339
|
+
Detects auth completion by monitoring terminal echo state.
|
|
2340
|
+
Echo is disabled during password prompts and enabled after
|
|
2341
|
+
successful authentication. Auth is considered complete when
|
|
2342
|
+
echo has been enabled for a sustained period (1s).
|
|
2343
|
+
"""
|
|
2344
|
+
try:
|
|
2345
|
+
while True:
|
|
2346
|
+
try:
|
|
2347
|
+
data = await loop.run_in_executor(
|
|
2348
|
+
None, os.read, master_fd, 4096)
|
|
2349
|
+
except OSError as e:
|
|
2350
|
+
logger.error(f'PTY read error (likely closed): {e}')
|
|
2351
|
+
break
|
|
2352
|
+
|
|
2353
|
+
if not data:
|
|
2354
|
+
break
|
|
2355
|
+
|
|
2356
|
+
await websocket.send_bytes(data)
|
|
2357
|
+
except asyncio.CancelledError:
|
|
2358
|
+
pass
|
|
2359
|
+
except Exception as e: # pylint: disable=broad-except
|
|
2360
|
+
logger.error(f'Error in pty_to_websocket: {e}')
|
|
2361
|
+
finally:
|
|
2362
|
+
try:
|
|
2363
|
+
await websocket.close()
|
|
2364
|
+
except Exception: # pylint: disable=broad-except
|
|
2365
|
+
pass
|
|
2366
|
+
|
|
2367
|
+
await asyncio.gather(websocket_to_pty(), pty_to_websocket())
|
|
2368
|
+
|
|
2369
|
+
except Exception as e: # pylint: disable=broad-except
|
|
2370
|
+
logger.error(f'Error in SSH interactive auth websocket: {e}')
|
|
2371
|
+
raise
|
|
2372
|
+
finally:
|
|
2373
|
+
# Clean up
|
|
2374
|
+
if master_fd >= 0:
|
|
2375
|
+
try:
|
|
2376
|
+
os.close(master_fd)
|
|
2377
|
+
except OSError:
|
|
2378
|
+
pass
|
|
2379
|
+
fd_sock.close()
|
|
2380
|
+
logger.debug(f'SSH interactive auth session {session_id} completed')
|
|
2381
|
+
|
|
2382
|
+
|
|
2003
2383
|
@app.get('/all_contexts')
|
|
2004
2384
|
async def all_contexts(request: fastapi.Request) -> None:
|
|
2005
2385
|
"""Gets all Kubernetes and SSH node pool contexts."""
|
|
@@ -2007,7 +2387,8 @@ async def all_contexts(request: fastapi.Request) -> None:
|
|
|
2007
2387
|
await executor.schedule_request_async(
|
|
2008
2388
|
request_id=request.state.request_id,
|
|
2009
2389
|
request_name=request_names.RequestName.ALL_CONTEXTS,
|
|
2010
|
-
request_body=
|
|
2390
|
+
request_body=server_utils.build_body_at_server(
|
|
2391
|
+
request=request, body_type=payloads.RequestBody),
|
|
2011
2392
|
func=core.get_all_contexts,
|
|
2012
2393
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
2013
2394
|
)
|
|
@@ -2057,6 +2438,14 @@ async def serve_dashboard(full_path: str):
|
|
|
2057
2438
|
if os.path.isfile(file_path):
|
|
2058
2439
|
return fastapi.responses.FileResponse(file_path)
|
|
2059
2440
|
|
|
2441
|
+
# Serve plugin catch-all page for any /plugins/* paths so client-side
|
|
2442
|
+
# routing can bootstrap correctly.
|
|
2443
|
+
if full_path == 'plugins' or full_path.startswith('plugins/'):
|
|
2444
|
+
plugin_catchall = os.path.join(server_constants.DASHBOARD_DIR,
|
|
2445
|
+
'plugins', '[...slug].html')
|
|
2446
|
+
if os.path.isfile(plugin_catchall):
|
|
2447
|
+
return fastapi.responses.FileResponse(plugin_catchall)
|
|
2448
|
+
|
|
2060
2449
|
# Serve index.html for client-side routing
|
|
2061
2450
|
# e.g. /clusters, /jobs
|
|
2062
2451
|
index_path = os.path.join(server_constants.DASHBOARD_DIR, 'index.html')
|
|
@@ -2161,6 +2550,9 @@ if __name__ == '__main__':
|
|
|
2161
2550
|
# Restore the server user hash
|
|
2162
2551
|
logger.info('Initializing server user hash')
|
|
2163
2552
|
_init_or_restore_server_user_hash()
|
|
2553
|
+
logger.info('Initializing permission service')
|
|
2554
|
+
permission.permission_service.initialize()
|
|
2555
|
+
logger.info('Permission service initialized')
|
|
2164
2556
|
|
|
2165
2557
|
max_db_connections = global_user_state.get_max_db_connections()
|
|
2166
2558
|
logger.info(f'Max db connections: {max_db_connections}')
|
|
@@ -2197,6 +2589,9 @@ if __name__ == '__main__':
|
|
|
2197
2589
|
global_tasks.append(
|
|
2198
2590
|
background.create_task(
|
|
2199
2591
|
global_user_state.cluster_event_retention_daemon()))
|
|
2592
|
+
global_tasks.append(
|
|
2593
|
+
background.create_task(
|
|
2594
|
+
managed_job_state.job_event_retention_daemon()))
|
|
2200
2595
|
threading.Thread(target=background.run_forever, daemon=True).start()
|
|
2201
2596
|
|
|
2202
2597
|
queue_server, workers = executor.start(config)
|
|
@@ -2220,6 +2615,8 @@ if __name__ == '__main__':
|
|
|
2220
2615
|
|
|
2221
2616
|
for gt in global_tasks:
|
|
2222
2617
|
gt.cancel()
|
|
2618
|
+
for plugin in plugins.get_plugins():
|
|
2619
|
+
plugin.shutdown()
|
|
2223
2620
|
subprocess_utils.run_in_parallel(lambda worker: worker.cancel(),
|
|
2224
2621
|
workers,
|
|
2225
2622
|
num_threads=len(workers))
|