skypilot-nightly 1.0.0.dev20250604__py3-none-any.whl → 1.0.0.dev20250606__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. sky/__init__.py +2 -2
  2. sky/admin_policy.py +5 -0
  3. sky/catalog/__init__.py +2 -2
  4. sky/catalog/common.py +7 -9
  5. sky/cli.py +11 -9
  6. sky/client/cli.py +11 -9
  7. sky/client/sdk.py +30 -12
  8. sky/clouds/kubernetes.py +2 -2
  9. sky/dashboard/out/404.html +1 -1
  10. sky/dashboard/out/_next/static/99m-BAySO8Q7J-ul1jZVL/_buildManifest.js +1 -0
  11. sky/dashboard/out/_next/static/chunks/{236-fef38aa6e5639300.js → 236-a90f0a9753a10420.js} +2 -2
  12. sky/dashboard/out/_next/static/chunks/614-635a84e87800f99e.js +66 -0
  13. sky/dashboard/out/_next/static/chunks/{856-f1b1f7f47edde2e8.js → 856-3a32da4b84176f6d.js} +1 -1
  14. sky/dashboard/out/_next/static/chunks/937.3759f538f11a0953.js +1 -0
  15. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-35cbeb5214fd4036.js +6 -0
  16. sky/dashboard/out/_next/static/chunks/pages/config-1a1eeb949dab8897.js +6 -0
  17. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-2d23a9c7571e6320.js +16 -0
  18. sky/dashboard/out/_next/static/chunks/pages/users-262aab38b9baaf3a.js +16 -0
  19. sky/dashboard/out/_next/static/chunks/pages/workspaces-384ea5fa0cea8f28.js +1 -0
  20. sky/dashboard/out/_next/static/chunks/{webpack-f27c9a32aa3d9c6d.js → webpack-65d465f948974c0d.js} +1 -1
  21. sky/dashboard/out/_next/static/css/667d941a2888ce6e.css +3 -0
  22. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  23. sky/dashboard/out/clusters/[cluster].html +1 -1
  24. sky/dashboard/out/clusters.html +1 -1
  25. sky/dashboard/out/config.html +1 -1
  26. sky/dashboard/out/index.html +1 -1
  27. sky/dashboard/out/infra/[context].html +1 -1
  28. sky/dashboard/out/infra.html +1 -1
  29. sky/dashboard/out/jobs/[job].html +1 -1
  30. sky/dashboard/out/jobs.html +1 -1
  31. sky/dashboard/out/users.html +1 -1
  32. sky/dashboard/out/workspace/new.html +1 -1
  33. sky/dashboard/out/workspaces/[name].html +1 -1
  34. sky/dashboard/out/workspaces.html +1 -1
  35. sky/execution.py +44 -46
  36. sky/global_user_state.py +118 -83
  37. sky/jobs/client/sdk.py +4 -1
  38. sky/jobs/server/core.py +5 -1
  39. sky/models.py +1 -0
  40. sky/resources.py +22 -1
  41. sky/serve/load_balancer.py +56 -45
  42. sky/server/constants.py +3 -1
  43. sky/server/requests/payloads.py +9 -0
  44. sky/server/server.py +30 -9
  45. sky/setup_files/MANIFEST.in +1 -0
  46. sky/setup_files/dependencies.py +2 -0
  47. sky/skylet/constants.py +10 -4
  48. sky/skypilot_config.py +4 -2
  49. sky/templates/websocket_proxy.py +11 -1
  50. sky/users/__init__.py +0 -0
  51. sky/users/model.conf +15 -0
  52. sky/users/permission.py +178 -0
  53. sky/users/rbac.py +86 -0
  54. sky/users/server.py +66 -0
  55. sky/utils/schemas.py +20 -7
  56. sky/workspaces/core.py +2 -2
  57. {skypilot_nightly-1.0.0.dev20250604.dist-info → skypilot_nightly-1.0.0.dev20250606.dist-info}/METADATA +3 -1
  58. {skypilot_nightly-1.0.0.dev20250604.dist-info → skypilot_nightly-1.0.0.dev20250606.dist-info}/RECORD +70 -66
  59. sky/catalog/constants.py +0 -8
  60. sky/dashboard/out/_next/static/chunks/614-3d29f98e0634b179.js +0 -66
  61. sky/dashboard/out/_next/static/chunks/937.f97f83652028e944.js +0 -1
  62. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-62c9982dc3675725.js +0 -6
  63. sky/dashboard/out/_next/static/chunks/pages/config-35383adcb0edb5e2.js +0 -6
  64. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-a62a3c65dc9bc57c.js +0 -11
  65. sky/dashboard/out/_next/static/chunks/pages/users-07b523ccb19317ad.js +0 -6
  66. sky/dashboard/out/_next/static/chunks/pages/workspaces-f54921ec9eb20965.js +0 -1
  67. sky/dashboard/out/_next/static/css/63d3995d8b528eb1.css +0 -3
  68. sky/dashboard/out/_next/static/vWwfD3jOky5J5jULHp8JT/_buildManifest.js +0 -1
  69. /sky/dashboard/out/_next/static/{vWwfD3jOky5J5jULHp8JT → 99m-BAySO8Q7J-ul1jZVL}/_ssgManifest.js +0 -0
  70. /sky/dashboard/out/_next/static/chunks/{121-8f55ee3fa6301784.js → 121-865d2bf8a3b84c6a.js} +0 -0
  71. /sky/dashboard/out/_next/static/chunks/{37-947904ccc5687bac.js → 37-beedd583fea84cc8.js} +0 -0
  72. /sky/dashboard/out/_next/static/chunks/{682-2be9b0f169727f2f.js → 682-6647f0417d5662f0.js} +0 -0
  73. /sky/dashboard/out/_next/static/chunks/{843-a097338acb89b7d7.js → 843-c296541442d4af88.js} +0 -0
  74. /sky/dashboard/out/_next/static/chunks/{969-d7b6fb7f602bfcb3.js → 969-c7abda31c10440ac.js} +0 -0
  75. /sky/dashboard/out/_next/static/chunks/pages/{_app-67925f5e6382e22f.js → _app-cb81dc4d27f4d009.js} +0 -0
  76. /sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/{[job]-158b70da336d8607.js → [job]-65d04d5d77cbb6b6.js} +0 -0
  77. {skypilot_nightly-1.0.0.dev20250604.dist-info → skypilot_nightly-1.0.0.dev20250606.dist-info}/WHEEL +0 -0
  78. {skypilot_nightly-1.0.0.dev20250604.dist-info → skypilot_nightly-1.0.0.dev20250606.dist-info}/entry_points.txt +0 -0
  79. {skypilot_nightly-1.0.0.dev20250604.dist-info → skypilot_nightly-1.0.0.dev20250606.dist-info}/licenses/LICENSE +0 -0
  80. {skypilot_nightly-1.0.0.dev20250604.dist-info → skypilot_nightly-1.0.0.dev20250606.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,8 @@
2
2
  import asyncio
3
3
  import logging
4
4
  import threading
5
- from typing import Dict, Optional, Union
5
+ import traceback
6
+ from typing import Dict, List, Optional, Union
6
7
 
7
8
  import aiohttp
8
9
  import fastapi
@@ -69,6 +70,48 @@ class SkyServeLoadBalancer:
69
70
  # updating it from _sync_with_controller.
70
71
  self._client_pool_lock: threading.Lock = threading.Lock()
71
72
 
73
+ async def _sync_with_controller_once(self) -> List[asyncio.Task]:
74
+ close_client_tasks = []
75
+ async with aiohttp.ClientSession() as session:
76
+ try:
77
+ # Send request information
78
+ async with session.post(
79
+ self._controller_url + '/controller/load_balancer_sync',
80
+ json={
81
+ 'request_aggregator':
82
+ self._request_aggregator.to_dict()
83
+ },
84
+ timeout=aiohttp.ClientTimeout(5),
85
+ ) as response:
86
+ # Clean up after reporting request info to avoid OOM.
87
+ self._request_aggregator.clear()
88
+ response.raise_for_status()
89
+ response_json = await response.json()
90
+ ready_replica_urls = response_json.get(
91
+ 'ready_replica_urls', [])
92
+ except (aiohttp.ClientError, asyncio.TimeoutError) as e:
93
+ logger.error(f'An error occurred when syncing with '
94
+ f'the controller: {e}'
95
+ f'\nTraceback: {traceback.format_exc()}')
96
+ else:
97
+ logger.info(f'Available Replica URLs: {ready_replica_urls}')
98
+ with self._client_pool_lock:
99
+ self._load_balancing_policy.set_ready_replicas(
100
+ ready_replica_urls)
101
+ for replica_url in ready_replica_urls:
102
+ if replica_url not in self._client_pool:
103
+ self._client_pool[replica_url] = httpx.AsyncClient(
104
+ base_url=replica_url)
105
+ urls_to_close = set(
106
+ self._client_pool.keys()) - set(ready_replica_urls)
107
+ client_to_close = []
108
+ for replica_url in urls_to_close:
109
+ client_to_close.append(
110
+ self._client_pool.pop(replica_url))
111
+ for client in client_to_close:
112
+ close_client_tasks.append(client.aclose())
113
+ return close_client_tasks
114
+
72
115
  async def _sync_with_controller(self):
73
116
  """Sync with controller periodically.
74
117
 
@@ -82,49 +125,16 @@ class SkyServeLoadBalancer:
82
125
  await asyncio.sleep(5)
83
126
 
84
127
  while True:
85
- close_client_tasks = []
86
- async with aiohttp.ClientSession() as session:
87
- try:
88
- # Send request information
89
- async with session.post(
90
- self._controller_url +
91
- '/controller/load_balancer_sync',
92
- json={
93
- 'request_aggregator':
94
- self._request_aggregator.to_dict()
95
- },
96
- timeout=aiohttp.ClientTimeout(5),
97
- ) as response:
98
- # Clean up after reporting request info to avoid OOM.
99
- self._request_aggregator.clear()
100
- response.raise_for_status()
101
- response_json = await response.json()
102
- ready_replica_urls = response_json.get(
103
- 'ready_replica_urls', [])
104
- except aiohttp.ClientError as e:
105
- logger.error('An error occurred when syncing with '
106
- f'the controller: {e}')
107
- else:
108
- logger.info(f'Available Replica URLs: {ready_replica_urls}')
109
- with self._client_pool_lock:
110
- self._load_balancing_policy.set_ready_replicas(
111
- ready_replica_urls)
112
- for replica_url in ready_replica_urls:
113
- if replica_url not in self._client_pool:
114
- self._client_pool[replica_url] = (
115
- httpx.AsyncClient(base_url=replica_url))
116
- urls_to_close = set(
117
- self._client_pool.keys()) - set(ready_replica_urls)
118
- client_to_close = []
119
- for replica_url in urls_to_close:
120
- client_to_close.append(
121
- self._client_pool.pop(replica_url))
122
- for client in client_to_close:
123
- close_client_tasks.append(client.aclose())
124
-
125
- await asyncio.sleep(constants.LB_CONTROLLER_SYNC_INTERVAL_SECONDS)
126
- # Await those tasks after the interval to avoid blocking.
127
- await asyncio.gather(*close_client_tasks)
128
+ try:
129
+ close_client_tasks = await self._sync_with_controller_once()
130
+ await asyncio.sleep(
131
+ constants.LB_CONTROLLER_SYNC_INTERVAL_SECONDS)
132
+ # Await those tasks after the interval to avoid blocking.
133
+ await asyncio.gather(*close_client_tasks)
134
+ except Exception as e: # pylint: disable=broad-except
135
+ logger.error(f'An error occurred when syncing with '
136
+ f'the controller: {e}'
137
+ f'\nTraceback: {traceback.format_exc()}')
128
138
 
129
139
  async def _proxy_request_to(
130
140
  self, url: str, request: fastapi.Request
@@ -168,7 +178,8 @@ class SkyServeLoadBalancer:
168
178
  background=background.BackgroundTask(background_func))
169
179
  except (httpx.RequestError, httpx.HTTPStatusError) as e:
170
180
  logger.error(f'Error when proxy request to {url}: '
171
- f'{common_utils.format_exception(e)}')
181
+ f'{common_utils.format_exception(e)}'
182
+ f'\nTraceback: {traceback.format_exc()}')
172
183
  return e
173
184
 
174
185
  async def _proxy_with_retries(
sky/server/constants.py CHANGED
@@ -7,7 +7,7 @@ from sky.skylet import constants
7
7
  # API server version, whenever there is a change in API server that requires a
8
8
  # restart of the local API server or error out when the client does not match
9
9
  # the server version.
10
- API_VERSION = '8'
10
+ API_VERSION = '9'
11
11
 
12
12
  # Prefix for API request names.
13
13
  REQUEST_NAME_PREFIX = 'sky.'
@@ -25,8 +25,10 @@ API_SERVER_REQUEST_DB_PATH = '~/.sky/api_server/requests.db'
25
25
  CLUSTER_REFRESH_DAEMON_INTERVAL_SECONDS = 60
26
26
 
27
27
  # Environment variable for a file path to the API cookie file.
28
+ # Keep in sync with websocket_proxy.py
28
29
  API_COOKIE_FILE_ENV_VAR = f'{constants.SKYPILOT_ENV_VAR_PREFIX}API_COOKIE_FILE'
29
30
  # Default file if unset.
31
+ # Keep in sync with websocket_proxy.py
30
32
  API_COOKIE_FILE_DEFAULT_LOCATION = '~/.sky/cookies.txt'
31
33
 
32
34
  # The path to the dashboard build output
@@ -196,8 +196,10 @@ class LaunchBody(RequestBody):
196
196
  task: str
197
197
  cluster_name: str
198
198
  retry_until_up: bool = False
199
+ # TODO(aylei): remove this field in v0.12.0
199
200
  idle_minutes_to_autostop: Optional[int] = None
200
201
  dryrun: bool = False
202
+ # TODO(aylei): remove this field in v0.12.0
201
203
  down: bool = False
202
204
  backend: Optional[str] = None
203
205
  optimize_target: common_lib.OptimizeTarget = common_lib.OptimizeTarget.COST
@@ -331,6 +333,12 @@ class ClusterJobsDownloadLogsBody(RequestBody):
331
333
  local_dir: str = constants.SKY_LOGS_DIRECTORY
332
334
 
333
335
 
336
+ class UserUpdateBody(RequestBody):
337
+ """The request body for the user update endpoint."""
338
+ user_id: str
339
+ role: str
340
+
341
+
334
342
  class DownloadBody(RequestBody):
335
343
  """The request body for the download endpoint."""
336
344
  folder_paths: List[str]
@@ -375,6 +383,7 @@ class JobsQueueBody(RequestBody):
375
383
  refresh: bool = False
376
384
  skip_finished: bool = False
377
385
  all_users: bool = False
386
+ job_ids: Optional[List[int]] = None
378
387
 
379
388
 
380
389
  class JobsCancelBody(RequestBody):
sky/server/server.py CHANGED
@@ -49,6 +49,7 @@ from sky.server.requests import preconditions
49
49
  from sky.server.requests import requests as requests_lib
50
50
  from sky.skylet import constants
51
51
  from sky.usage import usage_lib
52
+ from sky.users import server as users_rest
52
53
  from sky.utils import admin_policy_utils
53
54
  from sky.utils import common as common_lib
54
55
  from sky.utils import common_utils
@@ -100,6 +101,27 @@ logger = sky_logging.init_logger(__name__)
100
101
  # response will block other requests from being processed.
101
102
 
102
103
 
104
+ class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
105
+ """Middleware to handle RBAC."""
106
+
107
+ async def dispatch(self, request: fastapi.Request, call_next):
108
+ if request.url.path.startswith('/dashboard/'):
109
+ return await call_next(request)
110
+
111
+ auth_user = _get_auth_user_header(request)
112
+ if auth_user is None:
113
+ return await call_next(request)
114
+
115
+ permission_service = users_rest.permission_service
116
+ # Check the role permission
117
+ if permission_service.check_permission(auth_user.id, request.url.path,
118
+ request.method):
119
+ return fastapi.responses.JSONResponse(
120
+ status_code=403, content={'detail': 'Forbidden'})
121
+
122
+ return await call_next(request)
123
+
124
+
103
125
  class RequestIDMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
104
126
  """Middleware to add a request ID to each request."""
105
127
 
@@ -130,7 +152,10 @@ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
130
152
 
131
153
  # Add user to database if auth_user is present
132
154
  if auth_user is not None:
133
- global_user_state.add_or_update_user(auth_user)
155
+ newly_added = global_user_state.add_or_update_user(auth_user)
156
+ if newly_added:
157
+ users_rest.permission_service.add_user_if_not_exists(
158
+ auth_user.id)
134
159
 
135
160
  body = await request.body()
136
161
  if auth_user and body:
@@ -244,11 +269,13 @@ class PathCleanMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
244
269
  parent = pathlib.Path('/dashboard')
245
270
  request_path = pathlib.Path(posixpath.normpath(request.url.path))
246
271
  if not _is_relative_to(request_path, parent):
247
- raise fastapi.HTTPException(status_code=403, detail='Forbidden')
272
+ return fastapi.responses.JSONResponse(
273
+ status_code=403, content={'detail': 'Forbidden'})
248
274
  return await call_next(request)
249
275
 
250
276
 
251
277
  app = fastapi.FastAPI(prefix='/api/v1', debug=True, lifespan=lifespan)
278
+ app.add_middleware(RBACMiddleware)
252
279
  app.add_middleware(InternalDashboardPrefixMiddleware)
253
280
  app.add_middleware(PathCleanMiddleware)
254
281
  app.add_middleware(CacheControlStaticMiddleware)
@@ -266,6 +293,7 @@ app.add_middleware(AuthProxyMiddleware)
266
293
  app.add_middleware(RequestIDMiddleware)
267
294
  app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
268
295
  app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
296
+ app.include_router(users_rest.router, prefix='/users', tags=['users'])
269
297
  app.include_router(workspaces_rest.router,
270
298
  prefix='/workspaces',
271
299
  tags=['workspaces'])
@@ -835,13 +863,6 @@ async def logs(
835
863
  )
836
864
 
837
865
 
838
- @app.get('/users')
839
- async def users() -> List[Dict[str, Any]]:
840
- """Gets all users."""
841
- user_list = global_user_state.get_all_users()
842
- return [user.to_dict() for user in user_list]
843
-
844
-
845
866
  @app.post('/download_logs')
846
867
  async def download_logs(
847
868
  request: fastapi.Request,
@@ -16,3 +16,4 @@ include sky/templates/*
16
16
  include sky/utils/kubernetes/*
17
17
  include sky/server/html/*
18
18
  recursive-include sky/dashboard/out *
19
+ include sky/users/*.conf
@@ -58,6 +58,8 @@ install_requires = [
58
58
  'setproctitle',
59
59
  'sqlalchemy',
60
60
  'psycopg2-binary',
61
+ 'casbin',
62
+ 'sqlalchemy_adapter',
61
63
  ]
62
64
 
63
65
  local_ray = [
sky/skylet/constants.py CHANGED
@@ -379,7 +379,7 @@ OVERRIDEABLE_CONFIG_KEYS_IN_TASK: List[Tuple[str, ...]] = [
379
379
  SKIPPED_CLIENT_OVERRIDE_KEYS: List[Tuple[str, ...]] = [('admin_policy',),
380
380
  ('api_server',),
381
381
  ('allowed_clouds',),
382
- ('workspaces',)]
382
+ ('workspaces',), ('db',)]
383
383
 
384
384
  # Constants for Azure blob storage
385
385
  WAIT_FOR_STORAGE_ACCOUNT_CREATION = 60
@@ -409,6 +409,12 @@ ENV_VAR_IS_SKYPILOT_SERVER = 'IS_SKYPILOT_SERVER'
409
409
 
410
410
  SKYPILOT_DEFAULT_WORKSPACE = 'default'
411
411
 
412
- # Experimental - may be deprecated in the future without notice.
413
- SKYPILOT_API_SERVER_DB_URL_ENV_VAR: str = (
414
- f'{SKYPILOT_ENV_VAR_PREFIX}API_SERVER_DB_URL')
412
+ # BEGIN constants used for service catalog.
413
+ HOSTED_CATALOG_DIR_URL = 'https://raw.githubusercontent.com/skypilot-org/skypilot-catalog/master/catalogs' # pylint: disable=line-too-long
414
+ HOSTED_CATALOG_DIR_URL_S3_MIRROR = 'https://skypilot-catalog.s3.us-east-1.amazonaws.com/catalogs' # pylint: disable=line-too-long
415
+ CATALOG_SCHEMA_VERSION = 'v7'
416
+ CATALOG_DIR = '~/.sky/catalogs'
417
+ ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci',
418
+ 'kubernetes', 'runpod', 'vast', 'vsphere', 'cudo', 'fluidstack',
419
+ 'paperspace', 'do', 'nebius', 'ssh')
420
+ # END constants used for service catalog.
sky/skypilot_config.py CHANGED
@@ -756,13 +756,15 @@ def apply_cli_config(cli_config: Optional[List[str]]) -> Dict[str, Any]:
756
756
  return parsed_config
757
757
 
758
758
 
759
- def update_config_no_lock(config: config_utils.Config) -> None:
759
+ def update_api_server_config_no_lock(config: config_utils.Config) -> None:
760
760
  """Dumps the new config to a file and syncs to ConfigMap if in Kubernetes.
761
761
 
762
762
  Args:
763
763
  config: The config to save and sync.
764
764
  """
765
- global_config_path = os.path.expanduser(get_user_config_path())
765
+ global_config_path = _resolve_server_config_path()
766
+ if global_config_path is None:
767
+ global_config_path = get_user_config_path()
766
768
 
767
769
  # Always save to the local file (PVC in Kubernetes, local file otherwise)
768
770
  common_utils.dump_yaml(global_config_path, dict(config))
@@ -21,11 +21,21 @@ from websockets.asyncio.client import connect
21
21
 
22
22
  BUFFER_SIZE = 2**16 # 64KB
23
23
 
24
+ # Environment variable for a file path to the API cookie file.
25
+ # Keep in sync with server/constants.py
26
+ API_COOKIE_FILE_ENV_VAR = 'SKYPILOT_API_COOKIE_FILE'
27
+ # Default file if unset.
28
+ # Keep in sync with server/constants.py
29
+ API_COOKIE_FILE_DEFAULT_LOCATION = '~/.sky/cookies.txt'
30
+
24
31
 
25
32
  def _get_cookie_header(url: str) -> Dict[str, str]:
26
33
  """Extract Cookie header value from a cookie jar for a specific URL"""
27
- cookie_path = os.environ.get('SKYPILOT_API_COOKIE_FILE')
34
+ cookie_path = os.environ.get(API_COOKIE_FILE_ENV_VAR)
28
35
  if cookie_path is None:
36
+ cookie_path = API_COOKIE_FILE_DEFAULT_LOCATION
37
+ cookie_path = os.path.expanduser(cookie_path)
38
+ if not os.path.exists(cookie_path):
29
39
  return {}
30
40
 
31
41
  request = Request(url)
sky/users/__init__.py ADDED
File without changes
sky/users/model.conf ADDED
@@ -0,0 +1,15 @@
1
+ # rbac_model.conf
2
+ [request_definition]
3
+ r = sub, obj, act
4
+
5
+ [policy_definition]
6
+ p = sub, obj, act
7
+
8
+ [role_definition]
9
+ g = _, _
10
+
11
+ [policy_effect]
12
+ e = some(where (p.eft == allow))
13
+
14
+ [matchers]
15
+ m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act
@@ -0,0 +1,178 @@
1
+ """Permission service for SkyPilot API Server."""
2
+ import contextlib
3
+ import logging
4
+ import os
5
+ import threading
6
+ from typing import List
7
+
8
+ import casbin
9
+ import filelock
10
+ import sqlalchemy_adapter
11
+
12
+ from sky import global_user_state
13
+ from sky import sky_logging
14
+ from sky.users import rbac
15
+
16
+ logger = sky_logging.init_logger(__name__)
17
+
18
+ # Filelocks for the policy update.
19
+ POLICY_UPDATE_LOCK_PATH = os.path.expanduser('~/.sky/.policy_update.lock')
20
+ POLICY_UPDATE_LOCK_TIMEOUT_SECONDS = 20
21
+
22
+ _enforcer_instance = None
23
+ _lock = threading.Lock()
24
+
25
+
26
+ class PermissionService:
27
+ """Permission service for SkyPilot API Server."""
28
+
29
+ def __init__(self):
30
+ global _enforcer_instance
31
+ if _enforcer_instance is None:
32
+ # For different threads, we share the same enforcer instance.
33
+ with _lock:
34
+ if _enforcer_instance is None:
35
+ _enforcer_instance = self
36
+ engine = global_user_state.SQLALCHEMY_ENGINE
37
+ adapter = sqlalchemy_adapter.Adapter(engine)
38
+ model_path = os.path.join(os.path.dirname(__file__),
39
+ 'model.conf')
40
+ enforcer = casbin.Enforcer(model_path, adapter)
41
+ logging.getLogger('casbin.policy').setLevel(
42
+ sky_logging.ERROR)
43
+ logging.getLogger('casbin.role').setLevel(sky_logging.ERROR)
44
+ self.enforcer = enforcer
45
+ else:
46
+ self.enforcer = _enforcer_instance.enforcer
47
+ self._maybe_initialize_policies()
48
+
49
+ def _maybe_initialize_policies(self):
50
+ """Initialize policies if they don't already exist."""
51
+ logger.debug(f'Initializing policies in process: {os.getpid()}')
52
+
53
+ # Check if policies are already initialized by looking for existing
54
+ # permission policies in the enforcer
55
+ existing_policies = self.enforcer.get_policy()
56
+
57
+ # If we already have policies for the expected roles, skip
58
+ # initialization
59
+ role_permissions = rbac.get_role_permissions()
60
+ expected_policies = []
61
+ for role, permissions in role_permissions.items():
62
+ if permissions['permissions'] and 'blocklist' in permissions[
63
+ 'permissions']:
64
+ blocklist = permissions['permissions']['blocklist']
65
+ for item in blocklist:
66
+ expected_policies.append(
67
+ [role, item['path'], item['method']])
68
+
69
+ # Check if all expected policies already exist
70
+ policies_exist = all(
71
+ any(policy == expected
72
+ for policy in existing_policies)
73
+ for expected in expected_policies)
74
+
75
+ if not policies_exist:
76
+ # Only clear and reinitialize if policies don't exist or are
77
+ # incomplete
78
+ logger.debug('Policies not found or incomplete, initializing...')
79
+ # Only clear p policies (permission policies),
80
+ # keep g policies (role policies)
81
+ self.enforcer.remove_filtered_policy(0)
82
+ for role, permissions in role_permissions.items():
83
+ if permissions['permissions'] and 'blocklist' in permissions[
84
+ 'permissions']:
85
+ blocklist = permissions['permissions']['blocklist']
86
+ for item in blocklist:
87
+ path = item['path']
88
+ method = item['method']
89
+ self.enforcer.add_policy(role, path, method)
90
+ self.enforcer.save_policy()
91
+ else:
92
+ logger.debug('Policies already exist, skipping initialization')
93
+
94
+ # Always ensure users have default roles (this is idempotent)
95
+ all_users = global_user_state.get_all_users()
96
+ for user in all_users:
97
+ self.add_user_if_not_exists(user.id)
98
+
99
+ def add_user_if_not_exists(self, user: str) -> None:
100
+ """Add user role relationship."""
101
+ with _policy_lock():
102
+ user_roles = self.enforcer.get_roles_for_user(user)
103
+ if not user_roles:
104
+ logger.info(f'User {user} has no roles, adding'
105
+ f' default role {rbac.get_default_role()}')
106
+ self.enforcer.add_grouping_policy(user, rbac.get_default_role())
107
+ self.enforcer.save_policy()
108
+
109
+ def update_role(self, user: str, new_role: str):
110
+ """Update user role relationship."""
111
+ with _policy_lock():
112
+ # Get current roles
113
+ self._load_policy_no_lock()
114
+ # Avoid calling get_user_roles, as it will require the lock.
115
+ current_roles = self.enforcer.get_roles_for_user(user)
116
+ if not current_roles:
117
+ logger.warning(f'User {user} has no roles')
118
+ else:
119
+ # TODO(hailong): how to handle multiple roles?
120
+ current_role = current_roles[0]
121
+ if current_role == new_role:
122
+ logger.info(f'User {user} already has role {new_role}')
123
+ return
124
+ self.enforcer.remove_grouping_policy(user, current_role)
125
+
126
+ # Update user role
127
+ self.enforcer.add_grouping_policy(user, new_role)
128
+ self.enforcer.save_policy()
129
+
130
+ def get_user_roles(self, user: str) -> List[str]:
131
+ """Get all roles for a user.
132
+
133
+ This method returns all roles that the user has, including inherited
134
+ roles. For example, if a user has role 'admin' and 'admin' inherits
135
+ from 'user', this method will return ['admin', 'user'].
136
+
137
+ Args:
138
+ user: The user ID to get roles for.
139
+
140
+ Returns:
141
+ A list of role names that the user has.
142
+ """
143
+ self._load_policy()
144
+ return self.enforcer.get_roles_for_user(user)
145
+
146
+ def check_permission(self, user: str, path: str, method: str) -> bool:
147
+ """Check permission."""
148
+ # We intentionally don't load the policy here, as it is a hot path, and
149
+ # we don't support updating the policy.
150
+ # We don't hold the lock for checking permission, as it is read only and
151
+ # it is a hot path in every request. It is ok to have a stale policy,
152
+ # as long as it is eventually consistent.
153
+ # self._load_policy_no_lock()
154
+ return self.enforcer.enforce(user, path, method)
155
+
156
+ def _load_policy_no_lock(self):
157
+ """Load policy from storage."""
158
+ self.enforcer.load_policy()
159
+
160
+ def _load_policy(self):
161
+ """Load policy from storage with lock."""
162
+ with _policy_lock():
163
+ self._load_policy_no_lock()
164
+
165
+
166
+ @contextlib.contextmanager
167
+ def _policy_lock():
168
+ """Context manager for policy update lock."""
169
+ try:
170
+ with filelock.FileLock(POLICY_UPDATE_LOCK_PATH,
171
+ POLICY_UPDATE_LOCK_TIMEOUT_SECONDS):
172
+ yield
173
+ except filelock.Timeout as e:
174
+ raise RuntimeError(f'Failed to load policy due to a timeout '
175
+ f'when trying to acquire the lock at '
176
+ f'{POLICY_UPDATE_LOCK_PATH}. '
177
+ 'Please try again or manually remove the lock '
178
+ f'file if you believe it is stale.') from e
sky/users/rbac.py ADDED
@@ -0,0 +1,86 @@
1
+ """RBAC (Role-Based Access Control) functionality for SkyPilot API Server."""
2
+
3
+ import enum
4
+ from typing import Dict, List
5
+
6
+ from sky import sky_logging
7
+ from sky import skypilot_config
8
+
9
+ logger = sky_logging.init_logger(__name__)
10
+
11
+ # Default user blocklist for user role
12
+ # Cannot access workspace CUD operations
13
+ _DEFAULT_USER_BLOCKLIST = [{
14
+ 'path': '/workspaces/config',
15
+ 'method': 'POST'
16
+ }, {
17
+ 'path': '/workspaces/update',
18
+ 'method': 'POST'
19
+ }, {
20
+ 'path': '/workspaces/create',
21
+ 'method': 'POST'
22
+ }, {
23
+ 'path': '/workspaces/delete',
24
+ 'method': 'POST'
25
+ }, {
26
+ 'path': '/users/update',
27
+ 'method': 'POST'
28
+ }]
29
+
30
+
31
+ # Define roles
32
+ class RoleName(str, enum.Enum):
33
+ ADMIN = 'admin'
34
+ USER = 'user'
35
+
36
+
37
+ def get_supported_roles() -> List[str]:
38
+ return [role_name.value for role_name in RoleName]
39
+
40
+
41
+ def get_default_role() -> str:
42
+ return skypilot_config.get_nested(('rbac', 'default_role'),
43
+ default_value=RoleName.ADMIN.value)
44
+
45
+
46
+ def get_role_permissions(
47
+ ) -> Dict[str, Dict[str, Dict[str, List[Dict[str, str]]]]]:
48
+ """Get all role permissions from config.
49
+
50
+ Returns:
51
+ Dictionary containing all roles and their permissions configuration.
52
+ Example:
53
+ {
54
+ 'admin': {
55
+ 'permissions': {
56
+ 'blocklist': []
57
+ }
58
+ },
59
+ 'user': {
60
+ 'permissions': {
61
+ 'blocklist': [
62
+ {'path': '/workspaces/config', 'method': 'POST'},
63
+ {'path': '/workspaces/update', 'method': 'POST'}
64
+ ]
65
+ }
66
+ }
67
+ }
68
+ """
69
+ # Get all roles from the config
70
+ config_permissions = skypilot_config.get_nested(('rbac', 'roles'),
71
+ default_value={})
72
+ supported_roles = get_supported_roles()
73
+ for role, permissions in config_permissions.items():
74
+ role_name = role.lower()
75
+ if role_name not in supported_roles:
76
+ logger.warning(f'Invalid role: {role_name}')
77
+ continue
78
+ config_permissions[role_name] = permissions
79
+ # Add default roles if not present
80
+ if 'user' not in config_permissions:
81
+ config_permissions['user'] = {
82
+ 'permissions': {
83
+ 'blocklist': _DEFAULT_USER_BLOCKLIST
84
+ }
85
+ }
86
+ return config_permissions