skypilot-nightly 1.0.0.dev20250523__py3-none-any.whl → 1.0.0.dev20250524__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +62 -45
  3. sky/backends/cloud_vm_ray_backend.py +3 -1
  4. sky/check.py +332 -170
  5. sky/cli.py +44 -11
  6. sky/client/cli.py +44 -11
  7. sky/client/sdk.py +54 -10
  8. sky/clouds/gcp.py +19 -3
  9. sky/core.py +5 -2
  10. sky/dashboard/out/404.html +1 -1
  11. sky/dashboard/out/_next/static/aHej19bZyl4hoHgrzPCn7/_buildManifest.js +1 -0
  12. sky/dashboard/out/_next/static/chunks/480-ee58038f1a4afd5c.js +1 -0
  13. sky/dashboard/out/_next/static/chunks/488-50d843fdb5396d32.js +15 -0
  14. sky/dashboard/out/_next/static/chunks/498-d7722313e5e5b4e6.js +21 -0
  15. sky/dashboard/out/_next/static/chunks/573-f17bd89d9f9118b3.js +66 -0
  16. sky/dashboard/out/_next/static/chunks/578-7a4795009a56430c.js +6 -0
  17. sky/dashboard/out/_next/static/chunks/734-5f5ce8f347b7f417.js +1 -0
  18. sky/dashboard/out/_next/static/chunks/937.f97f83652028e944.js +1 -0
  19. sky/dashboard/out/_next/static/chunks/938-f347f6144075b0c8.js +1 -0
  20. sky/dashboard/out/_next/static/chunks/9f96d65d-5a3e4af68c26849e.js +1 -0
  21. sky/dashboard/out/_next/static/chunks/pages/_app-dec800f9ef1b10f4.js +1 -0
  22. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-37c042a356f8e608.js +1 -0
  23. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-9529d9e882a0e75c.js +16 -0
  24. sky/dashboard/out/_next/static/chunks/pages/clusters-9e6d1ec6e1ac5b29.js +1 -0
  25. sky/dashboard/out/_next/static/chunks/pages/infra-e690d864aa00e2ea.js +1 -0
  26. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-db6558a5ec687011.js +1 -0
  27. sky/dashboard/out/_next/static/chunks/pages/jobs-73d5e0c369d00346.js +16 -0
  28. sky/dashboard/out/_next/static/chunks/pages/users-2d319455c3f1c3e2.js +1 -0
  29. sky/dashboard/out/_next/static/chunks/pages/workspaces-02a7b60f2ead275f.js +1 -0
  30. sky/dashboard/out/_next/static/chunks/webpack-deda68c926e8d0bc.js +1 -0
  31. sky/dashboard/out/_next/static/css/d2cdba64c9202dd7.css +3 -0
  32. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  33. sky/dashboard/out/clusters/[cluster].html +1 -1
  34. sky/dashboard/out/clusters.html +1 -1
  35. sky/dashboard/out/index.html +1 -1
  36. sky/dashboard/out/infra.html +1 -1
  37. sky/dashboard/out/jobs/[job].html +1 -1
  38. sky/dashboard/out/jobs.html +1 -1
  39. sky/dashboard/out/users.html +1 -0
  40. sky/dashboard/out/workspaces.html +1 -0
  41. sky/data/storage.py +1 -1
  42. sky/global_user_state.py +42 -19
  43. sky/jobs/constants.py +1 -1
  44. sky/jobs/server/core.py +72 -56
  45. sky/jobs/state.py +26 -5
  46. sky/jobs/utils.py +65 -13
  47. sky/optimizer.py +6 -3
  48. sky/provision/fluidstack/instance.py +1 -0
  49. sky/serve/server/core.py +9 -6
  50. sky/server/html/token_page.html +6 -1
  51. sky/server/requests/executor.py +1 -0
  52. sky/server/requests/payloads.py +11 -0
  53. sky/server/server.py +68 -5
  54. sky/skylet/constants.py +4 -1
  55. sky/skypilot_config.py +83 -9
  56. sky/utils/cli_utils/status_utils.py +18 -8
  57. sky/utils/kubernetes/deploy_remote_cluster.py +150 -147
  58. sky/utils/log_utils.py +4 -0
  59. sky/utils/schemas.py +54 -0
  60. {skypilot_nightly-1.0.0.dev20250523.dist-info → skypilot_nightly-1.0.0.dev20250524.dist-info}/METADATA +1 -1
  61. {skypilot_nightly-1.0.0.dev20250523.dist-info → skypilot_nightly-1.0.0.dev20250524.dist-info}/RECORD +66 -59
  62. sky/dashboard/out/_next/static/ECKwDNS9v9y3_IKFZ2lpp/_buildManifest.js +0 -1
  63. sky/dashboard/out/_next/static/chunks/236-1a3a9440417720eb.js +0 -6
  64. sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
  65. sky/dashboard/out/_next/static/chunks/37-d584022b0da4ac3b.js +0 -6
  66. sky/dashboard/out/_next/static/chunks/393-e1eaa440481337ec.js +0 -1
  67. sky/dashboard/out/_next/static/chunks/480-f28cd152a98997de.js +0 -1
  68. sky/dashboard/out/_next/static/chunks/582-683f4f27b81996dc.js +0 -59
  69. sky/dashboard/out/_next/static/chunks/pages/_app-8cfab319f9fb3ae8.js +0 -1
  70. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-33bc2bec322249b1.js +0 -1
  71. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-e2fc2dd1955e6c36.js +0 -1
  72. sky/dashboard/out/_next/static/chunks/pages/clusters-3a748bd76e5c2984.js +0 -1
  73. sky/dashboard/out/_next/static/chunks/pages/infra-abf08c4384190a39.js +0 -1
  74. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-70756c2dad850a7e.js +0 -1
  75. sky/dashboard/out/_next/static/chunks/pages/jobs-ecd804b9272f4a7c.js +0 -1
  76. sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
  77. sky/dashboard/out/_next/static/css/7e7ce4ff31d3977b.css +0 -3
  78. /sky/dashboard/out/_next/static/{ECKwDNS9v9y3_IKFZ2lpp → aHej19bZyl4hoHgrzPCn7}/_ssgManifest.js +0 -0
  79. {skypilot_nightly-1.0.0.dev20250523.dist-info → skypilot_nightly-1.0.0.dev20250524.dist-info}/WHEEL +0 -0
  80. {skypilot_nightly-1.0.0.dev20250523.dist-info → skypilot_nightly-1.0.0.dev20250524.dist-info}/entry_points.txt +0 -0
  81. {skypilot_nightly-1.0.0.dev20250523.dist-info → skypilot_nightly-1.0.0.dev20250524.dist-info}/licenses/LICENSE +0 -0
  82. {skypilot_nightly-1.0.0.dev20250523.dist-info → skypilot_nightly-1.0.0.dev20250524.dist-info}/top_level.txt +0 -0
