skypilot-nightly 1.0.0.dev20250909__py3-none-any.whl → 1.0.0.dev20250910__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (67) hide show
  1. sky/__init__.py +2 -2
  2. sky/authentication.py +19 -4
  3. sky/backends/backend_utils.py +35 -1
  4. sky/backends/cloud_vm_ray_backend.py +2 -2
  5. sky/client/sdk.py +20 -0
  6. sky/client/sdk_async.py +18 -16
  7. sky/clouds/aws.py +3 -1
  8. sky/dashboard/out/404.html +1 -1
  9. sky/dashboard/out/_next/static/chunks/{webpack-d4fabc08788e14af.js → webpack-1d7e11230da3ca89.js} +1 -1
  10. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  11. sky/dashboard/out/clusters/[cluster].html +1 -1
  12. sky/dashboard/out/clusters.html +1 -1
  13. sky/dashboard/out/config.html +1 -1
  14. sky/dashboard/out/index.html +1 -1
  15. sky/dashboard/out/infra/[context].html +1 -1
  16. sky/dashboard/out/infra.html +1 -1
  17. sky/dashboard/out/jobs/[job].html +1 -1
  18. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  19. sky/dashboard/out/jobs.html +1 -1
  20. sky/dashboard/out/users.html +1 -1
  21. sky/dashboard/out/volumes.html +1 -1
  22. sky/dashboard/out/workspace/new.html +1 -1
  23. sky/dashboard/out/workspaces/[name].html +1 -1
  24. sky/dashboard/out/workspaces.html +1 -1
  25. sky/data/storage.py +5 -1
  26. sky/execution.py +21 -14
  27. sky/jobs/constants.py +3 -0
  28. sky/jobs/controller.py +732 -310
  29. sky/jobs/recovery_strategy.py +251 -129
  30. sky/jobs/scheduler.py +247 -174
  31. sky/jobs/server/core.py +20 -4
  32. sky/jobs/server/utils.py +2 -2
  33. sky/jobs/state.py +702 -511
  34. sky/jobs/utils.py +94 -39
  35. sky/provision/aws/config.py +4 -1
  36. sky/provision/gcp/config.py +6 -1
  37. sky/provision/kubernetes/utils.py +17 -8
  38. sky/provision/provisioner.py +1 -0
  39. sky/serve/replica_managers.py +0 -7
  40. sky/serve/serve_utils.py +5 -0
  41. sky/serve/server/impl.py +1 -2
  42. sky/serve/service.py +0 -2
  43. sky/server/common.py +8 -3
  44. sky/server/config.py +43 -24
  45. sky/server/constants.py +1 -0
  46. sky/server/daemons.py +7 -11
  47. sky/server/requests/serializers/encoders.py +1 -1
  48. sky/server/server.py +8 -1
  49. sky/setup_files/dependencies.py +4 -2
  50. sky/skylet/attempt_skylet.py +1 -0
  51. sky/skylet/constants.py +3 -1
  52. sky/skylet/events.py +2 -10
  53. sky/utils/command_runner.pyi +3 -3
  54. sky/utils/common_utils.py +11 -1
  55. sky/utils/controller_utils.py +5 -0
  56. sky/utils/db/db_utils.py +31 -2
  57. sky/utils/rich_utils.py +3 -1
  58. sky/utils/subprocess_utils.py +9 -0
  59. sky/volumes/volume.py +2 -0
  60. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/METADATA +39 -37
  61. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/RECORD +67 -67
  62. /sky/dashboard/out/_next/static/{eWytLgin5zvayQw3Xk46m → 3SYxqNGnvvPS8h3gdD2T7}/_buildManifest.js +0 -0
  63. /sky/dashboard/out/_next/static/{eWytLgin5zvayQw3Xk46m → 3SYxqNGnvvPS8h3gdD2T7}/_ssgManifest.js +0 -0
  64. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/WHEEL +0 -0
  65. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/entry_points.txt +0 -0
  66. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/licenses/LICENSE +0 -0
  67. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/top_level.txt +0 -0
sky/jobs/utils.py CHANGED
@@ -4,9 +4,11 @@ NOTE: whenever an API change is made in this file, we need to bump the
4
4
  jobs.constants.MANAGED_JOBS_VERSION and handle the API change in the
5
5
  ManagedJobCodeGen.
