skypilot-nightly 1.0.0.dev20241028__py3-none-any.whl → 1.0.0.dev20241030__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. sky/__init__.py +2 -2
  2. sky/adaptors/azure.py +3 -0
  3. sky/backends/backend_utils.py +10 -133
  4. sky/backends/cloud_vm_ray_backend.py +17 -105
  5. sky/clouds/azure.py +10 -1
  6. sky/execution.py +5 -4
  7. sky/jobs/controller.py +38 -22
  8. sky/jobs/recovery_strategy.py +30 -5
  9. sky/jobs/state.py +33 -5
  10. sky/jobs/utils.py +28 -4
  11. sky/optimizer.py +11 -7
  12. sky/provision/azure/azure-config-template.json +7 -1
  13. sky/provision/azure/config.py +65 -45
  14. sky/provision/azure/instance.py +275 -70
  15. sky/provision/constants.py +7 -0
  16. sky/provision/gcp/instance.py +0 -7
  17. sky/resources.py +25 -8
  18. sky/serve/core.py +0 -2
  19. sky/serve/serve_state.py +3 -7
  20. sky/serve/serve_utils.py +2 -14
  21. sky/serve/service_spec.py +0 -28
  22. sky/setup_files/setup.py +4 -3
  23. sky/skylet/job_lib.py +37 -53
  24. sky/skylet/log_lib.py +5 -14
  25. sky/templates/azure-ray.yml.j2 +1 -0
  26. sky/utils/dag_utils.py +14 -4
  27. sky/utils/schemas.py +25 -15
  28. {skypilot_nightly-1.0.0.dev20241028.dist-info → skypilot_nightly-1.0.0.dev20241030.dist-info}/METADATA +13 -11
  29. {skypilot_nightly-1.0.0.dev20241028.dist-info → skypilot_nightly-1.0.0.dev20241030.dist-info}/RECORD +33 -33
  30. {skypilot_nightly-1.0.0.dev20241028.dist-info → skypilot_nightly-1.0.0.dev20241030.dist-info}/WHEEL +1 -1
  31. {skypilot_nightly-1.0.0.dev20241028.dist-info → skypilot_nightly-1.0.0.dev20241030.dist-info}/LICENSE +0 -0
  32. {skypilot_nightly-1.0.0.dev20241028.dist-info → skypilot_nightly-1.0.0.dev20241030.dist-info}/entry_points.txt +0 -0
  33. {skypilot_nightly-1.0.0.dev20241028.dist-info → skypilot_nightly-1.0.0.dev20241030.dist-info}/top_level.txt +0 -0
sky/__init__.py CHANGED
@@ -5,7 +5,7 @@ from typing import Optional
5
5
  import urllib.request
6
6
 
7
7
  # Replaced with the current commit when building the wheels.
8
- _SKYPILOT_COMMIT_SHA = 'c0c17483d1f692ad639144050f5f6fa0966e47a5'
8
+ _SKYPILOT_COMMIT_SHA = '9d50f192b262d5f6cc74b5b6644f3a9e3ea31f2f'
9
9
 
10
10
 
11
11
  def _get_git_commit():
@@ -35,7 +35,7 @@ def _get_git_commit():
35
35
 
36
36
 
37
37
  __commit__ = _get_git_commit()
38
- __version__ = '1.0.0.dev20241028'
38
+ __version__ = '1.0.0.dev20241030'
39
39
  __root_dir__ = os.path.dirname(os.path.abspath(__file__))
40
40
 
41
41
 
