skypilot-nightly 1.0.0.dev20250427__py3-none-any.whl → 1.0.0.dev20250429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sky/__init__.py +2 -2
  2. sky/adaptors/nebius.py +28 -40
  3. sky/backends/backend_utils.py +19 -2
  4. sky/backends/cloud_vm_ray_backend.py +33 -8
  5. sky/backends/local_docker_backend.py +1 -2
  6. sky/cli.py +91 -38
  7. sky/client/cli.py +91 -38
  8. sky/client/sdk.py +3 -2
  9. sky/clouds/aws.py +12 -6
  10. sky/clouds/azure.py +3 -0
  11. sky/clouds/cloud.py +8 -2
  12. sky/clouds/cudo.py +2 -0
  13. sky/clouds/do.py +3 -0
  14. sky/clouds/fluidstack.py +3 -0
  15. sky/clouds/gcp.py +7 -0
  16. sky/clouds/ibm.py +2 -0
  17. sky/clouds/kubernetes.py +42 -19
  18. sky/clouds/lambda_cloud.py +1 -0
  19. sky/clouds/nebius.py +18 -10
  20. sky/clouds/oci.py +6 -3
  21. sky/clouds/paperspace.py +2 -0
  22. sky/clouds/runpod.py +2 -0
  23. sky/clouds/scp.py +2 -0
  24. sky/clouds/service_catalog/constants.py +1 -1
  25. sky/clouds/service_catalog/kubernetes_catalog.py +7 -7
  26. sky/clouds/vast.py +2 -0
  27. sky/clouds/vsphere.py +2 -0
  28. sky/core.py +58 -29
  29. sky/dashboard/out/404.html +1 -1
  30. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  31. sky/dashboard/out/clusters/[cluster].html +1 -1
  32. sky/dashboard/out/clusters.html +1 -1
  33. sky/dashboard/out/favicon.ico +0 -0
  34. sky/dashboard/out/index.html +1 -1
  35. sky/dashboard/out/jobs/[job].html +1 -1
  36. sky/dashboard/out/jobs.html +1 -1
  37. sky/exceptions.py +6 -0
  38. sky/execution.py +19 -4
  39. sky/global_user_state.py +1 -0
  40. sky/optimizer.py +35 -11
  41. sky/provision/common.py +2 -5
  42. sky/provision/docker_utils.py +22 -16
  43. sky/provision/instance_setup.py +1 -1
  44. sky/provision/kubernetes/instance.py +276 -93
  45. sky/provision/kubernetes/network.py +1 -1
  46. sky/provision/kubernetes/utils.py +36 -24
  47. sky/provision/provisioner.py +6 -0
  48. sky/serve/replica_managers.py +51 -5
  49. sky/serve/serve_state.py +41 -0
  50. sky/serve/service.py +108 -63
  51. sky/server/common.py +6 -3
  52. sky/server/config.py +184 -0
  53. sky/server/requests/executor.py +17 -156
  54. sky/server/server.py +4 -4
  55. sky/setup_files/dependencies.py +0 -1
  56. sky/skylet/constants.py +7 -0
  57. sky/skypilot_config.py +27 -6
  58. sky/task.py +1 -1
  59. sky/templates/kubernetes-ray.yml.j2 +145 -15
  60. sky/templates/nebius-ray.yml.j2 +63 -0
  61. sky/utils/command_runner.py +17 -3
  62. sky/utils/command_runner.pyi +2 -0
  63. sky/utils/controller_utils.py +24 -0
  64. sky/utils/kubernetes/rsync_helper.sh +20 -4
  65. sky/utils/schemas.py +13 -0
  66. {skypilot_nightly-1.0.0.dev20250427.dist-info → skypilot_nightly-1.0.0.dev20250429.dist-info}/METADATA +2 -2
  67. {skypilot_nightly-1.0.0.dev20250427.dist-info → skypilot_nightly-1.0.0.dev20250429.dist-info}/RECORD +73 -72
  68. {skypilot_nightly-1.0.0.dev20250427.dist-info → skypilot_nightly-1.0.0.dev20250429.dist-info}/WHEEL +1 -1
  69. /sky/dashboard/out/_next/static/{kTfCjujxwqIQ4b7YvP7Uq → BMtJJ079_cyYmtW2-7nVS}/_buildManifest.js +0 -0
  70. /sky/dashboard/out/_next/static/{kTfCjujxwqIQ4b7YvP7Uq → BMtJJ079_cyYmtW2-7nVS}/_ssgManifest.js +0 -0
  71. {skypilot_nightly-1.0.0.dev20250427.dist-info → skypilot_nightly-1.0.0.dev20250429.dist-info}/entry_points.txt +0 -0
  72. {skypilot_nightly-1.0.0.dev20250427.dist-info → skypilot_nightly-1.0.0.dev20250429.dist-info}/licenses/LICENSE +0 -0
  73. {skypilot_nightly-1.0.0.dev20250427.dist-info → skypilot_nightly-1.0.0.dev20250429.dist-info}/top_level.txt +0 -0