sky/jobs/utils.py CHANGED
@@ -23,6 +23,7 @@ from sky import backends
23
23
  from sky import exceptions
24
24
  from sky import global_user_state
25
25
  from sky import sky_logging
26
+ from sky import skypilot_config
26
27
  from sky.adaptors import common as adaptors_common
27
28
  from sky.backends import backend_utils
28
29
  from sky.jobs import constants as managed_job_constants
@@ -463,7 +464,8 @@ def generate_managed_job_cluster_name(task_name: str, job_id: int) -> str:
463
464
 
464
465
 
465
466
  def cancel_jobs_by_id(job_ids: Optional[List[int]],
466
- all_users: bool = False) -> str:
467
+ all_users: bool = False,
468
+ current_workspace: Optional[str] = None) -> str:
467
469
  """Cancel jobs by id.
468
470
 
469
471
  If job_ids is None, cancel all jobs.
@@ -474,9 +476,11 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]],
474
476
  job_ids = list(set(job_ids))
475
477
  if not job_ids:
476
478
  return 'No job to cancel.'
477
- job_id_str = ', '.join(map(str, job_ids))
478
- logger.info(f'Cancelling jobs {job_id_str}.')
479
+ if current_workspace is None:
480
+ current_workspace = constants.SKYPILOT_DEFAULT_WORKSPACE
481
+
479
482
  cancelled_job_ids: List[int] = []
483
+ wrong_workspace_job_ids: List[int] = []
480
484
  for job_id in job_ids:
481
485
  # Check the status of the managed job status. If it is in
482
486
  # terminal state, we can safely skip it.
@@ -491,6 +495,11 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]],
491
495
 
492
496
  update_managed_jobs_statuses(job_id)
493
497
 
498
+ job_workspace = managed_job_state.get_workspace(job_id)
499
+ if current_workspace is not None and job_workspace != current_workspace:
500
+ wrong_workspace_job_ids.append(job_id)
501
+ continue
502
+
494
503
  # Send the signal to the jobs controller.
495
504
  signal_file = pathlib.Path(SIGNAL_FILE_PREFIX.format(job_id))
496
505
  # Filelock is needed to prevent race condition between signal
@@ -501,17 +510,30 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]],
501
510
  f.flush()
502
511
  cancelled_job_ids.append(job_id)
503
512
 
513
+ wrong_workspace_job_str = ''
514
+ if wrong_workspace_job_ids:
515
+ plural = 's' if len(wrong_workspace_job_ids) > 1 else ''
516
+ plural_verb = 'are' if len(wrong_workspace_job_ids) > 1 else 'is'
517
+ wrong_workspace_job_str = (
518
+ f' Job{plural} with ID{plural}'
519
+ f' {", ".join(map(str, wrong_workspace_job_ids))} '
520
+ f'{plural_verb} skipped as they are not in the active workspace '
521
+ f'{current_workspace!r}. Check the workspace of the job with: '
522
+ f'sky jobs queue')
523
+
504
524
  if not cancelled_job_ids:
505
- return 'No job to cancel.'
525
+ return f'No job to cancel.{wrong_workspace_job_str}'
506
526
  identity_str = f'Job with ID {cancelled_job_ids[0]} is'
507
527
  if len(cancelled_job_ids) > 1:
508
528
  cancelled_job_ids_str = ', '.join(map(str, cancelled_job_ids))
509
529
  identity_str = f'Jobs with IDs {cancelled_job_ids_str} are'
510
530
 
511
- return f'{identity_str} scheduled to be cancelled.'
531
+ msg = f'{identity_str} scheduled to be cancelled.{wrong_workspace_job_str}'
532
+ return msg
512
533
 
513
534
 
514
- def cancel_job_by_name(job_name: str) -> str:
535
+ def cancel_job_by_name(job_name: str,
536
+ current_workspace: Optional[str] = None) -> str:
515
537
  """Cancel a job by name."""
516
538
  job_ids = managed_job_state.get_nonterminal_job_ids_by_name(job_name)
517
539
  if not job_ids:
@@ -520,8 +542,8 @@ def cancel_job_by_name(job_name: str) -> str:
520
542
  return (f'{colorama.Fore.RED}Multiple running jobs found '
521
543
  f'with name {job_name!r}.\n'
522
544
  f'Job IDs: {job_ids}{colorama.Style.RESET_ALL}')
523
- cancel_jobs_by_id(job_ids)
524
- return f'Job {job_name!r} is scheduled to be cancelled.'
545
+ msg = cancel_jobs_by_id(job_ids, current_workspace=current_workspace)
546
+ return f'{job_name!r} {msg}'
525
547
 
526
548
 
527
549
  def stream_logs_by_id(job_id: int, follow: bool = True) -> Tuple[str, int]:
@@ -1020,10 +1042,15 @@ def format_job_table(
1020
1042
  jobs[get_hash(task)].append(task)
1021
1043
 
1022
1044
  status_counts: Dict[str, int] = collections.defaultdict(int)
1045
+ workspaces = set()
1023
1046
  for job_tasks in jobs.values():
1024
1047
  managed_job_status = _get_job_status_from_tasks(job_tasks)[0]
1025
1048
  if not managed_job_status.is_terminal():
1026
1049
  status_counts[managed_job_status.value] += 1
1050
+ workspaces.add(job_tasks[0].get('workspace',
1051
+ constants.SKYPILOT_DEFAULT_WORKSPACE))
1052
+
1053
+ show_workspace = len(workspaces) > 1 or show_all
1027
1054
 
1028
1055
  user_cols: List[str] = []
1029
1056
  if show_user:
@@ -1034,6 +1061,7 @@ def format_job_table(
1034
1061
  columns = [
1035
1062
  'ID',
1036
1063
  'TASK',
1064
+ *(['WORKSPACE'] if show_workspace else []),
1037
1065
  'NAME',
1038
1066
  *user_cols,
1039
1067
  'REQUESTED',
@@ -1093,6 +1121,8 @@ def format_job_table(
1093
1121
  for job_hash, job_tasks in jobs.items():
1094
1122
  if show_all:
1095
1123
  schedule_state = job_tasks[0]['schedule_state']
1124
+ workspace = job_tasks[0].get('workspace',
1125
+ constants.SKYPILOT_DEFAULT_WORKSPACE)
1096
1126
 
1097
1127
  if len(job_tasks) > 1:
1098
1128
  # Aggregate the tasks into a new row in the table.
@@ -1134,6 +1164,7 @@ def format_job_table(
1134
1164
  job_values = [
1135
1165
  job_id,
1136
1166
  '',
1167
+ *([''] if show_workspace else []),
1137
1168
  job_name,
1138
1169
  *user_values,
1139
1170
  '-',
@@ -1163,9 +1194,11 @@ def format_job_table(
1163
1194
  0, task['job_duration'], absolute=True)
1164
1195
  submitted = log_utils.readable_time_duration(task['submitted_at'])
1165
1196
  user_values = get_user_column_values(task)
1197
+ task_workspace = '-' if len(job_tasks) > 1 else workspace
1166
1198
  values = [
1167
1199
  task['job_id'] if len(job_tasks) == 1 else ' \u21B3',
1168
1200
  task['task_id'] if len(job_tasks) > 1 else '-',
1201
+ *([task_workspace] if show_workspace else []),
1169
1202
  task['task_name'],
1170
1203
  *user_values,
1171
1204
  task['resources'],
@@ -1263,22 +1296,36 @@ class ManagedJobCodeGen:
1263
1296
  def cancel_jobs_by_id(cls,
1264
1297
  job_ids: Optional[List[int]],
1265
1298
  all_users: bool = False) -> str:
1299
+ active_workspace = skypilot_config.get_active_workspace()
1266
1300
  code = textwrap.dedent(f"""\