6
6
  """
7
+ import asyncio
7
8
  import collections
8
9
  import datetime
9
10
  import enum
11
+ import logging
10
12
  import os
11
13
  import pathlib
12
14
  import shlex
@@ -14,11 +16,11 @@ import textwrap
14
16
  import time
15
17
  import traceback
16
18
  import typing
17
- from typing import Any, Deque, Dict, List, Optional, Set, TextIO, Tuple, Union
19
+ from typing import (Any, Deque, Dict, List, Literal, Optional, Set, TextIO,
20
+ Tuple, Union)
18
21
 
19
22
  import colorama
20
23
  import filelock
21
- from typing_extensions import Literal
22
24
 
23
25
  from sky import backends
24
26
  from sky import exceptions
@@ -37,6 +39,7 @@ from sky.usage import usage_lib
37
39
  from sky.utils import annotations
38
40
  from sky.utils import command_runner
39
41
  from sky.utils import common_utils
42
+ from sky.utils import context_utils
40
43
  from sky.utils import controller_utils
41
44
  from sky.utils import infra_utils
42
45
  from sky.utils import log_utils
@@ -56,9 +59,9 @@ else:
56
59
 
57
60
  logger = sky_logging.init_logger(__name__)
58
61
 
59
- SIGNAL_FILE_PREFIX = '/tmp/sky_jobs_controller_signal_{}'
60
62
  # Controller checks its job's status every this many seconds.
61
- JOB_STATUS_CHECK_GAP_SECONDS = 20
63
+ # This is a tradeoff between the latency and the resource usage.
64
+ JOB_STATUS_CHECK_GAP_SECONDS = 15
62
65
 
63
66
  # Controller checks if its job has started every this many seconds.
64
67
  JOB_STARTED_STATUS_CHECK_GAP_SECONDS = 5
@@ -82,7 +85,7 @@ _JOB_CANCELLED_MESSAGE = (
82
85
  # blocking for a long time. This should be significantly longer than the
83
86
  # JOB_STATUS_CHECK_GAP_SECONDS to avoid timing out before the controller can
84
87
  # update the state.
85
- _FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 40
88
+ _FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 120
86
89
 
87
90
 
88
91
  class ManagedJobQueueResultType(enum.Enum):
@@ -99,7 +102,11 @@ class UserSignal(enum.Enum):
99
102
 
100
103
 
101
104
  # ====== internal functions ======
102
- def terminate_cluster(cluster_name: str, max_retry: int = 6) -> None:
105
+ def terminate_cluster(
106
+ cluster_name: str,
107
+ max_retry: int = 6,
108
+ _logger: logging.Logger = logger, # pylint: disable=invalid-name
109
+ ) -> None:
103
110
  """Terminate the cluster."""
104
111
  from sky import core # pylint: disable=import-outside-toplevel
105
112
  retry_cnt = 0
@@ -122,18 +129,18 @@ def terminate_cluster(cluster_name: str, max_retry: int = 6) -> None:
122
129
  return
123
130
  except exceptions.ClusterDoesNotExist:
124
131
  # The cluster is already down.
125
- logger.debug(f'The cluster {cluster_name} is already down.')
132
+ _logger.debug(f'The cluster {cluster_name} is already down.')
126
133
  return
127
134
  except Exception as e: # pylint: disable=broad-except
128
135
  retry_cnt += 1
129
136
  if retry_cnt >= max_retry:
130
137
  raise RuntimeError(
131
138
  f'Failed to terminate the cluster {cluster_name}.') from e
132
- logger.error(
139
+ _logger.error(
133
140
  f'Failed to terminate the cluster {cluster_name}. Retrying.'
134
141
  f'Details: {common_utils.format_exception(e)}')
135
142
  with ux_utils.enable_traceback():
136
- logger.error(f' Traceback: {traceback.format_exc()}')
143
+ _logger.error(f' Traceback: {traceback.format_exc()}')
137
144
  time.sleep(backoff.current_backoff())
138
145
 
139
146
 
@@ -183,6 +190,9 @@ def _validate_consolidation_mode_config(
183
190
  # Use LRU Cache so that the check is only done once.
184
191
  @annotations.lru_cache(scope='request', maxsize=1)
185
192
  def is_consolidation_mode() -> bool:
193
+ if os.environ.get(constants.OVERRIDE_CONSOLIDATION_MODE) is not None:
194
+ return True
195
+
186
196
  consolidation_mode = skypilot_config.get_nested(
187
197
  ('jobs', 'controller', 'consolidation_mode'), default_value=False)
188
198
  # We should only do this check on API server, as the controller will not
@@ -199,6 +209,7 @@ def ha_recovery_for_consolidation_mode():
199
209
  # already has all runtime installed. Directly start jobs recovery here.
200
210
  # Refers to sky/templates/kubernetes-ray.yml.j2 for more details.
201
211
  runner = command_runner.LocalProcessCommandRunner()
212
+ scheduler.maybe_start_controllers()
202
213
  with open(constants.HA_PERSISTENT_RECOVERY_LOG_PATH.format('jobs_'),
203
214
  'w',
204
215
  encoding='utf-8') as f:
@@ -214,7 +225,7 @@ def ha_recovery_for_consolidation_mode():
214
225
  # just keep running.
215
226
  if controller_pid is not None:
216
227
  try:
217
- if _controller_process_alive(controller_pid, job_id):
228
+ if controller_process_alive(controller_pid, job_id):
218
229
  f.write(f'Controller pid {controller_pid} for '
219
230
  f'job {job_id} is still running. '
220
231
  'Skipping recovery.\n')
@@ -227,7 +238,7 @@ def ha_recovery_for_consolidation_mode():
227
238
 
228
239
  if job['schedule_state'] not in [
229
240
  managed_job_state.ManagedJobScheduleState.DONE,
230
- managed_job_state.ManagedJobScheduleState.WAITING
241
+ managed_job_state.ManagedJobScheduleState.WAITING,
231
242
  ]:
232
243
  script = managed_job_state.get_ha_recovery_script(job_id)
233
244
  if script is None:
@@ -242,56 +253,66 @@ def ha_recovery_for_consolidation_mode():
242
253
  f.write(f'Total recovery time: {time.time() - start} seconds\n')
243
254
 
244
255
 
245
- def get_job_status(backend: 'backends.CloudVmRayBackend', cluster_name: str,
246
- job_id: Optional[int]) -> Optional['job_lib.JobStatus']:
256
+ async def get_job_status(
257
+ backend: 'backends.CloudVmRayBackend', cluster_name: str,
258
+ job_id: Optional[int],
259
+ job_logger: logging.Logger) -> Optional['job_lib.JobStatus']:
247
260
  """Check the status of the job running on a managed job cluster.
