skypilot-nightly 1.0.0.dev20250617__py3-none-any.whl → 1.0.0.dev20250618__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +7 -0
  3. sky/backends/cloud_vm_ray_backend.py +48 -36
  4. sky/cli.py +5 -5729
  5. sky/client/cli.py +11 -2
  6. sky/client/sdk.py +22 -2
  7. sky/clouds/kubernetes.py +5 -0
  8. sky/dashboard/out/404.html +1 -1
  9. sky/dashboard/out/_next/static/{vA3PPpkBwpRTRNBHFYAw_ → LRpGymRCqq-feuFyoWz4m}/_buildManifest.js +1 -1
  10. sky/dashboard/out/_next/static/chunks/641.c8e452bc5070a630.js +1 -0
  11. sky/dashboard/out/_next/static/chunks/984.ae8c08791d274ca0.js +50 -0
  12. sky/dashboard/out/_next/static/chunks/pages/users-928edf039219e47b.js +1 -0
  13. sky/dashboard/out/_next/static/chunks/webpack-ebc2404fd6ce581c.js +1 -0
  14. sky/dashboard/out/_next/static/css/6c12ecc3bd2239b6.css +3 -0
  15. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  16. sky/dashboard/out/clusters/[cluster].html +1 -1
  17. sky/dashboard/out/clusters.html +1 -1
  18. sky/dashboard/out/config.html +1 -1
  19. sky/dashboard/out/index.html +1 -1
  20. sky/dashboard/out/infra/[context].html +1 -1
  21. sky/dashboard/out/infra.html +1 -1
  22. sky/dashboard/out/jobs/[job].html +1 -1
  23. sky/dashboard/out/jobs.html +1 -1
  24. sky/dashboard/out/users.html +1 -1
  25. sky/dashboard/out/workspace/new.html +1 -1
  26. sky/dashboard/out/workspaces/[name].html +1 -1
  27. sky/dashboard/out/workspaces.html +1 -1
  28. sky/global_user_state.py +50 -11
  29. sky/logs/__init__.py +17 -0
  30. sky/logs/agent.py +73 -0
  31. sky/logs/gcp.py +91 -0
  32. sky/models.py +1 -0
  33. sky/provision/instance_setup.py +35 -0
  34. sky/provision/provisioner.py +11 -0
  35. sky/server/common.py +21 -9
  36. sky/server/requests/payloads.py +19 -1
  37. sky/server/server.py +121 -29
  38. sky/setup_files/dependencies.py +11 -1
  39. sky/skylet/constants.py +9 -1
  40. sky/skylet/job_lib.py +75 -19
  41. sky/templates/kubernetes-ray.yml.j2 +9 -0
  42. sky/users/permission.py +49 -19
  43. sky/users/rbac.py +10 -1
  44. sky/users/server.py +274 -9
  45. sky/utils/schemas.py +40 -0
  46. {skypilot_nightly-1.0.0.dev20250617.dist-info → skypilot_nightly-1.0.0.dev20250618.dist-info}/METADATA +9 -1
  47. {skypilot_nightly-1.0.0.dev20250617.dist-info → skypilot_nightly-1.0.0.dev20250618.dist-info}/RECORD +58 -54
  48. sky/dashboard/out/_next/static/chunks/600.bd2ed8c076b720ec.js +0 -16
  49. sky/dashboard/out/_next/static/chunks/pages/users-c69ffcab9d6e5269.js +0 -1
  50. sky/dashboard/out/_next/static/chunks/webpack-1b69b196a4dbffef.js +0 -1
  51. sky/dashboard/out/_next/static/css/8e97adcaacc15293.css +0 -3
  52. /sky/dashboard/out/_next/static/{vA3PPpkBwpRTRNBHFYAw_ → LRpGymRCqq-feuFyoWz4m}/_ssgManifest.js +0 -0
  53. /sky/dashboard/out/_next/static/chunks/{37-824c707421f6f003.js → 37-3a4d77ad62932eaf.js} +0 -0
  54. /sky/dashboard/out/_next/static/chunks/{843-ab9c4f609239155f.js → 843-b3040e493f6e7947.js} +0 -0
  55. /sky/dashboard/out/_next/static/chunks/{938-385d190b95815e11.js → 938-1493ac755eadeb35.js} +0 -0
  56. /sky/dashboard/out/_next/static/chunks/{973-c807fc34f09c7df3.js → 973-db3c97c2bfbceb65.js} +0 -0
  57. /sky/dashboard/out/_next/static/chunks/pages/{_app-32b2caae3445bf3b.js → _app-c416e87d5c2715cf.js} +0 -0
  58. /sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-c8c2191328532b7d.js → [name]-c4ff1ec05e2f3daf.js} +0 -0
  59. {skypilot_nightly-1.0.0.dev20250617.dist-info → skypilot_nightly-1.0.0.dev20250618.dist-info}/WHEEL +0 -0
  60. {skypilot_nightly-1.0.0.dev20250617.dist-info → skypilot_nightly-1.0.0.dev20250618.dist-info}/entry_points.txt +0 -0
  61. {skypilot_nightly-1.0.0.dev20250617.dist-info → skypilot_nightly-1.0.0.dev20250618.dist-info}/licenses/LICENSE +0 -0
  62. {skypilot_nightly-1.0.0.dev20250617.dist-info → skypilot_nightly-1.0.0.dev20250618.dist-info}/top_level.txt +0 -0