1267
1301
  if managed_job_version < 2:
1268
1302
  # For backward compatibility, since all_users is not supported
1269
- # before #4787. Assume th
1303
+ # before #4787.
1270
1304
  # TODO(cooperc): Remove compatibility before 0.12.0
1271
1305
  msg = utils.cancel_jobs_by_id({job_ids})
1272
- else:
1306
+ elif managed_job_version < 4:
1307
+ # For backward compatibility, since current_workspace is not
1308
+ # supported before #5660. Don't check the workspace.
1309
+ # TODO(zhwu): Remove compatibility before 0.12.0
1273
1310
  msg = utils.cancel_jobs_by_id({job_ids}, all_users={all_users})
1311
+ else:
1312
+ msg = utils.cancel_jobs_by_id({job_ids}, all_users={all_users},
1313
+ current_workspace={active_workspace!r})
1274
1314
  print(msg, end="", flush=True)
1275
1315
  """)
1276
1316
  return cls._build(code)
1277
1317
 
1278
1318
  @classmethod
1279
1319
  def cancel_job_by_name(cls, job_name: str) -> str:
1320
+ active_workspace = skypilot_config.get_active_workspace()
1280
1321
  code = textwrap.dedent(f"""\
1281
- msg = utils.cancel_job_by_name({job_name!r})
1322
+ if managed_job_version < 4:
1323
+ # For backward compatibility, since current_workspace is not
1324
+ # supported before #5660. Don't check the workspace.
1325
+ # TODO(zhwu): Remove compatibility before 0.12.0
1326
+ msg = utils.cancel_job_by_name({job_name!r})
1327
+ else:
1328
+ msg = utils.cancel_job_by_name({job_name!r}, {active_workspace!r})
1282
1329
  print(msg, end="", flush=True)