248
261
 
249
262
  It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_DRIVER,
250
263
  FAILED_SETUP or CANCELLED.
251
264
  """
252
- handle = global_user_state.get_handle_from_cluster_name(cluster_name)
265
+ # TODO(luca) make this async
266
+ handle = await context_utils.to_thread(
267
+ global_user_state.get_handle_from_cluster_name, cluster_name)
253
268
  if handle is None:
254
269
  # This can happen if the cluster was preempted and background status
255
270
  # refresh already noticed and cleaned it up.
256
- logger.info(f'Cluster {cluster_name} not found.')
271
+ job_logger.info(f'Cluster {cluster_name} not found.')
257
272
  return None
258
273
  assert isinstance(handle, backends.CloudVmRayResourceHandle), handle
259
274
  job_ids = None if job_id is None else [job_id]
260
275
  for i in range(_JOB_STATUS_FETCH_MAX_RETRIES):
261
276
  try:
262
- logger.info('=== Checking the job status... ===')
263
- statuses = backend.get_job_status(handle,
264
- job_ids=job_ids,
265
- stream_logs=False)
277
+ job_logger.info('=== Checking the job status... ===')
278
+ statuses = await context_utils.to_thread(backend.get_job_status,
279
+ handle,
280
+ job_ids=job_ids,
281
+ stream_logs=False)
266
282
  status = list(statuses.values())[0]
267
283
  if status is None:
268
- logger.info('No job found.')
284
+ job_logger.info('No job found.')
269
285
  else:
270
- logger.info(f'Job status: {status}')
271
- logger.info('=' * 34)
286
+ job_logger.info(f'Job status: {status}')
287
+ job_logger.info('=' * 34)
272
288
  return status
273
289
  except exceptions.CommandError as e:
274
290
  # Retry on k8s transient network errors. This is useful when using
275
291
  # coreweave which may have transient network issue sometimes.
276
292
  if (e.detailed_reason is not None and
277
293
  _JOB_K8S_TRANSIENT_NW_MSG in e.detailed_reason):
278
- logger.info('Failed to connect to the cluster. Retrying '
279
- f'({i + 1}/{_JOB_STATUS_FETCH_MAX_RETRIES})...')
280
- logger.info('=' * 34)
281
- time.sleep(1)
294
+ job_logger.info('Failed to connect to the cluster. Retrying '
295
+ f'({i + 1}/{_JOB_STATUS_FETCH_MAX_RETRIES})...')
296
+ job_logger.info('=' * 34)
297
+ await asyncio.sleep(1)
282
298
  else:
283
- logger.info(f'Failed to get job status: {e.detailed_reason}')
284
- logger.info('=' * 34)
299
+ job_logger.info(
300
+ f'Failed to get job status: {e.detailed_reason}')
301
+ job_logger.info('=' * 34)
285
302
  return None
286
303
  return None
287
304
 
288
305
 
289
- def _controller_process_alive(pid: int, job_id: int) -> bool:
306
+ def controller_process_alive(pid: int, job_id: int) -> bool:
290
307
  """Check if the controller process is alive."""
291
308
  try:
309
+ if pid < 0:
310
+ # new job controller process will always be negative
311
+ pid = -pid
292
312
  process = psutil.Process(pid)
293
313
  cmd_str = ' '.join(process.cmdline())
294
- return process.is_running() and f'--job-id {job_id}' in cmd_str
314
+ return process.is_running() and ((f'--job-id {job_id}' in cmd_str) or
315
+ ('controller' in cmd_str))
295
316
  except psutil.NoSuchProcess:
296
317
  return False
297
318
 
@@ -466,7 +487,7 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
466
487
  failure_reason = f'No controller pid set for {schedule_state.value}'
467
488
  else:
468
489
  logger.debug(f'Checking controller pid {pid}')
469
- if _controller_process_alive(pid, job_id):
490
+ if controller_process_alive(pid, job_id):
470
491
  # The controller is still running, so this job is fine.
471
492
  continue
472
493
 
@@ -604,7 +625,17 @@ def event_callback_func(job_id: int, task_id: int, task: 'sky.Task'):
604
625
  f'Bash:{event_callback},log_path:{log_path},result:{result}')
605
626
  logger.info(f'=== END: event callback for {status!r} ===')
606
627
 
607
- return callback_func
628
+ try:
629
+ asyncio.get_running_loop()
630
+
631
+ # In async context
632
+ async def async_callback_func(status: str):
633
+ return await context_utils.to_thread(callback_func, status)
634
+
635
+ return async_callback_func
636
+ except RuntimeError:
637
+ # Not in async context
638
+ return callback_func
608
639
 
609
640
 
610
641
  # ======== user functions ========
@@ -651,16 +682,41 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]],
651
682
  logger.info(f'Job {job_id} is already in terminal state '
652
683
  f'{job_status.value}. Skipped.')
653
684
  continue
685
+ elif job_status == managed_job_state.ManagedJobStatus.PENDING:
686
+ # the if is a short circuit, this will be atomic.
687
+ cancelled = managed_job_state.set_pending_cancelled(job_id)
688
+ if cancelled:
689
+ cancelled_job_ids.append(job_id)
690
+ continue
654
691
 
655
692
  update_managed_jobs_statuses(job_id)
656
693
 
694
+ job_controller_pid = managed_job_state.get_job_controller_pid(job_id)
695
+ if job_controller_pid is not None and job_controller_pid < 0:
696
+ # This is a consolidated job controller, so we need to cancel the
697
+ # with the controller server API
698
+ try:
699
+ # we create a file as a signal to the controller server
700
+ signal_file = pathlib.Path(
701
+ managed_job_constants.CONSOLIDATED_SIGNAL_PATH, f'{job_id}')
702
+ signal_file.touch()
703
+ cancelled_job_ids.append(job_id)
704
+ except OSError as e:
705
+ logger.error(f'Failed to cancel job {job_id} '
706
+ f'with controller server: {e}')
707
+ # don't add it to the to be cancelled job ids, since we don't
708
+ # know for sure yet.
709
+ continue
710
+ continue
711
+
657
712
  job_workspace = managed_job_state.get_workspace(job_id)
658
713
  if current_workspace is not None and job_workspace != current_workspace:
659
714
  wrong_workspace_job_ids.append(job_id)
660
715
  continue
661
716
 
662
717
  # Send the signal to the jobs controller.
663
- signal_file = pathlib.Path(SIGNAL_FILE_PREFIX.format(job_id))
718
+ signal_file = (pathlib.Path(
719
+ managed_job_constants.SIGNAL_FILE_PREFIX.format(job_id)))
664
720
  # Filelock is needed to prevent race condition between signal
665
721
  # check/removal and signal writing.
666
722
  with filelock.FileLock(str(signal_file) + '.lock'):
@@ -1159,8 +1215,7 @@ def dump_managed_job_queue(
1159
1215
  # It's possible for a WAITING/ALIVE_WAITING job to be ready to
1160
1216
  # launch, but the scheduler just hasn't run yet.
1161
1217
  managed_job_state.ManagedJobScheduleState.WAITING,
1162
- managed_job_state.ManagedJobScheduleState.ALIVE_WAITING,
1163
- ):
1218
+ managed_job_state.ManagedJobScheduleState.ALIVE_WAITING):
1164
1219
  # This job will not block others.
1165
1220
  continue
1166
1221
 
@@ -1370,12 +1425,12 @@ def load_managed_job_queue(
1370
1425
  """Load job queue from json string."""
1371
1426
  result = message_utils.decode_payload(payload)
1372
1427
  result_type = ManagedJobQueueResultType.DICT
1373
- status_counts = {}
1428
+ status_counts: Dict[str, int] = {}
1374
1429
  if isinstance(result, dict):
1375
- jobs = result['jobs']
1376
- total = result['total']
1430
+ jobs: List[Dict[str, Any]] = result['jobs']
1431
+ total: int = result['total']
1377
1432
  status_counts = result.get('status_counts', {})
1378
- total_no_filter = result.get('total_no_filter', total)
1433
+ total_no_filter: int = result.get('total_no_filter', total)
1379
1434
  else:
1380
1435
  jobs = result
1381
1436
  total = len(jobs)
@@ -305,7 +305,10 @@ def _get_route_tables(ec2: 'mypy_boto3_ec2.ServiceResource',
305
305
  Returns:
306
306
  A list of route tables associated with the options VPC and region
307
307
  """