sky/server/server.py CHANGED
@@ -23,6 +23,7 @@ import zipfile
23
23
  import aiofiles
24
24
  import fastapi
25
25
  from fastapi.middleware import cors
26
+ from passlib.hash import apr_md5_crypt
26
27
  import starlette.middleware.base
27
28
 
28
29
  import sky
@@ -102,6 +103,74 @@ logger = sky_logging.init_logger(__name__)
102
103
  # response will block other requests from being processed.
103
104
 
104
105
 
106
+ def _basic_auth_401_response(content: str):
107
+ """Return a 401 response with basic auth realm."""
108
+ return fastapi.responses.JSONResponse(
109
+ status_code=401,
110
+ headers={'WWW-Authenticate': 'Basic realm=\"SkyPilot\"'},
111
+ content=content)
112
+
113
+
114
+ # TODO(hailong): Remove this function and use request.state.auth_user instead.
115
+ async def _override_user_info_in_request_body(request: fastapi.Request,
116
+ auth_user: Optional[models.User]):
117
+ body = await request.body()
118
+ if auth_user and body:
119
+ try:
120
+ original_json = await request.json()
121
+ except json.JSONDecodeError as e:
122
+ logger.error(f'Error parsing request JSON: {e}')
123
+ else:
124
+ logger.debug(f'Overriding user for {request.state.request_id}: '
125
+ f'{auth_user.name}, {auth_user.id}')
126
+ if 'env_vars' in original_json:
127
+ if isinstance(original_json.get('env_vars'), dict):
128
+ original_json['env_vars'][
129
+ constants.USER_ID_ENV_VAR] = auth_user.id
130
+ original_json['env_vars'][
131
+ constants.USER_ENV_VAR] = auth_user.name
132
+ else:
133
+ logger.warning(
134
+ f'"env_vars" in request body is not a dictionary '
135
+ f'for request {request.state.request_id}. '
136
+ 'Skipping user info injection into body.')
137
+ else:
138
+ original_json['env_vars'] = {}
139
+ original_json['env_vars'][
140
+ constants.USER_ID_ENV_VAR] = auth_user.id
141
+ original_json['env_vars'][
142
+ constants.USER_ENV_VAR] = auth_user.name
143
+ request._body = json.dumps(original_json).encode('utf-8') # pylint: disable=protected-access
144
+
145
+
146
+ def _try_set_basic_auth_user(request: fastapi.Request):
147
+ auth_header = request.headers.get('authorization')
148
+ if not auth_header or not auth_header.lower().startswith('basic '):
149
+ return
150
+
151
+ # Check username and password
152
+ encoded = auth_header.split(' ', 1)[1]
153
+ try:
154
+ decoded = base64.b64decode(encoded).decode()
155
+ username, password = decoded.split(':', 1)
156
+ except Exception: # pylint: disable=broad-except
157
+ return
158
+
159
+ users = global_user_state.get_user_by_name(username)
160
+ if not users:
161
+ return
162
+
163
+ for user in users:
164
+ if not user.name or not user.password:
165
+ continue
166
+ username_encoded = username.encode('utf8')
167
+ db_username_encoded = user.name.encode('utf8')
168
+ if (username_encoded == db_username_encoded and
169
+ apr_md5_crypt.verify(password, user.password)):
170
+ request.state.auth_user = user
171
+ break
172
+
173
+
105
174
  class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
