skypilot-nightly 1.0.0.dev20250311__py3-none-any.whl → 1.0.0.dev20250313__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. sky/__init__.py +2 -2
  2. sky/adaptors/gcp.py +7 -0
  3. sky/adaptors/nebius.py +11 -1
  4. sky/backends/backend_utils.py +38 -15
  5. sky/backends/cloud_vm_ray_backend.py +17 -52
  6. sky/cli.py +26 -13
  7. sky/client/cli.py +26 -13
  8. sky/client/sdk.py +2 -9
  9. sky/clouds/gcp.py +4 -1
  10. sky/clouds/nebius.py +8 -6
  11. sky/data/storage.py +16 -0
  12. sky/exceptions.py +11 -3
  13. sky/provision/kubernetes/utils.py +10 -1
  14. sky/server/common.py +16 -0
  15. sky/server/requests/event_loop.py +31 -0
  16. sky/server/requests/executor.py +50 -22
  17. sky/server/requests/preconditions.py +174 -0
  18. sky/server/requests/requests.py +43 -4
  19. sky/server/server.py +29 -8
  20. sky/server/stream_utils.py +9 -6
  21. sky/server/uvicorn.py +81 -0
  22. sky/setup_files/dependencies.py +4 -1
  23. sky/utils/accelerator_registry.py +1 -1
  24. sky/utils/controller_utils.py +10 -0
  25. sky/utils/subprocess_utils.py +56 -1
  26. {skypilot_nightly-1.0.0.dev20250311.dist-info → skypilot_nightly-1.0.0.dev20250313.dist-info}/METADATA +3 -3
  27. {skypilot_nightly-1.0.0.dev20250311.dist-info → skypilot_nightly-1.0.0.dev20250313.dist-info}/RECORD +31 -28
  28. {skypilot_nightly-1.0.0.dev20250311.dist-info → skypilot_nightly-1.0.0.dev20250313.dist-info}/LICENSE +0 -0
  29. {skypilot_nightly-1.0.0.dev20250311.dist-info → skypilot_nightly-1.0.0.dev20250313.dist-info}/WHEEL +0 -0
  30. {skypilot_nightly-1.0.0.dev20250311.dist-info → skypilot_nightly-1.0.0.dev20250313.dist-info}/entry_points.txt +0 -0
  31. {skypilot_nightly-1.0.0.dev20250311.dist-info → skypilot_nightly-1.0.0.dev20250313.dist-info}/top_level.txt +0 -0
@@ -27,8 +27,8 @@ import os
27
27
  import queue as queue_lib
28
28
  import signal
29
29
  import sys
30
+ import threading
30
31
  import time
31
- import traceback
32
32
  import typing
33
33
  from typing import Any, Callable, Generator, List, Optional, TextIO, Tuple
34
34
 
@@ -41,11 +41,13 @@ from sky import skypilot_config
41
41
  from sky.server import common as server_common
42
42
  from sky.server import constants as server_constants
43
43
  from sky.server.requests import payloads
44
+ from sky.server.requests import preconditions
44
45
  from sky.server.requests import requests as api_requests
45
46
  from sky.server.requests.queues import mp_queue
46
47
  from sky.skylet import constants
47
48
  from sky.utils import annotations
48
49
  from sky.utils import common_utils
50
+ from sky.utils import subprocess_utils
49
51
  from sky.utils import timeline
50
52
  from sky.utils import ux_utils
51
53
 
