skypilot-nightly 1.0.0.dev20250607__py3-none-any.whl → 1.0.0.dev20250609__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +2 -2
- sky/backends/backend_utils.py +18 -2
- sky/check.py +4 -3
- sky/cli.py +5 -6
- sky/client/cli.py +5 -6
- sky/core.py +3 -2
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/chunks/470-680c19413b8f808b.js +1 -0
- sky/dashboard/out/_next/static/chunks/{614-635a84e87800f99e.js → 63-e2d7b1e75e67c713.js} +8 -8
- sky/dashboard/out/_next/static/chunks/843-16c7194621b2b512.js +11 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/{[job]-18aed9b56247d074.js → [job]-d31688d3e52736dd.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-b919a73aecdfa78f.js → [cluster]-e7d8710a9b0491e5.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{clusters-4f6b9dd9abcb33ad.js → clusters-3c674e5d970e05cb.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/config-3aac7a015c6eede1.js +6 -0
- sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-3a18d0eeb5119fe4.js → [context]-46d2e4ad6c487260.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{infra-a1a6abeeb58c1051.js → infra-7013d816a2a0e76c.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/{[job]-1354e28c81eeb686.js → [job]-f7f0c9e156d328bc.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/{jobs-23bfc8bf373423db.js → jobs-87e60396c376292f.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/users-9355a0f13d1db61d.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/workspace/{new-e1f9c0c3ff7ac4bd.js → new-9a749cca1813bd27.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-686590e0ee4b2412.js → [name]-8eeb628e03902f1b.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces-8fbcc5ab4af316d0.js +1 -0
- sky/dashboard/out/_next/static/css/8b1c8321d4c02372.css +3 -0
- sky/dashboard/out/_next/static/xos0euNCptbGAM7_Q3Acl/_buildManifest.js +1 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/exceptions.py +5 -0
- sky/global_user_state.py +11 -6
- sky/jobs/server/core.py +9 -1
- sky/jobs/server/server.py +0 -95
- sky/jobs/utils.py +2 -1
- sky/models.py +18 -0
- sky/serve/server/core.py +1 -1
- sky/server/common.py +4 -2
- sky/server/constants.py +0 -2
- sky/server/requests/executor.py +10 -2
- sky/server/requests/requests.py +4 -3
- sky/server/server.py +22 -5
- sky/skylet/constants.py +3 -0
- sky/skylet/job_lib.py +2 -1
- sky/skypilot_config.py +9 -0
- sky/users/model.conf +1 -1
- sky/users/permission.py +148 -31
- sky/users/rbac.py +26 -0
- sky/users/server.py +14 -13
- sky/utils/common.py +6 -1
- sky/utils/common_utils.py +21 -3
- sky/utils/schemas.py +9 -0
- sky/workspaces/core.py +100 -8
- sky/workspaces/server.py +15 -2
- sky/workspaces/utils.py +56 -0
- {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/METADATA +1 -1
- {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/RECORD +72 -71
- sky/dashboard/out/_next/static/1qG0HTmVilJPxQdBk0fX5/_buildManifest.js +0 -1
- sky/dashboard/out/_next/static/chunks/470-ad1e0db3afcbd9c9.js +0 -1
- sky/dashboard/out/_next/static/chunks/843-c296541442d4af88.js +0 -11
- sky/dashboard/out/_next/static/chunks/pages/config-fe375a56342cf609.js +0 -6
- sky/dashboard/out/_next/static/chunks/pages/users-5800045bd04e69c2.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/workspaces-76b07aa5da91b0df.js +0 -1
- sky/dashboard/out/_next/static/css/667d941a2888ce6e.css +0 -3
- /sky/dashboard/out/_next/static/chunks/{856-3a32da4b84176f6d.js → 856-affc52adf5403a3a.js} +0 -0
- /sky/dashboard/out/_next/static/chunks/{973-6d78a0814682d771.js → 973-aed916d5b02d2d63.js} +0 -0
- /sky/dashboard/out/_next/static/chunks/pages/{_app-cb81dc4d27f4d009.js → _app-5f16aba5794ee8e7.js} +0 -0
- /sky/dashboard/out/_next/static/{1qG0HTmVilJPxQdBk0fX5 → xos0euNCptbGAM7_Q3Acl}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/top_level.txt +0 -0
sky/jobs/server/server.py
CHANGED
@@ -1,12 +1,9 @@
|
|
1
1
|
"""REST API for managed jobs."""
|
2
|
-
import os
|
3
2
|
|
4
3
|
import fastapi
|
5
|
-
import httpx
|
6
4
|
|
7
5
|
from sky import sky_logging
|
8
6
|
from sky.jobs.server import core
|
9
|
-
from sky.jobs.server import dashboard_utils
|
10
7
|
from sky.server import common as server_common
|
11
8
|
from sky.server import stream_utils
|
12
9
|
from sky.server.requests import executor
|
@@ -14,7 +11,6 @@ from sky.server.requests import payloads
|
|
14
11
|
from sky.server.requests import requests as api_requests
|
15
12
|
from sky.skylet import constants
|
16
13
|
from sky.utils import common
|
17
|
-
from sky.utils import common_utils
|
18
14
|
|
19
15
|
logger = sky_logging.init_logger(__name__)
|
20
16
|
|
@@ -110,94 +106,3 @@ async def download_logs(
|
|
110
106
|
if jobs_download_logs_body.refresh else api_requests.ScheduleType.SHORT,
|
111
107
|
request_cluster_name=common.JOB_CONTROLLER_NAME,
|
112
108
|
)
|
113
|
-
|
114
|
-
|
115
|
-
@router.get('/dashboard')
|
116
|
-
async def dashboard(request: fastapi.Request,
|
117
|
-
user_hash: str) -> fastapi.Response:
|
118
|
-
# TODO(cooperc): Support showing only jobs for a specific user.
|
119
|
-
|
120
|
-
# FIX(zhwu/cooperc/eric): Fix log downloading (assumes global
|
121
|
-
# /download_log/xx route)
|
122
|
-
|
123
|
-
# Note: before #4717, each user had their own controller, and thus their own
|
124
|
-
# dashboard. Now, all users share the same controller, so this isn't really
|
125
|
-
# necessary. TODO(cooperc): clean up.
|
126
|
-
|
127
|
-
# TODO: Put this in an executor to avoid blocking the main server thread.
|
128
|
-
# It can take a long time if it needs to check the controller status.
|
129
|
-
|
130
|
-
# Find the port for the dashboard of the user
|
131
|
-
os.environ[constants.USER_ID_ENV_VAR] = user_hash
|
132
|
-
server_common.reload_for_new_request(client_entrypoint=None,
|
133
|
-
client_command=None,
|
134
|
-
using_remote_api_server=False)
|
135
|
-
logger.info(f'Starting dashboard for user hash: {user_hash}')
|
136
|
-
|
137
|
-
with dashboard_utils.get_dashboard_lock_for_user(user_hash):
|
138
|
-
max_retries = 3
|
139
|
-
for attempt in range(max_retries):
|
140
|
-
port, pid = dashboard_utils.get_dashboard_session(user_hash)
|
141
|
-
if port == 0 or attempt > 0:
|
142
|
-
# Let the client know that we are waiting for starting the
|
143
|
-
# dashboard.
|
144
|
-
try:
|
145
|
-
port, pid = core.start_dashboard_forwarding()
|
146
|
-
except Exception as e: # pylint: disable=broad-except
|
147
|
-
# We catch all exceptions to gracefully handle unknown
|
148
|
-
# errors and raise an HTTPException to the client.
|
149
|
-
msg = (
|
150
|
-
'Dashboard failed to start: '
|
151
|
-
f'{common_utils.format_exception(e, use_bracket=True)}')
|
152
|
-
logger.error(msg)
|
153
|
-
raise fastapi.HTTPException(status_code=503, detail=msg)
|
154
|
-
dashboard_utils.add_dashboard_session(user_hash, port, pid)
|
155
|
-
|
156
|
-
# Assuming the dashboard is forwarded to localhost on the API server
|
157
|
-
dashboard_url = f'http://localhost:{port}'
|
158
|
-
try:
|
159
|
-
# Ping the dashboard to check if it's still running
|
160
|
-
async with httpx.AsyncClient() as client:
|
161
|
-
response = await client.request('GET',
|
162
|
-
dashboard_url,
|
163
|
-
timeout=5)
|
164
|
-
if response.is_success:
|
165
|
-
break # Connection successful, proceed with the request
|
166
|
-
# Raise an HTTPException here which will be caught by the
|
167
|
-
# following except block to retry with new connection
|
168
|
-
response.raise_for_status()
|
169
|
-
except Exception as e: # pylint: disable=broad-except
|
170
|
-
# We catch all exceptions to gracefully handle unknown
|
171
|
-
# errors and retry or raise an HTTPException to the client.
|
172
|
-
# Assume an exception indicates that the dashboard connection
|
173
|
-
# is stale - remove it so that a new one is created.
|
174
|
-
dashboard_utils.remove_dashboard_session(user_hash)
|
175
|
-
msg = (
|
176
|
-
f'Dashboard connection attempt {attempt + 1} failed with '
|
177
|
-
f'{common_utils.format_exception(e, use_bracket=True)}')
|
178
|
-
logger.info(msg)
|
179
|
-
if attempt == max_retries - 1:
|
180
|
-
raise fastapi.HTTPException(status_code=503, detail=msg)
|
181
|
-
|
182
|
-
# Create a client session to forward the request
|
183
|
-
try:
|
184
|
-
async with httpx.AsyncClient() as client:
|
185
|
-
# Make the request and get the response
|
186
|
-
response = await client.request(
|
187
|
-
method='GET',
|
188
|
-
url=f'{dashboard_url}',
|
189
|
-
headers=request.headers.raw,
|
190
|
-
)
|
191
|
-
|
192
|
-
# Create a new response with the content already read
|
193
|
-
content = await response.aread()
|
194
|
-
return fastapi.Response(
|
195
|
-
content=content,
|
196
|
-
status_code=response.status_code,
|
197
|
-
headers=dict(response.headers),
|
198
|
-
media_type=response.headers.get('content-type'))
|
199
|
-
except Exception as e:
|
200
|
-
msg = (f'Failed to forward request to dashboard: '
|
201
|
-
f'{common_utils.format_exception(e, use_bracket=True)}')
|
202
|
-
logger.error(msg)
|
203
|
-
raise fastapi.HTTPException(status_code=502, detail=msg)
|
sky/jobs/utils.py
CHANGED
@@ -1025,7 +1025,8 @@ def load_managed_job_queue(payload: str) -> List[Dict[str, Any]]:
|
|
1025
1025
|
if 'user_hash' in job and job['user_hash'] is not None:
|
1026
1026
|
# Skip jobs that do not have user_hash info.
|
1027
1027
|
# TODO(cooperc): Remove check before 0.12.0.
|
1028
|
-
|
1028
|
+
user = global_user_state.get_user(job['user_hash'])
|
1029
|
+
job['user_name'] = user.name if user is not None else None
|
1029
1030
|
return jobs
|
1030
1031
|
|
1031
1032
|
|
sky/models.py
CHANGED
@@ -2,8 +2,13 @@
|
|
2
2
|
|
3
3
|
import collections
|
4
4
|
import dataclasses
|
5
|
+
import getpass
|
6
|
+
import os
|
5
7
|
from typing import Any, Dict, Optional
|
6
8
|
|
9
|
+
from sky.skylet import constants
|
10
|
+
from sky.utils import common_utils
|
11
|
+
|
7
12
|
|
8
13
|
@dataclasses.dataclass
|
9
14
|
class User:
|
@@ -16,6 +21,19 @@ class User:
|
|
16
21
|
def to_dict(self) -> Dict[str, Any]:
|
17
22
|
return {'id': self.id, 'name': self.name}
|
18
23
|
|
24
|
+
def to_env_vars(self) -> Dict[str, Any]:
|
25
|
+
return {
|
26
|
+
constants.USER_ID_ENV_VAR: self.id,
|
27
|
+
constants.USER_ENV_VAR: self.name,
|
28
|
+
}
|
29
|
+
|
30
|
+
@classmethod
|
31
|
+
def get_current_user(cls) -> 'User':
|
32
|
+
"""Returns the current user."""
|
33
|
+
user_name = os.getenv(constants.USER_ENV_VAR, getpass.getuser())
|
34
|
+
user_hash = common_utils.get_user_hash()
|
35
|
+
return User(id=user_hash, name=user_name)
|
36
|
+
|
19
37
|
|
20
38
|
RealtimeGpuAvailability = collections.namedtuple(
|
21
39
|
'RealtimeGpuAvailability', ['gpu', 'counts', 'capacity', 'available'])
|
sky/serve/server/core.py
CHANGED
@@ -221,7 +221,7 @@ def up(
|
|
221
221
|
# for the first time; otherwise it is a name conflict.
|
222
222
|
# Since the controller may be shared among multiple users, launch the
|
223
223
|
# controller with the API server's user hash.
|
224
|
-
with common.
|
224
|
+
with common.with_server_user():
|
225
225
|
with skypilot_config.local_active_workspace_ctx(
|
226
226
|
constants.SKYPILOT_DEFAULT_WORKSPACE):
|
227
227
|
controller_job_id, controller_handle = execution.launch(
|
sky/server/common.py
CHANGED
@@ -39,6 +39,7 @@ if typing.TYPE_CHECKING:
|
|
39
39
|
import requests
|
40
40
|
|
41
41
|
from sky import dag as dag_lib
|
42
|
+
from sky import models
|
42
43
|
else:
|
43
44
|
pydantic = adaptors_common.LazyImport('pydantic')
|
44
45
|
requests = adaptors_common.LazyImport('requests')
|
@@ -710,7 +711,7 @@ def request_body_to_params(body: 'pydantic.BaseModel') -> Dict[str, Any]:
|
|
710
711
|
|
711
712
|
def reload_for_new_request(client_entrypoint: Optional[str],
|
712
713
|
client_command: Optional[str],
|
713
|
-
using_remote_api_server: bool):
|
714
|
+
using_remote_api_server: bool, user: 'models.User'):
|
714
715
|
"""Reload modules, global variables, and usage message for a new request."""
|
715
716
|
# This should be called first to make sure the logger is up-to-date.
|
716
717
|
sky_logging.reload_logger()
|
@@ -719,10 +720,11 @@ def reload_for_new_request(client_entrypoint: Optional[str],
|
|
719
720
|
skypilot_config.safe_reload_config()
|
720
721
|
|
721
722
|
# Reset the client entrypoint and command for the usage message.
|
722
|
-
common_utils.
|
723
|
+
common_utils.set_request_context(
|
723
724
|
client_entrypoint=client_entrypoint,
|
724
725
|
client_command=client_command,
|
725
726
|
using_remote_api_server=using_remote_api_server,
|
727
|
+
user=user,
|
726
728
|
)
|
727
729
|
|
728
730
|
# Clear cache should be called before reload_logger and usage reset,
|
sky/server/constants.py
CHANGED
@@ -11,8 +11,6 @@ API_VERSION = '9'
|
|
11
11
|
|
12
12
|
# Prefix for API request names.
|
13
13
|
REQUEST_NAME_PREFIX = 'sky.'
|
14
|
-
# The user ID of the SkyPilot system.
|
15
|
-
SKYPILOT_SYSTEM_USER_ID = 'skypilot-system'
|
16
14
|
# The memory (GB) that SkyPilot tries to not use to prevent OOM.
|
17
15
|
MIN_AVAIL_MEM_GB = 2
|
18
16
|
# Default encoder/decoder handler name.
|
sky/server/requests/executor.py
CHANGED
@@ -53,6 +53,7 @@ from sky.utils import context
|
|
53
53
|
from sky.utils import context_utils
|
54
54
|
from sky.utils import subprocess_utils
|
55
55
|
from sky.utils import timeline
|
56
|
+
from sky.workspaces import core as workspaces_core
|
56
57
|
|
57
58
|
if typing.TYPE_CHECKING:
|
58
59
|
import types
|
@@ -229,6 +230,9 @@ def override_request_env_and_config(
|
|
229
230
|
original_env = os.environ.copy()
|
230
231
|
os.environ.update(request_body.env_vars)
|
231
232
|
# Note: may be overridden by AuthProxyMiddleware.
|
233
|
+
# TODO(zhwu): we need to make the entire request a context available to the
|
234
|
+
# entire request execution, so that we can access info like user through
|
235
|
+
# the execution.
|
232
236
|
user = models.User(id=request_body.env_vars[constants.USER_ID_ENV_VAR],
|
233
237
|
name=request_body.env_vars[constants.USER_ENV_VAR])
|
234
238
|
global_user_state.add_or_update_user(user)
|
@@ -237,13 +241,17 @@ def override_request_env_and_config(
|
|
237
241
|
server_common.reload_for_new_request(
|
238
242
|
client_entrypoint=request_body.entrypoint,
|
239
243
|
client_command=request_body.entrypoint_command,
|
240
|
-
using_remote_api_server=request_body.using_remote_api_server
|
244
|
+
using_remote_api_server=request_body.using_remote_api_server,
|
245
|
+
user=user)
|
241
246
|
try:
|
242
247
|
logger.debug(
|
243
248
|
f'override path: {request_body.override_skypilot_config_path}')
|
244
249
|
with skypilot_config.override_skypilot_config(
|
245
250
|
request_body.override_skypilot_config,
|
246
251
|
request_body.override_skypilot_config_path):
|
252
|
+
# Rejecting requests to workspaces that the user does not have
|
253
|
+
# permission to access.
|
254
|
+
workspaces_core.reject_request_for_unauthorized_workspace(user)
|
247
255
|
yield
|
248
256
|
finally:
|
249
257
|
# We need to call the save_timeline() since atexit will not be
|
@@ -433,7 +441,7 @@ def prepare_request(
|
|
433
441
|
"""Prepare a request for execution."""
|
434
442
|
user_id = request_body.env_vars[constants.USER_ID_ENV_VAR]
|
435
443
|
if is_skypilot_system:
|
436
|
-
user_id =
|
444
|
+
user_id = constants.SKYPILOT_SYSTEM_USER_ID
|
437
445
|
global_user_state.add_or_update_user(
|
438
446
|
models.User(id=user_id, name=user_id))
|
439
447
|
request = api_requests.Request(request_id=request_id,
|
sky/server/requests/requests.py
CHANGED
@@ -11,7 +11,7 @@ import signal
|
|
11
11
|
import sqlite3
|
12
12
|
import time
|
13
13
|
import traceback
|
14
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple
|
14
|
+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
15
15
|
|
16
16
|
import colorama
|
17
17
|
import filelock
|
@@ -204,7 +204,8 @@ class Request:
|
|
204
204
|
"""
|
205
205
|
assert isinstance(self.request_body,
|
206
206
|
payloads.RequestBody), (self.name, self.request_body)
|
207
|
-
|
207
|
+
user = global_user_state.get_user(self.user_id)
|
208
|
+
user_name = user.name if user is not None else None
|
208
209
|
return RequestPayload(
|
209
210
|
request_id=self.request_id,
|
210
211
|
name=self.name,
|
@@ -464,7 +465,7 @@ def request_lock_path(request_id: str) -> str:
|
|
464
465
|
|
465
466
|
@contextlib.contextmanager
|
466
467
|
@init_db
|
467
|
-
def update_request(request_id: str):
|
468
|
+
def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
|
468
469
|
"""Get a SkyPilot API request."""
|
469
470
|
request = _get_request_no_lock(request_id)
|
470
471
|
yield request
|
sky/server/server.py
CHANGED
@@ -49,6 +49,7 @@ from sky.server.requests import preconditions
|
|
49
49
|
from sky.server.requests import requests as requests_lib
|
50
50
|
from sky.skylet import constants
|
51
51
|
from sky.usage import usage_lib
|
52
|
+
from sky.users import permission
|
52
53
|
from sky.users import server as users_rest
|
53
54
|
from sky.utils import admin_policy_utils
|
54
55
|
from sky.utils import common as common_lib
|
@@ -105,17 +106,21 @@ class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
105
106
|
"""Middleware to handle RBAC."""
|
106
107
|
|
107
108
|
async def dispatch(self, request: fastapi.Request, call_next):
|
108
|
-
|
109
|
+
# TODO(hailong): should have a list of paths
|
110
|
+
# that are not checked for RBAC
|
111
|
+
if (request.url.path.startswith('/dashboard/') or
|
112
|
+
request.url.path.startswith('/api/')):
|
109
113
|
return await call_next(request)
|
110
114
|
|
111
115
|
auth_user = _get_auth_user_header(request)
|
112
116
|
if auth_user is None:
|
113
117
|
return await call_next(request)
|
114
118
|
|
115
|
-
permission_service =
|
119
|
+
permission_service = permission.permission_service
|
116
120
|
# Check the role permission
|
117
|
-
if permission_service.
|
118
|
-
|
121
|
+
if permission_service.check_endpoint_permission(auth_user.id,
|
122
|
+
request.url.path,
|
123
|
+
request.method):
|
119
124
|
return fastapi.responses.JSONResponse(
|
120
125
|
status_code=403, content={'detail': 'Forbidden'})
|
121
126
|
|
@@ -154,9 +159,15 @@ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
154
159
|
if auth_user is not None:
|
155
160
|
newly_added = global_user_state.add_or_update_user(auth_user)
|
156
161
|
if newly_added:
|
157
|
-
|
162
|
+
permission.permission_service.add_user_if_not_exists(
|
158
163
|
auth_user.id)
|
159
164
|
|
165
|
+
# Store user info in request.state for access by GET endpoints
|
166
|
+
if auth_user is not None:
|
167
|
+
request.state.auth_user = auth_user
|
168
|
+
else:
|
169
|
+
request.state.auth_user = None
|
170
|
+
|
160
171
|
body = await request.body()
|
161
172
|
if auth_user and body:
|
162
173
|
try:
|
@@ -177,6 +188,12 @@ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
177
188
|
f'"env_vars" in request body is not a dictionary '
|
178
189
|
f'for request {request.state.request_id}. '
|
179
190
|
'Skipping user info injection into body.')
|
191
|
+
else:
|
192
|
+
original_json['env_vars'] = {}
|
193
|
+
original_json['env_vars'][
|
194
|
+
constants.USER_ID_ENV_VAR] = auth_user.id
|
195
|
+
original_json['env_vars'][
|
196
|
+
constants.USER_ENV_VAR] = auth_user.name
|
180
197
|
request._body = json.dumps(original_json).encode('utf-8') # pylint: disable=protected-access
|
181
198
|
return await call_next(request)
|
182
199
|
|
sky/skylet/constants.py
CHANGED
@@ -419,3 +419,6 @@ ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci',
|
|
419
419
|
'kubernetes', 'runpod', 'vast', 'vsphere', 'cudo', 'fluidstack',
|
420
420
|
'paperspace', 'do', 'nebius', 'ssh')
|
421
421
|
# END constants used for service catalog.
|
422
|
+
|
423
|
+
# The user ID of the SkyPilot system.
|
424
|
+
SKYPILOT_SYSTEM_USER_ID = 'skypilot-system'
|
sky/skylet/job_lib.py
CHANGED
@@ -794,7 +794,8 @@ def load_job_queue(payload: str) -> List[Dict[str, Any]]:
|
|
794
794
|
for job in jobs:
|
795
795
|
job['status'] = JobStatus(job['status'])
|
796
796
|
job['user_hash'] = job['username']
|
797
|
-
|
797
|
+
user = global_user_state.get_user(job['user_hash'])
|
798
|
+
job['username'] = user.name if user is not None else None
|
798
799
|
return jobs
|
799
800
|
|
800
801
|
|
sky/skypilot_config.py
CHANGED
@@ -765,6 +765,15 @@ def update_api_server_config_no_lock(config: config_utils.Config) -> None:
|
|
765
765
|
Args:
|
766
766
|
config: The config to save and sync.
|
767
767
|
"""
|
768
|
+
|
769
|
+
def is_running_pytest() -> bool:
|
770
|
+
return 'PYTEST_CURRENT_TEST' in os.environ
|
771
|
+
|
772
|
+
# Only allow this function to be called by the API Server in production.
|
773
|
+
if not is_running_pytest() and os.environ.get(
|
774
|
+
constants.ENV_VAR_IS_SKYPILOT_SERVER) is None:
|
775
|
+
raise ValueError('This function can only be called by the API Server.')
|
776
|
+
|
768
777
|
global_config_path = _resolve_server_config_path()
|
769
778
|
if global_config_path is None:
|
770
779
|
global_config_path = get_user_config_path()
|
sky/users/model.conf
CHANGED
sky/users/permission.py
CHANGED
@@ -3,7 +3,7 @@ import contextlib
|
|
3
3
|
import logging
|
4
4
|
import os
|
5
5
|
import threading
|
6
|
-
from typing import List
|
6
|
+
from typing import Generator, List
|
7
7
|
|
8
8
|
import casbin
|
9
9
|
import filelock
|
@@ -11,8 +11,11 @@ import sqlalchemy_adapter
|
|
11
11
|
|
12
12
|
from sky import global_user_state
|
13
13
|
from sky import sky_logging
|
14
|
+
from sky.skylet import constants
|
14
15
|
from sky.users import rbac
|
15
16
|
|
17
|
+
logging.getLogger('casbin.policy').setLevel(sky_logging.ERROR)
|
18
|
+
logging.getLogger('casbin.role').setLevel(sky_logging.ERROR)
|
16
19
|
logger = sky_logging.init_logger(__name__)
|
17
20
|
|
18
21
|
# Filelocks for the policy update.
|
@@ -38,17 +41,19 @@ class PermissionService:
|
|
38
41
|
model_path = os.path.join(os.path.dirname(__file__),
|
39
42
|
'model.conf')
|
40
43
|
enforcer = casbin.Enforcer(model_path, adapter)
|
41
|
-
logging.getLogger('casbin.policy').setLevel(
|
42
|
-
sky_logging.ERROR)
|
43
|
-
logging.getLogger('casbin.role').setLevel(sky_logging.ERROR)
|
44
44
|
self.enforcer = enforcer
|
45
45
|
else:
|
46
46
|
self.enforcer = _enforcer_instance.enforcer
|
47
|
-
|
47
|
+
with _policy_lock():
|
48
|
+
self._maybe_initialize_policies()
|
48
49
|
|
49
|
-
def _maybe_initialize_policies(self):
|
50
|
+
def _maybe_initialize_policies(self) -> None:
|
50
51
|
"""Initialize policies if they don't already exist."""
|
52
|
+
# TODO(zhwu): we should avoid running this on client side.
|
51
53
|
logger.debug(f'Initializing policies in process: {os.getpid()}')
|
54
|
+
self._load_policy_no_lock()
|
55
|
+
|
56
|
+
policy_updated = False
|
52
57
|
|
53
58
|
# Check if policies are already initialized by looking for existing
|
54
59
|
# permission policies in the enforcer
|
@@ -66,6 +71,17 @@ class PermissionService:
|
|
66
71
|
expected_policies.append(
|
67
72
|
[role, item['path'], item['method']])
|
68
73
|
|
74
|
+
# Add workspace policy
|
75
|
+
workspace_policy_permissions = rbac.get_workspace_policy_permissions()
|
76
|
+
logger.debug(f'Workspace policy permissions from config: '
|
77
|
+
f'{workspace_policy_permissions}')
|
78
|
+
|
79
|
+
for workspace_name, users in workspace_policy_permissions.items():
|
80
|
+
for user in users:
|
81
|
+
expected_policies.append([user, workspace_name, '*'])
|
82
|
+
logger.debug(f'Expected workspace policy: user={user}, '
|
83
|
+
f'workspace={workspace_name}')
|
84
|
+
|
69
85
|
# Check if all expected policies already exist
|
70
86
|
policies_exist = all(
|
71
87
|
any(policy == expected
|
@@ -86,48 +102,71 @@ class PermissionService:
|
|
86
102
|
for item in blocklist:
|
87
103
|
path = item['path']
|
88
104
|
method = item['method']
|
105
|
+
logger.debug(f'Adding role policy: role={role}, '
|
106
|
+
f'path={path}, method={method}')
|
89
107
|
self.enforcer.add_policy(role, path, method)
|
90
|
-
|
108
|
+
policy_updated = True
|
109
|
+
|
110
|
+
for workspace_name, users in workspace_policy_permissions.items():
|
111
|
+
for user in users:
|
112
|
+
logger.debug(f'Initializing workspace policy: user={user}, '
|
113
|
+
f'workspace={workspace_name}')
|
114
|
+
self.enforcer.add_policy(user, workspace_name, '*')
|
115
|
+
policy_updated = True
|
116
|
+
logger.debug('Policies initialized successfully')
|
91
117
|
else:
|
92
118
|
logger.debug('Policies already exist, skipping initialization')
|
93
119
|
|
94
120
|
# Always ensure users have default roles (this is idempotent)
|
95
121
|
all_users = global_user_state.get_all_users()
|
96
|
-
for
|
97
|
-
self.
|
122
|
+
for existing_user in all_users:
|
123
|
+
user_added = self._add_user_if_not_exists_no_lock(existing_user.id)
|
124
|
+
policy_updated = policy_updated or user_added
|
125
|
+
|
126
|
+
if policy_updated:
|
127
|
+
self.enforcer.save_policy()
|
98
128
|
|
99
|
-
def add_user_if_not_exists(self,
|
129
|
+
def add_user_if_not_exists(self, user_id: str) -> None:
|
100
130
|
"""Add user role relationship."""
|
101
131
|
with _policy_lock():
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
132
|
+
self._add_user_if_not_exists_no_lock(user_id)
|
133
|
+
|
134
|
+
def _add_user_if_not_exists_no_lock(self, user_id: str) -> bool:
|
135
|
+
"""Add user role relationship without lock.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
True if the user was added, False otherwise.
|
139
|
+
"""
|
140
|
+
user_roles = self.enforcer.get_roles_for_user(user_id)
|
141
|
+
if not user_roles:
|
142
|
+
logger.info(f'User {user_id} has no roles, adding'
|
143
|
+
f' default role {rbac.get_default_role()}')
|
144
|
+
self.enforcer.add_grouping_policy(user_id, rbac.get_default_role())
|
145
|
+
return True
|
146
|
+
return False
|
147
|
+
|
148
|
+
def update_role(self, user_id: str, new_role: str) -> None:
|
110
149
|
"""Update user role relationship."""
|
111
150
|
with _policy_lock():
|
112
151
|
# Get current roles
|
113
152
|
self._load_policy_no_lock()
|
114
153
|
# Avoid calling get_user_roles, as it will require the lock.
|
115
|
-
current_roles = self.enforcer.get_roles_for_user(
|
154
|
+
current_roles = self.enforcer.get_roles_for_user(user_id)
|
116
155
|
if not current_roles:
|
117
|
-
logger.warning(f'User {
|
156
|
+
logger.warning(f'User {user_id} has no roles')
|
118
157
|
else:
|
119
158
|
# TODO(hailong): how to handle multiple roles?
|
120
159
|
current_role = current_roles[0]
|
121
160
|
if current_role == new_role:
|
122
|
-
logger.info(f'User {
|
161
|
+
logger.info(f'User {user_id} already has role {new_role}')
|
123
162
|
return
|
124
|
-
self.enforcer.remove_grouping_policy(
|
163
|
+
self.enforcer.remove_grouping_policy(user_id, current_role)
|
125
164
|
|
126
165
|
# Update user role
|
127
|
-
self.enforcer.add_grouping_policy(
|
166
|
+
self.enforcer.add_grouping_policy(user_id, new_role)
|
128
167
|
self.enforcer.save_policy()
|
129
168
|
|
130
|
-
def get_user_roles(self,
|
169
|
+
def get_user_roles(self, user_id: str) -> List[str]:
|
131
170
|
"""Get all roles for a user.
|
132
171
|
|
133
172
|
This method returns all roles that the user has, including inherited
|
@@ -140,10 +179,11 @@ class PermissionService:
|
|
140
179
|
Returns:
|
141
180
|
A list of role names that the user has.
|
142
181
|
"""
|
143
|
-
self.
|
144
|
-
return self.enforcer.get_roles_for_user(
|
182
|
+
self._load_policy_no_lock()
|
183
|
+
return self.enforcer.get_roles_for_user(user_id)
|
145
184
|
|
146
|
-
def
|
185
|
+
def check_endpoint_permission(self, user_id: str, path: str,
|
186
|
+
method: str) -> bool:
|
147
187
|
"""Check permission."""
|
148
188
|
# We intentionally don't load the policy here, as it is a hot path, and
|
149
189
|
# we don't support updating the policy.
|
@@ -151,28 +191,105 @@ class PermissionService:
|
|
151
191
|
# it is a hot path in every request. It is ok to have a stale policy,
|
152
192
|
# as long as it is eventually consistent.
|
153
193
|
# self._load_policy_no_lock()
|
154
|
-
return self.enforcer.enforce(
|
194
|
+
return self.enforcer.enforce(user_id, path, method)
|
155
195
|
|
156
196
|
def _load_policy_no_lock(self):
|
157
197
|
"""Load policy from storage."""
|
158
198
|
self.enforcer.load_policy()
|
159
199
|
|
160
|
-
def
|
200
|
+
def load_policy(self):
|
161
201
|
"""Load policy from storage with lock."""
|
162
202
|
with _policy_lock():
|
163
203
|
self._load_policy_no_lock()
|
164
204
|
|
205
|
+
def check_workspace_permission(self, user_id: str,
|
206
|
+
workspace_name: str) -> bool:
|
207
|
+
"""Check workspace permission.
|
208
|
+
|
209
|
+
This method checks if a user has permission to access a specific
|
210
|
+
workspace.
|
211
|
+
|
212
|
+
For private workspaces, the user must have explicit permission.
|
213
|
+
|
214
|
+
For public workspaces, the permission is granted via a wildcard policy
|
215
|
+
('*').
|
216
|
+
"""
|
217
|
+
if os.getenv(constants.ENV_VAR_IS_SKYPILOT_SERVER) is None:
|
218
|
+
# When it is not on API server, we allow all users to access all
|
219
|
+
# workspaces, as the workspace check has been done on API server.
|
220
|
+
return True
|
221
|
+
role = self.get_user_roles(user_id)
|
222
|
+
if rbac.RoleName.ADMIN.value in role:
|
223
|
+
return True
|
224
|
+
# The Casbin model matcher already handles the wildcard '*' case:
|
225
|
+
# m = (g(r.sub, p.sub)|| p.sub == '*') && r.obj == p.obj &&
|
226
|
+
# r.act == p.act
|
227
|
+
# This means if there's a policy ('*', workspace_name, '*'), it will
|
228
|
+
# match any user
|
229
|
+
result = self.enforcer.enforce(user_id, workspace_name, '*')
|
230
|
+
logger.debug(f'Workspace permission check: user={user_id}, '
|
231
|
+
f'workspace={workspace_name}, result={result}')
|
232
|
+
return result
|
233
|
+
|
234
|
+
def add_workspace_policy(self, workspace_name: str,
|
235
|
+
users: List[str]) -> None:
|
236
|
+
"""Add workspace policy.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
workspace_name: Name of the workspace
|
240
|
+
users: List of user IDs that should have access.
|
241
|
+
For public workspaces, this should be ['*'].
|
242
|
+
For private workspaces, this should be specific user IDs.
|
243
|
+
"""
|
244
|
+
with _policy_lock():
|
245
|
+
for user in users:
|
246
|
+
logger.debug(f'Adding workspace policy: user={user}, '
|
247
|
+
f'workspace={workspace_name}')
|
248
|
+
self.enforcer.add_policy(user, workspace_name, '*')
|
249
|
+
self.enforcer.save_policy()
|
250
|
+
|
251
|
+
def update_workspace_policy(self, workspace_name: str,
|
252
|
+
users: List[str]) -> None:
|
253
|
+
"""Update workspace policy.
|
254
|
+
|
255
|
+
Args:
|
256
|
+
workspace_name: Name of the workspace
|
257
|
+
users: List of user IDs that should have access.
|
258
|
+
For public workspaces, this should be ['*'].
|
259
|
+
For private workspaces, this should be specific user IDs.
|
260
|
+
"""
|
261
|
+
with _policy_lock():
|
262
|
+
self._load_policy_no_lock()
|
263
|
+
# Remove all existing policies for this workspace
|
264
|
+
self.enforcer.remove_filtered_policy(1, workspace_name)
|
265
|
+
# Add new policies
|
266
|
+
for user in users:
|
267
|
+
logger.debug(f'Updating workspace policy: user={user}, '
|
268
|
+
f'workspace={workspace_name}')
|
269
|
+
self.enforcer.add_policy(user, workspace_name, '*')
|
270
|
+
self.enforcer.save_policy()
|
271
|
+
|
272
|
+
def remove_workspace_policy(self, workspace_name: str) -> None:
|
273
|
+
"""Remove workspace policy."""
|
274
|
+
with _policy_lock():
|
275
|
+
self.enforcer.remove_filtered_policy(1, workspace_name)
|
276
|
+
self.enforcer.save_policy()
|
277
|
+
|
165
278
|
|
166
279
|
@contextlib.contextmanager
|
167
|
-
def _policy_lock():
|
280
|
+
def _policy_lock() -> Generator[None, None, None]:
|
168
281
|
"""Context manager for policy update lock."""
|
169
282
|
try:
|
170
283
|
with filelock.FileLock(POLICY_UPDATE_LOCK_PATH,
|
171
284
|
POLICY_UPDATE_LOCK_TIMEOUT_SECONDS):
|
172
285
|
yield
|
173
286
|
except filelock.Timeout as e:
|
174
|
-
raise RuntimeError(f'Failed to
|
287
|
+
raise RuntimeError(f'Failed to reload policy due to a timeout '
|
175
288
|
f'when trying to acquire the lock at '
|
176
289
|
f'{POLICY_UPDATE_LOCK_PATH}. '
|
177
290
|
'Please try again or manually remove the lock '
|
178
291
|
f'file if you believe it is stale.') from e
|
292
|
+
|
293
|
+
|
294
|
+
# Singleton instance of PermissionService for other modules to use.
|
295
|
+
permission_service = PermissionService()
|