1283
1330
  """)
1284
1331
  return cls._build(code)
@@ -1314,11 +1361,16 @@ class ManagedJobCodeGen:
1314
1361
  return cls._build(code)
1315
1362
 
1316
1363
  @classmethod
1317
- def set_pending(cls, job_id: int, managed_job_dag: 'dag_lib.Dag') -> str:
1364
+ def set_pending(cls, job_id: int, managed_job_dag: 'dag_lib.Dag',
1365
+ workspace) -> str:
1318
1366
  dag_name = managed_job_dag.name
1319
1367
  # Add the managed job to queue table.
1320
1368
  code = textwrap.dedent(f"""\
1321
- managed_job_state.set_job_info({job_id}, {dag_name!r})
1369
+ set_job_info_kwargs = {{'workspace': {workspace!r}}}
1370
+ if managed_job_version < 4:
1371
+ set_job_info_kwargs = {{}}
1372
+ managed_job_state.set_job_info(
1373
+ {job_id}, {dag_name!r}, **set_job_info_kwargs)
1322
1374
  """)
1323
1375
  for task_id, task in enumerate(managed_job_dag.tasks):
1324
1376
  resources_str = backend_utils.get_task_resources_str(
sky/optimizer.py CHANGED
@@ -14,6 +14,7 @@ from sky import clouds
14
14
  from sky import exceptions
15
15
  from sky import resources as resources_lib
16
16
  from sky import sky_logging
17
+ from sky import skypilot_config
17
18
  from sky import task as task_lib
18
19
  from sky.adaptors import common as adaptors_common
19
20
  from sky.clouds import cloud as sky_cloud
@@ -1217,9 +1218,11 @@ def _check_specified_clouds(dag: 'dag_lib.Dag') -> None:
1217
1218
  clouds_to_check_again = list(clouds_need_recheck -
1218
1219
  global_disabled_clouds)
1219
1220
  if len(clouds_to_check_again) > 0:
1220
- sky_check.check_capability(sky_cloud.CloudCapability.COMPUTE,
1221
- quiet=True,
1222
- clouds=clouds_to_check_again)
1221
+ sky_check.check_capability(
1222
+ sky_cloud.CloudCapability.COMPUTE,
1223
+ quiet=True,
1224
+ clouds=clouds_to_check_again,
1225
+ workspace=skypilot_config.get_active_workspace())
1223
1226
  enabled_clouds = sky_check.get_cached_enabled_clouds_or_refresh(
1224
1227
  capability=sky_cloud.CloudCapability.COMPUTE,
1225
1228
  raise_if_no_cloud_access=True)
@@ -26,6 +26,7 @@ logger = sky_logging.init_logger(__name__)
26
26
 
27
27
  def get_internal_ip(node_info: Dict[str, Any]) -> None:
28
28
  node_info['internal_ip'] = node_info['ip_address']
29
+
29
30
  private_key_path, _ = auth.get_or_generate_keys()
30
31
  runner = command_runner.SSHCommandRunner(
31
32
  (node_info['ip_address'], 22),
sky/serve/server/core.py CHANGED
@@ -14,6 +14,7 @@ from sky import backends
14
14
  from sky import exceptions
15
15
  from sky import execution
16
16
  from sky import sky_logging
17
+ from sky import skypilot_config
17
18
  from sky import task as task_lib
18
19
  from sky.backends import backend_utils
19
20
  from sky.clouds.service_catalog import common as service_catalog_common
@@ -221,12 +222,14 @@ def up(
221
222
  # Since the controller may be shared among multiple users, launch the
222
223
  # controller with the API server's user hash.
223
224
  with common.with_server_user_hash():
224
- controller_job_id, controller_handle = execution.launch(
225
- task=controller_task,
226
- cluster_name=controller_name,
227
- retry_until_up=True,
228
- _disable_controller_check=True,
229
- )
225
+ with skypilot_config.local_active_workspace_ctx(
226
+ constants.SKYPILOT_DEFAULT_WORKSPACE):
227
+ controller_job_id, controller_handle = execution.launch(
228
+ task=controller_task,
229
+ cluster_name=controller_name,
230
+ retry_until_up=True,
231
+ _disable_controller_check=True,
232
+ )
230
233
 
231
234
  style = colorama.Style
232
235
  fore = colorama.Fore
@@ -49,6 +49,11 @@
49
49
  margin-bottom: 20px;
50
50
  color: #5f6368;
51
51
  }
52
+ .user-identifier {
53
+ font-size: 12px; /* Smaller font size */
54
+ color: #80868b; /* Lighter color */
55
+ margin-bottom: 8px; /* Adjusted margin */
56
+ }
52
57
  .code-block {
53
58
  background-color: #f1f3f4;
54
59
  border: 1px solid #dadce0;
@@ -110,8 +115,8 @@
110
115
  </svg>
111
116
  </div>
112
117
  <h1>Sign in to SkyPilot CLI</h1>
118
+ <p class="user-identifier">USER_PLACEHOLDER</p>
113
119
  <p>You are seeing this page because a SkyPilot command requires authentication.</p>
114
-
115
120
  <p>Please copy the following token and paste it into your SkyPilot CLI prompt:</p>
116
121
  <div id="token-box" class="code-block">SKYPILOT_API_SERVER_USER_TOKEN_PLACEHOLDER</div>
117
122
  <button id="copy-btn" class="copy-button">Copy Token</button>
@@ -228,6 +228,7 @@ def override_request_env_and_config(
228
228
  """Override the environment and SkyPilot config for a request."""
229
229
  original_env = os.environ.copy()
230
230
  os.environ.update(request_body.env_vars)
231
+ # Note: may be overridden by AuthProxyMiddleware.
231
232
  user = models.User(id=request_body.env_vars[constants.USER_ID_ENV_VAR],
232
233
  name=request_body.env_vars[constants.USER_ENV_VAR])
233
234
  global_user_state.add_or_update_user(user)
@@ -88,6 +88,11 @@ class RequestBody(pydantic.BaseModel):
88
88
  using_remote_api_server: bool = False
89
89
  override_skypilot_config: Optional[Dict[str, Any]] = {}
90
90
 
91
+ # Allow extra fields in the request body, which is useful for backward
92
+ # compatibility, i.e., we can add new fields to the request body without
93
+ # breaking the existing old API server.
94
+ model_config = pydantic.ConfigDict(extra='allow')
95
+
91
96
  def __init__(self, **data):
92
97
  data['env_vars'] = data.get('env_vars', request_body_env_vars())
93
98
  usage_lib_entrypoint = usage_lib.messages.usage.entrypoint
@@ -126,6 +131,7 @@ class CheckBody(RequestBody):
126
131
  """The request body for the check endpoint."""
127
132
  clouds: Optional[Tuple[str, ...]] = None
128
133
  verbose: bool = False
134
+ workspace: Optional[str] = None
129
135
 
130
136
 
131
137
  class DagRequestBody(RequestBody):
@@ -525,3 +531,8 @@ class UploadZipFileResponse(pydantic.BaseModel):
525
531
  """The response body for the upload zip file endpoint."""
526
532
  status: str
527
533
  missing_chunks: Optional[List[str]] = None
534
+
535
+
536
+ class EnabledCloudsBody(RequestBody):
537
+ """The request body for the enabled clouds endpoint."""
538
+ workspace: Optional[str] = None
sky/server/server.py CHANGED
@@ -6,6 +6,7 @@ import base64
6
6
  import contextlib
7
7
  import dataclasses
8
8
  import datetime
9
+ import hashlib
9
10
  import json
10
11
  import logging
11
12
  import multiprocessing
@@ -31,7 +32,9 @@ from sky import core
31
32
  from sky import exceptions
32
33
  from sky import execution
33
34
  from sky import global_user_state
35
+ from sky import models
34
36
  from sky import sky_logging
37
+ from sky import skypilot_config
35
38
  from sky.clouds import service_catalog
36
39
  from sky.data import storage_utils
37
40
  from sky.jobs.server import server as jobs_rest
@@ -110,6 +113,38 @@ class RequestIDMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
110
113
  return response
111
114
 
112
115
 
116
+ def _get_auth_user_header(request: fastapi.Request) -> Optional[models.User]:
117
+ if 'X-Auth-Request-Email' not in request.headers:
118
+ return None
119
+ user_name = request.headers['X-Auth-Request-Email']
120
+ user_hash = hashlib.md5(
121
+ user_name.encode()).hexdigest()[:common_utils.USER_HASH_LENGTH]
122
+ return models.User(id=user_hash, name=user_name)
123
+
124
+
125
+ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
126
+ """Middleware to handle auth proxy."""
127
+
128
+ async def dispatch(self, request: fastapi.Request, call_next):
129
+ auth_user = _get_auth_user_header(request)
130
+ body = await request.body()
131
+ if auth_user and body:
132
+ try:
133
+ original_json = await request.json()
134
+ except json.JSONDecodeError as e:
135
+ logger.error(f'Error parsing request JSON: {e}')
136
+ else:
137
+ logger.debug(f'Overriding user for {request.state.request_id}: '
138
+ f'{auth_user.name}, {auth_user.id}')
139
+ if 'env_vars' in original_json:
140
+ original_json['env_vars'][
141
+ constants.USER_ID_ENV_VAR] = auth_user.id
142
+ original_json['env_vars'][
143
+ constants.USER_ENV_VAR] = auth_user.name
144
+ request._body = json.dumps(original_json).encode('utf-8') # pylint: disable=protected-access
145
+ return await call_next(request)
146
+
147
+
113
148
  # Default expiration time for upload ids before cleanup.
114
149
  _DEFAULT_UPLOAD_EXPIRATION_TIME = datetime.timedelta(hours=1)
115
150
  # Key: (upload_id, user_hash), Value: the time when the upload id needs to be
@@ -216,6 +251,7 @@ app.add_middleware(
216
251
  allow_headers=['*'],
217
252
  # TODO(syang): remove X-Request-ID when v0.10.0 is released.
218
253
  expose_headers=['X-Request-ID', 'X-Skypilot-Request-ID'])
254
+ app.add_middleware(AuthProxyMiddleware)
219
255
  app.add_middleware(RequestIDMiddleware)
220
256
  app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
221
257
  app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
@@ -223,8 +259,18 @@ app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
223
259
 
224
260
  @app.get('/token')
225
261
  async def token(request: fastapi.Request) -> fastapi.responses.HTMLResponse:
262
+ # If we have auth info, save this user to the database.
263
+ user = _get_auth_user_header(request)
264
+ if user is not None:
265
+ global_user_state.add_or_update_user(user)
266
+
267
+ token_data = {
268
+ 'v': 1, # Token version number, bump for backwards incompatible.
269
+ 'user': user.id if user is not None else None,
270
+ 'cookies': request.cookies,
271
+ }
226
272
  # Use base64 encoding to avoid having to escape anything in the HTML.
227
- json_bytes = json.dumps(request.cookies).encode('utf-8')
273
+ json_bytes = json.dumps(token_data).encode('utf-8')
228
274
  base64_str = base64.b64encode(json_bytes).decode('utf-8')
229
275
 
230
276
  html_dir = pathlib.Path(__file__).parent / 'html'
@@ -236,8 +282,10 @@ async def token(request: fastapi.Request) -> fastapi.responses.HTMLResponse:
236
282
  raise fastapi.HTTPException(
237
283
  status_code=500, detail='Token page template not found.') from e
238
284
 
285
+ user_info_string = f'Logged in as {user.name}' if user is not None else ''
239
286
  html_content = html_content.replace(
240
- 'SKYPILOT_API_SERVER_USER_TOKEN_PLACEHOLDER', base64_str)
287
+ 'SKYPILOT_API_SERVER_USER_TOKEN_PLACEHOLDER',
288
+ base64_str).replace('USER_PLACEHOLDER', user_info_string)
241
289
 
242
290
  return fastapi.responses.HTMLResponse(
243
291
  content=html_content,
@@ -263,17 +311,30 @@ async def check(request: fastapi.Request,
263
311
 
264
312
 
265
313
  @app.get('/enabled_clouds')
266
- async def enabled_clouds(request: fastapi.Request) -> None:
314
+ async def enabled_clouds(request: fastapi.Request,
315
+ workspace: Optional[str] = None) -> None:
267
316
  """Gets enabled clouds on the server."""
268
317
  executor.schedule_request(
269
318
  request_id=request.state.request_id,
270
319
  request_name='enabled_clouds',
271
- request_body=payloads.RequestBody(),
320
+ request_body=payloads.EnabledCloudsBody(workspace=workspace),
272
321
  func=core.enabled_clouds,
273
322
  schedule_type=requests_lib.ScheduleType.SHORT,
274
323
  )
275
324
 
276
325
 
326
+ @app.get('/workspaces')
327
+ async def get_workspace_config(request: fastapi.Request) -> None:
328
+ """Gets workspace config on the server."""
329
+ executor.schedule_request(
330
+ request_id=request.state.request_id,
331
+ request_name='workspaces',
332
+ request_body=payloads.RequestBody(),
333
+ func=skypilot_config.get_workspaces,
334
+ schedule_type=requests_lib.ScheduleType.SHORT,
335
+ )
336
+
337
+
277
338
  @app.post('/realtime_kubernetes_gpu_availability')
278
339
  async def realtime_kubernetes_gpu_availability(
279
340
  request: fastapi.Request,
@@ -1113,7 +1174,7 @@ async def api_status(
1113
1174
 
1114
1175
 
1115
1176
  @app.get('/api/health')
1116
- async def health() -> Dict[str, str]:
1177
+ async def health(request: fastapi.Request) -> Dict[str, Any]:
1117
1178
  """Checks the health of the API server.