106
175
  """Middleware to handle RBAC."""
107
176
 
@@ -112,7 +181,7 @@ class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
112
181
  request.url.path.startswith('/api/')):
113
182
  return await call_next(request)
114
183
 
115
- auth_user = _get_auth_user_header(request)
184
+ auth_user = request.state.auth_user
116
185
  if auth_user is None:
117
186
  return await call_next(request)
118
187
 
@@ -149,6 +218,50 @@ def _get_auth_user_header(request: fastapi.Request) -> Optional[models.User]:
149
218
  return models.User(id=user_hash, name=user_name)
150
219
 
151
220
 
221
+ class BasicAuthMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
222
+ """Middleware to handle HTTP Basic Auth."""
223
+
224
+ async def dispatch(self, request: fastapi.Request, call_next):
225
+ if request.url.path.startswith('/api/'):
226
+ # Try to set the auth user from the basic auth header so the
227
+ # following endpoint handlers can leverage the auth_user info
228
+ _try_set_basic_auth_user(request)
229
+ return await call_next(request)
230
+
231
+ auth_header = request.headers.get('authorization')
232
+ if not auth_header or not auth_header.lower().startswith('basic '):
233
+ return _basic_auth_401_response('Invalid basic auth')
234
+
235
+ # Check username and password
236
+ encoded = auth_header.split(' ', 1)[1]
237
+ try:
238
+ decoded = base64.b64decode(encoded).decode()
239
+ username, password = decoded.split(':', 1)
240
+ except Exception: # pylint: disable=broad-except
241
+ return _basic_auth_401_response('Invalid basic auth')
242
+
243
+ users = global_user_state.get_user_by_name(username)
244
+ if not users:
245
+ return _basic_auth_401_response('Invalid credentials')
246
+
247
+ valid_user = False
248
+ for user in users:
249
+ if not user.name or not user.password:
250
+ continue
251
+ username_encoded = username.encode('utf8')
252
+ db_username_encoded = user.name.encode('utf8')
253
+ if (username_encoded == db_username_encoded and
254
+ apr_md5_crypt.verify(password, user.password)):
255
+ valid_user = True
256
+ request.state.auth_user = user
257
+ await _override_user_info_in_request_body(request, user)
258
+ break
259
+ if not valid_user:
260
+ return _basic_auth_401_response('Invalid credentials')
261
+
262
+ return await call_next(request)
263
+
264
+
152
265
  class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
153
266
  """Middleware to handle auth proxy."""
154
267
 
@@ -168,33 +281,7 @@ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
168
281
  else:
169
282
  request.state.auth_user = None
170
283
 
171
- body = await request.body()
172
- if auth_user and body:
173
- try:
174
- original_json = await request.json()
175
- except json.JSONDecodeError as e:
176
- logger.error(f'Error parsing request JSON: {e}')
177
- else:
178
- logger.debug(f'Overriding user for {request.state.request_id}: '
179
- f'{auth_user.name}, {auth_user.id}')
180
- if 'env_vars' in original_json:
181
- if isinstance(original_json.get('env_vars'), dict):
182
- original_json['env_vars'][
183
- constants.USER_ID_ENV_VAR] = auth_user.id
184
- original_json['env_vars'][
185
- constants.USER_ENV_VAR] = auth_user.name
186
- else:
187
- logger.warning(
188
- f'"env_vars" in request body is not a dictionary '
189
- f'for request {request.state.request_id}. '
190
- 'Skipping user info injection into body.')
191
- else:
192
- original_json['env_vars'] = {}
193
- original_json['env_vars'][
194
- constants.USER_ID_ENV_VAR] = auth_user.id
195
- original_json['env_vars'][
196
- constants.USER_ENV_VAR] = auth_user.name
197
- request._body = json.dumps(original_json).encode('utf-8') # pylint: disable=protected-access
284
+ await _override_user_info_in_request_body(request, auth_user)
198
285
  return await call_next(request)
199
286
 
200
287
 
@@ -306,6 +393,9 @@ app.add_middleware(
306
393
  allow_headers=['*'],
307
394
  # TODO(syang): remove X-Request-ID when v0.10.0 is released.
308
395
  expose_headers=['X-Request-ID', 'X-Skypilot-Request-ID'])
