skypilot-nightly 1.0.0.dev20251210__py3-none-any.whl → 1.0.0.dev20260112__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +4 -2
- sky/adaptors/slurm.py +159 -72
- sky/backends/backend_utils.py +52 -10
- sky/backends/cloud_vm_ray_backend.py +192 -32
- sky/backends/task_codegen.py +40 -2
- sky/catalog/data_fetchers/fetch_gcp.py +9 -1
- sky/catalog/data_fetchers/fetch_nebius.py +1 -1
- sky/catalog/data_fetchers/fetch_vast.py +4 -2
- sky/catalog/seeweb_catalog.py +30 -15
- sky/catalog/shadeform_catalog.py +5 -2
- sky/catalog/slurm_catalog.py +0 -7
- sky/catalog/vast_catalog.py +30 -6
- sky/check.py +11 -8
- sky/client/cli/command.py +106 -54
- sky/client/interactive_utils.py +190 -0
- sky/client/sdk.py +8 -0
- sky/client/sdk_async.py +9 -0
- sky/clouds/aws.py +60 -2
- sky/clouds/azure.py +2 -0
- sky/clouds/kubernetes.py +2 -0
- sky/clouds/runpod.py +38 -7
- sky/clouds/slurm.py +44 -12
- sky/clouds/ssh.py +1 -1
- sky/clouds/vast.py +30 -17
- sky/core.py +69 -1
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/3nu-b8raeKRNABZ2d4GAG/_buildManifest.js +1 -0
- sky/dashboard/out/_next/static/chunks/1871-0565f8975a7dcd10.js +6 -0
- sky/dashboard/out/_next/static/chunks/2109-55a1546d793574a7.js +11 -0
- sky/dashboard/out/_next/static/chunks/2521-099b07cd9e4745bf.js +26 -0
- sky/dashboard/out/_next/static/chunks/2755.a636e04a928a700e.js +31 -0
- sky/dashboard/out/_next/static/chunks/3495.05eab4862217c1a5.js +6 -0
- sky/dashboard/out/_next/static/chunks/3785.cfc5dcc9434fd98c.js +1 -0
- sky/dashboard/out/_next/static/chunks/3981.645d01bf9c8cad0c.js +21 -0
- sky/dashboard/out/_next/static/chunks/4083-0115d67c1fb57d6c.js +21 -0
- sky/dashboard/out/_next/static/chunks/{8640.5b9475a2d18c5416.js → 429.a58e9ba9742309ed.js} +2 -2
- sky/dashboard/out/_next/static/chunks/4555.8e221537181b5dc1.js +6 -0
- sky/dashboard/out/_next/static/chunks/4725.937865b81fdaaebb.js +6 -0
- sky/dashboard/out/_next/static/chunks/6082-edabd8f6092300ce.js +25 -0
- sky/dashboard/out/_next/static/chunks/6989-49cb7dca83a7a62d.js +1 -0
- sky/dashboard/out/_next/static/chunks/6990-630bd2a2257275f8.js +1 -0
- sky/dashboard/out/_next/static/chunks/7248-a99800d4db8edabd.js +1 -0
- sky/dashboard/out/_next/static/chunks/754-cfc5d4ad1b843d29.js +18 -0
- sky/dashboard/out/_next/static/chunks/8050-dd8aa107b17dce00.js +16 -0
- sky/dashboard/out/_next/static/chunks/8056-d4ae1e0cb81e7368.js +1 -0
- sky/dashboard/out/_next/static/chunks/8555.011023e296c127b3.js +6 -0
- sky/dashboard/out/_next/static/chunks/8821-93c25df904a8362b.js +1 -0
- sky/dashboard/out/_next/static/chunks/8969-0662594b69432ade.js +1 -0
- sky/dashboard/out/_next/static/chunks/9025.f15c91c97d124a5f.js +6 -0
- sky/dashboard/out/_next/static/chunks/{9353-8369df1cf105221c.js → 9353-7ad6bd01858556f1.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/_app-5a86569acad99764.js +34 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-8297476714acb4ac.js +6 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-337c3ba1085f1210.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/{clusters-9e5d47818b9bdadd.js → clusters-57632ff3684a8b5c.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/infra/[context]-5fd3a453c079c2ea.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/infra-9f85c02c9c6cae9e.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-90f16972cbecf354.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-2dd42fc37aad427a.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs-ed806aeace26b972.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/users-bec34706b36f3524.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/{volumes-ef19d49c6d0e8500.js → volumes-a83ba9b38dff7ea9.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-96e0f298308da7e2.js → [name]-c781e9c3e52ef9fc.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces-91e0942f47310aae.js +1 -0
- sky/dashboard/out/_next/static/chunks/webpack-cfe59cf684ee13b9.js +1 -0
- sky/dashboard/out/_next/static/css/b0dbca28f027cc19.css +3 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/plugins/[...slug].html +1 -1
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/volumes.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/data/data_utils.py +26 -12
- sky/data/mounting_utils.py +29 -4
- sky/global_user_state.py +108 -16
- sky/jobs/client/sdk.py +8 -3
- sky/jobs/controller.py +191 -31
- sky/jobs/recovery_strategy.py +109 -11
- sky/jobs/server/core.py +81 -4
- sky/jobs/server/server.py +14 -0
- sky/jobs/state.py +417 -19
- sky/jobs/utils.py +73 -80
- sky/models.py +9 -0
- sky/optimizer.py +2 -1
- sky/provision/__init__.py +11 -9
- sky/provision/kubernetes/utils.py +122 -15
- sky/provision/kubernetes/volume.py +52 -17
- sky/provision/provisioner.py +2 -1
- sky/provision/runpod/instance.py +3 -1
- sky/provision/runpod/utils.py +13 -1
- sky/provision/runpod/volume.py +25 -9
- sky/provision/slurm/instance.py +75 -29
- sky/provision/slurm/utils.py +213 -107
- sky/provision/vast/utils.py +1 -0
- sky/resources.py +135 -13
- sky/schemas/api/responses.py +4 -0
- sky/schemas/db/global_user_state/010_save_ssh_key.py +1 -1
- sky/schemas/db/spot_jobs/008_add_full_resources.py +34 -0
- sky/schemas/db/spot_jobs/009_job_events.py +32 -0
- sky/schemas/db/spot_jobs/010_job_events_timestamp_with_timezone.py +43 -0
- sky/schemas/db/spot_jobs/011_add_links.py +34 -0
- sky/schemas/generated/jobsv1_pb2.py +9 -5
- sky/schemas/generated/jobsv1_pb2.pyi +12 -0
- sky/schemas/generated/jobsv1_pb2_grpc.py +44 -0
- sky/schemas/generated/managed_jobsv1_pb2.py +32 -28
- sky/schemas/generated/managed_jobsv1_pb2.pyi +11 -2
- sky/serve/serve_utils.py +232 -40
- sky/server/common.py +17 -0
- sky/server/constants.py +1 -1
- sky/server/metrics.py +6 -3
- sky/server/plugins.py +16 -0
- sky/server/requests/payloads.py +18 -0
- sky/server/requests/request_names.py +2 -0
- sky/server/requests/requests.py +28 -10
- sky/server/requests/serializers/encoders.py +5 -0
- sky/server/requests/serializers/return_value_serializers.py +14 -4
- sky/server/server.py +434 -107
- sky/server/uvicorn.py +5 -0
- sky/setup_files/MANIFEST.in +1 -0
- sky/setup_files/dependencies.py +21 -10
- sky/sky_logging.py +2 -1
- sky/skylet/constants.py +22 -5
- sky/skylet/executor/slurm.py +4 -6
- sky/skylet/job_lib.py +89 -4
- sky/skylet/services.py +18 -3
- sky/ssh_node_pools/deploy/tunnel/cleanup-tunnel.sh +62 -0
- sky/ssh_node_pools/deploy/tunnel/ssh-tunnel.sh +379 -0
- sky/templates/kubernetes-ray.yml.j2 +4 -6
- sky/templates/slurm-ray.yml.j2 +32 -2
- sky/templates/websocket_proxy.py +18 -41
- sky/users/permission.py +61 -51
- sky/utils/auth_utils.py +42 -0
- sky/utils/cli_utils/status_utils.py +19 -5
- sky/utils/cluster_utils.py +10 -3
- sky/utils/command_runner.py +256 -94
- sky/utils/command_runner.pyi +16 -0
- sky/utils/common_utils.py +30 -29
- sky/utils/context.py +32 -0
- sky/utils/db/db_utils.py +36 -6
- sky/utils/db/migration_utils.py +41 -21
- sky/utils/infra_utils.py +5 -1
- sky/utils/instance_links.py +139 -0
- sky/utils/interactive_utils.py +49 -0
- sky/utils/kubernetes/generate_kubeconfig.sh +42 -33
- sky/utils/kubernetes/rsync_helper.sh +5 -1
- sky/utils/plugin_extensions/__init__.py +14 -0
- sky/utils/plugin_extensions/external_failure_source.py +176 -0
- sky/utils/resources_utils.py +10 -8
- sky/utils/rich_utils.py +9 -11
- sky/utils/schemas.py +63 -20
- sky/utils/status_lib.py +7 -0
- sky/utils/subprocess_utils.py +17 -0
- sky/volumes/client/sdk.py +6 -3
- sky/volumes/server/core.py +65 -27
- sky_templates/ray/start_cluster +8 -4
- {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/METADATA +53 -57
- {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/RECORD +172 -162
- sky/dashboard/out/_next/static/KYAhEFa3FTfq4JyKVgo-s/_buildManifest.js +0 -1
- sky/dashboard/out/_next/static/chunks/1141-9c810f01ff4f398a.js +0 -11
- sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +0 -6
- sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +0 -1
- sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +0 -1
- sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +0 -15
- sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +0 -26
- sky/dashboard/out/_next/static/chunks/3294.ddda8c6c6f9f24dc.js +0 -1
- sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +0 -1
- sky/dashboard/out/_next/static/chunks/3800-b589397dc09c5b4e.js +0 -1
- sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +0 -1
- sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +0 -15
- sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +0 -13
- sky/dashboard/out/_next/static/chunks/6856-da20c5fd999f319c.js +0 -1
- sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +0 -1
- sky/dashboard/out/_next/static/chunks/6990-09cbf02d3cd518c3.js +0 -1
- sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +0 -30
- sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +0 -41
- sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +0 -1
- sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +0 -1
- sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +0 -6
- sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +0 -31
- sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +0 -30
- sky/dashboard/out/_next/static/chunks/pages/_app-68b647e26f9d2793.js +0 -34
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-33f525539665fdfd.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-a7565f586ef86467.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/infra/[context]-12c559ec4d81fdbd.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/infra-d187cd0413d72475.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-895847b6cf200b04.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-8d0f4655400b4eb9.js +0 -21
- sky/dashboard/out/_next/static/chunks/pages/jobs-e5a98f17f8513a96.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/users-2f7646eb77785a2c.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces-cb4da3abe08ebf19.js +0 -1
- sky/dashboard/out/_next/static/chunks/webpack-fba3de387ff6bb08.js +0 -1
- sky/dashboard/out/_next/static/css/c5a4cfd2600fc715.css +0 -3
- /sky/dashboard/out/_next/static/{KYAhEFa3FTfq4JyKVgo-s → 3nu-b8raeKRNABZ2d4GAG}/_ssgManifest.js +0 -0
- /sky/dashboard/out/_next/static/chunks/pages/plugins/{[...slug]-4f46050ca065d8f8.js → [...slug]-449a9f5a3bb20fb3.js} +0 -0
- {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/top_level.txt +0 -0
sky/server/server.py
CHANGED
|
@@ -15,12 +15,16 @@ import pathlib
|
|
|
15
15
|
import posixpath
|
|
16
16
|
import re
|
|
17
17
|
import resource
|
|
18
|
+
import shlex
|
|
18
19
|
import shutil
|
|
20
|
+
import socket
|
|
19
21
|
import struct
|
|
20
22
|
import sys
|
|
21
23
|
import threading
|
|
22
24
|
import traceback
|
|
23
|
-
|
|
25
|
+
import typing
|
|
26
|
+
from typing import (Any, Awaitable, Callable, Dict, List, Literal, Optional,
|
|
27
|
+
Set, Tuple, Type)
|
|
24
28
|
import uuid
|
|
25
29
|
import zipfile
|
|
26
30
|
|
|
@@ -43,6 +47,7 @@ from sky import global_user_state
|
|
|
43
47
|
from sky import models
|
|
44
48
|
from sky import sky_logging
|
|
45
49
|
from sky.data import storage_utils
|
|
50
|
+
from sky.jobs import state as managed_job_state
|
|
46
51
|
from sky.jobs import utils as managed_job_utils
|
|
47
52
|
from sky.jobs.server import server as jobs_rest
|
|
48
53
|
from sky.metrics import utils as metrics_utils
|
|
@@ -76,6 +81,7 @@ from sky.usage import usage_lib
|
|
|
76
81
|
from sky.users import permission
|
|
77
82
|
from sky.users import server as users_rest
|
|
78
83
|
from sky.utils import admin_policy_utils
|
|
84
|
+
from sky.utils import command_runner
|
|
79
85
|
from sky.utils import common as common_lib
|
|
80
86
|
from sky.utils import common_utils
|
|
81
87
|
from sky.utils import context
|
|
@@ -83,6 +89,7 @@ from sky.utils import context_utils
|
|
|
83
89
|
from sky.utils import controller_utils
|
|
84
90
|
from sky.utils import dag_utils
|
|
85
91
|
from sky.utils import env_options
|
|
92
|
+
from sky.utils import interactive_utils
|
|
86
93
|
from sky.utils import perf_utils
|
|
87
94
|
from sky.utils import status_lib
|
|
88
95
|
from sky.utils import subprocess_utils
|
|
@@ -91,6 +98,9 @@ from sky.utils.db import db_utils
|
|
|
91
98
|
from sky.volumes.server import server as volumes_rest
|
|
92
99
|
from sky.workspaces import server as workspaces_rest
|
|
93
100
|
|
|
101
|
+
if typing.TYPE_CHECKING:
|
|
102
|
+
from sky import backends
|
|
103
|
+
|
|
94
104
|
# pylint: disable=ungrouped-imports
|
|
95
105
|
if sys.version_info >= (3, 10):
|
|
96
106
|
from typing import ParamSpec
|
|
@@ -208,6 +218,10 @@ class BasicAuthMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
|
208
218
|
"""Middleware to handle HTTP Basic Auth."""
|
|
209
219
|
|
|
210
220
|
async def dispatch(self, request: fastapi.Request, call_next):
|
|
221
|
+
# If a previous middleware already authenticated the user, pass through
|
|
222
|
+
if request.state.auth_user is not None:
|
|
223
|
+
return await call_next(request)
|
|
224
|
+
|
|
211
225
|
if managed_job_utils.is_consolidation_mode(
|
|
212
226
|
) and loopback.is_loopback_request(request):
|
|
213
227
|
return await call_next(request)
|
|
@@ -275,6 +289,10 @@ class BearerTokenMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
|
275
289
|
X-Skypilot-Auth-Mode header. The auth proxy should either validate the
|
|
276
290
|
auth or set the header X-Skypilot-Auth-Mode: token.
|
|
277
291
|
"""
|
|
292
|
+
# If a previous middleware already authenticated the user, pass through
|
|
293
|
+
if request.state.auth_user is not None:
|
|
294
|
+
return await call_next(request)
|
|
295
|
+
|
|
278
296
|
has_skypilot_auth_header = (
|
|
279
297
|
request.headers.get('X-Skypilot-Auth-Mode') == 'token')
|
|
280
298
|
auth_header = request.headers.get('authorization')
|
|
@@ -818,7 +836,8 @@ async def slurm_gpu_availability(
|
|
|
818
836
|
)
|
|
819
837
|
|
|
820
838
|
|
|
821
|
-
|
|
839
|
+
# Keep the GET method for backwards compatibility
|
|
840
|
+
@app.api_route('/slurm_node_info', methods=['GET', 'POST'])
|
|
822
841
|
async def slurm_node_info(
|
|
823
842
|
request: fastapi.Request,
|
|
824
843
|
slurm_node_info_body: payloads.SlurmNodeInfoRequestBody) -> None:
|
|
@@ -1503,6 +1522,21 @@ async def cost_report(request: fastapi.Request,
|
|
|
1503
1522
|
)
|
|
1504
1523
|
|
|
1505
1524
|
|
|
1525
|
+
@app.post('/cluster_events')
|
|
1526
|
+
async def cluster_events(
|
|
1527
|
+
request: fastapi.Request,
|
|
1528
|
+
cluster_events_body: payloads.ClusterEventsBody) -> None:
|
|
1529
|
+
"""Gets events for a cluster."""
|
|
1530
|
+
await executor.schedule_request_async(
|
|
1531
|
+
request_id=request.state.request_id,
|
|
1532
|
+
request_name=request_names.RequestName.CLUSTER_EVENTS,
|
|
1533
|
+
request_body=cluster_events_body,
|
|
1534
|
+
func=core.get_cluster_events,
|
|
1535
|
+
schedule_type=requests_lib.ScheduleType.SHORT,
|
|
1536
|
+
request_cluster_name=cluster_events_body.cluster_name or '',
|
|
1537
|
+
)
|
|
1538
|
+
|
|
1539
|
+
|
|
1506
1540
|
@app.get('/storage/ls')
|
|
1507
1541
|
async def storage_ls(request: fastapi.Request) -> None:
|
|
1508
1542
|
"""Gets the storages."""
|
|
@@ -1805,10 +1839,17 @@ async def api_status(
|
|
|
1805
1839
|
@app.get('/api/plugins', response_class=fastapi_responses.ORJSONResponse)
|
|
1806
1840
|
async def list_plugins() -> Dict[str, List[Dict[str, Any]]]:
|
|
1807
1841
|
"""Return metadata about loaded backend plugins."""
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
|
|
1842
|
+
plugin_infos = []
|
|
1843
|
+
for plugin_info in plugins.get_plugins():
|
|
1844
|
+
info = {
|
|
1845
|
+
'js_extension_path': plugin_info.js_extension_path,
|
|
1846
|
+
}
|
|
1847
|
+
for attr in ('name', 'version', 'commit'):
|
|
1848
|
+
value = getattr(plugin_info, attr, None)
|
|
1849
|
+
if value is not None:
|
|
1850
|
+
info[attr] = value
|
|
1851
|
+
plugin_infos.append(info)
|
|
1852
|
+
return {'plugins': plugin_infos}
|
|
1812
1853
|
|
|
1813
1854
|
|
|
1814
1855
|
@app.get(
|
|
@@ -1882,12 +1923,149 @@ async def health(request: fastapi.Request) -> responses.APIHealthResponse:
|
|
|
1882
1923
|
)
|
|
1883
1924
|
|
|
1884
1925
|
|
|
1885
|
-
class
|
|
1926
|
+
class SSHMessageType(IntEnum):
|
|
1886
1927
|
REGULAR_DATA = 0
|
|
1887
1928
|
PINGPONG = 1
|
|
1888
1929
|
LATENCY_MEASUREMENT = 2
|
|
1889
1930
|
|
|
1890
1931
|
|
|
1932
|
+
async def _get_cluster_and_validate(
|
|
1933
|
+
cluster_name: str,
|
|
1934
|
+
cloud_type: Type[clouds.Cloud],
|
|
1935
|
+
) -> 'backends.CloudVmRayResourceHandle':
|
|
1936
|
+
"""Fetch cluster status and validate it's UP and correct cloud type."""
|
|
1937
|
+
# Run core.status in another thread to avoid blocking the event loop.
|
|
1938
|
+
# TODO(aylei): core.status() will be called with server user, which has
|
|
1939
|
+
# permission to all workspaces, this will break workspace isolation.
|
|
1940
|
+
# It is ok for now, as users with limited access will not get the ssh config
|
|
1941
|
+
# for the clusters in non-accessible workspaces.
|
|
1942
|
+
with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
|
1943
|
+
cluster_records = await context_utils.to_thread_with_executor(
|
|
1944
|
+
thread_pool_executor, core.status, cluster_name, all_users=True)
|
|
1945
|
+
cluster_record = cluster_records[0]
|
|
1946
|
+
if cluster_record['status'] != status_lib.ClusterStatus.UP:
|
|
1947
|
+
raise fastapi.HTTPException(
|
|
1948
|
+
status_code=400, detail=f'Cluster {cluster_name} is not running')
|
|
1949
|
+
|
|
1950
|
+
handle: Optional['backends.CloudVmRayResourceHandle'] = cluster_record[
|
|
1951
|
+
'handle']
|
|
1952
|
+
assert handle is not None, 'Cluster handle is None'
|
|
1953
|
+
if not isinstance(handle.launched_resources.cloud, cloud_type):
|
|
1954
|
+
raise fastapi.HTTPException(
|
|
1955
|
+
status_code=400,
|
|
1956
|
+
detail=f'Cluster {cluster_name} is not a {str(cloud_type())} '
|
|
1957
|
+
'cluster. Use ssh to connect to the cluster instead.')
|
|
1958
|
+
|
|
1959
|
+
return handle
|
|
1960
|
+
|
|
1961
|
+
|
|
1962
|
+
async def _run_websocket_proxy(
|
|
1963
|
+
websocket: fastapi.WebSocket,
|
|
1964
|
+
read_from_backend: Callable[[], Awaitable[bytes]],
|
|
1965
|
+
write_to_backend: Callable[[bytes], Awaitable[None]],
|
|
1966
|
+
close_backend: Callable[[], Awaitable[None]],
|
|
1967
|
+
timestamps_supported: bool,
|
|
1968
|
+
) -> bool:
|
|
1969
|
+
"""Run bidirectional WebSocket-to-backend proxy.
|
|
1970
|
+
|
|
1971
|
+
Args:
|
|
1972
|
+
websocket: FastAPI WebSocket connection
|
|
1973
|
+
read_from_backend: Async callable to read bytes from backend
|
|
1974
|
+
write_to_backend: Async callable to write bytes to backend
|
|
1975
|
+
close_backend: Async callable to close backend connection
|
|
1976
|
+
timestamps_supported: Whether to use message type framing
|
|
1977
|
+
|
|
1978
|
+
Returns:
|
|
1979
|
+
True if SSH failed, False otherwise
|
|
1980
|
+
"""
|
|
1981
|
+
ssh_failed = False
|
|
1982
|
+
websocket_closed = False
|
|
1983
|
+
|
|
1984
|
+
async def websocket_to_backend():
|
|
1985
|
+
try:
|
|
1986
|
+
async for message in websocket.iter_bytes():
|
|
1987
|
+
if timestamps_supported:
|
|
1988
|
+
type_size = struct.calcsize('!B')
|
|
1989
|
+
message_type = struct.unpack('!B', message[:type_size])[0]
|
|
1990
|
+
if message_type == SSHMessageType.REGULAR_DATA:
|
|
1991
|
+
# Regular data - strip type byte and forward to backend
|
|
1992
|
+
message = message[type_size:]
|
|
1993
|
+
elif message_type == SSHMessageType.PINGPONG:
|
|
1994
|
+
# PING message - respond with PONG
|
|
1995
|
+
ping_id_size = struct.calcsize('!I')
|
|
1996
|
+
if len(message) != type_size + ping_id_size:
|
|
1997
|
+
raise ValueError(
|
|
1998
|
+
f'Invalid PING message length: {len(message)}')
|
|
1999
|
+
# Return the same PING message for latency measurement
|
|
2000
|
+
await websocket.send_bytes(message)
|
|
2001
|
+
continue
|
|
2002
|
+
elif message_type == SSHMessageType.LATENCY_MEASUREMENT:
|
|
2003
|
+
# Latency measurement from client
|
|
2004
|
+
latency_size = struct.calcsize('!Q')
|
|
2005
|
+
if len(message) != type_size + latency_size:
|
|
2006
|
+
raise ValueError('Invalid latency measurement '
|
|
2007
|
+
f'message length: {len(message)}')
|
|
2008
|
+
avg_latency_ms = struct.unpack(
|
|
2009
|
+
'!Q',
|
|
2010
|
+
message[type_size:type_size + latency_size])[0]
|
|
2011
|
+
latency_seconds = avg_latency_ms / 1000
|
|
2012
|
+
metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels( # pylint: disable=line-too-long
|
|
2013
|
+
pid=os.getpid()).observe(latency_seconds)
|
|
2014
|
+
continue
|
|
2015
|
+
else:
|
|
2016
|
+
raise ValueError(
|
|
2017
|
+
f'Unknown message type: {message_type}')
|
|
2018
|
+
|
|
2019
|
+
try:
|
|
2020
|
+
await write_to_backend(message)
|
|
2021
|
+
except Exception as e: # pylint: disable=broad-except
|
|
2022
|
+
# Typically we will not reach here, if the conn to backend
|
|
2023
|
+
# is disconnected, backend_to_websocket will exit first.
|
|
2024
|
+
# But just in case.
|
|
2025
|
+
logger.error(f'Failed to write to backend through '
|
|
2026
|
+
f'connection: {e}')
|
|
2027
|
+
nonlocal ssh_failed
|
|
2028
|
+
ssh_failed = True
|
|
2029
|
+
break
|
|
2030
|
+
except fastapi.WebSocketDisconnect:
|
|
2031
|
+
pass
|
|
2032
|
+
nonlocal websocket_closed
|
|
2033
|
+
websocket_closed = True
|
|
2034
|
+
await close_backend()
|
|
2035
|
+
|
|
2036
|
+
async def backend_to_websocket():
|
|
2037
|
+
try:
|
|
2038
|
+
while True:
|
|
2039
|
+
data = await read_from_backend()
|
|
2040
|
+
if not data:
|
|
2041
|
+
if not websocket_closed:
|
|
2042
|
+
logger.warning(
|
|
2043
|
+
'SSH connection to backend is disconnected '
|
|
2044
|
+
'before websocket connection is closed')
|
|
2045
|
+
nonlocal ssh_failed
|
|
2046
|
+
ssh_failed = True
|
|
2047
|
+
break
|
|
2048
|
+
if timestamps_supported:
|
|
2049
|
+
# Prepend message type byte (0 = regular data)
|
|
2050
|
+
message_type_bytes = struct.pack(
|
|
2051
|
+
'!B', SSHMessageType.REGULAR_DATA.value)
|
|
2052
|
+
data = message_type_bytes + data
|
|
2053
|
+
await websocket.send_bytes(data)
|
|
2054
|
+
except Exception: # pylint: disable=broad-except
|
|
2055
|
+
pass
|
|
2056
|
+
try:
|
|
2057
|
+
await websocket.close()
|
|
2058
|
+
except Exception: # pylint: disable=broad-except
|
|
2059
|
+
# The websocket might have been closed by the client
|
|
2060
|
+
pass
|
|
2061
|
+
|
|
2062
|
+
await asyncio.gather(websocket_to_backend(),
|
|
2063
|
+
backend_to_websocket(),
|
|
2064
|
+
return_exceptions=True)
|
|
2065
|
+
|
|
2066
|
+
return ssh_failed
|
|
2067
|
+
|
|
2068
|
+
|
|
1891
2069
|
@app.websocket('/kubernetes-pod-ssh-proxy')
|
|
1892
2070
|
async def kubernetes_pod_ssh_proxy(
|
|
1893
2071
|
websocket: fastapi.WebSocket,
|
|
@@ -1901,22 +2079,7 @@ async def kubernetes_pod_ssh_proxy(
|
|
|
1901
2079
|
logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
|
|
1902
2080
|
client_version = {client_version}')
|
|
1903
2081
|
|
|
1904
|
-
|
|
1905
|
-
with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
|
|
1906
|
-
cluster_records = await context_utils.to_thread_with_executor(
|
|
1907
|
-
thread_pool_executor, core.status, cluster_name, all_users=True)
|
|
1908
|
-
cluster_record = cluster_records[0]
|
|
1909
|
-
if cluster_record['status'] != status_lib.ClusterStatus.UP:
|
|
1910
|
-
raise fastapi.HTTPException(
|
|
1911
|
-
status_code=400, detail=f'Cluster {cluster_name} is not running')
|
|
1912
|
-
|
|
1913
|
-
handle = cluster_record['handle']
|
|
1914
|
-
assert handle is not None, 'Cluster handle is None'
|
|
1915
|
-
if not isinstance(handle.launched_resources.cloud, clouds.Kubernetes):
|
|
1916
|
-
raise fastapi.HTTPException(
|
|
1917
|
-
status_code=400,
|
|
1918
|
-
detail=f'Cluster {cluster_name} is not a Kubernetes cluster'
|
|
1919
|
-
'Use ssh to connect to the cluster instead.')
|
|
2082
|
+
handle = await _get_cluster_and_validate(cluster_name, clouds.Kubernetes)
|
|
1920
2083
|
|
|
1921
2084
|
kubectl_cmd = handle.get_command_runners()[0].port_forward_command(
|
|
1922
2085
|
port_forward=[(None, 22)])
|
|
@@ -1946,96 +2109,25 @@ async def kubernetes_pod_ssh_proxy(
|
|
|
1946
2109
|
conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
|
|
1947
2110
|
pid=os.getpid())
|
|
1948
2111
|
ssh_failed = False
|
|
1949
|
-
websocket_closed = False
|
|
1950
2112
|
try:
|
|
1951
2113
|
conn_gauge.inc()
|
|
1952
2114
|
# Connect to the local port
|
|
1953
2115
|
reader, writer = await asyncio.open_connection('127.0.0.1', local_port)
|
|
1954
2116
|
|
|
1955
|
-
async def
|
|
1956
|
-
|
|
1957
|
-
|
|
1958
|
-
if timestamps_supported:
|
|
1959
|
-
type_size = struct.calcsize('!B')
|
|
1960
|
-
message_type = struct.unpack('!B',
|
|
1961
|
-
message[:type_size])[0]
|
|
1962
|
-
if (message_type ==
|
|
1963
|
-
KubernetesSSHMessageType.REGULAR_DATA):
|
|
1964
|
-
# Regular data - strip type byte and forward to SSH
|
|
1965
|
-
message = message[type_size:]
|
|
1966
|
-
elif message_type == KubernetesSSHMessageType.PINGPONG:
|
|
1967
|
-
# PING message - respond with PONG (type 1)
|
|
1968
|
-
ping_id_size = struct.calcsize('!I')
|
|
1969
|
-
if len(message) != type_size + ping_id_size:
|
|
1970
|
-
raise ValueError('Invalid PING message '
|
|
1971
|
-
f'length: {len(message)}')
|
|
1972
|
-
# Return the same PING message, so that the client
|
|
1973
|
-
# can measure the latency.
|
|
1974
|
-
await websocket.send_bytes(message)
|
|
1975
|
-
continue
|
|
1976
|
-
elif (message_type ==
|
|
1977
|
-
KubernetesSSHMessageType.LATENCY_MEASUREMENT):
|
|
1978
|
-
# Latency measurement from client
|
|
1979
|
-
latency_size = struct.calcsize('!Q')
|
|
1980
|
-
if len(message) != type_size + latency_size:
|
|
1981
|
-
raise ValueError(
|
|
1982
|
-
'Invalid latency measurement '
|
|
1983
|
-
f'message length: {len(message)}')
|
|
1984
|
-
avg_latency_ms = struct.unpack(
|
|
1985
|
-
'!Q',
|
|
1986
|
-
message[type_size:type_size + latency_size])[0]
|
|
1987
|
-
latency_seconds = avg_latency_ms / 1000
|
|
1988
|
-
metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels(pid=os.getpid()).observe(latency_seconds) # pylint: disable=line-too-long
|
|
1989
|
-
continue
|
|
1990
|
-
else:
|
|
1991
|
-
# Unknown message type.
|
|
1992
|
-
raise ValueError(
|
|
1993
|
-
f'Unknown message type: {message_type}')
|
|
1994
|
-
writer.write(message)
|
|
1995
|
-
try:
|
|
1996
|
-
await writer.drain()
|
|
1997
|
-
except Exception as e: # pylint: disable=broad-except
|
|
1998
|
-
# Typically we will not reach here, if the ssh to pod
|
|
1999
|
-
# is disconnected, ssh_to_websocket will exit first.
|
|
2000
|
-
# But just in case.
|
|
2001
|
-
logger.error('Failed to write to pod through '
|
|
2002
|
-
f'port-forward connection: {e}')
|
|
2003
|
-
nonlocal ssh_failed
|
|
2004
|
-
ssh_failed = True
|
|
2005
|
-
break
|
|
2006
|
-
except fastapi.WebSocketDisconnect:
|
|
2007
|
-
pass
|
|
2008
|
-
nonlocal websocket_closed
|
|
2009
|
-
websocket_closed = True
|
|
2010
|
-
writer.close()
|
|
2117
|
+
async def write_and_drain(data: bytes) -> None:
|
|
2118
|
+
writer.write(data)
|
|
2119
|
+
await writer.drain()
|
|
2011
2120
|
|
|
2012
|
-
async def
|
|
2013
|
-
|
|
2014
|
-
while True:
|
|
2015
|
-
data = await reader.read(1024)
|
|
2016
|
-
if not data:
|
|
2017
|
-
if not websocket_closed:
|
|
2018
|
-
logger.warning('SSH connection to pod is '
|
|
2019
|
-
'disconnected before websocket '
|
|
2020
|
-
'connection is closed')
|
|
2021
|
-
nonlocal ssh_failed
|
|
2022
|
-
ssh_failed = True
|
|
2023
|
-
break
|
|
2024
|
-
if timestamps_supported:
|
|
2025
|
-
# Prepend message type byte (0 = regular data)
|
|
2026
|
-
message_type_bytes = struct.pack(
|
|
2027
|
-
'!B', KubernetesSSHMessageType.REGULAR_DATA.value)
|
|
2028
|
-
data = message_type_bytes + data
|
|
2029
|
-
await websocket.send_bytes(data)
|
|
2030
|
-
except Exception: # pylint: disable=broad-except
|
|
2031
|
-
pass
|
|
2032
|
-
try:
|
|
2033
|
-
await websocket.close()
|
|
2034
|
-
except Exception: # pylint: disable=broad-except
|
|
2035
|
-
# The websocket might has been closed by the client.
|
|
2036
|
-
pass
|
|
2121
|
+
async def close_writer() -> None:
|
|
2122
|
+
writer.close()
|
|
2037
2123
|
|
|
2038
|
-
await
|
|
2124
|
+
ssh_failed = await _run_websocket_proxy(
|
|
2125
|
+
websocket,
|
|
2126
|
+
read_from_backend=lambda: reader.read(1024),
|
|
2127
|
+
write_to_backend=write_and_drain,
|
|
2128
|
+
close_backend=close_writer,
|
|
2129
|
+
timestamps_supported=timestamps_supported,
|
|
2130
|
+
)
|
|
2039
2131
|
finally:
|
|
2040
2132
|
conn_gauge.dec()
|
|
2041
2133
|
reason = ''
|
|
@@ -2049,7 +2141,7 @@ async def kubernetes_pod_ssh_proxy(
|
|
|
2049
2141
|
f'output: {str(stdout)}')
|
|
2050
2142
|
reason = 'KubectlPortForwardExit'
|
|
2051
2143
|
metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
|
|
2052
|
-
pid=os.getpid(), reason=
|
|
2144
|
+
pid=os.getpid(), reason=reason).inc()
|
|
2053
2145
|
else:
|
|
2054
2146
|
if ssh_failed:
|
|
2055
2147
|
reason = 'SSHToPodDisconnected'
|
|
@@ -2059,6 +2151,235 @@ async def kubernetes_pod_ssh_proxy(
|
|
|
2059
2151
|
pid=os.getpid(), reason=reason).inc()
|
|
2060
2152
|
|
|
2061
2153
|
|
|
2154
|
+
@app.websocket('/slurm-job-ssh-proxy')
|
|
2155
|
+
async def slurm_job_ssh_proxy(websocket: fastapi.WebSocket,
|
|
2156
|
+
cluster_name: str,
|
|
2157
|
+
worker: int = 0,
|
|
2158
|
+
client_version: Optional[int] = None) -> None:
|
|
2159
|
+
"""Proxies SSH to the Slurm job via sshd inside srun."""
|
|
2160
|
+
await websocket.accept()
|
|
2161
|
+
logger.info(f'WebSocket connection accepted for cluster: '
|
|
2162
|
+
f'{cluster_name}')
|
|
2163
|
+
|
|
2164
|
+
timestamps_supported = client_version is not None and client_version > 21
|
|
2165
|
+
logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
|
|
2166
|
+
client_version = {client_version}')
|
|
2167
|
+
|
|
2168
|
+
handle = await _get_cluster_and_validate(cluster_name, clouds.Slurm)
|
|
2169
|
+
|
|
2170
|
+
assert handle.cached_cluster_info is not None, 'Cached cluster info is None'
|
|
2171
|
+
provider_config = handle.cached_cluster_info.provider_config
|
|
2172
|
+
assert provider_config is not None, 'Provider config is None'
|
|
2173
|
+
login_node_ssh_config = provider_config['ssh']
|
|
2174
|
+
login_node_host = login_node_ssh_config['hostname']
|
|
2175
|
+
login_node_port = int(login_node_ssh_config['port'])
|
|
2176
|
+
login_node_user = login_node_ssh_config['user']
|
|
2177
|
+
login_node_key = login_node_ssh_config['private_key']
|
|
2178
|
+
login_node_proxy_command = login_node_ssh_config.get('proxycommand', None)
|
|
2179
|
+
login_node_proxy_jump = login_node_ssh_config.get('proxyjump', None)
|
|
2180
|
+
|
|
2181
|
+
login_node_runner = command_runner.SSHCommandRunner(
|
|
2182
|
+
(login_node_host, login_node_port),
|
|
2183
|
+
login_node_user,
|
|
2184
|
+
login_node_key,
|
|
2185
|
+
ssh_proxy_command=login_node_proxy_command,
|
|
2186
|
+
ssh_proxy_jump=login_node_proxy_jump,
|
|
2187
|
+
)
|
|
2188
|
+
|
|
2189
|
+
ssh_cmd = login_node_runner.ssh_base_command(
|
|
2190
|
+
ssh_mode=command_runner.SshMode.NON_INTERACTIVE,
|
|
2191
|
+
port_forward=None,
|
|
2192
|
+
connect_timeout=None)
|
|
2193
|
+
|
|
2194
|
+
# There can only be one InstanceInfo per instance_id.
|
|
2195
|
+
head_instance = handle.cached_cluster_info.get_head_instance()
|
|
2196
|
+
assert head_instance is not None, 'Head instance is None'
|
|
2197
|
+
job_id = head_instance.tags['job_id']
|
|
2198
|
+
|
|
2199
|
+
# Instances are ordered: head first, then workers
|
|
2200
|
+
instances = handle.cached_cluster_info.instances
|
|
2201
|
+
node_hostnames = [inst[0].tags['node'] for inst in instances.values()]
|
|
2202
|
+
if worker >= len(node_hostnames):
|
|
2203
|
+
raise fastapi.HTTPException(
|
|
2204
|
+
status_code=400,
|
|
2205
|
+
detail=f'Worker index {worker} out of range. '
|
|
2206
|
+
f'Cluster has {len(node_hostnames)} nodes.')
|
|
2207
|
+
target_node = node_hostnames[worker]
|
|
2208
|
+
|
|
2209
|
+
# Run sshd inside the Slurm job "container" via srun, such that it inherits
|
|
2210
|
+
# the resource constraints of the Slurm job.
|
|
2211
|
+
ssh_cmd += [
|
|
2212
|
+
shlex.quote(
|
|
2213
|
+
slurm_utils.srun_sshd_command(job_id, target_node, login_node_user))
|
|
2214
|
+
]
|
|
2215
|
+
|
|
2216
|
+
proc = await asyncio.create_subprocess_shell(
|
|
2217
|
+
' '.join(ssh_cmd),
|
|
2218
|
+
stdin=asyncio.subprocess.PIPE,
|
|
2219
|
+
stdout=asyncio.subprocess.PIPE,
|
|
2220
|
+
stderr=asyncio.subprocess.PIPE, # Capture stderr separately for logging
|
|
2221
|
+
)
|
|
2222
|
+
assert proc.stdin is not None
|
|
2223
|
+
assert proc.stdout is not None
|
|
2224
|
+
assert proc.stderr is not None
|
|
2225
|
+
|
|
2226
|
+
stdin = proc.stdin
|
|
2227
|
+
stdout = proc.stdout
|
|
2228
|
+
stderr = proc.stderr
|
|
2229
|
+
|
|
2230
|
+
async def log_stderr():
|
|
2231
|
+
while True:
|
|
2232
|
+
line = await stderr.readline()
|
|
2233
|
+
if not line:
|
|
2234
|
+
break
|
|
2235
|
+
logger.debug(f'srun stderr: {line.decode().rstrip()}')
|
|
2236
|
+
|
|
2237
|
+
stderr_task = None
|
|
2238
|
+
if env_options.Options.SHOW_DEBUG_INFO.get():
|
|
2239
|
+
stderr_task = asyncio.create_task(log_stderr())
|
|
2240
|
+
conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
|
|
2241
|
+
pid=os.getpid())
|
|
2242
|
+
ssh_failed = False
|
|
2243
|
+
try:
|
|
2244
|
+
conn_gauge.inc()
|
|
2245
|
+
|
|
2246
|
+
async def write_and_drain(data: bytes) -> None:
|
|
2247
|
+
stdin.write(data)
|
|
2248
|
+
await stdin.drain()
|
|
2249
|
+
|
|
2250
|
+
async def close_stdin() -> None:
|
|
2251
|
+
stdin.close()
|
|
2252
|
+
|
|
2253
|
+
ssh_failed = await _run_websocket_proxy(
|
|
2254
|
+
websocket,
|
|
2255
|
+
read_from_backend=lambda: stdout.read(4096),
|
|
2256
|
+
write_to_backend=write_and_drain,
|
|
2257
|
+
close_backend=close_stdin,
|
|
2258
|
+
timestamps_supported=timestamps_supported,
|
|
2259
|
+
)
|
|
2260
|
+
|
|
2261
|
+
finally:
|
|
2262
|
+
conn_gauge.dec()
|
|
2263
|
+
reason = ''
|
|
2264
|
+
try:
|
|
2265
|
+
logger.info('Terminating srun process')
|
|
2266
|
+
proc.terminate()
|
|
2267
|
+
except ProcessLookupError:
|
|
2268
|
+
stdout_data = await stdout.read()
|
|
2269
|
+
logger.error('srun process was terminated before the '
|
|
2270
|
+
'ssh websocket connection was closed. Remaining '
|
|
2271
|
+
f'output: {str(stdout_data)}')
|
|
2272
|
+
reason = 'SrunProcessExit'
|
|
2273
|
+
metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
|
|
2274
|
+
pid=os.getpid(), reason=reason).inc()
|
|
2275
|
+
else:
|
|
2276
|
+
if ssh_failed:
|
|
2277
|
+
reason = 'SSHToSlurmJobDisconnected'
|
|
2278
|
+
else:
|
|
2279
|
+
reason = 'ClientClosed'
|
|
2280
|
+
|
|
2281
|
+
metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
|
|
2282
|
+
pid=os.getpid(), reason=reason).inc()
|
|
2283
|
+
|
|
2284
|
+
# Cancel the stderr logging task if it's still running
|
|
2285
|
+
if stderr_task is not None and not stderr_task.done():
|
|
2286
|
+
stderr_task.cancel()
|
|
2287
|
+
try:
|
|
2288
|
+
await stderr_task
|
|
2289
|
+
except asyncio.CancelledError:
|
|
2290
|
+
pass
|
|
2291
|
+
|
|
2292
|
+
|
|
2293
|
+
@app.websocket('/ssh-interactive-auth')
|
|
2294
|
+
async def ssh_interactive_auth(websocket: fastapi.WebSocket,
|
|
2295
|
+
session_id: str) -> None:
|
|
2296
|
+
"""Proxies PTY for SSH interactive authentication via websocket.
|
|
2297
|
+
|
|
2298
|
+
This endpoint receives a PTY file descriptor from a worker process
|
|
2299
|
+
and bridges it bidirectionally with a websocket connection, allowing
|
|
2300
|
+
the client to handle interactive SSH authentication (e.g., 2FA).
|
|
2301
|
+
|
|
2302
|
+
Detects auth completion by monitoring terminal echo state and data flow.
|
|
2303
|
+
"""
|
|
2304
|
+
await websocket.accept()
|
|
2305
|
+
logger.info(f'WebSocket connection accepted for SSH auth session: '
|
|
2306
|
+
f'{session_id}')
|
|
2307
|
+
|
|
2308
|
+
loop = asyncio.get_running_loop()
|
|
2309
|
+
|
|
2310
|
+
# Connect to worker process to receive PTY file descriptor
|
|
2311
|
+
fd_socket_path = interactive_utils.get_pty_socket_path(session_id)
|
|
2312
|
+
fd_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
2313
|
+
master_fd = -1
|
|
2314
|
+
try:
|
|
2315
|
+
# Connect to worker's FD-passing socket
|
|
2316
|
+
await loop.sock_connect(fd_sock, fd_socket_path)
|
|
2317
|
+
master_fd = await loop.run_in_executor(None, interactive_utils.recv_fd,
|
|
2318
|
+
fd_sock)
|
|
2319
|
+
logger.debug(f'Received PTY master fd {master_fd} for session '
|
|
2320
|
+
f'{session_id}')
|
|
2321
|
+
|
|
2322
|
+
# Bridge PTY ↔ websocket bidirectionally
|
|
2323
|
+
async def websocket_to_pty():
|
|
2324
|
+
"""Forward websocket messages to PTY."""
|
|
2325
|
+
try:
|
|
2326
|
+
async for message in websocket.iter_bytes():
|
|
2327
|
+
await loop.run_in_executor(None, os.write, master_fd,
|
|
2328
|
+
message)
|
|
2329
|
+
except fastapi.WebSocketDisconnect:
|
|
2330
|
+
logger.debug(f'WebSocket disconnected for session {session_id}')
|
|
2331
|
+
except asyncio.CancelledError:
|
|
2332
|
+
pass
|
|
2333
|
+
except Exception as e: # pylint: disable=broad-except
|
|
2334
|
+
logger.error(f'Error in websocket_to_pty: {e}')
|
|
2335
|
+
|
|
2336
|
+
async def pty_to_websocket():
|
|
2337
|
+
"""Forward PTY output to websocket and detect auth completion.
|
|
2338
|
+
|
|
2339
|
+
Detects auth completion by monitoring terminal echo state.
|
|
2340
|
+
Echo is disabled during password prompts and enabled after
|
|
2341
|
+
successful authentication. Auth is considered complete when
|
|
2342
|
+
echo has been enabled for a sustained period (1s).
|
|
2343
|
+
"""
|
|
2344
|
+
try:
|
|
2345
|
+
while True:
|
|
2346
|
+
try:
|
|
2347
|
+
data = await loop.run_in_executor(
|
|
2348
|
+
None, os.read, master_fd, 4096)
|
|
2349
|
+
except OSError as e:
|
|
2350
|
+
logger.error(f'PTY read error (likely closed): {e}')
|
|
2351
|
+
break
|
|
2352
|
+
|
|
2353
|
+
if not data:
|
|
2354
|
+
break
|
|
2355
|
+
|
|
2356
|
+
await websocket.send_bytes(data)
|
|
2357
|
+
except asyncio.CancelledError:
|
|
2358
|
+
pass
|
|
2359
|
+
except Exception as e: # pylint: disable=broad-except
|
|
2360
|
+
logger.error(f'Error in pty_to_websocket: {e}')
|
|
2361
|
+
finally:
|
|
2362
|
+
try:
|
|
2363
|
+
await websocket.close()
|
|
2364
|
+
except Exception: # pylint: disable=broad-except
|
|
2365
|
+
pass
|
|
2366
|
+
|
|
2367
|
+
await asyncio.gather(websocket_to_pty(), pty_to_websocket())
|
|
2368
|
+
|
|
2369
|
+
except Exception as e: # pylint: disable=broad-except
|
|
2370
|
+
logger.error(f'Error in SSH interactive auth websocket: {e}')
|
|
2371
|
+
raise
|
|
2372
|
+
finally:
|
|
2373
|
+
# Clean up
|
|
2374
|
+
if master_fd >= 0:
|
|
2375
|
+
try:
|
|
2376
|
+
os.close(master_fd)
|
|
2377
|
+
except OSError:
|
|
2378
|
+
pass
|
|
2379
|
+
fd_sock.close()
|
|
2380
|
+
logger.debug(f'SSH interactive auth session {session_id} completed')
|
|
2381
|
+
|
|
2382
|
+
|
|
2062
2383
|
@app.get('/all_contexts')
|
|
2063
2384
|
async def all_contexts(request: fastapi.Request) -> None:
|
|
2064
2385
|
"""Gets all Kubernetes and SSH node pool contexts."""
|
|
@@ -2229,6 +2550,9 @@ if __name__ == '__main__':
|
|
|
2229
2550
|
# Restore the server user hash
|
|
2230
2551
|
logger.info('Initializing server user hash')
|
|
2231
2552
|
_init_or_restore_server_user_hash()
|
|
2553
|
+
logger.info('Initializing permission service')
|
|
2554
|
+
permission.permission_service.initialize()
|
|
2555
|
+
logger.info('Permission service initialized')
|
|
2232
2556
|
|
|
2233
2557
|
max_db_connections = global_user_state.get_max_db_connections()
|
|
2234
2558
|
logger.info(f'Max db connections: {max_db_connections}')
|
|
@@ -2265,6 +2589,9 @@ if __name__ == '__main__':
|
|
|
2265
2589
|
global_tasks.append(
|
|
2266
2590
|
background.create_task(
|
|
2267
2591
|
global_user_state.cluster_event_retention_daemon()))
|
|
2592
|
+
global_tasks.append(
|
|
2593
|
+
background.create_task(
|
|
2594
|
+
managed_job_state.job_event_retention_daemon()))
|
|
2268
2595
|
threading.Thread(target=background.run_forever, daemon=True).start()
|
|
2269
2596
|
|
|
2270
2597
|
queue_server, workers = executor.start(config)
|
sky/server/uvicorn.py
CHANGED
|
@@ -20,6 +20,7 @@ from uvicorn.supervisors import multiprocess
|
|
|
20
20
|
from sky import sky_logging
|
|
21
21
|
from sky.server import daemons
|
|
22
22
|
from sky.server import metrics as metrics_lib
|
|
23
|
+
from sky.server import plugins
|
|
23
24
|
from sky.server import state
|
|
24
25
|
from sky.server.requests import requests as requests_lib
|
|
25
26
|
from sky.skylet import constants
|
|
@@ -237,6 +238,10 @@ def run(config: uvicorn.Config, max_db_connections: Optional[int] = None):
|
|
|
237
238
|
server = Server(config=config, max_db_connections=max_db_connections)
|
|
238
239
|
try:
|
|
239
240
|
if config.workers is not None and config.workers > 1:
|
|
241
|
+
# When workers > 1, uvicorn does not run server app in the main
|
|
242
|
+
# process. In this case, plugins are not loaded at this point, so
|
|
243
|
+
# load plugins here without uvicorn app.
|
|
244
|
+
plugins.load_plugins(plugins.ExtensionContext())
|
|
240
245
|
sock = config.bind_socket()
|
|
241
246
|
SlowStartMultiprocess(config, target=server.run,
|
|
242
247
|
sockets=[sock]).run()
|