308
- filters = [{'Name': 'association.main', 'Values': [str(main).lower()]}]
308
+ filters: List['ec2_type_defs.FilterTypeDef'] = [{
309
+ 'Name': 'association.main',
310
+ 'Values': [str(main).lower()],
311
+ }]
309
312
  if vpc_id is not None:
310
313
  filters.append({'Name': 'vpc-id', 'Values': [vpc_id]})
311
314
  logger.debug(
@@ -5,6 +5,8 @@ import time
5
5
  import typing
6
6
  from typing import Any, Dict, List, Set, Tuple
7
7
 
8
+ from typing_extensions import TypedDict
9
+
8
10
  from sky.adaptors import gcp
9
11
  from sky.clouds.utils import gcp_utils
10
12
  from sky.provision import common
@@ -415,6 +417,9 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam) -> dict:
415
417
  return iam_role
416
418
 
417
419
 
420
+ AllowedList = TypedDict('AllowedList', {'IPProtocol': str, 'ports': List[str]})
421
+
422
+
418
423
  def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str,
419
424
  compute):
420
425
  """Check if the firewall rules in the VPC are sufficient."""
@@ -466,7 +471,7 @@ def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str,
466
471
  }
467
472
  """
468
473
  source2rules: Dict[Tuple[str, str], Dict[str, Set[int]]] = {}
469
- source2allowed_list: Dict[Tuple[str, str], List[Dict[str, str]]] = {}
474
+ source2allowed_list: Dict[Tuple[str, str], List[AllowedList]] = {}
470
475
  for rule in rules:
471
476
  # Rules applied to specific VM (targetTags) may not work for the
472
477
  # current VM, so should be skipped.
@@ -451,6 +451,9 @@ class CoreWeaveLabelFormatter(GPULabelFormatter):
451
451
 
452
452
  LABEL_KEY = 'gpu.nvidia.com/class'
453
453
 
454
+ # TODO (kyuds): fill in more label values for different accelerators.
455
+ ACC_VALUE_MAPPINGS = {'H100_NVLINK_80GB': 'H100'}
456
+
454
457
  @classmethod
455
458
  def get_label_key(cls, accelerator: Optional[str] = None) -> str:
456
459
  return cls.LABEL_KEY
@@ -469,7 +472,8 @@ class CoreWeaveLabelFormatter(GPULabelFormatter):
469
472
 
470
473
  @classmethod
471
474
  def get_accelerator_from_label_value(cls, value: str) -> str:
472
- return value
475
+ # return original label value if not found in mappings.
476
+ return cls.ACC_VALUE_MAPPINGS.get(value, value)
473
477
 
474
478
 
475
479
  class GKELabelFormatter(GPULabelFormatter):
@@ -1012,15 +1016,16 @@ class GKEAutoscaler(Autoscaler):
1012
1016
  to fit the instance type.
1013
1017
  """