sky/adaptors/azure.py CHANGED
@@ -131,6 +131,9 @@ def get_client(name: str,
131
131
  from azure.mgmt import authorization
132
132
  return authorization.AuthorizationManagementClient(
133
133
  credential, subscription_id)
134
+ elif name == 'msi':
135
+ from azure.mgmt import msi
136
+ return msi.ManagedServiceIdentityClient(credential, subscription_id)
134
137
  elif name == 'graph':
135
138
  import msgraph
136
139
  return msgraph.GraphServiceClient(credential)
@@ -401,6 +401,8 @@ class SSHConfigHelper(object):
401
401
 
402
402
  ssh_conf_path = '~/.ssh/config'
403
403
  ssh_conf_lock_path = os.path.expanduser('~/.sky/ssh_config.lock')
404
+ ssh_conf_per_cluster_lock_path = os.path.expanduser(
405
+ '~/.sky/ssh_config_{}.lock')
404
406
  ssh_cluster_path = SKY_USER_FILE_PATH + '/ssh/{}'
405
407
 
406
408
  @classmethod
@@ -486,12 +488,6 @@ class SSHConfigHelper(object):
486
488
 
487
489
  config_path = os.path.expanduser(cls.ssh_conf_path)
488
490
 
489
- # For backward compatibility: before #2706, we wrote the config of SkyPilot clusters
490
- # directly in ~/.ssh/config. For these clusters, we remove the config in ~/.ssh/config
491
- # and write/overwrite the config in ~/.sky/ssh/<cluster_name> instead.
492
- cls._remove_stale_cluster_config_for_backward_compatibility(
493
- cluster_name, ip, auth_config, docker_user)
494
-
495
491
  if not os.path.exists(config_path):
496
492
  config = ['\n']
497
493
  with open(config_path,
@@ -560,139 +556,20 @@ class SSHConfigHelper(object):
560
556
  f.write(codegen)
561
557
 
562
558
  @classmethod
563
- def _remove_stale_cluster_config_for_backward_compatibility(
564
- cls,
565
- cluster_name: str,
566
- ip: str,
567
- auth_config: Dict[str, str],
568
- docker_user: Optional[str] = None,
569
- ):
570
- """Remove authentication information for cluster from local SSH config.
571
-
572
- If no existing host matching the provided specification is found, then
573
- nothing is removed.
574
-
575
- Args:
576
- ip: Head node's IP address.
577
- auth_config: read_yaml(handle.cluster_yaml)['auth']
578
- docker_user: If not None, use this user to ssh into the docker
579
- """
580
- username = auth_config['ssh_user']
581
- config_path = os.path.expanduser(cls.ssh_conf_path)
582
- cluster_config_path = os.path.expanduser(
583
- cls.ssh_cluster_path.format(cluster_name))
584
- if not os.path.exists(config_path):
585
- return
586
-
587
- with open(config_path, 'r', encoding='utf-8') as f:
588
- config = f.readlines()
589
-
590
- start_line_idx = None
591
-
592
- # Scan the config for the cluster name.
593
- for i, line in enumerate(config):
594
- next_line = config[i + 1] if i + 1 < len(config) else ''
595
- if docker_user is None:
596
- found = (line.strip() == f'HostName {ip}' and
597
- next_line.strip() == f'User {username}')
598
- else:
599
- found = (line.strip() == 'HostName localhost' and
600
- next_line.strip() == f'User {docker_user}')
601
- if found:
602
- # Find the line starting with ProxyCommand and contains the ip
603
- found = False
604
- for idx in range(i, len(config)):
605
- # Stop if we reach an empty line, which means a new host
606
- if not config[idx].strip():
607
- break
608
- if config[idx].strip().startswith('ProxyCommand'):
609
- proxy_command_line = config[idx].strip()
610
- if proxy_command_line.endswith(f'@{ip}'):
611
- found = True
612
- break
613
- if found:
614
- start_line_idx = i - 1
615
- break
616
-
617
- if start_line_idx is not None:
618
- # Scan for end of previous config.
619
- cursor = start_line_idx
620
- while cursor > 0 and len(config[cursor].strip()) > 0:
621
- cursor -= 1
622
- prev_end_line_idx = cursor
623
-
624
- # Scan for end of the cluster config.
625
- end_line_idx = None
626
- cursor = start_line_idx + 1
627
- start_line_idx -= 1 # remove auto-generated comment
628
- while cursor < len(config):
629
- if config[cursor].strip().startswith(
630
- '# ') or config[cursor].strip().startswith('Host '):
631
- end_line_idx = cursor
632
- break
633
- cursor += 1
634
-
635
- # Remove sky-generated config and update the file.
636
- config[prev_end_line_idx:end_line_idx] = [
637
- '\n'
638
- ] if end_line_idx is not None else []
639
- with open(config_path, 'w', encoding='utf-8') as f:
640
- f.write(''.join(config).strip())
641
- f.write('\n' * 2)
642
-
643
- # Delete include statement if it exists in the config.
644
- sky_autogen_comment = ('# Added by sky (use `sky stop/down '
645
- f'{cluster_name}` to remove)')
646
- with open(config_path, 'r', encoding='utf-8') as f:
647
- config = f.readlines()
648
-
649
- for i, line in enumerate(config):
650
- config_str = line.strip()
651
- if f'Include {cluster_config_path}' in config_str:
652
- with open(config_path, 'w', encoding='utf-8') as f:
653
- if i < len(config) - 1 and config[i + 1] == '\n':
654
- del config[i + 1]
655
- # Delete Include string
656
- del config[i]
657
- # Delete Sky Autogen Comment
658
- if i > 0 and sky_autogen_comment in config[i - 1].strip():
659
- del config[i - 1]
660
- f.write(''.join(config))
661
- break
662
- if 'Host' in config_str:
663
- break
664
-
665
- @classmethod
666
- # TODO: We can remove this after 0.6.0 and have a lock only per cluster.
667
- @timeline.FileLockEvent(ssh_conf_lock_path)
668
- def remove_cluster(
669
- cls,
670
- cluster_name: str,
671
- ip: str,
672
- auth_config: Dict[str, str],
673
- docker_user: Optional[str] = None,
674
- ):
559
+ def remove_cluster(cls, cluster_name: str):
675
560
  """Remove authentication information for cluster from ~/.sky/ssh/<cluster_name>.
676
561
 
677
- For backward compatibility also remove the config from ~/.ssh/config if it exists.
678
-
679
562
  If no existing host matching the provided specification is found, then
680
563
  nothing is removed.
681
564
 
682
565
  Args:
683
- ip: Head node's IP address.
684
- auth_config: read_yaml(handle.cluster_yaml)['auth']
685
- docker_user: If not None, use this user to ssh into the docker
566
+ cluster_name: Cluster name.
686
567
  """
687
- cluster_config_path = os.path.expanduser(
688
- cls.ssh_cluster_path.format(cluster_name))
689
- common_utils.remove_file_if_exists(cluster_config_path)
690
-
691
- # Ensures backward compatibility: before #2706, we wrote the config of SkyPilot clusters
692
- # directly in ~/.ssh/config. For these clusters, we should clean up the config.
693
- # TODO: Remove this after 0.6.0
694
- cls._remove_stale_cluster_config_for_backward_compatibility(
695
- cluster_name, ip, auth_config, docker_user)
568
+ with timeline.FileLockEvent(
569
+ cls.ssh_conf_per_cluster_lock_path.format(cluster_name)):
570
+ cluster_config_path = os.path.expanduser(
571
+ cls.ssh_cluster_path.format(cluster_name))
572
+ common_utils.remove_file_if_exists(cluster_config_path)
696
573
 
697
574
 
698
575
  def _replace_yaml_dicts(
@@ -867,7 +744,7 @@ def write_cluster_config(
867
744
  labels = skypilot_config.get_nested((str(cloud).lower(), 'labels'), {})
868
745
  # Deprecated: instance_tags have been replaced by labels. For backward
869
746
  # compatibility, we support them and the schema allows them only if
870
- # `labels` are not specified. This should be removed after 0.7.0.
747
+ # `labels` are not specified. This should be removed after 0.8.0.
871
748
  labels = skypilot_config.get_nested((str(cloud).lower(), 'instance_tags'),
872
749
  labels)
873
750
  # labels is a dict, which is guaranteed by the type check in
@@ -2118,13 +2118,8 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle):
2118
2118
  stable_internal_external_ips: Optional[List[Tuple[str,
2119
2119
  str]]] = None,
2120
2120
  stable_ssh_ports: Optional[List[int]] = None,
2121
- cluster_info: Optional[provision_common.ClusterInfo] = None,
2122
- # The following 2 fields are deprecated. SkyPilot new provisioner
2123
- # API handles the TPU node creation/deletion.
2124
- # Backward compatibility for TPU nodes created before #2943.
2125
- # TODO (zhwu): Remove this after 0.6.0.
2126
- tpu_create_script: Optional[str] = None,
2127
- tpu_delete_script: Optional[str] = None) -> None:
2121
+ cluster_info: Optional[provision_common.ClusterInfo] = None
2122
+ ) -> None:
2128
2123
  self._version = self._VERSION
2129
2124
  self.cluster_name = cluster_name
2130
2125
  self.cluster_name_on_cloud = cluster_name_on_cloud
@@ -2139,12 +2134,6 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle):
2139
2134
  self.launched_nodes = launched_nodes