1118
1179
 
1119
1180
  Returns:
@@ -1125,12 +1186,14 @@ async def health() -> Dict[str, str]:
1125
1186
  disk, which can be used to warn about restarting the API server
1126
1187
  - commit: str; The commit hash of SkyPilot used for API server.
1127
1188
  """
1189
+ user = _get_auth_user_header(request)
1128
1190
  return {
1129
1191
  'status': common.ApiServerStatus.HEALTHY.value,
1130
1192
  'api_version': server_constants.API_VERSION,
1131
1193
  'version': sky.__version__,
1132
1194
  'version_on_disk': common.get_skypilot_version_on_disk(),
1133
1195
  'commit': sky.__commit__,
1196
+ 'user': user.to_dict() if user is not None else None,
1134
1197
  }
1135
1198
 
1136
1199
 
sky/skylet/constants.py CHANGED
@@ -378,7 +378,8 @@ OVERRIDEABLE_CONFIG_KEYS_IN_TASK: List[Tuple[str, ...]] = [
378
378
  # we skip the following keys because they are meant to be client-side configs.
379
379
  SKIPPED_CLIENT_OVERRIDE_KEYS: List[Tuple[str, ...]] = [('admin_policy',),
380
380
  ('api_server',),
381
- ('allowed_clouds',)]
381
+ ('allowed_clouds',),
382
+ ('workspaces',)]
382
383
 
383
384
  # Constants for Azure blob storage
384
385
  WAIT_FOR_STORAGE_ACCOUNT_CREATION = 60
@@ -405,3 +406,5 @@ SKY_USER_FILE_PATH = '~/.sky/generated'
405
406
 
406
407
  # Environment variable that is set to 'true' if this is a skypilot server.
407
408
  ENV_VAR_IS_SKYPILOT_SERVER = 'IS_SKYPILOT_SERVER'
409
+
410
+ SKYPILOT_DEFAULT_WORKSPACE = 'default'
sky/skypilot_config.py CHANGED
@@ -123,6 +123,8 @@ class ConfigContext:
123
123
  _global_config_context = ConfigContext()
124
124
  _reload_config_lock = threading.Lock()
125
125
 
126
+ _active_workspace_context = threading.local()
127
+
126
128
 
127
129
  def _get_config_context() -> ConfigContext:
128
130
  """Get config context for current context.
