skypilot-nightly 1.0.0.dev20250521__py3-none-any.whl → 1.0.0.dev20250523__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sky/__init__.py +2 -2
  2. sky/adaptors/kubernetes.py +46 -16
  3. sky/backends/cloud_vm_ray_backend.py +16 -4
  4. sky/check.py +109 -44
  5. sky/cli.py +261 -90
  6. sky/client/cli.py +261 -90
  7. sky/client/sdk.py +122 -3
  8. sky/clouds/__init__.py +5 -0
  9. sky/clouds/aws.py +4 -2
  10. sky/clouds/azure.py +4 -2
  11. sky/clouds/cloud.py +30 -6
  12. sky/clouds/cudo.py +2 -1
  13. sky/clouds/do.py +2 -1
  14. sky/clouds/fluidstack.py +2 -1
  15. sky/clouds/gcp.py +160 -23
  16. sky/clouds/ibm.py +4 -2
  17. sky/clouds/kubernetes.py +66 -22
  18. sky/clouds/lambda_cloud.py +2 -1
  19. sky/clouds/nebius.py +18 -2
  20. sky/clouds/oci.py +4 -2
  21. sky/clouds/paperspace.py +2 -1
  22. sky/clouds/runpod.py +2 -1
  23. sky/clouds/scp.py +2 -1
  24. sky/clouds/service_catalog/__init__.py +3 -0
  25. sky/clouds/service_catalog/common.py +9 -2
  26. sky/clouds/service_catalog/constants.py +2 -1
  27. sky/clouds/service_catalog/ssh_catalog.py +167 -0
  28. sky/clouds/ssh.py +203 -0
  29. sky/clouds/vast.py +2 -1
  30. sky/clouds/vsphere.py +2 -1
  31. sky/core.py +59 -17
  32. sky/dashboard/out/404.html +1 -1
  33. sky/dashboard/out/_next/static/{hvWzC5E6Q4CcKzXcWbgig → ECKwDNS9v9y3_IKFZ2lpp}/_buildManifest.js +1 -1
  34. sky/dashboard/out/_next/static/chunks/pages/infra-abf08c4384190a39.js +1 -0
  35. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  36. sky/dashboard/out/clusters/[cluster].html +1 -1
  37. sky/dashboard/out/clusters.html +1 -1
  38. sky/dashboard/out/index.html +1 -1
  39. sky/dashboard/out/infra.html +1 -1
  40. sky/dashboard/out/jobs/[job].html +1 -1
  41. sky/dashboard/out/jobs.html +1 -1
  42. sky/data/storage.py +1 -0
  43. sky/execution.py +56 -7
  44. sky/jobs/server/core.py +4 -2
  45. sky/optimizer.py +29 -15
  46. sky/provision/__init__.py +1 -0
  47. sky/provision/aws/instance.py +17 -1
  48. sky/provision/gcp/constants.py +147 -4
  49. sky/provision/gcp/instance_utils.py +10 -0
  50. sky/provision/gcp/volume_utils.py +247 -0
  51. sky/provision/kubernetes/instance.py +16 -5
  52. sky/provision/kubernetes/utils.py +37 -19
  53. sky/provision/nebius/instance.py +3 -1
  54. sky/provision/nebius/utils.py +14 -2
  55. sky/provision/ssh/__init__.py +18 -0
  56. sky/resources.py +177 -4
  57. sky/serve/server/core.py +2 -4
  58. sky/server/common.py +46 -9
  59. sky/server/constants.py +2 -0
  60. sky/server/html/token_page.html +154 -0
  61. sky/server/requests/executor.py +3 -6
  62. sky/server/requests/payloads.py +7 -0
  63. sky/server/server.py +80 -8
  64. sky/setup_files/dependencies.py +1 -0
  65. sky/skypilot_config.py +117 -31
  66. sky/task.py +24 -1
  67. sky/templates/gcp-ray.yml.j2 +44 -1
  68. sky/templates/nebius-ray.yml.j2 +12 -2
  69. sky/utils/admin_policy_utils.py +26 -22
  70. sky/utils/context.py +36 -6
  71. sky/utils/context_utils.py +15 -0
  72. sky/utils/infra_utils.py +21 -1
  73. sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
  74. sky/utils/kubernetes/create_cluster.sh +1 -0
  75. sky/utils/kubernetes/deploy_remote_cluster.py +1437 -0
  76. sky/utils/kubernetes/kubernetes_deploy_utils.py +117 -10
  77. sky/utils/kubernetes/ssh-tunnel.sh +387 -0
  78. sky/utils/log_utils.py +214 -1
  79. sky/utils/resources_utils.py +14 -0
  80. sky/utils/schemas.py +67 -0
  81. sky/utils/ux_utils.py +2 -1
  82. {skypilot_nightly-1.0.0.dev20250521.dist-info → skypilot_nightly-1.0.0.dev20250523.dist-info}/METADATA +6 -1
  83. {skypilot_nightly-1.0.0.dev20250521.dist-info → skypilot_nightly-1.0.0.dev20250523.dist-info}/RECORD +88 -81
  84. sky/dashboard/out/_next/static/chunks/pages/infra-9180cd91cee64b96.js +0 -1
  85. sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
  86. /sky/dashboard/out/_next/static/{hvWzC5E6Q4CcKzXcWbgig → ECKwDNS9v9y3_IKFZ2lpp}/_ssgManifest.js +0 -0
  87. {skypilot_nightly-1.0.0.dev20250521.dist-info → skypilot_nightly-1.0.0.dev20250523.dist-info}/WHEEL +0 -0
  88. {skypilot_nightly-1.0.0.dev20250521.dist-info → skypilot_nightly-1.0.0.dev20250523.dist-info}/entry_points.txt +0 -0
  89. {skypilot_nightly-1.0.0.dev20250521.dist-info → skypilot_nightly-1.0.0.dev20250523.dist-info}/licenses/LICENSE +0 -0
  90. {skypilot_nightly-1.0.0.dev20250521.dist-info → skypilot_nightly-1.0.0.dev20250523.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,154 @@
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>SkyPilot API Server Login</title>
7
+ <style>
8
+ body {
9
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
10
+ display: flex;
11
+ flex-direction: column;
12
+ align-items: center;
13
+ justify-content: center;
14
+ min-height: 100vh;
15
+ margin: 0;
16
+ background-color: #f8f9fa;
17
+ color: #202124;
18
+ padding: 20px;
19
+ box-sizing: border-box;
20
+ }
21
+ .container {
22
+ background-color: #ffffff;
23
+ padding: 48px;
24
+ border-radius: 8px;
25
+ box-shadow: 0 1px 3px rgba(0,0,0,0.12), 0 1px 2px rgba(0,0,0,0.24);
26
+ text-align: center;
27
+ max-width: 600px;
28
+ width: 100%;
29
+ }
30
+ .logo {
31
+ width: 64px;
32
+ height: 64px;
33
+ margin-bottom: 20px;
34
+ display: inline-block;
35
+ }
36
+ .logo svg {
37
+ width: 100%;
38
+ height: 100%;
39
+ }
40
+ h1 {
41
+ font-size: 24px;
42
+ font-weight: 500;
43
+ margin-bottom: 20px;
44
+ color: #202124;
45
+ }
46
+ p {
47
+ font-size: 14px;
48
+ line-height: 1.5;
49
+ margin-bottom: 20px;
50
+ color: #5f6368;
51
+ }
52
+ .code-block {
53
+ background-color: #f1f3f4;
54
+ border: 1px solid #dadce0;
55
+ border-radius: 4px;
56
+ padding: 16px;
57
+ margin-top: 24px;
58
+ margin-bottom: 24px;
59
+ margin-left: auto;
60
+ margin-right: auto;
61
+ text-align: left;
62
+ word-break: break-all;
63
+ white-space: pre-wrap;
64
+ font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, Courier, monospace;
65
+ font-size: 13px;
66
+ line-height: 1.4;
67
+ max-width: 480px;
68
+ }
69
+ #token-box { /* Specifically for the token */
70
+ height: auto;
71
+ min-height: 6em; /* Ensure it's a reasonable size */
72
+ max-height: 15em; /* Prevent it from getting too large */
73
+ overflow-y: auto;
74
+ }
75
+ .copy-button {
76
+ background-color: #1a73e8;
77
+ color: white;
78
+ border: none;
79
+ border-radius: 4px;
80
+ padding: 10px 24px;
81
+ font-size: 14px;
82
+ font-weight: 500;
83
+ cursor: pointer;
84
+ transition: background-color 0.3s;
85
+ margin-top: 10px;
86
+ }
87
+ .copy-button:hover {
88
+ background-color: #287ae6;
89
+ }
90
+ .copy-button:active {
91
+ background-color: #1b66c9;
92
+ }
93
+ .footer-text {
94
+ font-size: 12px;
95
+ color: #5f6368;
96
+ margin-top: 30px;
97
+ }
98
+ </style>
99
+ </head>
100
+ <body>
101
+ <div class="container">
102
+ <div class="logo">
103
+ <!-- SkyPilot Logo Icon -->
104
+ <svg viewBox="0 0 50 50" fill="none" xmlns="http://www.w3.org/2000/svg">
105
+ <path d="M25.1258 30.8274L19.2842 31.6783L33.8316 46.2268L31.492 37.1925L25.1258 30.8274Z" fill="#372F8A"/>
106
+ <path d="M46.9433 0.000976562L0.719727 13.1148L15.2661 27.6601L16.633 21.3925L10.3728 15.1323L40.183 6.74118C40.183 6.74118 46.102 0.855027 46.9444 0.00203721L46.9433 0.000976562Z" fill="#372F8A"/>
107
+ <path d="M40.1821 6.74021L31.4922 37.1925L33.8318 46.2257L46.9445 0C46.1022 0.85299 40.1831 6.73915 40.1831 6.73915L40.1821 6.74021Z" fill="#372F8A"/>
108
+ <path d="M21.3356 25.6089L19.2842 31.6783L25.1258 30.8275L30.3741 16.6011L30.3275 16.617L21.3356 25.6089Z" fill="#195D7F"/>
109
+ <path d="M16.632 21.3918L15.2651 27.6605L21.3357 25.6091L30.3276 16.6172L16.632 21.3918Z" fill="#39A4DD"/>
110
+ </svg>
111
+ </div>
112
+ <h1>Sign in to SkyPilot CLI</h1>
113
+ <p>You are seeing this page because a SkyPilot command requires authentication.</p>
114
+
115
+ <p>Please copy the following token and paste it into your SkyPilot CLI prompt:</p>
116
+ <div id="token-box" class="code-block">SKYPILOT_API_SERVER_USER_TOKEN_PLACEHOLDER</div>
117
+ <button id="copy-btn" class="copy-button">Copy Token</button>
118
+
119
+ <p class="footer-text">You can close this tab after copying the token.</p>
120
+ </div>
121
+
122
+ <script>
123
+ const tokenBox = document.getElementById('token-box');
124
+ const copyBtn = document.getElementById('copy-btn');
125
+
126
+ function selectToken() {
127
+ // For <pre> or <div>, create a range to select its content
128
+ const range = document.createRange();
129
+ range.selectNodeContents(tokenBox);
130
+ const sel = window.getSelection();
131
+ sel.removeAllRanges();
132
+ sel.addRange(range);
133
+ }
134
+
135
+ // Optional: Select the token when the page loads or when token box is clicked
136
+ tokenBox.addEventListener('click', selectToken);
137
+ window.addEventListener('load', selectToken);
138
+
139
+ copyBtn.addEventListener('click', () => {
140
+ selectToken(); // Select the text
141
+ try {
142
+ document.execCommand('copy');
143
+ copyBtn.textContent = 'Copied!';
144
+ } catch (err) {
145
+ copyBtn.textContent = 'Error!';
146
+ console.error('Failed to copy text: ', err);
147
+ }
148
+ setTimeout(() => {
149
+ copyBtn.textContent = 'Copy Token';
150
+ }, 2000);
151
+ });
152
+ </script>
153
+ </body>
154
+ </html>
@@ -20,8 +20,6 @@ See the [README.md](../README.md) for detailed architecture of the executor.
20
20
  """
21
21
  import asyncio
22
22
  import contextlib
23
- import contextvars
24
- import functools
25
23
  import multiprocessing
26
24
  import os
27
25
  import queue as queue_lib
@@ -52,6 +50,7 @@ from sky.skylet import constants
52
50
  from sky.utils import annotations
53
51
  from sky.utils import common_utils
54
52
  from sky.utils import context
53
+ from sky.utils import context_utils
55
54
  from sky.utils import subprocess_utils
56
55
  from sky.utils import timeline
57
56
 
@@ -368,10 +367,8 @@ async def execute_request_coroutine(request: api_requests.Request):
368
367
  # 1. skypilot config is not contextual
369
368
  # 2. envs that read directly from os.environ are not contextual
370
369
  ctx.override_envs(request_body.env_vars)
371
- loop = asyncio.get_running_loop()
372
- pyctx = contextvars.copy_context()
373
- func_call = functools.partial(pyctx.run, func, **request_body.to_kwargs())
374
- fut: asyncio.Future = loop.run_in_executor(None, func_call)
370
+ fut: asyncio.Future = context_utils.to_thread(func,
371
+ **request_body.to_kwargs())
375
372
 
376
373
  async def poll_task(request_id: str) -> bool:
377
374
  request = api_requests.get_request(request_id)
@@ -446,6 +446,7 @@ class RealtimeGpuAvailabilityRequestBody(RequestBody):
446
446
  context: Optional[str] = None
447
447
  name_filter: Optional[str] = None
448
448
  quantity_filter: Optional[int] = None
449
+ is_ssh: Optional[bool] = None
449
450
 
450
451
 
451
452
  class KubernetesNodeInfoRequestBody(RequestBody):
@@ -485,6 +486,12 @@ class LocalUpBody(RequestBody):
485
486
  password: Optional[str] = None
486
487
 
487
488
 
489
+ class SSHUpBody(RequestBody):
490
+ """The request body for the SSH up/down endpoints."""
491
+ infra: Optional[str] = None
492
+ cleanup: bool = False
493
+
494
+
488
495
  class ServeTerminateReplicaBody(RequestBody):
489
496
  """The request body for the serve terminate replica endpoint."""
490
497
  service_name: str
sky/server/server.py CHANGED
@@ -2,9 +2,11 @@
2
2
 
3
3
  import argparse
4
4
  import asyncio
5
+ import base64
5
6
  import contextlib
6
7
  import dataclasses
7
8
  import datetime
9
+ import json
8
10
  import logging
9
11
  import multiprocessing
10
12
  import os
@@ -49,6 +51,7 @@ from sky.utils import admin_policy_utils
49
51
  from sky.utils import common as common_lib
50
52
  from sky.utils import common_utils
51
53
  from sky.utils import context
54
+ from sky.utils import context_utils
52
55
  from sky.utils import dag_utils
53
56
  from sky.utils import env_options
54
57
  from sky.utils import status_lib
@@ -218,6 +221,34 @@ app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
218
221
  app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
219
222
 
220
223
 
224
+ @app.get('/token')
225
+ async def token(request: fastapi.Request) -> fastapi.responses.HTMLResponse:
226
+ # Use base64 encoding to avoid having to escape anything in the HTML.
227
+ json_bytes = json.dumps(request.cookies).encode('utf-8')
228
+ base64_str = base64.b64encode(json_bytes).decode('utf-8')
229
+
230
+ html_dir = pathlib.Path(__file__).parent / 'html'
231
+ token_page_path = html_dir / 'token_page.html'
232
+ try:
233
+ with open(token_page_path, 'r', encoding='utf-8') as f:
234
+ html_content = f.read()
235
+ except FileNotFoundError as e:
236
+ raise fastapi.HTTPException(
237
+ status_code=500, detail='Token page template not found.') from e
238
+
239
+ html_content = html_content.replace(
240
+ 'SKYPILOT_API_SERVER_USER_TOKEN_PLACEHOLDER', base64_str)
241
+
242
+ return fastapi.responses.HTMLResponse(
243
+ content=html_content,
244
+ headers={
245
+ 'Cache-Control': 'no-cache, no-transform',
246
+ # X-Accel-Buffering: no is useful for preventing buffering issues
247
+ # with some reverse proxies.
248
+ 'X-Accel-Buffering': 'no'
249
+ })
250
+
251
+
221
252
  @app.post('/check')
222
253
  async def check(request: fastapi.Request,
223
254
  check_body: payloads.CheckBody) -> None:
@@ -327,25 +358,26 @@ async def validate(validate_body: payloads.ValidateBody) -> None:
327
358
  # pairs.
328
359
  logger.debug(f'Validating tasks: {validate_body.dag}')
329
360
 
361
+ context.initialize()
362
+
330
363
  def validate_dag(dag: dag_utils.dag_lib.Dag):
331
364
  # TODO: Admin policy may contain arbitrary code, which may be expensive
332
365
  # to run and may block the server thread. However, moving it into the
333
366
  # executor adds a ~150ms penalty on the local API server because of
334
367
  # added RTTs. For now, we stick to doing the validation inline in the
335
368
  # server thread.
336
- dag, _ = admin_policy_utils.apply(
337
- dag, request_options=validate_body.request_options)
338
- # Skip validating workdir and file_mounts, as those need to be
339
- # validated after the files are uploaded to the SkyPilot API server
340
- # with `upload_mounts_to_api_server`.
341
- dag.validate(skip_file_mounts=True, skip_workdir=True)
369
+ with admin_policy_utils.apply_and_use_config_in_current_request(
370
+ dag, request_options=validate_body.request_options) as dag:
371
+ # Skip validating workdir and file_mounts, as those need to be
372
+ # validated after the files are uploaded to the SkyPilot API server
373
+ # with `upload_mounts_to_api_server`.
374
+ dag.validate(skip_file_mounts=True, skip_workdir=True)
342
375
 
343
376
  try:
344
377
  dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag)
345
- loop = asyncio.get_running_loop()
346
378
  # Apply admin policy and validate DAG is blocking, run it in a separate
347
379
  # thread executor to avoid blocking the uvicorn event loop.
348
- await loop.run_in_executor(None, validate_dag, dag)
380
+ await context_utils.to_thread(validate_dag, dag)
349
381
  except Exception as e: # pylint: disable=broad-except
350
382
  raise fastapi.HTTPException(
351
383
  status_code=400, detail=exceptions.serialize_exception(e)) from e
@@ -877,6 +909,33 @@ async def local_down(request: fastapi.Request) -> None:
877
909
  )
878
910
 
879
911
 
912
+ @app.post('/ssh_up')
913
+ async def ssh_up(request: fastapi.Request,
914
+ ssh_up_body: payloads.SSHUpBody) -> None:
915
+ """Deploys a Kubernetes cluster on SSH targets."""
916
+ executor.schedule_request(
917
+ request_id=request.state.request_id,
918
+ request_name='ssh_up',
919
+ request_body=ssh_up_body,
920
+ func=core.ssh_up,
921
+ schedule_type=requests_lib.ScheduleType.LONG,
922
+ )
923
+
924
+
925
+ @app.post('/ssh_down')
926
+ async def ssh_down(request: fastapi.Request,
927
+ ssh_up_body: payloads.SSHUpBody) -> None:
928
+ """Tears down a Kubernetes cluster on SSH targets."""
929
+ # We still call ssh_up but with cleanup=True
930
+ executor.schedule_request(
931
+ request_id=request.state.request_id,
932
+ request_name='ssh_down',
933
+ request_body=ssh_up_body,
934
+ func=core.ssh_up, # Reuse ssh_up function with cleanup=True
935
+ schedule_type=requests_lib.ScheduleType.LONG,
936
+ )
937
+
938
+
880
939
  # === API server related APIs ===
881
940
  @app.get('/api/get')
882
941
  async def api_get(request_id: str) -> requests_lib.RequestPayload:
@@ -1153,6 +1212,19 @@ async def kubernetes_pod_ssh_proxy(
1153
1212
  proc.terminate()
1154
1213
 
1155
1214
 
1215
+ @app.get('/all_contexts')
1216
+ async def all_contexts(request: fastapi.Request) -> None:
1217
+ """Gets all Kubernetes and SSH node pool contexts."""
1218
+
1219
+ executor.schedule_request(
1220
+ request_id=request.state.request_id,
1221
+ request_name='all_contexts',
1222
+ request_body=payloads.RequestBody(),
1223
+ func=core.get_all_contexts,
1224
+ schedule_type=requests_lib.ScheduleType.SHORT,
1225
+ )
1226
+
1227
+
1156
1228
  # === Internal APIs ===
1157
1229
  @app.get('/api/completion/cluster_name')
1158
1230
  async def complete_cluster_name(incomplete: str,) -> List[str]:
@@ -130,6 +130,7 @@ extras_require: Dict[str, List[str]] = {
130
130
  'oci': ['oci'] + local_ray,
131
131
  # Kubernetes 32.0.0 has an authentication bug: https://github.com/kubernetes-client/python/issues/2333 # pylint: disable=line-too-long
132
132
  'kubernetes': ['kubernetes>=20.0.0,!=32.0.0', 'websockets'],
133
+ 'ssh': ['kubernetes>=20.0.0,!=32.0.0', 'websockets'],
133
134
  'remote': remote,
134
135
  # For the container registry auth api. Reference:
135
136
  # https://github.com/runpod/runpod-python/releases/tag/1.6.1
sky/skypilot_config.py CHANGED
@@ -52,6 +52,7 @@ import contextlib
52
52
  import copy
53
53
  import json
54
54
  import os
55
+ import tempfile
55
56
  import threading
56
57
  import typing
57
58
  from typing import Any, Dict, Iterator, List, Optional, Tuple
@@ -62,6 +63,7 @@ from sky.adaptors import common as adaptors_common
62
63
  from sky.skylet import constants
63
64
  from sky.utils import common_utils
64
65
  from sky.utils import config_utils
66
+ from sky.utils import context
65
67
  from sky.utils import schemas
66
68
  from sky.utils import ux_utils
67
69
 
@@ -105,13 +107,66 @@ ENV_VAR_PROJECT_CONFIG = f'{constants.SKYPILOT_ENV_VAR_PREFIX}PROJECT_CONFIG'
105
107
  _GLOBAL_CONFIG_PATH = '~/.sky/config.yaml'
106
108
  _PROJECT_CONFIG_PATH = '.sky.yaml'
107
109
 
108
- # The loaded config.
109
- _dict = config_utils.Config()
110
- _loaded_config_path: Optional[str] = None
111
- _config_overridden: bool = False
110
+
111
+ class ConfigContext:
112
+
113
+ def __init__(self,
114
+ config: config_utils.Config = config_utils.Config(),
115
+ config_path: Optional[str] = None,
116
+ config_overridden: bool = False):
117
+ self.config = config
118
+ self.config_path = config_path
119
+ self.config_overridden = config_overridden
120
+
121
+
122
+ # The global loaded config.
123
+ _global_config_context = ConfigContext()
112
124
  _reload_config_lock = threading.Lock()
113
125
 
114
126
 
127
+ def _get_config_context() -> ConfigContext:
128
+ """Get config context for current context.
129
+
130
+ If no context is available, the global config context is returned.
131
+ """
132
+ ctx = context.get()
133
+ if not ctx:
134
+ return _global_config_context
135
+ if ctx.config_context is None:
136
+ # Config context for current context is not initialized, inherit from
137
+ # the global one.
138
+ ctx.config_context = ConfigContext(
139
+ config=copy.deepcopy(_global_config_context.config),
140
+ config_path=_global_config_context.config_path,
141
+ config_overridden=_global_config_context.config_overridden,
142
+ )
143
+ return ctx.config_context
144
+
145
+
146
+ def _get_loaded_config() -> config_utils.Config:
147
+ return _get_config_context().config
148
+
149
+
150
+ def _set_loaded_config(config: config_utils.Config) -> None:
151
+ _get_config_context().config = config
152
+
153
+
154
+ def _get_loaded_config_path() -> Optional[str]:
155
+ return _get_config_context().config_path
156
+
157
+
158
+ def _set_loaded_config_path(path: Optional[str]) -> None:
159
+ _get_config_context().config_path = path
160
+
161
+
162
+ def _is_config_overridden() -> bool:
163
+ return _get_config_context().config_overridden
164
+
165
+
166
+ def _set_config_overridden(config_overridden: bool) -> None:
167
+ _get_config_context().config_overridden = config_overridden
168
+
169
+
115
170
  def get_user_config_path() -> str:
116
171
  """Returns the path to the user config file."""
117
172
  return _GLOBAL_CONFIG_PATH
@@ -224,7 +279,7 @@ def get_nested(keys: Tuple[str, ...],
224
279
  Returns:
225
280
  The value of the nested key, or 'default_value' if not found.
226
281
  """
