skypilot-nightly 1.0.0.dev20250606__py3-none-any.whl → 1.0.0.dev20250609__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +2 -2
- sky/backends/backend_utils.py +21 -3
- sky/check.py +18 -22
- sky/cli.py +5 -8
- sky/client/cli.py +5 -8
- sky/client/sdk.py +2 -1
- sky/clouds/cloud.py +4 -0
- sky/clouds/nebius.py +44 -4
- sky/core.py +3 -2
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/chunks/236-619ed0248fb6fdd9.js +6 -0
- sky/dashboard/out/_next/static/chunks/470-680c19413b8f808b.js +1 -0
- sky/dashboard/out/_next/static/chunks/{614-635a84e87800f99e.js → 63-e2d7b1e75e67c713.js} +8 -8
- sky/dashboard/out/_next/static/chunks/843-16c7194621b2b512.js +11 -0
- sky/dashboard/out/_next/static/chunks/969-2c584e28e6b4b106.js +1 -0
- sky/dashboard/out/_next/static/chunks/973-aed916d5b02d2d63.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/{[job]-65d04d5d77cbb6b6.js → [job]-d31688d3e52736dd.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-35cbeb5214fd4036.js → [cluster]-e7d8710a9b0491e5.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{clusters-5549a350f97d7ef3.js → clusters-3c674e5d970e05cb.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/config-3aac7a015c6eede1.js +6 -0
- sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-b68ddeed712d45b5.js → [context]-46d2e4ad6c487260.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{infra-13b117a831702196.js → infra-7013d816a2a0e76c.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-f7f0c9e156d328bc.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/{jobs-a76b2700eca236f7.js → jobs-87e60396c376292f.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/users-9355a0f13d1db61d.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/workspace/{new-c7516f2b4c3727c0.js → new-9a749cca1813bd27.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-7799de9e691e35d8.js → [name]-8eeb628e03902f1b.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces-8fbcc5ab4af316d0.js +1 -0
- sky/dashboard/out/_next/static/css/8b1c8321d4c02372.css +3 -0
- sky/dashboard/out/_next/static/xos0euNCptbGAM7_Q3Acl/_buildManifest.js +1 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/exceptions.py +5 -0
- sky/global_user_state.py +11 -6
- sky/jobs/scheduler.py +9 -4
- sky/jobs/server/core.py +23 -2
- sky/jobs/server/server.py +0 -95
- sky/jobs/state.py +18 -15
- sky/jobs/utils.py +2 -1
- sky/models.py +18 -0
- sky/provision/kubernetes/utils.py +12 -5
- sky/provision/nebius/constants.py +47 -0
- sky/provision/nebius/instance.py +2 -1
- sky/provision/nebius/utils.py +28 -7
- sky/serve/server/core.py +1 -1
- sky/server/common.py +4 -2
- sky/server/constants.py +0 -2
- sky/server/requests/executor.py +10 -2
- sky/server/requests/requests.py +4 -3
- sky/server/server.py +22 -5
- sky/skylet/constants.py +4 -0
- sky/skylet/job_lib.py +2 -1
- sky/skypilot_config.py +13 -1
- sky/templates/jobs-controller.yaml.j2 +3 -1
- sky/templates/nebius-ray.yml.j2 +6 -0
- sky/users/model.conf +1 -1
- sky/users/permission.py +148 -31
- sky/users/rbac.py +26 -0
- sky/users/server.py +14 -13
- sky/utils/common.py +6 -1
- sky/utils/common_utils.py +21 -3
- sky/utils/kubernetes/deploy_remote_cluster.py +5 -3
- sky/utils/resources_utils.py +3 -1
- sky/utils/schemas.py +9 -0
- sky/workspaces/core.py +100 -8
- sky/workspaces/server.py +15 -2
- sky/workspaces/utils.py +56 -0
- {skypilot_nightly-1.0.0.dev20250606.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/METADATA +1 -1
- {skypilot_nightly-1.0.0.dev20250606.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/RECORD +89 -87
- sky/dashboard/out/_next/static/99m-BAySO8Q7J-ul1jZVL/_buildManifest.js +0 -1
- sky/dashboard/out/_next/static/chunks/236-a90f0a9753a10420.js +0 -6
- sky/dashboard/out/_next/static/chunks/470-9e7a479cc8303baa.js +0 -1
- sky/dashboard/out/_next/static/chunks/843-c296541442d4af88.js +0 -11
- sky/dashboard/out/_next/static/chunks/969-c7abda31c10440ac.js +0 -1
- sky/dashboard/out/_next/static/chunks/973-1a09cac61cfcc1e1.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/config-1a1eeb949dab8897.js +0 -6
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-2d23a9c7571e6320.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/users-262aab38b9baaf3a.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/workspaces-384ea5fa0cea8f28.js +0 -1
- sky/dashboard/out/_next/static/css/667d941a2888ce6e.css +0 -3
- /sky/dashboard/out/_next/static/chunks/{37-beedd583fea84cc8.js → 37-600191c5804dcae2.js} +0 -0
- /sky/dashboard/out/_next/static/chunks/{682-6647f0417d5662f0.js → 682-b60cfdacc15202e8.js} +0 -0
- /sky/dashboard/out/_next/static/chunks/{856-3a32da4b84176f6d.js → 856-affc52adf5403a3a.js} +0 -0
- /sky/dashboard/out/_next/static/chunks/pages/{_app-cb81dc4d27f4d009.js → _app-5f16aba5794ee8e7.js} +0 -0
- /sky/dashboard/out/_next/static/{99m-BAySO8Q7J-ul1jZVL → xos0euNCptbGAM7_Q3Acl}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20250606.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20250606.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250606.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250606.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/top_level.txt +0 -0
sky/jobs/server/core.py
CHANGED
@@ -37,6 +37,7 @@ from sky.utils import status_lib
|
|
37
37
|
from sky.utils import subprocess_utils
|
38
38
|
from sky.utils import timeline
|
39
39
|
from sky.utils import ux_utils
|
40
|
+
from sky.workspaces import core as workspaces_core
|
40
41
|
|
41
42
|
if typing.TYPE_CHECKING:
|
42
43
|
import sky
|
@@ -88,6 +89,9 @@ def launch(
|
|
88
89
|
raise ValueError('Only single-task or chain DAG is '
|
89
90
|
f'allowed for job_launch. Dag: {dag}')
|
90
91
|
dag.validate()
|
92
|
+
|
93
|
+
user_dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)
|
94
|
+
|
91
95
|
dag_utils.maybe_infer_and_fill_dag_and_task_names(dag)
|
92
96
|
|
93
97
|
task_names = set()
|
@@ -175,12 +179,20 @@ def launch(
|
|
175
179
|
controller_utils.translate_local_file_mounts_to_two_hop(
|
176
180
|
task_))
|
177
181
|
|
182
|
+
# Has to use `\` to avoid yapf issue.
|
178
183
|
with tempfile.NamedTemporaryFile(prefix=f'managed-dag-{dag.name}-',
|
179
|
-
mode='w') as f
|
184
|
+
mode='w') as f, \
|
185
|
+
tempfile.NamedTemporaryFile(prefix=f'managed-user-dag-{dag.name}-',
|
186
|
+
mode='w') as original_user_yaml_path:
|
187
|
+
original_user_yaml_path.write(user_dag_str)
|
188
|
+
original_user_yaml_path.flush()
|
189
|
+
|
180
190
|
dag_utils.dump_chain_dag_to_yaml(dag, f.name)
|
181
191
|
controller = controller_utils.Controllers.JOBS_CONTROLLER
|
182
192
|
controller_name = controller.value.cluster_name
|
183
193
|
prefix = managed_job_constants.JOBS_TASK_YAML_PREFIX
|
194
|
+
remote_original_user_yaml_path = (
|
195
|
+
f'{prefix}/{dag.name}-{dag_uuid}.original_user_yaml')
|
184
196
|
remote_user_yaml_path = f'{prefix}/{dag.name}-{dag_uuid}.yaml'
|
185
197
|
remote_user_config_path = f'{prefix}/{dag.name}-{dag_uuid}.config_yaml'
|
186
198
|
remote_env_file_path = f'{prefix}/{dag.name}-{dag_uuid}.env'
|
@@ -189,6 +201,8 @@ def launch(
|
|
189
201
|
task_resources=sum([list(t.resources) for t in dag.tasks], []))
|
190
202
|
|
191
203
|
vars_to_fill = {
|
204
|
+
'remote_original_user_yaml_path': remote_original_user_yaml_path,
|
205
|
+
'original_user_dag_path': original_user_yaml_path.name,
|
192
206
|
'remote_user_yaml_path': remote_user_yaml_path,
|
193
207
|
'user_yaml_path': f.name,
|
194
208
|
'local_to_controller_file_mounts': local_to_controller_file_mounts,
|
@@ -231,7 +245,7 @@ def launch(
|
|
231
245
|
|
232
246
|
# Launch with the api server's user hash, so that sky status does not
|
233
247
|
# show the owner of the controller as whatever user launched it first.
|
234
|
-
with common.
|
248
|
+
with common.with_server_user():
|
235
249
|
# Always launch the controller in the default workspace.
|
236
250
|
with skypilot_config.local_active_workspace_ctx(
|
237
251
|
skylet_constants.SKYPILOT_DEFAULT_WORKSPACE):
|
@@ -442,6 +456,13 @@ def queue(refresh: bool,
|
|
442
456
|
|
443
457
|
jobs = list(filter(user_hash_matches_or_missing, jobs))
|
444
458
|
|
459
|
+
accessible_workspaces = workspaces_core.get_workspaces()
|
460
|
+
jobs = list(
|
461
|
+
filter(
|
462
|
+
lambda job: job.get('workspace', skylet_constants.
|
463
|
+
SKYPILOT_DEFAULT_WORKSPACE) in
|
464
|
+
accessible_workspaces, jobs))
|
465
|
+
|
445
466
|
if skip_finished:
|
446
467
|
# Filter out the finished jobs. If a multi-task job is partially
|
447
468
|
# finished, we will include all its tasks.
|
sky/jobs/server/server.py
CHANGED
@@ -1,12 +1,9 @@
|
|
1
1
|
"""REST API for managed jobs."""
|
2
|
-
import os
|
3
2
|
|
4
3
|
import fastapi
|
5
|
-
import httpx
|
6
4
|
|
7
5
|
from sky import sky_logging
|
8
6
|
from sky.jobs.server import core
|
9
|
-
from sky.jobs.server import dashboard_utils
|
10
7
|
from sky.server import common as server_common
|
11
8
|
from sky.server import stream_utils
|
12
9
|
from sky.server.requests import executor
|
@@ -14,7 +11,6 @@ from sky.server.requests import payloads
|
|
14
11
|
from sky.server.requests import requests as api_requests
|
15
12
|
from sky.skylet import constants
|
16
13
|
from sky.utils import common
|
17
|
-
from sky.utils import common_utils
|
18
14
|
|
19
15
|
logger = sky_logging.init_logger(__name__)
|
20
16
|
|
@@ -110,94 +106,3 @@ async def download_logs(
|
|
110
106
|
if jobs_download_logs_body.refresh else api_requests.ScheduleType.SHORT,
|
111
107
|
request_cluster_name=common.JOB_CONTROLLER_NAME,
|
112
108
|
)
|
113
|
-
|
114
|
-
|
115
|
-
@router.get('/dashboard')
|
116
|
-
async def dashboard(request: fastapi.Request,
|
117
|
-
user_hash: str) -> fastapi.Response:
|
118
|
-
# TODO(cooperc): Support showing only jobs for a specific user.
|
119
|
-
|
120
|
-
# FIX(zhwu/cooperc/eric): Fix log downloading (assumes global
|
121
|
-
# /download_log/xx route)
|
122
|
-
|
123
|
-
# Note: before #4717, each user had their own controller, and thus their own
|
124
|
-
# dashboard. Now, all users share the same controller, so this isn't really
|
125
|
-
# necessary. TODO(cooperc): clean up.
|
126
|
-
|
127
|
-
# TODO: Put this in an executor to avoid blocking the main server thread.
|
128
|
-
# It can take a long time if it needs to check the controller status.
|
129
|
-
|
130
|
-
# Find the port for the dashboard of the user
|
131
|
-
os.environ[constants.USER_ID_ENV_VAR] = user_hash
|
132
|
-
server_common.reload_for_new_request(client_entrypoint=None,
|
133
|
-
client_command=None,
|
134
|
-
using_remote_api_server=False)
|
135
|
-
logger.info(f'Starting dashboard for user hash: {user_hash}')
|
136
|
-
|
137
|
-
with dashboard_utils.get_dashboard_lock_for_user(user_hash):
|
138
|
-
max_retries = 3
|
139
|
-
for attempt in range(max_retries):
|
140
|
-
port, pid = dashboard_utils.get_dashboard_session(user_hash)
|
141
|
-
if port == 0 or attempt > 0:
|
142
|
-
# Let the client know that we are waiting for starting the
|
143
|
-
# dashboard.
|
144
|
-
try:
|
145
|
-
port, pid = core.start_dashboard_forwarding()
|
146
|
-
except Exception as e: # pylint: disable=broad-except
|
147
|
-
# We catch all exceptions to gracefully handle unknown
|
148
|
-
# errors and raise an HTTPException to the client.
|
149
|
-
msg = (
|
150
|
-
'Dashboard failed to start: '
|
151
|
-
f'{common_utils.format_exception(e, use_bracket=True)}')
|
152
|
-
logger.error(msg)
|
153
|
-
raise fastapi.HTTPException(status_code=503, detail=msg)
|
154
|
-
dashboard_utils.add_dashboard_session(user_hash, port, pid)
|
155
|
-
|
156
|
-
# Assuming the dashboard is forwarded to localhost on the API server
|
157
|
-
dashboard_url = f'http://localhost:{port}'
|
158
|
-
try:
|
159
|
-
# Ping the dashboard to check if it's still running
|
160
|
-
async with httpx.AsyncClient() as client:
|
161
|
-
response = await client.request('GET',
|
162
|
-
dashboard_url,
|
163
|
-
timeout=5)
|
164
|
-
if response.is_success:
|
165
|
-
break # Connection successful, proceed with the request
|
166
|
-
# Raise an HTTPException here which will be caught by the
|
167
|
-
# following except block to retry with new connection
|
168
|
-
response.raise_for_status()
|
169
|
-
except Exception as e: # pylint: disable=broad-except
|
170
|
-
# We catch all exceptions to gracefully handle unknown
|
171
|
-
# errors and retry or raise an HTTPException to the client.
|
172
|
-
# Assume an exception indicates that the dashboard connection
|
173
|
-
# is stale - remove it so that a new one is created.
|
174
|
-
dashboard_utils.remove_dashboard_session(user_hash)
|
175
|
-
msg = (
|
176
|
-
f'Dashboard connection attempt {attempt + 1} failed with '
|
177
|
-
f'{common_utils.format_exception(e, use_bracket=True)}')
|
178
|
-
logger.info(msg)
|
179
|
-
if attempt == max_retries - 1:
|
180
|
-
raise fastapi.HTTPException(status_code=503, detail=msg)
|
181
|
-
|
182
|
-
# Create a client session to forward the request
|
183
|
-
try:
|
184
|
-
async with httpx.AsyncClient() as client:
|
185
|
-
# Make the request and get the response
|
186
|
-
response = await client.request(
|
187
|
-
method='GET',
|
188
|
-
url=f'{dashboard_url}',
|
189
|
-
headers=request.headers.raw,
|
190
|
-
)
|
191
|
-
|
192
|
-
# Create a new response with the content already read
|
193
|
-
content = await response.aread()
|
194
|
-
return fastapi.Response(
|
195
|
-
content=content,
|
196
|
-
status_code=response.status_code,
|
197
|
-
headers=dict(response.headers),
|
198
|
-
media_type=response.headers.get('content-type'))
|
199
|
-
except Exception as e:
|
200
|
-
msg = (f'Failed to forward request to dashboard: '
|
201
|
-
f'{common_utils.format_exception(e, use_bracket=True)}')
|
202
|
-
logger.error(msg)
|
203
|
-
raise fastapi.HTTPException(status_code=502, detail=msg)
|
sky/jobs/state.py
CHANGED
@@ -122,7 +122,8 @@ def create_table(cursor, conn):
|
|
122
122
|
user_hash TEXT,
|
123
123
|
workspace TEXT DEFAULT NULL,
|
124
124
|
priority INTEGER DEFAULT 500,
|
125
|
-
entrypoint TEXT DEFAULT NULL
|
125
|
+
entrypoint TEXT DEFAULT NULL,
|
126
|
+
original_user_yaml_path TEXT DEFAULT NULL)""")
|
126
127
|
|
127
128
|
db_utils.add_column_to_table(cursor, conn, 'job_info', 'schedule_state',
|
128
129
|
'TEXT')
|
@@ -153,6 +154,8 @@ def create_table(cursor, conn):
|
|
153
154
|
value_to_replace_existing_entries=500)
|
154
155
|
|
155
156
|
db_utils.add_column_to_table(cursor, conn, 'job_info', 'entrypoint', 'TEXT')
|
157
|
+
db_utils.add_column_to_table(cursor, conn, 'job_info',
|
158
|
+
'original_user_yaml_path', 'TEXT')
|
156
159
|
conn.commit()
|
157
160
|
|
158
161
|
|
@@ -212,6 +215,7 @@ columns = [
|
|
212
215
|
'workspace',
|
213
216
|
'priority',
|
214
217
|
'entrypoint',
|
218
|
+
'original_user_yaml_path',
|
215
219
|
]
|
216
220
|
|
217
221
|
|
@@ -1013,19 +1017,16 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
1013
1017
|
if job_dict['job_name'] is None:
|
1014
1018
|
job_dict['job_name'] = job_dict['task_name']
|
1015
1019
|
|
1016
|
-
# Add YAML content
|
1017
|
-
|
1018
|
-
if
|
1020
|
+
# Add user YAML content for managed jobs.
|
1021
|
+
yaml_path = job_dict.get('original_user_yaml_path')
|
1022
|
+
if yaml_path:
|
1019
1023
|
try:
|
1020
|
-
with open(
|
1021
|
-
job_dict['
|
1024
|
+
with open(yaml_path, 'r', encoding='utf-8') as f:
|
1025
|
+
job_dict['user_yaml'] = f.read()
|
1022
1026
|
except (FileNotFoundError, IOError, OSError):
|
1023
|
-
job_dict['
|
1024
|
-
|
1025
|
-
# Generate a command that could be used to launch this job
|
1026
|
-
# Format: sky jobs launch <yaml_path>
|
1027
|
+
job_dict['user_yaml'] = None
|
1027
1028
|
else:
|
1028
|
-
job_dict['
|
1029
|
+
job_dict['user_yaml'] = None
|
1029
1030
|
|
1030
1031
|
jobs.append(job_dict)
|
1031
1032
|
return jobs
|
@@ -1083,18 +1084,20 @@ def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
|
|
1083
1084
|
# scheduler lock to work correctly.
|
1084
1085
|
|
1085
1086
|
|
1086
|
-
def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
|
1087
|
+
def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
|
1088
|
+
original_user_yaml_path: str, env_file_path: str,
|
1087
1089
|
user_hash: str, priority: int) -> None:
|
1088
1090
|
"""Do not call without holding the scheduler lock."""
|
1089
1091
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1090
1092
|
updated_count = cursor.execute(
|
1091
1093
|
'UPDATE job_info SET '
|
1092
|
-
'schedule_state = (?), dag_yaml_path = (?),
|
1094
|
+
'schedule_state = (?), dag_yaml_path = (?), '
|
1095
|
+
'original_user_yaml_path = (?), env_file_path = (?), '
|
1093
1096
|
' user_hash = (?), priority = (?) '
|
1094
1097
|
'WHERE spot_job_id = (?) AND schedule_state = (?)',
|
1095
1098
|
(ManagedJobScheduleState.WAITING.value, dag_yaml_path,
|
1096
|
-
env_file_path, user_hash, priority,
|
1097
|
-
ManagedJobScheduleState.INACTIVE.value)).rowcount
|
1099
|
+
original_user_yaml_path, env_file_path, user_hash, priority,
|
1100
|
+
job_id, ManagedJobScheduleState.INACTIVE.value)).rowcount
|
1098
1101
|
assert updated_count == 1, (job_id, updated_count)
|
1099
1102
|
|
1100
1103
|
|
sky/jobs/utils.py
CHANGED
@@ -1025,7 +1025,8 @@ def load_managed_job_queue(payload: str) -> List[Dict[str, Any]]:
|
|
1025
1025
|
if 'user_hash' in job and job['user_hash'] is not None:
|
1026
1026
|
# Skip jobs that do not have user_hash info.
|
1027
1027
|
# TODO(cooperc): Remove check before 0.12.0.
|
1028
|
-
|
1028
|
+
user = global_user_state.get_user(job['user_hash'])
|
1029
|
+
job['user_name'] = user.name if user is not None else None
|
1029
1030
|
return jobs
|
1030
1031
|
|
1031
1032
|
|
sky/models.py
CHANGED
@@ -2,8 +2,13 @@
|
|
2
2
|
|
3
3
|
import collections
|
4
4
|
import dataclasses
|
5
|
+
import getpass
|
6
|
+
import os
|
5
7
|
from typing import Any, Dict, Optional
|
6
8
|
|
9
|
+
from sky.skylet import constants
|
10
|
+
from sky.utils import common_utils
|
11
|
+
|
7
12
|
|
8
13
|
@dataclasses.dataclass
|
9
14
|
class User:
|
@@ -16,6 +21,19 @@ class User:
|
|
16
21
|
def to_dict(self) -> Dict[str, Any]:
|
17
22
|
return {'id': self.id, 'name': self.name}
|
18
23
|
|
24
|
+
def to_env_vars(self) -> Dict[str, Any]:
|
25
|
+
return {
|
26
|
+
constants.USER_ID_ENV_VAR: self.id,
|
27
|
+
constants.USER_ENV_VAR: self.name,
|
28
|
+
}
|
29
|
+
|
30
|
+
@classmethod
|
31
|
+
def get_current_user(cls) -> 'User':
|
32
|
+
"""Returns the current user."""
|
33
|
+
user_name = os.getenv(constants.USER_ENV_VAR, getpass.getuser())
|
34
|
+
user_hash = common_utils.get_user_hash()
|
35
|
+
return User(id=user_hash, name=user_name)
|
36
|
+
|
19
37
|
|
20
38
|
RealtimeGpuAvailability = collections.namedtuple(
|
21
39
|
'RealtimeGpuAvailability', ['gpu', 'counts', 'capacity', 'available'])
|
@@ -2342,6 +2342,7 @@ def get_endpoint_debug_message() -> str:
|
|
2342
2342
|
def combine_pod_config_fields(
|
2343
2343
|
cluster_yaml_path: str,
|
2344
2344
|
cluster_config_overrides: Dict[str, Any],
|
2345
|
+
cloud: Optional[clouds.Cloud] = None,
|
2345
2346
|
) -> None:
|
2346
2347
|
"""Adds or updates fields in the YAML with fields from the
|
2347
2348
|
~/.sky/config.yaml's kubernetes.pod_spec dict.
|
@@ -2386,11 +2387,17 @@ def combine_pod_config_fields(
|
|
2386
2387
|
yaml_obj = yaml.safe_load(yaml_content)
|
2387
2388
|
# We don't use override_configs in `skypilot_config.get_nested`, as merging
|
2388
2389
|
# the pod config requires special handling.
|
2389
|
-
|
2390
|
-
|
2391
|
-
|
2392
|
-
|
2393
|
-
'
|
2390
|
+
if isinstance(cloud, clouds.SSH):
|
2391
|
+
kubernetes_config = skypilot_config.get_nested(('ssh', 'pod_config'),
|
2392
|
+
default_value={},
|
2393
|
+
override_configs={})
|
2394
|
+
override_pod_config = (cluster_config_overrides.get('ssh', {}).get(
|
2395
|
+
'pod_config', {}))
|
2396
|
+
else:
|
2397
|
+
kubernetes_config = skypilot_config.get_nested(
|
2398
|
+
('kubernetes', 'pod_config'), default_value={}, override_configs={})
|
2399
|
+
override_pod_config = (cluster_config_overrides.get(
|
2400
|
+
'kubernetes', {}).get('pod_config', {}))
|
2394
2401
|
config_utils.merge_k8s_configs(kubernetes_config, override_pod_config)
|
2395
2402
|
|
2396
2403
|
# Merge the kubernetes config into the YAML for both head and worker nodes.
|
@@ -0,0 +1,47 @@
|
|
1
|
+
"""Constants used by the Nebius provisioner."""
|
2
|
+
|
3
|
+
VERSION = 'v1'
|
4
|
+
|
5
|
+
# InfiniBand-capable instance platforms
|
6
|
+
INFINIBAND_INSTANCE_PLATFORMS = [
|
7
|
+
'gpu-h100-sxm',
|
8
|
+
'gpu-h200-sxm',
|
9
|
+
]
|
10
|
+
|
11
|
+
# InfiniBand environment variables for NCCL and UCX
|
12
|
+
INFINIBAND_ENV_VARS = {
|
13
|
+
'NCCL_IB_HCA': 'mlx5',
|
14
|
+
'UCX_NET_DEVICES': ('mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,'
|
15
|
+
'mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1')
|
16
|
+
}
|
17
|
+
|
18
|
+
# Docker run options for InfiniBand support
|
19
|
+
INFINIBAND_DOCKER_OPTIONS = ['--device=/dev/infiniband', '--cap-add=IPC_LOCK']
|
20
|
+
|
21
|
+
# InfiniBand fabric mapping by platform and region
|
22
|
+
# Based on Nebius documentation
|
23
|
+
INFINIBAND_FABRIC_MAPPING = {
|
24
|
+
# H100 platforms
|
25
|
+
('gpu-h100-sxm', 'eu-north1'): [
|
26
|
+
'fabric-2', 'fabric-3', 'fabric-4', 'fabric-6'
|
27
|
+
],
|
28
|
+
|
29
|
+
# H200 platforms
|
30
|
+
('gpu-h200-sxm', 'eu-north1'): ['fabric-7'],
|
31
|
+
('gpu-h200-sxm', 'eu-west1'): ['fabric-5'],
|
32
|
+
('gpu-h200-sxm', 'us-central1'): ['us-central1-a'],
|
33
|
+
}
|
34
|
+
|
35
|
+
|
36
|
+
def get_default_fabric(platform: str, region: str) -> str:
|
37
|
+
"""Get the default (first) fabric for a given platform and region."""
|
38
|
+
fabrics = INFINIBAND_FABRIC_MAPPING.get((platform, region), [])
|
39
|
+
if not fabrics:
|
40
|
+
# Select north europe region as default
|
41
|
+
fabrics = INFINIBAND_FABRIC_MAPPING.get(('gpu-h100-sxm', 'eu-north1'),
|
42
|
+
[])
|
43
|
+
if not fabrics:
|
44
|
+
raise ValueError(
|
45
|
+
f'No InfiniBand fabric available for platform {platform} '
|
46
|
+
f'in region {region}')
|
47
|
+
return fabrics[0]
|
sky/provision/nebius/instance.py
CHANGED
@@ -124,6 +124,7 @@ def run_instances(region: str, cluster_name_on_cloud: str,
|
|
124
124
|
node_type = 'head' if head_instance_id is None else 'worker'
|
125
125
|
try:
|
126
126
|
platform, preset = config.node_config['InstanceType'].split('_')
|
127
|
+
|
127
128
|
instance_id = utils.launch(
|
128
129
|
cluster_name_on_cloud=cluster_name_on_cloud,
|
129
130
|
node_type=node_type,
|
@@ -136,7 +137,7 @@ def run_instances(region: str, cluster_name_on_cloud: str,
|
|
136
137
|
associate_public_ip_address=(
|
137
138
|
not config.provider_config['use_internal_ips']),
|
138
139
|
filesystems=config.node_config.get('filesystems', []),
|
139
|
-
|
140
|
+
network_tier=config.node_config.get('network_tier'))
|
140
141
|
except Exception as e: # pylint: disable=broad-except
|
141
142
|
logger.warning(f'run_instances error: {e}')
|
142
143
|
raise
|
sky/provision/nebius/utils.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1
1
|
"""Nebius library wrapper for SkyPilot."""
|
2
2
|
import time
|
3
|
-
from typing import Any, Dict, List
|
3
|
+
from typing import Any, Dict, List, Optional
|
4
4
|
import uuid
|
5
5
|
|
6
6
|
from sky import sky_logging
|
7
7
|
from sky import skypilot_config
|
8
8
|
from sky.adaptors import nebius
|
9
|
+
from sky.provision.nebius import constants as nebius_constants
|
9
10
|
from sky.utils import common_utils
|
11
|
+
from sky.utils import resources_utils
|
10
12
|
|
11
13
|
logger = sky_logging.init_logger(__name__)
|
12
14
|
|
@@ -156,10 +158,17 @@ def start(instance_id: str) -> None:
|
|
156
158
|
f' to be ready.')
|
157
159
|
|
158
160
|
|
159
|
-
def launch(cluster_name_on_cloud: str,
|
160
|
-
|
161
|
-
|
162
|
-
|
161
|
+
def launch(cluster_name_on_cloud: str,
|
162
|
+
node_type: str,
|
163
|
+
platform: str,
|
164
|
+
preset: str,
|
165
|
+
region: str,
|
166
|
+
image_family: str,
|
167
|
+
disk_size: int,
|
168
|
+
user_data: str,
|
169
|
+
associate_public_ip_address: bool,
|
170
|
+
filesystems: List[Dict[str, Any]],
|
171
|
+
network_tier: Optional[resources_utils.NetworkTier] = None) -> str:
|
163
172
|
# Each node must have a unique name to avoid conflicts between
|
164
173
|
# multiple worker VMs. To ensure uniqueness,a UUID is appended
|
165
174
|
# to the node name.
|
@@ -173,11 +182,23 @@ def launch(cluster_name_on_cloud: str, node_type: str, platform: str,
|
|
173
182
|
# 8 GPU virtual machines can be grouped into a GPU cluster.
|
174
183
|
# The GPU clusters are built with InfiniBand secure high-speed networking.
|
175
184
|
# https://docs.nebius.com/compute/clusters/gpu
|
176
|
-
if platform in
|
185
|
+
if platform in nebius_constants.INFINIBAND_INSTANCE_PLATFORMS:
|
177
186
|
if preset == '8gpu-128vcpu-1600gb':
|
178
|
-
# Check is there fabric in config
|
179
187
|
fabric = skypilot_config.get_nested(('nebius', region, 'fabric'),
|
180
188
|
None)
|
189
|
+
|
190
|
+
# Auto-select fabric if network_tier=best and no fabric configured
|
191
|
+
if (fabric is None and
|
192
|
+
str(network_tier) == str(resources_utils.NetworkTier.BEST)):
|
193
|
+
try:
|
194
|
+
fabric = nebius_constants.get_default_fabric(
|
195
|
+
platform, region)
|
196
|
+
logger.info(f'Auto-selected InfiniBand fabric {fabric} '
|
197
|
+
f'for {platform} in {region}')
|
198
|
+
except ValueError as e:
|
199
|
+
logger.warning(
|
200
|
+
f'InfiniBand fabric auto-selection failed: {e}')
|
201
|
+
|
181
202
|
if fabric is None:
|
182
203
|
logger.warning(
|
183
204
|
f'Set up fabric for region {region} in ~/.sky/config.yaml '
|
sky/serve/server/core.py
CHANGED
@@ -221,7 +221,7 @@ def up(
|
|
221
221
|
# for the first time; otherwise it is a name conflict.
|
222
222
|
# Since the controller may be shared among multiple users, launch the
|
223
223
|
# controller with the API server's user hash.
|
224
|
-
with common.
|
224
|
+
with common.with_server_user():
|
225
225
|
with skypilot_config.local_active_workspace_ctx(
|
226
226
|
constants.SKYPILOT_DEFAULT_WORKSPACE):
|
227
227
|
controller_job_id, controller_handle = execution.launch(
|
sky/server/common.py
CHANGED
@@ -39,6 +39,7 @@ if typing.TYPE_CHECKING:
|
|
39
39
|
import requests
|
40
40
|
|
41
41
|
from sky import dag as dag_lib
|
42
|
+
from sky import models
|
42
43
|
else:
|
43
44
|
pydantic = adaptors_common.LazyImport('pydantic')
|
44
45
|
requests = adaptors_common.LazyImport('requests')
|
@@ -710,7 +711,7 @@ def request_body_to_params(body: 'pydantic.BaseModel') -> Dict[str, Any]:
|
|
710
711
|
|
711
712
|
def reload_for_new_request(client_entrypoint: Optional[str],
|
712
713
|
client_command: Optional[str],
|
713
|
-
using_remote_api_server: bool):
|
714
|
+
using_remote_api_server: bool, user: 'models.User'):
|
714
715
|
"""Reload modules, global variables, and usage message for a new request."""
|
715
716
|
# This should be called first to make sure the logger is up-to-date.
|
716
717
|
sky_logging.reload_logger()
|
@@ -719,10 +720,11 @@ def reload_for_new_request(client_entrypoint: Optional[str],
|
|
719
720
|
skypilot_config.safe_reload_config()
|
720
721
|
|
721
722
|
# Reset the client entrypoint and command for the usage message.
|
722
|
-
common_utils.
|
723
|
+
common_utils.set_request_context(
|
723
724
|
client_entrypoint=client_entrypoint,
|
724
725
|
client_command=client_command,
|
725
726
|
using_remote_api_server=using_remote_api_server,
|
727
|
+
user=user,
|
726
728
|
)
|
727
729
|
|
728
730
|
# Clear cache should be called before reload_logger and usage reset,
|
sky/server/constants.py
CHANGED
@@ -11,8 +11,6 @@ API_VERSION = '9'
|
|
11
11
|
|
12
12
|
# Prefix for API request names.
|
13
13
|
REQUEST_NAME_PREFIX = 'sky.'
|
14
|
-
# The user ID of the SkyPilot system.
|
15
|
-
SKYPILOT_SYSTEM_USER_ID = 'skypilot-system'
|
16
14
|
# The memory (GB) that SkyPilot tries to not use to prevent OOM.
|
17
15
|
MIN_AVAIL_MEM_GB = 2
|
18
16
|
# Default encoder/decoder handler name.
|
sky/server/requests/executor.py
CHANGED
@@ -53,6 +53,7 @@ from sky.utils import context
|
|
53
53
|
from sky.utils import context_utils
|
54
54
|
from sky.utils import subprocess_utils
|
55
55
|
from sky.utils import timeline
|
56
|
+
from sky.workspaces import core as workspaces_core
|
56
57
|
|
57
58
|
if typing.TYPE_CHECKING:
|
58
59
|
import types
|
@@ -229,6 +230,9 @@ def override_request_env_and_config(
|
|
229
230
|
original_env = os.environ.copy()
|
230
231
|
os.environ.update(request_body.env_vars)
|
231
232
|
# Note: may be overridden by AuthProxyMiddleware.
|
233
|
+
# TODO(zhwu): we need to make the entire request a context available to the
|
234
|
+
# entire request execution, so that we can access info like user through
|
235
|
+
# the execution.
|
232
236
|
user = models.User(id=request_body.env_vars[constants.USER_ID_ENV_VAR],
|
233
237
|
name=request_body.env_vars[constants.USER_ENV_VAR])
|
234
238
|
global_user_state.add_or_update_user(user)
|
@@ -237,13 +241,17 @@ def override_request_env_and_config(
|
|
237
241
|
server_common.reload_for_new_request(
|
238
242
|
client_entrypoint=request_body.entrypoint,
|
239
243
|
client_command=request_body.entrypoint_command,
|
240
|
-
using_remote_api_server=request_body.using_remote_api_server
|
244
|
+
using_remote_api_server=request_body.using_remote_api_server,
|
245
|
+
user=user)
|
241
246
|
try:
|
242
247
|
logger.debug(
|
243
248
|
f'override path: {request_body.override_skypilot_config_path}')
|
244
249
|
with skypilot_config.override_skypilot_config(
|
245
250
|
request_body.override_skypilot_config,
|
246
251
|
request_body.override_skypilot_config_path):
|
252
|
+
# Rejecting requests to workspaces that the user does not have
|
253
|
+
# permission to access.
|
254
|
+
workspaces_core.reject_request_for_unauthorized_workspace(user)
|
247
255
|
yield
|
248
256
|
finally:
|
249
257
|
# We need to call the save_timeline() since atexit will not be
|
@@ -433,7 +441,7 @@ def prepare_request(
|
|
433
441
|
"""Prepare a request for execution."""
|
434
442
|
user_id = request_body.env_vars[constants.USER_ID_ENV_VAR]
|
435
443
|
if is_skypilot_system:
|
436
|
-
user_id =
|
444
|
+
user_id = constants.SKYPILOT_SYSTEM_USER_ID
|
437
445
|
global_user_state.add_or_update_user(
|
438
446
|
models.User(id=user_id, name=user_id))
|
439
447
|
request = api_requests.Request(request_id=request_id,
|
sky/server/requests/requests.py
CHANGED
@@ -11,7 +11,7 @@ import signal
|
|
11
11
|
import sqlite3
|
12
12
|
import time
|
13
13
|
import traceback
|
14
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple
|
14
|
+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
15
15
|
|
16
16
|
import colorama
|
17
17
|
import filelock
|
@@ -204,7 +204,8 @@ class Request:
|
|
204
204
|
"""
|
205
205
|
assert isinstance(self.request_body,
|
206
206
|
payloads.RequestBody), (self.name, self.request_body)
|
207
|
-
|
207
|
+
user = global_user_state.get_user(self.user_id)
|
208
|
+
user_name = user.name if user is not None else None
|
208
209
|
return RequestPayload(
|
209
210
|
request_id=self.request_id,
|
210
211
|
name=self.name,
|
@@ -464,7 +465,7 @@ def request_lock_path(request_id: str) -> str:
|
|
464
465
|
|
465
466
|
@contextlib.contextmanager
|
466
467
|
@init_db
|
467
|
-
def update_request(request_id: str):
|
468
|
+
def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
|
468
469
|
"""Get a SkyPilot API request."""
|
469
470
|
request = _get_request_no_lock(request_id)
|
470
471
|
yield request
|
sky/server/server.py
CHANGED
@@ -49,6 +49,7 @@ from sky.server.requests import preconditions
|
|
49
49
|
from sky.server.requests import requests as requests_lib
|
50
50
|
from sky.skylet import constants
|
51
51
|
from sky.usage import usage_lib
|
52
|
+
from sky.users import permission
|
52
53
|
from sky.users import server as users_rest
|
53
54
|
from sky.utils import admin_policy_utils
|
54
55
|
from sky.utils import common as common_lib
|
@@ -105,17 +106,21 @@ class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
105
106
|
"""Middleware to handle RBAC."""
|
106
107
|
|
107
108
|
async def dispatch(self, request: fastapi.Request, call_next):
|
108
|
-
|
109
|
+
# TODO(hailong): should have a list of paths
|
110
|
+
# that are not checked for RBAC
|
111
|
+
if (request.url.path.startswith('/dashboard/') or
|
112
|
+
request.url.path.startswith('/api/')):
|
109
113
|
return await call_next(request)
|
110
114
|
|
111
115
|
auth_user = _get_auth_user_header(request)
|
112
116
|
if auth_user is None:
|
113
117
|
return await call_next(request)
|
114
118
|
|
115
|
-
permission_service =
|
119
|
+
permission_service = permission.permission_service
|
116
120
|
# Check the role permission
|
117
|
-
if permission_service.
|
118
|
-
|
121
|
+
if permission_service.check_endpoint_permission(auth_user.id,
|
122
|
+
request.url.path,
|
123
|
+
request.method):
|
119
124
|
return fastapi.responses.JSONResponse(
|
120
125
|
status_code=403, content={'detail': 'Forbidden'})
|
121
126
|
|
@@ -154,9 +159,15 @@ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
154
159
|
if auth_user is not None:
|
155
160
|
newly_added = global_user_state.add_or_update_user(auth_user)
|
156
161
|
if newly_added:
|
157
|
-
|
162
|
+
permission.permission_service.add_user_if_not_exists(
|
158
163
|
auth_user.id)
|
159
164
|
|
165
|
+
# Store user info in request.state for access by GET endpoints
|
166
|
+
if auth_user is not None:
|
167
|
+
request.state.auth_user = auth_user
|
168
|
+
else:
|
169
|
+
request.state.auth_user = None
|
170
|
+
|
160
171
|
body = await request.body()
|
161
172
|
if auth_user and body:
|
162
173
|
try:
|
@@ -177,6 +188,12 @@ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
177
188
|
f'"env_vars" in request body is not a dictionary '
|
178
189
|
f'for request {request.state.request_id}. '
|
179
190
|
'Skipping user info injection into body.')
|
191
|
+
else:
|
192
|
+
original_json['env_vars'] = {}
|
193
|
+
original_json['env_vars'][
|
194
|
+
constants.USER_ID_ENV_VAR] = auth_user.id
|
195
|
+
original_json['env_vars'][
|
196
|
+
constants.USER_ENV_VAR] = auth_user.name
|
180
197
|
request._body = json.dumps(original_json).encode('utf-8') # pylint: disable=protected-access
|
181
198
|
return await call_next(request)
|
182
199
|
|