@@ -194,8 +196,7 @@ def get_user_config() -> config_utils.Config:
194
196
 
195
197
  # load the user config file
196
198
  if os.path.exists(user_config_path):
197
- user_config = parse_config_file(user_config_path)
198
- _validate_config(user_config, user_config_path)
199
+ user_config = parse_and_validate_config_file(user_config_path)
199
200
  else:
200
201
  user_config = config_utils.Config()
201
202
  return user_config
@@ -223,8 +224,7 @@ def _get_project_config() -> config_utils.Config:
223
224
 
224
225
  # load the project config file
225
226
  if os.path.exists(project_config_path):
226
- project_config = parse_config_file(project_config_path)
227
- _validate_config(project_config, project_config_path)
227
+ project_config = parse_and_validate_config_file(project_config_path)
228
228
  else:
229
229
  project_config = config_utils.Config()
230
230
  return project_config
@@ -252,8 +252,7 @@ def get_server_config() -> config_utils.Config:
252
252
 
253
253
  # load the server config file
254
254
  if os.path.exists(server_config_path):
255
- server_config = parse_config_file(server_config_path)
256
- _validate_config(server_config, server_config_path)
255
+ server_config = parse_and_validate_config_file(server_config_path)
257
256
  else:
258
257
  server_config = config_utils.Config()