227
- return _dict.get_nested(
282
+ return _get_loaded_config().get_nested(
228
283
  keys,
229
284
  default_value,
230
285
  override_configs,
@@ -237,14 +292,14 @@ def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]:
237
292
 
238
293
  Like get_nested(), if any key is not found, this will not raise an error.
239
294
  """
240
- copied_dict = copy.deepcopy(_dict)
295
+ copied_dict = copy.deepcopy(_get_loaded_config())
241
296
  copied_dict.set_nested(keys, value)
242
297
  return dict(**copied_dict)
243
298
 
244
299
 
245
300
  def to_dict() -> config_utils.Config:
246
301
  """Returns a deep-copied version of the current config."""
247
- return copy.deepcopy(_dict)
302
+ return copy.deepcopy(_get_loaded_config())
248
303
 
249
304
 
250
305
  def _get_config_file_path(envvar: str) -> Optional[str]:
@@ -345,10 +400,9 @@ def _parse_dotlist(dotlist: List[str]) -> config_utils.Config:
345
400
 
346
401
 
347
402
  def _reload_config_from_internal_file(internal_config_path: str) -> None:
348
- global _dict, _loaded_config_path
349
403
  # Reset the global variables, to avoid using stale values.
350
- _dict = config_utils.Config()
351
- _loaded_config_path = None
404
+ _set_loaded_config(config_utils.Config())
405
+ _set_loaded_config_path(None)
352
406
 
353
407
  config_path = os.path.expanduser(internal_config_path)
354
408
  if not os.path.exists(config_path):
@@ -359,14 +413,13 @@ def _reload_config_from_internal_file(internal_config_path: str) -> None:
359
413
  'exist. Please double check the path or unset the env var: '
360
414
  f'unset {ENV_VAR_SKYPILOT_CONFIG}')
361
415
  logger.debug(f'Using config path: {config_path}')
362
- _dict = parse_config_file(config_path)
363
- _loaded_config_path = config_path
416
+ _set_loaded_config(parse_config_file(config_path))
417
+ _set_loaded_config_path(config_path)
364
418
 
365
419
 
366
420
  def _reload_config_as_server() -> None:
367
- global _dict
368
421
  # Reset the global variables, to avoid using stale values.
369
- _dict = config_utils.Config()
422
+ _set_loaded_config(config_utils.Config())
370
423
 
371
424
  overrides: List[config_utils.Config] = []
372
425
  server_config = get_server_config()
@@ -382,13 +435,12 @@ def _reload_config_as_server() -> None:
382
435
  logger.debug(
383
436
  f'server config: \n'
384
437
  f'{common_utils.dump_yaml_str(dict(overlaid_server_config))}')
385
- _dict = overlaid_server_config
438
+ _set_loaded_config(overlaid_server_config)
386
439
 
387
440
 
388
441
  def _reload_config_as_client() -> None:
389
- global _dict
390
442
  # Reset the global variables, to avoid using stale values.
391
- _dict = config_utils.Config()
443
+ _set_loaded_config(config_utils.Config())
392
444
 
393
445
  overrides: List[config_utils.Config] = []
394
446
  user_config = get_user_config()
@@ -407,15 +459,15 @@ def _reload_config_as_client() -> None:
407
459
  logger.debug(
408
460
  f'client config (before task and CLI overrides): \n'
409
461
  f'{common_utils.dump_yaml_str(dict(overlaid_client_config))}')
410
- _dict = overlaid_client_config
462
+ _set_loaded_config(overlaid_client_config)
411
463
 
412
464
 
413
465
  def loaded_config_path() -> Optional[str]:
414
466
  """Returns the path to the loaded config file, or
415
467
  '<overridden>' if the config is overridden."""
416
- if _config_overridden:
468
+ if _is_config_overridden():
417
469
  return '<overridden>'
418
- return _loaded_config_path
470
+ return _get_loaded_config_path()
419
471
 
420
472
 
421
473
  # Load on import, synchronization is guaranteed by python interpreter.
@@ -424,21 +476,20 @@ _reload_config()
424
476
 
425
477
  def loaded() -> bool:
426
478
  """Returns if the user configurations are loaded."""
427
- return bool(_dict)
479
+ return bool(_get_loaded_config())
428
480
 
429
481
 
430
482
  @contextlib.contextmanager
431
483
  def override_skypilot_config(
432
484
  override_configs: Optional[Dict[str, Any]]) -> Iterator[None]:
433
485
  """Overrides the user configurations."""
434
- global _dict, _config_overridden
435
486
  # TODO(SKY-1215): allow admin user to extend the disallowed keys or specify
436
487
  # allowed keys.
437
488
  if not override_configs:
438
489
  # If no override configs (None or empty dict), do nothing.
439
490
  yield
440
491
  return
441
- original_config = _dict
492
+ original_config = _get_loaded_config()
442
493
  override_configs = config_utils.Config(override_configs)
443
494
  disallowed_diff_keys = []
444
495
  for key in constants.SKIPPED_CLIENT_OVERRIDE_KEYS:
@@ -455,7 +506,7 @@ def override_skypilot_config(
455
506
  'and will be ignored. Remove these keys to disable this warning. '
456
507
  'If you want to specify it, please modify it on server side or '
457
508
  'contact your administrator.')
458
- config = _dict.get_nested(
509
+ config = original_config.get_nested(
459
510
  keys=tuple(),
460
511
  default_value=None,
461
512
  override_configs=dict(override_configs),
@@ -469,8 +520,8 @@ def override_skypilot_config(
469
520
  'https://docs.skypilot.co/en/latest/reference/config.html. ' # pylint: disable=line-too-long
470
521
  'Error: ',
471
522
  skip_none=False)
472
- _config_overridden = True
473
- _dict = config
523
+ _set_config_overridden(True)
524
+ _set_loaded_config(config)
474
525
  yield
475
526
  except exceptions.InvalidSkyPilotConfigError as e:
476
527
  with ux_utils.print_exception_no_traceback():
@@ -483,8 +534,43 @@ def override_skypilot_config(
483
534
  f'{common_utils.dump_yaml_str(dict(override_configs))}\n'
484
535
  f'Details: {e}') from e
485
536
  finally:
486
- _dict = original_config
487
- _config_overridden = False
537
+ _set_loaded_config(original_config)
538
+ _set_config_overridden(False)
539
+
540
+
541
+ @contextlib.contextmanager
542
+ def replace_skypilot_config(new_configs: config_utils.Config) -> Iterator[None]:
543
+ """Replaces the global config with the new configs.
544
+
545
+ This function is concurrent safe when it is:
546
+ 1. called in different processes;
547
+ 2. or called in a same process but with different context, refer to
548
+ sky_utils.context for more details.
549
+ """
550
+ original_config = _get_loaded_config()
551
+ original_env_var = os.environ.get(ENV_VAR_SKYPILOT_CONFIG)
552
+ if new_configs != original_config:
553
+ # Modify the global config of current process or context
554
+ _set_loaded_config(new_configs)
555
+ with tempfile.NamedTemporaryFile(delete=False,
556
+ mode='w',
557
+ prefix='mutated-skypilot-config-',
558
+ suffix='.yaml') as temp_file:
559
+ common_utils.dump_yaml(temp_file.name, dict(**new_configs))
560
+ # Modify the env var of current process or context so that the
561
+ # new config will be used by spawned sub-processes.
562
+ # Note that this code modifies os.environ directly because it
563
+ # will be hijacked to be context-aware if a context is active.
564
+ os.environ[ENV_VAR_SKYPILOT_CONFIG] = temp_file.name
565
+ yield
566
+ # Restore the original config and env var.
567
+ _set_loaded_config(original_config)
568
+ if original_env_var:
569
+ os.environ[ENV_VAR_SKYPILOT_CONFIG] = original_env_var
570
+ else:
571
+ os.environ.pop(ENV_VAR_SKYPILOT_CONFIG, None)
572
+ else:
573
+ yield
488
574
 
489
575
 
490
576
  def _compose_cli_config(cli_config: Optional[List[str]]) -> config_utils.Config:
@@ -529,11 +615,11 @@ def apply_cli_config(cli_config: Optional[List[str]]) -> Dict[str, Any]:
529
615
  cli_config: A path to a config file or a comma-separated
530
616
  list of key-value pairs.
531
617
  """
532
- global _dict
533
618
  parsed_config = _compose_cli_config(cli_config)
534
619
  if sky_logging.logging_enabled(logger, sky_logging.DEBUG):
535
620
  logger.debug(f'applying following CLI overrides: \n'
536
621
  f'{common_utils.dump_yaml_str(dict(parsed_config))}')
537
- _dict = overlay_skypilot_config(original_config=_dict,
538
- override_configs=parsed_config)
622
+ _set_loaded_config(
623
+ overlay_skypilot_config(original_config=_get_loaded_config(),
624
+ override_configs=parsed_config))
539
625
  return parsed_config
sky/task.py CHANGED
@@ -512,6 +512,7 @@ class Task:
512
512
  # storage objects with the storage/storage_mount objects.
513
513
  fm_storages = []
514
514
  file_mounts = config.pop('file_mounts', None)
515
+ volumes = []
515
516
  if file_mounts is not None:
516
517
  copy_mounts = {}
517
518
  for dst_path, src in file_mounts.items():
@@ -521,7 +522,27 @@ class Task:
521
522
  # If the src is not a str path, it is likely a dict. Try to
522
523
  # parse storage object.
523
524
  elif isinstance(src, dict):
524
- fm_storages.append((dst_path, src))
525
+ if (src.get('store') ==
526
+ storage_lib.StoreType.VOLUME.value.lower()):
527
+ # Build the volumes config for resources.
528
+ volume_config = {
529
+ 'path': dst_path,
530
+ }
531
+ if src.get('name'):
532
+ volume_config['name'] = src.get('name')
533
+ persistent = src.get('persistent', False)
534
+ volume_config['auto_delete'] = not persistent
535
+ volume_config_detail = src.get('config', {})
536
+ volume_config.update(volume_config_detail)
537
+ volumes.append(volume_config)
538
+ source_path = src.get('source')
539
+ if source_path:
540
+ # For volume, copy the source path to the
541
+ # data directory of the volume mount point.
542
+ copy_mounts[
543
+ f'{dst_path.rstrip("/")}/data'] = source_path
544
+ else:
545
+ fm_storages.append((dst_path, src))
525
546
  else:
526
547
  with ux_utils.print_exception_no_traceback():
527
548
  raise ValueError(f'Unable to parse file_mount '
@@ -599,6 +620,8 @@ class Task:
599
620
  'experimental.config_overrides')
600
621
  resources_config[
601
622
  '_cluster_config_overrides'] = cluster_config_override
623
+ if volumes:
624
+ resources_config['volumes'] = volumes
602
625
  task.set_resources(sky.Resources.from_yaml_config(resources_config))
603
626
 
604
627
  service = config.pop('service', None)