2140
2135
  self.launched_resources = launched_resources
2141
2136
  self.docker_user: Optional[str] = None
2142
- # Deprecated. SkyPilot new provisioner API handles the TPU node
2143
- # creation/deletion.
2144
- # Backward compatibility for TPU nodes created before #2943.
2145
- # TODO (zhwu): Remove this after 0.6.0.
2146
- self.tpu_create_script = tpu_create_script
2147
- self.tpu_delete_script = tpu_delete_script
2148
2137
 
2149
2138
  def __repr__(self):
2150
2139
  return (f'ResourceHandle('
@@ -2160,10 +2149,7 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle):
2160
2149
  f'\n\tlaunched_resources={self.launched_nodes}x '
2161
2150
  f'{self.launched_resources}, '
2162
2151
  f'\n\tdocker_user={self.docker_user},'
2163
- f'\n\tssh_user={self.ssh_user},'
2164
- # TODO (zhwu): Remove this after 0.6.0.
2165
- f'\n\ttpu_create_script={self.tpu_create_script}, '
2166
- f'\n\ttpu_delete_script={self.tpu_delete_script})')
2152
+ f'\n\tssh_user={self.ssh_user}')
2167
2153
 
2168
2154
  def get_cluster_name(self):
2169
2155
  return self.cluster_name