@@ -45,6 +45,16 @@ else:
45
45
  jinja2 = adaptors_common.LazyImport('jinja2')
46
46
  yaml = adaptors_common.LazyImport('yaml')
47
47
 
48
+ # Please be careful when changing this.
49
+ # When mounting, Kubernetes changes the ownership of the parent directory
50
+ # to root:root.
51
+ # See https://stackoverflow.com/questions/50818029/mounted-folder-created-as-root-instead-of-current-user-in-docker/50820023#50820023. # pylint: disable=line-too-long
52
+ HIGH_AVAILABILITY_DEPLOYMENT_VOLUME_MOUNT_NAME = 'sky-data'
53
+ # Path where the persistent volume for HA controller is mounted.
54
+ # TODO(andy): Consider using dedicated path like `/var/skypilot`
55
+ # and store all data that needs to be persisted in future.
56
+ HIGH_AVAILABILITY_DEPLOYMENT_VOLUME_MOUNT_PATH = '/home/sky'
57
+
48
58
  # TODO(romilb): Move constants to constants.py
49
59
  DEFAULT_NAMESPACE = 'default'
50
60
 
@@ -233,7 +243,7 @@ class GPULabelFormatter:
233
243
  raise NotImplementedError
234
244
 
235
245
  @classmethod
236
- def get_label_value(cls, accelerator: str) -> str:
246
+ def get_label_values(cls, accelerator: str) -> List[str]:
237
247
  """Given a GPU type, returns the label value to be used"""
238
248
  raise NotImplementedError
239
249
 
@@ -301,10 +311,10 @@ class SkyPilotLabelFormatter(GPULabelFormatter):
301
311
  return [cls.LABEL_KEY]
302
312
 
303
313
  @classmethod
304
- def get_label_value(cls, accelerator: str) -> str:
314
+ def get_label_values(cls, accelerator: str) -> List[str]:
305
315
  # For SkyPilot formatter, we use the accelerator str directly.
306
316
  # See sky.utils.kubernetes.gpu_labeler.
307
- return accelerator.lower()
317
+ return [accelerator.lower()]
308
318
 
309
319
  @classmethod
310
320
  def match_label_key(cls, label_key: str) -> bool:
@@ -341,8 +351,8 @@ class CoreWeaveLabelFormatter(GPULabelFormatter):
341
351
  return [cls.LABEL_KEY]
342
352
 
343
353
  @classmethod
344
- def get_label_value(cls, accelerator: str) -> str:
345
- return accelerator.upper()
354
+ def get_label_values(cls, accelerator: str) -> List[str]:
355
+ return [accelerator.upper()]
346
356
 
347
357
  @classmethod
348
358
  def match_label_key(cls, label_key: str) -> bool:
@@ -428,8 +438,8 @@ class GKELabelFormatter(GPULabelFormatter):
428
438
  return count_to_topology
429
439
 
430
440
  @classmethod
431
- def get_label_value(cls, accelerator: str) -> str:
432
- return get_gke_accelerator_name(accelerator)
441
+ def get_label_values(cls, accelerator: str) -> List[str]:
442
+ return [get_gke_accelerator_name(accelerator)]
433
443
 
434
444
  @classmethod
435
445
  def get_accelerator_from_label_value(cls, value: str) -> str:
@@ -462,7 +472,7 @@ class GFDLabelFormatter(GPULabelFormatter):
462
472
  https://docs.nvidia.com/datacenter/cloud-native/gpu-operator/latest/overview.html
463
473
 
464
474
  This LabelFormatter can't be used in autoscaling clusters since accelerators