@@ -262,13 +264,7 @@ def _request_execution_wrapper(request_id: str,
262
264
  _restore_output(original_stdout, original_stderr)
263
265
  return
264
266
  except (Exception, SystemExit) as e: # pylint: disable=broad-except
265
- with ux_utils.enable_traceback():
266
- stacktrace = traceback.format_exc()
267
- setattr(e, 'stacktrace', stacktrace)
268
- with api_requests.update_request(request_id) as request_task:
269
- assert request_task is not None, request_id
270
- request_task.status = api_requests.RequestStatus.FAILED
271
- request_task.set_error(e)
267
+ api_requests.set_request_failed(request_id, e)
272
268
  _restore_output(original_stdout, original_stderr)
273
269
  logger.info(f'Request {request_id} failed due to '
274
270
  f'{common_utils.format_exception(e)}')
@@ -283,16 +279,37 @@ def _request_execution_wrapper(request_id: str,
283
279
  logger.info(f'Request {request_id} finished')
284
280
 
285
281
 
286
- def schedule_request(request_id: str,
287
- request_name: str,
288
- request_body: payloads.RequestBody,
289
- func: Callable[P, Any],
290
- request_cluster_name: Optional[str] = None,
291
- ignore_return_value: bool = False,
292
- schedule_type: api_requests.ScheduleType = api_requests.
293
- ScheduleType.LONG,
294
- is_skypilot_system: bool = False) -> None:
295
- """Enqueue a request to the request queue."""
282
+ def schedule_request(
283
+ request_id: str,
284
+ request_name: str,
285
+ request_body: payloads.RequestBody,
286
+ func: Callable[P, Any],
287
+ request_cluster_name: Optional[str] = None,
288
+ ignore_return_value: bool = False,
289
+ schedule_type: api_requests.ScheduleType = (
290
+ api_requests.ScheduleType.LONG),
291
+ is_skypilot_system: bool = False,
292
+ precondition: Optional[preconditions.Precondition] = None) -> None:
293
+ """Enqueue a request to the request queue.
294
+
295
+ Args:
296
+ request_id: ID of the request.
297
+ request_name: Name of the request type, e.g. "sky.launch".
298
+ request_body: The request body containing parameters and environment
299
+ variables.
300
+ func: The function to execute when the request is processed.
301
+ request_cluster_name: The name of the cluster associated with this
302
+ request, if any.
303
+ ignore_return_value: If True, the return value of the function will be
304
+ ignored.
305
+ schedule_type: The type of scheduling to use for this request, refer to
306
+ `api_requests.ScheduleType` for more details.
307
+ is_skypilot_system: Denote whether the request is from SkyPilot system.
308
+ precondition: If a precondition is provided, the request will only be
309
+ scheduled for execution when the precondition is met (returns True).
310
+ The precondition is waited asynchronously and does not block the
311
+ caller.
312
+ """
296
313
  user_id = request_body.env_vars[constants.USER_ID_ENV_VAR]
297
314
  if is_skypilot_system:
298
315
  user_id = server_constants.SKYPILOT_SYSTEM_USER_ID
@@ -314,10 +331,17 @@ def schedule_request(request_id: str,
314
331
  return
315
332
 
316
333
  request.log_path.touch()
317
- input_tuple = (request_id, ignore_return_value)
318
334
 
319
- logger.info(f'Queuing request: {request_id}')
320
- _get_queue(schedule_type).put(input_tuple)
335
+ def enqueue():
336
+ input_tuple = (request_id, ignore_return_value)
337
+ logger.info(f'Queuing request: {request_id}')
338
+ _get_queue(schedule_type).put(input_tuple)
339
+
340
+ if precondition is not None:
341
+ # Wait async to avoid blocking caller.
342
+ precondition.wait_async(on_condition_met=enqueue)
343
+ else:
344
+ enqueue()
321
345
 
322
346
 
323
347
  def executor_initializer(proc_group: str):
@@ -431,13 +455,17 @@ def start(deploy: bool) -> List[multiprocessing.Process]:
431
455
 
432
456
  logger.info('Request queues created')
433
457
 
458
+ long_workers = []
434
459
  for worker_id in range(max_parallel_for_long):
435
460
  worker = RequestWorker(id=worker_id,
436
461
  schedule_type=api_requests.ScheduleType.LONG)
437
462
  worker_proc = multiprocessing.Process(target=request_worker,
438
463
  args=(worker, 1))
439
- worker_proc.start()
464
+ long_workers.append(worker_proc)
440
465
  sub_procs.append(worker_proc)
466
+ threading.Thread(target=subprocess_utils.slow_start_processes,
467
+ args=(long_workers,),
468
+ daemon=True).start()
441
469
 
442
470
  # Start a worker for short requests.
443
471
  worker = RequestWorker(id=1, schedule_type=api_requests.ScheduleType.SHORT)
@@ -0,0 +1,174 @@
1
+ """Precondition for a request to be executed.
2
+
3
+ Preconditions are introduced so that:
4
+ - Wait for precondition does not block executor process, which is expensive;
5
+ - Cross requests knowledge (e.g. waiting for other requests to be completed)
6
+ can be handled at precondition level, instead of invading the execution
7
+ logic of specific requests.
8
+ """
9
+ import abc
10
+ import asyncio
11
+ import time
12
+ from typing import Callable, Optional, Tuple
13
+
14
+ from sky import exceptions
15
+ from sky import global_user_state
16
+ from sky import sky_logging
17
+ from sky.server.requests import event_loop
18
+ from sky.server.requests import requests as api_requests
19
+ from sky.utils import common_utils
20
+ from sky.utils import status_lib
21
+
22
+ # The default interval seconds to check the precondition.
23
+ _PRECONDITION_CHECK_INTERVAL = 1
24
+ # The default timeout seconds to wait for the precondition to be met.
25
+ _PRECONDITION_TIMEOUT = 60 * 60
26
+
27
+ logger = sky_logging.init_logger(__name__)
28
+
29
+
30
+ class Precondition(abc.ABC):
31
+ """Abstract base class for a precondition for a request to be executed.
32
+
33
+ A Precondition can be waited in either of the following ways:
34
+ - await Precondition: wait for the precondition to be met.
35
+ - Precondition.wait_async: wait for the precondition to be met in background
36
+ and execute the given callback on met.
37
+ """
38
+
39
+ def __init__(self,
40
+ request_id: str,
41
+ check_interval: float = _PRECONDITION_CHECK_INTERVAL,
42
+ timeout: float = _PRECONDITION_TIMEOUT):
43
+ self.request_id = request_id
44
+ self.check_interval = check_interval
45
+ self.timeout = timeout
46
+
47
+ def __await__(self):
48
+ """Make Precondition awaitable."""
49
+ return self._wait().__await__()
50
+
51
+ def wait_async(
52
+ self,
53
+ on_condition_met: Optional[Callable[[], None]] = None) -> None:
54
+ """Wait precondition asynchronously and execute the callback on met."""
55
+
56
+ async def wait_with_callback():
57
+ met = await self
58
+ if met and on_condition_met is not None:
59
+ on_condition_met()
60
+
61
+ event_loop.run(wait_with_callback())
62
+
63
+ @abc.abstractmethod
64
+ async def check(self) -> Tuple[bool, Optional[str]]:
65
+ """Check if the precondition is met.
66
+
67
+ Note that compared to _request_execution_wrapper, the env vars and
68
+ skypilot config here are not overridden since the lack of process
69
+ isolation, which may cause issues if the check accidentally depends on
70
+ these. Make sure the check function is independent of the request
71
+ environment.
72
+ TODO(aylei): a new request context isolation mechanism is needed to
73
+ enable more tasks/sub-tasks to be processed in coroutines or threads.
74
+
75
+ Returns:
76
+ A tuple of (bool, Optional[str]).
77
+ The bool indicates if the precondition is met.
78
+ The str is the current status of the precondition if any.
79
+ """
80
+ raise NotImplementedError
81
+
82
+ async def _wait(self) -> bool:
83
+ """Wait for the precondition to be met.
84
+
85
+ Args:
86
+ on_condition_met: Callback to execute when the precondition is met.
87
+ """
88
+ start_time = time.time()
89
+ last_status_msg = ''
90
+ while True:
91
+ if self.timeout > 0 and time.time() - start_time > self.timeout:
92
+ # Cancel the request on timeout.
93
+ api_requests.set_request_failed(
94
+ self.request_id,
95
+ exceptions.RequestCancelled(
96
+ f'Request {self.request_id} precondition wait timed '
97
+ f'out after {self.timeout}s'))
98
+ return False
99
+
100
+ # Check if the request has been cancelled
101
+ request = api_requests.get_request(self.request_id)
102
+ if request is None:
103
+ logger.error(f'Request {self.request_id} not found')
104
+ return False
105
+ if request.status == api_requests.RequestStatus.CANCELLED:
106
+ logger.debug(f'Request {self.request_id} cancelled')
107
+ return False
108
+
109
+ try:
110
+ met, status_msg = await self.check()
111
+ if met:
112
+ return True
113
+ if status_msg is not None and status_msg != last_status_msg:
114
+ # Update the status message if it has changed.
115
+ with api_requests.update_request(self.request_id) as req:
116
+ assert req is not None, self.request_id
117
+ req.status_msg = status_msg
118
+ last_status_msg = status_msg
119
+ except (Exception, SystemExit, KeyboardInterrupt) as e: # pylint: disable=broad-except
120
+ api_requests.set_request_failed(self.request_id, e)
121
+ logger.info(f'Request {self.request_id} failed due to '
122
+ f'{common_utils.format_exception(e)}')
123
+ return False
124
+
125
+ await asyncio.sleep(self.check_interval)
126
+
127
+
128
+ class ClusterStartCompletePrecondition(Precondition):
129
+ """Whether the start process of a cluster is complete.
130
+
131
+ This condition only waits the start process of a cluster to complete, e.g.
132
+ `sky launch` or `sky start`.
133
+ For cluster that has been started but not in UP status, bypass the waiting
134
+ in favor of:
135
+ - allowing the task to refresh cluster status from cloud vendor;
136
+ - unified error message in task handlers.
137
+
138
+ Args:
139
+ request_id: The request ID of the task.
140
+ cluster_name: The name of the cluster to wait for.
141
+ """
142
+
143
+ def __init__(self, request_id: str, cluster_name: str, **kwargs):
144
+ super().__init__(request_id=request_id, **kwargs)
145
+ self.cluster_name = cluster_name
146
+
147
+ async def check(self) -> Tuple[bool, Optional[str]]:
148
+ cluster_record = global_user_state.get_cluster_from_name(
149
+ self.cluster_name)
150
+ if (cluster_record and
151
+ cluster_record['status'] is status_lib.ClusterStatus.UP):
152
+ # Shortcut for started clusters, ignore cluster not found
153
+ # since the cluster record might not yet be created by the
154
+ # launch task.
155
+ return True, None
156
+ # Check if there is a task starting the cluster, we do not check
157
+ # SUCCEEDED requests since successfully launched cluster can be
158
+ # restarted later on.
159
+ # Note that since the requests are not persistent yet between restarts,
160
+ # a cluster might be started in halfway and requests are lost.
161
+ # We unify these situations into a single state: the process of starting
162
+ # the cluster is done (either normally or abnormally) but cluster is not
163
+ # in UP status.
164
+ requests = api_requests.get_request_tasks(
165
+ status=[
166
+ api_requests.RequestStatus.RUNNING,
167
+ api_requests.RequestStatus.PENDING
168
+ ],
169
+ include_request_names=['sky.launch', 'sky.start'],
170
+ cluster_names=[self.cluster_name])
171
+ if len(requests) == 0:
172
+ # No runnning or pending tasks, the start process is done.
173
+ return True, None
174
+ return False, f'Waiting for cluster {self.cluster_name} to be UP.'
@@ -10,6 +10,7 @@ import shutil
10
10
  import signal
11
11
  import sqlite3
12
12
  import time
13
+ import traceback
13
14
  from typing import Any, Callable, Dict, List, Optional, Tuple
14
15
 
15
16
  import colorama
@@ -27,6 +28,7 @@ from sky.utils import common
27
28
  from sky.utils import common_utils
28
29
  from sky.utils import db_utils
29
30
  from sky.utils import env_options
31
+ from sky.utils import ux_utils
30
32
 
31
33
  logger = sky_logging.init_logger(__name__)
32
34
 
@@ -34,6 +36,7 @@ logger = sky_logging.init_logger(__name__)
34
36
  REQUEST_TABLE = 'requests'
35
37
  COL_CLUSTER_NAME = 'cluster_name'
36
38
  COL_USER_ID = 'user_id'
39
+ COL_STATUS_MSG = 'status_msg'
37
40
  REQUEST_LOG_PATH_PREFIX = '~/sky_logs/api_server/requests'
38
41
 
39
42
  # TODO(zhwu): For scalability, there are several TODOs:
@@ -81,6 +84,7 @@ REQUEST_COLUMNS = [
81
84
  COL_CLUSTER_NAME,
82
85
  'schedule_type',
83
86
  COL_USER_ID,
87
+ COL_STATUS_MSG,
84
88
  ]
85
89
 
86
90
 
@@ -109,6 +113,7 @@ class RequestPayload:
109
113
  user_name: Optional[str] = None
110
114
  # Resources the request operates on.
111
115
  cluster_name: Optional[str] = None
116
+ status_msg: Optional[str] = None
112
117
 
113
118
 
114
119
  @dataclasses.dataclass
@@ -129,6 +134,8 @@ class Request:
129
134
  schedule_type: ScheduleType = ScheduleType.LONG
130
135
  # Resources the request operates on.
131
136
  cluster_name: Optional[str] = None
137
+ # Status message of the request, indicates the reason of current status.
138
+ status_msg: Optional[str] = None
132
139
 
133
140
  @property
134
141
  def log_path(self) -> pathlib.Path:
@@ -138,7 +145,7 @@ class Request:
138
145
  log_path = (log_path_prefix / self.request_id).with_suffix('.log')
139
146
  return log_path
140
147
 
141
- def set_error(self, error: Exception) -> None:
148
+ def set_error(self, error: BaseException) -> None:
142
149
  """Set the error."""
143
150
  # TODO(zhwu): pickle.dump does not work well with custom exceptions if
144
151
  # it has more than 1 arguments.
@@ -212,6 +219,7 @@ class Request:
212
219
  user_id=self.user_id,
213
220
  user_name=user_name,
214
221
  cluster_name=self.cluster_name,
222
+ status_msg=self.status_msg,
215
223
  )
216
224
 
217
225
  def encode(self) -> RequestPayload:
@@ -232,6 +240,7 @@ class Request:
232
240
  schedule_type=self.schedule_type.value,
233
241
  user_id=self.user_id,
234
242
  cluster_name=self.cluster_name,
243
+ status_msg=self.status_msg,
235
244
  )
236
245
  except (TypeError, ValueError) as e:
237
246
  # The error is unexpected, so we don't suppress the stack trace.
@@ -262,6 +271,7 @@ class Request:
262
271
  schedule_type=ScheduleType(payload.schedule_type),
263
272
  user_id=payload.user_id,
264
273
  cluster_name=payload.cluster_name,
274
+ status_msg=payload.status_msg,
265
275
  )