396
+ enable_basic_auth = os.environ.get(constants.ENV_VAR_ENABLE_BASIC_AUTH, 'false')
397
+ if str(enable_basic_auth).lower() == 'true':
398
+ app.add_middleware(BasicAuthMiddleware)
309
399
  app.add_middleware(AuthProxyMiddleware)
310
400
  app.add_middleware(RequestIDMiddleware)
311
401
  app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
@@ -1232,7 +1322,7 @@ async def health(request: fastapi.Request) -> Dict[str, Any]:
1232
1322
  disk, which can be used to warn about restarting the API server
1233
1323
  - commit: str; The commit hash of SkyPilot used for API server.
1234
1324
  """
1235
- user = _get_auth_user_header(request)
1325
+ user = request.state.auth_user
1236
1326
  return {
1237
1327
  'status': common.ApiServerStatus.HEALTHY.value,
1238
1328
  'api_version': server_constants.API_VERSION,
@@ -1240,6 +1330,8 @@ async def health(request: fastapi.Request) -> Dict[str, Any]:
1240
1330
  'version_on_disk': common.get_skypilot_version_on_disk(),
1241
1331
  'commit': sky.__commit__,
1242
1332
  'user': user.to_dict() if user is not None else None,
1333
+ 'basic_auth_enabled': os.environ.get(
1334
+ constants.ENV_VAR_ENABLE_BASIC_AUTH, 'false').lower() == 'true',
1243
1335
  }
1244
1336
 
1245
1337
 
@@ -58,8 +58,17 @@ install_requires = [
58
58
  'setproctitle',
59
59
  'sqlalchemy',
60
60
  'psycopg2-binary',
61
+ # TODO(hailong): These three dependencies should be removed after we make
62
+ # the client-side actually not importing them.
61
63
  'casbin',
62
64
  'sqlalchemy_adapter',
65
+ 'passlib',
66
+ ]
67
+
68
+ server_dependencies = [
69
+ 'casbin',
70
+ 'sqlalchemy_adapter',
71
+ 'passlib',
63
72
  ]
64
73
 
65
74
  local_ray = [
@@ -162,7 +171,8 @@ extras_require: Dict[str, List[str]] = {
162
171
  'nebius': [
163
172
  'nebius>=0.2.0',
164
173
  ] + aws_dependencies,
165
- 'hyperbolic': [] # No dependencies needed for hyperbolic
174
+ 'hyperbolic': [], # No dependencies needed for hyperbolic
175
+ 'server': server_dependencies,
166
176
  }
167
177
 
168
178
  # Nebius needs python3.10. If python 3.9 [all] will not install nebius
sky/skylet/constants.py CHANGED
@@ -89,7 +89,7 @@ TASK_ID_LIST_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}TASK_IDS'
89
89
  # cluster yaml is updated.
90
90
  #
91
91
  # TODO(zongheng,zhanghao): make the upgrading of skylet automatic?
92
- SKYLET_VERSION = '13'
92
+ SKYLET_VERSION = '14'
93
93
  # The version of the lib files that skylet/jobs use. Whenever there is an API
94
94
  # change for the job_lib or log_lib, we need to bump this version, so that the
95
95
  # user can be notified to update their SkyPilot version on the remote cluster.
@@ -411,6 +411,11 @@ SKY_USER_FILE_PATH = '~/.sky/generated'
411
411
  # Environment variable that is set to 'true' if this is a skypilot server.
412
412
  ENV_VAR_IS_SKYPILOT_SERVER = 'IS_SKYPILOT_SERVER'
413
413
 
414
+ # Environment variable that is set to 'true' if basic
415
+ # authentication is enabled in the API server.
416
+ ENV_VAR_ENABLE_BASIC_AUTH = 'ENABLE_BASIC_AUTH'
417
+ SKYPILOT_INITIAL_BASIC_AUTH = 'SKYPILOT_INITIAL_BASIC_AUTH'
418
+
414
419
  SKYPILOT_DEFAULT_WORKSPACE = 'default'
415
420
 
416
421
  # BEGIN constants used for service catalog.
@@ -426,6 +431,9 @@ ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci',
426
431
  # The user ID of the SkyPilot system.
427
432
  SKYPILOT_SYSTEM_USER_ID = 'skypilot-system'
428
433
 
434
+ # The directory to store the logging configuration.
435
+ LOGGING_CONFIG_DIR = '~/.sky/logging'
436
+
429
437
  # Resources constants
430
438
  TIME_UNITS = {
431
439
  's': 1 / 60,
sky/skylet/job_lib.py CHANGED
@@ -14,7 +14,7 @@ import sqlite3
14
14
  import threading
15
15
  import time
16
16
  import typing
17
- from typing import Any, Dict, List, Optional, Sequence
17
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
18
18
 
19
19
  import colorama
20
20
  import filelock
@@ -62,6 +62,7 @@ class JobInfoLoc(enum.IntEnum):
62
62
  END_AT = 7
63
63
  RESOURCES = 8
64
64
  PID = 9
65
+ LOG_PATH = 10
65
66
 
66
67
 
67
68
  def create_table(cursor, conn):
@@ -101,7 +102,8 @@ def create_table(cursor, conn):
101
102
  start_at FLOAT DEFAULT -1,
102
103
  end_at FLOAT DEFAULT NULL,
103
104
  resources TEXT DEFAULT NULL,
104
- pid INTEGER DEFAULT -1)""")
105
+ pid INTEGER DEFAULT -1,
106
+ log_dir TEXT DEFAULT NULL)""")
105
107
 
106
108
  cursor.execute("""CREATE TABLE IF NOT EXISTS pending_jobs(
107
109
  job_id INTEGER,
@@ -114,6 +116,8 @@ def create_table(cursor, conn):
114
116
  db_utils.add_column_to_table(cursor, conn, 'jobs', 'resources', 'TEXT')
115
117
  db_utils.add_column_to_table(cursor, conn, 'jobs', 'pid',
116
118
  'INTEGER DEFAULT -1')
119
+ db_utils.add_column_to_table(cursor, conn, 'jobs', 'log_dir',
120
+ 'TEXT DEFAULT NULL')
117
121
  conn.commit()
118
122
 
119
123
 
@@ -335,13 +339,13 @@ def make_job_command_with_user_switching(username: str,
335
339
 
336
340
  @init_db
337
341
  def add_job(job_name: str, username: str, run_timestamp: str,
338
- resources_str: str) -> int:
342
+ resources_str: str) -> Tuple[int, str]:
339
343
  """Atomically reserve the next available job id for the user."""
340
344
  assert _DB is not None
341
345
  job_submitted_at = time.time()
342
346
  # job_id will autoincrement with the null value
343
347
  _DB.cursor.execute(
344
- 'INSERT INTO jobs VALUES (null, ?, ?, ?, ?, ?, ?, null, ?, 0)',
348
+ 'INSERT INTO jobs VALUES (null, ?, ?, ?, ?, ?, ?, null, ?, 0, null)',
345
349
  (job_name, username, job_submitted_at, JobStatus.INIT.value,
346
350
  run_timestamp, None, resources_str))
347
351
  _DB.conn.commit()
@@ -350,7 +354,41 @@ def add_job(job_name: str, username: str, run_timestamp: str,
350
354
  for row in rows:
351
355
  job_id = row[0]
352
356
  assert job_id is not None
353
- return job_id
357
+ log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, f'{job_id}-{job_name}')
358
+ set_log_dir_no_lock(job_id, log_dir)
359
+ return job_id, log_dir
360
+
361
+
362
+ @init_db
363
+ def set_log_dir_no_lock(job_id: int, log_dir: str) -> None:
364
+ """Set the log directory for the job.
365
+
366
+ We persist the log directory for the job to allow changing the log directory
367
+ generation logic over versions.
368
+
369
+ Args:
370
+ job_id: The ID of the job.
371
+ log_dir: The log directory for the job.
372
+ """
373
+ assert _DB is not None
374
+ _DB.cursor.execute('UPDATE jobs SET log_dir=(?) WHERE job_id=(?)',
375
+ (log_dir, job_id))
376
+ _DB.conn.commit()
377
+
378
+
379
+ @init_db
380
+ def get_log_dir_for_job(job_id: int) -> Optional[str]:
381
+ """Get the log directory for the job.
382
+
383
+ Args:
384
+ job_id: The ID of the job.
385
+ """
386
+ assert _DB is not None
387
+ rows = _DB.cursor.execute('SELECT log_dir FROM jobs WHERE job_id=(?)',
388
+ (job_id,))
389
+ for row in rows:
390
+ return row[0]
391
+ return None
354
392
 
355
393
 
356
394
  @init_db
@@ -978,8 +1016,8 @@ def get_run_timestamp(job_id: Optional[int]) -> Optional[str]:
978
1016
 
979
1017
 
980
1018
  @init_db
981
- def run_timestamp_with_globbing_payload(job_ids: List[Optional[str]]) -> str:
982
- """Returns the relative paths to the log files for job with globbing."""
1019
+ def get_log_dir_for_jobs(job_ids: List[Optional[str]]) -> str:
1020
+ """Returns the relative paths to the log files for jobs with globbing."""
983
1021
  assert _DB is not None
984
1022
  query_str = ' OR '.join(['job_id GLOB (?)'] * len(job_ids))
985
1023
  _DB.cursor.execute(
@@ -987,12 +1025,16 @@ def run_timestamp_with_globbing_payload(job_ids: List[Optional[str]]) -> str:
987
1025
  SELECT * FROM jobs
988
1026
  WHERE {query_str}""", job_ids)
989
1027
  rows = _DB.cursor.fetchall()
990
- run_timestamps = {}
1028
+ job_to_dir = {}
991
1029
  for row in rows:
992
1030
  job_id = row[JobInfoLoc.JOB_ID.value]
993
- run_timestamp = row[JobInfoLoc.RUN_TIMESTAMP.value]
994
- run_timestamps[str(job_id)] = run_timestamp
995
- return message_utils.encode_payload(run_timestamps)
1031
+ if row[JobInfoLoc.LOG_PATH.value]:
1032
+ job_to_dir[str(job_id)] = row[JobInfoLoc.LOG_PATH.value]
1033
+ else:
1034
+ run_timestamp = row[JobInfoLoc.RUN_TIMESTAMP.value]
1035
+ job_to_dir[str(job_id)] = os.path.join(constants.SKY_LOGS_DIRECTORY,
1036
+ run_timestamp)
1037
+ return message_utils.encode_payload(job_to_dir)
996
1038
 
997
1039
 
998
1040
  class JobLibCodeGen:
@@ -1024,12 +1066,16 @@ class JobLibCodeGen:
1024
1066
  '\nif int(constants.SKYLET_VERSION) < 9: '
1025
1067
  'raise RuntimeError("SkyPilot runtime is too old, which does not '
1026
1068
  'support submitting jobs.")',
1027
- '\njob_id = job_lib.add_job('
1069
+ '\nresult = job_lib.add_job('
1028
1070
  f'{job_name!r},'
1029
1071
  f'{username!r},'
1030
1072
  f'{run_timestamp!r},'
1031
1073
  f'{resources_str!r})',
1032
- 'print("Job ID: " + str(job_id), flush=True)',
1074
+ ('\nif isinstance(result, tuple):'
1075
+ '\n print("Job ID: " + str(result[0]), flush=True)'
1076
+ '\n print("Log Dir: " + str(result[1]), flush=True)'
1077
+ '\nelse:'
1078
+ '\n print("Job ID: " + str(result), flush=True)'),
1033
1079
  ]
1034
1080
  return cls._build(code)
1035
1081
 
@@ -1098,9 +1144,17 @@ class JobLibCodeGen:
1098
1144
  # We use != instead of is not because 1 is not None will print a warning:
1099
1145
  # <stdin>:1: SyntaxWarning: "is not" with a literal. Did you mean "!="?
1100
1146
  f'job_id = {job_id} if {job_id} != None else job_lib.get_latest_job_id()',
1101
- 'run_timestamp = job_lib.get_run_timestamp(job_id)',
1102
- f'log_dir = None if run_timestamp is None else os.path.join({constants.SKY_LOGS_DIRECTORY!r}, run_timestamp)',
1103
- f'tail_log_kwargs = {{"job_id": job_id, "log_dir": log_dir, "managed_job_id": {managed_job_id!r}, "follow": {follow}}}',
1147
+ # For backward compatibility, use the legacy generation rule for
1148
+ # jobs submitted before 0.11.0.
1149
+ ('log_dir = None\n'
1150
+ 'if hasattr(job_lib, "get_log_dir_for_job"):\n'
1151
+ ' log_dir = job_lib.get_log_dir_for_job(job_id)\n'
1152
+ 'if log_dir is None:\n'
1153
+ ' run_timestamp = job_lib.get_run_timestamp(job_id)\n'
1154
+ f' log_dir = None if run_timestamp is None else os.path.join({constants.SKY_LOGS_DIRECTORY!r}, run_timestamp)'
1155
+ ),
1156
+ # Add a newline to leave the if indent block above.
1157
+ f'\ntail_log_kwargs = {{"job_id": job_id, "log_dir": log_dir, "managed_job_id": {managed_job_id!r}, "follow": {follow}}}',
1104
1158
  f'{_LINUX_NEW_LINE}if getattr(constants, "SKYLET_LIB_VERSION", 1) > 1: tail_log_kwargs["tail"] = {tail}',
1105
1159
  f'{_LINUX_NEW_LINE}log_lib.tail_logs(**tail_log_kwargs)',
1106
1160
  # After tailing, check the job status and exit with appropriate code
@@ -1140,12 +1194,14 @@ class JobLibCodeGen:
1140
1194
  return cls._build(code)
1141
1195
 
1142
1196
  @classmethod
1143
- def get_run_timestamp_with_globbing(cls,
1144
- job_ids: Optional[List[str]]) -> str:
1197
+ def get_log_dirs_for_jobs(cls, job_ids: Optional[List[str]]) -> str:
1145
1198
  code = [
1146
1199
  f'job_ids = {job_ids} if {job_ids} is not None '
1147
1200
  'else [job_lib.get_latest_job_id()]',
1148
- 'log_dirs = job_lib.run_timestamp_with_globbing_payload(job_ids)',
1201
+ # TODO(aylei): backward compatibility, remove after 0.12.0.
1202
+ 'log_dirs = job_lib.get_log_dir_for_jobs(job_ids) if '
1203
+ 'hasattr(job_lib, "get_log_dir_for_jobs") else '
1204
+ 'job_lib.run_timestamp_with_globbing_payload(job_ids)',
1149
1205
  'print(log_dirs, flush=True)',
1150
1206
  ]
1151
1207
  return cls._build(code)
@@ -273,6 +273,15 @@ available_node_types:
273
273
  {% if (k8s_acc_label_key is not none and k8s_acc_label_values is not none) %}
274
274
  skypilot-binpack: "gpu"
275
275
  {% endif %}
276
+ {% if k8s_kueue_local_queue_name %}
277
+ kueue.x-k8s.io/queue-name: {{k8s_kueue_local_queue_name}}
278
+ kueue.x-k8s.io/pod-group-name: {{cluster_name_on_cloud}}
279
+ {% endif %}
280
+ {% if k8s_kueue_local_queue_name %}
281
+ annotations:
282
+ kueue.x-k8s.io/retriable-in-group: "false"
283
+ kueue.x-k8s.io/pod-group-total-count: "{{ num_nodes|string }}"
284
+ {% endif %}
276
285
  spec:
277
286
  # serviceAccountName: skypilot-service-account
278
287
  serviceAccountName: {{k8s_service_account_name}}
sky/users/permission.py CHANGED
@@ -1,8 +1,8 @@
1
1
  """Permission service for SkyPilot API Server."""
2
2
  import contextlib
3
+ import hashlib
3
4
  import logging
4
5
  import os
5
- import threading
6
6
  from typing import Generator, List
7
7
 
8
8
  import casbin
@@ -10,9 +10,11 @@ import filelock
10
10
  import sqlalchemy_adapter
11
11
 
12
12
  from sky import global_user_state
13
+ from sky import models
13
14
  from sky import sky_logging
14
15
  from sky.skylet import constants
15
16
  from sky.users import rbac
17
+ from sky.utils import common_utils
16
18
 
17
19
  logging.getLogger('casbin.policy').setLevel(sky_logging.ERROR)
18
20
  logging.getLogger('casbin.role').setLevel(sky_logging.ERROR)
@@ -23,31 +25,46 @@ POLICY_UPDATE_LOCK_PATH = os.path.expanduser('~/.sky/.policy_update.lock')
23
25
  POLICY_UPDATE_LOCK_TIMEOUT_SECONDS = 20
24
26
 
25
27
  _enforcer_instance = None
26
- _lock = threading.Lock()
27
28
 
28
29
 
29
30
  class PermissionService:
30
31
  """Permission service for SkyPilot API Server."""
31
32
 
32
33
  def __init__(self):
33
- global _enforcer_instance
34
- if _enforcer_instance is None:
35
- # For different threads, we share the same enforcer instance.
36
- with _lock:
37
- if _enforcer_instance is None:
38
- _enforcer_instance = self
39
- engine = global_user_state.initialize_and_get_db()
40
- adapter = sqlalchemy_adapter.Adapter(engine)
41
- model_path = os.path.join(os.path.dirname(__file__),
42
- 'model.conf')
43
- enforcer = casbin.Enforcer(model_path, adapter)
44
- self.enforcer = enforcer
45
- else:
46
- self.enforcer = _enforcer_instance.enforcer
47
- else:
48
- self.enforcer = _enforcer_instance.enforcer
49
34
  with _policy_lock():
50
- self._maybe_initialize_policies()
35
+ global _enforcer_instance
36
+ if _enforcer_instance is None:
37
+ _enforcer_instance = self
38
+ engine = global_user_state.initialize_and_get_db()
39
+ adapter = sqlalchemy_adapter.Adapter(engine)
40
+ model_path = os.path.join(os.path.dirname(__file__),
41
+ 'model.conf')
42
+ enforcer = casbin.Enforcer(model_path, adapter)
43
+ self.enforcer = enforcer
44
+ self._maybe_initialize_policies()
45
+ self._maybe_initialize_basic_auth_user()
46
+ else:
47
+ self.enforcer = _enforcer_instance.enforcer
48
+
49
+ def _maybe_initialize_basic_auth_user(self) -> None:
50
+ """Initialize basic auth user if it is enabled."""
51
+ basic_auth = os.environ.get(constants.SKYPILOT_INITIAL_BASIC_AUTH)
52
+ if not basic_auth:
53
+ return
54
+ username, password = basic_auth.split(':', 1)
55
+ if username and password:
56
+ user_hash = hashlib.md5(
57
+ username.encode()).hexdigest()[:common_utils.USER_HASH_LENGTH]
58
+ user_info = global_user_state.get_user(user_hash)
59
+ if user_info:
60
+ logger.info(f'Basic auth user {username} already exists')
61
+ return
62
+ global_user_state.add_or_update_user(
63
+ models.User(id=user_hash, name=username, password=password))
64
+ self.enforcer.add_grouping_policy(user_hash,
65
+ rbac.RoleName.ADMIN.value)
66
+ self.enforcer.save_policy()
67
+ logger.info(f'Basic auth user {username} initialized')
51
68
 
52
69
  def _maybe_initialize_policies(self) -> None:
53
70
  """Initialize policies if they don't already exist."""
@@ -147,6 +164,19 @@ class PermissionService:
147
164
  return True
148
165
  return False
149
166
 
167
+ def delete_user(self, user_id: str) -> None:
168
+ """Delete user role relationship."""
169
+ with _policy_lock():
170
+ # Get current roles
171
+ self._load_policy_no_lock()
172
+ # Avoid calling get_user_roles, as it will require the lock.
173
+ current_roles = self.enforcer.get_roles_for_user(user_id)
174
+ if not current_roles:
175
+ logger.warning(f'User {user_id} has no roles')
176
+ return
177
+ self.enforcer.remove_grouping_policy(user_id, current_roles[0])
178
+ self.enforcer.save_policy()
179
+
150
180
  def update_role(self, user_id: str, new_role: str) -> None:
151
181
  """Update user role relationship."""
152
182
  with _policy_lock():
sky/users/rbac.py CHANGED
@@ -25,8 +25,17 @@ _DEFAULT_USER_BLOCKLIST = [{
25
25
  'path': '/workspaces/delete',
26
26
  'method': 'POST'
27
27
  }, {
28
- 'path': '/users/update',
28
+ 'path': '/users/delete',
29
29
  'method': 'POST'
30
+ }, {
31
+ 'path': '/users/create',
32
+ 'method': 'POST'
33
+ }, {
34
+ 'path': '/users/import',
35
+ 'method': 'POST'
36
+ }, {
37
+ 'path': '/users/export',
38
+ 'method': 'GET'
30
39
  }]
31
40
 
32
41