465
- may map to multiple label, so we're not implementing `get_label_value`
475
+ may map to multiple label, so we're not implementing `get_label_values`
466
476
  """
467
477
 
468
478
  LABEL_KEY = 'nvidia.com/gpu.product'
@@ -476,10 +486,10 @@ class GFDLabelFormatter(GPULabelFormatter):
476
486
  return [cls.LABEL_KEY]
477
487
 
478
488
  @classmethod
479
- def get_label_value(cls, accelerator: str) -> str:
480
- """An accelerator can map to many Nvidia GFD labels
481
- (e.g., A100-80GB-PCIE vs. A100-SXM4-80GB).
482
- As a result, we do not support get_label_value for GFDLabelFormatter."""
489
+ def get_label_values(cls, accelerator: str) -> List[str]:
490
+ # An accelerator can map to many Nvidia GFD labels
491
+ # (e.g., A100-80GB-PCIE vs. A100-SXM4-80GB).
492
+ # TODO implement get_label_values for GFDLabelFormatter
483
493
  raise NotImplementedError
484
494
 
485
495
  @classmethod
@@ -1022,15 +1032,17 @@ def check_instance_fits(context: Optional[str],
1022
1032
  # met.
1023
1033
  assert acc_count is not None, (acc_type, acc_count)
1024
1034
  try:
1025
- gpu_label_key, gpu_label_val, _, _ = (
1026
- get_accelerator_label_key_value(context, acc_type, acc_count))
1035
+ gpu_label_key, gpu_label_values, _, _ = (
1036
+ get_accelerator_label_key_values(context, acc_type, acc_count))
1037
+ if gpu_label_values is None:
1038
+ gpu_label_values = []
1027
1039
  except exceptions.ResourcesUnavailableError as e:
1028
1040
  # If GPU not found, return empty list and error message.
1029
1041
  return False, str(e)
1030
1042
  # Get the set of nodes that have the GPU type
1031
1043
  gpu_nodes = [
1032
1044
  node for node in nodes if gpu_label_key in node.metadata.labels and
1033
- node.metadata.labels[gpu_label_key] == gpu_label_val
1045
+ node.metadata.labels[gpu_label_key] in gpu_label_values
1034
1046
  ]
1035
1047
  if not gpu_nodes:
1036
1048
  return False, f'No GPU nodes found with {acc_type} on the cluster'
@@ -1072,12 +1084,12 @@ def check_instance_fits(context: Optional[str],
1072
1084
  return fits, reason
1073
1085
 
1074
1086
 
1075
- def get_accelerator_label_key_value(
1087
+ def get_accelerator_label_key_values(
1076
1088
  context: Optional[str],
1077
1089
  acc_type: str,
1078
1090
  acc_count: int,
1079
1091
  check_mode=False
1080
- ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
1092
+ ) -> Tuple[Optional[str], Optional[List[str]], Optional[str], Optional[str]]:
1081
1093
  """Returns the label key and value for the given GPU/TPU type.
1082
1094
 
1083
1095
  Args:
@@ -1131,7 +1143,7 @@ def get_accelerator_label_key_value(
1131
1143
  tpu_topology_label_key = formatter.get_tpu_topology_label_key()
1132
1144
  tpu_topology_label_value = formatter.get_tpu_topology_label_value(
1133
1145
  acc_type, acc_count)
1134
- return formatter.get_label_key(acc_type), formatter.get_label_value(
1146
+ return formatter.get_label_key(acc_type), formatter.get_label_values(
1135
1147
  acc_type), tpu_topology_label_key, tpu_topology_label_value
1136
1148
 
1137
1149
  has_gpus, cluster_resources = detect_accelerator_resource(context)
@@ -1210,12 +1222,12 @@ def get_accelerator_label_key_value(
1210
1222
  # different topologies that maps to identical
1211
1223
  # number of TPU chips.
1212
1224
  if tpu_topology_chip_count == acc_count:
1213
- return (label, value, topology_label_key,
1225
+ return (label, [value], topology_label_key,
1214
1226
  topology_value)
1215
1227
  else:
1216
1228
  continue
1217
1229
  else:
1218
- return label, value, None, None
1230
+ return label, [value], None, None
1219
1231
 
1220
1232
  # If no node is found with the requested acc_type, raise error
1221
1233
  with ux_utils.print_exception_no_traceback():
@@ -1377,10 +1389,10 @@ def check_credentials(context: Optional[str],
1377
1389
  # `get_unlabeled_accelerator_nodes`.
1378
1390
  # Therefore, if `get_unlabeled_accelerator_nodes` detects unlabelled
1379
1391
  # nodes, we skip this check.
1380
- get_accelerator_label_key_value(context,
1381
- acc_type='',
1382
- acc_count=0,
1383
- check_mode=True)
1392
+ get_accelerator_label_key_values(context,
1393
+ acc_type='',
1394
+ acc_count=0,
1395
+ check_mode=True)
1384
1396
  except exceptions.ResourcesUnavailableError as e:
1385
1397
  # If GPUs are not available, we return cluster as enabled
1386
1398
  # (since it can be a CPU-only cluster) but we also return the
@@ -149,6 +149,12 @@ def bulk_provision(
149
149
  # Skip the teardown if the cloud config is expired and
150
150
  # the provisioner should failover to other clouds.
151
151
  raise
152
+ except exceptions.InconsistentHighAvailabilityError:
153
+ # Skip the teardown if the high availability property in the
154
+ # user config is inconsistent with the actual cluster.
155
+ # This error is a user error instead of a provisioning failure.
156
+ # And there is no possibility to fix it by teardown.
157
+ raise
152
158
  except Exception: # pylint: disable=broad-except
153
159
  zone_str = 'all zones'
154
160
  if zones:
@@ -387,11 +387,12 @@ class ReplicaStatusProperty:
387
387
  class ReplicaInfo:
388
388
  """Replica info for each replica."""
389
389
 
390
- _VERSION = 1
390
+ _VERSION = 2
391
391
 
392
392
  def __init__(self, replica_id: int, cluster_name: str, replica_port: str,
393
393
  is_spot: bool, location: Optional[spot_placer.Location],
394
- version: int) -> None:
394
+ version: int, resources_override: Optional[Dict[str,
395
+ Any]]) -> None:
395
396
  self._version = self._VERSION
396
397
  self.replica_id: int = replica_id
397
398
  self.cluster_name: str = cluster_name
@@ -403,6 +404,7 @@ class ReplicaInfo:
403
404
  self.is_spot: bool = is_spot
404
405
  self.location: Optional[Dict[str, Optional[str]]] = (
405
406
  location.to_pickleable() if location is not None else None)
407
+ self.resources_override: Optional[Dict[str, Any]] = resources_override
406
408
 
407
409
  def get_spot_location(self) -> Optional[spot_placer.Location]:
408
410
  return spot_placer.Location.from_pickleable(self.location)
@@ -569,6 +571,9 @@ class ReplicaInfo:
569
571
  if version < 1:
570
572
  self.location = None
571
573
 
574
+ if version < 2:
575
+ self.resources_override = None
576
+
572
577
  self.__dict__.update(state)
573
578
 
574
579
 
@@ -650,6 +655,44 @@ class SkyPilotReplicaManager(ReplicaManager):
650
655
  threading.Thread(target=self._job_status_fetcher).start()
651
656
  threading.Thread(target=self._replica_prober).start()
652
657
 
658
+ self._recover_replica_operations()
659
+
660
+ def _recover_replica_operations(self):
661
+ """Let's see are there something to do for ReplicaManager in a
662
+ recovery run"""
663
+ assert (not self._launch_process_pool and not self._down_process_pool
664
+ ), 'We should not have any running processes in a recovery run'
665
+
666
+ # There is a FIFO queue with capacity _MAX_NUM_LAUNCH for
667
+ # _launch_replica.
668
+ # We prioritize PROVISIONING replicas since they were previously
669
+ # launched but may have been interrupted and need to be restarted.
670
+ # This is why we process PENDING replicas only after PROVISIONING
671
+ # replicas.
672
+ to_up_replicas = serve_state.get_replicas_at_status(
673
+ self._service_name, serve_state.ReplicaStatus.PROVISIONING)
674
+ to_up_replicas.extend(
675
+ serve_state.get_replicas_at_status(
676
+ self._service_name, serve_state.ReplicaStatus.PENDING))
677
+
678
+ for replica_info in to_up_replicas:
679
+ # It should be robust enough for `execution.launch` to handle cases
680
+ # where the provisioning is partially done.
681
+ # So we mock the original request based on all call sites,
682
+ # including SkyServeController._run_autoscaler.
683
+ self._launch_replica(
684
+ replica_info.replica_id,
685
+ resources_override=replica_info.resources_override)
686
+
687
+ for replica_info in serve_state.get_replicas_at_status(
688
+ self._service_name, serve_state.ReplicaStatus.SHUTTING_DOWN):
689
+ self._terminate_replica(
690
+ replica_info.replica_id,
691
+ sync_down_logs=False,
692
+ replica_drain_delay_seconds=0,
693
+ purge=replica_info.status_property.purged,
694
+ is_scale_down=replica_info.status_property.is_scale_down)
695
+
653
696
  ################################
654
697
  # Replica management functions #
655
698
  ################################
@@ -705,7 +748,7 @@ class SkyPilotReplicaManager(ReplicaManager):
705
748
  replica_port = _get_resources_ports(self._task_yaml_path)
706
749
 
707
750
  info = ReplicaInfo(replica_id, cluster_name, replica_port, use_spot,
708
- location, self.latest_version)
751
+ location, self.latest_version, resources_override)
709
752
  serve_state.add_or_update_replica(self._service_name, replica_id, info)
710
753
  # Don't start right now; we will start it later in _refresh_process_pool
711
754
  # to avoid too many sky.launch running at the same time.
@@ -884,7 +927,9 @@ class SkyPilotReplicaManager(ReplicaManager):
884
927
  the fly. If any of them finished, it will update the status of the
885
928
  corresponding replica.
886
929
  """
887
- for replica_id, p in list(self._launch_process_pool.items()):
930
+ # To avoid `dictionary changed size during iteration` error.
931
+ launch_process_pool_snapshot = list(self._launch_process_pool.items())
932
+ for replica_id, p in launch_process_pool_snapshot:
888
933
  if not p.is_alive():
889
934
  info = serve_state.get_replica_info_from_id(
890
935
  self._service_name, replica_id)
@@ -943,7 +988,8 @@ class SkyPilotReplicaManager(ReplicaManager):
943
988
  self._terminate_replica(replica_id,
944
989
  sync_down_logs=True,
945
990
  replica_drain_delay_seconds=0)
946
- for replica_id, p in list(self._down_process_pool.items()):
991
+ down_process_pool_snapshot = list(self._down_process_pool.items())
992
+ for replica_id, p in down_process_pool_snapshot:
947
993
  if not p.is_alive():
948
994
  logger.info(
949
995
  f'Terminate process for replica {replica_id} finished.')
sky/serve/serve_state.py CHANGED
@@ -479,6 +479,14 @@ def total_number_provisioning_replicas() -> int:
479
479
  return provisioning_count
480
480
 
481
481
 
482
+ def get_replicas_at_status(
483
+ service_name: str,
484
+ status: ReplicaStatus,
485
+ ) -> List['replica_managers.ReplicaInfo']:
486
+ replicas = get_replica_infos(service_name)
487
+ return [replica for replica in replicas if replica.status == status]
488
+
489
+
482
490
  # === Version functions ===
483
491
  def add_version(service_name: str) -> int:
484
492
  """Adds a version to the database."""
@@ -549,3 +557,36 @@ def delete_all_versions(service_name: str) -> None:
549
557
  """\
550
558
  DELETE FROM version_specs
551
559
  WHERE service_name=(?)""", (service_name,))
560
+
561
+
562
+ def get_latest_version(service_name: str) -> Optional[int]:
563
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
564
+ rows = cursor.execute(
565
+ """\
566
+ SELECT MAX(version) FROM version_specs
567
+ WHERE service_name=(?)""", (service_name,)).fetchall()
568
+ if not rows or rows[0][0] is None:
569
+ return None
570
+ return rows[0][0]
571
+
572
+
573
+ def get_service_controller_port(service_name: str) -> int:
574
+ """Gets the controller port of a service."""
575
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
576
+ cursor.execute('SELECT controller_port FROM services WHERE name = ?',
577
+ (service_name,))
578
+ row = cursor.fetchone()
579
+ if row is None:
580
+ raise ValueError(f'Service {service_name} does not exist.')
581
+ return row[0]
582
+
583
+
584
+ def get_service_load_balancer_port(service_name: str) -> int:
585
+ """Gets the load balancer port of a service."""
586
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
587
+ cursor.execute('SELECT load_balancer_port FROM services WHERE name = ?',
588
+ (service_name,))
589
+ row = cursor.fetchone()
590
+ if row is None:
591
+ raise ValueError(f'Service {service_name} does not exist.')
592
+ return row[0]
sky/serve/service.py CHANGED
@@ -25,6 +25,7 @@ from sky.serve import load_balancer
25
25
  from sky.serve import replica_managers
26
26
  from sky.serve import serve_state
27
27
  from sky.serve import serve_utils
28
+ from sky.skylet import constants as skylet_constants
28
29
  from sky.utils import common_utils
29
30
  from sky.utils import subprocess_utils
30
31
  from sky.utils import ux_utils
@@ -136,8 +137,25 @@ def _cleanup(service_name: str) -> bool:
136
137
  return failed
137
138
 
138
139
 
140
+ def _cleanup_task_run_script(job_id: int) -> None:
141
+ """Clean up task run script.
142
+ Please see `kubernetes-ray.yml.j2` for more details.
143
+ """
144
+ task_run_dir = pathlib.Path(
145
+ skylet_constants.PERSISTENT_RUN_SCRIPT_DIR).expanduser()
146
+ if task_run_dir.exists():
147
+ this_task_run_script = task_run_dir / f'sky_job_{job_id}'
148
+ if this_task_run_script.exists():
149
+ this_task_run_script.unlink()
150
+ logger.info(f'Task run script {this_task_run_script} removed')
151
+ else:
152
+ logger.warning(f'Task run script {this_task_run_script} not found')
153
+
154
+
139
155
  def _start(service_name: str, tmp_task_yaml: str, job_id: int):
140
- """Starts the service."""
156
+ """Starts the service.
157
+ This including the controller and load balancer.
158
+ """
141
159
  # Generate ssh key pair to avoid race condition when multiple sky.launch
142
160
  # are executed at the same time.
143
161
  authentication.get_or_generate_keys()
@@ -147,62 +165,79 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
147
165
  # Already checked before submit to controller.
148
166
  assert task.service is not None, task
149
167
  service_spec = task.service
150
- if (len(serve_state.get_services()) >=
151
- serve_utils.get_num_service_threshold()):
152
- cleanup_storage(tmp_task_yaml)
153
- with ux_utils.print_exception_no_traceback():
154
- raise RuntimeError('Max number of services reached.')
155
- success = serve_state.add_service(
156
- service_name,
157
- controller_job_id=job_id,
158
- policy=service_spec.autoscaling_policy_str(),
159
- requested_resources_str=backend_utils.get_task_resources_str(task),
160
- load_balancing_policy=service_spec.load_balancing_policy,
161
- status=serve_state.ServiceStatus.CONTROLLER_INIT,
162
- tls_encrypted=service_spec.tls_credential is not None)
163
- # Directly throw an error here. See sky/serve/api.py::up
164
- # for more details.
165
- if not success:
166
- cleanup_storage(tmp_task_yaml)
167
- with ux_utils.print_exception_no_traceback():
168
- raise ValueError(f'Service {service_name} already exists.')
169
-
170
- # Add initial version information to the service state.
171
- serve_state.add_or_update_version(service_name, constants.INITIAL_VERSION,
172
- service_spec)
173
-
174
- # Create the service working directory.
168
+
169
+ def is_recovery_mode(service_name: str) -> bool:
170
+ """Check if service exists in database to determine recovery mode.
171
+ """
172
+ service = serve_state.get_service_from_name(service_name)
173
+ return service is not None
174
+
175
+ is_recovery = is_recovery_mode(service_name)
176
+ logger.info(f'It is a {"first" if not is_recovery else "recovery"} run')
177
+
178
+ if is_recovery:
179
+ version = serve_state.get_latest_version(service_name)
180
+ if version is None:
181
+ raise ValueError(f'No version found for service {service_name}')
182
+ else:
183
+ version = constants.INITIAL_VERSION
184
+ # Add initial version information to the service state.
185
+ serve_state.add_or_update_version(service_name, version, service_spec)
186
+
175
187
  service_dir = os.path.expanduser(
176
188
  serve_utils.generate_remote_service_dir_name(service_name))
177
- os.makedirs(service_dir, exist_ok=True)
178
-
179
- # Copy the tmp task yaml file to the final task yaml file.
180
- # This is for the service name conflict case. The _execute will
181
- # sync file mounts first and then realized a name conflict. We
182
- # don't want the new file mounts to overwrite the old one, so we
183
- # sync to a tmp file first and then copy it to the final name
184
- # if there is no name conflict.
185
- task_yaml = serve_utils.generate_task_yaml_file_name(
186
- service_name, constants.INITIAL_VERSION)
187
- shutil.copy(tmp_task_yaml, task_yaml)
188
-
189
- # Generate load balancer log file name.
190
- load_balancer_log_file = os.path.expanduser(
191
- serve_utils.generate_remote_load_balancer_log_file_name(service_name))
189
+ task_yaml = serve_utils.generate_task_yaml_file_name(service_name, version)
190
+
191
+ if not is_recovery:
192
+ if (len(serve_state.get_services()) >=
193
+ serve_utils.get_num_service_threshold()):
194
+ cleanup_storage(tmp_task_yaml)
195
+ with ux_utils.print_exception_no_traceback():
196
+ raise RuntimeError('Max number of services reached.')
197
+ success = serve_state.add_service(
198
+ service_name,
199
+ controller_job_id=job_id,
200
+ policy=service_spec.autoscaling_policy_str(),
201
+ requested_resources_str=backend_utils.get_task_resources_str(task),
202
+ load_balancing_policy=service_spec.load_balancing_policy,
203
+ status=serve_state.ServiceStatus.CONTROLLER_INIT,
204
+ tls_encrypted=service_spec.tls_credential is not None)
205
+ # Directly throw an error here. See sky/serve/api.py::up
206
+ # for more details.
207
+ if not success:
208
+ cleanup_storage(tmp_task_yaml)
209
+ with ux_utils.print_exception_no_traceback():
210
+ raise ValueError(f'Service {service_name} already exists.')
211
+
212
+ # Create the service working directory.
213
+ os.makedirs(service_dir, exist_ok=True)
214
+
215
+ # Copy the tmp task yaml file to the final task yaml file.
216
+ # This is for the service name conflict case. The _execute will
217
+ # sync file mounts first and then realized a name conflict. We
218
+ # don't want the new file mounts to overwrite the old one, so we
219
+ # sync to a tmp file first and then copy it to the final name
220
+ # if there is no name conflict.
221
+ shutil.copy(tmp_task_yaml, task_yaml)
192
222
 
193
223
  controller_process = None
194
224
  load_balancer_process = None
195
225
  try:
196
226
  with filelock.FileLock(
197
227
  os.path.expanduser(constants.PORT_SELECTION_FILE_LOCK_PATH)):
198
- controller_port = common_utils.find_free_port(
199
- constants.CONTROLLER_PORT_START)
200
-
201
- # We expose the controller to the public network when running
202
- # inside a kubernetes cluster to allow external load balancers
203
- # (example, for high availability load balancers) to communicate
204
- # with the controller.
205
- def _get_host():
228
+ # Start the controller.
229
+ controller_port = (
230
+ common_utils.find_free_port(constants.CONTROLLER_PORT_START)
231
+ if not is_recovery else
232
+ serve_state.get_service_controller_port(service_name))
233
+
234
+ def _get_controller_host():
235
+ """Get the controller host address.
236
+ We expose the controller to the public network when running
237
+ inside a kubernetes cluster to allow external load balancers
238
+ (example, for high availability load balancers) to communicate
239
+ with the controller.
240
+ """
206
241
  if 'KUBERNETES_SERVICE_HOST' in os.environ:
207
242
  return '0.0.0.0'
208
243
  # Not using localhost to avoid using ipv6 address and causing
@@ -211,26 +246,28 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
211
246
  # ('::1', 20001, 0, 0): cannot assign requested address
212
247
  return '127.0.0.1'
213
248
 
214
- controller_host = _get_host()
215
-
216
- # Start the controller.
249
+ controller_host = _get_controller_host()
217
250
  controller_process = multiprocessing.Process(
218
251
  target=controller.run_controller,
219
252
  args=(service_name, service_spec, task_yaml, controller_host,
220
253
  controller_port))
221
254
  controller_process.start()
222
- serve_state.set_service_controller_port(service_name,
223
- controller_port)
224
255
 
225
- controller_addr = f'http://{controller_host}:{controller_port}'
256
+ if not is_recovery:
257
+ serve_state.set_service_controller_port(service_name,
258
+ controller_port)
226
259
 
227
- load_balancer_port = common_utils.find_free_port(
228
- constants.LOAD_BALANCER_PORT_START)
229
-
230
- # Extract the load balancing policy from the service spec
231
- policy_name = service_spec.load_balancing_policy
260
+ controller_addr = f'http://{controller_host}:{controller_port}'
232
261
 
233
262
  # Start the load balancer.
263
+ load_balancer_port = (
264
+ common_utils.find_free_port(constants.LOAD_BALANCER_PORT_START)
265
+ if not is_recovery else
266
+ serve_state.get_service_load_balancer_port(service_name))
267
+ load_balancer_log_file = os.path.expanduser(
268
+ serve_utils.generate_remote_load_balancer_log_file_name(
269
+ service_name))
270
+
234
271
  # TODO(tian): Probably we could enable multiple ports specified in
235
272
  # service spec and we could start multiple load balancers.
236
273
  # After that, we will have a mapping from replica port to endpoint.
@@ -238,11 +275,14 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
238
275
  target=ux_utils.RedirectOutputForProcess(
239
276
  load_balancer.run_load_balancer,
240
277
  load_balancer_log_file).run,
241
- args=(controller_addr, load_balancer_port, policy_name,
278
+ args=(controller_addr, load_balancer_port,
279
+ service_spec.load_balancing_policy,
242
280
  service_spec.tls_credential))
243
281
  load_balancer_process.start()
244
- serve_state.set_service_load_balancer_port(service_name,
245
- load_balancer_port)
282
+
283
+ if not is_recovery:
284
+ serve_state.set_service_load_balancer_port(
285
+ service_name, load_balancer_port)
246
286
 
247
287
  while True:
248
288
  _handle_signal(service_name)
@@ -262,6 +302,7 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
262
302
  force=True)
263
303
  for process in process_to_kill:
264
304
  process.join()
305
+
265
306
  failed = _cleanup(service_name)
266
307
  if failed:
267
308
  serve_state.set_service_status_and_active_versions(
@@ -273,8 +314,12 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
273
314
  serve_state.delete_all_versions(service_name)
274
315
  logger.info(f'Service {service_name} terminated successfully.')
275
316
 
317
+ _cleanup_task_run_script(job_id)
318
+
276
319
 
277
320
  if __name__ == '__main__':
321
+ logger.info('Starting service...')
322
+
278
323
  parser = argparse.ArgumentParser(description='Sky Serve Service')
279
324
  parser.add_argument('--service-name',
280
325
  type=str,
sky/server/common.py CHANGED
@@ -333,7 +333,7 @@ def _start_api_server(deploy: bool = False,
333
333
  break
334
334
 
335
335
  server_url = get_server_url(host)
336
- dashboard_msg = (f'Dashboard: {get_dashboard_url(server_url)}')
336
+ dashboard_msg = ''
337
337
  api_server_info = get_api_server_status(server_url)
338
338
  if api_server_info.version == _DEV_VERSION:
339
339
  dashboard_msg += (
@@ -343,12 +343,15 @@ def _start_api_server(deploy: bool = False,
343
343
  dashboard_msg += (
344
344
  'Dashboard is not built, '
345
345
  'to build: npm --prefix sky/dashboard install '
346
- '&& npm --prefix sky/dashboard run build')
346
+ '&& npm --prefix sky/dashboard run build\n')
347
347
  else:
348
348
  dashboard_msg += (
349
349
  'Dashboard may be stale when installed from source, '
350
350
  'to rebuild: npm --prefix sky/dashboard install '
351
- '&& npm --prefix sky/dashboard run build')
351
+ '&& npm --prefix sky/dashboard run build\n')
352
+ dashboard_msg += (
353
+ f'{ux_utils.INDENT_LAST_SYMBOL}{colorama.Fore.GREEN}'
354
+ f'Dashboard: {get_dashboard_url(server_url)}')
352
355
  dashboard_msg += f'{colorama.Style.RESET_ALL}'
353
356
  logger.info(
354
357
  ux_utils.finishing_message(