skypilot-nightly 1.0.0.dev20250607__py3-none-any.whl → 1.0.0.dev20250609__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +18 -2
  3. sky/check.py +4 -3
  4. sky/cli.py +5 -6
  5. sky/client/cli.py +5 -6
  6. sky/core.py +3 -2
  7. sky/dashboard/out/404.html +1 -1
  8. sky/dashboard/out/_next/static/chunks/470-680c19413b8f808b.js +1 -0
  9. sky/dashboard/out/_next/static/chunks/{614-635a84e87800f99e.js → 63-e2d7b1e75e67c713.js} +8 -8
  10. sky/dashboard/out/_next/static/chunks/843-16c7194621b2b512.js +11 -0
  11. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/{[job]-18aed9b56247d074.js → [job]-d31688d3e52736dd.js} +1 -1
  12. sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-b919a73aecdfa78f.js → [cluster]-e7d8710a9b0491e5.js} +1 -1
  13. sky/dashboard/out/_next/static/chunks/pages/{clusters-4f6b9dd9abcb33ad.js → clusters-3c674e5d970e05cb.js} +1 -1
  14. sky/dashboard/out/_next/static/chunks/pages/config-3aac7a015c6eede1.js +6 -0
  15. sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-3a18d0eeb5119fe4.js → [context]-46d2e4ad6c487260.js} +1 -1
  16. sky/dashboard/out/_next/static/chunks/pages/{infra-a1a6abeeb58c1051.js → infra-7013d816a2a0e76c.js} +1 -1
  17. sky/dashboard/out/_next/static/chunks/pages/jobs/{[job]-1354e28c81eeb686.js → [job]-f7f0c9e156d328bc.js} +1 -1
  18. sky/dashboard/out/_next/static/chunks/pages/{jobs-23bfc8bf373423db.js → jobs-87e60396c376292f.js} +1 -1
  19. sky/dashboard/out/_next/static/chunks/pages/users-9355a0f13d1db61d.js +16 -0
  20. sky/dashboard/out/_next/static/chunks/pages/workspace/{new-e1f9c0c3ff7ac4bd.js → new-9a749cca1813bd27.js} +1 -1
  21. sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-686590e0ee4b2412.js → [name]-8eeb628e03902f1b.js} +1 -1
  22. sky/dashboard/out/_next/static/chunks/pages/workspaces-8fbcc5ab4af316d0.js +1 -0
  23. sky/dashboard/out/_next/static/css/8b1c8321d4c02372.css +3 -0
  24. sky/dashboard/out/_next/static/xos0euNCptbGAM7_Q3Acl/_buildManifest.js +1 -0
  25. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  26. sky/dashboard/out/clusters/[cluster].html +1 -1
  27. sky/dashboard/out/clusters.html +1 -1
  28. sky/dashboard/out/config.html +1 -1
  29. sky/dashboard/out/index.html +1 -1
  30. sky/dashboard/out/infra/[context].html +1 -1
  31. sky/dashboard/out/infra.html +1 -1
  32. sky/dashboard/out/jobs/[job].html +1 -1
  33. sky/dashboard/out/jobs.html +1 -1
  34. sky/dashboard/out/users.html +1 -1
  35. sky/dashboard/out/workspace/new.html +1 -1
  36. sky/dashboard/out/workspaces/[name].html +1 -1
  37. sky/dashboard/out/workspaces.html +1 -1
  38. sky/exceptions.py +5 -0
  39. sky/global_user_state.py +11 -6
  40. sky/jobs/server/core.py +9 -1
  41. sky/jobs/server/server.py +0 -95
  42. sky/jobs/utils.py +2 -1
  43. sky/models.py +18 -0
  44. sky/serve/server/core.py +1 -1
  45. sky/server/common.py +4 -2
  46. sky/server/constants.py +0 -2
  47. sky/server/requests/executor.py +10 -2
  48. sky/server/requests/requests.py +4 -3
  49. sky/server/server.py +22 -5
  50. sky/skylet/constants.py +3 -0
  51. sky/skylet/job_lib.py +2 -1
  52. sky/skypilot_config.py +9 -0
  53. sky/users/model.conf +1 -1
  54. sky/users/permission.py +148 -31
  55. sky/users/rbac.py +26 -0
  56. sky/users/server.py +14 -13
  57. sky/utils/common.py +6 -1
  58. sky/utils/common_utils.py +21 -3
  59. sky/utils/schemas.py +9 -0
  60. sky/workspaces/core.py +100 -8
  61. sky/workspaces/server.py +15 -2
  62. sky/workspaces/utils.py +56 -0
  63. {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/METADATA +1 -1
  64. {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/RECORD +72 -71
  65. sky/dashboard/out/_next/static/1qG0HTmVilJPxQdBk0fX5/_buildManifest.js +0 -1
  66. sky/dashboard/out/_next/static/chunks/470-ad1e0db3afcbd9c9.js +0 -1
  67. sky/dashboard/out/_next/static/chunks/843-c296541442d4af88.js +0 -11
  68. sky/dashboard/out/_next/static/chunks/pages/config-fe375a56342cf609.js +0 -6
  69. sky/dashboard/out/_next/static/chunks/pages/users-5800045bd04e69c2.js +0 -16
  70. sky/dashboard/out/_next/static/chunks/pages/workspaces-76b07aa5da91b0df.js +0 -1
  71. sky/dashboard/out/_next/static/css/667d941a2888ce6e.css +0 -3
  72. /sky/dashboard/out/_next/static/chunks/{856-3a32da4b84176f6d.js → 856-affc52adf5403a3a.js} +0 -0
  73. /sky/dashboard/out/_next/static/chunks/{973-6d78a0814682d771.js → 973-aed916d5b02d2d63.js} +0 -0
  74. /sky/dashboard/out/_next/static/chunks/pages/{_app-cb81dc4d27f4d009.js → _app-5f16aba5794ee8e7.js} +0 -0
  75. /sky/dashboard/out/_next/static/{1qG0HTmVilJPxQdBk0fX5 → xos0euNCptbGAM7_Q3Acl}/_ssgManifest.js +0 -0
  76. {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/WHEEL +0 -0
  77. {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/entry_points.txt +0 -0
  78. {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/licenses/LICENSE +0 -0
  79. {skypilot_nightly-1.0.0.dev20250607.dist-info → skypilot_nightly-1.0.0.dev20250609.dist-info}/top_level.txt +0 -0
sky/jobs/server/server.py CHANGED
@@ -1,12 +1,9 @@
1
1
  """REST API for managed jobs."""
2
- import os
3
2
 
4
3
  import fastapi
5
- import httpx
6
4
 
7
5
  from sky import sky_logging
8
6
  from sky.jobs.server import core
9
- from sky.jobs.server import dashboard_utils
10
7
  from sky.server import common as server_common
11
8
  from sky.server import stream_utils
12
9
  from sky.server.requests import executor
@@ -14,7 +11,6 @@ from sky.server.requests import payloads
14
11
  from sky.server.requests import requests as api_requests
15
12
  from sky.skylet import constants
16
13
  from sky.utils import common
17
- from sky.utils import common_utils
18
14
 
19
15
  logger = sky_logging.init_logger(__name__)
20
16
 
@@ -110,94 +106,3 @@ async def download_logs(
110
106
  if jobs_download_logs_body.refresh else api_requests.ScheduleType.SHORT,
111
107
  request_cluster_name=common.JOB_CONTROLLER_NAME,
112
108
  )
113
-
114
-
115
- @router.get('/dashboard')
116
- async def dashboard(request: fastapi.Request,
117
- user_hash: str) -> fastapi.Response:
118
- # TODO(cooperc): Support showing only jobs for a specific user.
119
-
120
- # FIX(zhwu/cooperc/eric): Fix log downloading (assumes global
121
- # /download_log/xx route)
122
-
123
- # Note: before #4717, each user had their own controller, and thus their own
124
- # dashboard. Now, all users share the same controller, so this isn't really
125
- # necessary. TODO(cooperc): clean up.
126
-
127
- # TODO: Put this in an executor to avoid blocking the main server thread.
128
- # It can take a long time if it needs to check the controller status.
129
-
130
- # Find the port for the dashboard of the user
131
- os.environ[constants.USER_ID_ENV_VAR] = user_hash
132
- server_common.reload_for_new_request(client_entrypoint=None,
133
- client_command=None,
134
- using_remote_api_server=False)
135
- logger.info(f'Starting dashboard for user hash: {user_hash}')
136
-
137
- with dashboard_utils.get_dashboard_lock_for_user(user_hash):
138
- max_retries = 3
139
- for attempt in range(max_retries):
140
- port, pid = dashboard_utils.get_dashboard_session(user_hash)
141
- if port == 0 or attempt > 0:
142
- # Let the client know that we are waiting for starting the
143
- # dashboard.
144
- try:
145
- port, pid = core.start_dashboard_forwarding()
146
- except Exception as e: # pylint: disable=broad-except
147
- # We catch all exceptions to gracefully handle unknown
148
- # errors and raise an HTTPException to the client.
149
- msg = (
150
- 'Dashboard failed to start: '
151
- f'{common_utils.format_exception(e, use_bracket=True)}')
152
- logger.error(msg)
153
- raise fastapi.HTTPException(status_code=503, detail=msg)
154
- dashboard_utils.add_dashboard_session(user_hash, port, pid)
155
-
156
- # Assuming the dashboard is forwarded to localhost on the API server
157
- dashboard_url = f'http://localhost:{port}'
158
- try:
159
- # Ping the dashboard to check if it's still running
160
- async with httpx.AsyncClient() as client:
161
- response = await client.request('GET',
162
- dashboard_url,
163
- timeout=5)
164
- if response.is_success:
165
- break # Connection successful, proceed with the request
166
- # Raise an HTTPException here which will be caught by the
167
- # following except block to retry with new connection
168
- response.raise_for_status()
169
- except Exception as e: # pylint: disable=broad-except
170
- # We catch all exceptions to gracefully handle unknown
171
- # errors and retry or raise an HTTPException to the client.
172
- # Assume an exception indicates that the dashboard connection
173
- # is stale - remove it so that a new one is created.
174
- dashboard_utils.remove_dashboard_session(user_hash)
175
- msg = (
176
- f'Dashboard connection attempt {attempt + 1} failed with '
177
- f'{common_utils.format_exception(e, use_bracket=True)}')
178
- logger.info(msg)
179
- if attempt == max_retries - 1:
180
- raise fastapi.HTTPException(status_code=503, detail=msg)
181
-
182
- # Create a client session to forward the request
183
- try:
184
- async with httpx.AsyncClient() as client:
185
- # Make the request and get the response
186
- response = await client.request(
187
- method='GET',
188
- url=f'{dashboard_url}',
189
- headers=request.headers.raw,
190
- )
191
-
192
- # Create a new response with the content already read
193
- content = await response.aread()
194
- return fastapi.Response(
195
- content=content,
196
- status_code=response.status_code,
197
- headers=dict(response.headers),
198
- media_type=response.headers.get('content-type'))
199
- except Exception as e:
200
- msg = (f'Failed to forward request to dashboard: '
201
- f'{common_utils.format_exception(e, use_bracket=True)}')
202
- logger.error(msg)
203
- raise fastapi.HTTPException(status_code=502, detail=msg)
sky/jobs/utils.py CHANGED
@@ -1025,7 +1025,8 @@ def load_managed_job_queue(payload: str) -> List[Dict[str, Any]]:
1025
1025
  if 'user_hash' in job and job['user_hash'] is not None:
1026
1026
  # Skip jobs that do not have user_hash info.
1027
1027
  # TODO(cooperc): Remove check before 0.12.0.
1028
- job['user_name'] = global_user_state.get_user(job['user_hash']).name
1028
+ user = global_user_state.get_user(job['user_hash'])
1029
+ job['user_name'] = user.name if user is not None else None
1029
1030
  return jobs
1030
1031
 
1031
1032
 
sky/models.py CHANGED
@@ -2,8 +2,13 @@
2
2
 
3
3
  import collections
4
4
  import dataclasses
5
+ import getpass
6
+ import os
5
7
  from typing import Any, Dict, Optional
6
8
 
9
+ from sky.skylet import constants
10
+ from sky.utils import common_utils
11
+
7
12
 
8
13
  @dataclasses.dataclass
9
14
  class User:
@@ -16,6 +21,19 @@ class User:
16
21
  def to_dict(self) -> Dict[str, Any]:
17
22
  return {'id': self.id, 'name': self.name}
18
23
 
24
+ def to_env_vars(self) -> Dict[str, Any]:
25
+ return {
26
+ constants.USER_ID_ENV_VAR: self.id,
27
+ constants.USER_ENV_VAR: self.name,
28
+ }
29
+
30
+ @classmethod
31
+ def get_current_user(cls) -> 'User':
32
+ """Returns the current user."""
33
+ user_name = os.getenv(constants.USER_ENV_VAR, getpass.getuser())
34
+ user_hash = common_utils.get_user_hash()
35
+ return User(id=user_hash, name=user_name)
36
+
19
37
 
20
38
  RealtimeGpuAvailability = collections.namedtuple(
21
39
  'RealtimeGpuAvailability', ['gpu', 'counts', 'capacity', 'available'])
sky/serve/server/core.py CHANGED
@@ -221,7 +221,7 @@ def up(
221
221
  # for the first time; otherwise it is a name conflict.
222
222
  # Since the controller may be shared among multiple users, launch the
223
223
  # controller with the API server's user hash.
224
- with common.with_server_user_hash():
224
+ with common.with_server_user():
225
225
  with skypilot_config.local_active_workspace_ctx(
226
226
  constants.SKYPILOT_DEFAULT_WORKSPACE):
227
227
  controller_job_id, controller_handle = execution.launch(
sky/server/common.py CHANGED
@@ -39,6 +39,7 @@ if typing.TYPE_CHECKING:
39
39
  import requests
40
40
 
41
41
  from sky import dag as dag_lib
42
+ from sky import models
42
43
  else:
43
44
  pydantic = adaptors_common.LazyImport('pydantic')
44
45
  requests = adaptors_common.LazyImport('requests')
@@ -710,7 +711,7 @@ def request_body_to_params(body: 'pydantic.BaseModel') -> Dict[str, Any]:
710
711
 
711
712
  def reload_for_new_request(client_entrypoint: Optional[str],
712
713
  client_command: Optional[str],
713
- using_remote_api_server: bool):
714
+ using_remote_api_server: bool, user: 'models.User'):
714
715
  """Reload modules, global variables, and usage message for a new request."""
715
716
  # This should be called first to make sure the logger is up-to-date.
716
717
  sky_logging.reload_logger()
@@ -719,10 +720,11 @@ def reload_for_new_request(client_entrypoint: Optional[str],
719
720
  skypilot_config.safe_reload_config()
720
721
 
721
722
  # Reset the client entrypoint and command for the usage message.
722
- common_utils.set_client_status(
723
+ common_utils.set_request_context(
723
724
  client_entrypoint=client_entrypoint,
724
725
  client_command=client_command,
725
726
  using_remote_api_server=using_remote_api_server,
727
+ user=user,
726
728
  )
727
729
 
728
730
  # Clear cache should be called before reload_logger and usage reset,
sky/server/constants.py CHANGED
@@ -11,8 +11,6 @@ API_VERSION = '9'
11
11
 
12
12
  # Prefix for API request names.
13
13
  REQUEST_NAME_PREFIX = 'sky.'
14
- # The user ID of the SkyPilot system.
15
- SKYPILOT_SYSTEM_USER_ID = 'skypilot-system'
16
14
  # The memory (GB) that SkyPilot tries to not use to prevent OOM.
17
15
  MIN_AVAIL_MEM_GB = 2
18
16
  # Default encoder/decoder handler name.
@@ -53,6 +53,7 @@ from sky.utils import context
53
53
  from sky.utils import context_utils
54
54
  from sky.utils import subprocess_utils
55
55
  from sky.utils import timeline
56
+ from sky.workspaces import core as workspaces_core
56
57
 
57
58
  if typing.TYPE_CHECKING:
58
59
  import types
@@ -229,6 +230,9 @@ def override_request_env_and_config(
229
230
  original_env = os.environ.copy()
230
231
  os.environ.update(request_body.env_vars)
231
232
  # Note: may be overridden by AuthProxyMiddleware.
233
+ # TODO(zhwu): we need to make the entire request a context available to the
234
+ # entire request execution, so that we can access info like user through
235
+ # the execution.
232
236
  user = models.User(id=request_body.env_vars[constants.USER_ID_ENV_VAR],
233
237
  name=request_body.env_vars[constants.USER_ENV_VAR])
234
238
  global_user_state.add_or_update_user(user)
@@ -237,13 +241,17 @@ def override_request_env_and_config(
237
241
  server_common.reload_for_new_request(
238
242
  client_entrypoint=request_body.entrypoint,
239
243
  client_command=request_body.entrypoint_command,
240
- using_remote_api_server=request_body.using_remote_api_server)
244
+ using_remote_api_server=request_body.using_remote_api_server,
245
+ user=user)
241
246
  try:
242
247
  logger.debug(
243
248
  f'override path: {request_body.override_skypilot_config_path}')
244
249
  with skypilot_config.override_skypilot_config(
245
250
  request_body.override_skypilot_config,
246
251
  request_body.override_skypilot_config_path):
252
+ # Rejecting requests to workspaces that the user does not have
253
+ # permission to access.
254
+ workspaces_core.reject_request_for_unauthorized_workspace(user)
247
255
  yield
248
256
  finally:
249
257
  # We need to call the save_timeline() since atexit will not be
@@ -433,7 +441,7 @@ def prepare_request(
433
441
  """Prepare a request for execution."""
434
442
  user_id = request_body.env_vars[constants.USER_ID_ENV_VAR]
435
443
  if is_skypilot_system:
436
- user_id = server_constants.SKYPILOT_SYSTEM_USER_ID
444
+ user_id = constants.SKYPILOT_SYSTEM_USER_ID
437
445
  global_user_state.add_or_update_user(
438
446
  models.User(id=user_id, name=user_id))
439
447
  request = api_requests.Request(request_id=request_id,
@@ -11,7 +11,7 @@ import signal
11
11
  import sqlite3
12
12
  import time
13
13
  import traceback
14
- from typing import Any, Callable, Dict, List, Optional, Tuple
14
+ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
15
15
 
16
16
  import colorama
17
17
  import filelock
@@ -204,7 +204,8 @@ class Request:
204
204
  """
205
205
  assert isinstance(self.request_body,
206
206
  payloads.RequestBody), (self.name, self.request_body)
207
- user_name = global_user_state.get_user(self.user_id).name
207
+ user = global_user_state.get_user(self.user_id)
208
+ user_name = user.name if user is not None else None
208
209
  return RequestPayload(
209
210
  request_id=self.request_id,
210
211
  name=self.name,
@@ -464,7 +465,7 @@ def request_lock_path(request_id: str) -> str:
464
465
 
465
466
  @contextlib.contextmanager
466
467
  @init_db
467
- def update_request(request_id: str):
468
+ def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
468
469
  """Get a SkyPilot API request."""
469
470
  request = _get_request_no_lock(request_id)
470
471
  yield request
sky/server/server.py CHANGED
@@ -49,6 +49,7 @@ from sky.server.requests import preconditions
49
49
  from sky.server.requests import requests as requests_lib
50
50
  from sky.skylet import constants
51
51
  from sky.usage import usage_lib
52
+ from sky.users import permission
52
53
  from sky.users import server as users_rest
53
54
  from sky.utils import admin_policy_utils
54
55
  from sky.utils import common as common_lib
@@ -105,17 +106,21 @@ class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
105
106
  """Middleware to handle RBAC."""
106
107
 
107
108
  async def dispatch(self, request: fastapi.Request, call_next):
108
- if request.url.path.startswith('/dashboard/'):
109
+ # TODO(hailong): should have a list of paths
110
+ # that are not checked for RBAC
111
+ if (request.url.path.startswith('/dashboard/') or
112
+ request.url.path.startswith('/api/')):
109
113
  return await call_next(request)
110
114
 
111
115
  auth_user = _get_auth_user_header(request)
112
116
  if auth_user is None:
113
117
  return await call_next(request)
114
118
 
115
- permission_service = users_rest.permission_service
119
+ permission_service = permission.permission_service
116
120
  # Check the role permission
117
- if permission_service.check_permission(auth_user.id, request.url.path,
118
- request.method):
121
+ if permission_service.check_endpoint_permission(auth_user.id,
122
+ request.url.path,
123
+ request.method):
119
124
  return fastapi.responses.JSONResponse(
120
125
  status_code=403, content={'detail': 'Forbidden'})
121
126
 
@@ -154,9 +159,15 @@ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
154
159
  if auth_user is not None:
155
160
  newly_added = global_user_state.add_or_update_user(auth_user)
156
161
  if newly_added:
157
- users_rest.permission_service.add_user_if_not_exists(
162
+ permission.permission_service.add_user_if_not_exists(
158
163
  auth_user.id)
159
164
 
165
+ # Store user info in request.state for access by GET endpoints
166
+ if auth_user is not None:
167
+ request.state.auth_user = auth_user
168
+ else:
169
+ request.state.auth_user = None
170
+
160
171
  body = await request.body()
161
172
  if auth_user and body:
162
173
  try:
@@ -177,6 +188,12 @@ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
177
188
  f'"env_vars" in request body is not a dictionary '
178
189
  f'for request {request.state.request_id}. '
179
190
  'Skipping user info injection into body.')
191
+ else:
192
+ original_json['env_vars'] = {}
193
+ original_json['env_vars'][
194
+ constants.USER_ID_ENV_VAR] = auth_user.id
195
+ original_json['env_vars'][
196
+ constants.USER_ENV_VAR] = auth_user.name
180
197
  request._body = json.dumps(original_json).encode('utf-8') # pylint: disable=protected-access
181
198
  return await call_next(request)
182
199
 
sky/skylet/constants.py CHANGED
@@ -419,3 +419,6 @@ ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci',
419
419
  'kubernetes', 'runpod', 'vast', 'vsphere', 'cudo', 'fluidstack',
420
420
  'paperspace', 'do', 'nebius', 'ssh')
421
421
  # END constants used for service catalog.
422
+
423
+ # The user ID of the SkyPilot system.
424
+ SKYPILOT_SYSTEM_USER_ID = 'skypilot-system'
sky/skylet/job_lib.py CHANGED
@@ -794,7 +794,8 @@ def load_job_queue(payload: str) -> List[Dict[str, Any]]:
794
794
  for job in jobs:
795
795
  job['status'] = JobStatus(job['status'])
796
796
  job['user_hash'] = job['username']
797
- job['username'] = global_user_state.get_user(job['user_hash']).name
797
+ user = global_user_state.get_user(job['user_hash'])
798
+ job['username'] = user.name if user is not None else None
798
799
  return jobs
799
800
 
800
801
 
sky/skypilot_config.py CHANGED
@@ -765,6 +765,15 @@ def update_api_server_config_no_lock(config: config_utils.Config) -> None:
765
765
  Args:
766
766
  config: The config to save and sync.
767
767
  """
768
+
769
+ def is_running_pytest() -> bool:
770
+ return 'PYTEST_CURRENT_TEST' in os.environ
771
+
772
+ # Only allow this function to be called by the API Server in production.
773
+ if not is_running_pytest() and os.environ.get(
774
+ constants.ENV_VAR_IS_SKYPILOT_SERVER) is None:
775
+ raise ValueError('This function can only be called by the API Server.')
776
+
768
777
  global_config_path = _resolve_server_config_path()
769
778
  if global_config_path is None:
770
779
  global_config_path = get_user_config_path()
sky/users/model.conf CHANGED
@@ -12,4 +12,4 @@ g = _, _
12
12
  e = some(where (p.eft == allow))
13
13
 
14
14
  [matchers]
15
- m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act
15
+ m = (g(r.sub, p.sub)|| p.sub == '*') && r.obj == p.obj && r.act == p.act
sky/users/permission.py CHANGED
@@ -3,7 +3,7 @@ import contextlib
3
3
  import logging
4
4
  import os
5
5
  import threading
6
- from typing import List
6
+ from typing import Generator, List
7
7
 
8
8
  import casbin
9
9
  import filelock
@@ -11,8 +11,11 @@ import sqlalchemy_adapter
11
11
 
12
12
  from sky import global_user_state
13
13
  from sky import sky_logging
14
+ from sky.skylet import constants
14
15
  from sky.users import rbac
15
16
 
17
+ logging.getLogger('casbin.policy').setLevel(sky_logging.ERROR)
18
+ logging.getLogger('casbin.role').setLevel(sky_logging.ERROR)
16
19
  logger = sky_logging.init_logger(__name__)
17
20
 
18
21
  # Filelocks for the policy update.
@@ -38,17 +41,19 @@ class PermissionService:
38
41
  model_path = os.path.join(os.path.dirname(__file__),
39
42
  'model.conf')
40
43
  enforcer = casbin.Enforcer(model_path, adapter)
41
- logging.getLogger('casbin.policy').setLevel(
42
- sky_logging.ERROR)
43
- logging.getLogger('casbin.role').setLevel(sky_logging.ERROR)
44
44
  self.enforcer = enforcer
45
45
  else:
46
46
  self.enforcer = _enforcer_instance.enforcer
47
- self._maybe_initialize_policies()
47
+ with _policy_lock():
48
+ self._maybe_initialize_policies()
48
49
 
49
- def _maybe_initialize_policies(self):
50
+ def _maybe_initialize_policies(self) -> None:
50
51
  """Initialize policies if they don't already exist."""
52
+ # TODO(zhwu): we should avoid running this on client side.
51
53
  logger.debug(f'Initializing policies in process: {os.getpid()}')
54
+ self._load_policy_no_lock()
55
+
56
+ policy_updated = False
52
57
 
53
58
  # Check if policies are already initialized by looking for existing
54
59
  # permission policies in the enforcer
@@ -66,6 +71,17 @@ class PermissionService:
66
71
  expected_policies.append(
67
72
  [role, item['path'], item['method']])
68
73
 
74
+ # Add workspace policy
75
+ workspace_policy_permissions = rbac.get_workspace_policy_permissions()
76
+ logger.debug(f'Workspace policy permissions from config: '
77
+ f'{workspace_policy_permissions}')
78
+
79
+ for workspace_name, users in workspace_policy_permissions.items():
80
+ for user in users:
81
+ expected_policies.append([user, workspace_name, '*'])
82
+ logger.debug(f'Expected workspace policy: user={user}, '
83
+ f'workspace={workspace_name}')
84
+
69
85
  # Check if all expected policies already exist
70
86
  policies_exist = all(
71
87
  any(policy == expected
@@ -86,48 +102,71 @@ class PermissionService:
86
102
  for item in blocklist:
87
103
  path = item['path']
88
104
  method = item['method']
105
+ logger.debug(f'Adding role policy: role={role}, '
106
+ f'path={path}, method={method}')
89
107
  self.enforcer.add_policy(role, path, method)
90
- self.enforcer.save_policy()
108
+ policy_updated = True
109
+
110
+ for workspace_name, users in workspace_policy_permissions.items():
111
+ for user in users:
112
+ logger.debug(f'Initializing workspace policy: user={user}, '
113
+ f'workspace={workspace_name}')
114
+ self.enforcer.add_policy(user, workspace_name, '*')
115
+ policy_updated = True
116
+ logger.debug('Policies initialized successfully')
91
117
  else:
92
118
  logger.debug('Policies already exist, skipping initialization')
93
119
 
94
120
  # Always ensure users have default roles (this is idempotent)
95
121
  all_users = global_user_state.get_all_users()
96
- for user in all_users:
97
- self.add_user_if_not_exists(user.id)
122
+ for existing_user in all_users:
123
+ user_added = self._add_user_if_not_exists_no_lock(existing_user.id)
124
+ policy_updated = policy_updated or user_added
125
+
126
+ if policy_updated:
127
+ self.enforcer.save_policy()
98
128
 
99
- def add_user_if_not_exists(self, user: str) -> None:
129
+ def add_user_if_not_exists(self, user_id: str) -> None:
100
130
  """Add user role relationship."""
101
131
  with _policy_lock():
102
- user_roles = self.enforcer.get_roles_for_user(user)
103
- if not user_roles:
104
- logger.info(f'User {user} has no roles, adding'
105
- f' default role {rbac.get_default_role()}')
106
- self.enforcer.add_grouping_policy(user, rbac.get_default_role())
107
- self.enforcer.save_policy()
108
-
109
- def update_role(self, user: str, new_role: str):
132
+ self._add_user_if_not_exists_no_lock(user_id)
133
+
134
+ def _add_user_if_not_exists_no_lock(self, user_id: str) -> bool:
135
+ """Add user role relationship without lock.
136
+
137
+ Returns:
138
+ True if the user was added, False otherwise.
139
+ """
140
+ user_roles = self.enforcer.get_roles_for_user(user_id)
141
+ if not user_roles:
142
+ logger.info(f'User {user_id} has no roles, adding'
143
+ f' default role {rbac.get_default_role()}')
144
+ self.enforcer.add_grouping_policy(user_id, rbac.get_default_role())
145
+ return True
146
+ return False
147
+
148
+ def update_role(self, user_id: str, new_role: str) -> None:
110
149
  """Update user role relationship."""
111
150
  with _policy_lock():
112
151
  # Get current roles
113
152
  self._load_policy_no_lock()
114
153
  # Avoid calling get_user_roles, as it will require the lock.
115
- current_roles = self.enforcer.get_roles_for_user(user)
154
+ current_roles = self.enforcer.get_roles_for_user(user_id)
116
155
  if not current_roles:
117
- logger.warning(f'User {user} has no roles')
156
+ logger.warning(f'User {user_id} has no roles')
118
157
  else:
119
158
  # TODO(hailong): how to handle multiple roles?
120
159
  current_role = current_roles[0]
121
160
  if current_role == new_role:
122
- logger.info(f'User {user} already has role {new_role}')
161
+ logger.info(f'User {user_id} already has role {new_role}')
123
162
  return
124
- self.enforcer.remove_grouping_policy(user, current_role)
163
+ self.enforcer.remove_grouping_policy(user_id, current_role)
125
164
 
126
165
  # Update user role
127
- self.enforcer.add_grouping_policy(user, new_role)
166
+ self.enforcer.add_grouping_policy(user_id, new_role)
128
167
  self.enforcer.save_policy()
129
168
 
130
- def get_user_roles(self, user: str) -> List[str]:
169
+ def get_user_roles(self, user_id: str) -> List[str]:
131
170
  """Get all roles for a user.
132
171
 
133
172
  This method returns all roles that the user has, including inherited
@@ -140,10 +179,11 @@ class PermissionService:
140
179
  Returns:
141
180
  A list of role names that the user has.
142
181
  """
143
- self._load_policy()
144
- return self.enforcer.get_roles_for_user(user)
182
+ self._load_policy_no_lock()
183
+ return self.enforcer.get_roles_for_user(user_id)
145
184
 
146
- def check_permission(self, user: str, path: str, method: str) -> bool:
185
+ def check_endpoint_permission(self, user_id: str, path: str,
186
+ method: str) -> bool:
147
187
  """Check permission."""
148
188
  # We intentionally don't load the policy here, as it is a hot path, and
149
189
  # we don't support updating the policy.
@@ -151,28 +191,105 @@ class PermissionService:
151
191
  # it is a hot path in every request. It is ok to have a stale policy,
152
192
  # as long as it is eventually consistent.
153
193
  # self._load_policy_no_lock()
154
- return self.enforcer.enforce(user, path, method)
194
+ return self.enforcer.enforce(user_id, path, method)
155
195
 
156
196
  def _load_policy_no_lock(self):
157
197
  """Load policy from storage."""
158
198
  self.enforcer.load_policy()
159
199
 
160
- def _load_policy(self):
200
+ def load_policy(self):
161
201
  """Load policy from storage with lock."""
162
202
  with _policy_lock():
163
203
  self._load_policy_no_lock()
164
204
 
205
+ def check_workspace_permission(self, user_id: str,
206
+ workspace_name: str) -> bool:
207
+ """Check workspace permission.
208
+
209
+ This method checks if a user has permission to access a specific
210
+ workspace.
211
+
212
+ For private workspaces, the user must have explicit permission.
213
+
214
+ For public workspaces, the permission is granted via a wildcard policy
215
+ ('*').
216
+ """
217
+ if os.getenv(constants.ENV_VAR_IS_SKYPILOT_SERVER) is None:
218
+ # When it is not on API server, we allow all users to access all
219
+ # workspaces, as the workspace check has been done on API server.
220
+ return True
221
+ role = self.get_user_roles(user_id)
222
+ if rbac.RoleName.ADMIN.value in role:
223
+ return True
224
+ # The Casbin model matcher already handles the wildcard '*' case:
225
+ # m = (g(r.sub, p.sub)|| p.sub == '*') && r.obj == p.obj &&
226
+ # r.act == p.act
227
+ # This means if there's a policy ('*', workspace_name, '*'), it will
228
+ # match any user
229
+ result = self.enforcer.enforce(user_id, workspace_name, '*')
230
+ logger.debug(f'Workspace permission check: user={user_id}, '
231
+ f'workspace={workspace_name}, result={result}')
232
+ return result
233
+
234
+ def add_workspace_policy(self, workspace_name: str,
235
+ users: List[str]) -> None:
236
+ """Add workspace policy.
237
+
238
+ Args:
239
+ workspace_name: Name of the workspace
240
+ users: List of user IDs that should have access.
241
+ For public workspaces, this should be ['*'].
242
+ For private workspaces, this should be specific user IDs.
243
+ """
244
+ with _policy_lock():
245
+ for user in users:
246
+ logger.debug(f'Adding workspace policy: user={user}, '
247
+ f'workspace={workspace_name}')
248
+ self.enforcer.add_policy(user, workspace_name, '*')
249
+ self.enforcer.save_policy()
250
+
251
+ def update_workspace_policy(self, workspace_name: str,
252
+ users: List[str]) -> None:
253
+ """Update workspace policy.
254
+
255
+ Args:
256
+ workspace_name: Name of the workspace
257
+ users: List of user IDs that should have access.
258
+ For public workspaces, this should be ['*'].
259
+ For private workspaces, this should be specific user IDs.
260
+ """
261
+ with _policy_lock():
262
+ self._load_policy_no_lock()
263
+ # Remove all existing policies for this workspace
264
+ self.enforcer.remove_filtered_policy(1, workspace_name)
265
+ # Add new policies
266
+ for user in users:
267
+ logger.debug(f'Updating workspace policy: user={user}, '
268
+ f'workspace={workspace_name}')
269
+ self.enforcer.add_policy(user, workspace_name, '*')
270
+ self.enforcer.save_policy()
271
+
272
+ def remove_workspace_policy(self, workspace_name: str) -> None:
273
+ """Remove workspace policy."""
274
+ with _policy_lock():
275
+ self.enforcer.remove_filtered_policy(1, workspace_name)
276
+ self.enforcer.save_policy()
277
+
165
278
 
166
279
  @contextlib.contextmanager
167
- def _policy_lock():
280
+ def _policy_lock() -> Generator[None, None, None]:
168
281
  """Context manager for policy update lock."""
169
282
  try:
170
283
  with filelock.FileLock(POLICY_UPDATE_LOCK_PATH,
171
284
  POLICY_UPDATE_LOCK_TIMEOUT_SECONDS):
172
285
  yield
173
286
  except filelock.Timeout as e:
174
- raise RuntimeError(f'Failed to load policy due to a timeout '
287
+ raise RuntimeError(f'Failed to reload policy due to a timeout '
175
288
  f'when trying to acquire the lock at '
176
289
  f'{POLICY_UPDATE_LOCK_PATH}. '
177
290
  'Please try again or manually remove the lock '
178
291
  f'file if you believe it is stale.') from e
292
+
293
+
294
+ # Singleton instance of PermissionService for other modules to use.
295
+ permission_service = PermissionService()