@@ -2176,26 +2162,6 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle):
2176
2162
  return common_utils.read_yaml(self.cluster_yaml).get(
2177
2163
  'provider', {}).get('use_internal_ips', False)
2178
2164
 
2179
- def _update_cluster_region(self):
2180
- """Update the region in handle.launched_resources.
2181
-
2182
- This is for backward compatibility to handle the clusters launched
2183
- long before. We should remove this after 0.6.0.
2184
- """
2185
- if self.launched_resources.region is not None:
2186
- return
2187
-
2188
- config = common_utils.read_yaml(self.cluster_yaml)
2189
- provider = config['provider']
2190
- cloud = self.launched_resources.cloud
2191
- if cloud.is_same_cloud(clouds.Azure()):
2192
- region = provider['location']
2193
- elif cloud.is_same_cloud(clouds.GCP()) or cloud.is_same_cloud(
2194
- clouds.AWS()):
2195
- region = provider['region']
2196
-
2197
- self.launched_resources = self.launched_resources.copy(region=region)
2198
-
2199
2165
  def update_ssh_ports(self, max_attempts: int = 1) -> None:
2200
2166
  """Fetches and sets the SSH ports for the cluster nodes.
2201
2167
 
@@ -2567,8 +2533,6 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle):
2567
2533
  if version < 4:
2568
2534
  self.update_ssh_ports()
2569
2535
 
2570
- self._update_cluster_region()
2571
-
2572
2536
  if version < 8:
2573
2537
  try:
2574
2538
  self._update_cluster_info()
@@ -2649,8 +2613,6 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
2649
2613
  if record is not None:
2650
2614
  usage_lib.messages.usage.update_cluster_status(record['status'])
2651
2615
 
2652
- # Backward compatibility: the old launched_resources without region info
2653
- # was handled by ResourceHandle._update_cluster_region.
2654
2616
  assert launched_resources.region is not None, handle
2655
2617
 
2656
2618
  mismatch_str = (f'To fix: specify a new cluster name, or down the '
@@ -3213,9 +3175,19 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
3213
3175
  returncode = _run_setup(f'{create_script_code} && {setup_cmd}',)
3214
3176
  if returncode == 255:
3215
3177
  is_message_too_long = False
3216
- with open(setup_log_path, 'r', encoding='utf-8') as f:
3217
- if 'too long' in f.read():
3218
- is_message_too_long = True
3178
+ try:
3179
+ with open(os.path.expanduser(setup_log_path),
3180
+ 'r',
3181
+ encoding='utf-8') as f:
3182
+ if 'too long' in f.read():
3183
+ is_message_too_long = True
3184
+ except Exception as e: # pylint: disable=broad-except
3185
+ # We don't crash the setup if we cannot read the log file.
3186
+ # Instead, we should retry the setup with dumping the script
3187
+ # to a file to be safe.
3188
+ logger.debug('Failed to read setup log file '
3189
+ f'{setup_log_path}: {e}')
3190
+ is_message_too_long = True
3219
3191
 
3220
3192
  if is_message_too_long:
3221
3193
  # If the setup script is too long, we retry it with dumping
@@ -3585,9 +3557,6 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
3585
3557
  backend_utils.CLUSTER_STATUS_LOCK_PATH.format(cluster_name))
3586
3558
 
3587
3559
  try:
3588
- # TODO(mraheja): remove pylint disabling when filelock
3589
- # version updated
3590
- # pylint: disable=abstract-class-instantiated
3591
3560
  with filelock.FileLock(
3592
3561
  lock_path,
3593
3562
  backend_utils.CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS):
@@ -4096,55 +4065,9 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
4096
4065
  * Removing ssh configs for the cluster;
4097
4066
  * Updating the local state of the cluster;
4098
4067
  * Removing the terminated cluster's scripts and ray yaml files.
4099
-
4100
- Raises:
4101
- RuntimeError: If it fails to delete the TPU.
4102
4068
  """
