skypilot-nightly 1.0.0.dev20250604__py3-none-any.whl → 1.0.0.dev20250606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +2 -2
- sky/admin_policy.py +5 -0
- sky/catalog/__init__.py +2 -2
- sky/catalog/common.py +7 -9
- sky/cli.py +11 -9
- sky/client/cli.py +11 -9
- sky/client/sdk.py +30 -12
- sky/clouds/kubernetes.py +2 -2
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/99m-BAySO8Q7J-ul1jZVL/_buildManifest.js +1 -0
- sky/dashboard/out/_next/static/chunks/{236-fef38aa6e5639300.js → 236-a90f0a9753a10420.js} +2 -2
- sky/dashboard/out/_next/static/chunks/614-635a84e87800f99e.js +66 -0
- sky/dashboard/out/_next/static/chunks/{856-f1b1f7f47edde2e8.js → 856-3a32da4b84176f6d.js} +1 -1
- sky/dashboard/out/_next/static/chunks/937.3759f538f11a0953.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-35cbeb5214fd4036.js +6 -0
- sky/dashboard/out/_next/static/chunks/pages/config-1a1eeb949dab8897.js +6 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-2d23a9c7571e6320.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/users-262aab38b9baaf3a.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/workspaces-384ea5fa0cea8f28.js +1 -0
- sky/dashboard/out/_next/static/chunks/{webpack-f27c9a32aa3d9c6d.js → webpack-65d465f948974c0d.js} +1 -1
- sky/dashboard/out/_next/static/css/667d941a2888ce6e.css +3 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/execution.py +44 -46
- sky/global_user_state.py +118 -83
- sky/jobs/client/sdk.py +4 -1
- sky/jobs/server/core.py +5 -1
- sky/models.py +1 -0
- sky/resources.py +22 -1
- sky/serve/load_balancer.py +56 -45
- sky/server/constants.py +3 -1
- sky/server/requests/payloads.py +9 -0
- sky/server/server.py +30 -9
- sky/setup_files/MANIFEST.in +1 -0
- sky/setup_files/dependencies.py +2 -0
- sky/skylet/constants.py +10 -4
- sky/skypilot_config.py +4 -2
- sky/templates/websocket_proxy.py +11 -1
- sky/users/__init__.py +0 -0
- sky/users/model.conf +15 -0
- sky/users/permission.py +178 -0
- sky/users/rbac.py +86 -0
- sky/users/server.py +66 -0
- sky/utils/schemas.py +20 -7
- sky/workspaces/core.py +2 -2
- {skypilot_nightly-1.0.0.dev20250604.dist-info → skypilot_nightly-1.0.0.dev20250606.dist-info}/METADATA +3 -1
- {skypilot_nightly-1.0.0.dev20250604.dist-info → skypilot_nightly-1.0.0.dev20250606.dist-info}/RECORD +70 -66
- sky/catalog/constants.py +0 -8
- sky/dashboard/out/_next/static/chunks/614-3d29f98e0634b179.js +0 -66
- sky/dashboard/out/_next/static/chunks/937.f97f83652028e944.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-62c9982dc3675725.js +0 -6
- sky/dashboard/out/_next/static/chunks/pages/config-35383adcb0edb5e2.js +0 -6
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-a62a3c65dc9bc57c.js +0 -11
- sky/dashboard/out/_next/static/chunks/pages/users-07b523ccb19317ad.js +0 -6
- sky/dashboard/out/_next/static/chunks/pages/workspaces-f54921ec9eb20965.js +0 -1
- sky/dashboard/out/_next/static/css/63d3995d8b528eb1.css +0 -3
- sky/dashboard/out/_next/static/vWwfD3jOky5J5jULHp8JT/_buildManifest.js +0 -1
- /sky/dashboard/out/_next/static/{vWwfD3jOky5J5jULHp8JT → 99m-BAySO8Q7J-ul1jZVL}/_ssgManifest.js +0 -0
- /sky/dashboard/out/_next/static/chunks/{121-8f55ee3fa6301784.js → 121-865d2bf8a3b84c6a.js} +0 -0
- /sky/dashboard/out/_next/static/chunks/{37-947904ccc5687bac.js → 37-beedd583fea84cc8.js} +0 -0
- /sky/dashboard/out/_next/static/chunks/{682-2be9b0f169727f2f.js → 682-6647f0417d5662f0.js} +0 -0
- /sky/dashboard/out/_next/static/chunks/{843-a097338acb89b7d7.js → 843-c296541442d4af88.js} +0 -0
- /sky/dashboard/out/_next/static/chunks/{969-d7b6fb7f602bfcb3.js → 969-c7abda31c10440ac.js} +0 -0
- /sky/dashboard/out/_next/static/chunks/pages/{_app-67925f5e6382e22f.js → _app-cb81dc4d27f4d009.js} +0 -0
- /sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/{[job]-158b70da336d8607.js → [job]-65d04d5d77cbb6b6.js} +0 -0
- {skypilot_nightly-1.0.0.dev20250604.dist-info → skypilot_nightly-1.0.0.dev20250606.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20250604.dist-info → skypilot_nightly-1.0.0.dev20250606.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250604.dist-info → skypilot_nightly-1.0.0.dev20250606.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250604.dist-info → skypilot_nightly-1.0.0.dev20250606.dist-info}/top_level.txt +0 -0
sky/serve/load_balancer.py
CHANGED
@@ -2,7 +2,8 @@
|
|
2
2
|
import asyncio
|
3
3
|
import logging
|
4
4
|
import threading
|
5
|
-
|
5
|
+
import traceback
|
6
|
+
from typing import Dict, List, Optional, Union
|
6
7
|
|
7
8
|
import aiohttp
|
8
9
|
import fastapi
|
@@ -69,6 +70,48 @@ class SkyServeLoadBalancer:
|
|
69
70
|
# updating it from _sync_with_controller.
|
70
71
|
self._client_pool_lock: threading.Lock = threading.Lock()
|
71
72
|
|
73
|
+
async def _sync_with_controller_once(self) -> List[asyncio.Task]:
|
74
|
+
close_client_tasks = []
|
75
|
+
async with aiohttp.ClientSession() as session:
|
76
|
+
try:
|
77
|
+
# Send request information
|
78
|
+
async with session.post(
|
79
|
+
self._controller_url + '/controller/load_balancer_sync',
|
80
|
+
json={
|
81
|
+
'request_aggregator':
|
82
|
+
self._request_aggregator.to_dict()
|
83
|
+
},
|
84
|
+
timeout=aiohttp.ClientTimeout(5),
|
85
|
+
) as response:
|
86
|
+
# Clean up after reporting request info to avoid OOM.
|
87
|
+
self._request_aggregator.clear()
|
88
|
+
response.raise_for_status()
|
89
|
+
response_json = await response.json()
|
90
|
+
ready_replica_urls = response_json.get(
|
91
|
+
'ready_replica_urls', [])
|
92
|
+
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
93
|
+
logger.error(f'An error occurred when syncing with '
|
94
|
+
f'the controller: {e}'
|
95
|
+
f'\nTraceback: {traceback.format_exc()}')
|
96
|
+
else:
|
97
|
+
logger.info(f'Available Replica URLs: {ready_replica_urls}')
|
98
|
+
with self._client_pool_lock:
|
99
|
+
self._load_balancing_policy.set_ready_replicas(
|
100
|
+
ready_replica_urls)
|
101
|
+
for replica_url in ready_replica_urls:
|
102
|
+
if replica_url not in self._client_pool:
|
103
|
+
self._client_pool[replica_url] = httpx.AsyncClient(
|
104
|
+
base_url=replica_url)
|
105
|
+
urls_to_close = set(
|
106
|
+
self._client_pool.keys()) - set(ready_replica_urls)
|
107
|
+
client_to_close = []
|
108
|
+
for replica_url in urls_to_close:
|
109
|
+
client_to_close.append(
|
110
|
+
self._client_pool.pop(replica_url))
|
111
|
+
for client in client_to_close:
|
112
|
+
close_client_tasks.append(client.aclose())
|
113
|
+
return close_client_tasks
|
114
|
+
|
72
115
|
async def _sync_with_controller(self):
|
73
116
|
"""Sync with controller periodically.
|
74
117
|
|
@@ -82,49 +125,16 @@ class SkyServeLoadBalancer:
|
|
82
125
|
await asyncio.sleep(5)
|
83
126
|
|
84
127
|
while True:
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
},
|
96
|
-
timeout=aiohttp.ClientTimeout(5),
|
97
|
-
) as response:
|
98
|
-
# Clean up after reporting request info to avoid OOM.
|
99
|
-
self._request_aggregator.clear()
|
100
|
-
response.raise_for_status()
|
101
|
-
response_json = await response.json()
|
102
|
-
ready_replica_urls = response_json.get(
|
103
|
-
'ready_replica_urls', [])
|
104
|
-
except aiohttp.ClientError as e:
|
105
|
-
logger.error('An error occurred when syncing with '
|
106
|
-
f'the controller: {e}')
|
107
|
-
else:
|
108
|
-
logger.info(f'Available Replica URLs: {ready_replica_urls}')
|
109
|
-
with self._client_pool_lock:
|
110
|
-
self._load_balancing_policy.set_ready_replicas(
|
111
|
-
ready_replica_urls)
|
112
|
-
for replica_url in ready_replica_urls:
|
113
|
-
if replica_url not in self._client_pool:
|
114
|
-
self._client_pool[replica_url] = (
|
115
|
-
httpx.AsyncClient(base_url=replica_url))
|
116
|
-
urls_to_close = set(
|
117
|
-
self._client_pool.keys()) - set(ready_replica_urls)
|
118
|
-
client_to_close = []
|
119
|
-
for replica_url in urls_to_close:
|
120
|
-
client_to_close.append(
|
121
|
-
self._client_pool.pop(replica_url))
|
122
|
-
for client in client_to_close:
|
123
|
-
close_client_tasks.append(client.aclose())
|
124
|
-
|
125
|
-
await asyncio.sleep(constants.LB_CONTROLLER_SYNC_INTERVAL_SECONDS)
|
126
|
-
# Await those tasks after the interval to avoid blocking.
|
127
|
-
await asyncio.gather(*close_client_tasks)
|
128
|
+
try:
|
129
|
+
close_client_tasks = await self._sync_with_controller_once()
|
130
|
+
await asyncio.sleep(
|
131
|
+
constants.LB_CONTROLLER_SYNC_INTERVAL_SECONDS)
|
132
|
+
# Await those tasks after the interval to avoid blocking.
|
133
|
+
await asyncio.gather(*close_client_tasks)
|
134
|
+
except Exception as e: # pylint: disable=broad-except
|
135
|
+
logger.error(f'An error occurred when syncing with '
|
136
|
+
f'the controller: {e}'
|
137
|
+
f'\nTraceback: {traceback.format_exc()}')
|
128
138
|
|
129
139
|
async def _proxy_request_to(
|
130
140
|
self, url: str, request: fastapi.Request
|
@@ -168,7 +178,8 @@ class SkyServeLoadBalancer:
|
|
168
178
|
background=background.BackgroundTask(background_func))
|
169
179
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
170
180
|
logger.error(f'Error when proxy request to {url}: '
|
171
|
-
f'{common_utils.format_exception(e)}'
|
181
|
+
f'{common_utils.format_exception(e)}'
|
182
|
+
f'\nTraceback: {traceback.format_exc()}')
|
172
183
|
return e
|
173
184
|
|
174
185
|
async def _proxy_with_retries(
|
sky/server/constants.py
CHANGED
@@ -7,7 +7,7 @@ from sky.skylet import constants
|
|
7
7
|
# API server version, whenever there is a change in API server that requires a
|
8
8
|
# restart of the local API server or error out when the client does not match
|
9
9
|
# the server version.
|
10
|
-
API_VERSION = '
|
10
|
+
API_VERSION = '9'
|
11
11
|
|
12
12
|
# Prefix for API request names.
|
13
13
|
REQUEST_NAME_PREFIX = 'sky.'
|
@@ -25,8 +25,10 @@ API_SERVER_REQUEST_DB_PATH = '~/.sky/api_server/requests.db'
|
|
25
25
|
CLUSTER_REFRESH_DAEMON_INTERVAL_SECONDS = 60
|
26
26
|
|
27
27
|
# Environment variable for a file path to the API cookie file.
|
28
|
+
# Keep in sync with websocket_proxy.py
|
28
29
|
API_COOKIE_FILE_ENV_VAR = f'{constants.SKYPILOT_ENV_VAR_PREFIX}API_COOKIE_FILE'
|
29
30
|
# Default file if unset.
|
31
|
+
# Keep in sync with websocket_proxy.py
|
30
32
|
API_COOKIE_FILE_DEFAULT_LOCATION = '~/.sky/cookies.txt'
|
31
33
|
|
32
34
|
# The path to the dashboard build output
|
sky/server/requests/payloads.py
CHANGED
@@ -196,8 +196,10 @@ class LaunchBody(RequestBody):
|
|
196
196
|
task: str
|
197
197
|
cluster_name: str
|
198
198
|
retry_until_up: bool = False
|
199
|
+
# TODO(aylei): remove this field in v0.12.0
|
199
200
|
idle_minutes_to_autostop: Optional[int] = None
|
200
201
|
dryrun: bool = False
|
202
|
+
# TODO(aylei): remove this field in v0.12.0
|
201
203
|
down: bool = False
|
202
204
|
backend: Optional[str] = None
|
203
205
|
optimize_target: common_lib.OptimizeTarget = common_lib.OptimizeTarget.COST
|
@@ -331,6 +333,12 @@ class ClusterJobsDownloadLogsBody(RequestBody):
|
|
331
333
|
local_dir: str = constants.SKY_LOGS_DIRECTORY
|
332
334
|
|
333
335
|
|
336
|
+
class UserUpdateBody(RequestBody):
|
337
|
+
"""The request body for the user update endpoint."""
|
338
|
+
user_id: str
|
339
|
+
role: str
|
340
|
+
|
341
|
+
|
334
342
|
class DownloadBody(RequestBody):
|
335
343
|
"""The request body for the download endpoint."""
|
336
344
|
folder_paths: List[str]
|
@@ -375,6 +383,7 @@ class JobsQueueBody(RequestBody):
|
|
375
383
|
refresh: bool = False
|
376
384
|
skip_finished: bool = False
|
377
385
|
all_users: bool = False
|
386
|
+
job_ids: Optional[List[int]] = None
|
378
387
|
|
379
388
|
|
380
389
|
class JobsCancelBody(RequestBody):
|
sky/server/server.py
CHANGED
@@ -49,6 +49,7 @@ from sky.server.requests import preconditions
|
|
49
49
|
from sky.server.requests import requests as requests_lib
|
50
50
|
from sky.skylet import constants
|
51
51
|
from sky.usage import usage_lib
|
52
|
+
from sky.users import server as users_rest
|
52
53
|
from sky.utils import admin_policy_utils
|
53
54
|
from sky.utils import common as common_lib
|
54
55
|
from sky.utils import common_utils
|
@@ -100,6 +101,27 @@ logger = sky_logging.init_logger(__name__)
|
|
100
101
|
# response will block other requests from being processed.
|
101
102
|
|
102
103
|
|
104
|
+
class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
105
|
+
"""Middleware to handle RBAC."""
|
106
|
+
|
107
|
+
async def dispatch(self, request: fastapi.Request, call_next):
|
108
|
+
if request.url.path.startswith('/dashboard/'):
|
109
|
+
return await call_next(request)
|
110
|
+
|
111
|
+
auth_user = _get_auth_user_header(request)
|
112
|
+
if auth_user is None:
|
113
|
+
return await call_next(request)
|
114
|
+
|
115
|
+
permission_service = users_rest.permission_service
|
116
|
+
# Check the role permission
|
117
|
+
if permission_service.check_permission(auth_user.id, request.url.path,
|
118
|
+
request.method):
|
119
|
+
return fastapi.responses.JSONResponse(
|
120
|
+
status_code=403, content={'detail': 'Forbidden'})
|
121
|
+
|
122
|
+
return await call_next(request)
|
123
|
+
|
124
|
+
|
103
125
|
class RequestIDMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
104
126
|
"""Middleware to add a request ID to each request."""
|
105
127
|
|
@@ -130,7 +152,10 @@ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
130
152
|
|
131
153
|
# Add user to database if auth_user is present
|
132
154
|
if auth_user is not None:
|
133
|
-
global_user_state.add_or_update_user(auth_user)
|
155
|
+
newly_added = global_user_state.add_or_update_user(auth_user)
|
156
|
+
if newly_added:
|
157
|
+
users_rest.permission_service.add_user_if_not_exists(
|
158
|
+
auth_user.id)
|
134
159
|
|
135
160
|
body = await request.body()
|
136
161
|
if auth_user and body:
|
@@ -244,11 +269,13 @@ class PathCleanMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
|
|
244
269
|
parent = pathlib.Path('/dashboard')
|
245
270
|
request_path = pathlib.Path(posixpath.normpath(request.url.path))
|
246
271
|
if not _is_relative_to(request_path, parent):
|
247
|
-
|
272
|
+
return fastapi.responses.JSONResponse(
|
273
|
+
status_code=403, content={'detail': 'Forbidden'})
|
248
274
|
return await call_next(request)
|
249
275
|
|
250
276
|
|
251
277
|
app = fastapi.FastAPI(prefix='/api/v1', debug=True, lifespan=lifespan)
|
278
|
+
app.add_middleware(RBACMiddleware)
|
252
279
|
app.add_middleware(InternalDashboardPrefixMiddleware)
|
253
280
|
app.add_middleware(PathCleanMiddleware)
|
254
281
|
app.add_middleware(CacheControlStaticMiddleware)
|
@@ -266,6 +293,7 @@ app.add_middleware(AuthProxyMiddleware)
|
|
266
293
|
app.add_middleware(RequestIDMiddleware)
|
267
294
|
app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
|
268
295
|
app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
|
296
|
+
app.include_router(users_rest.router, prefix='/users', tags=['users'])
|
269
297
|
app.include_router(workspaces_rest.router,
|
270
298
|
prefix='/workspaces',
|
271
299
|
tags=['workspaces'])
|
@@ -835,13 +863,6 @@ async def logs(
|
|
835
863
|
)
|
836
864
|
|
837
865
|
|
838
|
-
@app.get('/users')
|
839
|
-
async def users() -> List[Dict[str, Any]]:
|
840
|
-
"""Gets all users."""
|
841
|
-
user_list = global_user_state.get_all_users()
|
842
|
-
return [user.to_dict() for user in user_list]
|
843
|
-
|
844
|
-
|
845
866
|
@app.post('/download_logs')
|
846
867
|
async def download_logs(
|
847
868
|
request: fastapi.Request,
|
sky/setup_files/MANIFEST.in
CHANGED
sky/setup_files/dependencies.py
CHANGED
sky/skylet/constants.py
CHANGED
@@ -379,7 +379,7 @@ OVERRIDEABLE_CONFIG_KEYS_IN_TASK: List[Tuple[str, ...]] = [
|
|
379
379
|
SKIPPED_CLIENT_OVERRIDE_KEYS: List[Tuple[str, ...]] = [('admin_policy',),
|
380
380
|
('api_server',),
|
381
381
|
('allowed_clouds',),
|
382
|
-
('workspaces',)]
|
382
|
+
('workspaces',), ('db',)]
|
383
383
|
|
384
384
|
# Constants for Azure blob storage
|
385
385
|
WAIT_FOR_STORAGE_ACCOUNT_CREATION = 60
|
@@ -409,6 +409,12 @@ ENV_VAR_IS_SKYPILOT_SERVER = 'IS_SKYPILOT_SERVER'
|
|
409
409
|
|
410
410
|
SKYPILOT_DEFAULT_WORKSPACE = 'default'
|
411
411
|
|
412
|
-
#
|
413
|
-
|
414
|
-
|
412
|
+
# BEGIN constants used for service catalog.
|
413
|
+
HOSTED_CATALOG_DIR_URL = 'https://raw.githubusercontent.com/skypilot-org/skypilot-catalog/master/catalogs' # pylint: disable=line-too-long
|
414
|
+
HOSTED_CATALOG_DIR_URL_S3_MIRROR = 'https://skypilot-catalog.s3.us-east-1.amazonaws.com/catalogs' # pylint: disable=line-too-long
|
415
|
+
CATALOG_SCHEMA_VERSION = 'v7'
|
416
|
+
CATALOG_DIR = '~/.sky/catalogs'
|
417
|
+
ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci',
|
418
|
+
'kubernetes', 'runpod', 'vast', 'vsphere', 'cudo', 'fluidstack',
|
419
|
+
'paperspace', 'do', 'nebius', 'ssh')
|
420
|
+
# END constants used for service catalog.
|
sky/skypilot_config.py
CHANGED
@@ -756,13 +756,15 @@ def apply_cli_config(cli_config: Optional[List[str]]) -> Dict[str, Any]:
|
|
756
756
|
return parsed_config
|
757
757
|
|
758
758
|
|
759
|
-
def
|
759
|
+
def update_api_server_config_no_lock(config: config_utils.Config) -> None:
|
760
760
|
"""Dumps the new config to a file and syncs to ConfigMap if in Kubernetes.
|
761
761
|
|
762
762
|
Args:
|
763
763
|
config: The config to save and sync.
|
764
764
|
"""
|
765
|
-
global_config_path =
|
765
|
+
global_config_path = _resolve_server_config_path()
|
766
|
+
if global_config_path is None:
|
767
|
+
global_config_path = get_user_config_path()
|
766
768
|
|
767
769
|
# Always save to the local file (PVC in Kubernetes, local file otherwise)
|
768
770
|
common_utils.dump_yaml(global_config_path, dict(config))
|
sky/templates/websocket_proxy.py
CHANGED
@@ -21,11 +21,21 @@ from websockets.asyncio.client import connect
|
|
21
21
|
|
22
22
|
BUFFER_SIZE = 2**16 # 64KB
|
23
23
|
|
24
|
+
# Environment variable for a file path to the API cookie file.
|
25
|
+
# Keep in sync with server/constants.py
|
26
|
+
API_COOKIE_FILE_ENV_VAR = 'SKYPILOT_API_COOKIE_FILE'
|
27
|
+
# Default file if unset.
|
28
|
+
# Keep in sync with server/constants.py
|
29
|
+
API_COOKIE_FILE_DEFAULT_LOCATION = '~/.sky/cookies.txt'
|
30
|
+
|
24
31
|
|
25
32
|
def _get_cookie_header(url: str) -> Dict[str, str]:
|
26
33
|
"""Extract Cookie header value from a cookie jar for a specific URL"""
|
27
|
-
cookie_path = os.environ.get(
|
34
|
+
cookie_path = os.environ.get(API_COOKIE_FILE_ENV_VAR)
|
28
35
|
if cookie_path is None:
|
36
|
+
cookie_path = API_COOKIE_FILE_DEFAULT_LOCATION
|
37
|
+
cookie_path = os.path.expanduser(cookie_path)
|
38
|
+
if not os.path.exists(cookie_path):
|
29
39
|
return {}
|
30
40
|
|
31
41
|
request = Request(url)
|
sky/users/__init__.py
ADDED
File without changes
|
sky/users/model.conf
ADDED
@@ -0,0 +1,15 @@
|
|
1
|
+
# rbac_model.conf
|
2
|
+
[request_definition]
|
3
|
+
r = sub, obj, act
|
4
|
+
|
5
|
+
[policy_definition]
|
6
|
+
p = sub, obj, act
|
7
|
+
|
8
|
+
[role_definition]
|
9
|
+
g = _, _
|
10
|
+
|
11
|
+
[policy_effect]
|
12
|
+
e = some(where (p.eft == allow))
|
13
|
+
|
14
|
+
[matchers]
|
15
|
+
m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act
|
sky/users/permission.py
ADDED
@@ -0,0 +1,178 @@
|
|
1
|
+
"""Permission service for SkyPilot API Server."""
|
2
|
+
import contextlib
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import threading
|
6
|
+
from typing import List
|
7
|
+
|
8
|
+
import casbin
|
9
|
+
import filelock
|
10
|
+
import sqlalchemy_adapter
|
11
|
+
|
12
|
+
from sky import global_user_state
|
13
|
+
from sky import sky_logging
|
14
|
+
from sky.users import rbac
|
15
|
+
|
16
|
+
logger = sky_logging.init_logger(__name__)
|
17
|
+
|
18
|
+
# Filelocks for the policy update.
|
19
|
+
POLICY_UPDATE_LOCK_PATH = os.path.expanduser('~/.sky/.policy_update.lock')
|
20
|
+
POLICY_UPDATE_LOCK_TIMEOUT_SECONDS = 20
|
21
|
+
|
22
|
+
_enforcer_instance = None
|
23
|
+
_lock = threading.Lock()
|
24
|
+
|
25
|
+
|
26
|
+
class PermissionService:
|
27
|
+
"""Permission service for SkyPilot API Server."""
|
28
|
+
|
29
|
+
def __init__(self):
|
30
|
+
global _enforcer_instance
|
31
|
+
if _enforcer_instance is None:
|
32
|
+
# For different threads, we share the same enforcer instance.
|
33
|
+
with _lock:
|
34
|
+
if _enforcer_instance is None:
|
35
|
+
_enforcer_instance = self
|
36
|
+
engine = global_user_state.SQLALCHEMY_ENGINE
|
37
|
+
adapter = sqlalchemy_adapter.Adapter(engine)
|
38
|
+
model_path = os.path.join(os.path.dirname(__file__),
|
39
|
+
'model.conf')
|
40
|
+
enforcer = casbin.Enforcer(model_path, adapter)
|
41
|
+
logging.getLogger('casbin.policy').setLevel(
|
42
|
+
sky_logging.ERROR)
|
43
|
+
logging.getLogger('casbin.role').setLevel(sky_logging.ERROR)
|
44
|
+
self.enforcer = enforcer
|
45
|
+
else:
|
46
|
+
self.enforcer = _enforcer_instance.enforcer
|
47
|
+
self._maybe_initialize_policies()
|
48
|
+
|
49
|
+
def _maybe_initialize_policies(self):
|
50
|
+
"""Initialize policies if they don't already exist."""
|
51
|
+
logger.debug(f'Initializing policies in process: {os.getpid()}')
|
52
|
+
|
53
|
+
# Check if policies are already initialized by looking for existing
|
54
|
+
# permission policies in the enforcer
|
55
|
+
existing_policies = self.enforcer.get_policy()
|
56
|
+
|
57
|
+
# If we already have policies for the expected roles, skip
|
58
|
+
# initialization
|
59
|
+
role_permissions = rbac.get_role_permissions()
|
60
|
+
expected_policies = []
|
61
|
+
for role, permissions in role_permissions.items():
|
62
|
+
if permissions['permissions'] and 'blocklist' in permissions[
|
63
|
+
'permissions']:
|
64
|
+
blocklist = permissions['permissions']['blocklist']
|
65
|
+
for item in blocklist:
|
66
|
+
expected_policies.append(
|
67
|
+
[role, item['path'], item['method']])
|
68
|
+
|
69
|
+
# Check if all expected policies already exist
|
70
|
+
policies_exist = all(
|
71
|
+
any(policy == expected
|
72
|
+
for policy in existing_policies)
|
73
|
+
for expected in expected_policies)
|
74
|
+
|
75
|
+
if not policies_exist:
|
76
|
+
# Only clear and reinitialize if policies don't exist or are
|
77
|
+
# incomplete
|
78
|
+
logger.debug('Policies not found or incomplete, initializing...')
|
79
|
+
# Only clear p policies (permission policies),
|
80
|
+
# keep g policies (role policies)
|
81
|
+
self.enforcer.remove_filtered_policy(0)
|
82
|
+
for role, permissions in role_permissions.items():
|
83
|
+
if permissions['permissions'] and 'blocklist' in permissions[
|
84
|
+
'permissions']:
|
85
|
+
blocklist = permissions['permissions']['blocklist']
|
86
|
+
for item in blocklist:
|
87
|
+
path = item['path']
|
88
|
+
method = item['method']
|
89
|
+
self.enforcer.add_policy(role, path, method)
|
90
|
+
self.enforcer.save_policy()
|
91
|
+
else:
|
92
|
+
logger.debug('Policies already exist, skipping initialization')
|
93
|
+
|
94
|
+
# Always ensure users have default roles (this is idempotent)
|
95
|
+
all_users = global_user_state.get_all_users()
|
96
|
+
for user in all_users:
|
97
|
+
self.add_user_if_not_exists(user.id)
|
98
|
+
|
99
|
+
def add_user_if_not_exists(self, user: str) -> None:
|
100
|
+
"""Add user role relationship."""
|
101
|
+
with _policy_lock():
|
102
|
+
user_roles = self.enforcer.get_roles_for_user(user)
|
103
|
+
if not user_roles:
|
104
|
+
logger.info(f'User {user} has no roles, adding'
|
105
|
+
f' default role {rbac.get_default_role()}')
|
106
|
+
self.enforcer.add_grouping_policy(user, rbac.get_default_role())
|
107
|
+
self.enforcer.save_policy()
|
108
|
+
|
109
|
+
def update_role(self, user: str, new_role: str):
|
110
|
+
"""Update user role relationship."""
|
111
|
+
with _policy_lock():
|
112
|
+
# Get current roles
|
113
|
+
self._load_policy_no_lock()
|
114
|
+
# Avoid calling get_user_roles, as it will require the lock.
|
115
|
+
current_roles = self.enforcer.get_roles_for_user(user)
|
116
|
+
if not current_roles:
|
117
|
+
logger.warning(f'User {user} has no roles')
|
118
|
+
else:
|
119
|
+
# TODO(hailong): how to handle multiple roles?
|
120
|
+
current_role = current_roles[0]
|
121
|
+
if current_role == new_role:
|
122
|
+
logger.info(f'User {user} already has role {new_role}')
|
123
|
+
return
|
124
|
+
self.enforcer.remove_grouping_policy(user, current_role)
|
125
|
+
|
126
|
+
# Update user role
|
127
|
+
self.enforcer.add_grouping_policy(user, new_role)
|
128
|
+
self.enforcer.save_policy()
|
129
|
+
|
130
|
+
def get_user_roles(self, user: str) -> List[str]:
|
131
|
+
"""Get all roles for a user.
|
132
|
+
|
133
|
+
This method returns all roles that the user has, including inherited
|
134
|
+
roles. For example, if a user has role 'admin' and 'admin' inherits
|
135
|
+
from 'user', this method will return ['admin', 'user'].
|
136
|
+
|
137
|
+
Args:
|
138
|
+
user: The user ID to get roles for.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
A list of role names that the user has.
|
142
|
+
"""
|
143
|
+
self._load_policy()
|
144
|
+
return self.enforcer.get_roles_for_user(user)
|
145
|
+
|
146
|
+
def check_permission(self, user: str, path: str, method: str) -> bool:
|
147
|
+
"""Check permission."""
|
148
|
+
# We intentionally don't load the policy here, as it is a hot path, and
|
149
|
+
# we don't support updating the policy.
|
150
|
+
# We don't hold the lock for checking permission, as it is read only and
|
151
|
+
# it is a hot path in every request. It is ok to have a stale policy,
|
152
|
+
# as long as it is eventually consistent.
|
153
|
+
# self._load_policy_no_lock()
|
154
|
+
return self.enforcer.enforce(user, path, method)
|
155
|
+
|
156
|
+
def _load_policy_no_lock(self):
|
157
|
+
"""Load policy from storage."""
|
158
|
+
self.enforcer.load_policy()
|
159
|
+
|
160
|
+
def _load_policy(self):
|
161
|
+
"""Load policy from storage with lock."""
|
162
|
+
with _policy_lock():
|
163
|
+
self._load_policy_no_lock()
|
164
|
+
|
165
|
+
|
166
|
+
@contextlib.contextmanager
|
167
|
+
def _policy_lock():
|
168
|
+
"""Context manager for policy update lock."""
|
169
|
+
try:
|
170
|
+
with filelock.FileLock(POLICY_UPDATE_LOCK_PATH,
|
171
|
+
POLICY_UPDATE_LOCK_TIMEOUT_SECONDS):
|
172
|
+
yield
|
173
|
+
except filelock.Timeout as e:
|
174
|
+
raise RuntimeError(f'Failed to load policy due to a timeout '
|
175
|
+
f'when trying to acquire the lock at '
|
176
|
+
f'{POLICY_UPDATE_LOCK_PATH}. '
|
177
|
+
'Please try again or manually remove the lock '
|
178
|
+
f'file if you believe it is stale.') from e
|
sky/users/rbac.py
ADDED
@@ -0,0 +1,86 @@
|
|
1
|
+
"""RBAC (Role-Based Access Control) functionality for SkyPilot API Server."""
|
2
|
+
|
3
|
+
import enum
|
4
|
+
from typing import Dict, List
|
5
|
+
|
6
|
+
from sky import sky_logging
|
7
|
+
from sky import skypilot_config
|
8
|
+
|
9
|
+
logger = sky_logging.init_logger(__name__)
|
10
|
+
|
11
|
+
# Default user blocklist for user role
|
12
|
+
# Cannot access workspace CUD operations
|
13
|
+
_DEFAULT_USER_BLOCKLIST = [{
|
14
|
+
'path': '/workspaces/config',
|
15
|
+
'method': 'POST'
|
16
|
+
}, {
|
17
|
+
'path': '/workspaces/update',
|
18
|
+
'method': 'POST'
|
19
|
+
}, {
|
20
|
+
'path': '/workspaces/create',
|
21
|
+
'method': 'POST'
|
22
|
+
}, {
|
23
|
+
'path': '/workspaces/delete',
|
24
|
+
'method': 'POST'
|
25
|
+
}, {
|
26
|
+
'path': '/users/update',
|
27
|
+
'method': 'POST'
|
28
|
+
}]
|
29
|
+
|
30
|
+
|
31
|
+
# Define roles
|
32
|
+
class RoleName(str, enum.Enum):
|
33
|
+
ADMIN = 'admin'
|
34
|
+
USER = 'user'
|
35
|
+
|
36
|
+
|
37
|
+
def get_supported_roles() -> List[str]:
|
38
|
+
return [role_name.value for role_name in RoleName]
|
39
|
+
|
40
|
+
|
41
|
+
def get_default_role() -> str:
|
42
|
+
return skypilot_config.get_nested(('rbac', 'default_role'),
|
43
|
+
default_value=RoleName.ADMIN.value)
|
44
|
+
|
45
|
+
|
46
|
+
def get_role_permissions(
|
47
|
+
) -> Dict[str, Dict[str, Dict[str, List[Dict[str, str]]]]]:
|
48
|
+
"""Get all role permissions from config.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
Dictionary containing all roles and their permissions configuration.
|
52
|
+
Example:
|
53
|
+
{
|
54
|
+
'admin': {
|
55
|
+
'permissions': {
|
56
|
+
'blocklist': []
|
57
|
+
}
|
58
|
+
},
|
59
|
+
'user': {
|
60
|
+
'permissions': {
|
61
|
+
'blocklist': [
|
62
|
+
{'path': '/workspaces/config', 'method': 'POST'},
|
63
|
+
{'path': '/workspaces/update', 'method': 'POST'}
|
64
|
+
]
|
65
|
+
}
|
66
|
+
}
|
67
|
+
}
|
68
|
+
"""
|
69
|
+
# Get all roles from the config
|
70
|
+
config_permissions = skypilot_config.get_nested(('rbac', 'roles'),
|
71
|
+
default_value={})
|
72
|
+
supported_roles = get_supported_roles()
|
73
|
+
for role, permissions in config_permissions.items():
|
74
|
+
role_name = role.lower()
|
75
|
+
if role_name not in supported_roles:
|
76
|
+
logger.warning(f'Invalid role: {role_name}')
|
77
|
+
continue
|
78
|
+
config_permissions[role_name] = permissions
|
79
|
+
# Add default roles if not present
|
80
|
+
if 'user' not in config_permissions:
|
81
|
+
config_permissions['user'] = {
|
82
|
+
'permissions': {
|
83
|
+
'blocklist': _DEFAULT_USER_BLOCKLIST
|
84
|
+
}
|
85
|
+
}
|
86
|
+
return config_permissions
|