skypilot-nightly 1.0.0.dev20241120__py3-none-any.whl → 1.0.0.dev20241122__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +20 -15
  3. sky/backends/cloud_vm_ray_backend.py +21 -3
  4. sky/clouds/aws.py +1 -0
  5. sky/clouds/azure.py +1 -0
  6. sky/clouds/cloud.py +1 -0
  7. sky/clouds/cudo.py +1 -0
  8. sky/clouds/fluidstack.py +1 -0
  9. sky/clouds/gcp.py +1 -0
  10. sky/clouds/ibm.py +1 -0
  11. sky/clouds/kubernetes.py +45 -3
  12. sky/clouds/lambda_cloud.py +1 -0
  13. sky/clouds/oci.py +1 -0
  14. sky/clouds/paperspace.py +1 -0
  15. sky/clouds/runpod.py +1 -0
  16. sky/clouds/scp.py +1 -0
  17. sky/clouds/vsphere.py +1 -0
  18. sky/provision/instance_setup.py +80 -83
  19. sky/provision/kubernetes/instance.py +108 -76
  20. sky/provision/kubernetes/utils.py +2 -0
  21. sky/provision/oci/instance.py +4 -2
  22. sky/provision/provisioner.py +95 -19
  23. sky/resources.py +2 -1
  24. sky/skylet/constants.py +31 -21
  25. sky/templates/kubernetes-ray.yml.j2 +169 -39
  26. sky/utils/subprocess_utils.py +49 -4
  27. {skypilot_nightly-1.0.0.dev20241120.dist-info → skypilot_nightly-1.0.0.dev20241122.dist-info}/METADATA +65 -55
  28. {skypilot_nightly-1.0.0.dev20241120.dist-info → skypilot_nightly-1.0.0.dev20241122.dist-info}/RECORD +32 -32
  29. {skypilot_nightly-1.0.0.dev20241120.dist-info → skypilot_nightly-1.0.0.dev20241122.dist-info}/WHEEL +1 -1
  30. {skypilot_nightly-1.0.0.dev20241120.dist-info → skypilot_nightly-1.0.0.dev20241122.dist-info}/LICENSE +0 -0
  31. {skypilot_nightly-1.0.0.dev20241120.dist-info → skypilot_nightly-1.0.0.dev20241122.dist-info}/entry_points.txt +0 -0
  32. {skypilot_nightly-1.0.0.dev20241120.dist-info → skypilot_nightly-1.0.0.dev20241122.dist-info}/top_level.txt +0 -0
@@ -20,12 +20,13 @@ from sky.utils import command_runner
20
20
  from sky.utils import common_utils
21
21
  from sky.utils import kubernetes_enums
22
22
  from sky.utils import subprocess_utils
23
+ from sky.utils import timeline
23
24
  from sky.utils import ux_utils
24
25
 
25
26
  POLL_INTERVAL = 2
26
27
  _TIMEOUT_FOR_POD_TERMINATION = 60 # 1 minutes
27
28
  _MAX_RETRIES = 3
28
- NUM_THREADS = subprocess_utils.get_parallel_threads() * 2
29
+ _NUM_THREADS = subprocess_utils.get_parallel_threads('kubernetes')
29
30
 
30
31
  logger = sky_logging.init_logger(__name__)
31
32
  TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
@@ -120,6 +121,9 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes):
120
121
  are recorded as events. This function retrieves those events and raises
121
122
  descriptive errors for better debugging and user feedback.
122
123
  """
124
+ timeout_err_msg = ('Timed out while waiting for nodes to start. '
125
+ 'Cluster may be out of resources or '
126
+ 'may be too slow to autoscale.')
123
127
  for new_node in new_nodes:
124
128
  pod = kubernetes.core_api(context).read_namespaced_pod(
125
129
  new_node.metadata.name, namespace)
@@ -148,9 +152,6 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes):
148
152
  if event.reason == 'FailedScheduling':
149
153
  event_message = event.message
150
154
  break
151
- timeout_err_msg = ('Timed out while waiting for nodes to start. '
152
- 'Cluster may be out of resources or '
153
- 'may be too slow to autoscale.')
154
155
  if event_message is not None:
155
156
  if pod_status == 'Pending':
156
157
  logger.info(event_message)
@@ -219,6 +220,7 @@ def _raise_command_running_error(message: str, command: str, pod_name: str,
219
220
  f'code {rc}: {command!r}\nOutput: {stdout}.')
220
221
 
221
222
 
223
+ @timeline.event
222
224
  def _wait_for_pods_to_schedule(namespace, context, new_nodes, timeout: int):
223
225
  """Wait for all pods to be scheduled.