4103
- log_path = os.path.join(os.path.expanduser(self.log_dir),
4104
- 'teardown.log')
4105
- log_abs_path = os.path.abspath(log_path)
4106
4069
  cluster_name_on_cloud = handle.cluster_name_on_cloud
4107
4070
 
4108
- # Backward compatibility for TPU nodes created before #2943. Any TPU
4109
- # node launched before that PR have the delete script generated (and do
4110
- # not have the tpu_node config set in its cluster yaml), so we have to
4111
- # call the deletion script to clean up the TPU node.
4112
- # For TPU nodes launched after the PR, deletion is done in SkyPilot's
4113
- # new GCP provisioner API.
4114
- # TODO (zhwu): Remove this after 0.6.0.
4115
- if (handle.tpu_delete_script is not None and
4116
- os.path.exists(handle.tpu_delete_script)):
4117
- # Only call the deletion script if the cluster config does not
4118
- # contain TPU node config. Otherwise, the deletion should
4119
- # already be handled by the new provisioner.
4120
- config = common_utils.read_yaml(handle.cluster_yaml)
4121
- tpu_node_config = config['provider'].get('tpu_node')
4122
- if tpu_node_config is None:
4123
- with rich_utils.safe_status(
4124
- ux_utils.spinner_message('Terminating TPU')):
4125
- tpu_rc, tpu_stdout, tpu_stderr = log_lib.run_with_log(
4126
- ['bash', handle.tpu_delete_script],
4127
- log_abs_path,
4128
- stream_logs=False,
4129
- require_outputs=True)
4130
- if tpu_rc != 0:
4131
- if _TPU_NOT_FOUND_ERROR in tpu_stderr:
4132
- logger.info('TPU not found. '
4133
- 'It should have been deleted already.')
4134
- elif purge:
4135
- logger.warning(
4136
- _TEARDOWN_PURGE_WARNING.format(
4137
- reason='stopping/terminating TPU',
4138
- details=tpu_stderr))
4139
- else:
4140
- raise RuntimeError(
4141
- _TEARDOWN_FAILURE_MESSAGE.format(
4142
- extra_reason='It is caused by TPU failure.',
4143
- cluster_name=common_utils.cluster_name_in_hint(
4144
- handle.cluster_name, cluster_name_on_cloud),
4145
- stdout=tpu_stdout,
4146
- stderr=tpu_stderr))
4147
-
4148
4071
  if (terminate and handle.launched_resources.is_image_managed is True):