266
276
  except (TypeError, ValueError) as e:
267
277
  logger.error(
@@ -415,7 +425,8 @@ def create_table(cursor, conn):
415
425
  pid INTEGER,
416
426
  {COL_CLUSTER_NAME} TEXT,
417
427
  schedule_type TEXT,
418
- {COL_USER_ID} TEXT)""")
428
+ {COL_USER_ID} TEXT,
429
+ {COL_STATUS_MSG} TEXT)""")
419
430
 
420
431
 
421
432
  _DB = None
@@ -436,7 +447,7 @@ def init_db(func):
436
447
 
437
448
  def reset_db_and_logs():
438
449
  """Create the database."""
439
- common_utils.remove_file_if_exists(_DB_PATH)
450
+ server_common.clear_local_api_server_database()
440
451
  shutil.rmtree(pathlib.Path(REQUEST_LOG_PATH_PREFIX).expanduser(),
441
452
  ignore_errors=True)
442
453
  shutil.rmtree(server_common.API_SERVER_CLIENT_DIR.expanduser(),
@@ -507,8 +518,9 @@ def create_if_not_exists(request: Request) -> bool:
507
518
  def get_request_tasks(
508
519
  status: Optional[List[RequestStatus]] = None,
509
520
  cluster_names: Optional[List[str]] = None,
510
- exclude_request_names: Optional[List[str]] = None,
511
521
  user_id: Optional[str] = None,
522
+ exclude_request_names: Optional[List[str]] = None,
523
+ include_request_names: Optional[List[str]] = None,
512
524
  ) -> List[Request]:
513
525
  """Get a list of requests that match the given filters.
514
526
 
@@ -516,9 +528,21 @@ def get_request_tasks(
516
528
  status: a list of statuses of the requests to filter on.
517
529
  cluster_names: a list of cluster names to filter requests on.
518
530
  exclude_request_names: a list of request names to exclude from results.
531
+ Mutually exclusive with include_request_names.
519
532
  user_id: the user ID to filter requests on.
520
533
  If None, all users are included.
534
+ include_request_names: a list of request names to filter on.
535
+ Mutually exclusive with exclude_request_names.
536
+
537
+ Raises:
538
+ ValueError: If both exclude_request_names and include_request_names are
539
+ provided.
521
540
  """
541
+ if exclude_request_names is not None and include_request_names is not None:
542
+ raise ValueError(
543
+ 'Only one of exclude_request_names or include_request_names can be '
544
+ 'provided, not both.')
545
+
522
546
  filters = []
523
547
  filter_params = []
524
548
  if status is not None:
@@ -534,6 +558,10 @@ def get_request_tasks(
534
558
  if user_id is not None:
535
559
  filters.append(f'{COL_USER_ID} = ?')
536
560
  filter_params.append(user_id)
561
+ if include_request_names is not None:
562
+ request_names_str = ','.join(
563
+ repr(name) for name in include_request_names)
564
+ filters.append(f'name IN ({request_names_str})')
537
565
  assert _DB is not None
538
566
  with _DB.conn:
539
567
  cursor = _DB.conn.cursor()
@@ -565,3 +593,14 @@ def _add_or_update_request_no_lock(request: Request):
565
593
  cursor.execute(
566
594
  f'INSERT OR REPLACE INTO {REQUEST_TABLE} ({key_str}) '
567
595
  f'VALUES ({fill_str})', row)
596
+
597
+
598
+ def set_request_failed(request_id: str, e: BaseException) -> None:
599
+ """Set a request to failed and populate the error message."""
600
+ with ux_utils.enable_traceback():
601
+ stacktrace = traceback.format_exc()
602
+ setattr(e, 'stacktrace', stacktrace)
603
+ with update_request(request_id) as request_task:
604
+ assert request_task is not None, request_id
605
+ request_task.status = RequestStatus.FAILED
606
+ request_task.set_error(e)
sky/server/server.py CHANGED
@@ -6,6 +6,7 @@ import contextlib
6
6
  import dataclasses
7
7
  import datetime
8
8
  import logging
9
+ import multiprocessing
9
10
  import os
10
11
  import pathlib
11
12
  import re
@@ -38,6 +39,7 @@ from sky.server import constants as server_constants
38
39
  from sky.server import stream_utils
39
40
  from sky.server.requests import executor
40
41
  from sky.server.requests import payloads
42
+ from sky.server.requests import preconditions
41
43
  from sky.server.requests import requests as requests_lib
42
44
  from sky.skylet import constants
43
45
  from sky.usage import usage_lib
@@ -47,6 +49,7 @@ from sky.utils import common_utils
47
49
  from sky.utils import dag_utils
48
50
  from sky.utils import env_options
49
51
  from sky.utils import status_lib
52
+ from sky.utils import subprocess_utils
50
53
 
51
54
  # pylint: disable=ungrouped-imports
52
55
  if sys.version_info >= (3, 10):
@@ -496,13 +499,18 @@ async def launch(launch_body: payloads.LaunchBody,
496
499
  # pylint: disable=redefined-builtin
497
500
  async def exec(request: fastapi.Request, exec_body: payloads.ExecBody) -> None:
498
501
  """Executes a task on an existing cluster."""
502
+ cluster_name = exec_body.cluster_name
499
503
  executor.schedule_request(
500
504
  request_id=request.state.request_id,
501
505
  request_name='exec',
502
506
  request_body=exec_body,
503
507
  func=execution.exec,
508
+ precondition=preconditions.ClusterStartCompletePrecondition(
509
+ request_id=request.state.request_id,
510
+ cluster_name=cluster_name,
511
+ ),
504
512
  schedule_type=requests_lib.ScheduleType.LONG,
505
- request_cluster_name=exec_body.cluster_name,
513
+ request_cluster_name=cluster_name,
506
514
  )
507
515
 
508
516
 
@@ -1088,6 +1096,9 @@ async def complete_storage_name(incomplete: str,) -> List[str]:
1088
1096
 
1089
1097
  if __name__ == '__main__':
1090
1098
  import uvicorn
1099
+
1100
+ from sky.server import uvicorn as skyuvicorn
1101
+
1091
1102
  requests_lib.reset_db_and_logs()
1092
1103
 
1093
1104
  parser = argparse.ArgumentParser()
@@ -1109,16 +1120,26 @@ if __name__ == '__main__':
1109
1120
  logger.info(f'Starting SkyPilot API server, workers={num_workers}')
1110
1121
  # We don't support reload for now, since it may cause leakage of request
1111
1122
  # workers or interrupt running requests.
1112
- uvicorn.run('sky.server.server:app',
1113
- host=cmd_args.host,
1114
- port=cmd_args.port,
1115
- workers=num_workers)
1123
+ config = uvicorn.Config('sky.server.server:app',
1124
+ host=cmd_args.host,
1125
+ port=cmd_args.port,
1126
+ workers=num_workers)
1127
+ skyuvicorn.run(config)
1116
1128
  except Exception as exc: # pylint: disable=broad-except
1117
1129
  logger.error(f'Failed to start SkyPilot API server: '
1118
1130
  f'{common_utils.format_exception(exc, use_bracket=True)}')
1119
1131
  raise
1120
1132
  finally:
1121
1133
  logger.info('Shutting down SkyPilot API server...')
1122
- for sub_proc in sub_procs:
1123
- sub_proc.terminate()
1124
- sub_proc.join()
1134
+
1135
+ def cleanup(proc: multiprocessing.Process) -> None:
1136
+ try:
1137
+ proc.terminate()
1138
+ proc.join()
1139
+ finally:
1140
+ # The process may not be started yet, close it anyway.
1141
+ proc.close()
1142
+
1143
+ subprocess_utils.run_in_parallel(cleanup,
1144
+ sub_procs,
1145
+ num_threads=len(sub_procs))
@@ -55,19 +55,22 @@ async def log_streamer(request_id: Optional[str],
55
55
  if show_request_waiting_spinner:
56
56
  yield status_msg.init()
57
57
  yield status_msg.start()
58
- is_waiting_msg_logged = False
58
+ last_waiting_msg = ''
59
59
  waiting_msg = (f'Waiting for {request_task.name!r} request to be '
60
60
  f'scheduled: {request_id}')
61
61
  while request_task.status < requests_lib.RequestStatus.RUNNING:
62
+ if request_task.status_msg is not None:
63
+ waiting_msg = request_task.status_msg
62
64
  if show_request_waiting_spinner:
63
65
  yield status_msg.update(f'[dim]{waiting_msg}[/dim]')
64
- elif plain_logs and not is_waiting_msg_logged:
65
- is_waiting_msg_logged = True
66
+ elif plain_logs and waiting_msg != last_waiting_msg:
67
+ # Only log when waiting message changes.
68
+ last_waiting_msg = waiting_msg
66
69
  # Use smaller padding (1024 bytes) to force browser rendering
67
70
  yield f'{waiting_msg}' + ' ' * 4096 + '\n'
68
- # Sleep 0 to yield, so other coroutines can run. This busy waiting
69
- # loop is performance critical for short-running requests, so we do
70
- # not want to yield too long.
71
+ # Sleep shortly to avoid storming the DB and CPU and allow other
72
+ # coroutines to run. This busy waiting loop is performance critical
73
+ # for short-running requests, so we do not want to yield too long.
71
74
  await asyncio.sleep(0.1)
72
75
  request_task = requests_lib.get_request(request_id)
73
76
  if not follow:
sky/server/uvicorn.py ADDED
@@ -0,0 +1,81 @@
1
+ """Uvicorn wrapper for SkyPilot API server.
2
+
3
+ This module is a wrapper around uvicorn to customize the behavior of the
4
+ server.
5
+ """
6
+ import os
7
+ import threading
8
+ from typing import Optional
9
+
10
+ import uvicorn
11
+ from uvicorn.supervisors import multiprocess
12
+
13
+ from sky.utils import subprocess_utils
14
+
15
+
16
+ def run(config: uvicorn.Config):
17
+ """Run unvicorn server."""
18
+ if config.reload:
19
+ # Reload and multi-workers are mutually exclusive
20
+ # in uvicorn. Since we do not use reload now, simply
21
+ # guard by an exception.
22
+ raise ValueError('Reload is not supported yet.')
23
+ server = uvicorn.Server(config=config)
24
+ try:
25
+ if config.workers is not None and config.workers > 1:
26
+ sock = config.bind_socket()
27
+ SlowStartMultiprocess(config, target=server.run,
28
+ sockets=[sock]).run()
29
+ else:
30
+ server.run()
31
+ finally:
32
+ # Copied from unvicorn.run()
33
+ if config.uds and os.path.exists(config.uds):
34
+ os.remove(config.uds)
35
+
36
+
37
+ class SlowStartMultiprocess(multiprocess.Multiprocess):
38
+ """Uvicorn Multiprocess wrapper with slow start.
39
+
40
+ Slow start offers faster and more stable start time.
41
+ Profile shows the start time is more stable and accelerated from
42
+ ~7s to ~3.3s on a 12-core machine after switching LONG workers and
43
+ Uvicorn workers to slow start.
44
+ Refer to subprocess_utils.slow_start_processes() for more details.
45
+ """
46
+
47
+ def __init__(self, config: uvicorn.Config, **kwargs):
48
+ """Initialize the multiprocess wrapper.
49
+
50
+ Args:
51
+ config: The uvicorn config.
52
+ """
53
+ super().__init__(config, **kwargs)
54
+ self._init_thread: Optional[threading.Thread] = None
55
+
56
+ def init_processes(self) -> None:
57
+ # Slow start worker processes asynchronously to avoid blocking signal
58
+ # handling of uvicorn.
59
+ self._init_thread = threading.Thread(target=self.slow_start_processes,
60
+ daemon=True)
61
+ self._init_thread.start()
62
+
63
+ def slow_start_processes(self) -> None:
64
+ """Initialize processes with slow start."""
65
+ to_start = []
66
+ # Init N worker processes
67
+ for _ in range(self.processes_num):
68
+ to_start.append(
69
+ multiprocess.Process(self.config, self.target, self.sockets))
70
+ # Start the processes with slow start, we only append start to
71
+ # self.processes because Uvicorn periodically restarts unstarted
72
+ # workers.
73
+ subprocess_utils.slow_start_processes(to_start,
74
+ on_start=self.processes.append,
75
+ should_exit=self.should_exit)
76
+
77
+ def terminate_all(self) -> None:
78
+ """Wait init thread to finish before terminating all processes."""
79
+ if self._init_thread is not None:
80
+ self._init_thread.join()
81
+ super().terminate_all()
@@ -120,7 +120,10 @@ extras_require: Dict[str, List[str]] = {
120
120
  # https://github.com/googleapis/google-api-python-client/commit/f6e9d3869ed605b06f7cbf2e8cf2db25108506e6
121
121
  'gcp': ['google-api-python-client>=2.69.0', 'google-cloud-storage'],
122
122
  'ibm': [
123
- 'ibm-cloud-sdk-core', 'ibm-vpc', 'ibm-platform-services', 'ibm-cos-sdk'
123
+ 'ibm-cloud-sdk-core',
124
+ 'ibm-vpc',
125
+ 'ibm-platform-services>=0.48.0',
126
+ 'ibm-cos-sdk',
124
127
  ] + local_ray,
125
128
  'docker': ['docker'] + local_ray,
126
129
  'lambda': [], # No dependencies needed for lambda
@@ -77,7 +77,7 @@ def canonicalize_accelerator_name(accelerator: str,
77
77
  # Look for Kubernetes accelerators online if the accelerator is not found
78
78
  # in the public cloud catalog. This is to make sure custom accelerators
79
79
  # on Kubernetes can be correctly canonicalized.
80
- if not names and cloud_str in ['kubernetes', None]:
80
+ if not names and cloud_str in ['Kubernetes', None]:
81
81
  with rich_utils.safe_status(
82
82
  ux_utils.spinner_message('Listing accelerators on Kubernetes')):
83
83
  searched = service_catalog.list_accelerators(