skypilot-nightly 1.0.0.dev20250510__py3-none-any.whl → 1.0.0.dev20250513__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +2 -2
- sky/backends/backend_utils.py +3 -0
- sky/backends/cloud_vm_ray_backend.py +7 -0
- sky/cli.py +109 -109
- sky/client/cli.py +109 -109
- sky/clouds/gcp.py +35 -8
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/{C0fkLhvxyqkymoV7IeInQ → 2dkponv64SfFShA8Rnw0D}/_buildManifest.js +1 -1
- sky/dashboard/out/_next/static/chunks/845-0ca6f2c1ba667c3b.js +1 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/global_user_state.py +2 -0
- sky/provision/docker_utils.py +4 -1
- sky/provision/gcp/config.py +197 -15
- sky/provision/gcp/constants.py +64 -0
- sky/provision/nebius/instance.py +3 -1
- sky/provision/nebius/utils.py +4 -2
- sky/server/requests/executor.py +114 -22
- sky/server/requests/requests.py +15 -0
- sky/server/server.py +12 -7
- sky/server/uvicorn.py +12 -2
- sky/sky_logging.py +40 -2
- sky/skylet/constants.py +3 -0
- sky/skylet/log_lib.py +51 -11
- sky/templates/gcp-ray.yml.j2 +11 -0
- sky/templates/nebius-ray.yml.j2 +4 -0
- sky/templates/websocket_proxy.py +29 -9
- sky/utils/command_runner.py +3 -0
- sky/utils/context.py +264 -0
- sky/utils/context_utils.py +172 -0
- sky/utils/rich_utils.py +81 -37
- sky/utils/schemas.py +9 -1
- sky/utils/subprocess_utils.py +8 -2
- {skypilot_nightly-1.0.0.dev20250510.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/METADATA +1 -1
- {skypilot_nightly-1.0.0.dev20250510.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/RECORD +44 -42
- sky/dashboard/out/_next/static/chunks/845-0f8017370869e269.js +0 -1
- /sky/dashboard/out/_next/static/{C0fkLhvxyqkymoV7IeInQ → 2dkponv64SfFShA8Rnw0D}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20250510.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20250510.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250510.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250510.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/top_level.txt +0 -0
sky/utils/context.py
ADDED
@@ -0,0 +1,264 @@
|
|
1
|
+
"""SkyPilot context for threads and coroutines."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
from collections.abc import Mapping
|
5
|
+
from collections.abc import MutableMapping
|
6
|
+
import contextvars
|
7
|
+
import os
|
8
|
+
import pathlib
|
9
|
+
import subprocess
|
10
|
+
import sys
|
11
|
+
from typing import Dict, Optional, TextIO
|
12
|
+
|
13
|
+
|
14
|
+
class Context(object):
|
15
|
+
"""SkyPilot typed context vars for threads and coroutines.
|
16
|
+
|
17
|
+
This is a wrapper around `contextvars.ContextVar` that provides a typed
|
18
|
+
interface for the SkyPilot specific context variables that can be accessed
|
19
|
+
at any layer of the call stack. ContextVar is coroutine local, an empty
|
20
|
+
Context will be intialized for each coroutine when it is created.
|
21
|
+
|
22
|
+
Adding a new context variable for a new feature is as simple as:
|
23
|
+
1. Add a new instance variable to the Context class.
|
24
|
+
2. (Optional) Add new accessor methods if the variable should be protected.
|
25
|
+
|
26
|
+
To propagate the context to a new thread/coroutine, use
|
27
|
+
`contextvars.copy_context()`.
|
28
|
+
|
29
|
+
Example:
|
30
|
+
import asyncio
|
31
|
+
import contextvars
|
32
|
+
import time
|
33
|
+
from sky.utils import context
|
34
|
+
|
35
|
+
def sync_task():
|
36
|
+
while True:
|
37
|
+
if context.get().is_canceled():
|
38
|
+
break
|
39
|
+
time.sleep(1)
|
40
|
+
|
41
|
+
async def fastapi_handler():
|
42
|
+
# context.initialize() has been called in lifespan
|
43
|
+
ctx = contextvars.copy_context()
|
44
|
+
# asyncio.to_thread copies current context implicitly
|
45
|
+
task = asyncio.to_thread(sync_task)
|
46
|
+
# Or explicitly:
|
47
|
+
# loop = asyncio.get_running_loop()
|
48
|
+
# ctx = contextvars.copy_context()
|
49
|
+
# task = loop.run_in_executor(None, ctx.run, sync_task)
|
50
|
+
await asyncio.sleep(1)
|
51
|
+
context.get().cancel()
|
52
|
+
await task
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(self):
|
56
|
+
self._canceled = asyncio.Event()
|
57
|
+
self._log_file = None
|
58
|
+
self._log_file_handle = None
|
59
|
+
self.env_overrides = {}
|
60
|
+
|
61
|
+
def cancel(self):
|
62
|
+
"""Cancel the context."""
|
63
|
+
self._canceled.set()
|
64
|
+
|
65
|
+
def is_canceled(self):
|
66
|
+
"""Check if the context is canceled."""
|
67
|
+
return self._canceled.is_set()
|
68
|
+
|
69
|
+
def redirect_log(
|
70
|
+
self, log_file: Optional[pathlib.Path]) -> Optional[pathlib.Path]:
|
71
|
+
"""Redirect the stdout and stderr of current context to a file.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
log_file: The log file to redirect to. If None, the stdout and
|
75
|
+
stderr will be restored to the original streams.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
The old log file, or None if the stdout and stderr were not
|
79
|
+
redirected.
|
80
|
+
"""
|
81
|
+
original_log_file = self._log_file
|
82
|
+
original_log_handle = self._log_file_handle
|
83
|
+
if log_file is None:
|
84
|
+
self._log_file_handle = None
|
85
|
+
else:
|
86
|
+
self._log_file_handle = open(log_file, 'a', encoding='utf-8')
|
87
|
+
self._log_file = log_file
|
88
|
+
if original_log_file is not None:
|
89
|
+
original_log_handle.close()
|
90
|
+
return original_log_file
|
91
|
+
|
92
|
+
def output_stream(self, fallback: TextIO) -> TextIO:
|
93
|
+
if self._log_file_handle is None:
|
94
|
+
return fallback
|
95
|
+
else:
|
96
|
+
return self._log_file_handle
|
97
|
+
|
98
|
+
def override_envs(self, envs: Dict[str, str]):
|
99
|
+
for k, v in envs.items():
|
100
|
+
self.env_overrides[k] = v
|
101
|
+
|
102
|
+
|
103
|
+
_CONTEXT = contextvars.ContextVar('sky_context', default=None)
|
104
|
+
|
105
|
+
|
106
|
+
def get() -> Optional[Context]:
|
107
|
+
"""Get the current SkyPilot context.
|
108
|
+
|
109
|
+
If the context is not initialized, get() will return None. This helps
|
110
|
+
sync code to check whether it runs in a cancellable context and avoid
|
111
|
+
polling the cancellation event if it is not.
|
112
|
+
"""
|
113
|
+
return _CONTEXT.get()
|
114
|
+
|
115
|
+
|
116
|
+
class ContextualEnviron(MutableMapping):
|
117
|
+
"""Environment variables wrapper with contextual overrides.
|
118
|
+
|
119
|
+
An instance of ContextualEnviron will typically be used to replace
|
120
|
+
os.environ to make the envron access of current process contextual
|
121
|
+
aware.
|
122
|
+
|
123
|
+
Behavior of spawning a subprocess:
|
124
|
+
- The contexual overrides will not be applied to the subprocess by
|
125
|
+
default.
|
126
|
+
- When using env=os.environ to pass the environment variables to the
|
127
|
+
subprocess explicitly. The subprocess will inherit the contextual
|
128
|
+
environment variables at the time of the spawn, that is, it will not
|
129
|
+
see the updates to the environment variables after the spawn. Also,
|
130
|
+
os.environ of the subprocess will not be a ContextualEnviron unless
|
131
|
+
the subprocess hijacks os.environ explicitly.
|
132
|
+
- Optionally, context.Popen() can be used to automatically pass
|
133
|
+
os.environ with overrides to subprocess.
|
134
|
+
|
135
|
+
|
136
|
+
Example:
|
137
|
+
1. Parent process:
|
138
|
+
# Hijack os.environ to be a ContextualEnviron
|
139
|
+
os.environ = ContextualEnviron(os.environ)
|
140
|
+
ctx = context.get()
|
141
|
+
ctx.override_envs({'FOO': 'BAR1'})
|
142
|
+
proc = subprocess.Popen(..., env=os.environ)
|
143
|
+
# Or use context.Popen instead
|
144
|
+
# proc = context.Popen(...)
|
145
|
+
ctx.override_envs({'FOO': 'BAR2'})
|
146
|
+
2. Subprocess:
|
147
|
+
assert os.environ['FOO'] == 'BAR1'
|
148
|
+
ctx = context.get()
|
149
|
+
# Override the contextual env var in the subprocess does not take
|
150
|
+
# effect since the os.environ is not hijacked.
|
151
|
+
ctx.override_envs({'FOO': 'BAR3'})
|
152
|
+
assert os.environ['FOO'] == 'BAR1'
|
153
|
+
"""
|
154
|
+
|
155
|
+
def __init__(self, environ):
|
156
|
+
self._environ = environ
|
157
|
+
|
158
|
+
def __getitem__(self, key):
|
159
|
+
ctx = get()
|
160
|
+
if ctx is not None:
|
161
|
+
if key in ctx.env_overrides:
|
162
|
+
return ctx.env_overrides[key]
|
163
|
+
return self._environ[key]
|
164
|
+
|
165
|
+
def __iter__(self):
|
166
|
+
ctx = get()
|
167
|
+
if ctx is not None:
|
168
|
+
for key in ctx.env_overrides:
|
169
|
+
yield key
|
170
|
+
for key in self._environ:
|
171
|
+
# Deduplicate the keys
|
172
|
+
if key not in ctx.env_overrides:
|
173
|
+
yield key
|
174
|
+
else:
|
175
|
+
return self._environ.__iter__()
|
176
|
+
|
177
|
+
def __len__(self):
|
178
|
+
return len(dict(self))
|
179
|
+
|
180
|
+
def __setitem__(self, key, value):
|
181
|
+
return self._environ.__setitem__(key, value)
|
182
|
+
|
183
|
+
def __delitem__(self, key):
|
184
|
+
return self._environ.__delitem__(key)
|
185
|
+
|
186
|
+
def __repr__(self):
|
187
|
+
return self._environ.__repr__()
|
188
|
+
|
189
|
+
def copy(self):
|
190
|
+
copied = self._environ.copy()
|
191
|
+
ctx = get()
|
192
|
+
if ctx is not None:
|
193
|
+
copied.update(ctx.env_overrides)
|
194
|
+
return copied
|
195
|
+
|
196
|
+
def setdefault(self, key, default=None):
|
197
|
+
return self._environ.setdefault(key, default)
|
198
|
+
|
199
|
+
def __ior__(self, other):
|
200
|
+
if not isinstance(other, Mapping):
|
201
|
+
return NotImplemented
|
202
|
+
self.update(other)
|
203
|
+
return self
|
204
|
+
|
205
|
+
def __or__(self, other):
|
206
|
+
if not isinstance(other, Mapping):
|
207
|
+
return NotImplemented
|
208
|
+
new = dict(self)
|
209
|
+
new.update(other)
|
210
|
+
return new
|
211
|
+
|
212
|
+
def __ror__(self, other):
|
213
|
+
if not isinstance(other, Mapping):
|
214
|
+
return NotImplemented
|
215
|
+
new = dict(other)
|
216
|
+
new.update(self)
|
217
|
+
return new
|
218
|
+
|
219
|
+
|
220
|
+
class Popen(subprocess.Popen):
|
221
|
+
|
222
|
+
def __init__(self, *args, **kwargs):
|
223
|
+
env = kwargs.pop('env', None)
|
224
|
+
if env is None:
|
225
|
+
env = os.environ
|
226
|
+
super().__init__(*args, env=env, **kwargs)
|
227
|
+
|
228
|
+
|
229
|
+
def initialize():
|
230
|
+
"""Initialize the current SkyPilot context."""
|
231
|
+
_CONTEXT.set(Context())
|
232
|
+
|
233
|
+
|
234
|
+
class _ContextualStream:
|
235
|
+
"""A base class for streams that are contextually aware.
|
236
|
+
|
237
|
+
This class implements the TextIO interface via __getattr__ to delegate
|
238
|
+
attribute access to the original or contextual stream.
|
239
|
+
"""
|
240
|
+
_original_stream: TextIO
|
241
|
+
|
242
|
+
def __init__(self, original_stream: TextIO):
|
243
|
+
self._original_stream = original_stream
|
244
|
+
|
245
|
+
def __getattr__(self, attr: str):
|
246
|
+
return getattr(self._active_stream(), attr)
|
247
|
+
|
248
|
+
def _active_stream(self) -> TextIO:
|
249
|
+
ctx = get()
|
250
|
+
if ctx is None:
|
251
|
+
return self._original_stream
|
252
|
+
return ctx.output_stream(self._original_stream)
|
253
|
+
|
254
|
+
|
255
|
+
class Stdout(_ContextualStream):
|
256
|
+
|
257
|
+
def __init__(self):
|
258
|
+
super().__init__(sys.stdout)
|
259
|
+
|
260
|
+
|
261
|
+
class Stderr(_ContextualStream):
|
262
|
+
|
263
|
+
def __init__(self):
|
264
|
+
super().__init__(sys.stderr)
|
@@ -0,0 +1,172 @@
|
|
1
|
+
"""Utilities for SkyPilot context."""
|
2
|
+
import asyncio
|
3
|
+
import functools
|
4
|
+
import io
|
5
|
+
import multiprocessing
|
6
|
+
import os
|
7
|
+
import subprocess
|
8
|
+
import sys
|
9
|
+
import typing
|
10
|
+
from typing import Any, Callable, IO, Optional, Tuple, TypeVar
|
11
|
+
|
12
|
+
from sky import sky_logging
|
13
|
+
from sky.utils import context
|
14
|
+
from sky.utils import subprocess_utils
|
15
|
+
|
16
|
+
StreamHandler = Callable[[IO[Any], IO[Any]], str]
|
17
|
+
|
18
|
+
|
19
|
+
# TODO(aylei): call hijack_sys_attrs() proactivly in module init at server-side
|
20
|
+
# once we have context widely adopted.
|
21
|
+
def hijack_sys_attrs():
|
22
|
+
"""hijack system attributes to be context aware
|
23
|
+
|
24
|
+
This function should be called at the very beginning of the processes
|
25
|
+
that might use sky.utils.context.
|
26
|
+
"""
|
27
|
+
# Modify stdout and stderr of unvicorn process to be contextually aware,
|
28
|
+
# use setattr to bypass the TextIO type check.
|
29
|
+
setattr(sys, 'stdout', context.Stdout())
|
30
|
+
setattr(sys, 'stderr', context.Stderr())
|
31
|
+
# Reload logger to apply latest stdout and stderr.
|
32
|
+
sky_logging.reload_logger()
|
33
|
+
# Hijack os.environ with ContextualEnviron to make env variables
|
34
|
+
# contextually aware.
|
35
|
+
setattr(os, 'environ', context.ContextualEnviron(os.environ))
|
36
|
+
# Hijack subprocess.Popen to pass the contextual environ to subprocess
|
37
|
+
# by default.
|
38
|
+
setattr(subprocess, 'Popen', context.Popen)
|
39
|
+
|
40
|
+
|
41
|
+
def passthrough_stream_handler(in_stream: IO[Any], out_stream: IO[Any]) -> str:
|
42
|
+
"""Passthrough the stream from the process to the output stream"""
|
43
|
+
wrapped = io.TextIOWrapper(in_stream,
|
44
|
+
encoding='utf-8',
|
45
|
+
newline='',
|
46
|
+
errors='replace',
|
47
|
+
write_through=True)
|
48
|
+
while True:
|
49
|
+
line = wrapped.readline()
|
50
|
+
if line:
|
51
|
+
out_stream.write(line)
|
52
|
+
out_stream.flush()
|
53
|
+
else:
|
54
|
+
break
|
55
|
+
return ''
|
56
|
+
|
57
|
+
|
58
|
+
def pipe_and_wait_process(
|
59
|
+
ctx: context.Context,
|
60
|
+
proc: subprocess.Popen,
|
61
|
+
poll_interval: float = 0.5,
|
62
|
+
cancel_callback: Optional[Callable[[], None]] = None,
|
63
|
+
stdout_stream_handler: Optional[StreamHandler] = None,
|
64
|
+
stderr_stream_handler: Optional[StreamHandler] = None
|
65
|
+
) -> Tuple[str, str]:
|
66
|
+
"""Wait for the process to finish or cancel it if the context is cancelled.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
proc: The process to wait for.
|
70
|
+
poll_interval: The interval to poll the process.
|
71
|
+
cancel_callback: The callback to call if the context is cancelled.
|
72
|
+
stdout_stream_handler: An optional handler to handle the stdout stream,
|
73
|
+
if None, the stdout stream will be passed through.
|
74
|
+
stderr_stream_handler: An optional handler to handle the stderr stream,
|
75
|
+
if None, the stderr stream will be passed through.
|
76
|
+
"""
|
77
|
+
|
78
|
+
if stdout_stream_handler is None:
|
79
|
+
stdout_stream_handler = passthrough_stream_handler
|
80
|
+
if stderr_stream_handler is None:
|
81
|
+
stderr_stream_handler = passthrough_stream_handler
|
82
|
+
|
83
|
+
# Threads are lazily created, so no harm if stderr is None
|
84
|
+
with multiprocessing.pool.ThreadPool(processes=2) as pool:
|
85
|
+
# Context will be lost in the new thread, capture current output stream
|
86
|
+
# and pass it to the new thread directly.
|
87
|
+
stdout_fut = pool.apply_async(
|
88
|
+
stdout_stream_handler, (proc.stdout, ctx.output_stream(sys.stdout)))
|
89
|
+
stderr_fut = None
|
90
|
+
if proc.stderr is not None:
|
91
|
+
stderr_fut = pool.apply_async(
|
92
|
+
stderr_stream_handler,
|
93
|
+
(proc.stderr, ctx.output_stream(sys.stderr)))
|
94
|
+
try:
|
95
|
+
wait_process(ctx,
|
96
|
+
proc,
|
97
|
+
poll_interval=poll_interval,
|
98
|
+
cancel_callback=cancel_callback)
|
99
|
+
finally:
|
100
|
+
# Wait for the stream handler threads to exit when process is done
|
101
|
+
# or cancelled
|
102
|
+
stdout_fut.wait()
|
103
|
+
if stderr_fut is not None:
|
104
|
+
stderr_fut.wait()
|
105
|
+
stdout = stdout_fut.get()
|
106
|
+
stderr = ''
|
107
|
+
if stderr_fut is not None:
|
108
|
+
stderr = stderr_fut.get()
|
109
|
+
return stdout, stderr
|
110
|
+
|
111
|
+
|
112
|
+
def wait_process(ctx: context.Context,
|
113
|
+
proc: subprocess.Popen,
|
114
|
+
poll_interval: float = 0.5,
|
115
|
+
cancel_callback: Optional[Callable[[], None]] = None):
|
116
|
+
"""Wait for the process to finish or cancel it if the context is cancelled.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
proc: The process to wait for.
|
120
|
+
poll_interval: The interval to poll the process.
|
121
|
+
cancel_callback: The callback to call if the context is cancelled.
|
122
|
+
"""
|
123
|
+
while True:
|
124
|
+
if ctx.is_canceled():
|
125
|
+
if cancel_callback is not None:
|
126
|
+
cancel_callback()
|
127
|
+
# Kill the process despite the caller's callback, the utility
|
128
|
+
# function gracefully handles the case where the process is
|
129
|
+
# already terminated.
|
130
|
+
subprocess_utils.kill_process_with_grace_period(proc)
|
131
|
+
raise asyncio.CancelledError()
|
132
|
+
try:
|
133
|
+
proc.wait(poll_interval)
|
134
|
+
except subprocess.TimeoutExpired:
|
135
|
+
pass
|
136
|
+
else:
|
137
|
+
# Process exited
|
138
|
+
break
|
139
|
+
|
140
|
+
|
141
|
+
F = TypeVar('F', bound=Callable[..., Any])
|
142
|
+
|
143
|
+
|
144
|
+
def cancellation_guard(func: F) -> F:
|
145
|
+
"""Decorator to make a synchronous function cancellable via context.
|
146
|
+
|
147
|
+
Guards the function execution by checking context.is_canceled() before
|
148
|
+
executing the function and raises asyncio.CancelledError if the context
|
149
|
+
is already cancelled.
|
150
|
+
|
151
|
+
This basically mimics the behavior of asyncio, which checks coroutine
|
152
|
+
cancelled in await call.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
func: The function to be decorated.
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
The wrapped function that checks cancellation before execution.
|
159
|
+
|
160
|
+
Raises:
|
161
|
+
asyncio.CancelledError: If the context is cancelled before execution.
|
162
|
+
"""
|
163
|
+
|
164
|
+
@functools.wraps(func)
|
165
|
+
def wrapper(*args, **kwargs):
|
166
|
+
ctx = context.get()
|
167
|
+
if ctx is not None and ctx.is_canceled():
|
168
|
+
raise asyncio.CancelledError(
|
169
|
+
f'Function {func.__name__} cancelled before execution')
|
170
|
+
return func(*args, **kwargs)
|
171
|
+
|
172
|
+
return typing.cast(F, wrapper)
|
sky/utils/rich_utils.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1
1
|
"""Rich status spinner utils."""
|
2
2
|
import contextlib
|
3
|
+
import contextvars
|
3
4
|
import enum
|
4
5
|
import logging
|
5
6
|
import threading
|
6
7
|
import typing
|
7
|
-
from typing import
|
8
|
+
from typing import Callable, Iterator, Optional, Tuple, Union
|
8
9
|
|
9
10
|
from sky.adaptors import common as adaptors_common
|
10
11
|
from sky.utils import annotations
|
12
|
+
from sky.utils import context
|
11
13
|
from sky.utils import message_utils
|
12
14
|
from sky.utils import rich_console_utils
|
13
15
|
|
@@ -18,11 +20,31 @@ else:
|
|
18
20
|
requests = adaptors_common.LazyImport('requests')
|
19
21
|
rich_console = adaptors_common.LazyImport('rich.console')
|
20
22
|
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
23
|
+
GeneralStatus = Union['rich_console.Status', 'EncodedStatus']
|
24
|
+
|
25
|
+
_client_status: Optional[GeneralStatus] = None
|
26
|
+
_server_status: contextvars.ContextVar[
|
27
|
+
Optional[GeneralStatus]] = contextvars.ContextVar('server_status',
|
28
|
+
default=None)
|
29
|
+
|
30
|
+
|
31
|
+
def _get_client_status() -> Optional[GeneralStatus]:
|
32
|
+
return _client_status
|
33
|
+
|
34
|
+
|
35
|
+
def _get_server_status() -> Optional[GeneralStatus]:
|
36
|
+
return _server_status.get()
|
37
|
+
|
38
|
+
|
39
|
+
def _set_client_status(status: Optional[GeneralStatus]):
|
40
|
+
global _client_status
|
41
|
+
_client_status = status
|
42
|
+
|
43
|
+
|
44
|
+
def _set_server_status(status: Optional[GeneralStatus]):
|
45
|
+
_server_status.set(status)
|
46
|
+
|
47
|
+
|
26
48
|
_status_nesting_level = 0
|
27
49
|
|
28
50
|
_logging_lock = threading.RLock()
|
@@ -128,20 +150,22 @@ class _NoOpConsoleStatus:
|
|
128
150
|
class _RevertibleStatus:
|
129
151
|
"""A wrapper for status that can revert to previous message after exit."""
|
130
152
|
|
131
|
-
def __init__(self, message: str,
|
153
|
+
def __init__(self, message: str, get_status_fn: Callable[[], GeneralStatus],
|
154
|
+
set_status_fn: Callable[[Optional[GeneralStatus]], None]):
|
132
155
|
self.previous_message = None
|
133
|
-
self.
|
134
|
-
|
156
|
+
self.get_status_fn = get_status_fn
|
157
|
+
self.set_status_fn = set_status_fn
|
158
|
+
status = self.get_status_fn()
|
135
159
|
if status is not None:
|
136
160
|
self.previous_message = status.status
|
137
161
|
self.message = message
|
138
162
|
|
139
163
|
def __enter__(self):
|
140
164
|
global _status_nesting_level
|
141
|
-
|
165
|
+
self.get_status_fn().update(self.message)
|
142
166
|
_status_nesting_level += 1
|
143
|
-
|
144
|
-
return
|
167
|
+
self.get_status_fn().__enter__()
|
168
|
+
return self.get_status_fn()
|
145
169
|
|
146
170
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
147
171
|
# We use the same lock with the `safe_logger` to avoid the following 2
|
@@ -160,32 +184,48 @@ class _RevertibleStatus:
|
|
160
184
|
_status_nesting_level -= 1
|
161
185
|
if _status_nesting_level <= 0:
|
162
186
|
_status_nesting_level = 0
|
163
|
-
if
|
164
|
-
|
165
|
-
|
166
|
-
_statuses[self.status_type] = None
|
187
|
+
if self.get_status_fn() is not None:
|
188
|
+
self.get_status_fn().__exit__(exc_type, exc_val, exc_tb)
|
189
|
+
self.set_status_fn(None)
|
167
190
|
else:
|
168
|
-
|
191
|
+
self.get_status_fn().update(self.previous_message)
|
169
192
|
|
170
193
|
def update(self, *args, **kwargs):
|
171
|
-
|
194
|
+
self.get_status_fn().update(*args, **kwargs)
|
172
195
|
|
173
196
|
def stop(self):
|
174
|
-
|
197
|
+
self.get_status_fn().stop()
|
175
198
|
|
176
199
|
def start(self):
|
177
|
-
|
200
|
+
self.get_status_fn().start()
|
201
|
+
|
202
|
+
|
203
|
+
def _is_thread_safe() -> bool:
|
204
|
+
"""Check if the current status context is thread-safe.
|
205
|
+
|
206
|
+
We are thread-safe if we are on the main thread or the server_status is
|
207
|
+
context-local, i.e. an async context has been initialized.
|
208
|
+
"""
|
209
|
+
return (threading.current_thread() is threading.main_thread() or
|
210
|
+
context.get() is not None)
|
178
211
|
|
179
212
|
|
180
213
|
def safe_status(msg: str) -> Union['rich_console.Status', _NoOpConsoleStatus]:
|
181
|
-
"""A wrapper for multi-threaded console.status.
|
214
|
+
"""A wrapper for multi-threaded server-side console.status.
|
215
|
+
|
216
|
+
This function will encode rich status with control codes and output the
|
217
|
+
encoded string to stdout. Client-side decode control codes from server
|
218
|
+
output and update the rich status. This function is safe to be called in
|
219
|
+
async/multi-threaded context.
|
220
|
+
|
221
|
+
See also: :func:`client_status`, :class:`EncodedStatus`.
|
222
|
+
"""
|
182
223
|
from sky import sky_logging # pylint: disable=import-outside-toplevel
|
183
|
-
if (annotations.is_on_api_server and
|
184
|
-
threading.current_thread() is threading.main_thread() and
|
224
|
+
if (annotations.is_on_api_server and _is_thread_safe() and
|
185
225
|
not sky_logging.is_silent()):
|
186
|
-
if
|
187
|
-
|
188
|
-
return _RevertibleStatus(msg,
|
226
|
+
if _get_server_status() is None:
|
227
|
+
_set_server_status(EncodedStatus(msg))
|
228
|
+
return _RevertibleStatus(msg, _get_server_status, _set_server_status)
|
189
229
|
return _NoOpConsoleStatus()
|
190
230
|
|
191
231
|
|
@@ -196,22 +236,26 @@ def stop_safe_status():
|
|
196
236
|
stream logs from user program and do not want it to interfere with the
|
197
237
|
spinner display.
|
198
238
|
"""
|
199
|
-
if (
|
200
|
-
|
201
|
-
|
239
|
+
if _is_thread_safe():
|
240
|
+
return
|
241
|
+
server_status = _get_server_status()
|
242
|
+
if server_status is not None:
|
243
|
+
server_status.stop()
|
202
244
|
|
203
245
|
|
204
246
|
def force_update_status(msg: str):
|
205
247
|
"""Update the status message even if sky_logging.is_silent() is true."""
|
206
|
-
if
|
207
|
-
|
208
|
-
|
248
|
+
if not _is_thread_safe():
|
249
|
+
return
|
250
|
+
server_status = _get_server_status()
|
251
|
+
if server_status is not None:
|
252
|
+
server_status.update(msg)
|
209
253
|
|
210
254
|
|
211
255
|
@contextlib.contextmanager
|
212
256
|
def safe_logger():
|
213
257
|
with _logging_lock:
|
214
|
-
client_status_obj =
|
258
|
+
client_status_obj = _get_client_status()
|
215
259
|
|
216
260
|
client_status_live = (client_status_obj is not None and
|
217
261
|
client_status_obj._live.is_started) # pylint: disable=protected-access
|
@@ -230,13 +274,13 @@ class RichSafeStreamHandler(logging.StreamHandler):
|
|
230
274
|
|
231
275
|
|
232
276
|
def client_status(msg: str) -> Union['rich_console.Status', _NoOpConsoleStatus]:
|
233
|
-
"""A wrapper for multi-threaded console.status."""
|
277
|
+
"""A wrapper for multi-threaded client-side console.status."""
|
234
278
|
from sky import sky_logging # pylint: disable=import-outside-toplevel
|
235
279
|
if (threading.current_thread() is threading.main_thread() and
|
236
280
|
not sky_logging.is_silent()):
|
237
|
-
if
|
238
|
-
|
239
|
-
return _RevertibleStatus(msg,
|
281
|
+
if _get_client_status() is None:
|
282
|
+
_set_client_status(rich_console_utils.get_console().status(msg))
|
283
|
+
return _RevertibleStatus(msg, _get_client_status, _set_client_status)
|
240
284
|
return _NoOpConsoleStatus()
|
241
285
|
|
242
286
|
|
sky/utils/schemas.py
CHANGED
@@ -837,6 +837,12 @@ def get_config_schema():
|
|
837
837
|
'enable_gvnic': {
|
838
838
|
'type': 'boolean'
|
839
839
|
},
|
840
|
+
'enable_gpu_direct': {
|
841
|
+
'type': 'boolean'
|
842
|
+
},
|
843
|
+
'placement_policy': {
|
844
|
+
'type': 'string',
|
845
|
+
},
|
840
846
|
'vpc_name': {
|
841
847
|
'oneOf': [
|
842
848
|
{
|
@@ -966,7 +972,9 @@ def get_config_schema():
|
|
966
972
|
'nebius': {
|
967
973
|
'type': 'object',
|
968
974
|
'required': [],
|
969
|
-
'properties': {
|
975
|
+
'properties': {
|
976
|
+
**_NETWORK_CONFIG_SCHEMA,
|
977
|
+
},
|
970
978
|
'additionalProperties': {
|
971
979
|
'type': 'object',
|
972
980
|
'required': [],
|
sky/utils/subprocess_utils.py
CHANGED
@@ -208,8 +208,11 @@ def kill_children_processes(parent_pids: Optional[Union[
|
|
208
208
|
kill_process_with_grace_period(child, force=force)
|
209
209
|
|
210
210
|
|
211
|
-
|
212
|
-
|
211
|
+
GenericProcess = Union[multiprocessing.Process, psutil.Process,
|
212
|
+
subprocess.Popen]
|
213
|
+
|
214
|
+
|
215
|
+
def kill_process_with_grace_period(proc: GenericProcess,
|
213
216
|
force: bool = False,
|
214
217
|
grace_period: int = 10) -> None:
|
215
218
|
"""Kill a process with SIGTERM and wait for it to exit.
|
@@ -223,6 +226,9 @@ def kill_process_with_grace_period(proc: Union[multiprocessing.Process,
|
|
223
226
|
if isinstance(proc, psutil.Process):
|
224
227
|
alive = proc.is_running
|
225
228
|
wait = proc.wait
|
229
|
+
elif isinstance(proc, subprocess.Popen):
|
230
|
+
alive = lambda: proc.poll() is None
|
231
|
+
wait = proc.wait
|
226
232
|
else:
|
227
233
|
alive = proc.is_alive
|
228
234
|
wait = proc.join
|