4149
4072
  # Delete the image when terminating a "cloned" cluster, i.e.,
4150
4073
  # whose image is created by SkyPilot (--clone-disk-from)
@@ -4189,11 +4112,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
4189
4112
  # The cluster file must exist because the cluster_yaml will only
4190
4113
  # be removed after the cluster entry in the database is removed.
4191
4114
  config = common_utils.read_yaml(handle.cluster_yaml)
4192
- auth_config = config['auth']
4193
- backend_utils.SSHConfigHelper.remove_cluster(handle.cluster_name,
4194
- handle.head_ip,
4195
- auth_config,
4196
- handle.docker_user)
4115
+ backend_utils.SSHConfigHelper.remove_cluster(handle.cluster_name)
4197
4116
 
4198
4117
  global_user_state.remove_cluster(handle.cluster_name,
4199
4118
  terminate=terminate)
@@ -4202,13 +4121,6 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
4202
4121
  # This function could be directly called from status refresh,
4203
4122
  # where we need to cleanup the cluster profile.
4204
4123
  metadata_utils.remove_cluster_metadata(handle.cluster_name)
4205
- # Clean up TPU creation/deletion scripts
4206
- # Backward compatibility for TPU nodes created before #2943.
4207
- # TODO (zhwu): Remove this after 0.6.0.
4208
- if handle.tpu_delete_script is not None:
4209
- assert handle.tpu_create_script is not None
4210
- common_utils.remove_file_if_exists(handle.tpu_create_script)
4211
- common_utils.remove_file_if_exists(handle.tpu_delete_script)
4212
4124
 
4213
4125
  # Clean up generated config
4214
4126
  # No try-except is needed since Ray will fail to teardown the
sky/clouds/azure.py CHANGED
@@ -12,6 +12,7 @@ import colorama
12
12
  from sky import clouds
13
13
  from sky import exceptions
14
14
  from sky import sky_logging
15
+ from sky import skypilot_config
15
16
  from sky.adaptors import azure
16
17
  from sky.clouds import service_catalog
17
18
  from sky.clouds.utils import azure_utils
@@ -353,6 +354,13 @@ class Azure(clouds.Cloud):
353
354
  need_nvidia_driver_extension = (acc_dict is not None and
354
355
  'A10' in acc_dict)
355
356
 
357
+ # Determine resource group for deploying the instance.
358
+ resource_group_name = skypilot_config.get_nested(
359
+ ('azure', 'resource_group_vm'), None)
360
+ use_external_resource_group = resource_group_name is not None
361
+ if resource_group_name is None:
362
+ resource_group_name = f'{cluster_name.name_on_cloud}-{region_name}'
363
+
356
364
  # Setup commands to eliminate the banner and restart sshd.
357
365
  # This script will modify /etc/ssh/sshd_config and add a bash script
358
366
  # into .bashrc. The bash script will restart sshd if it has not been
@@ -409,7 +417,8 @@ class Azure(clouds.Cloud):
409
417
  'disk_tier': Azure._get_disk_type(disk_tier),
410
418
  'cloud_init_setup_commands': cloud_init_setup_commands,
411
419
  'azure_subscription_id': self.get_project_id(dryrun),
412
- 'resource_group': f'{cluster_name.name_on_cloud}-{region_name}',
420
+ 'resource_group': resource_group_name,
421
+ 'use_external_resource_group': use_external_resource_group,
413
422
  }
414
423
 
415
424
  # Setting disk performance tier for high disk tier.
