skypilot-nightly 1.0.0.dev20250514__py3-none-any.whl → 1.0.0.dev20250516__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend.py +3 -2
  3. sky/backends/backend_utils.py +19 -17
  4. sky/backends/cloud_vm_ray_backend.py +30 -11
  5. sky/clouds/aws.py +11 -9
  6. sky/clouds/azure.py +16 -13
  7. sky/clouds/cloud.py +4 -3
  8. sky/clouds/cudo.py +3 -2
  9. sky/clouds/do.py +3 -2
  10. sky/clouds/fluidstack.py +3 -3
  11. sky/clouds/gcp.py +1 -1
  12. sky/clouds/ibm.py +12 -10
  13. sky/clouds/kubernetes.py +3 -2
  14. sky/clouds/lambda_cloud.py +6 -6
  15. sky/clouds/nebius.py +6 -5
  16. sky/clouds/oci.py +9 -7
  17. sky/clouds/paperspace.py +3 -2
  18. sky/clouds/runpod.py +9 -9
  19. sky/clouds/scp.py +5 -3
  20. sky/clouds/vast.py +8 -7
  21. sky/clouds/vsphere.py +4 -2
  22. sky/core.py +18 -12
  23. sky/dashboard/out/404.html +1 -1
  24. sky/dashboard/out/_next/static/chunks/pages/index-6b0d9e5031b70c58.js +1 -0
  25. sky/dashboard/out/_next/static/{tdxxQrPV6NW90a983oHXe → y1yf6Xc0zwam5fFluIyUm}/_buildManifest.js +1 -1
  26. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  27. sky/dashboard/out/clusters/[cluster].html +1 -1
  28. sky/dashboard/out/clusters.html +1 -1
  29. sky/dashboard/out/index.html +1 -1
  30. sky/dashboard/out/jobs/[job].html +1 -1
  31. sky/dashboard/out/jobs.html +1 -1
  32. sky/execution.py +33 -0
  33. sky/global_user_state.py +2 -0
  34. sky/jobs/recovery_strategy.py +4 -1
  35. sky/jobs/server/core.py +6 -12
  36. sky/optimizer.py +19 -13
  37. sky/provision/kubernetes/utils.py +26 -1
  38. sky/resources.py +203 -44
  39. sky/serve/server/core.py +0 -5
  40. sky/serve/spot_placer.py +3 -0
  41. sky/server/requests/executor.py +114 -22
  42. sky/server/requests/requests.py +15 -0
  43. sky/server/server.py +63 -20
  44. sky/server/uvicorn.py +12 -2
  45. sky/setup_files/dependencies.py +4 -1
  46. sky/sky_logging.py +40 -2
  47. sky/skylet/log_lib.py +60 -11
  48. sky/skylet/log_lib.pyi +5 -0
  49. sky/task.py +8 -6
  50. sky/utils/cli_utils/status_utils.py +6 -5
  51. sky/utils/command_runner.py +3 -0
  52. sky/utils/context.py +264 -0
  53. sky/utils/context_utils.py +172 -0
  54. sky/utils/controller_utils.py +39 -43
  55. sky/utils/dag_utils.py +4 -2
  56. sky/utils/resources_utils.py +3 -0
  57. sky/utils/rich_utils.py +81 -37
  58. sky/utils/schemas.py +33 -24
  59. sky/utils/subprocess_utils.py +8 -2
  60. {skypilot_nightly-1.0.0.dev20250514.dist-info → skypilot_nightly-1.0.0.dev20250516.dist-info}/METADATA +2 -2
  61. {skypilot_nightly-1.0.0.dev20250514.dist-info → skypilot_nightly-1.0.0.dev20250516.dist-info}/RECORD +66 -64
  62. {skypilot_nightly-1.0.0.dev20250514.dist-info → skypilot_nightly-1.0.0.dev20250516.dist-info}/WHEEL +1 -1
  63. sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
  64. /sky/dashboard/out/_next/static/{tdxxQrPV6NW90a983oHXe → y1yf6Xc0zwam5fFluIyUm}/_ssgManifest.js +0 -0
  65. {skypilot_nightly-1.0.0.dev20250514.dist-info → skypilot_nightly-1.0.0.dev20250516.dist-info}/entry_points.txt +0 -0
  66. {skypilot_nightly-1.0.0.dev20250514.dist-info → skypilot_nightly-1.0.0.dev20250516.dist-info}/licenses/LICENSE +0 -0
  67. {skypilot_nightly-1.0.0.dev20250514.dist-info → skypilot_nightly-1.0.0.dev20250516.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,10 @@ The number of the workers is determined by the system resources.
18
18
 
19
19
  See the [README.md](../README.md) for detailed architecture of the executor.
20
20
  """
21
+ import asyncio
21
22
  import contextlib
23
+ import contextvars
24
+ import functools
22
25
  import multiprocessing
23
26
  import os
24
27
  import queue as queue_lib
@@ -47,6 +50,7 @@ from sky.server.requests.queues import mp_queue
47
50
  from sky.skylet import constants
48
51
  from sky.utils import annotations
49
52
  from sky.utils import common_utils
53
+ from sky.utils import context
50
54
  from sky.utils import subprocess_utils
51
55
  from sky.utils import timeline
52
56
 
@@ -60,7 +64,6 @@ else:
60
64
  from typing_extensions import ParamSpec
61
65
 
62
66
  P = ParamSpec('P')
63
-
64
67
  logger = sky_logging.init_logger(__name__)
65
68
 
66
69
  # On macOS, the default start method for multiprocessing is 'fork', which
@@ -341,6 +344,114 @@ def _request_execution_wrapper(request_id: str,
341
344
  logger.info(f'Request {request_id} finished')
342
345
 
343
346
 
347
+ async def execute_request_coroutine(request: api_requests.Request):
348
+ """Execute a request in current event loop.
349
+
350
+ Similar to _request_execution_wrapper, but executed as coroutine in current
351
+ event loop. This is designed for executing tasks that are not CPU
352
+ intensive, e.g. sky logs.
353
+ """
354
+ ctx = context.get()
355
+ if ctx is None:
356
+ raise ValueError('Context is not initialized')
357
+ logger.info(f'Executing request {request.request_id} in coroutine')
358
+ func = request.entrypoint
359
+ request_body = request.request_body
360
+ with api_requests.update_request(request.request_id) as request_task:
361
+ request_task.status = api_requests.RequestStatus.RUNNING
362
+ # Redirect stdout and stderr to the request log path.
363
+ original_output = ctx.redirect_log(request.log_path)
364
+ # Override environment variables that backs env_options.Options
365
+ # TODO(aylei): compared to process executor, running task in coroutine has
366
+ # two issues to fix:
367
+ # 1. skypilot config is not contextual
368
+ # 2. envs that read directly from os.environ are not contextual
369
+ ctx.override_envs(request_body.env_vars)
370
+ loop = asyncio.get_running_loop()
371
+ pyctx = contextvars.copy_context()
372
+ func_call = functools.partial(pyctx.run, func, **request_body.to_kwargs())
373
+ fut: asyncio.Future = loop.run_in_executor(None, func_call)
374
+
375
+ async def poll_task(request_id: str) -> bool:
376
+ request = api_requests.get_request(request_id)
377
+ if request is None:
378
+ raise RuntimeError('Request not found')
379
+
380
+ if request.status == api_requests.RequestStatus.CANCELLED:
381
+ ctx.cancel()
382
+ return True
383
+
384
+ if fut.done():
385
+ try:
386
+ result = await fut
387
+ api_requests.set_request_succeeded(request_id, result)
388
+ except asyncio.CancelledError:
389
+ # The task is cancelled by ctx.cancel(), where the status
390
+ # should already be set to CANCELLED.
391
+ pass
392
+ except Exception as e: # pylint: disable=broad-except
393
+ ctx.redirect_log(original_output)
394
+ api_requests.set_request_failed(request_id, e)
395
+ logger.error(f'Request {request_id} failed due to '
396
+ f'{common_utils.format_exception(e)}')
397
+ return True
398
+ return False
399
+
400
+ try:
401
+ while True:
402
+ res = await poll_task(request.request_id)
403
+ if res:
404
+ break
405
+ await asyncio.sleep(0.5)
406
+ except asyncio.CancelledError:
407
+ # Current coroutine is cancelled due to client disconnect, set the
408
+ # request status for consistency.
409
+ api_requests.set_request_cancelled(request.request_id)
410
+ pass
411
+ # pylint: disable=broad-except
412
+ except (Exception, KeyboardInterrupt, SystemExit) as e:
413
+ # Handle any other error
414
+ ctx.redirect_log(original_output)
415
+ ctx.cancel()
416
+ api_requests.set_request_failed(request.request_id, e)
417
+ logger.error(f'Request {request.request_id} interrupted due to '
418
+ f'unhandled exception: {common_utils.format_exception(e)}')
419
+ raise
420
+
421
+
422
+ def prepare_request(
423
+ request_id: str,
424
+ request_name: str,
425
+ request_body: payloads.RequestBody,
426
+ func: Callable[P, Any],
427
+ request_cluster_name: Optional[str] = None,
428
+ schedule_type: api_requests.ScheduleType = (api_requests.ScheduleType.LONG),
429
+ is_skypilot_system: bool = False,
430
+ ) -> api_requests.Request:
431
+ """Prepare a request for execution."""
432
+ user_id = request_body.env_vars[constants.USER_ID_ENV_VAR]
433
+ if is_skypilot_system:
434
+ user_id = server_constants.SKYPILOT_SYSTEM_USER_ID
435
+ global_user_state.add_or_update_user(
436
+ models.User(id=user_id, name=user_id))
437
+ request = api_requests.Request(request_id=request_id,
438
+ name=server_constants.REQUEST_NAME_PREFIX +
439
+ request_name,
440
+ entrypoint=func,
441
+ request_body=request_body,
442
+ status=api_requests.RequestStatus.PENDING,
443
+ created_at=time.time(),
444
+ schedule_type=schedule_type,
445
+ user_id=user_id,
446
+ cluster_name=request_cluster_name)
447
+
448
+ if not api_requests.create_if_not_exists(request):
449
+ raise RuntimeError(f'Request {request_id} already exists.')
450
+
451
+ request.log_path.touch()
452
+ return request
453
+
454
+
344
455
  def schedule_request(
345
456
  request_id: str,
346
457
  request_name: str,
@@ -372,27 +483,8 @@ def schedule_request(
372
483
  The precondition is waited asynchronously and does not block the
373
484
  caller.
374
485
  """
375
- user_id = request_body.env_vars[constants.USER_ID_ENV_VAR]
376
- if is_skypilot_system:
377
- user_id = server_constants.SKYPILOT_SYSTEM_USER_ID
378
- global_user_state.add_or_update_user(
379
- models.User(id=user_id, name=user_id))
380
- request = api_requests.Request(request_id=request_id,
381
- name=server_constants.REQUEST_NAME_PREFIX +
382
- request_name,
383
- entrypoint=func,
384
- request_body=request_body,
385
- status=api_requests.RequestStatus.PENDING,
386
- created_at=time.time(),
387
- schedule_type=schedule_type,
388
- user_id=user_id,
389
- cluster_name=request_cluster_name)
390
-
391
- if not api_requests.create_if_not_exists(request):
392
- logger.debug(f'Request {request_id} already exists.')
393
- return
394
-
395
- request.log_path.touch()
486
+ prepare_request(request_id, request_name, request_body, func,
487
+ request_cluster_name, schedule_type, is_skypilot_system)
396
488
 
397
489
  def enqueue():
398
490
  input_tuple = (request_id, ignore_return_value)
@@ -606,3 +606,18 @@ def set_request_failed(request_id: str, e: BaseException) -> None:
606
606
  assert request_task is not None, request_id
607
607
  request_task.status = RequestStatus.FAILED
608
608
  request_task.set_error(e)
609
+
610
+
611
+ def set_request_succeeded(request_id: str, result: Any) -> None:
612
+ """Set a request to succeeded and populate the result."""
613
+ with update_request(request_id) as request_task:
614
+ assert request_task is not None, request_id
615
+ request_task.status = RequestStatus.SUCCEEDED
616
+ request_task.set_return_value(result)
617
+
618
+
619
+ def set_request_cancelled(request_id: str) -> None:
620
+ """Set a request to cancelled."""
621
+ with update_request(request_id) as request_task:
622
+ assert request_task is not None, request_id
623
+ request_task.status = RequestStatus.CANCELLED
sky/server/server.py CHANGED
@@ -9,6 +9,7 @@ import logging
9
9
  import multiprocessing
10
10
  import os
11
11
  import pathlib
12
+ import posixpath
12
13
  import re
13
14
  import shutil
14
15
  import sys
@@ -47,6 +48,7 @@ from sky.usage import usage_lib
47
48
  from sky.utils import admin_policy_utils
48
49
  from sky.utils import common as common_lib
49
50
  from sky.utils import common_utils
51
+ from sky.utils import context
50
52
  from sky.utils import dag_utils
51
53
  from sky.utils import env_options
52
54
  from sky.utils import status_lib
@@ -166,8 +168,36 @@ class InternalDashboardPrefixMiddleware(
166
168
  return await call_next(request)
167
169
 
168
170
 
171
+ class CacheControlStaticMiddleware(starlette.middleware.base.BaseHTTPMiddleware
172
+ ):
173
+ """Middleware to add cache control headers to static files."""
174
+
175
+ async def dispatch(self, request: fastapi.Request, call_next):
176
+ if request.url.path.startswith('/dashboard/_next'):
177
+ response = await call_next(request)
178
+ response.headers['Cache-Control'] = 'max-age=3600'
179
+ return response
180
+ return await call_next(request)
181
+
182
+
183
+ class PathCleanMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
184
+ """Middleware to check the path of requests."""
185
+
186
+ async def dispatch(self, request: fastapi.Request, call_next):
187
+ if request.url.path.startswith('/dashboard/'):
188
+ # If the requested path is not relative to the expected directory,
189
+ # then the user is attempting path traversal, so deny the request.
190
+ parent = pathlib.Path('/dashboard')
191
+ request_path = pathlib.Path(posixpath.normpath(request.url.path))
192
+ if not _is_relative_to(request_path, parent):
193
+ raise fastapi.HTTPException(status_code=403, detail='Forbidden')
194
+ return await call_next(request)
195
+
196
+
169
197
  app = fastapi.FastAPI(prefix='/api/v1', debug=True, lifespan=lifespan)
170
198
  app.add_middleware(InternalDashboardPrefixMiddleware)
199
+ app.add_middleware(PathCleanMiddleware)
200
+ app.add_middleware(CacheControlStaticMiddleware)
171
201
  app.add_middleware(
172
202
  cors.CORSMiddleware,
173
203
  # TODO(zhwu): in production deployment, we should restrict the allowed
@@ -673,24 +703,28 @@ async def logs(
673
703
  # TODO(zhwu): This should wait for the request on the cluster, e.g., async
674
704
  # launch, to finish, so that a user does not need to manually pull the
675
705
  # request status.
676
- executor.schedule_request(
706
+ # Only initialize the context in logs handler to limit the scope of this
707
+ # experimental change.
708
+ # TODO(aylei): init in lifespan() to enable SkyPilot context in all APIs.
709
+ context.initialize()
710
+ request_task = executor.prepare_request(
677
711
  request_id=request.state.request_id,
678
712
  request_name='logs',
679
713
  request_body=cluster_job_body,
680
714
  func=core.tail_logs,
681
- # TODO(aylei): We have tail logs scheduled as SHORT request, because it
682
- # should be responsive. However, it can be long running if the user's
683
- # job keeps running, and we should avoid it taking the SHORT worker.
684
715
  schedule_type=requests_lib.ScheduleType.SHORT,
685
- request_cluster_name=cluster_job_body.cluster_name,
686
716
  )
717
+ task = asyncio.create_task(executor.execute_request_coroutine(request_task))
687
718
 
688
- request_task = requests_lib.get_request(request.state.request_id)
719
+ def cancel_task():
720
+ task.cancel()
689
721
 
722
+ # Cancel the task after the request is done or client disconnects
723
+ background_tasks.add_task(cancel_task)
690
724
  # TODO(zhwu): This makes viewing logs in browser impossible. We should adopt
691
725
  # the same approach as /stream.
692
726
  return stream_utils.stream_response(
693
- request_id=request_task.request_id,
727
+ request_id=request.state.request_id,
694
728
  logs_path=request_task.log_path,
695
729
  background_tasks=background_tasks,
696
730
  )
@@ -1125,25 +1159,28 @@ async def complete_storage_name(incomplete: str,) -> List[str]:
1125
1159
  return global_user_state.get_storage_names_start_with(incomplete)
1126
1160
 
1127
1161
 
1128
- # Add a route to serve static files
1129
- @app.get('/{full_path:path}')
1130
- async def serve_static_or_dashboard(full_path: str):
1131
- """Serves static files for any unmatched routes.
1162
+ @app.get('/dashboard/{full_path:path}')
1163
+ async def serve_dashboard(full_path: str):
1164
+ """Serves the Next.js dashboard application.
1132
1165
 
1133
- Handles the /dashboard prefix from Next.js configuration.
1134
- """
1135
- # Check if the path starts with 'dashboard/' and remove it if it does
1136
- if full_path.startswith('dashboard/'):
1137
- full_path = full_path[len('dashboard/'):]
1166
+ Args:
1167
+ full_path: The path requested by the client.
1168
+ e.g. /clusters, /jobs
1138
1169
 
1139
- # Try to serve the file directly from the out directory first
1170
+ Returns:
1171
+ FileResponse for static files or index.html for client-side routing.
1172
+
1173
+ Raises:
1174
+ HTTPException: If the path is invalid or file not found.
1175
+ """
1176
+ # Try to serve the staticfile directly e.g. /skypilot.svg,
1177
+ # /favicon.ico, and /_next/, etc.
1140
1178
  file_path = os.path.join(server_constants.DASHBOARD_DIR, full_path)
1141
1179
  if os.path.isfile(file_path):
1142
1180
  return fastapi.responses.FileResponse(file_path)
1143
1181
 
1144
- # If file not found, serve the index.html for client-side routing.
1145
- # For example, the non-matched arbitrary route (/ or /test) from
1146
- # client will be redirected to the index.html.
1182
+ # Serve index.html for client-side routing
1183
+ # e.g. /clusters, /jobs
1147
1184
  index_path = os.path.join(server_constants.DASHBOARD_DIR, 'index.html')
1148
1185
  try:
1149
1186
  with open(index_path, 'r', encoding='utf-8') as f:
@@ -1154,6 +1191,12 @@ async def serve_static_or_dashboard(full_path: str):
1154
1191
  raise fastapi.HTTPException(status_code=500, detail=str(e))
1155
1192
 
1156
1193
 
1194
+ # Redirect the root path to dashboard
1195
+ @app.get('/')
1196
+ async def root():
1197
+ return fastapi.responses.RedirectResponse(url='/dashboard/')
1198
+
1199
+
1157
1200
  if __name__ == '__main__':
1158
1201
  import uvicorn
1159
1202
 
sky/server/uvicorn.py CHANGED
@@ -3,6 +3,7 @@
3
3
  This module is a wrapper around uvicorn to customize the behavior of the
4
4
  server.
5
5
  """
6
+ import functools
6
7
  import os
7
8
  import threading
8
9
  from typing import Optional
@@ -10,6 +11,7 @@ from typing import Optional
10
11
  import uvicorn
11
12
  from uvicorn.supervisors import multiprocess
12
13
 
14
+ from sky.utils import context_utils
13
15
  from sky.utils import subprocess_utils
14
16
 
15
17
 
@@ -21,19 +23,27 @@ def run(config: uvicorn.Config):
21
23
  # guard by an exception.
22
24
  raise ValueError('Reload is not supported yet.')
23
25
  server = uvicorn.Server(config=config)
26
+ run_server_process = functools.partial(_run_server_process, server)
24
27
  try:
25
28
  if config.workers is not None and config.workers > 1:
26
29
  sock = config.bind_socket()
27
- SlowStartMultiprocess(config, target=server.run,
30
+ SlowStartMultiprocess(config,
31
+ target=run_server_process,
28
32
  sockets=[sock]).run()
29
33
  else:
30
- server.run()
34
+ run_server_process()
31
35
  finally:
32
36
  # Copied from unvicorn.run()
33
37
  if config.uds and os.path.exists(config.uds):
34
38
  os.remove(config.uds)
35
39
 
36
40
 
41
+ def _run_server_process(server: uvicorn.Server, *args, **kwargs):
42
+ """Run the server process with contextually aware."""
43
+ context_utils.hijack_sys_attrs()
44
+ server.run(*args, **kwargs)
45
+
46
+
37
47
  class SlowStartMultiprocess(multiprocess.Multiprocess):
38
48
  """Uvicorn Multiprocess wrapper with slow start.
39
49
 
@@ -12,7 +12,10 @@ install_requires = [
12
12
  'wheel<0.46.0', # https://github.com/skypilot-org/skypilot/issues/5153
13
13
  'cachetools',
14
14
  # NOTE: ray requires click>=7.0.
15
- 'click >= 7.0',
15
+ # click 8.2.0 has a bug in parsing the command line arguments:
16
+ # https://github.com/pallets/click/issues/2894
17
+ # TODO(aylei): remove this once the bug is fixed in click.
18
+ 'click >= 7.0, < 8.2.0',
16
19
  'colorama',
17
20
  'cryptography',
18
21
  # Jinja has a bug in older versions because of the lack of pinning
sky/sky_logging.py CHANGED
@@ -10,6 +10,7 @@ import threading
10
10
  import colorama
11
11
 
12
12
  from sky.skylet import constants
13
+ from sky.utils import context
13
14
  from sky.utils import env_options
14
15
  from sky.utils import rich_utils
15
16
 
@@ -47,6 +48,43 @@ class NewLineFormatter(logging.Formatter):
47
48
  return msg
48
49
 
49
50
 
51
+ class EnvAwareHandler(rich_utils.RichSafeStreamHandler):
52
+ """A handler that awares environment variables.
53
+
54
+ This handler dynamically reflects the log level from environment variables.
55
+ """
56
+
57
+ def __init__(self, stream=None, level=logging.NOTSET, sensitive=False):
58
+ super().__init__(stream)
59
+ self.level = level
60
+ self._sensitive = sensitive
61
+
62
+ @property
63
+ def level(self):
64
+ # Only refresh log level if we are in a context, since the log level
65
+ # has already been reloaded eagerly in multi-processing. Refresh again
66
+ # is a no-op and can be avoided.
67
+ # TODO(aylei): unify the mechanism for coroutine context and
68
+ # multi-processing.
69
+ if context.get() is not None:
70
+ if self._sensitive:
71
+ # For sensitive logger, suppress debug log despite the
72
+ # SKYPILOT_DEBUG env var if SUPPRESS_SENSITIVE_LOG is set
73
+ if env_options.Options.SUPPRESS_SENSITIVE_LOG.get():
74
+ return logging.INFO
75
+ if env_options.Options.SHOW_DEBUG_INFO.get():
76
+ return logging.DEBUG
77
+ else:
78
+ return self._level
79
+ else:
80
+ return self._level
81
+
82
+ @level.setter
83
+ def level(self, level):
84
+ # pylint: disable=protected-access
85
+ self._level = logging._checkLevel(level)
86
+
87
+
50
88
  _root_logger = logging.getLogger('sky')
51
89
  _default_handler = None
52
90
  _logging_config = threading.local()
@@ -67,7 +105,7 @@ def _setup_logger():
67
105
  _root_logger.setLevel(logging.DEBUG)
68
106
  global _default_handler
69
107
  if _default_handler is None:
70
- _default_handler = rich_utils.RichSafeStreamHandler(sys.stdout)
108
+ _default_handler = EnvAwareHandler(sys.stdout)
71
109
  _default_handler.flush = sys.stdout.flush # type: ignore
72
110
  if env_options.Options.SHOW_DEBUG_INFO.get():
73
111
  _default_handler.setLevel(logging.DEBUG)
@@ -87,7 +125,7 @@ def _setup_logger():
87
125
  # for certain loggers.
88
126
  for logger_name in _SENSITIVE_LOGGER:
89
127
  logger = logging.getLogger(logger_name)
90
- handler_to_logger = rich_utils.RichSafeStreamHandler(sys.stdout)
128
+ handler_to_logger = EnvAwareHandler(sys.stdout, sensitive=True)
91
129
  handler_to_logger.flush = sys.stdout.flush # type: ignore
92
130
  logger.addHandler(handler_to_logger)
93
131
  logger.setLevel(logging.INFO)
sky/skylet/log_lib.py CHANGED
@@ -4,6 +4,7 @@ This is a remote utility module that provides logging functionality.
4
4
  """
5
5
  import collections
6
6
  import copy
7
+ import functools
7
8
  import io
8
9
  import multiprocessing.pool
9
10
  import os
@@ -21,6 +22,8 @@ import colorama
21
22
  from sky import sky_logging
22
23
  from sky.skylet import constants
23
24
  from sky.skylet import job_lib
25
+ from sky.utils import context
26
+ from sky.utils import context_utils
24
27
  from sky.utils import log_utils
25
28
  from sky.utils import subprocess_utils
26
29
  from sky.utils import ux_utils
@@ -59,6 +62,16 @@ class _ProcessingArgs:
59
62
  self.streaming_prefix = streaming_prefix
60
63
 
61
64
 
65
+ def _get_context():
66
+ # TODO(aylei): remove this after we drop the backward-compatibility for
67
+ # 0.9.x in 0.12.0
68
+ # Keep backward-compatibility for the old version of SkyPilot runtimes.
69
+ if 'context' in globals():
70
+ return context.get()
71
+ else:
72
+ return None
73
+
74
+
62
75
  def _handle_io_stream(io_stream, out_stream, args: _ProcessingArgs):
63
76
  """Process the stream of a process."""
64
77
  out_io = io.TextIOWrapper(io_stream,
@@ -77,6 +90,9 @@ def _handle_io_stream(io_stream, out_stream, args: _ProcessingArgs):
77
90
  with open(args.log_path, 'a', encoding='utf-8') as fout:
78
91
  with line_processor:
79
92
  while True:
93
+ ctx = _get_context()
94
+ if ctx is not None and ctx.is_canceled():
95
+ return
80
96
  line = out_io.readline()
81
97
  if not line:
82
98
  break
@@ -111,26 +127,24 @@ def _handle_io_stream(io_stream, out_stream, args: _ProcessingArgs):
111
127
  return ''.join(out)
112
128
 
113
129
 
114
- def process_subprocess_stream(proc, args: _ProcessingArgs) -> Tuple[str, str]:
115
- """Redirect the process's filtered stdout/stderr to both stream and file"""
130
+ def process_subprocess_stream(proc, stdout_stream_handler,
131
+ stderr_stream_handler) -> Tuple[str, str]:
132
+ """Process the stream of a process in threads, blocking."""
116
133
  if proc.stderr is not None:
117
134
  # Asyncio does not work as the output processing can be executed in a
118
135
  # different thread.
119
136
  # selectors is possible to handle the multiplexing of stdout/stderr,
120
137
  # but it introduces buffering making the output not streaming.
121
138
  with multiprocessing.pool.ThreadPool(processes=1) as pool:
122
- err_args = copy.copy(args)
123
- err_args.line_processor = None
124
- stderr_fut = pool.apply_async(_handle_io_stream,
125
- args=(proc.stderr, sys.stderr,
126
- err_args))
139
+ stderr_fut = pool.apply_async(stderr_stream_handler,
140
+ args=(proc.stderr, sys.stderr))
127
141
  # Do not launch a thread for stdout as the rich.status does not
128
142
  # work in a thread, which is used in
129
143
  # log_utils.RayUpLineProcessor.
130
- stdout = _handle_io_stream(proc.stdout, sys.stdout, args)
144
+ stdout = stdout_stream_handler(proc.stdout, sys.stdout)
131
145
  stderr = stderr_fut.get()
132
146
  else:
133
- stdout = _handle_io_stream(proc.stdout, sys.stdout, args)
147
+ stdout = stdout_stream_handler(proc.stdout, sys.stdout)
134
148
  stderr = ''
135
149
  return stdout, stderr
136
150
 
@@ -176,7 +190,12 @@ def run_with_log(
176
190
  # Redirect stderr to stdout when using ray, to preserve the order of
177
191
  # stdout and stderr.
178
192
  stdout_arg = stderr_arg = None
179
- if process_stream:
193
+ ctx = _get_context()
194
+ if process_stream or ctx is not None:
195
+ # Capture stdout/stderr of the subprocess if:
196
+ # 1. Post-processing is needed (process_stream=True)
197
+ # 2. Potential contextual handling is needed (ctx is not None)
198
+ # TODO(aylei): can we always capture the stdout/stderr?
180
199
  stdout_arg = subprocess.PIPE
181
200
  stderr_arg = subprocess.PIPE if not with_ray else subprocess.STDOUT
182
201
  # Use stdin=subprocess.DEVNULL by default, as allowing inputs will mess up
@@ -197,6 +216,8 @@ def run_with_log(
197
216
  subprocess_utils.kill_process_daemon(proc.pid)
198
217
  stdout = ''
199
218
  stderr = ''
219
+ stdout_stream_handler = None
220
+ stderr_stream_handler = None
200
221
 
201
222
  if process_stream:
202
223
  if skip_lines is None:
@@ -223,7 +244,35 @@ def run_with_log(
223
244
  replace_crlf=with_ray,
224
245
  streaming_prefix=streaming_prefix,
225
246
  )
226
- stdout, stderr = process_subprocess_stream(proc, args)
247
+ stdout_stream_handler = functools.partial(
248
+ _handle_io_stream,
249
+ args=args,
250
+ )
251
+ if proc.stderr is not None:
252
+ err_args = copy.copy(args)
253
+ err_args.line_processor = None
254
+ stderr_stream_handler = functools.partial(
255
+ _handle_io_stream,
256
+ args=err_args,
257
+ )
258
+ if ctx is not None:
259
+ # When runs in a coroutine, always process the subprocess
260
+ # stream to:
261
+ # 1. handle context cancellation
262
+ # 2. redirect subprocess stdout/stderr to the contextual
263
+ # stdout/stderr of current coroutine.
264
+ stdout, stderr = context_utils.pipe_and_wait_process(
265
+ ctx,
266
+ proc,
267
+ cancel_callback=subprocess_utils.kill_children_processes,
268
+ stdout_stream_handler=stdout_stream_handler,
269
+ stderr_stream_handler=stderr_stream_handler)
270
+ elif process_stream:
271
+ # When runs in a process, only process subprocess stream if
272
+ # necessary to avoid unnecessary stream handling overhead.
273
+ stdout, stderr = process_subprocess_stream(
274
+ proc, stdout_stream_handler, stderr_stream_handler)
275
+ # Ensure returncode is set.
227
276
  proc.wait()
228
277
  if require_outputs:
229
278
  return proc.returncode, stdout, stderr
sky/skylet/log_lib.pyi CHANGED
@@ -11,6 +11,7 @@ from typing_extensions import Literal
11
11
  from sky import sky_logging as sky_logging
12
12
  from sky.skylet import constants as constants
13
13
  from sky.skylet import job_lib as job_lib
14
+ from sky.utils import context
14
15
  from sky.utils import log_utils as log_utils
15
16
 
16
17
  SKY_LOG_WAITING_GAP_SECONDS: int = ...
@@ -41,6 +42,10 @@ class _ProcessingArgs:
41
42
  ...
42
43
 
43
44
 
45
+ def _get_context() -> Optional[context.Context]:
46
+ ...
47
+
48
+
44
49
  def _handle_io_stream(io_stream, out_stream, args: _ProcessingArgs):
45
50
  ...
46
51
 
sky/task.py CHANGED
@@ -165,7 +165,8 @@ def _with_docker_login_config(
165
165
  f'ignored.{colorama.Style.RESET_ALL}')
166
166
  return resources
167
167
  # Already checked in extract_docker_image
168
- assert len(resources.image_id) == 1, resources.image_id
168
+ assert resources.image_id is not None and len(
169
+ resources.image_id) == 1, resources.image_id
169
170
  region = list(resources.image_id.keys())[0]
170
171
  return resources.copy(image_id={region: 'docker:' + docker_image},
171
172
  _docker_login_config=docker_login_config)
@@ -775,7 +776,7 @@ class Task:
775
776
  for _, storage_obj in self.storage_mounts.items():
776
777
  if storage_obj.mode in storage_lib.MOUNTABLE_STORAGE_MODES:
777
778
  for r in self.resources:
778
- r.requires_fuse = True
779
+ r.set_requires_fuse(True)
779
780
  break
780
781
 
781
782
  return self
@@ -931,7 +932,7 @@ class Task:
931
932
  self.storage_mounts = {}
932
933
  # Clear the requires_fuse flag if no storage mounts are set.
933
934
  for r in self.resources:
934
- r.requires_fuse = False
935
+ r.set_requires_fuse(False)
935
936
  return self
936
937
  for target, storage_obj in storage_mounts.items():
937
938
  # TODO(zhwu): /home/username/sky_workdir as the target path need
@@ -956,7 +957,7 @@ class Task:
956
957
  # If any storage is using MOUNT mode, we need to enable FUSE in
957
958
  # the resources.
958
959
  for r in self.resources:
959
- r.requires_fuse = True
960
+ r.set_requires_fuse(True)
960
961
  # Storage source validation is done in Storage object
961
962
  self.storage_mounts = storage_mounts
962
963
  return self
@@ -1234,13 +1235,14 @@ class Task:
1234
1235
 
1235
1236
  add_if_not_none('name', self.name)
1236
1237
 
1237
- tmp_resource_config = {}
1238
+ tmp_resource_config: Union[Dict[str, Union[str, int]],
1239
+ Dict[str, List[Dict[str, Union[str, int]]]]]
1238
1240
  if len(self.resources) > 1:
1239
1241
  resource_list = []
1240
1242
  for r in self.resources:
1241
1243
  resource_list.append(r.to_yaml_config())
1242
1244
  key = 'ordered' if isinstance(self.resources, list) else 'any_of'
1243
- tmp_resource_config[key] = resource_list
1245
+ tmp_resource_config = {key: resource_list}
1244
1246
  else:
1245
1247
  tmp_resource_config = list(self.resources)[0].to_yaml_config()
1246
1248