259
258
  return server_config
@@ -287,6 +286,60 @@ def get_nested(keys: Tuple[str, ...],
287
286
  disallowed_override_keys=None)
288
287
 
289
288
 
289
+ def get_workspace_cloud(cloud: str,
290
+ workspace: Optional[str] = None) -> config_utils.Config:
291
+ """Returns the workspace config."""
292
+ if workspace is None:
293
+ workspace = get_active_workspace()
294
+ clouds = get_nested(keys=(
295
+ 'workspaces',
296
+ workspace,
297
+ ), default_value=None)
298
+ if clouds is None:
299
+ return config_utils.Config()
300
+ return clouds.get(cloud.lower(), config_utils.Config())
301
+
302
+
303
+ @contextlib.contextmanager
304
+ def local_active_workspace_ctx(workspace: str) -> Iterator[None]:
305
+ """Temporarily set the active workspace IN CURRENT THREAD.
306
+
307
+ Note: having this function thread-local is error-prone, as wrapping some
308
+ operations with this will not have the underlying threads to get the
309
+ correct active workspace. However, we cannot make it global either, as
310
+ backend_utils.refresh_cluster_status() will be called in multiple threads,
311
+ and they may have different active workspaces for different threads.
312
+
313
+ # TODO(zhwu): make this function global by default and able to be set
314
+ # it to thread-local with an argument.
315
+
316
+ Args:
317
+ workspace: The workspace to set as active.
318
+
319
+ Raises:
320
+ RuntimeError: If called from a non-main thread.
321
+ """
322
+ original_workspace = get_active_workspace()
323
+ if original_workspace == workspace:
324
+ # No change, do nothing.
325
+ yield
326
+ return
327
+ _active_workspace_context.workspace = workspace
328
+ logger.debug(f'Set context workspace: {workspace}')
329
+ yield
330
+ logger.debug(f'Reset context workspace: {original_workspace}')
331
+ _active_workspace_context.workspace = original_workspace
332
+
333
+
334
+ def get_active_workspace(force_user_workspace: bool = False) -> str:
335
+ context_workspace = getattr(_active_workspace_context, 'workspace', None)
336
+ if not force_user_workspace and context_workspace is not None:
337
+ logger.debug(f'Get context workspace: {context_workspace}')
338
+ return context_workspace
339
+ return get_nested(keys=('active_workspace',),
340
+ default_value=constants.SKYPILOT_DEFAULT_WORKSPACE)
341
+
342
+
290
343
  def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]:
