skypilot-nightly 1.0.0.dev20250509__py3-none-any.whl → 1.0.0.dev20250513__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +2 -2
- sky/backends/backend_utils.py +3 -0
- sky/backends/cloud_vm_ray_backend.py +7 -0
- sky/cli.py +109 -109
- sky/client/cli.py +109 -109
- sky/clouds/gcp.py +35 -8
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/{LksQgChY5izXjokL3LcEu → 2dkponv64SfFShA8Rnw0D}/_buildManifest.js +1 -1
- sky/dashboard/out/_next/static/chunks/845-0ca6f2c1ba667c3b.js +1 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/global_user_state.py +2 -0
- sky/provision/docker_utils.py +4 -1
- sky/provision/gcp/config.py +197 -15
- sky/provision/gcp/constants.py +64 -0
- sky/provision/gcp/instance.py +5 -3
- sky/provision/gcp/instance_utils.py +8 -4
- sky/provision/nebius/instance.py +3 -1
- sky/provision/nebius/utils.py +4 -2
- sky/server/requests/executor.py +114 -22
- sky/server/requests/requests.py +15 -0
- sky/server/server.py +12 -7
- sky/server/uvicorn.py +12 -2
- sky/sky_logging.py +40 -2
- sky/skylet/constants.py +3 -0
- sky/skylet/log_lib.py +51 -11
- sky/templates/gcp-ray.yml.j2 +11 -0
- sky/templates/nebius-ray.yml.j2 +4 -0
- sky/templates/websocket_proxy.py +29 -9
- sky/utils/command_runner.py +3 -0
- sky/utils/context.py +264 -0
- sky/utils/context_utils.py +172 -0
- sky/utils/rich_utils.py +81 -37
- sky/utils/schemas.py +9 -1
- sky/utils/subprocess_utils.py +8 -2
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/METADATA +1 -5
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/RECORD +46 -44
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/WHEEL +1 -1
- sky/dashboard/out/_next/static/chunks/845-0f8017370869e269.js +0 -1
- /sky/dashboard/out/_next/static/{LksQgChY5izXjokL3LcEu → 2dkponv64SfFShA8Rnw0D}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20250513.dist-info}/top_level.txt +0 -0
sky/server/requests/executor.py
CHANGED
@@ -18,7 +18,10 @@ The number of the workers is determined by the system resources.
|
|
18
18
|
|
19
19
|
See the [README.md](../README.md) for detailed architecture of the executor.
|
20
20
|
"""
|
21
|
+
import asyncio
|
21
22
|
import contextlib
|
23
|
+
import contextvars
|
24
|
+
import functools
|
22
25
|
import multiprocessing
|
23
26
|
import os
|
24
27
|
import queue as queue_lib
|
@@ -47,6 +50,7 @@ from sky.server.requests.queues import mp_queue
|
|
47
50
|
from sky.skylet import constants
|
48
51
|
from sky.utils import annotations
|
49
52
|
from sky.utils import common_utils
|
53
|
+
from sky.utils import context
|
50
54
|
from sky.utils import subprocess_utils
|
51
55
|
from sky.utils import timeline
|
52
56
|
|
@@ -60,7 +64,6 @@ else:
|
|
60
64
|
from typing_extensions import ParamSpec
|
61
65
|
|
62
66
|
P = ParamSpec('P')
|
63
|
-
|
64
67
|
logger = sky_logging.init_logger(__name__)
|
65
68
|
|
66
69
|
# On macOS, the default start method for multiprocessing is 'fork', which
|
@@ -341,6 +344,114 @@ def _request_execution_wrapper(request_id: str,
|
|
341
344
|
logger.info(f'Request {request_id} finished')
|
342
345
|
|
343
346
|
|
347
|
+
async def execute_request_coroutine(request: api_requests.Request):
|
348
|
+
"""Execute a request in current event loop.
|
349
|
+
|
350
|
+
Similar to _request_execution_wrapper, but executed as coroutine in current
|
351
|
+
event loop. This is designed for executing tasks that are not CPU
|
352
|
+
intensive, e.g. sky logs.
|
353
|
+
"""
|
354
|
+
ctx = context.get()
|
355
|
+
if ctx is None:
|
356
|
+
raise ValueError('Context is not initialized')
|
357
|
+
logger.info(f'Executing request {request.request_id} in coroutine')
|
358
|
+
func = request.entrypoint
|
359
|
+
request_body = request.request_body
|
360
|
+
with api_requests.update_request(request.request_id) as request_task:
|
361
|
+
request_task.status = api_requests.RequestStatus.RUNNING
|
362
|
+
# Redirect stdout and stderr to the request log path.
|
363
|
+
original_output = ctx.redirect_log(request.log_path)
|
364
|
+
# Override environment variables that backs env_options.Options
|
365
|
+
# TODO(aylei): compared to process executor, running task in coroutine has
|
366
|
+
# two issues to fix:
|
367
|
+
# 1. skypilot config is not contextual
|
368
|
+
# 2. envs that read directly from os.environ are not contextual
|
369
|
+
ctx.override_envs(request_body.env_vars)
|
370
|
+
loop = asyncio.get_running_loop()
|
371
|
+
pyctx = contextvars.copy_context()
|
372
|
+
func_call = functools.partial(pyctx.run, func, **request_body.to_kwargs())
|
373
|
+
fut: asyncio.Future = loop.run_in_executor(None, func_call)
|
374
|
+
|
375
|
+
async def poll_task(request_id: str) -> bool:
|
376
|
+
request = api_requests.get_request(request_id)
|
377
|
+
if request is None:
|
378
|
+
raise RuntimeError('Request not found')
|
379
|
+
|
380
|
+
if request.status == api_requests.RequestStatus.CANCELLED:
|
381
|
+
ctx.cancel()
|
382
|
+
return True
|
383
|
+
|
384
|
+
if fut.done():
|
385
|
+
try:
|
386
|
+
result = await fut
|
387
|
+
api_requests.set_request_succeeded(request_id, result)
|
388
|
+
except asyncio.CancelledError:
|
389
|
+
# The task is cancelled by ctx.cancel(), where the status
|
390
|
+
# should already be set to CANCELLED.
|
391
|
+
pass
|
392
|
+
except Exception as e: # pylint: disable=broad-except
|
393
|
+
ctx.redirect_log(original_output)
|
394
|
+
api_requests.set_request_failed(request_id, e)
|
395
|
+
logger.error(f'Request {request_id} failed due to '
|
396
|
+
f'{common_utils.format_exception(e)}')
|
397
|
+
return True
|
398
|
+
return False
|
399
|
+
|
400
|
+
try:
|
401
|
+
while True:
|
402
|
+
res = await poll_task(request.request_id)
|
403
|
+
if res:
|
404
|
+
break
|
405
|
+
await asyncio.sleep(0.5)
|
406
|
+
except asyncio.CancelledError:
|
407
|
+
# Current coroutine is cancelled due to client disconnect, set the
|
408
|
+
# request status for consistency.
|
409
|
+
api_requests.set_request_cancelled(request.request_id)
|
410
|
+
pass
|
411
|
+
# pylint: disable=broad-except
|
412
|
+
except (Exception, KeyboardInterrupt, SystemExit) as e:
|
413
|
+
# Handle any other error
|
414
|
+
ctx.redirect_log(original_output)
|
415
|
+
ctx.cancel()
|
416
|
+
api_requests.set_request_failed(request.request_id, e)
|
417
|
+
logger.error(f'Request {request.request_id} interrupted due to '
|
418
|
+
f'unhandled exception: {common_utils.format_exception(e)}')
|
419
|
+
raise
|
420
|
+
|
421
|
+
|
422
|
+
def prepare_request(
|
423
|
+
request_id: str,
|
424
|
+
request_name: str,
|
425
|
+
request_body: payloads.RequestBody,
|
426
|
+
func: Callable[P, Any],
|
427
|
+
request_cluster_name: Optional[str] = None,
|
428
|
+
schedule_type: api_requests.ScheduleType = (api_requests.ScheduleType.LONG),
|
429
|
+
is_skypilot_system: bool = False,
|
430
|
+
) -> api_requests.Request:
|
431
|
+
"""Prepare a request for execution."""
|
432
|
+
user_id = request_body.env_vars[constants.USER_ID_ENV_VAR]
|
433
|
+
if is_skypilot_system:
|
434
|
+
user_id = server_constants.SKYPILOT_SYSTEM_USER_ID
|
435
|
+
global_user_state.add_or_update_user(
|
436
|
+
models.User(id=user_id, name=user_id))
|
437
|
+
request = api_requests.Request(request_id=request_id,
|
438
|
+
name=server_constants.REQUEST_NAME_PREFIX +
|
439
|
+
request_name,
|
440
|
+
entrypoint=func,
|
441
|
+
request_body=request_body,
|
442
|
+
status=api_requests.RequestStatus.PENDING,
|
443
|
+
created_at=time.time(),
|
444
|
+
schedule_type=schedule_type,
|
445
|
+
user_id=user_id,
|
446
|
+
cluster_name=request_cluster_name)
|
447
|
+
|
448
|
+
if not api_requests.create_if_not_exists(request):
|
449
|
+
raise RuntimeError(f'Request {request_id} already exists.')
|
450
|
+
|
451
|
+
request.log_path.touch()
|
452
|
+
return request
|
453
|
+
|
454
|
+
|
344
455
|
def schedule_request(
|
345
456
|
request_id: str,
|
346
457
|
request_name: str,
|
@@ -372,27 +483,8 @@ def schedule_request(
|
|
372
483
|
The precondition is waited asynchronously and does not block the
|
373
484
|
caller.
|
374
485
|
"""
|
375
|
-
|
376
|
-
|
377
|
-
user_id = server_constants.SKYPILOT_SYSTEM_USER_ID
|
378
|
-
global_user_state.add_or_update_user(
|
379
|
-
models.User(id=user_id, name=user_id))
|
380
|
-
request = api_requests.Request(request_id=request_id,
|
381
|
-
name=server_constants.REQUEST_NAME_PREFIX +
|
382
|
-
request_name,
|
383
|
-
entrypoint=func,
|
384
|
-
request_body=request_body,
|
385
|
-
status=api_requests.RequestStatus.PENDING,
|
386
|
-
created_at=time.time(),
|
387
|
-
schedule_type=schedule_type,
|
388
|
-
user_id=user_id,
|
389
|
-
cluster_name=request_cluster_name)
|
390
|
-
|
391
|
-
if not api_requests.create_if_not_exists(request):
|
392
|
-
logger.debug(f'Request {request_id} already exists.')
|
393
|
-
return
|
394
|
-
|
395
|
-
request.log_path.touch()
|
486
|
+
prepare_request(request_id, request_name, request_body, func,
|
487
|
+
request_cluster_name, schedule_type, is_skypilot_system)
|
396
488
|
|
397
489
|
def enqueue():
|
398
490
|
input_tuple = (request_id, ignore_return_value)
|
sky/server/requests/requests.py
CHANGED
@@ -606,3 +606,18 @@ def set_request_failed(request_id: str, e: BaseException) -> None:
|
|
606
606
|
assert request_task is not None, request_id
|
607
607
|
request_task.status = RequestStatus.FAILED
|
608
608
|
request_task.set_error(e)
|
609
|
+
|
610
|
+
|
611
|
+
def set_request_succeeded(request_id: str, result: Any) -> None:
|
612
|
+
"""Set a request to succeeded and populate the result."""
|
613
|
+
with update_request(request_id) as request_task:
|
614
|
+
assert request_task is not None, request_id
|
615
|
+
request_task.status = RequestStatus.SUCCEEDED
|
616
|
+
request_task.set_return_value(result)
|
617
|
+
|
618
|
+
|
619
|
+
def set_request_cancelled(request_id: str) -> None:
|
620
|
+
"""Set a request to cancelled."""
|
621
|
+
with update_request(request_id) as request_task:
|
622
|
+
assert request_task is not None, request_id
|
623
|
+
request_task.status = RequestStatus.CANCELLED
|
sky/server/server.py
CHANGED
@@ -47,6 +47,7 @@ from sky.usage import usage_lib
|
|
47
47
|
from sky.utils import admin_policy_utils
|
48
48
|
from sky.utils import common as common_lib
|
49
49
|
from sky.utils import common_utils
|
50
|
+
from sky.utils import context
|
50
51
|
from sky.utils import dag_utils
|
51
52
|
from sky.utils import env_options
|
52
53
|
from sky.utils import status_lib
|
@@ -673,24 +674,28 @@ async def logs(
|
|
673
674
|
# TODO(zhwu): This should wait for the request on the cluster, e.g., async
|
674
675
|
# launch, to finish, so that a user does not need to manually pull the
|
675
676
|
# request status.
|
676
|
-
|
677
|
+
# Only initialize the context in logs handler to limit the scope of this
|
678
|
+
# experimental change.
|
679
|
+
# TODO(aylei): init in lifespan() to enable SkyPilot context in all APIs.
|
680
|
+
context.initialize()
|
681
|
+
request_task = executor.prepare_request(
|
677
682
|
request_id=request.state.request_id,
|
678
683
|
request_name='logs',
|
679
684
|
request_body=cluster_job_body,
|
680
685
|
func=core.tail_logs,
|
681
|
-
# TODO(aylei): We have tail logs scheduled as SHORT request, because it
|
682
|
-
# should be responsive. However, it can be long running if the user's
|
683
|
-
# job keeps running, and we should avoid it taking the SHORT worker.
|
684
686
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
685
|
-
request_cluster_name=cluster_job_body.cluster_name,
|
686
687
|
)
|
688
|
+
task = asyncio.create_task(executor.execute_request_coroutine(request_task))
|
687
689
|
|
688
|
-
|
690
|
+
def cancel_task():
|
691
|
+
task.cancel()
|
689
692
|
|
693
|
+
# Cancel the task after the request is done or client disconnects
|
694
|
+
background_tasks.add_task(cancel_task)
|
690
695
|
# TODO(zhwu): This makes viewing logs in browser impossible. We should adopt
|
691
696
|
# the same approach as /stream.
|
692
697
|
return stream_utils.stream_response(
|
693
|
-
request_id=
|
698
|
+
request_id=request.state.request_id,
|
694
699
|
logs_path=request_task.log_path,
|
695
700
|
background_tasks=background_tasks,
|
696
701
|
)
|
sky/server/uvicorn.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
This module is a wrapper around uvicorn to customize the behavior of the
|
4
4
|
server.
|
5
5
|
"""
|
6
|
+
import functools
|
6
7
|
import os
|
7
8
|
import threading
|
8
9
|
from typing import Optional
|
@@ -10,6 +11,7 @@ from typing import Optional
|
|
10
11
|
import uvicorn
|
11
12
|
from uvicorn.supervisors import multiprocess
|
12
13
|
|
14
|
+
from sky.utils import context_utils
|
13
15
|
from sky.utils import subprocess_utils
|
14
16
|
|
15
17
|
|
@@ -21,19 +23,27 @@ def run(config: uvicorn.Config):
|
|
21
23
|
# guard by an exception.
|
22
24
|
raise ValueError('Reload is not supported yet.')
|
23
25
|
server = uvicorn.Server(config=config)
|
26
|
+
run_server_process = functools.partial(_run_server_process, server)
|
24
27
|
try:
|
25
28
|
if config.workers is not None and config.workers > 1:
|
26
29
|
sock = config.bind_socket()
|
27
|
-
SlowStartMultiprocess(config,
|
30
|
+
SlowStartMultiprocess(config,
|
31
|
+
target=run_server_process,
|
28
32
|
sockets=[sock]).run()
|
29
33
|
else:
|
30
|
-
|
34
|
+
run_server_process()
|
31
35
|
finally:
|
32
36
|
# Copied from unvicorn.run()
|
33
37
|
if config.uds and os.path.exists(config.uds):
|
34
38
|
os.remove(config.uds)
|
35
39
|
|
36
40
|
|
41
|
+
def _run_server_process(server: uvicorn.Server, *args, **kwargs):
|
42
|
+
"""Run the server process with contextually aware."""
|
43
|
+
context_utils.hijack_sys_attrs()
|
44
|
+
server.run(*args, **kwargs)
|
45
|
+
|
46
|
+
|
37
47
|
class SlowStartMultiprocess(multiprocess.Multiprocess):
|
38
48
|
"""Uvicorn Multiprocess wrapper with slow start.
|
39
49
|
|
sky/sky_logging.py
CHANGED
@@ -10,6 +10,7 @@ import threading
|
|
10
10
|
import colorama
|
11
11
|
|
12
12
|
from sky.skylet import constants
|
13
|
+
from sky.utils import context
|
13
14
|
from sky.utils import env_options
|
14
15
|
from sky.utils import rich_utils
|
15
16
|
|
@@ -47,6 +48,43 @@ class NewLineFormatter(logging.Formatter):
|
|
47
48
|
return msg
|
48
49
|
|
49
50
|
|
51
|
+
class EnvAwareHandler(rich_utils.RichSafeStreamHandler):
|
52
|
+
"""A handler that awares environment variables.
|
53
|
+
|
54
|
+
This handler dynamically reflects the log level from environment variables.
|
55
|
+
"""
|
56
|
+
|
57
|
+
def __init__(self, stream=None, level=logging.NOTSET, sensitive=False):
|
58
|
+
super().__init__(stream)
|
59
|
+
self.level = level
|
60
|
+
self._sensitive = sensitive
|
61
|
+
|
62
|
+
@property
|
63
|
+
def level(self):
|
64
|
+
# Only refresh log level if we are in a context, since the log level
|
65
|
+
# has already been reloaded eagerly in multi-processing. Refresh again
|
66
|
+
# is a no-op and can be avoided.
|
67
|
+
# TODO(aylei): unify the mechanism for coroutine context and
|
68
|
+
# multi-processing.
|
69
|
+
if context.get() is not None:
|
70
|
+
if self._sensitive:
|
71
|
+
# For sensitive logger, suppress debug log despite the
|
72
|
+
# SKYPILOT_DEBUG env var if SUPPRESS_SENSITIVE_LOG is set
|
73
|
+
if env_options.Options.SUPPRESS_SENSITIVE_LOG.get():
|
74
|
+
return logging.INFO
|
75
|
+
if env_options.Options.SHOW_DEBUG_INFO.get():
|
76
|
+
return logging.DEBUG
|
77
|
+
else:
|
78
|
+
return self._level
|
79
|
+
else:
|
80
|
+
return self._level
|
81
|
+
|
82
|
+
@level.setter
|
83
|
+
def level(self, level):
|
84
|
+
# pylint: disable=protected-access
|
85
|
+
self._level = logging._checkLevel(level)
|
86
|
+
|
87
|
+
|
50
88
|
_root_logger = logging.getLogger('sky')
|
51
89
|
_default_handler = None
|
52
90
|
_logging_config = threading.local()
|
@@ -67,7 +105,7 @@ def _setup_logger():
|
|
67
105
|
_root_logger.setLevel(logging.DEBUG)
|
68
106
|
global _default_handler
|
69
107
|
if _default_handler is None:
|
70
|
-
_default_handler =
|
108
|
+
_default_handler = EnvAwareHandler(sys.stdout)
|
71
109
|
_default_handler.flush = sys.stdout.flush # type: ignore
|
72
110
|
if env_options.Options.SHOW_DEBUG_INFO.get():
|
73
111
|
_default_handler.setLevel(logging.DEBUG)
|
@@ -87,7 +125,7 @@ def _setup_logger():
|
|
87
125
|
# for certain loggers.
|
88
126
|
for logger_name in _SENSITIVE_LOGGER:
|
89
127
|
logger = logging.getLogger(logger_name)
|
90
|
-
handler_to_logger =
|
128
|
+
handler_to_logger = EnvAwareHandler(sys.stdout, sensitive=True)
|
91
129
|
handler_to_logger.flush = sys.stdout.flush # type: ignore
|
92
130
|
logger.addHandler(handler_to_logger)
|
93
131
|
logger.setLevel(logging.INFO)
|
sky/skylet/constants.py
CHANGED
@@ -370,6 +370,9 @@ OVERRIDEABLE_CONFIG_KEYS_IN_TASK: List[Tuple[str, ...]] = [
|
|
370
370
|
('kubernetes', 'pod_config'),
|
371
371
|
('kubernetes', 'provision_timeout'),
|
372
372
|
('gcp', 'managed_instance_group'),
|
373
|
+
('gcp', 'enable_gvnic'),
|
374
|
+
('gcp', 'enable_gpu_direct'),
|
375
|
+
('gcp', 'placement_policy'),
|
373
376
|
]
|
374
377
|
# When overriding the SkyPilot configs on the API server with the client one,
|
375
378
|
# we skip the following keys because they are meant to be client-side configs.
|
sky/skylet/log_lib.py
CHANGED
@@ -4,6 +4,7 @@ This is a remote utility module that provides logging functionality.
|
|
4
4
|
"""
|
5
5
|
import collections
|
6
6
|
import copy
|
7
|
+
import functools
|
7
8
|
import io
|
8
9
|
import multiprocessing.pool
|
9
10
|
import os
|
@@ -21,6 +22,8 @@ import colorama
|
|
21
22
|
from sky import sky_logging
|
22
23
|
from sky.skylet import constants
|
23
24
|
from sky.skylet import job_lib
|
25
|
+
from sky.utils import context
|
26
|
+
from sky.utils import context_utils
|
24
27
|
from sky.utils import log_utils
|
25
28
|
from sky.utils import subprocess_utils
|
26
29
|
from sky.utils import ux_utils
|
@@ -77,6 +80,9 @@ def _handle_io_stream(io_stream, out_stream, args: _ProcessingArgs):
|
|
77
80
|
with open(args.log_path, 'a', encoding='utf-8') as fout:
|
78
81
|
with line_processor:
|
79
82
|
while True:
|
83
|
+
ctx = context.get()
|
84
|
+
if ctx is not None and ctx.is_canceled():
|
85
|
+
return
|
80
86
|
line = out_io.readline()
|
81
87
|
if not line:
|
82
88
|
break
|
@@ -111,30 +117,29 @@ def _handle_io_stream(io_stream, out_stream, args: _ProcessingArgs):
|
|
111
117
|
return ''.join(out)
|
112
118
|
|
113
119
|
|
114
|
-
def process_subprocess_stream(proc,
|
115
|
-
|
120
|
+
def process_subprocess_stream(proc, stdout_stream_handler,
|
121
|
+
stderr_stream_handler) -> Tuple[str, str]:
|
122
|
+
"""Process the stream of a process in threads, blocking."""
|
116
123
|
if proc.stderr is not None:
|
117
124
|
# Asyncio does not work as the output processing can be executed in a
|
118
125
|
# different thread.
|
119
126
|
# selectors is possible to handle the multiplexing of stdout/stderr,
|
120
127
|
# but it introduces buffering making the output not streaming.
|
121
128
|
with multiprocessing.pool.ThreadPool(processes=1) as pool:
|
122
|
-
|
123
|
-
|
124
|
-
stderr_fut = pool.apply_async(_handle_io_stream,
|
125
|
-
args=(proc.stderr, sys.stderr,
|
126
|
-
err_args))
|
129
|
+
stderr_fut = pool.apply_async(stderr_stream_handler,
|
130
|
+
args=(proc.stderr, sys.stderr))
|
127
131
|
# Do not launch a thread for stdout as the rich.status does not
|
128
132
|
# work in a thread, which is used in
|
129
133
|
# log_utils.RayUpLineProcessor.
|
130
|
-
stdout =
|
134
|
+
stdout = stdout_stream_handler(proc.stdout, sys.stdout)
|
131
135
|
stderr = stderr_fut.get()
|
132
136
|
else:
|
133
|
-
stdout =
|
137
|
+
stdout = stdout_stream_handler(proc.stdout, sys.stdout)
|
134
138
|
stderr = ''
|
135
139
|
return stdout, stderr
|
136
140
|
|
137
141
|
|
142
|
+
@context_utils.cancellation_guard
|
138
143
|
def run_with_log(
|
139
144
|
cmd: Union[List[str], str],
|
140
145
|
log_path: str,
|
@@ -176,7 +181,12 @@ def run_with_log(
|
|
176
181
|
# Redirect stderr to stdout when using ray, to preserve the order of
|
177
182
|
# stdout and stderr.
|
178
183
|
stdout_arg = stderr_arg = None
|
179
|
-
|
184
|
+
ctx = context.get()
|
185
|
+
if process_stream or ctx is not None:
|
186
|
+
# Capture stdout/stderr of the subprocess if:
|
187
|
+
# 1. Post-processing is needed (process_stream=True)
|
188
|
+
# 2. Potential contextual handling is needed (ctx is not None)
|
189
|
+
# TODO(aylei): can we always capture the stdout/stderr?
|
180
190
|
stdout_arg = subprocess.PIPE
|
181
191
|
stderr_arg = subprocess.PIPE if not with_ray else subprocess.STDOUT
|
182
192
|
# Use stdin=subprocess.DEVNULL by default, as allowing inputs will mess up
|
@@ -197,6 +207,8 @@ def run_with_log(
|
|
197
207
|
subprocess_utils.kill_process_daemon(proc.pid)
|
198
208
|
stdout = ''
|
199
209
|
stderr = ''
|
210
|
+
stdout_stream_handler = None
|
211
|
+
stderr_stream_handler = None
|
200
212
|
|
201
213
|
if process_stream:
|
202
214
|
if skip_lines is None:
|
@@ -223,7 +235,35 @@ def run_with_log(
|
|
223
235
|
replace_crlf=with_ray,
|
224
236
|
streaming_prefix=streaming_prefix,
|
225
237
|
)
|
226
|
-
|
238
|
+
stdout_stream_handler = functools.partial(
|
239
|
+
_handle_io_stream,
|
240
|
+
args=args,
|
241
|
+
)
|
242
|
+
if proc.stderr is not None:
|
243
|
+
err_args = copy.copy(args)
|
244
|
+
err_args.line_processor = None
|
245
|
+
stderr_stream_handler = functools.partial(
|
246
|
+
_handle_io_stream,
|
247
|
+
args=err_args,
|
248
|
+
)
|
249
|
+
if ctx is not None:
|
250
|
+
# When runs in a coroutine, always process the subprocess
|
251
|
+
# stream to:
|
252
|
+
# 1. handle context cancellation
|
253
|
+
# 2. redirect subprocess stdout/stderr to the contextual
|
254
|
+
# stdout/stderr of current coroutine.
|
255
|
+
stdout, stderr = context_utils.pipe_and_wait_process(
|
256
|
+
ctx,
|
257
|
+
proc,
|
258
|
+
cancel_callback=subprocess_utils.kill_children_processes,
|
259
|
+
stdout_stream_handler=stdout_stream_handler,
|
260
|
+
stderr_stream_handler=stderr_stream_handler)
|
261
|
+
elif process_stream:
|
262
|
+
# When runs in a process, only process subprocess stream if
|
263
|
+
# necessary to avoid unnecessary stream handling overhead.
|
264
|
+
stdout, stderr = process_subprocess_stream(
|
265
|
+
proc, stdout_stream_handler, stderr_stream_handler)
|
266
|
+
# Ensure returncode is set.
|
227
267
|
proc.wait()
|
228
268
|
if require_outputs:
|
229
269
|
return proc.returncode, stdout, stderr
|
sky/templates/gcp-ray.yml.j2
CHANGED
@@ -69,6 +69,12 @@ provider:
|
|
69
69
|
{%- if enable_gvnic %}
|
70
70
|
enable_gvnic: {{ enable_gvnic }}
|
71
71
|
{%- endif %}
|
72
|
+
{%- if enable_gpu_direct %}
|
73
|
+
enable_gpu_direct: {{ enable_gpu_direct }}
|
74
|
+
{%- endif %}
|
75
|
+
{%- if placement_policy %}
|
76
|
+
placement_policy: {{ placement_policy }}
|
77
|
+
{%- endif %}
|
72
78
|
|
73
79
|
auth:
|
74
80
|
ssh_user: gcpuser
|
@@ -148,6 +154,11 @@ available_node_types:
|
|
148
154
|
- key: install-nvidia-driver
|
149
155
|
value: "True"
|
150
156
|
{%- endif %}
|
157
|
+
{%- if user_data is not none %}
|
158
|
+
- key: user-data
|
159
|
+
value: |-
|
160
|
+
{{ user_data | indent(10) }}
|
161
|
+
{%- endif %}
|
151
162
|
{%- if use_spot or gpu is not none %}
|
152
163
|
scheduling:
|
153
164
|
{%- if use_spot %}
|
sky/templates/nebius-ray.yml.j2
CHANGED
@@ -9,6 +9,7 @@ provider:
|
|
9
9
|
type: external
|
10
10
|
module: sky.provision.nebius
|
11
11
|
region: "{{region}}"
|
12
|
+
use_internal_ips: {{use_internal_ips}}
|
12
13
|
|
13
14
|
{%- if docker_image is not none %}
|
14
15
|
docker:
|
@@ -34,6 +35,9 @@ docker:
|
|
34
35
|
auth:
|
35
36
|
ssh_user: ubuntu
|
36
37
|
ssh_private_key: {{ssh_private_key}}
|
38
|
+
{% if ssh_proxy_command is not none %}
|
39
|
+
ssh_proxy_command: {{ssh_proxy_command}}
|
40
|
+
{% endif %}
|
37
41
|
|
38
42
|
available_node_types:
|
39
43
|
ray_head_default:
|
sky/templates/websocket_proxy.py
CHANGED
@@ -16,8 +16,11 @@ from typing import Dict
|
|
16
16
|
from urllib.request import Request
|
17
17
|
|
18
18
|
import websockets
|
19
|
+
from websockets.asyncio.client import ClientConnection
|
19
20
|
from websockets.asyncio.client import connect
|
20
21
|
|
22
|
+
BUFFER_SIZE = 2**16 # 64KB
|
23
|
+
|
21
24
|
|
22
25
|
def _get_cookie_header(url: str) -> Dict[str, str]:
|
23
26
|
"""Extract Cookie header value from a cookie jar for a specific URL"""
|
@@ -51,19 +54,36 @@ async def main(url: str) -> None:
|
|
51
54
|
old_settings = None
|
52
55
|
|
53
56
|
try:
|
54
|
-
|
55
|
-
|
57
|
+
loop = asyncio.get_running_loop()
|
58
|
+
# Use asyncio.Stream primitives to wrap stdin and stdout, this is to
|
59
|
+
# avoid creating a new thread for each read/write operation
|
60
|
+
# excessively.
|
61
|
+
stdin_reader = asyncio.StreamReader()
|
62
|
+
protocol = asyncio.StreamReaderProtocol(stdin_reader)
|
63
|
+
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
|
64
|
+
transport, protocol = await loop.connect_write_pipe(
|
65
|
+
asyncio.streams.FlowControlMixin, sys.stdout) # type: ignore
|
66
|
+
stdout_writer = asyncio.StreamWriter(transport, protocol, None,
|
67
|
+
loop)
|
68
|
+
|
69
|
+
await asyncio.gather(stdin_to_websocket(stdin_reader, websocket),
|
70
|
+
websocket_to_stdout(websocket, stdout_writer))
|
56
71
|
finally:
|
57
72
|
if old_settings:
|
58
73
|
termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN,
|
59
74
|
old_settings)
|
60
75
|
|
61
76
|
|
62
|
-
async def stdin_to_websocket(
|
77
|
+
async def stdin_to_websocket(reader: asyncio.StreamReader,
|
78
|
+
websocket: ClientConnection):
|
63
79
|
try:
|
64
80
|
while True:
|
65
|
-
|
66
|
-
|
81
|
+
# Read at most BUFFER_SIZE bytes, this not affect
|
82
|
+
# responsiveness since it will return as soon as
|
83
|
+
# there is at least one byte.
|
84
|
+
# The BUFFER_SIZE is chosen to be large enough to improve
|
85
|
+
# throughput.
|
86
|
+
data = await reader.read(BUFFER_SIZE)
|
67
87
|
if not data:
|
68
88
|
break
|
69
89
|
await websocket.send(data)
|
@@ -73,13 +93,13 @@ async def stdin_to_websocket(websocket):
|
|
73
93
|
await websocket.close()
|
74
94
|
|
75
95
|
|
76
|
-
async def websocket_to_stdout(websocket
|
96
|
+
async def websocket_to_stdout(websocket: ClientConnection,
|
97
|
+
writer: asyncio.StreamWriter):
|
77
98
|
try:
|
78
99
|
while True:
|
79
100
|
message = await websocket.recv()
|
80
|
-
|
81
|
-
await
|
82
|
-
None, sys.stdout.buffer.flush)
|
101
|
+
writer.write(message)
|
102
|
+
await writer.drain()
|
83
103
|
except websockets.exceptions.ConnectionClosed:
|
84
104
|
print('WebSocket connection closed', file=sys.stderr)
|
85
105
|
except Exception as e: # pylint: disable=broad-except
|
sky/utils/command_runner.py
CHANGED
@@ -11,6 +11,7 @@ from sky import sky_logging
|
|
11
11
|
from sky.skylet import constants
|
12
12
|
from sky.skylet import log_lib
|
13
13
|
from sky.utils import common_utils
|
14
|
+
from sky.utils import context_utils
|
14
15
|
from sky.utils import control_master_utils
|
15
16
|
from sky.utils import subprocess_utils
|
16
17
|
from sky.utils import timeline
|
@@ -574,6 +575,7 @@ class SSHCommandRunner(CommandRunner):
|
|
574
575
|
shell=True)
|
575
576
|
|
576
577
|
@timeline.event
|
578
|
+
@context_utils.cancellation_guard
|
577
579
|
def run(
|
578
580
|
self,
|
579
581
|
cmd: Union[str, List[str]],
|
@@ -779,6 +781,7 @@ class KubernetesCommandRunner(CommandRunner):
|
|
779
781
|
return kubectl_cmd
|
780
782
|
|
781
783
|
@timeline.event
|
784
|
+
@context_utils.cancellation_guard
|
782
785
|
def run(
|
783
786
|
self,
|
784
787
|
cmd: Union[str, List[str]],
|