224
226
 
@@ -229,6 +231,10 @@ def _wait_for_pods_to_schedule(namespace, context, new_nodes, timeout: int):
229
231
 
230
232
  If timeout is set to a negative value, this method will wait indefinitely.
231
233
  """
234
+ # Create a set of pod names we're waiting for
235
+ if not new_nodes:
236
+ return
237
+ expected_pod_names = {node.metadata.name for node in new_nodes}
232
238
  start_time = time.time()
233
239
 
234
240
  def _evaluate_timeout() -> bool:
@@ -238,19 +244,34 @@ def _wait_for_pods_to_schedule(namespace, context, new_nodes, timeout: int):
238
244
  return time.time() - start_time < timeout
239
245
 
240
246
  while _evaluate_timeout():
241
- all_pods_scheduled = True
242
- for node in new_nodes:
243
- # Iterate over each pod to check their status
244
- pod = kubernetes.core_api(context).read_namespaced_pod(
245
- node.metadata.name, namespace)
246
- if pod.status.phase == 'Pending':
247
+ # Get all pods in a single API call using the cluster name label
248
+ # which all pods in new_nodes should share
249
+ cluster_name = new_nodes[0].metadata.labels[TAG_SKYPILOT_CLUSTER_NAME]
250
+ pods = kubernetes.core_api(context).list_namespaced_pod(
251
+ namespace,
252
+ label_selector=f'{TAG_SKYPILOT_CLUSTER_NAME}={cluster_name}').items
253
+
254
+ # Get the set of found pod names and check if we have all expected pods
255
+ found_pod_names = {pod.metadata.name for pod in pods}
256
+ missing_pods = expected_pod_names - found_pod_names
257
+ if missing_pods:
258
+ logger.info('Retrying waiting for pods: '
259
+ f'Missing pods: {missing_pods}')
260
+ time.sleep(0.5)
261
+ continue
262
+
263
+ # Check if all pods are scheduled
264
+ all_scheduled = True
265
+ for pod in pods:
266
+ if (pod.metadata.name in expected_pod_names and
267
+ pod.status.phase == 'Pending'):
247
268
  # If container_statuses is None, then the pod hasn't
248
269
  # been scheduled yet.
249
270
  if pod.status.container_statuses is None:
250
- all_pods_scheduled = False
271
+ all_scheduled = False
251
272
  break
252
273
 
253
- if all_pods_scheduled:
274
+ if all_scheduled:
254
275
  return
255
276
  time.sleep(1)
256
277
 
@@ -266,12 +287,18 @@ def _wait_for_pods_to_schedule(namespace, context, new_nodes, timeout: int):
266
287
  f'Error: {common_utils.format_exception(e)}') from None
267
288
 
268
289
 
290
+ @timeline.event
269
291
  def _wait_for_pods_to_run(namespace, context, new_nodes):
270
292
  """Wait for pods and their containers to be ready.
271
293
 
272
294
  Pods may be pulling images or may be in the process of container
273
295
  creation.
