skypilot-nightly 1.0.0.dev20251203__py3-none-any.whl → 1.0.0.dev20260112__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +6 -2
- sky/adaptors/aws.py +1 -61
- sky/adaptors/slurm.py +565 -0
- sky/backends/backend_utils.py +95 -12
- sky/backends/cloud_vm_ray_backend.py +224 -65
- sky/backends/task_codegen.py +380 -4
- sky/catalog/__init__.py +0 -3
- sky/catalog/data_fetchers/fetch_gcp.py +9 -1
- sky/catalog/data_fetchers/fetch_nebius.py +1 -1
- sky/catalog/data_fetchers/fetch_vast.py +4 -2
- sky/catalog/kubernetes_catalog.py +12 -4
- sky/catalog/seeweb_catalog.py +30 -15
- sky/catalog/shadeform_catalog.py +5 -2
- sky/catalog/slurm_catalog.py +236 -0
- sky/catalog/vast_catalog.py +30 -6
- sky/check.py +25 -11
- sky/client/cli/command.py +391 -32
- sky/client/interactive_utils.py +190 -0
- sky/client/sdk.py +64 -2
- sky/client/sdk_async.py +9 -0
- sky/clouds/__init__.py +2 -0
- sky/clouds/aws.py +60 -2
- sky/clouds/azure.py +2 -0
- sky/clouds/cloud.py +7 -0
- sky/clouds/kubernetes.py +2 -0
- sky/clouds/runpod.py +38 -7
- sky/clouds/slurm.py +610 -0
- sky/clouds/ssh.py +3 -2
- sky/clouds/vast.py +39 -16
- sky/core.py +197 -37
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/3nu-b8raeKRNABZ2d4GAG/_buildManifest.js +1 -0
- sky/dashboard/out/_next/static/chunks/1871-0565f8975a7dcd10.js +6 -0
- sky/dashboard/out/_next/static/chunks/2109-55a1546d793574a7.js +11 -0
- sky/dashboard/out/_next/static/chunks/2521-099b07cd9e4745bf.js +26 -0
- sky/dashboard/out/_next/static/chunks/2755.a636e04a928a700e.js +31 -0
- sky/dashboard/out/_next/static/chunks/3495.05eab4862217c1a5.js +6 -0
- sky/dashboard/out/_next/static/chunks/3785.cfc5dcc9434fd98c.js +1 -0
- sky/dashboard/out/_next/static/chunks/3850-fd5696f3bbbaddae.js +1 -0
- sky/dashboard/out/_next/static/chunks/3981.645d01bf9c8cad0c.js +21 -0
- sky/dashboard/out/_next/static/chunks/4083-0115d67c1fb57d6c.js +21 -0
- sky/dashboard/out/_next/static/chunks/{8640.5b9475a2d18c5416.js → 429.a58e9ba9742309ed.js} +2 -2
- sky/dashboard/out/_next/static/chunks/4555.8e221537181b5dc1.js +6 -0
- sky/dashboard/out/_next/static/chunks/4725.937865b81fdaaebb.js +6 -0
- sky/dashboard/out/_next/static/chunks/6082-edabd8f6092300ce.js +25 -0
- sky/dashboard/out/_next/static/chunks/6989-49cb7dca83a7a62d.js +1 -0
- sky/dashboard/out/_next/static/chunks/6990-630bd2a2257275f8.js +1 -0
- sky/dashboard/out/_next/static/chunks/7248-a99800d4db8edabd.js +1 -0
- sky/dashboard/out/_next/static/chunks/754-cfc5d4ad1b843d29.js +18 -0
- sky/dashboard/out/_next/static/chunks/8050-dd8aa107b17dce00.js +16 -0
- sky/dashboard/out/_next/static/chunks/8056-d4ae1e0cb81e7368.js +1 -0
- sky/dashboard/out/_next/static/chunks/8555.011023e296c127b3.js +6 -0
- sky/dashboard/out/_next/static/chunks/8821-93c25df904a8362b.js +1 -0
- sky/dashboard/out/_next/static/chunks/8969-0662594b69432ade.js +1 -0
- sky/dashboard/out/_next/static/chunks/9025.f15c91c97d124a5f.js +6 -0
- sky/dashboard/out/_next/static/chunks/9353-7ad6bd01858556f1.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/_app-5a86569acad99764.js +34 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-8297476714acb4ac.js +6 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-337c3ba1085f1210.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/{clusters-ee39056f9851a3ff.js → clusters-57632ff3684a8b5c.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{config-dfb9bf07b13045f4.js → config-718cdc365de82689.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/infra/[context]-5fd3a453c079c2ea.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/infra-9f85c02c9c6cae9e.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-90f16972cbecf354.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-2dd42fc37aad427a.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs-ed806aeace26b972.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/plugins/[...slug]-449a9f5a3bb20fb3.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/users-bec34706b36f3524.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/{volumes-b84b948ff357c43e.js → volumes-a83ba9b38dff7ea9.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-84a40f8c7c627fe4.js → [name]-c781e9c3e52ef9fc.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces-91e0942f47310aae.js +1 -0
- sky/dashboard/out/_next/static/chunks/webpack-cfe59cf684ee13b9.js +1 -0
- sky/dashboard/out/_next/static/css/b0dbca28f027cc19.css +3 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/plugins/[...slug].html +1 -0
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/volumes.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/data/data_utils.py +26 -12
- sky/data/mounting_utils.py +44 -5
- sky/global_user_state.py +111 -19
- sky/jobs/client/sdk.py +8 -3
- sky/jobs/controller.py +191 -31
- sky/jobs/recovery_strategy.py +109 -11
- sky/jobs/server/core.py +81 -4
- sky/jobs/server/server.py +14 -0
- sky/jobs/state.py +417 -19
- sky/jobs/utils.py +73 -80
- sky/models.py +11 -0
- sky/optimizer.py +8 -6
- sky/provision/__init__.py +12 -9
- sky/provision/common.py +20 -0
- sky/provision/docker_utils.py +15 -2
- sky/provision/kubernetes/utils.py +163 -20
- sky/provision/kubernetes/volume.py +52 -17
- sky/provision/provisioner.py +17 -7
- sky/provision/runpod/instance.py +3 -1
- sky/provision/runpod/utils.py +13 -1
- sky/provision/runpod/volume.py +25 -9
- sky/provision/slurm/__init__.py +12 -0
- sky/provision/slurm/config.py +13 -0
- sky/provision/slurm/instance.py +618 -0
- sky/provision/slurm/utils.py +689 -0
- sky/provision/vast/instance.py +4 -1
- sky/provision/vast/utils.py +11 -6
- sky/resources.py +135 -13
- sky/schemas/api/responses.py +4 -0
- sky/schemas/db/global_user_state/010_save_ssh_key.py +1 -1
- sky/schemas/db/spot_jobs/008_add_full_resources.py +34 -0
- sky/schemas/db/spot_jobs/009_job_events.py +32 -0
- sky/schemas/db/spot_jobs/010_job_events_timestamp_with_timezone.py +43 -0
- sky/schemas/db/spot_jobs/011_add_links.py +34 -0
- sky/schemas/generated/jobsv1_pb2.py +9 -5
- sky/schemas/generated/jobsv1_pb2.pyi +12 -0
- sky/schemas/generated/jobsv1_pb2_grpc.py +44 -0
- sky/schemas/generated/managed_jobsv1_pb2.py +32 -28
- sky/schemas/generated/managed_jobsv1_pb2.pyi +11 -2
- sky/serve/serve_utils.py +232 -40
- sky/serve/server/impl.py +1 -1
- sky/server/common.py +17 -0
- sky/server/constants.py +1 -1
- sky/server/metrics.py +6 -3
- sky/server/plugins.py +238 -0
- sky/server/requests/executor.py +5 -2
- sky/server/requests/payloads.py +30 -1
- sky/server/requests/request_names.py +4 -0
- sky/server/requests/requests.py +33 -11
- sky/server/requests/serializers/encoders.py +22 -0
- sky/server/requests/serializers/return_value_serializers.py +70 -0
- sky/server/server.py +506 -109
- sky/server/server_utils.py +30 -0
- sky/server/uvicorn.py +5 -0
- sky/setup_files/MANIFEST.in +1 -0
- sky/setup_files/dependencies.py +22 -9
- sky/sky_logging.py +2 -1
- sky/skylet/attempt_skylet.py +13 -3
- sky/skylet/constants.py +55 -13
- sky/skylet/events.py +10 -4
- sky/skylet/executor/__init__.py +1 -0
- sky/skylet/executor/slurm.py +187 -0
- sky/skylet/job_lib.py +91 -5
- sky/skylet/log_lib.py +22 -6
- sky/skylet/log_lib.pyi +8 -6
- sky/skylet/services.py +18 -3
- sky/skylet/skylet.py +5 -1
- sky/skylet/subprocess_daemon.py +2 -1
- sky/ssh_node_pools/constants.py +12 -0
- sky/ssh_node_pools/core.py +40 -3
- sky/ssh_node_pools/deploy/__init__.py +4 -0
- sky/{utils/kubernetes/deploy_ssh_node_pools.py → ssh_node_pools/deploy/deploy.py} +279 -504
- sky/ssh_node_pools/deploy/tunnel/ssh-tunnel.sh +379 -0
- sky/ssh_node_pools/deploy/tunnel_utils.py +199 -0
- sky/ssh_node_pools/deploy/utils.py +173 -0
- sky/ssh_node_pools/server.py +11 -13
- sky/{utils/kubernetes/ssh_utils.py → ssh_node_pools/utils.py} +9 -6
- sky/templates/kubernetes-ray.yml.j2 +12 -6
- sky/templates/slurm-ray.yml.j2 +115 -0
- sky/templates/vast-ray.yml.j2 +1 -0
- sky/templates/websocket_proxy.py +18 -41
- sky/users/model.conf +1 -1
- sky/users/permission.py +85 -52
- sky/users/rbac.py +31 -3
- sky/utils/annotations.py +108 -8
- sky/utils/auth_utils.py +42 -0
- sky/utils/cli_utils/status_utils.py +19 -5
- sky/utils/cluster_utils.py +10 -3
- sky/utils/command_runner.py +389 -35
- sky/utils/command_runner.pyi +43 -4
- sky/utils/common_utils.py +47 -31
- sky/utils/context.py +32 -0
- sky/utils/db/db_utils.py +36 -6
- sky/utils/db/migration_utils.py +41 -21
- sky/utils/infra_utils.py +5 -1
- sky/utils/instance_links.py +139 -0
- sky/utils/interactive_utils.py +49 -0
- sky/utils/kubernetes/generate_kubeconfig.sh +42 -33
- sky/utils/kubernetes/kubernetes_deploy_utils.py +2 -94
- sky/utils/kubernetes/rsync_helper.sh +5 -1
- sky/utils/kubernetes/ssh-tunnel.sh +7 -376
- sky/utils/plugin_extensions/__init__.py +14 -0
- sky/utils/plugin_extensions/external_failure_source.py +176 -0
- sky/utils/resources_utils.py +10 -8
- sky/utils/rich_utils.py +9 -11
- sky/utils/schemas.py +93 -19
- sky/utils/status_lib.py +7 -0
- sky/utils/subprocess_utils.py +17 -0
- sky/volumes/client/sdk.py +6 -3
- sky/volumes/server/core.py +65 -27
- sky_templates/ray/start_cluster +8 -4
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/METADATA +67 -59
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/RECORD +208 -180
- sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +0 -1
- sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +0 -11
- sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +0 -6
- sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +0 -1
- sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +0 -1
- sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +0 -15
- sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +0 -26
- sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +0 -1
- sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +0 -1
- sky/dashboard/out/_next/static/chunks/3800-7b45f9fbb6308557.js +0 -1
- sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +0 -1
- sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +0 -1
- sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +0 -15
- sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +0 -13
- sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +0 -1
- sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +0 -1
- sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +0 -1
- sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +0 -30
- sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +0 -41
- sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +0 -1
- sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +0 -1
- sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +0 -6
- sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +0 -1
- sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +0 -31
- sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +0 -30
- sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +0 -34
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-abfcac9c137aa543.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-9faf940b253e3e06.js +0 -21
- sky/dashboard/out/_next/static/chunks/pages/jobs-2072b48b617989c9.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/users-f42674164aa73423.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces-531b2f8c4bf89f82.js +0 -1
- sky/dashboard/out/_next/static/chunks/webpack-64e05f17bf2cf8ce.js +0 -1
- sky/dashboard/out/_next/static/css/0748ce22df867032.css +0 -3
- /sky/dashboard/out/_next/static/{96_E2yl3QAiIJGOYCkSpB → 3nu-b8raeKRNABZ2d4GAG}/_ssgManifest.js +0 -0
- /sky/{utils/kubernetes → ssh_node_pools/deploy/tunnel}/cleanup-tunnel.sh +0 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20251203.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""Utilities for SSH Node Pools Deployment"""
|
|
2
|
+
import os
|
|
3
|
+
import subprocess
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
import colorama
|
|
7
|
+
|
|
8
|
+
from sky import sky_logging
|
|
9
|
+
from sky.utils import ux_utils
|
|
10
|
+
|
|
11
|
+
logger = sky_logging.init_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def check_ssh_cluster_dependencies(
|
|
15
|
+
raise_error: bool = True) -> Optional[List[str]]:
|
|
16
|
+
"""Checks if the dependencies for ssh cluster are installed.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
raise_error: set to true when the dependency needs to be present.
|
|
20
|
+
set to false for `sky check`, where reason strings are compiled
|
|
21
|
+
at the end.
|
|
22
|
+
|
|
23
|
+
Returns: the reasons list if there are missing dependencies.
|
|
24
|
+
"""
|
|
25
|
+
# error message
|
|
26
|
+
jq_message = ('`jq` is required to setup ssh cluster.')
|
|
27
|
+
|
|
28
|
+
# save
|
|
29
|
+
reasons = []
|
|
30
|
+
required_binaries = []
|
|
31
|
+
|
|
32
|
+
# Ensure jq is installed
|
|
33
|
+
try:
|
|
34
|
+
subprocess.run(['jq', '--version'],
|
|
35
|
+
stdout=subprocess.DEVNULL,
|
|
36
|
+
stderr=subprocess.DEVNULL,
|
|
37
|
+
check=True)
|
|
38
|
+
except (FileNotFoundError, subprocess.CalledProcessError):
|
|
39
|
+
required_binaries.append('jq')
|
|
40
|
+
reasons.append(jq_message)
|
|
41
|
+
|
|
42
|
+
if required_binaries:
|
|
43
|
+
reasons.extend([
|
|
44
|
+
'On Debian/Ubuntu, install the missing dependenc(ies) with:',
|
|
45
|
+
f' $ sudo apt install {" ".join(required_binaries)}',
|
|
46
|
+
'On MacOS, install with: ',
|
|
47
|
+
f' $ brew install {" ".join(required_binaries)}',
|
|
48
|
+
])
|
|
49
|
+
if raise_error:
|
|
50
|
+
with ux_utils.print_exception_no_traceback():
|
|
51
|
+
raise RuntimeError('\n'.join(reasons))
|
|
52
|
+
return reasons
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def run_command(cmd, shell=False, silent=False):
|
|
57
|
+
"""Run a local command and return the output."""
|
|
58
|
+
process = subprocess.run(cmd,
|
|
59
|
+
shell=shell,
|
|
60
|
+
capture_output=True,
|
|
61
|
+
text=True,
|
|
62
|
+
check=False)
|
|
63
|
+
if process.returncode != 0:
|
|
64
|
+
if not silent:
|
|
65
|
+
logger.error(f'{colorama.Fore.RED}Error executing command: {cmd}\n'
|
|
66
|
+
f'{colorama.Style.RESET_ALL}STDOUT: {process.stdout}\n'
|
|
67
|
+
f'STDERR: {process.stderr}')
|
|
68
|
+
return None
|
|
69
|
+
return process.stdout.strip()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_effective_host_ip(hostname: str) -> str:
|
|
73
|
+
"""Get the effective IP for a hostname from SSH config."""
|
|
74
|
+
try:
|
|
75
|
+
result = subprocess.run(['ssh', '-G', hostname],
|
|
76
|
+
capture_output=True,
|
|
77
|
+
text=True,
|
|
78
|
+
check=False)
|
|
79
|
+
if result.returncode == 0:
|
|
80
|
+
for line in result.stdout.splitlines():
|
|
81
|
+
if line.startswith('hostname '):
|
|
82
|
+
return line.split(' ', 1)[1].strip()
|
|
83
|
+
except Exception: # pylint: disable=broad-except
|
|
84
|
+
pass
|
|
85
|
+
return hostname # Return the original hostname if lookup fails
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def run_remote(node,
|
|
89
|
+
cmd,
|
|
90
|
+
user='',
|
|
91
|
+
ssh_key='',
|
|
92
|
+
connect_timeout=30,
|
|
93
|
+
use_ssh_config=False,
|
|
94
|
+
print_output=False,
|
|
95
|
+
use_shell=False,
|
|
96
|
+
silent=False):
|
|
97
|
+
"""Run a command on a remote machine via SSH."""
|
|
98
|
+
ssh_cmd: List[str]
|
|
99
|
+
if use_ssh_config:
|
|
100
|
+
# Use SSH config for connection parameters
|
|
101
|
+
ssh_cmd = ['ssh', node, cmd]
|
|
102
|
+
else:
|
|
103
|
+
# Use explicit parameters
|
|
104
|
+
ssh_cmd = [
|
|
105
|
+
'ssh', '-o', 'StrictHostKeyChecking=no', '-o', 'IdentitiesOnly=yes',
|
|
106
|
+
'-o', f'ConnectTimeout={connect_timeout}', '-o',
|
|
107
|
+
'ServerAliveInterval=10', '-o', 'ServerAliveCountMax=3'
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
if ssh_key:
|
|
111
|
+
if not os.path.isfile(ssh_key):
|
|
112
|
+
raise ValueError(f'SSH key not found: {ssh_key}')
|
|
113
|
+
ssh_cmd.extend(['-i', ssh_key])
|
|
114
|
+
|
|
115
|
+
ssh_cmd.append(f'{user}@{node}' if user else node)
|
|
116
|
+
ssh_cmd.append(cmd)
|
|
117
|
+
|
|
118
|
+
subprocess_cmd = ' '.join(ssh_cmd) if use_shell else ssh_cmd
|
|
119
|
+
process = subprocess.run(subprocess_cmd,
|
|
120
|
+
capture_output=True,
|
|
121
|
+
text=True,
|
|
122
|
+
check=False,
|
|
123
|
+
shell=use_shell)
|
|
124
|
+
if process.returncode != 0:
|
|
125
|
+
if not silent:
|
|
126
|
+
logger.error(f'{colorama.Fore.RED}Error executing command {cmd} on '
|
|
127
|
+
f'{node}:{colorama.Style.RESET_ALL} {process.stderr}')
|
|
128
|
+
return None
|
|
129
|
+
if print_output:
|
|
130
|
+
logger.info(process.stdout)
|
|
131
|
+
return process.stdout.strip()
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def ensure_directory_exists(path):
|
|
135
|
+
"""Ensure the directory for the specified file path exists."""
|
|
136
|
+
directory = os.path.dirname(path)
|
|
137
|
+
if directory and not os.path.exists(directory):
|
|
138
|
+
os.makedirs(directory, exist_ok=True)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def check_gpu(node, user, ssh_key, use_ssh_config=False, is_head=False):
|
|
142
|
+
"""Check if a node has a GPU."""
|
|
143
|
+
cmd = ('command -v nvidia-smi &> /dev/null && '
|
|
144
|
+
'nvidia-smi --query-gpu=gpu_name --format=csv,noheader')
|
|
145
|
+
result = run_remote(node,
|
|
146
|
+
cmd,
|
|
147
|
+
user,
|
|
148
|
+
ssh_key,
|
|
149
|
+
use_ssh_config=use_ssh_config,
|
|
150
|
+
silent=True)
|
|
151
|
+
if result is not None:
|
|
152
|
+
# Check that all GPUs have the same type.
|
|
153
|
+
# Currently, SkyPilot does not support heterogeneous GPU node
|
|
154
|
+
# (i.e. more than one GPU type on the same node).
|
|
155
|
+
gpu_names = {
|
|
156
|
+
line.strip() for line in result.splitlines() if line.strip()
|
|
157
|
+
}
|
|
158
|
+
if not gpu_names:
|
|
159
|
+
# This can happen if nvidia-smi returns only whitespace.
|
|
160
|
+
# Set result to None to ensure this function returns False.
|
|
161
|
+
result = None
|
|
162
|
+
elif len(gpu_names) > 1:
|
|
163
|
+
# Sort for a deterministic error message.
|
|
164
|
+
sorted_gpu_names = sorted(list(gpu_names))
|
|
165
|
+
raise RuntimeError(
|
|
166
|
+
f'Node {node} has more than one GPU types '
|
|
167
|
+
f'({", ".join(sorted_gpu_names)}). '
|
|
168
|
+
'SkyPilot does not support a node with multiple GPU types.')
|
|
169
|
+
else:
|
|
170
|
+
logger.info(f'{colorama.Fore.YELLOW}➜ GPU {list(gpu_names)[0]} '
|
|
171
|
+
f'detected on {"head" if is_head else "worker"} '
|
|
172
|
+
f'node ({node}).{colorama.Style.RESET_ALL}')
|
|
173
|
+
return result is not None
|
sky/ssh_node_pools/server.py
CHANGED
|
@@ -4,12 +4,11 @@ from typing import Any, Dict, List
|
|
|
4
4
|
|
|
5
5
|
import fastapi
|
|
6
6
|
|
|
7
|
-
from sky import core as sky_core
|
|
8
7
|
from sky.server.requests import executor
|
|
9
8
|
from sky.server.requests import payloads
|
|
10
9
|
from sky.server.requests import request_names
|
|
11
10
|
from sky.server.requests import requests as requests_lib
|
|
12
|
-
from sky.ssh_node_pools import core
|
|
11
|
+
from sky.ssh_node_pools import core
|
|
13
12
|
from sky.utils import common_utils
|
|
14
13
|
|
|
15
14
|
router = fastapi.APIRouter()
|
|
@@ -19,7 +18,7 @@ router = fastapi.APIRouter()
|
|
|
19
18
|
def get_ssh_node_pools() -> Dict[str, Any]:
|
|
20
19
|
"""Get all SSH Node Pool configurations."""
|
|
21
20
|
try:
|
|
22
|
-
return
|
|
21
|
+
return core.get_all_pools()
|
|
23
22
|
except Exception as e:
|
|
24
23
|
raise fastapi.HTTPException(
|
|
25
24
|
status_code=500,
|
|
@@ -31,7 +30,7 @@ def get_ssh_node_pools() -> Dict[str, Any]:
|
|
|
31
30
|
def update_ssh_node_pools(pools_config: Dict[str, Any]) -> Dict[str, str]:
|
|
32
31
|
"""Update SSH Node Pool configurations."""
|
|
33
32
|
try:
|
|
34
|
-
|
|
33
|
+
core.update_pools(pools_config)
|
|
35
34
|
return {'status': 'success'}
|
|
36
35
|
except Exception as e:
|
|
37
36
|
raise fastapi.HTTPException(status_code=400,
|
|
@@ -43,7 +42,7 @@ def update_ssh_node_pools(pools_config: Dict[str, Any]) -> Dict[str, str]:
|
|
|
43
42
|
def delete_ssh_node_pool(pool_name: str) -> Dict[str, str]:
|
|
44
43
|
"""Delete a SSH Node Pool configuration."""
|
|
45
44
|
try:
|
|
46
|
-
if
|
|
45
|
+
if core.delete_pool(pool_name):
|
|
47
46
|
return {'status': 'success'}
|
|
48
47
|
else:
|
|
49
48
|
raise fastapi.HTTPException(
|
|
@@ -70,8 +69,7 @@ async def upload_ssh_key(request: fastapi.Request) -> Dict[str, str]:
|
|
|
70
69
|
detail='Missing key_name or key_file')
|
|
71
70
|
|
|
72
71
|
key_content = await key_file.read()
|
|
73
|
-
key_path =
|
|
74
|
-
key_content.decode())
|
|
72
|
+
key_path = core.upload_ssh_key(key_name, key_content.decode())
|
|
75
73
|
|
|
76
74
|
return {'status': 'success', 'key_path': key_path}
|
|
77
75
|
except fastapi.HTTPException:
|
|
@@ -87,7 +85,7 @@ async def upload_ssh_key(request: fastapi.Request) -> Dict[str, str]:
|
|
|
87
85
|
def list_ssh_keys() -> List[str]:
|
|
88
86
|
"""List available SSH keys."""
|
|
89
87
|
try:
|
|
90
|
-
return
|
|
88
|
+
return core.list_ssh_keys()
|
|
91
89
|
except Exception as e:
|
|
92
90
|
exception_msg = common_utils.format_exception(e)
|
|
93
91
|
raise fastapi.HTTPException(
|
|
@@ -104,7 +102,7 @@ async def deploy_ssh_node_pool(request: fastapi.Request,
|
|
|
104
102
|
request_id=request.state.request_id,
|
|
105
103
|
request_name=request_names.RequestName.SSH_NODE_POOLS_UP,
|
|
106
104
|
request_body=ssh_up_body,
|
|
107
|
-
func=
|
|
105
|
+
func=core.ssh_up,
|
|
108
106
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
109
107
|
)
|
|
110
108
|
|
|
@@ -129,7 +127,7 @@ async def deploy_ssh_node_pool_general(
|
|
|
129
127
|
request_id=request.state.request_id,
|
|
130
128
|
request_name=request_names.RequestName.SSH_NODE_POOLS_UP,
|
|
131
129
|
request_body=ssh_up_body,
|
|
132
|
-
func=
|
|
130
|
+
func=core.ssh_up,
|
|
133
131
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
134
132
|
)
|
|
135
133
|
|
|
@@ -155,7 +153,7 @@ async def down_ssh_node_pool(request: fastapi.Request,
|
|
|
155
153
|
request_id=request.state.request_id,
|
|
156
154
|
request_name=request_names.RequestName.SSH_NODE_POOLS_DOWN,
|
|
157
155
|
request_body=ssh_up_body,
|
|
158
|
-
func=
|
|
156
|
+
func=core.ssh_up, # Reuse ssh_up function with cleanup=True
|
|
159
157
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
160
158
|
)
|
|
161
159
|
|
|
@@ -183,7 +181,7 @@ async def down_ssh_node_pool_general(
|
|
|
183
181
|
request_id=request.state.request_id,
|
|
184
182
|
request_name=request_names.RequestName.SSH_NODE_POOLS_DOWN,
|
|
185
183
|
request_body=ssh_up_body,
|
|
186
|
-
func=
|
|
184
|
+
func=core.ssh_up, # Reuse ssh_up function with cleanup=True
|
|
187
185
|
schedule_type=requests_lib.ScheduleType.LONG,
|
|
188
186
|
)
|
|
189
187
|
|
|
@@ -206,7 +204,7 @@ def get_ssh_node_pool_status(pool_name: str) -> Dict[str, str]:
|
|
|
206
204
|
try:
|
|
207
205
|
# Call ssh_status to check the context
|
|
208
206
|
context_name = f'ssh-{pool_name}'
|
|
209
|
-
is_ready, reason =
|
|
207
|
+
is_ready, reason = core.ssh_status(context_name)
|
|
210
208
|
|
|
211
209
|
# Strip ANSI escape codes from the reason text
|
|
212
210
|
def strip_ansi_codes(text):
|
|
@@ -5,13 +5,14 @@ import subprocess
|
|
|
5
5
|
from typing import Any, Callable, Dict, List, Optional
|
|
6
6
|
import uuid
|
|
7
7
|
|
|
8
|
+
import colorama
|
|
8
9
|
import yaml
|
|
9
10
|
|
|
11
|
+
from sky import sky_logging
|
|
12
|
+
from sky.ssh_node_pools import constants
|
|
10
13
|
from sky.utils import ux_utils
|
|
11
14
|
|
|
12
|
-
|
|
13
|
-
RED = '\033[0;31m'
|
|
14
|
-
NC = '\033[0m' # No color
|
|
15
|
+
logger = sky_logging.init_logger(__name__)
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def check_host_in_ssh_config(hostname: str) -> bool:
|
|
@@ -92,7 +93,8 @@ def load_ssh_targets(file_path: str) -> Dict[str, Any]:
|
|
|
92
93
|
def get_cluster_config(
|
|
93
94
|
targets: Dict[str, Any],
|
|
94
95
|
cluster_name: Optional[str] = None,
|
|
95
|
-
file_path: str = DEFAULT_SSH_NODE_POOLS_PATH
|
|
96
|
+
file_path: str = constants.DEFAULT_SSH_NODE_POOLS_PATH
|
|
97
|
+
) -> Dict[str, Any]:
|
|
96
98
|
"""Get configuration for specific clusters or all clusters."""
|
|
97
99
|
if not targets:
|
|
98
100
|
with ux_utils.print_exception_no_traceback():
|
|
@@ -186,8 +188,9 @@ def prepare_hosts_info(
|
|
|
186
188
|
else:
|
|
187
189
|
# It's a dict with potential overrides
|
|
188
190
|
if 'ip' not in host:
|
|
189
|
-
|
|
190
|
-
|
|
191
|
+
logger.warning(f'{colorama.Fore.RED}Warning: Host missing'
|
|
192
|
+
f'\'ip\' field, skipping: {host}'
|
|
193
|
+
f'{colorama.Style.RESET_ALL}')
|
|
191
194
|
continue
|
|
192
195
|
|
|
193
196
|
# Check if this is an SSH config hostname
|
|
@@ -523,6 +523,14 @@ available_node_types:
|
|
|
523
523
|
resourceFieldRef:
|
|
524
524
|
containerName: ray-node
|
|
525
525
|
resource: requests.memory
|
|
526
|
+
# Disable Ray memory monitor to prevent Ray's memory manager
|
|
527
|
+
# from interfering with kubernetes resource manager.
|
|
528
|
+
# If ray memory monitor is enabled, the ray memory monitor kills
|
|
529
|
+
# the running job is the job uses more than 95% of allocated memory,
|
|
530
|
+
# even if the job is not misbehaving or using its full allocated memory.
|
|
531
|
+
# This behavior does not give a chance for k8s scheduler to evict the pod.
|
|
532
|
+
- name: RAY_memory_monitor_refresh_ms
|
|
533
|
+
value: "0"
|
|
526
534
|
{% for key, value in k8s_env_vars.items() if k8s_env_vars is not none %}
|
|
527
535
|
- name: {{ key }}
|
|
528
536
|
value: {{ value }}
|
|
@@ -912,19 +920,17 @@ available_node_types:
|
|
|
912
920
|
{{ ray_installation_commands }}
|
|
913
921
|
|
|
914
922
|
# set UV_SYSTEM_PYTHON to false in case the user provided docker image set it to true.
|
|
915
|
-
# unset PYTHONPATH
|
|
916
|
-
VIRTUAL_ENV=~/skypilot-runtime UV_SYSTEM_PYTHON=false
|
|
923
|
+
# unset PYTHONPATH and set CWD to $HOME to avoid user image interfering with SkyPilot runtime.
|
|
924
|
+
VIRTUAL_ENV=~/skypilot-runtime UV_SYSTEM_PYTHON=false {{sky_unset_pythonpath_and_set_cwd}} ~/.local/bin/uv pip install skypilot[kubernetes,remote]
|
|
917
925
|
# Wait for `patch` package to be installed before applying ray patches
|
|
918
926
|
until dpkg -l | grep -q "^ii patch "; do
|
|
919
927
|
sleep 0.1
|
|
920
928
|
echo "Waiting for patch package to be installed..."
|
|
921
929
|
done
|
|
922
930
|
# Apply Ray patches for progress bar fix
|
|
923
|
-
# set UV_SYSTEM_PYTHON to false in case the user provided docker image set it to true.
|
|
924
|
-
# unset PYTHONPATH in case the user provided docker image set it.
|
|
925
931
|
# ~/.sky/python_path is seeded by conda_installation_commands
|
|
926
|
-
VIRTUAL_ENV=~/skypilot-runtime UV_SYSTEM_PYTHON=false
|
|
927
|
-
|
|
932
|
+
VIRTUAL_ENV=~/skypilot-runtime UV_SYSTEM_PYTHON=false {{sky_unset_pythonpath_and_set_cwd}} ~/.local/bin/uv pip list | grep "ray " | grep 2.9.3 2>&1 > /dev/null && {
|
|
933
|
+
{{sky_unset_pythonpath_and_set_cwd}} $(cat ~/.sky/python_path) -c "from sky.skylet.ray_patches import patch; patch()" || exit 1;
|
|
928
934
|
}
|
|
929
935
|
touch /tmp/ray_skypilot_installation_complete
|
|
930
936
|
echo "=== Ray and skypilot installation completed ==="
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
cluster_name: {{cluster_name_on_cloud}}
|
|
2
|
+
|
|
3
|
+
# The maximum number of workers nodes to launch in addition to the head node.
|
|
4
|
+
max_workers: {{num_nodes - 1}}
|
|
5
|
+
upscaling_speed: {{num_nodes - 1}}
|
|
6
|
+
idle_timeout_minutes: 60
|
|
7
|
+
|
|
8
|
+
provider:
|
|
9
|
+
type: external
|
|
10
|
+
module: sky.provision.slurm
|
|
11
|
+
|
|
12
|
+
cluster: {{slurm_cluster}}
|
|
13
|
+
partition: {{slurm_partition}}
|
|
14
|
+
|
|
15
|
+
ssh:
|
|
16
|
+
hostname: {{ssh_hostname}}
|
|
17
|
+
port: {{ssh_port}}
|
|
18
|
+
user: {{ssh_user}}
|
|
19
|
+
private_key: {{slurm_private_key}}
|
|
20
|
+
{% if slurm_proxy_command is not none %}
|
|
21
|
+
proxycommand: {{slurm_proxy_command | tojson }}
|
|
22
|
+
{% endif %}
|
|
23
|
+
{% if slurm_proxy_jump is not none %}
|
|
24
|
+
proxyjump: {{slurm_proxy_jump | tojson }}
|
|
25
|
+
{% endif %}
|
|
26
|
+
|
|
27
|
+
auth:
|
|
28
|
+
ssh_user: {{ssh_user}}
|
|
29
|
+
# TODO(jwj,kevin): Modify this tmp workaround.
|
|
30
|
+
# Right now there's a chicken-and-egg problem:
|
|
31
|
+
# 1. ssh_credential_from_yaml reads from the auth.ssh_private_key: ~/.sky/clients/.../ssh/sky-key
|
|
32
|
+
# 2. This is SkyPilot's generated key, not the Slurm cluster's key
|
|
33
|
+
# 3. The internal_file_mounts stage tries to rsync using sky-key, but its public key isn't on the remote yet
|
|
34
|
+
# 4. The public key only gets added by setup_commands, which runs AFTER file_mounts
|
|
35
|
+
# ssh_private_key: {{ssh_private_key}}
|
|
36
|
+
ssh_private_key: {{slurm_private_key}}
|
|
37
|
+
ssh_proxy_command: {{slurm_proxy_command | tojson }}
|
|
38
|
+
|
|
39
|
+
available_node_types:
|
|
40
|
+
ray_head_default:
|
|
41
|
+
resources: {}
|
|
42
|
+
node_config:
|
|
43
|
+
# From clouds/slurm.py::Slurm.make_deploy_resources_variables.
|
|
44
|
+
instance_type: {{instance_type}}
|
|
45
|
+
disk_size: {{disk_size}}
|
|
46
|
+
cpus: {{cpus}}
|
|
47
|
+
memory: {{memory}}
|
|
48
|
+
accelerator_type: {{accelerator_type}}
|
|
49
|
+
accelerator_count: {{accelerator_count}}
|
|
50
|
+
|
|
51
|
+
# TODO: more configs that is required by the provisioner to create new
|
|
52
|
+
# instances on the FluffyCloud:
|
|
53
|
+
# sky/provision/fluffycloud/instance.py::run_instances
|
|
54
|
+
|
|
55
|
+
head_node_type: ray_head_default
|
|
56
|
+
|
|
57
|
+
# Format: `REMOTE_PATH : LOCAL_PATH`
|
|
58
|
+
file_mounts: {
|
|
59
|
+
"{{sky_ray_yaml_remote_path}}": "{{sky_ray_yaml_local_path}}",
|
|
60
|
+
"{{sky_remote_path}}/{{sky_wheel_hash}}": "{{sky_local_path}}",
|
|
61
|
+
{%- for remote_path, local_path in credentials.items() %}
|
|
62
|
+
"{{remote_path}}": "{{local_path}}",
|
|
63
|
+
{%- endfor %}
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
rsync_exclude: []
|
|
67
|
+
|
|
68
|
+
initialization_commands: []
|
|
69
|
+
|
|
70
|
+
# List of shell commands to run to set up nodes.
|
|
71
|
+
# NOTE: these are very performance-sensitive. Each new item opens/closes an SSH
|
|
72
|
+
# connection, which is expensive. Try your best to co-locate commands into fewer
|
|
73
|
+
# items!
|
|
74
|
+
#
|
|
75
|
+
# Increment the following for catching performance bugs easier:
|
|
76
|
+
# current num items (num SSH connections): 1
|
|
77
|
+
setup_commands:
|
|
78
|
+
- |
|
|
79
|
+
{%- for initial_setup_command in initial_setup_commands %}
|
|
80
|
+
{{ initial_setup_command }}
|
|
81
|
+
{%- endfor %}
|
|
82
|
+
# Generate host key for sshd -i if not exists
|
|
83
|
+
mkdir -p ~{{ssh_user}}/.ssh && chmod 700 ~{{ssh_user}}/.ssh
|
|
84
|
+
[ -f ~{{ssh_user}}/.ssh/{{slurm_sshd_host_key_filename}} ] || ssh-keygen -t ed25519 -f ~{{ssh_user}}/.ssh/{{slurm_sshd_host_key_filename}} -N "" -q
|
|
85
|
+
# Add public key to user's authorized_keys if not already present
|
|
86
|
+
grep -qF 'skypilot:ssh_public_key_content' ~{{ssh_user}}/.ssh/authorized_keys 2>/dev/null || cat >> ~{{ssh_user}}/.ssh/authorized_keys <<'SKYPILOT_SSH_KEY_EOF'
|
|
87
|
+
skypilot:ssh_public_key_content
|
|
88
|
+
SKYPILOT_SSH_KEY_EOF
|
|
89
|
+
chmod 600 ~{{ssh_user}}/.ssh/authorized_keys
|
|
90
|
+
|
|
91
|
+
mkdir -p ~{{ssh_user}}/.sky
|
|
92
|
+
cat > ~{{ssh_user}}/.sky_ssh_rc <<'SKYPILOT_SSH_RC'
|
|
93
|
+
# Added by SkyPilot: override HOME for Slurm interactive sessions
|
|
94
|
+
if [ -n "${{slurm_cluster_name_env_var}}" ]; then
|
|
95
|
+
CLUSTER_DIR=~/.sky_clusters/${{slurm_cluster_name_env_var}}
|
|
96
|
+
if [ -d "$CLUSTER_DIR" ]; then
|
|
97
|
+
cd "$CLUSTER_DIR"
|
|
98
|
+
export HOME=$(pwd)
|
|
99
|
+
fi
|
|
100
|
+
fi
|
|
101
|
+
SKYPILOT_SSH_RC
|
|
102
|
+
grep -q "source ~/.sky_ssh_rc" ~{{ssh_user}}/.bashrc 2>/dev/null || (echo "" >> ~{{ssh_user}}/.bashrc && echo "source ~/.sky_ssh_rc" >> ~{{ssh_user}}/.bashrc)
|
|
103
|
+
{{ setup_sky_dirs_commands }}
|
|
104
|
+
{{ conda_installation_commands }}
|
|
105
|
+
{{ skypilot_wheel_installation_commands }}
|
|
106
|
+
{{ copy_skypilot_templates_commands }}
|
|
107
|
+
|
|
108
|
+
head_node: {}
|
|
109
|
+
worker_nodes: {}
|
|
110
|
+
|
|
111
|
+
# These fields are required for external cloud providers.
|
|
112
|
+
head_setup_commands: []
|
|
113
|
+
worker_setup_commands: []
|
|
114
|
+
cluster_synced_files: []
|
|
115
|
+
file_mounts_sync_continuously: False
|
sky/templates/vast-ray.yml.j2
CHANGED
sky/templates/websocket_proxy.py
CHANGED
|
@@ -9,13 +9,11 @@
|
|
|
9
9
|
This script is useful for users who do not have local Kubernetes credentials.
|
|
10
10
|
"""
|
|
11
11
|
import asyncio
|
|
12
|
-
from http.cookiejar import MozillaCookieJar
|
|
13
12
|
import os
|
|
14
13
|
import struct
|
|
15
14
|
import sys
|
|
16
15
|
import time
|
|
17
16
|
from typing import Dict, Optional
|
|
18
|
-
from urllib.request import Request
|
|
19
17
|
|
|
20
18
|
import requests
|
|
21
19
|
import websockets
|
|
@@ -24,46 +22,19 @@ from websockets.asyncio.client import connect
|
|
|
24
22
|
|
|
25
23
|
from sky import exceptions
|
|
26
24
|
from sky.client import service_account_auth
|
|
25
|
+
from sky.server import common as server_common
|
|
27
26
|
from sky.server import constants
|
|
28
|
-
from sky.server.server import
|
|
27
|
+
from sky.server.server import SSHMessageType
|
|
29
28
|
from sky.skylet import constants as skylet_constants
|
|
30
29
|
|
|
31
30
|
BUFFER_SIZE = 2**16 # 64KB
|
|
32
31
|
HEARTBEAT_INTERVAL_SECONDS = 10
|
|
33
|
-
|
|
34
|
-
# Environment variable for a file path to the API cookie file.
|
|
35
|
-
# Keep in sync with server/constants.py
|
|
36
|
-
API_COOKIE_FILE_ENV_VAR = 'SKYPILOT_API_COOKIE_FILE'
|
|
37
|
-
# Default file if unset.
|
|
38
|
-
# Keep in sync with server/constants.py
|
|
39
|
-
API_COOKIE_FILE_DEFAULT_LOCATION = '~/.sky/cookies.txt'
|
|
40
|
-
|
|
41
32
|
MAX_UNANSWERED_PINGS = 100
|
|
42
33
|
|
|
43
34
|
|
|
44
|
-
def _get_cookie_header(url: str) -> Dict[str, str]:
|
|
45
|
-
"""Extract Cookie header value from a cookie jar for a specific URL"""
|
|
46
|
-
cookie_path = os.environ.get(API_COOKIE_FILE_ENV_VAR)
|
|
47
|
-
if cookie_path is None:
|
|
48
|
-
cookie_path = API_COOKIE_FILE_DEFAULT_LOCATION
|
|
49
|
-
cookie_path = os.path.expanduser(cookie_path)
|
|
50
|
-
if not os.path.exists(cookie_path):
|
|
51
|
-
return {}
|
|
52
|
-
|
|
53
|
-
request = Request(url)
|
|
54
|
-
cookie_jar = MozillaCookieJar(os.path.expanduser(cookie_path))
|
|
55
|
-
cookie_jar.load(ignore_discard=True, ignore_expires=True)
|
|
56
|
-
cookie_jar.add_cookie_header(request)
|
|
57
|
-
cookie_header = request.get_header('Cookie')
|
|
58
|
-
# if cookie file is empty, return empty dict
|
|
59
|
-
if cookie_header is None:
|
|
60
|
-
return {}
|
|
61
|
-
return {'Cookie': cookie_header}
|
|
62
|
-
|
|
63
|
-
|
|
64
35
|
async def main(url: str, timestamps_supported: bool, login_url: str) -> None:
|
|
65
36
|
headers = {}
|
|
66
|
-
headers.update(
|
|
37
|
+
headers.update(server_common.get_cookie_header_for_url(url))
|
|
67
38
|
headers.update(service_account_auth.get_service_account_headers())
|
|
68
39
|
try:
|
|
69
40
|
async with connect(url, ping_interval=None,
|
|
@@ -142,8 +113,9 @@ async def latency_monitor(websocket: ClientConnection,
|
|
|
142
113
|
ping_time = time.time()
|
|
143
114
|
next_id += 1
|
|
144
115
|
last_ping_time_dict[next_id] = ping_time
|
|
145
|
-
message_header_bytes = struct.pack(
|
|
146
|
-
|
|
116
|
+
message_header_bytes = struct.pack('!BI',
|
|
117
|
+
SSHMessageType.PINGPONG.value,
|
|
118
|
+
next_id)
|
|
147
119
|
try:
|
|
148
120
|
async with websocket_lock:
|
|
149
121
|
await websocket.send(message_header_bytes)
|
|
@@ -176,7 +148,7 @@ async def stdin_to_websocket(reader: asyncio.StreamReader,
|
|
|
176
148
|
if timestamps_supported:
|
|
177
149
|
# Send message with type 0 to indicate data.
|
|
178
150
|
message_type_bytes = struct.pack(
|
|
179
|
-
'!B',
|
|
151
|
+
'!B', SSHMessageType.REGULAR_DATA.value)
|
|
180
152
|
data = message_type_bytes + data
|
|
181
153
|
async with websocket_lock:
|
|
182
154
|
await websocket.send(data)
|
|
@@ -201,10 +173,10 @@ async def websocket_to_stdout(websocket: ClientConnection,
|
|
|
201
173
|
if (timestamps_supported and len(message) > 0 and
|
|
202
174
|
last_ping_time_dict is not None):
|
|
203
175
|
message_type = struct.unpack('!B', message[:1])[0]
|
|
204
|
-
if message_type ==
|
|
176
|
+
if message_type == SSHMessageType.REGULAR_DATA.value:
|
|
205
177
|
# Regular data - strip type byte and write to stdout
|
|
206
178
|
message = message[1:]
|
|
207
|
-
elif message_type ==
|
|
179
|
+
elif message_type == SSHMessageType.PINGPONG.value:
|
|
208
180
|
# PONG response - calculate latency and send measurement
|
|
209
181
|
if not len(message) == struct.calcsize('!BI'):
|
|
210
182
|
raise ValueError(
|
|
@@ -222,8 +194,7 @@ async def websocket_to_stdout(websocket: ClientConnection,
|
|
|
222
194
|
|
|
223
195
|
# Send latency measurement (type 2)
|
|
224
196
|
message_type_bytes = struct.pack(
|
|
225
|
-
'!B',
|
|
226
|
-
KubernetesSSHMessageType.LATENCY_MEASUREMENT.value)
|
|
197
|
+
'!B', SSHMessageType.LATENCY_MEASUREMENT.value)
|
|
227
198
|
latency_bytes = struct.pack('!Q', latency_ms)
|
|
228
199
|
message = message_type_bytes + latency_bytes
|
|
229
200
|
# Send to server.
|
|
@@ -255,7 +226,7 @@ if __name__ == '__main__':
|
|
|
255
226
|
# TODO(aylei): remove the separate /api/health call and use the header
|
|
256
227
|
# during websocket handshake to determine the server version.
|
|
257
228
|
health_url = f'{server_url}/api/health'
|
|
258
|
-
cookie_hdr =
|
|
229
|
+
cookie_hdr = server_common.get_cookie_header_for_url(health_url)
|
|
259
230
|
health_response = requests.get(health_url, headers=cookie_hdr)
|
|
260
231
|
health_data = health_response.json()
|
|
261
232
|
timestamps_are_supported = int(health_data.get('api_version', 0)) > 21
|
|
@@ -272,7 +243,13 @@ if __name__ == '__main__':
|
|
|
272
243
|
client_version_str = (f'&client_version={constants.API_VERSION}'
|
|
273
244
|
if timestamps_are_supported else '')
|
|
274
245
|
|
|
275
|
-
|
|
246
|
+
# For backwards compatibility, fallback to kubernetes-pod-ssh-proxy if
|
|
247
|
+
# no endpoint is provided.
|
|
248
|
+
endpoint = sys.argv[3] if len(sys.argv) > 3 else 'kubernetes-pod-ssh-proxy'
|
|
249
|
+
# Worker index for Slurm.
|
|
250
|
+
worker_idx = sys.argv[4] if len(sys.argv) > 4 else '0'
|
|
251
|
+
websocket_url = (f'{server_url}/{endpoint}'
|
|
276
252
|
f'?cluster_name={sys.argv[2]}'
|
|
253
|
+
f'&worker={worker_idx}'
|
|
277
254
|
f'{client_version_str}')
|
|
278
255
|
asyncio.run(main(websocket_url, timestamps_are_supported, _login_url))
|
sky/users/model.conf
CHANGED