skypilot-nightly 1.0.0.dev20250610__py3-none-any.whl → 1.0.0.dev20250611__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +2 -2
- sky/admin_policy.py +132 -6
- sky/benchmark/benchmark_state.py +39 -1
- sky/cli.py +1 -1
- sky/client/cli.py +1 -1
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/chunks/600.15a0009177e86b86.js +16 -0
- sky/dashboard/out/_next/static/chunks/938-ab185187a63f9cdb.js +1 -0
- sky/dashboard/out/_next/static/chunks/{webpack-0574a5a4ba3cf0ac.js → webpack-208a9812ab4f61c9.js} +1 -1
- sky/dashboard/out/_next/static/css/{8b1c8321d4c02372.css → 5d71bfc09f184bab.css} +1 -1
- sky/dashboard/out/_next/static/{4lwUJxN6KwBqUxqO1VccB → zJqasksBQ3HcqMpA2wTUZ}/_buildManifest.js +1 -1
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/jobs/scheduler.py +4 -5
- sky/jobs/state.py +104 -11
- sky/jobs/utils.py +5 -5
- sky/skylet/job_lib.py +95 -40
- sky/users/permission.py +34 -17
- sky/utils/admin_policy_utils.py +32 -13
- sky/utils/schemas.py +11 -3
- {skypilot_nightly-1.0.0.dev20250610.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/METADATA +1 -1
- {skypilot_nightly-1.0.0.dev20250610.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/RECORD +39 -39
- sky/dashboard/out/_next/static/chunks/600.9cc76ec442b22e10.js +0 -16
- sky/dashboard/out/_next/static/chunks/938-a75b7712639298b7.js +0 -1
- /sky/dashboard/out/_next/static/chunks/pages/{_app-4768de0aede04dc9.js → _app-7bbd9d39d6f9a98a.js} +0 -0
- /sky/dashboard/out/_next/static/{4lwUJxN6KwBqUxqO1VccB → zJqasksBQ3HcqMpA2wTUZ}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20250610.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20250610.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250610.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250610.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/top_level.txt +0 -0
sky/skylet/job_lib.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
This is a remote utility module that provides job queue functionality.
|
4
4
|
"""
|
5
5
|
import enum
|
6
|
+
import functools
|
6
7
|
import getpass
|
7
8
|
import json
|
8
9
|
import os
|
@@ -10,6 +11,7 @@ import pathlib
|
|
10
11
|
import shlex
|
11
12
|
import signal
|
12
13
|
import sqlite3
|
14
|
+
import threading
|
13
15
|
import time
|
14
16
|
import typing
|
15
17
|
from typing import Any, Dict, List, Optional, Sequence
|
@@ -119,9 +121,25 @@ def create_table(cursor, conn):
|
|
119
121
|
conn.commit()
|
120
122
|
|
121
123
|
|
122
|
-
_DB =
|
123
|
-
|
124
|
-
|
124
|
+
_DB = None
|
125
|
+
_db_init_lock = threading.Lock()
|
126
|
+
|
127
|
+
|
128
|
+
def init_db(func):
|
129
|
+
"""Initialize the database."""
|
130
|
+
|
131
|
+
@functools.wraps(func)
|
132
|
+
def wrapper(*args, **kwargs):
|
133
|
+
global _DB
|
134
|
+
if _DB is not None:
|
135
|
+
return func(*args, **kwargs)
|
136
|
+
|
137
|
+
with _db_init_lock:
|
138
|
+
if _DB is None:
|
139
|
+
_DB = db_utils.SQLiteConn(_DB_PATH, create_table)
|
140
|
+
return func(*args, **kwargs)
|
141
|
+
|
142
|
+
return wrapper
|
125
143
|
|
126
144
|
|
127
145
|
class JobStatus(enum.Enum):
|
@@ -210,30 +228,37 @@ _PRE_RESOURCE_STATUSES = [JobStatus.PENDING]
|
|
210
228
|
class JobScheduler:
|
211
229
|
"""Base class for job scheduler"""
|
212
230
|
|
231
|
+
@init_db
|
213
232
|
def queue(self, job_id: int, cmd: str) -> None:
|
214
|
-
|
215
|
-
|
216
|
-
|
233
|
+
assert _DB is not None
|
234
|
+
_DB.cursor.execute('INSERT INTO pending_jobs VALUES (?,?,?,?)',
|
235
|
+
(job_id, cmd, 0, int(time.time())))
|
236
|
+
_DB.conn.commit()
|
217
237
|
set_status(job_id, JobStatus.PENDING)
|
218
238
|
self.schedule_step()
|
219
239
|
|
240
|
+
@init_db
|
220
241
|
def remove_job_no_lock(self, job_id: int) -> None:
|
221
|
-
|
222
|
-
|
242
|
+
assert _DB is not None
|
243
|
+
_DB.cursor.execute(f'DELETE FROM pending_jobs WHERE job_id={job_id!r}')
|
244
|
+
_DB.conn.commit()
|
223
245
|
|
246
|
+
@init_db
|
224
247
|
def _run_job(self, job_id: int, run_cmd: str):
|
225
|
-
|
226
|
-
|
227
|
-
|
248
|
+
assert _DB is not None
|
249
|
+
_DB.cursor.execute(
|
250
|
+
(f'UPDATE pending_jobs SET submit={int(time.time())} '
|
251
|
+
f'WHERE job_id={job_id!r}'))
|
252
|
+
_DB.conn.commit()
|
228
253
|
pid = subprocess_utils.launch_new_process_tree(run_cmd)
|
229
254
|
# TODO(zhwu): Backward compatibility, remove this check after 0.10.0.
|
230
255
|
# This is for the case where the job is submitted with SkyPilot older
|
231
256
|
# than #4318, using ray job submit.
|
232
257
|
if 'job submit' in run_cmd:
|
233
258
|
pid = -1
|
234
|
-
|
235
|
-
|
236
|
-
|
259
|
+
_DB.cursor.execute((f'UPDATE jobs SET pid={pid} '
|
260
|
+
f'WHERE job_id={job_id!r}'))
|
261
|
+
_DB.conn.commit()
|
237
262
|
|
238
263
|
def schedule_step(self, force_update_jobs: bool = False) -> None:
|
239
264
|
if force_update_jobs:
|
@@ -282,8 +307,10 @@ class JobScheduler:
|
|
282
307
|
class FIFOScheduler(JobScheduler):
|
283
308
|
"""First in first out job scheduler"""
|
284
309
|
|
310
|
+
@init_db
|
285
311
|
def _get_pending_job_ids(self) -> List[int]:
|
286
|
-
|
312
|
+
assert _DB is not None
|
313
|
+
rows = _DB.cursor.execute(
|
287
314
|
'SELECT job_id FROM pending_jobs ORDER BY job_id').fetchall()
|
288
315
|
return [row[0] for row in rows]
|
289
316
|
|
@@ -308,26 +335,30 @@ def make_job_command_with_user_switching(username: str,
|
|
308
335
|
return ['sudo', '-H', 'su', '--login', username, '-c', command]
|
309
336
|
|
310
337
|
|
338
|
+
@init_db
|
311
339
|
def add_job(job_name: str, username: str, run_timestamp: str,
|
312
340
|
resources_str: str) -> int:
|
313
341
|
"""Atomically reserve the next available job id for the user."""
|
342
|
+
assert _DB is not None
|
314
343
|
job_submitted_at = time.time()
|
315
344
|
# job_id will autoincrement with the null value
|
316
|
-
|
345
|
+
_DB.cursor.execute(
|
317
346
|
'INSERT INTO jobs VALUES (null, ?, ?, ?, ?, ?, ?, null, ?, 0)',
|
318
347
|
(job_name, username, job_submitted_at, JobStatus.INIT.value,
|
319
348
|
run_timestamp, None, resources_str))
|
320
|
-
|
321
|
-
rows =
|
322
|
-
|
349
|
+
_DB.conn.commit()
|
350
|
+
rows = _DB.cursor.execute('SELECT job_id FROM jobs WHERE run_timestamp=(?)',
|
351
|
+
(run_timestamp,))
|
323
352
|
for row in rows:
|
324
353
|
job_id = row[0]
|
325
354
|
assert job_id is not None
|
326
355
|
return job_id
|
327
356
|
|
328
357
|
|
358
|
+
@init_db
|
329
359
|
def _set_status_no_lock(job_id: int, status: JobStatus) -> None:
|
330
360
|
"""Setting the status of the job in the database."""
|
361
|
+
assert _DB is not None
|
331
362
|
assert status != JobStatus.RUNNING, (
|
332
363
|
'Please use set_job_started() to set job status to RUNNING')
|
333
364
|
if status.is_terminal():
|
@@ -339,15 +370,15 @@ def _set_status_no_lock(job_id: int, status: JobStatus) -> None:
|
|
339
370
|
check_end_at_str = ' AND end_at IS NULL'
|
340
371
|
if status != JobStatus.FAILED_SETUP:
|
341
372
|
check_end_at_str = ''
|
342
|
-
|
373
|
+
_DB.cursor.execute(
|
343
374
|
'UPDATE jobs SET status=(?), end_at=(?) '
|
344
375
|
f'WHERE job_id=(?) {check_end_at_str}',
|
345
376
|
(status.value, end_at, job_id))
|
346
377
|
else:
|
347
|
-
|
378
|
+
_DB.cursor.execute(
|
348
379
|
'UPDATE jobs SET status=(?), end_at=NULL '
|
349
380
|
'WHERE job_id=(?)', (status.value, job_id))
|
350
|
-
|
381
|
+
_DB.conn.commit()
|
351
382
|
|
352
383
|
|
353
384
|
def set_status(job_id: int, status: JobStatus) -> None:
|
@@ -357,16 +388,19 @@ def set_status(job_id: int, status: JobStatus) -> None:
|
|
357
388
|
_set_status_no_lock(job_id, status)
|
358
389
|
|
359
390
|
|
391
|
+
@init_db
|
360
392
|
def set_job_started(job_id: int) -> None:
|
361
393
|
# TODO(mraheja): remove pylint disabling when filelock version updated.
|
362
394
|
# pylint: disable=abstract-class-instantiated
|
395
|
+
assert _DB is not None
|
363
396
|
with filelock.FileLock(_get_lock_path(job_id)):
|
364
|
-
|
397
|
+
_DB.cursor.execute(
|
365
398
|
'UPDATE jobs SET status=(?), start_at=(?), end_at=NULL '
|
366
399
|
'WHERE job_id=(?)', (JobStatus.RUNNING.value, time.time(), job_id))
|
367
|
-
|
400
|
+
_DB.conn.commit()
|
368
401
|
|
369
402
|
|
403
|
+
@init_db
|
370
404
|
def get_status_no_lock(job_id: int) -> Optional[JobStatus]:
|
371
405
|
"""Get the status of the job with the given id.
|
372
406
|
|
@@ -375,8 +409,9 @@ def get_status_no_lock(job_id: int) -> Optional[JobStatus]:
|
|
375
409
|
the status in a while loop as in `log_lib._follow_job_logs`. Otherwise, use
|
376
410
|
`get_status`.
|
377
411
|
"""
|
378
|
-
|
379
|
-
|
412
|
+
assert _DB is not None
|
413
|
+
rows = _DB.cursor.execute('SELECT status FROM jobs WHERE job_id=(?)',
|
414
|
+
(job_id,))
|
380
415
|
for (status,) in rows:
|
381
416
|
if status is None:
|
382
417
|
return None
|
@@ -391,11 +426,13 @@ def get_status(job_id: int) -> Optional[JobStatus]:
|
|
391
426
|
return get_status_no_lock(job_id)
|
392
427
|
|
393
428
|
|
429
|
+
@init_db
|
394
430
|
def get_statuses_payload(job_ids: List[Optional[int]]) -> str:
|
431
|
+
assert _DB is not None
|
395
432
|
# Per-job lock is not required here, since the staled job status will not
|
396
433
|
# affect the caller.
|
397
434
|
query_str = ','.join(['?'] * len(job_ids))
|
398
|
-
rows =
|
435
|
+
rows = _DB.cursor.execute(
|
399
436
|
f'SELECT job_id, status FROM jobs WHERE job_id IN ({query_str})',
|
400
437
|
job_ids)
|
401
438
|
statuses = {job_id: None for job_id in job_ids}
|
@@ -419,14 +456,17 @@ def load_statuses_payload(
|
|
419
456
|
return statuses
|
420
457
|
|
421
458
|
|
459
|
+
@init_db
|
422
460
|
def get_latest_job_id() -> Optional[int]:
|
423
|
-
|
461
|
+
assert _DB is not None
|
462
|
+
rows = _DB.cursor.execute(
|
424
463
|
'SELECT job_id FROM jobs ORDER BY job_id DESC LIMIT 1')
|
425
464
|
for (job_id,) in rows:
|
426
465
|
return job_id
|
427
466
|
return None
|
428
467
|
|
429
468
|
|
469
|
+
@init_db
|
430
470
|
def get_job_submitted_or_ended_timestamp_payload(job_id: int,
|
431
471
|
get_ended_time: bool) -> str:
|
432
472
|
"""Get the job submitted/ended timestamp.
|
@@ -440,9 +480,10 @@ def get_job_submitted_or_ended_timestamp_payload(job_id: int,
|
|
440
480
|
`format_job_queue()`), because the job may stay in PENDING if the cluster is
|
441
481
|
busy.
|
442
482
|
"""
|
483
|
+
assert _DB is not None
|
443
484
|
field = 'end_at' if get_ended_time else 'submitted_at'
|
444
|
-
rows =
|
445
|
-
|
485
|
+
rows = _DB.cursor.execute(f'SELECT {field} FROM jobs WHERE job_id=(?)',
|
486
|
+
(job_id,))
|
446
487
|
for (timestamp,) in rows:
|
447
488
|
return message_utils.encode_payload(timestamp)
|
448
489
|
return message_utils.encode_payload(None)
|
@@ -496,10 +537,12 @@ def _get_records_from_rows(rows) -> List[Dict[str, Any]]:
|
|
496
537
|
return records
|
497
538
|
|
498
539
|
|
540
|
+
@init_db
|
499
541
|
def _get_jobs(
|
500
542
|
user_hash: Optional[str],
|
501
543
|
status_list: Optional[List[JobStatus]] = None) -> List[Dict[str, Any]]:
|
502
544
|
"""Returns jobs with the given fields, sorted by job_id, descending."""
|
545
|
+
assert _DB is not None
|
503
546
|
if status_list is None:
|
504
547
|
status_list = list(JobStatus)
|
505
548
|
status_str_list = [repr(status.value) for status in status_list]
|
@@ -509,14 +552,16 @@ def _get_jobs(
|
|
509
552
|
# We use the old username field for compatibility.
|
510
553
|
filter_str += ' AND username=(?)'
|
511
554
|
params.append(user_hash)
|
512
|
-
rows =
|
555
|
+
rows = _DB.cursor.execute(
|
513
556
|
f'SELECT * FROM jobs {filter_str} ORDER BY job_id DESC', params)
|
514
557
|
records = _get_records_from_rows(rows)
|
515
558
|
return records
|
516
559
|
|
517
560
|
|
561
|
+
@init_db
|
518
562
|
def _get_jobs_by_ids(job_ids: List[int]) -> List[Dict[str, Any]]:
|
519
|
-
|
563
|
+
assert _DB is not None
|
564
|
+
rows = _DB.cursor.execute(
|
520
565
|
f"""\
|
521
566
|
SELECT * FROM jobs
|
522
567
|
WHERE job_id IN ({','.join(['?'] * len(job_ids))})
|
@@ -527,8 +572,10 @@ def _get_jobs_by_ids(job_ids: List[int]) -> List[Dict[str, Any]]:
|
|
527
572
|
return records
|
528
573
|
|
529
574
|
|
575
|
+
@init_db
|
530
576
|
def _get_pending_job(job_id: int) -> Optional[Dict[str, Any]]:
|
531
|
-
|
577
|
+
assert _DB is not None
|
578
|
+
rows = _DB.cursor.execute(
|
532
579
|
'SELECT created_time, submit, run_cmd FROM pending_jobs '
|
533
580
|
f'WHERE job_id={job_id!r}')
|
534
581
|
for row in rows:
|
@@ -698,16 +745,18 @@ def update_job_status(job_ids: List[int],
|
|
698
745
|
return statuses
|
699
746
|
|
700
747
|
|
748
|
+
@init_db
|
701
749
|
def fail_all_jobs_in_progress() -> None:
|
750
|
+
assert _DB is not None
|
702
751
|
in_progress_status = [
|
703
752
|
status.value for status in JobStatus.nonterminal_statuses()
|
704
753
|
]
|
705
|
-
|
754
|
+
_DB.cursor.execute(
|
706
755
|
f"""\
|
707
756
|
UPDATE jobs SET status=(?)
|
708
757
|
WHERE status IN ({','.join(['?'] * len(in_progress_status))})
|
709
758
|
""", (JobStatus.FAILED_DRIVER.value, *in_progress_status))
|
710
|
-
|
759
|
+
_DB.conn.commit()
|
711
760
|
|
712
761
|
|
713
762
|
def update_status() -> None:
|
@@ -720,12 +769,14 @@ def update_status() -> None:
|
|
720
769
|
update_job_status(nonterminal_job_ids)
|
721
770
|
|
722
771
|
|
772
|
+
@init_db
|
723
773
|
def is_cluster_idle() -> bool:
|
724
774
|
"""Returns if the cluster is idle (no in-flight jobs)."""
|
775
|
+
assert _DB is not None
|
725
776
|
in_progress_status = [
|
726
777
|
status.value for status in JobStatus.nonterminal_statuses()
|
727
778
|
]
|
728
|
-
rows =
|
779
|
+
rows = _DB.cursor.execute(
|
729
780
|
f"""\
|
730
781
|
SELECT COUNT(*) FROM jobs
|
731
782
|
WHERE status IN ({','.join(['?'] * len(in_progress_status))})
|
@@ -905,27 +956,31 @@ def cancel_jobs_encoded_results(jobs: Optional[List[int]],
|
|
905
956
|
return message_utils.encode_payload(cancelled_ids)
|
906
957
|
|
907
958
|
|
959
|
+
@init_db
|
908
960
|
def get_run_timestamp(job_id: Optional[int]) -> Optional[str]:
|
909
961
|
"""Returns the relative path to the log file for a job."""
|
910
|
-
|
962
|
+
assert _DB is not None
|
963
|
+
_DB.cursor.execute(
|
911
964
|
"""\
|
912
965
|
SELECT * FROM jobs
|
913
966
|
WHERE job_id=(?)""", (job_id,))
|
914
|
-
row =
|
967
|
+
row = _DB.cursor.fetchone()
|
915
968
|
if row is None:
|
916
969
|
return None
|
917
970
|
run_timestamp = row[JobInfoLoc.RUN_TIMESTAMP.value]
|
918
971
|
return run_timestamp
|
919
972
|
|
920
973
|
|
974
|
+
@init_db
|
921
975
|
def run_timestamp_with_globbing_payload(job_ids: List[Optional[str]]) -> str:
|
922
976
|
"""Returns the relative paths to the log files for job with globbing."""
|
977
|
+
assert _DB is not None
|
923
978
|
query_str = ' OR '.join(['job_id GLOB (?)'] * len(job_ids))
|
924
|
-
|
979
|
+
_DB.cursor.execute(
|
925
980
|
f"""\
|
926
981
|
SELECT * FROM jobs
|
927
982
|
WHERE {query_str}""", job_ids)
|
928
|
-
rows =
|
983
|
+
rows = _DB.cursor.fetchall()
|
929
984
|
run_timestamps = {}
|
930
985
|
for row in rows:
|
931
986
|
job_id = row[JobInfoLoc.JOB_ID.value]
|
sky/users/permission.py
CHANGED
@@ -30,26 +30,34 @@ class PermissionService:
|
|
30
30
|
"""Permission service for SkyPilot API Server."""
|
31
31
|
|
32
32
|
def __init__(self):
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
33
|
+
self.enforcer = None
|
34
|
+
self.init_lock = threading.Lock()
|
35
|
+
|
36
|
+
def _lazy_initialize(self):
|
37
|
+
if self.enforcer is not None:
|
38
|
+
return
|
39
|
+
with self.init_lock:
|
40
|
+
if self.enforcer is not None:
|
41
|
+
return
|
42
|
+
global _enforcer_instance
|
43
|
+
if _enforcer_instance is None:
|
44
|
+
# For different threads, we share the same enforcer instance.
|
45
|
+
with _lock:
|
46
|
+
if _enforcer_instance is None:
|
47
|
+
_enforcer_instance = self
|
48
|
+
engine = global_user_state.initialize_and_get_db()
|
49
|
+
adapter = sqlalchemy_adapter.Adapter(engine)
|
50
|
+
model_path = os.path.join(os.path.dirname(__file__),
|
51
|
+
'model.conf')
|
52
|
+
enforcer = casbin.Enforcer(model_path, adapter)
|
53
|
+
self.enforcer = enforcer
|
54
|
+
else:
|
55
|
+
self.enforcer = _enforcer_instance.enforcer
|
56
|
+
with _policy_lock():
|
57
|
+
self._maybe_initialize_policies()
|
49
58
|
|
50
59
|
def _maybe_initialize_policies(self) -> None:
|
51
60
|
"""Initialize policies if they don't already exist."""
|
52
|
-
# TODO(zhwu): we should avoid running this on client side.
|
53
61
|
logger.debug(f'Initializing policies in process: {os.getpid()}')
|
54
62
|
self._load_policy_no_lock()
|
55
63
|
|
@@ -128,6 +136,7 @@ class PermissionService:
|
|
128
136
|
|
129
137
|
def add_user_if_not_exists(self, user_id: str) -> None:
|
130
138
|
"""Add user role relationship."""
|
139
|
+
self._lazy_initialize()
|
131
140
|
with _policy_lock():
|
132
141
|
self._add_user_if_not_exists_no_lock(user_id)
|
133
142
|
|
@@ -147,6 +156,7 @@ class PermissionService:
|
|
147
156
|
|
148
157
|
def update_role(self, user_id: str, new_role: str) -> None:
|
149
158
|
"""Update user role relationship."""
|
159
|
+
self._lazy_initialize()
|
150
160
|
with _policy_lock():
|
151
161
|
# Get current roles
|
152
162
|
self._load_policy_no_lock()
|
@@ -179,6 +189,7 @@ class PermissionService:
|
|
179
189
|
Returns:
|
180
190
|
A list of role names that the user has.
|
181
191
|
"""
|
192
|
+
self._lazy_initialize()
|
182
193
|
self._load_policy_no_lock()
|
183
194
|
return self.enforcer.get_roles_for_user(user_id)
|
184
195
|
|
@@ -191,6 +202,7 @@ class PermissionService:
|
|
191
202
|
# it is a hot path in every request. It is ok to have a stale policy,
|
192
203
|
# as long as it is eventually consistent.
|
193
204
|
# self._load_policy_no_lock()
|
205
|
+
self._lazy_initialize()
|
194
206
|
return self.enforcer.enforce(user_id, path, method)
|
195
207
|
|
196
208
|
def _load_policy_no_lock(self):
|
@@ -199,6 +211,7 @@ class PermissionService:
|
|
199
211
|
|
200
212
|
def load_policy(self):
|
201
213
|
"""Load policy from storage with lock."""
|
214
|
+
self._lazy_initialize()
|
202
215
|
with _policy_lock():
|
203
216
|
self._load_policy_no_lock()
|
204
217
|
|
@@ -214,6 +227,7 @@ class PermissionService:
|
|
214
227
|
For public workspaces, the permission is granted via a wildcard policy
|
215
228
|
('*').
|
216
229
|
"""
|
230
|
+
self._lazy_initialize()
|
217
231
|
if os.getenv(constants.ENV_VAR_IS_SKYPILOT_SERVER) is None:
|
218
232
|
# When it is not on API server, we allow all users to access all
|
219
233
|
# workspaces, as the workspace check has been done on API server.
|
@@ -241,6 +255,7 @@ class PermissionService:
|
|
241
255
|
For public workspaces, this should be ['*'].
|
242
256
|
For private workspaces, this should be specific user IDs.
|
243
257
|
"""
|
258
|
+
self._lazy_initialize()
|
244
259
|
with _policy_lock():
|
245
260
|
for user in users:
|
246
261
|
logger.debug(f'Adding workspace policy: user={user}, '
|
@@ -258,6 +273,7 @@ class PermissionService:
|
|
258
273
|
For public workspaces, this should be ['*'].
|
259
274
|
For private workspaces, this should be specific user IDs.
|
260
275
|
"""
|
276
|
+
self._lazy_initialize()
|
261
277
|
with _policy_lock():
|
262
278
|
self._load_policy_no_lock()
|
263
279
|
# Remove all existing policies for this workspace
|
@@ -271,6 +287,7 @@ class PermissionService:
|
|
271
287
|
|
272
288
|
def remove_workspace_policy(self, workspace_name: str) -> None:
|
273
289
|
"""Remove workspace policy."""
|
290
|
+
self._lazy_initialize()
|
274
291
|
with _policy_lock():
|
275
292
|
self.enforcer.remove_filtered_policy(1, workspace_name)
|
276
293
|
self.enforcer.save_policy()
|
sky/utils/admin_policy_utils.py
CHANGED
@@ -3,6 +3,7 @@ import contextlib
|
|
3
3
|
import copy
|
4
4
|
import importlib
|
5
5
|
from typing import Iterator, Optional, Tuple, Union
|
6
|
+
import urllib.parse
|
6
7
|
|
7
8
|
import colorama
|
8
9
|
|
@@ -19,18 +20,34 @@ from sky.utils import ux_utils
|
|
19
20
|
logger = sky_logging.init_logger(__name__)
|
20
21
|
|
21
22
|
|
22
|
-
def
|
23
|
-
|
23
|
+
def _is_url(policy_string: str) -> bool:
|
24
|
+
"""Check if the policy string is a URL."""
|
25
|
+
try:
|
26
|
+
parsed = urllib.parse.urlparse(policy_string)
|
27
|
+
return parsed.scheme in ('http', 'https')
|
28
|
+
except Exception: # pylint: disable=broad-except
|
29
|
+
return False
|
30
|
+
|
31
|
+
|
32
|
+
def _get_policy_impl(
|
33
|
+
policy_location: Optional[str]
|
34
|
+
) -> Optional[admin_policy.PolicyInterface]:
|
24
35
|
"""Gets admin-defined policy."""
|
25
|
-
if
|
36
|
+
if policy_location is None:
|
26
37
|
return None
|
38
|
+
|
39
|
+
if _is_url(policy_location):
|
40
|
+
# Use the built-in URL policy class when an URL is specified.
|
41
|
+
return admin_policy.RestfulAdminPolicy(policy_location)
|
42
|
+
|
43
|
+
# Handle module path format
|
27
44
|
try:
|
28
|
-
module_path, class_name =
|
45
|
+
module_path, class_name = policy_location.rsplit('.', 1)
|
29
46
|
module = importlib.import_module(module_path)
|
30
47
|
except ImportError as e:
|
31
48
|
with ux_utils.print_exception_no_traceback():
|
32
49
|
raise ImportError(
|
33
|
-
f'Failed to import policy module: {
|
50
|
+
f'Failed to import policy module: {policy_location}. '
|
34
51
|
'Please check if the module is installed in your Python '
|
35
52
|
'environment.') from e
|
36
53
|
|
@@ -42,13 +59,15 @@ def _get_policy_cls(
|
|
42
59
|
f'Could not find {class_name} class in module {module_path}. '
|
43
60
|
'Please check with your policy admin for details.') from e
|
44
61
|
|
45
|
-
#
|
62
|
+
# Currently we only allow users to define subclass of AdminPolicy
|
63
|
+
# instead of inheriting from PolicyInterface or PolicyTemplate.
|
46
64
|
if not issubclass(policy_cls, admin_policy.AdminPolicy):
|
47
65
|
with ux_utils.print_exception_no_traceback():
|
48
66
|
raise ValueError(
|
49
|
-
f'Policy class {
|
50
|
-
'interface. Please check with your policy admin
|
51
|
-
|
67
|
+
f'Policy class {policy_cls!r} does not implement the '
|
68
|
+
'AdminPolicy interface. Please check with your policy admin '
|
69
|
+
'for details.')
|
70
|
+
return policy_cls()
|
52
71
|
|
53
72
|
|
54
73
|
@contextlib.contextmanager
|
@@ -102,9 +121,9 @@ def apply(
|
|
102
121
|
else:
|
103
122
|
dag = entrypoint
|
104
123
|
|
105
|
-
|
106
|
-
|
107
|
-
if
|
124
|
+
policy_location = skypilot_config.get_nested(('admin_policy',), None)
|
125
|
+
policy = _get_policy_impl(policy_location)
|
126
|
+
if policy is None:
|
108
127
|
return dag, skypilot_config.to_dict()
|
109
128
|
|
110
129
|
if at_client_side:
|
@@ -120,7 +139,7 @@ def apply(
|
|
120
139
|
user_request = admin_policy.UserRequest(task, config, request_options,
|
121
140
|
at_client_side)
|
122
141
|
try:
|
123
|
-
mutated_user_request =
|
142
|
+
mutated_user_request = policy.apply(user_request)
|
124
143
|
except Exception as e: # pylint: disable=broad-except
|
125
144
|
with ux_utils.print_exception_no_traceback():
|
126
145
|
raise exceptions.UserRequestRejectedByPolicy(
|
sky/utils/schemas.py
CHANGED
@@ -1149,9 +1149,17 @@ def get_config_schema():
|
|
1149
1149
|
|
1150
1150
|
admin_policy_schema = {
|
1151
1151
|
'type': 'string',
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1152
|
+
'anyOf': [
|
1153
|
+
{
|
1154
|
+
# Check regex to be a valid python module path
|
1155
|
+
'pattern': (r'^[a-zA-Z_][a-zA-Z0-9_]*'
|
1156
|
+
r'(\.[a-zA-Z_][a-zA-Z0-9_]*)+$'),
|
1157
|
+
},
|
1158
|
+
{
|
1159
|
+
# Check for valid HTTP/HTTPS URL
|
1160
|
+
'pattern': r'^https?://.*$',
|
1161
|
+
}
|
1162
|
+
]
|
1155
1163
|
}
|
1156
1164
|
|
1157
1165
|
allowed_clouds = {
|