sky/execution.py CHANGED
@@ -171,10 +171,11 @@ def _execute(
171
171
  task = dag.tasks[0]
172
172
 
173
173
  if any(r.job_recovery is not None for r in task.resources):
174
- with ux_utils.print_exception_no_traceback():
175
- raise ValueError(
176
- 'Job recovery is specified in the task. To launch a '
177
- 'managed job, please use: sky jobs launch')
174
+ logger.warning(
175
+ f'{colorama.Style.DIM}The task has `job_recovery` specified, '
176
+ 'but is launched as an unmanaged job. It will be ignored.'
177
+ 'To enable job recovery, use managed jobs: sky jobs launch.'
178
+ f'{colorama.Style.RESET_ALL}')
178
179
 
179
180
  cluster_exists = False
180
181
  if cluster_name is not None:
sky/jobs/controller.py CHANGED
@@ -160,6 +160,11 @@ class JobsController:
160
160
  if task_id == 0:
161
161
  submitted_at = backend_utils.get_timestamp_from_run_timestamp(
162
162
  self._backend.run_timestamp)
163
+ assert task.name is not None, task
164
+ cluster_name = managed_job_utils.generate_managed_job_cluster_name(
165
+ task.name, self._job_id)
166
+ self._strategy_executor = recovery_strategy.StrategyExecutor.make(
167
+ cluster_name, self._backend, task, self._retry_until_up)
163
168
  managed_job_state.set_submitted(
164
169
  self._job_id,
165
170
  task_id,
@@ -167,15 +172,14 @@ class JobsController:
167
172
  submitted_at,
168
173
  resources_str=backend_utils.get_task_resources_str(
169
174
  task, is_managed_job=True),
175
+ specs={
176
+ 'max_restarts_on_errors':
177
+ self._strategy_executor.max_restarts_on_errors
178
+ },
170
179
  callback_func=callback_func)
171
180
  logger.info(
172
181
  f'Submitted managed job {self._job_id} (task: {task_id}, name: '
173
182
  f'{task.name!r}); {constants.TASK_ID_ENV_VAR}: {task_id_env_var}')
174
- assert task.name is not None, task
175
- cluster_name = managed_job_utils.generate_managed_job_cluster_name(
176
- task.name, self._job_id)
177
- self._strategy_executor = recovery_strategy.StrategyExecutor.make(
178
- cluster_name, self._backend, task, self._retry_until_up)
179
183
 
180
184
  logger.info('Started monitoring.')
181
185
  managed_job_state.set_starting(job_id=self._job_id,
@@ -283,23 +287,35 @@ class JobsController:
283
287
  failure_reason = (
284
288
  'To see the details, run: '
285
289
  f'sky jobs logs --controller {self._job_id}')
286
-
287
- managed_job_state.set_failed(
288
- self._job_id,
289
- task_id,
290
- failure_type=managed_job_status,
291
- failure_reason=failure_reason,
292
- end_time=end_time,
293
- callback_func=callback_func)
294
- return False
295
- # Although the cluster is healthy, we fail to access the
296
- # job status. Try to recover the job (will not restart the
297
- # cluster, if the cluster is healthy).
298
- assert job_status is None, job_status
299
- logger.info('Failed to fetch the job status while the '
300
- 'cluster is healthy. Try to recover the job '
301
- '(the cluster will not be restarted).')
302
-
290
+ should_restart_on_failure = (
291
+ self._strategy_executor.should_restart_on_failure())
292
+ if should_restart_on_failure:
293
+ max_restarts = (
294
+ self._strategy_executor.max_restarts_on_errors)
295
+ logger.info(
296
+ f'User program crashed '
297
+ f'({managed_job_status.value}). '
298
+ f'Retry the job as max_restarts_on_errors is '
299
+ f'set to {max_restarts}. '
300
+ f'[{self._strategy_executor.restart_cnt_on_failure}'
301
+ f'/{max_restarts}]')
302
+ else:
303
+ managed_job_state.set_failed(
304
+ self._job_id,
305
+ task_id,
306
+ failure_type=managed_job_status,
307
+ failure_reason=failure_reason,
308
+ end_time=end_time,
309
+ callback_func=callback_func)
310
+ return False
311
+ else:
312
+ # Although the cluster is healthy, we fail to access the
313
+ # job status. Try to recover the job (will not restart the
314
+ # cluster, if the cluster is healthy).
315
+ assert job_status is None, job_status
316
+ logger.info('Failed to fetch the job status while the '
317
+ 'cluster is healthy. Try to recover the job '
318
+ '(the cluster will not be restarted).')
303
319
  # When the handle is None, the cluster should be cleaned up already.
304
320
  if handle is not None:
305
321
  resources = handle.launched_resources
@@ -66,7 +66,8 @@ class StrategyExecutor:
66
66
  RETRY_INIT_GAP_SECONDS = 60
67
67
 
68
68
  def __init__(self, cluster_name: str, backend: 'backends.Backend',
69
- task: 'task_lib.Task', retry_until_up: bool) -> None:
69
+ task: 'task_lib.Task', retry_until_up: bool,
70
+ max_restarts_on_errors: int) -> None:
70
71
  """Initialize the strategy executor.
71
72
 
72
73
  Args:
@@ -82,6 +83,8 @@ class StrategyExecutor:
82
83
  self.cluster_name = cluster_name
83
84
  self.backend = backend
84
85
  self.retry_until_up = retry_until_up
86
+ self.max_restarts_on_errors = max_restarts_on_errors
87
+ self.restart_cnt_on_failure = 0
85
88
 
86
89
  def __init_subclass__(cls, name: str, default: bool = False):
87
90
  RECOVERY_STRATEGIES[name] = cls
@@ -109,8 +112,17 @@ class StrategyExecutor:
109
112
  # set the new_task_resources to be the same type (list or set) as the
110
113
  # original task.resources
111
114
  task.set_resources(type(task.resources)(new_resources_list))
112
- return RECOVERY_STRATEGIES[job_recovery](cluster_name, backend, task,
113
- retry_until_up)
115
+ if isinstance(job_recovery, dict):
116
+ job_recovery_name = job_recovery.pop('strategy',
117
+ DEFAULT_RECOVERY_STRATEGY)
118
+ max_restarts_on_errors = job_recovery.pop('max_restarts_on_errors',
119
+ 0)
120
+ else:
121
+ job_recovery_name = job_recovery
122
+ max_restarts_on_errors = 0
123
+ return RECOVERY_STRATEGIES[job_recovery_name](cluster_name, backend,
124
+ task, retry_until_up,
125
+ max_restarts_on_errors)
114
126
 
115
127
  def launch(self) -> float:
116
128
  """Launch the cluster for the first time.
@@ -368,6 +380,17 @@ class StrategyExecutor:
368
380
  f'{gap_seconds:.1f} seconds.')