1014
1018
  for accelerator in node_pool_accelerators:
1019
+ raw_value = accelerator['acceleratorType']
1015
1020
  node_accelerator_type = (
1016
- GKELabelFormatter.get_accelerator_from_label_value(
1017
- accelerator['acceleratorType']))
1021
+ GKELabelFormatter.get_accelerator_from_label_value(raw_value))
1018
1022
  # handle heterogenous nodes.
1019
1023
  if not node_accelerator_type:
1020
1024
  continue
1021
1025
  node_accelerator_count = accelerator['acceleratorCount']
1022
- if node_accelerator_type == requested_gpu_type and int(
1023
- node_accelerator_count) >= requested_gpu_count:
1026
+ viable_names = [node_accelerator_type.lower(), raw_value.lower()]
1027
+ if (requested_gpu_type.lower() in viable_names and
1028
+ int(node_accelerator_count) >= requested_gpu_count):
1024
1029
  return True
1025
1030
  return False
1026
1031
 
@@ -1448,9 +1453,13 @@ def get_accelerator_label_key_values(
1448
1453
  if is_multi_host_tpu(node_metadata_labels):
1449
1454
  continue
1450
1455
  for label, value in label_list:
1451
- if (label_formatter.match_label_key(label) and
1452
- label_formatter.get_accelerator_from_label_value(
1453
- value).lower() == acc_type.lower()):
1456
+ if label_formatter.match_label_key(label):
1457
+ # match either canonicalized name or raw name
1458
+ accelerator = (label_formatter.
1459
+ get_accelerator_from_label_value(value))
1460
+ viable = [value.lower(), accelerator.lower()]
1461
+ if acc_type.lower() not in viable:
1462
+ continue
1454
1463
  if is_tpu_on_gke(acc_type):
1455
1464
  assert isinstance(label_formatter,
1456
1465
  GKELabelFormatter)
@@ -526,6 +526,7 @@ def _post_provision_setup(
526
526
  status.update(
527
527
  ux_utils.spinner_message(
528
528
  'Checking controller version compatibility'))
529
+
529
530
  try:
530
531
  server_jobs_utils.check_version_mismatch_and_non_terminal_jobs()
531
532
  except exceptions.ClusterNotUpError:
@@ -22,7 +22,6 @@ from sky import global_user_state
22
22
  from sky import sky_logging
23
23
  from sky import task as task_lib
24
24
  from sky.backends import backend_utils
25
- from sky.jobs import scheduler as jobs_scheduler
26
25
  from sky.serve import constants as serve_constants
27
26
  from sky.serve import serve_state
28
27
  from sky.serve import serve_utils
@@ -1052,7 +1051,6 @@ class SkyPilotReplicaManager(ReplicaManager):
1052
1051
  self._service_name, replica_id)
1053
1052
  assert info is not None, replica_id
1054
1053
  error_in_sky_launch = False
1055
- schedule_next_jobs = False
1056
1054
  if info.status == serve_state.ReplicaStatus.PENDING:
1057
1055
  # sky.launch not started yet
1058
1056
  if controller_utils.can_provision():
@@ -1080,7 +1078,6 @@ class SkyPilotReplicaManager(ReplicaManager):
1080
1078
  else:
1081
1079
  info.status_property.sky_launch_status = (
1082
1080
  common_utils.ProcessStatus.SUCCEEDED)
1083
- schedule_next_jobs = True
1084
1081
  if self._spot_placer is not None and info.is_spot:
1085
1082
  # TODO(tian): Currently, we set the location to
1086
1083
  # preemptive if the launch process failed. This is
@@ -1100,16 +1097,12 @@ class SkyPilotReplicaManager(ReplicaManager):
1100
1097
  self._spot_placer.set_active(location)
1101
1098
  serve_state.add_or_update_replica(self._service_name,
1102
1099
  replica_id, info)
1103
- if schedule_next_jobs and self._is_pool:
1104
- jobs_scheduler.maybe_schedule_next_jobs()
1105
1100
  if error_in_sky_launch:
1106
1101
  # Teardown after update replica info since
1107
1102
  # _terminate_replica will update the replica info too.
1108
1103
  self._terminate_replica(replica_id,
1109
1104
  sync_down_logs=True,
1110
1105
  replica_drain_delay_seconds=0)
1111
- # Try schedule next job after acquiring the lock.
1112
- jobs_scheduler.maybe_schedule_next_jobs()
1113
1106
  down_process_pool_snapshot = list(self._down_process_pool.items())
1114
1107
  for replica_id, p in down_process_pool_snapshot:
1115
1108
  if p.is_alive():
sky/serve/serve_utils.py CHANGED
@@ -294,6 +294,11 @@ def is_consolidation_mode(pool: bool = False) -> bool:
294
294
  # We should only do this check on API server, as the controller will not
295
295
  # have related config and will always seemingly disabled for consolidation
296
296
  # mode. Check #6611 for more details.
297
+ if (os.environ.get(skylet_constants.OVERRIDE_CONSOLIDATION_MODE) is not None
298
+ and controller.controller_type == 'jobs'):
299
+ # if we are in the job controller, we must always be in consolidation
300
+ # mode.
301
+ return True
297
302
  if os.environ.get(skylet_constants.ENV_VAR_IS_SKYPILOT_SERVER) is not None:
298
303
  _validate_consolidation_mode_config(consolidation_mode, pool)
299
304
  return consolidation_mode
sky/serve/server/impl.py CHANGED
@@ -280,8 +280,7 @@ def up(
280
280
  ]
281
281
  run_script = '\n'.join(env_cmds + [run_script])
282
282
  # Dump script for high availability recovery.
283
- if controller_utils.high_availability_specified(controller_name):
284
- serve_state.set_ha_recovery_script(service_name, run_script)
283
+ serve_state.set_ha_recovery_script(service_name, run_script)
285
284
  backend.run_on_head(controller_handle, run_script)
286
285
 
287
286
  style = colorama.Style
sky/serve/service.py CHANGED
@@ -21,7 +21,6 @@ from sky import task as task_lib
21
21
  from sky.backends import backend_utils
22
22
  from sky.backends import cloud_vm_ray_backend
23
23
  from sky.data import data_utils
24
- from sky.jobs import scheduler as jobs_scheduler
25
24
  from sky.serve import constants
26
25
  from sky.serve import controller
27
26
  from sky.serve import load_balancer
@@ -278,7 +277,6 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int, entrypoint: str):
278
277
  pool=service_spec.pool,