291
344
  """Returns a deep-copied config with the nested key set to value.
292
345
 
@@ -357,7 +410,7 @@ def _reload_config() -> None:
357
410
  _reload_config_as_client()
358
411
 
359
412
 
360
- def parse_config_file(config_path: str) -> config_utils.Config:
413
+ def parse_and_validate_config_file(config_path: str) -> config_utils.Config:
361
414
  config = config_utils.Config()
362
415
  try:
363
416
  config_dict = common_utils.read_yaml(config_path)
@@ -413,7 +466,7 @@ def _reload_config_from_internal_file(internal_config_path: str) -> None:
413
466
  'exist. Please double check the path or unset the env var: '
414
467
  f'unset {ENV_VAR_SKYPILOT_CONFIG}')
415
468
  logger.debug(f'Using config path: {config_path}')
416
- _set_loaded_config(parse_config_file(config_path))
469
+ _set_loaded_config(parse_and_validate_config_file(config_path))
417
470
  _set_loaded_config_path(config_path)
418
471
 
419
472
 
@@ -512,6 +565,19 @@ def override_skypilot_config(
512
565
  override_configs=dict(override_configs),
513
566
  allowed_override_keys=None,
514
567
  disallowed_override_keys=constants.SKIPPED_CLIENT_OVERRIDE_KEYS)
568
+ workspace = config.get_nested(
569
+ keys=('active_workspace',),
570
+ default_value=constants.SKYPILOT_DEFAULT_WORKSPACE)
571
+ if (workspace != constants.SKYPILOT_DEFAULT_WORKSPACE and workspace
572
+ not in get_nested(keys=('workspaces',), default_value={})):
573
+ raise ValueError(f'Workspace {workspace} does not exist. '
574
+ 'Use `sky check` to see if it is defined on the API '
575
+ 'server and try again.')
576
+ # Initialize the active workspace context to the workspace specified, so
577
+ # that a new request is not affected by the previous request's workspace.
578
+ global _active_workspace_context
579
+ _active_workspace_context = threading.local()
580
+
515
581
  try:
516
582
  common_utils.validate_schema(
517
583
  config,
@@ -592,7 +658,7 @@ def _compose_cli_config(cli_config: Optional[List[str]]) -> config_utils.Config:
592
658
  'Cannot use multiple --config flags with a config file.')
593
659
  config_source = maybe_config_path
594
660
  # cli_config is a path to a config file
595
- parsed_config = parse_config_file(maybe_config_path)
661
+ parsed_config = parse_and_validate_config_file(maybe_config_path)
596
662
  else: # cli_config is a comma-separated list of key-value pairs
597
663
  parsed_config = _parse_dotlist(cli_config)
598
664
  _validate_config(parsed_config, config_source)
@@ -623,3 +689,11 @@ def apply_cli_config(cli_config: Optional[List[str]]) -> Dict[str, Any]:
623
689
  overlay_skypilot_config(original_config=_get_loaded_config(),
624
690
  override_configs=parsed_config))
625
691
  return parsed_config
692
+
693
+
694
+ def get_workspaces() -> Dict[str, Any]:
695
+ """Returns the workspace config."""
696
+ workspaces = get_nested(('workspaces',), default_value={})
697
+ if constants.SKYPILOT_DEFAULT_WORKSPACE not in workspaces:
698
+ workspaces[constants.SKYPILOT_DEFAULT_WORKSPACE] = {}
699
+ return workspaces