274
296
  """
297
+ if not new_nodes:
298
+ return
299
+
300
+ # Create a set of pod names we're waiting for
301
+ expected_pod_names = {node.metadata.name for node in new_nodes}
275
302
 
276
303
  def _check_init_containers(pod):
277
304
  # Check if any of the init containers failed
@@ -299,12 +326,25 @@ def _wait_for_pods_to_run(namespace, context, new_nodes):
299
326
  f'{pod.metadata.name}. Error details: {msg}.')
300
327
 
301
328
  while True:
302
- all_pods_running = True
303
- # Iterate over each pod to check their status
304
- for node in new_nodes:
305
- pod = kubernetes.core_api(context).read_namespaced_pod(
306
- node.metadata.name, namespace)
329
+ # Get all pods in a single API call
330
+ cluster_name = new_nodes[0].metadata.labels[TAG_SKYPILOT_CLUSTER_NAME]
331
+ all_pods = kubernetes.core_api(context).list_namespaced_pod(
332
+ namespace,
333
+ label_selector=f'{TAG_SKYPILOT_CLUSTER_NAME}={cluster_name}').items
334
+
335
+ # Get the set of found pod names and check if we have all expected pods
336
+ found_pod_names = {pod.metadata.name for pod in all_pods}
337
+ missing_pods = expected_pod_names - found_pod_names
338
+ if missing_pods:
339
+ logger.info('Retrying running pods check: '
340
+ f'Missing pods: {missing_pods}')
341
+ time.sleep(0.5)
342
+ continue
307
343
 
344
+ all_pods_running = True
345
+ for pod in all_pods:
346
+ if pod.metadata.name not in expected_pod_names:
347
+ continue
308
348
  # Continue if pod and all the containers within the
309
349
  # pod are successfully created and running.
310
350
  if pod.status.phase == 'Running' and all(
@@ -367,6 +407,7 @@ def _run_function_with_retries(func: Callable,
367
407
  raise
368
408
 
369
409
 
410
+ @timeline.event
370
411
  def pre_init(namespace: str, context: Optional[str], new_nodes: List) -> None:
371
412
  """Pre-initialization step for SkyPilot pods.
372
413
 
@@ -514,7 +555,7 @@ def pre_init(namespace: str, context: Optional[str], new_nodes: List) -> None:
514
555
  logger.info(f'{"-"*20}End: Pre-init in pod {pod_name!r} {"-"*20}')
515
556
 
516
557
  # Run pre_init in parallel across all new_nodes
517
- subprocess_utils.run_in_parallel(_pre_init_thread, new_nodes, NUM_THREADS)
558
+ subprocess_utils.run_in_parallel(_pre_init_thread, new_nodes, _NUM_THREADS)
518
559
 
519
560
 