279
278
  controller_pid=os.getpid(),
280
279
  entrypoint=entrypoint)
281
- jobs_scheduler.maybe_schedule_next_jobs()
282
280
  # Directly throw an error here. See sky/serve/api.py::up
283
281
  # for more details.
284
282
  if not success:
sky/server/common.py CHANGED
@@ -538,12 +538,17 @@ def _start_api_server(deploy: bool = False,
538
538
 
539
539
  # Check available memory before starting the server.
540
540
  avail_mem_size_gb: float = common_utils.get_mem_size_gb()
541
- if avail_mem_size_gb <= server_constants.MIN_AVAIL_MEM_GB:
541
+ # pylint: disable=import-outside-toplevel
542
+ import sky.jobs.utils as job_utils
543
+ max_memory = (server_constants.MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE
544
+ if job_utils.is_consolidation_mode() else
545
+ server_constants.MIN_AVAIL_MEM_GB)
546
+ if avail_mem_size_gb <= max_memory:
542
547
  logger.warning(
543
548
  f'{colorama.Fore.YELLOW}Your SkyPilot API server machine only '
544
549
  f'has {avail_mem_size_gb:.1f}GB memory available. '
545
- f'At least {server_constants.MIN_AVAIL_MEM_GB}GB is '
546
- 'recommended to support higher load with better performance.'
550
+ f'At least {max_memory}GB is recommended to support higher '
551
+ 'load with better performance.'
547
552
  f'{colorama.Style.RESET_ALL}')
548
553
 
549
554
  args = [sys.executable, *API_SERVER_CMD.split()]
sky/server/config.py CHANGED
@@ -19,8 +19,9 @@ from sky.utils import common_utils
19
19
  # TODO(aylei): maintaining these constants is error-prone, we may need to
20
20
  # automatically tune parallelism at runtime according to system usage stats
21
21
  # in the future.
22
- _LONG_WORKER_MEM_GB = 0.4
23
- _SHORT_WORKER_MEM_GB = 0.25
22
+ # TODO(luca): The future is now! ^^^
23
+ LONG_WORKER_MEM_GB = 0.4
24
+ SHORT_WORKER_MEM_GB = 0.25
24
25
  # To control the number of long workers.
25
26
  _CPU_MULTIPLIER_FOR_LONG_WORKERS = 2
26
27
  # Limit the number of long workers of local API server, since local server is
@@ -75,8 +76,8 @@ class ServerConfig:
75
76
 
76
77
 
77
78
  def compute_server_config(deploy: bool,
78
- max_db_connections: Optional[int] = None
79
- ) -> ServerConfig:
79
+ max_db_connections: Optional[int] = None,
80
+ quiet: bool = False) -> ServerConfig:
80
81
  """Compute the server config based on environment.
81
82
 
82
83
  We have different assumptions for the resources in different deployment
@@ -140,7 +141,12 @@ def compute_server_config(deploy: bool,
140
141
  burstable_parallel_for_short = _BURSTABLE_WORKERS_FOR_LOCAL
141
142
  # Runs in low resource mode if the available memory is less than
142
143
  # server_constants.MIN_AVAIL_MEM_GB.
143
- if not deploy and mem_size_gb < server_constants.MIN_AVAIL_MEM_GB:
144
+ # pylint: disable=import-outside-toplevel
145
+ import sky.jobs.utils as job_utils
146
+ max_memory = (server_constants.MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE
147
+ if job_utils.is_consolidation_mode() else
148
+ server_constants.MIN_AVAIL_MEM_GB)
149
+ if not deploy and mem_size_gb < max_memory:
144
150
  # Permanent worker process may have significant memory consumption
145
151
  # (~350MB per worker) after running commands like `sky check`, so we
146
152
  # don't start any permanent workers in low resource local mode. This
@@ -151,25 +157,29 @@ def compute_server_config(deploy: bool,
151
157
  # permanently because it never exits.
152
158
  max_parallel_for_long = 0
153
159
  max_parallel_for_short = 0
154
- logger.warning(
155
- 'SkyPilot API server will run in low resource mode because '
156
- 'the available memory is less than '
157
- f'{server_constants.MIN_AVAIL_MEM_GB}GB.')
160
+ if not quiet:
161
+ logger.warning(
162
+ 'SkyPilot API server will run in low resource mode because '
163
+ 'the available memory is less than '
164
+ f'{server_constants.MIN_AVAIL_MEM_GB}GB.')
158
165
  elif max_db_connections is not None:
159
166
  if max_parallel_all_workers > max_db_connections:
160
- logger.warning(
161
- f'Max parallel all workers ({max_parallel_all_workers}) '
162
- f'is greater than max db connections ({max_db_connections}). '
163
- 'Increase the number of max db connections to '
164
- f'at least {max_parallel_all_workers} for optimal performance.')
167
+ if not quiet:
168
+ logger.warning(
169
+ f'Max parallel all workers ({max_parallel_all_workers}) '
170
+ 'is greater than max db connections '
171
+ f'({max_db_connections}). Increase the number of max db '
172
+ f'connections to at least {max_parallel_all_workers} for '
173
+ 'optimal performance.')
165
174
  else:
166
175
  num_db_connections_per_worker = 1
167
176
 
168
- logger.info(
169
- f'SkyPilot API server will start {num_server_workers} server processes '
170
- f'with {max_parallel_for_long} background workers for long requests '
171
- f'and will allow at max {max_parallel_for_short} short requests in '
172
- f'parallel.')
177
+ if not quiet:
178
+ logger.info(
179
+ f'SkyPilot API server will start {num_server_workers} server '
180
+ f'processes with {max_parallel_for_long} background workers for '
181
+ f'long requests and will allow at max {max_parallel_for_short} '
182
+ 'short requests in parallel.')
173
183
  return ServerConfig(
174
184
  num_server_workers=num_server_workers,
175
185
  queue_backend=queue_backend,
@@ -190,10 +200,15 @@ def _max_long_worker_parallism(cpu_count: int,
190
200
  local=False) -> int:
191
201
  """Max parallelism for long workers."""
192
202
  # Reserve min available memory to avoid OOM.
193
- available_mem = max(0, mem_size_gb - server_constants.MIN_AVAIL_MEM_GB)
203
+ # pylint: disable=import-outside-toplevel
204
+ import sky.jobs.utils as job_utils
205
+ max_memory = (server_constants.MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE
206
+ if job_utils.is_consolidation_mode() else
207
+ server_constants.MIN_AVAIL_MEM_GB)
208
+ available_mem = max(0, mem_size_gb - max_memory)
194
209
  cpu_based_max_parallel = cpu_count * _CPU_MULTIPLIER_FOR_LONG_WORKERS
195
210
  mem_based_max_parallel = int(available_mem * _MAX_MEM_PERCENT_FOR_BLOCKING /
196
- _LONG_WORKER_MEM_GB)
211
+ LONG_WORKER_MEM_GB)
197
212
  n = max(_MIN_LONG_WORKERS,
198
213
  min(cpu_based_max_parallel, mem_based_max_parallel))
199
214
  if local:
@@ -205,8 +220,12 @@ def _max_short_worker_parallism(mem_size_gb: float,
205
220
  long_worker_parallism: int) -> int:
206
221
  """Max parallelism for short workers."""
207
222
  # Reserve memory for long workers and min available memory.
208
- reserved_mem = server_constants.MIN_AVAIL_MEM_GB + (long_worker_parallism *
209
- _LONG_WORKER_MEM_GB)
223
+ # pylint: disable=import-outside-toplevel
224
+ import sky.jobs.utils as job_utils
225
+ max_memory = (server_constants.MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE
226
+ if job_utils.is_consolidation_mode() else
227
+ server_constants.MIN_AVAIL_MEM_GB)
228
+ reserved_mem = max_memory + (long_worker_parallism * LONG_WORKER_MEM_GB)
210
229
  available_mem = max(0, mem_size_gb - reserved_mem)
211
- n = max(_MIN_SHORT_WORKERS, int(available_mem / _SHORT_WORKER_MEM_GB))
230
+ n = max(_MIN_SHORT_WORKERS, int(available_mem / SHORT_WORKER_MEM_GB))
212
231
  return n
sky/server/constants.py CHANGED
@@ -34,6 +34,7 @@ VERSION_HEADER = 'X-SkyPilot-Version'
34
34
  REQUEST_NAME_PREFIX = 'sky.'
35
35
  # The memory (GB) that SkyPilot tries to not use to prevent OOM.
36
36
  MIN_AVAIL_MEM_GB = 2
37
+ MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE = 4
37
38
  # Default encoder/decoder handler name.
38
39
  DEFAULT_HANDLER_NAME = 'default'
39
40
  # The path to the API request database.