skypilot-nightly 1.0.0.dev20250510__py3-none-any.whl → 1.0.0.dev20250513__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +3 -0
  3. sky/backends/cloud_vm_ray_backend.py +7 -0
  4. sky/cli.py +109 -109
  5. sky/client/cli.py +109 -109
  6. sky/clouds/gcp.py +35 -8
  7. sky/dashboard/out/404.html +1 -1
  8. sky/dashboard/out/_next/static/{C0fkLhvxyqkymoV7IeInQ → 2dkponv64SfFShA8Rnw0D}/_buildManifest.js +1 -1
  9. sky/dashboard/out/_next/static/chunks/845-0ca6f2c1ba667c3b.js +1 -0
  10. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  11. sky/dashboard/out/clusters/[cluster].html +1 -1
  12. sky/dashboard/out/clusters.html +1 -1
  13. sky/dashboard/out/index.html +1 -1
  14. sky/dashboard/out/jobs/[job].html +1 -1
  15. sky/dashboard/out/jobs.html +1 -1
  16. sky/global_user_state.py +2 -0
  17. sky/provision/docker_utils.py +4 -1
  18. sky/provision/gcp/config.py +197 -15
  19. sky/provision/gcp/constants.py +64 -0
  20. sky/provision/nebius/instance.py +3 -1
  21. sky/provision/nebius/utils.py +4 -2
  22. sky/server/requests/executor.py +114 -22
  23. sky/server/requests/requests.py +15 -0
  24. sky/server/server.py +12 -7
  25. sky/server/uvicorn.py +12 -2
  26. sky/sky_logging.py +40 -2
  27. sky/skylet/constants.py +3 -0
  28. sky/skylet/log_lib.py +51 -11
  29. sky/templates/gcp-ray.yml.j2 +11 -0
  30. sky/templates/nebius-ray.yml.j2 +4 -0
  31. sky/templates/websocket_proxy.py +29 -9
  32. sky/utils/command_runner.py +3 -0
  33. sky/utils/context.py +264 -0
  34. sky/utils/context_utils.py +172 -0
  35. sky/utils/rich_utils.py +81 -37
  36. sky/utils/schemas.py +9 -1
  37. sky/utils/subprocess_utils.py +8 -2
  38. {skypilot_nightly-1.0.0.dev20250510.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/METADATA +1 -1
  39. {skypilot_nightly-1.0.0.dev20250510.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/RECORD +44 -42
  40. sky/dashboard/out/_next/static/chunks/845-0f8017370869e269.js +0 -1
  41. /sky/dashboard/out/_next/static/{C0fkLhvxyqkymoV7IeInQ → 2dkponv64SfFShA8Rnw0D}/_ssgManifest.js +0 -0
  42. {skypilot_nightly-1.0.0.dev20250510.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/WHEEL +0 -0
  43. {skypilot_nightly-1.0.0.dev20250510.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/entry_points.txt +0 -0
  44. {skypilot_nightly-1.0.0.dev20250510.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/licenses/LICENSE +0 -0
  45. {skypilot_nightly-1.0.0.dev20250510.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/top_level.txt +0 -0
sky/server/server.py CHANGED
@@ -47,6 +47,7 @@ from sky.usage import usage_lib
47
47
  from sky.utils import admin_policy_utils
48
48
  from sky.utils import common as common_lib
49
49
  from sky.utils import common_utils
50
+ from sky.utils import context
50
51
  from sky.utils import dag_utils
51
52
  from sky.utils import env_options
52
53
  from sky.utils import status_lib
@@ -673,24 +674,28 @@ async def logs(
673
674
  # TODO(zhwu): This should wait for the request on the cluster, e.g., async
674
675
  # launch, to finish, so that a user does not need to manually pull the
675
676
  # request status.
676
- executor.schedule_request(
677
+ # Only initialize the context in logs handler to limit the scope of this
678
+ # experimental change.
679
+ # TODO(aylei): init in lifespan() to enable SkyPilot context in all APIs.
680
+ context.initialize()
681
+ request_task = executor.prepare_request(
677
682
  request_id=request.state.request_id,
678
683
  request_name='logs',
679
684
  request_body=cluster_job_body,
680
685
  func=core.tail_logs,
681
- # TODO(aylei): We have tail logs scheduled as SHORT request, because it
682
- # should be responsive. However, it can be long running if the user's
683
- # job keeps running, and we should avoid it taking the SHORT worker.
684
686
  schedule_type=requests_lib.ScheduleType.SHORT,
685
- request_cluster_name=cluster_job_body.cluster_name,
686
687
  )
688
+ task = asyncio.create_task(executor.execute_request_coroutine(request_task))
687
689
 
688
- request_task = requests_lib.get_request(request.state.request_id)
690
+ def cancel_task():
691
+ task.cancel()
689
692
 
693
+ # Cancel the task after the request is done or client disconnects
694
+ background_tasks.add_task(cancel_task)
690
695
  # TODO(zhwu): This makes viewing logs in browser impossible. We should adopt
691
696
  # the same approach as /stream.
692
697
  return stream_utils.stream_response(
693
- request_id=request_task.request_id,
698
+ request_id=request.state.request_id,
694
699
  logs_path=request_task.log_path,
695
700
  background_tasks=background_tasks,
696
701
  )
sky/server/uvicorn.py CHANGED
@@ -3,6 +3,7 @@
3
3
  This module is a wrapper around uvicorn to customize the behavior of the
4
4
  server.
5
5
  """
6
+ import functools
6
7
  import os
7
8
  import threading
8
9
  from typing import Optional
@@ -10,6 +11,7 @@ from typing import Optional
10
11
  import uvicorn
11
12
  from uvicorn.supervisors import multiprocess
12
13
 
14
+ from sky.utils import context_utils
13
15
  from sky.utils import subprocess_utils
14
16
 
15
17
 
@@ -21,19 +23,27 @@ def run(config: uvicorn.Config):
21
23
  # guard by an exception.
22
24
  raise ValueError('Reload is not supported yet.')
23
25
  server = uvicorn.Server(config=config)
26
+ run_server_process = functools.partial(_run_server_process, server)
24
27
  try:
25
28
  if config.workers is not None and config.workers > 1:
26
29
  sock = config.bind_socket()
27
- SlowStartMultiprocess(config, target=server.run,
30
+ SlowStartMultiprocess(config,
31
+ target=run_server_process,
28
32
  sockets=[sock]).run()
29
33
  else:
30
- server.run()
34
+ run_server_process()
31
35
  finally:
32
36
  # Copied from unvicorn.run()
33
37
  if config.uds and os.path.exists(config.uds):
34
38
  os.remove(config.uds)
35
39
 
36
40
 
41
+ def _run_server_process(server: uvicorn.Server, *args, **kwargs):
42
+ """Run the server process with contextually aware."""
43
+ context_utils.hijack_sys_attrs()
44
+ server.run(*args, **kwargs)
45
+
46
+
37
47
  class SlowStartMultiprocess(multiprocess.Multiprocess):
38
48
  """Uvicorn Multiprocess wrapper with slow start.
39
49
 
sky/sky_logging.py CHANGED
@@ -10,6 +10,7 @@ import threading
10
10
  import colorama
11
11
 
12
12
  from sky.skylet import constants
13
+ from sky.utils import context
13
14
  from sky.utils import env_options
14
15
  from sky.utils import rich_utils
15
16
 
@@ -47,6 +48,43 @@ class NewLineFormatter(logging.Formatter):
47
48
  return msg
48
49
 
49
50
 
51
+ class EnvAwareHandler(rich_utils.RichSafeStreamHandler):
52
+ """A handler that awares environment variables.
53
+
54
+ This handler dynamically reflects the log level from environment variables.
55
+ """
56
+
57
+ def __init__(self, stream=None, level=logging.NOTSET, sensitive=False):
58
+ super().__init__(stream)
59
+ self.level = level
60
+ self._sensitive = sensitive
61
+
62
+ @property
63
+ def level(self):
64
+ # Only refresh log level if we are in a context, since the log level
65
+ # has already been reloaded eagerly in multi-processing. Refresh again
66
+ # is a no-op and can be avoided.
67
+ # TODO(aylei): unify the mechanism for coroutine context and
68
+ # multi-processing.
69
+ if context.get() is not None:
70
+ if self._sensitive:
71
+ # For sensitive logger, suppress debug log despite the
72
+ # SKYPILOT_DEBUG env var if SUPPRESS_SENSITIVE_LOG is set
73
+ if env_options.Options.SUPPRESS_SENSITIVE_LOG.get():
74
+ return logging.INFO
75
+ if env_options.Options.SHOW_DEBUG_INFO.get():
76
+ return logging.DEBUG
77
+ else:
78
+ return self._level
79
+ else:
80
+ return self._level
81
+
82
+ @level.setter
83
+ def level(self, level):
84
+ # pylint: disable=protected-access
85
+ self._level = logging._checkLevel(level)
86
+
87
+
50
88
  _root_logger = logging.getLogger('sky')
51
89
  _default_handler = None
52
90
  _logging_config = threading.local()
@@ -67,7 +105,7 @@ def _setup_logger():
67
105
  _root_logger.setLevel(logging.DEBUG)
68
106
  global _default_handler
69
107
  if _default_handler is None:
70
- _default_handler = rich_utils.RichSafeStreamHandler(sys.stdout)
108
+ _default_handler = EnvAwareHandler(sys.stdout)
71
109
  _default_handler.flush = sys.stdout.flush # type: ignore
72
110
  if env_options.Options.SHOW_DEBUG_INFO.get():
73
111
  _default_handler.setLevel(logging.DEBUG)
@@ -87,7 +125,7 @@ def _setup_logger():
87
125
  # for certain loggers.
88
126
  for logger_name in _SENSITIVE_LOGGER:
89
127
  logger = logging.getLogger(logger_name)
90
- handler_to_logger = rich_utils.RichSafeStreamHandler(sys.stdout)
128
+ handler_to_logger = EnvAwareHandler(sys.stdout, sensitive=True)
91
129
  handler_to_logger.flush = sys.stdout.flush # type: ignore
92
130
  logger.addHandler(handler_to_logger)
93
131
  logger.setLevel(logging.INFO)
sky/skylet/constants.py CHANGED
@@ -370,6 +370,9 @@ OVERRIDEABLE_CONFIG_KEYS_IN_TASK: List[Tuple[str, ...]] = [
370
370
  ('kubernetes', 'pod_config'),
371
371
  ('kubernetes', 'provision_timeout'),
372
372
  ('gcp', 'managed_instance_group'),
373
+ ('gcp', 'enable_gvnic'),
374
+ ('gcp', 'enable_gpu_direct'),
375
+ ('gcp', 'placement_policy'),
373
376
  ]
374
377
  # When overriding the SkyPilot configs on the API server with the client one,
375
378
  # we skip the following keys because they are meant to be client-side configs.
sky/skylet/log_lib.py CHANGED
@@ -4,6 +4,7 @@ This is a remote utility module that provides logging functionality.
4
4
  """
5
5
  import collections
6
6
  import copy
7
+ import functools
7
8
  import io
8
9
  import multiprocessing.pool
9
10
  import os
@@ -21,6 +22,8 @@ import colorama
21
22
  from sky import sky_logging
22
23
  from sky.skylet import constants
23
24
  from sky.skylet import job_lib
25
+ from sky.utils import context
26
+ from sky.utils import context_utils
24
27
  from sky.utils import log_utils
25
28
  from sky.utils import subprocess_utils
26
29
  from sky.utils import ux_utils
@@ -77,6 +80,9 @@ def _handle_io_stream(io_stream, out_stream, args: _ProcessingArgs):
77
80
  with open(args.log_path, 'a', encoding='utf-8') as fout:
78
81
  with line_processor:
79
82
  while True:
83
+ ctx = context.get()
84
+ if ctx is not None and ctx.is_canceled():
85
+ return
80
86
  line = out_io.readline()
81
87
  if not line:
82
88
  break
@@ -111,30 +117,29 @@ def _handle_io_stream(io_stream, out_stream, args: _ProcessingArgs):
111
117
  return ''.join(out)
112
118
 
113
119
 
114
- def process_subprocess_stream(proc, args: _ProcessingArgs) -> Tuple[str, str]:
115
- """Redirect the process's filtered stdout/stderr to both stream and file"""
120
+ def process_subprocess_stream(proc, stdout_stream_handler,
121
+ stderr_stream_handler) -> Tuple[str, str]:
122
+ """Process the stream of a process in threads, blocking."""
116
123
  if proc.stderr is not None:
117
124
  # Asyncio does not work as the output processing can be executed in a
118
125
  # different thread.
119
126
  # selectors is possible to handle the multiplexing of stdout/stderr,
120
127
  # but it introduces buffering making the output not streaming.
121
128
  with multiprocessing.pool.ThreadPool(processes=1) as pool:
122
- err_args = copy.copy(args)
123
- err_args.line_processor = None
124
- stderr_fut = pool.apply_async(_handle_io_stream,
125
- args=(proc.stderr, sys.stderr,
126
- err_args))
129
+ stderr_fut = pool.apply_async(stderr_stream_handler,
130
+ args=(proc.stderr, sys.stderr))
127
131
  # Do not launch a thread for stdout as the rich.status does not
128
132
  # work in a thread, which is used in
129
133
  # log_utils.RayUpLineProcessor.
130
- stdout = _handle_io_stream(proc.stdout, sys.stdout, args)
134
+ stdout = stdout_stream_handler(proc.stdout, sys.stdout)
131
135
  stderr = stderr_fut.get()
132
136
  else:
133
- stdout = _handle_io_stream(proc.stdout, sys.stdout, args)
137
+ stdout = stdout_stream_handler(proc.stdout, sys.stdout)
134
138
  stderr = ''
135
139
  return stdout, stderr
136
140
 
137
141
 
142
+ @context_utils.cancellation_guard
138
143
  def run_with_log(
139
144
  cmd: Union[List[str], str],
140
145
  log_path: str,
@@ -176,7 +181,12 @@ def run_with_log(
176
181
  # Redirect stderr to stdout when using ray, to preserve the order of
177
182
  # stdout and stderr.
178
183
  stdout_arg = stderr_arg = None
179
- if process_stream:
184
+ ctx = context.get()
185
+ if process_stream or ctx is not None:
186
+ # Capture stdout/stderr of the subprocess if:
187
+ # 1. Post-processing is needed (process_stream=True)
188
+ # 2. Potential contextual handling is needed (ctx is not None)
189
+ # TODO(aylei): can we always capture the stdout/stderr?
180
190
  stdout_arg = subprocess.PIPE
181
191
  stderr_arg = subprocess.PIPE if not with_ray else subprocess.STDOUT
182
192
  # Use stdin=subprocess.DEVNULL by default, as allowing inputs will mess up
@@ -197,6 +207,8 @@ def run_with_log(
197
207
  subprocess_utils.kill_process_daemon(proc.pid)
198
208
  stdout = ''
199
209
  stderr = ''
210
+ stdout_stream_handler = None
211
+ stderr_stream_handler = None
200
212
 
201
213
  if process_stream:
202
214
  if skip_lines is None:
@@ -223,7 +235,35 @@ def run_with_log(
223
235
  replace_crlf=with_ray,
224
236
  streaming_prefix=streaming_prefix,
225
237
  )
226
- stdout, stderr = process_subprocess_stream(proc, args)
238
+ stdout_stream_handler = functools.partial(
239
+ _handle_io_stream,
240
+ args=args,
241
+ )
242
+ if proc.stderr is not None:
243
+ err_args = copy.copy(args)
244
+ err_args.line_processor = None
245
+ stderr_stream_handler = functools.partial(
246
+ _handle_io_stream,
247
+ args=err_args,
248
+ )
249
+ if ctx is not None:
250
+ # When runs in a coroutine, always process the subprocess
251
+ # stream to:
252
+ # 1. handle context cancellation
253
+ # 2. redirect subprocess stdout/stderr to the contextual
254
+ # stdout/stderr of current coroutine.
255
+ stdout, stderr = context_utils.pipe_and_wait_process(
256
+ ctx,
257
+ proc,
258
+ cancel_callback=subprocess_utils.kill_children_processes,
259
+ stdout_stream_handler=stdout_stream_handler,
260
+ stderr_stream_handler=stderr_stream_handler)
261
+ elif process_stream:
262
+ # When runs in a process, only process subprocess stream if
263
+ # necessary to avoid unnecessary stream handling overhead.
264
+ stdout, stderr = process_subprocess_stream(
265
+ proc, stdout_stream_handler, stderr_stream_handler)
266
+ # Ensure returncode is set.
227
267
  proc.wait()
228
268
  if require_outputs:
229
269
  return proc.returncode, stdout, stderr
@@ -69,6 +69,12 @@ provider:
69
69
  {%- if enable_gvnic %}
70
70
  enable_gvnic: {{ enable_gvnic }}
71
71
  {%- endif %}
72
+ {%- if enable_gpu_direct %}
73
+ enable_gpu_direct: {{ enable_gpu_direct }}
74
+ {%- endif %}
75
+ {%- if placement_policy %}
76
+ placement_policy: {{ placement_policy }}
77
+ {%- endif %}
72
78
 
73
79
  auth:
74
80
  ssh_user: gcpuser
@@ -148,6 +154,11 @@ available_node_types:
148
154
  - key: install-nvidia-driver
149
155
  value: "True"
150
156
  {%- endif %}
157
+ {%- if user_data is not none %}
158
+ - key: user-data
159
+ value: |-
160
+ {{ user_data | indent(10) }}
161
+ {%- endif %}
151
162
  {%- if use_spot or gpu is not none %}
152
163
  scheduling:
153
164
  {%- if use_spot %}
@@ -9,6 +9,7 @@ provider:
9
9
  type: external
10
10
  module: sky.provision.nebius
11
11
  region: "{{region}}"
12
+ use_internal_ips: {{use_internal_ips}}
12
13
 
13
14
  {%- if docker_image is not none %}
14
15
  docker:
@@ -34,6 +35,9 @@ docker:
34
35
  auth:
35
36
  ssh_user: ubuntu
36
37
  ssh_private_key: {{ssh_private_key}}
38
+ {% if ssh_proxy_command is not none %}
39
+ ssh_proxy_command: {{ssh_proxy_command}}
40
+ {% endif %}
37
41
 
38
42
  available_node_types:
39
43
  ray_head_default:
@@ -16,8 +16,11 @@ from typing import Dict
16
16
  from urllib.request import Request
17
17
 
18
18
  import websockets
19
+ from websockets.asyncio.client import ClientConnection
19
20
  from websockets.asyncio.client import connect
20
21
 
22
+ BUFFER_SIZE = 2**16 # 64KB
23
+
21
24
 
22
25
  def _get_cookie_header(url: str) -> Dict[str, str]:
23
26
  """Extract Cookie header value from a cookie jar for a specific URL"""
@@ -51,19 +54,36 @@ async def main(url: str) -> None:
51
54
  old_settings = None
52
55
 
53
56
  try:
54
- await asyncio.gather(stdin_to_websocket(websocket),
55
- websocket_to_stdout(websocket))
57
+ loop = asyncio.get_running_loop()
58
+ # Use asyncio.Stream primitives to wrap stdin and stdout, this is to
59
+ # avoid creating a new thread for each read/write operation
60
+ # excessively.
61
+ stdin_reader = asyncio.StreamReader()
62
+ protocol = asyncio.StreamReaderProtocol(stdin_reader)
63
+ await loop.connect_read_pipe(lambda: protocol, sys.stdin)
64
+ transport, protocol = await loop.connect_write_pipe(
65
+ asyncio.streams.FlowControlMixin, sys.stdout) # type: ignore
66
+ stdout_writer = asyncio.StreamWriter(transport, protocol, None,
67
+ loop)
68
+
69
+ await asyncio.gather(stdin_to_websocket(stdin_reader, websocket),
70
+ websocket_to_stdout(websocket, stdout_writer))
56
71
  finally:
57
72
  if old_settings:
58
73
  termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN,
59
74
  old_settings)
60
75
 
61
76
 
62
- async def stdin_to_websocket(websocket):
77
+ async def stdin_to_websocket(reader: asyncio.StreamReader,
78
+ websocket: ClientConnection):
63
79
  try:
64
80
  while True:
65
- data = await asyncio.get_event_loop().run_in_executor(
66
- None, sys.stdin.buffer.read, 1)
81
+ # Read at most BUFFER_SIZE bytes, this not affect
82
+ # responsiveness since it will return as soon as
83
+ # there is at least one byte.
84
+ # The BUFFER_SIZE is chosen to be large enough to improve
85
+ # throughput.
86
+ data = await reader.read(BUFFER_SIZE)
67
87
  if not data:
68
88
  break
69
89
  await websocket.send(data)
@@ -73,13 +93,13 @@ async def stdin_to_websocket(websocket):
73
93
  await websocket.close()
74
94
 
75
95
 
76
- async def websocket_to_stdout(websocket):
96
+ async def websocket_to_stdout(websocket: ClientConnection,
97
+ writer: asyncio.StreamWriter):
77
98
  try:
78
99
  while True:
79
100
  message = await websocket.recv()
80
- sys.stdout.buffer.write(message)
81
- await asyncio.get_event_loop().run_in_executor(
82
- None, sys.stdout.buffer.flush)
101
+ writer.write(message)
102
+ await writer.drain()
83
103
  except websockets.exceptions.ConnectionClosed:
84
104
  print('WebSocket connection closed', file=sys.stderr)
85
105
  except Exception as e: # pylint: disable=broad-except
@@ -11,6 +11,7 @@ from sky import sky_logging
11
11
  from sky.skylet import constants
12
12
  from sky.skylet import log_lib
13
13
  from sky.utils import common_utils
14
+ from sky.utils import context_utils
14
15
  from sky.utils import control_master_utils
15
16
  from sky.utils import subprocess_utils
16
17
  from sky.utils import timeline
@@ -574,6 +575,7 @@ class SSHCommandRunner(CommandRunner):
574
575
  shell=True)
575
576
 
576
577
  @timeline.event
578
+ @context_utils.cancellation_guard
577
579
  def run(
578
580
  self,
579
581
  cmd: Union[str, List[str]],
@@ -779,6 +781,7 @@ class KubernetesCommandRunner(CommandRunner):
779
781
  return kubectl_cmd
780
782
 
781
783
  @timeline.event
784
+ @context_utils.cancellation_guard
782
785
  def run(
783
786
  self,
784
787
  cmd: Union[str, List[str]],