520
561
  def _label_pod(namespace: str, context: Optional[str], pod_name: str,
@@ -528,6 +569,7 @@ def _label_pod(namespace: str, context: Optional[str], pod_name: str,
528
569
  _request_timeout=kubernetes.API_TIMEOUT)
529
570
 
530
571
 
572
+ @timeline.event
531
573
  def _create_namespaced_pod_with_retries(namespace: str, pod_spec: dict,
532
574
  context: Optional[str]) -> Any:
533
575
  """Attempts to create a Kubernetes Pod and handle any errors.
@@ -606,6 +648,7 @@ def _create_namespaced_pod_with_retries(namespace: str, pod_spec: dict,
606
648
  raise e
607
649
 
608
650
 
651
+ @timeline.event
609
652
  def _create_pods(region: str, cluster_name_on_cloud: str,
610
653
  config: common.ProvisionConfig) -> common.ProvisionRecord:
611
654
  """Create pods based on the config."""
@@ -627,7 +670,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
627
670
  terminating_pods = kubernetes_utils.filter_pods(namespace, context, tags,
628
671
  ['Terminating'])
629
672
  start_time = time.time()
630
- while (len(terminating_pods) > 0 and
673
+ while (terminating_pods and
631
674
  time.time() - start_time < _TIMEOUT_FOR_POD_TERMINATION):
632
675
  logger.debug(f'run_instances: Found {len(terminating_pods)} '
633
676
  'terminating pods. Waiting them to finish: '
@@ -636,7 +679,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
636
679
  terminating_pods = kubernetes_utils.filter_pods(namespace, context,
637
680
  tags, ['Terminating'])
638
681
 
639
- if len(terminating_pods) > 0:
682
+ if terminating_pods:
640
683
  # If there are still terminating pods, we force delete them.
641
684
  logger.debug(f'run_instances: Found {len(terminating_pods)} '
642
685
  'terminating pods still in terminating state after '
@@ -695,24 +738,29 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
695
738
  created_pods = {}
696
739
  logger.debug(f'run_instances: calling create_namespaced_pod '
697
740
  f'(count={to_start_count}).')
698
- for _ in range(to_start_count):
699
- if head_pod_name is None:
700
- pod_spec['metadata']['labels'].update(constants.HEAD_NODE_TAGS)
741
+
742
+ def _create_pod_thread(i: int):
743
+ pod_spec_copy = copy.deepcopy(pod_spec)
744
+ if head_pod_name is None and i == 0:
745
+ # First pod should be head if no head exists
746
+ pod_spec_copy['metadata']['labels'].update(constants.HEAD_NODE_TAGS)
701
747
  head_selector = head_service_selector(cluster_name_on_cloud)
702
- pod_spec['metadata']['labels'].update(head_selector)
703
- pod_spec['metadata']['name'] = f'{cluster_name_on_cloud}-head'
748
+ pod_spec_copy['metadata']['labels'].update(head_selector)
749
+ pod_spec_copy['metadata']['name'] = f'{cluster_name_on_cloud}-head'
704
750
  else:
705
- pod_spec['metadata']['labels'].update(constants.WORKER_NODE_TAGS)
706
- pod_uuid = str(uuid.uuid4())[:4]
751
+ # Worker pods
752
+ pod_spec_copy['metadata']['labels'].update(
753
+ constants.WORKER_NODE_TAGS)
754
+ pod_uuid = str(uuid.uuid4())[:6]
707
755
  pod_name = f'{cluster_name_on_cloud}-{pod_uuid}'
708
- pod_spec['metadata']['name'] = f'{pod_name}-worker'
756
+ pod_spec_copy['metadata']['name'] = f'{pod_name}-worker'
709
757
  # For multi-node support, we put a soft-constraint to schedule
710
758
  # worker pods on different nodes than the head pod.
711
759
  # This is not set as a hard constraint because if different nodes
712
760
  # are not available, we still want to be able to schedule worker
713
761
  # pods on larger nodes which may be able to fit multiple SkyPilot
714
762
  # "nodes".
715
- pod_spec['spec']['affinity'] = {
763
+ pod_spec_copy['spec']['affinity'] = {
716
764
  'podAntiAffinity': {
717
765
  # Set as a soft constraint
718
766
  'preferredDuringSchedulingIgnoredDuringExecution': [{
@@ -747,17 +795,22 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
747
795
  'value': 'present',
748
796
  'effect': 'NoSchedule'
749
797
  }
750
- pod_spec['spec']['tolerations'] = [tpu_toleration]
798
+ pod_spec_copy['spec']['tolerations'] = [tpu_toleration]
751
799
 
752
- pod = _create_namespaced_pod_with_retries(namespace, pod_spec, context)
800
+ return _create_namespaced_pod_with_retries(namespace, pod_spec_copy,
801
+ context)
802
+
803
+ # Create pods in parallel
804
+ pods = subprocess_utils.run_in_parallel(_create_pod_thread,
805
+ range(to_start_count), _NUM_THREADS)
806
+
807
+ # Process created pods
808
+ for pod in pods:
753
809
  created_pods[pod.metadata.name] = pod
754
- if head_pod_name is None:
810
+ if head_pod_name is None and pod.metadata.labels.get(
811
+ constants.TAG_RAY_NODE_KIND) == 'head':
755
812
  head_pod_name = pod.metadata.name
756
813
 
757
- wait_pods_dict = kubernetes_utils.filter_pods(namespace, context, tags,
758
- ['Pending'])
759
- wait_pods = list(wait_pods_dict.values())
760
-
761
814
  networking_mode = network_utils.get_networking_mode(
762
815
  config.provider_config.get('networking_mode'))
763
816
  if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT:
@@ -766,52 +819,24 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
766
819
  ssh_jump_pod_name = pod_spec['metadata']['labels']['skypilot-ssh-jump']
767
820
  jump_pod = kubernetes.core_api(context).read_namespaced_pod(
768
821
  ssh_jump_pod_name, namespace)
769
- wait_pods.append(jump_pod)
822
+ pods.append(jump_pod)
770
823
  provision_timeout = provider_config['timeout']
771
824
 
772
825
  wait_str = ('indefinitely'
773
826
  if provision_timeout < 0 else f'for {provision_timeout}s')
774
827
  logger.debug(f'run_instances: waiting {wait_str} for pods to schedule and '
775
- f'run: {list(wait_pods_dict.keys())}')
828
+ f'run: {[pod.metadata.name for pod in pods]}')
776
829
 
777
830
  # Wait until the pods are scheduled and surface cause for error
778
831
  # if there is one
779
- _wait_for_pods_to_schedule(namespace, context, wait_pods, provision_timeout)
832
+ _wait_for_pods_to_schedule(namespace, context, pods, provision_timeout)
780
833
  # Wait until the pods and their containers are up and running, and
781
834
  # fail early if there is an error
782
835
  logger.debug(f'run_instances: waiting for pods to be running (pulling '
783
- f'images): {list(wait_pods_dict.keys())}')
784
- _wait_for_pods_to_run(namespace, context, wait_pods)
836
+ f'images): {[pod.metadata.name for pod in pods]}')
837
+ _wait_for_pods_to_run(namespace, context, pods)
785
838
  logger.debug(f'run_instances: all pods are scheduled and running: '
786
- f'{list(wait_pods_dict.keys())}')
787
-
788
- running_pods = kubernetes_utils.filter_pods(namespace, context, tags,
789
- ['Running'])
790
- initialized_pods = kubernetes_utils.filter_pods(namespace, context, {
791
- TAG_POD_INITIALIZED: 'true',
792
- **tags
793
- }, ['Running'])
794
- uninitialized_pods = {
795
- pod_name: pod
796
- for pod_name, pod in running_pods.items()
797
- if pod_name not in initialized_pods
798
- }
799
- if len(uninitialized_pods) > 0:
800
- logger.debug(f'run_instances: Initializing {len(uninitialized_pods)} '
801
- f'pods: {list(uninitialized_pods.keys())}')
802
- uninitialized_pods_list = list(uninitialized_pods.values())
803
-
804
- # Run pre-init steps in the pod.
805
- pre_init(namespace, context, uninitialized_pods_list)
806
-
807
- for pod in uninitialized_pods.values():
808
- _label_pod(namespace,
809
- context,
810
- pod.metadata.name,
811
- label={
812
- TAG_POD_INITIALIZED: 'true',
813
- **pod.metadata.labels
814
- })
839
+ f'{[pod.metadata.name for pod in pods]}')
815
840
 
816
841
  assert head_pod_name is not None, 'head_instance_id should not be None'
817
842
  return common.ProvisionRecord(
@@ -854,11 +879,6 @@ def _terminate_node(namespace: str, context: Optional[str],
854
879
  pod_name: str) -> None:
855
880
  """Terminate a pod."""
856
881
  logger.debug('terminate_instances: calling delete_namespaced_pod')
857
- try:
858
- kubernetes_utils.clean_zombie_ssh_jump_pod(namespace, context, pod_name)
859
- except Exception as e: # pylint: disable=broad-except
860
- logger.warning('terminate_instances: Error occurred when analyzing '
861
- f'SSH Jump pod: {e}')
862
882
  try:
863
883
  kubernetes.core_api(context).delete_namespaced_service(
864
884
  pod_name, namespace, _request_timeout=config_lib.DELETION_TIMEOUT)
@@ -895,6 +915,18 @@ def terminate_instances(
895
915
  }
896
916
  pods = kubernetes_utils.filter_pods(namespace, context, tag_filters, None)
897
917
 
918
+ # Clean up the SSH jump pod if in use
919
+ networking_mode = network_utils.get_networking_mode(
920
+ provider_config.get('networking_mode'))
921
+ if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT:
922
+ pod_name = list(pods.keys())[0]
923
+ try:
924
+ kubernetes_utils.clean_zombie_ssh_jump_pod(namespace, context,
925
+ pod_name)
926
+ except Exception as e: # pylint: disable=broad-except
927
+ logger.warning('terminate_instances: Error occurred when analyzing '
928
+ f'SSH Jump pod: {e}')
929
+
898
930
  def _is_head(pod) -> bool:
899
931
  return pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head'
900
932
 
@@ -907,7 +939,7 @@ def terminate_instances(
907
939
 
908
940
  # Run pod termination in parallel
909
941
  subprocess_utils.run_in_parallel(_terminate_pod_thread, pods.items(),
910
- NUM_THREADS)
942
+ _NUM_THREADS)
911
943
 
912
944
 
913
945
  def get_cluster_info(
@@ -28,6 +28,7 @@ from sky.utils import common_utils
28
28
  from sky.utils import env_options
29
29
  from sky.utils import kubernetes_enums
30
30
  from sky.utils import schemas
31
+ from sky.utils import timeline
31
32
  from sky.utils import ux_utils
32
33
 
33
34
  if typing.TYPE_CHECKING:
@@ -2053,6 +2054,7 @@ def get_namespace_from_config(provider_config: Dict[str, Any]) -> str:
2053
2054
  get_kube_config_context_namespace(context))
2054
2055
 
2055
2056
 
2057
+ @timeline.event
2056
2058
  def filter_pods(namespace: str,
2057
2059
  context: Optional[str],
2058
2060
  tag_filters: Dict[str, str],
@@ -123,8 +123,8 @@ def run_instances(region: str, cluster_name_on_cloud: str,
123
123
  # Let's create additional new nodes (if neccessary)
124
124
  to_start_count = config.count - len(resume_instances)
125
125
  created_instances = []
126
+ node_config = config.node_config
126
127
  if to_start_count > 0:
127
- node_config = config.node_config
128
128
  compartment = query_helper.find_compartment(region)
129
129
  vcn = query_helper.find_create_vcn_subnet(region)
130
130
 
@@ -242,10 +242,12 @@ def run_instances(region: str, cluster_name_on_cloud: str,
242
242
 
243
243
  assert head_instance_id is not None, head_instance_id
244
244
 
245
+ # Format: TenancyPrefix:AvailabilityDomain, e.g. bxtG:US-SANJOSE-1-AD-1
246
+ _, ad = str(node_config['AvailabilityDomain']).split(':', maxsplit=1)
245
247
  return common.ProvisionRecord(
246
248
  provider_name='oci',
247
249
  region=region,
248
- zone=None,
250
+ zone=ad,
249
251
  cluster_name=cluster_name_on_cloud,
250
252
  head_instance_id=head_instance_id,
251
253
  created_instance_ids=[n['inst_id'] for n in created_instances],
@@ -29,6 +29,7 @@ from sky.utils import common_utils
29
29
  from sky.utils import resources_utils
30
30
  from sky.utils import rich_utils
31
31
  from sky.utils import subprocess_utils
32
+ from sky.utils import timeline
32
33
  from sky.utils import ux_utils
33
34
 
34
35
  # Do not use __name__ as we do not want to propagate logs to sky.provision,
@@ -343,6 +344,7 @@ def _wait_ssh_connection_indirect(ip: str,
343
344
  return True, ''
344
345
 
345
346
 
347
+ @timeline.event
346
348
  def wait_for_ssh(cluster_info: provision_common.ClusterInfo,
347
349
  ssh_credentials: Dict[str, str]):
348
350
  """Wait until SSH is ready.
@@ -432,11 +434,15 @@ def _post_provision_setup(
432
434
  ux_utils.spinner_message(
433
435
  'Launching - Waiting for SSH access',
434
436
  provision_logging.config.log_path)) as status:
435
-
436
- logger.debug(
437
- f'\nWaiting for SSH to be available for {cluster_name!r} ...')
438
- wait_for_ssh(cluster_info, ssh_credentials)
439
- logger.debug(f'SSH Connection ready for {cluster_name!r}')
437
+ # If on Kubernetes, skip SSH check since the pods are guaranteed to be
438
+ # ready by the provisioner, and we use kubectl instead of SSH to run the
439
+ # commands and rsync on the pods. SSH will still be ready after a while
440
+ # for the users to SSH into the pod.
441
+ if cloud_name.lower() != 'kubernetes':
442
+ logger.debug(
443
+ f'\nWaiting for SSH to be available for {cluster_name!r} ...')
444
+ wait_for_ssh(cluster_info, ssh_credentials)
445
+ logger.debug(f'SSH Connection ready for {cluster_name!r}')
440
446
  vm_str = 'Instance' if cloud_name.lower() != 'kubernetes' else 'Pod'
441
447
  plural = '' if len(cluster_info.instances) == 1 else 's'
442
448
  verb = 'is' if len(cluster_info.instances) == 1 else 'are'
@@ -496,31 +502,94 @@ def _post_provision_setup(
496
502
  **ssh_credentials)
497
503
  head_runner = runners[0]
498
504
 
499
- status.update(
500
- runtime_preparation_str.format(step=3, step_name='runtime'))
501
- full_ray_setup = True
502
- ray_port = constants.SKY_REMOTE_RAY_PORT
503
- if not provision_record.is_instance_just_booted(
504
- head_instance.instance_id):
505
+ def is_ray_cluster_healthy(ray_status_output: str,
506
+ expected_num_nodes: int) -> bool:
507
+ """Parse the output of `ray status` to get #active nodes.
508
+
509
+ The output of `ray status` looks like:
510
+ Node status
511
+ ---------------------------------------------------------------
512
+ Active:
513
+ 1 node_291a8b849439ad6186387c35dc76dc43f9058108f09e8b68108cf9ec
514
+ 1 node_0945fbaaa7f0b15a19d2fd3dc48f3a1e2d7c97e4a50ca965f67acbfd
515
+ Pending:
516
+ (no pending nodes)
517
+ Recent failures:
518
+ (no failures)
519
+ """
520
+ start = ray_status_output.find('Active:')
521
+ end = ray_status_output.find('Pending:', start)
522
+ if start == -1 or end == -1:
523
+ return False
524
+ num_active_nodes = 0
525
+ for line in ray_status_output[start:end].split('\n'):
526
+ if line.strip() and not line.startswith('Active:'):
527
+ num_active_nodes += 1
528
+ return num_active_nodes == expected_num_nodes
529
+
530
+ def check_ray_port_and_cluster_healthy() -> Tuple[int, bool, bool]:
531
+ head_ray_needs_restart = True
532
+ ray_cluster_healthy = False
533
+ ray_port = constants.SKY_REMOTE_RAY_PORT
534
+
505
535
  # Check if head node Ray is alive
506
536
  returncode, stdout, _ = head_runner.run(
507
537
  instance_setup.RAY_STATUS_WITH_SKY_RAY_PORT_COMMAND,
508
538
  stream_logs=False,
509
539
  require_outputs=True)
510
- if returncode:
511
- logger.debug('Ray cluster on head is not up. Restarting...')
512
- else:
513
- logger.debug('Ray cluster on head is up.')
540
+ if not returncode:
514
541
  ray_port = common_utils.decode_payload(stdout)['ray_port']
515
- full_ray_setup = bool(returncode)
542
+ logger.debug(f'Ray cluster on head is up with port {ray_port}.')
543
+
544
+ head_ray_needs_restart = bool(returncode)
545
+ # This is a best effort check to see if the ray cluster has expected
546
+ # number of nodes connected.
547
+ ray_cluster_healthy = (not head_ray_needs_restart and
548
+ is_ray_cluster_healthy(
549
+ stdout, cluster_info.num_instances))
550
+ return ray_port, ray_cluster_healthy, head_ray_needs_restart
551
+
552
+ status.update(
553
+ runtime_preparation_str.format(step=3, step_name='runtime'))
554
+
555
+ ray_port = constants.SKY_REMOTE_RAY_PORT
556
+ head_ray_needs_restart = True
557
+ ray_cluster_healthy = False
558
+ if (not provision_record.is_instance_just_booted(
559
+ head_instance.instance_id)):
560
+ # Check if head node Ray is alive
561
+ (ray_port, ray_cluster_healthy,
562
+ head_ray_needs_restart) = check_ray_port_and_cluster_healthy()
563
+ elif cloud_name.lower() == 'kubernetes':
564
+ timeout = 90 # 1.5-min maximum timeout
565
+ start = time.time()
566
+ while True:
567
+ # Wait until Ray cluster is ready
568
+ (ray_port, ray_cluster_healthy,
569
+ head_ray_needs_restart) = check_ray_port_and_cluster_healthy()
570
+ if ray_cluster_healthy:
571
+ logger.debug('Ray cluster is ready. Skip head and worker '
572
+ 'node ray cluster setup.')
573
+ break
574
+ if time.time() - start > timeout:
575
+ # In most cases, the ray cluster will be ready after a few
576
+ # seconds. Trigger ray start on head or worker nodes to be
577
+ # safe, if the ray cluster is not ready after timeout.
578
+ break
579
+ logger.debug('Ray cluster is not ready yet, waiting for the '
580
+ 'async setup to complete...')
581
+ time.sleep(1)
516
582
 
517
- if full_ray_setup:
583
+ if head_ray_needs_restart:
518
584
  logger.debug('Starting Ray on the entire cluster.')
519
585
  instance_setup.start_ray_on_head_node(
520
586
  cluster_name.name_on_cloud,
521
587
  custom_resource=custom_resource,
522
588
  cluster_info=cluster_info,
523
589
  ssh_credentials=ssh_credentials)
590
+ else:
591
+ logger.debug('Ray cluster on head is ready. Skip starting ray '
592
+ 'cluster on head node.')
524
593
 
525
594
  # NOTE: We have to check all worker nodes to make sure they are all
526
595
  # healthy, otherwise we can only start Ray on newly started worker
@@ -531,10 +600,13 @@ def _post_provision_setup(
531
600
  # if provision_record.is_instance_just_booted(inst.instance_id):
532
601
  # worker_ips.append(inst.public_ip)
533
602
 
534
- if cluster_info.num_instances > 1:
603
+ # We don't need to restart ray on worker nodes if the ray cluster is
604
+ # already healthy, i.e. the head node has expected number of nodes
605
+ # connected to the ray cluster.
606
+ if cluster_info.num_instances > 1 and not ray_cluster_healthy:
535
607
  instance_setup.start_ray_on_worker_nodes(
536
608
  cluster_name.name_on_cloud,
537
- no_restart=not full_ray_setup,
609
+ no_restart=not head_ray_needs_restart,
538
610
  custom_resource=custom_resource,
539
611
  # Pass the ray_port to worker nodes for backward compatibility
540
612
  # as in some existing clusters the ray_port is not dumped with
@@ -543,6 +615,9 @@ def _post_provision_setup(
543
615
  ray_port=ray_port,
544
616
  cluster_info=cluster_info,
545
617
  ssh_credentials=ssh_credentials)
618
+ elif ray_cluster_healthy:
619
+ logger.debug('Ray cluster is ready. Skip starting ray cluster on '
620
+ 'worker nodes.')
546
621
 
547
622
  instance_setup.start_skylet_on_head_node(cluster_name.name_on_cloud,
548
623
  cluster_info, ssh_credentials)
@@ -553,6 +628,7 @@ def _post_provision_setup(
553
628
  return cluster_info
554
629
 
555
630
 
631
+ @timeline.event
556
632
  def post_provision_runtime_setup(
557
633
  cloud_name: str, cluster_name: resources_utils.ClusterName,
558
634
  cluster_yaml: str, provision_record: provision_common.ProvisionRecord,
sky/resources.py CHANGED
@@ -1041,6 +1041,7 @@ class Resources:
1041
1041
  def make_deploy_variables(self, cluster_name: resources_utils.ClusterName,
1042
1042
  region: clouds.Region,
1043
1043
  zones: Optional[List[clouds.Zone]],
1044
+ num_nodes: int,
1044
1045
  dryrun: bool) -> Dict[str, Optional[str]]:
1045
1046
  """Converts planned sky.Resources to resource variables.
1046
1047
 
@@ -1062,7 +1063,7 @@ class Resources:
1062
1063
 
1063
1064
  # Cloud specific variables
1064
1065
  cloud_specific_variables = self.cloud.make_deploy_resources_variables(
1065
- self, cluster_name, region, zones, dryrun)
1066
+ self, cluster_name, region, zones, num_nodes, dryrun)
1066
1067
 
1067
1068
  # Docker run options
1068
1069
  docker_run_options = skypilot_config.get_nested(