369
381
  time.sleep(gap_seconds)
370
382
 
383
+ def should_restart_on_failure(self) -> bool:
384
+ """Increments counter & checks if job should be restarted on a failure.
385
+
386
+ Returns:
387
+ True if the job should be restarted, otherwise False.
388
+ """
389
+ self.restart_cnt_on_failure += 1
390
+ if self.restart_cnt_on_failure > self.max_restarts_on_errors:
391
+ return False
392
+ return True
393
+
371
394
 
372
395
  class FailoverStrategyExecutor(StrategyExecutor, name='FAILOVER',
373
396
  default=False):
@@ -376,8 +399,10 @@ class FailoverStrategyExecutor(StrategyExecutor, name='FAILOVER',
376
399
  _MAX_RETRY_CNT = 240 # Retry for 4 hours.
377
400
 
378
401
  def __init__(self, cluster_name: str, backend: 'backends.Backend',
379
- task: 'task_lib.Task', retry_until_up: bool) -> None:
380
- super().__init__(cluster_name, backend, task, retry_until_up)
402
+ task: 'task_lib.Task', retry_until_up: bool,
403
+ max_restarts_on_errors: int) -> None:
404
+ super().__init__(cluster_name, backend, task, retry_until_up,
405
+ max_restarts_on_errors)
381
406
  # Note down the cloud/region of the launched cluster, so that we can
382
407
  # first retry in the same cloud/region. (Inside recover() we may not
383
408
  # rely on cluster handle, as it can be None if the cluster is