skypilot-nightly 1.0.0.dev20241121__py3-none-any.whl → 1.0.0.dev20241122__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sky/__init__.py CHANGED
@@ -5,7 +5,7 @@ from typing import Optional
5
5
  import urllib.request
6
6
 
7
7
  # Replaced with the current commit when building the wheels.
8
- _SKYPILOT_COMMIT_SHA = '627be7298f987ddf6ada7f3c0cf849edde97240a'
8
+ _SKYPILOT_COMMIT_SHA = '204d979fedece9b7b789dcd2610d1ebdbc8d1fc5'
9
9
 
10
10
 
11
11
  def _get_git_commit():
@@ -35,7 +35,7 @@ def _get_git_commit():
35
35
 
36
36
 
37
37
  __commit__ = _get_git_commit()
38
- __version__ = '1.0.0.dev20241121'
38
+ __version__ = '1.0.0.dev20241122'
39
39
  __root_dir__ = os.path.dirname(os.path.abspath(__file__))
40
40
 
41
41
 
@@ -683,7 +683,7 @@ def write_cluster_config(
683
683
  resources_utils.ClusterName(
684
684
  cluster_name,
685
685
  cluster_name_on_cloud,
686
- ), region, zones, dryrun)
686
+ ), region, zones, num_nodes, dryrun)
687
687
  config_dict = {}
688
688
 
689
689
  specific_reservations = set(
@@ -844,7 +844,11 @@ def write_cluster_config(
844
844
  '{sky_wheel_hash}',
845
845
  wheel_hash).replace('{cloud}',
846
846
  str(cloud).lower())),
847
-
847
+ 'skypilot_wheel_installation_commands':
848
+ constants.SKYPILOT_WHEEL_INSTALLATION_COMMANDS.replace(
849
+ '{sky_wheel_hash}',
850
+ wheel_hash).replace('{cloud}',
851
+ str(cloud).lower()),
848
852
  # Port of Ray (GCS server).
849
853
  # Ray's default port 6379 is conflicted with Redis.
850
854
  'ray_port': constants.SKY_REMOTE_RAY_PORT,
@@ -1191,18 +1195,18 @@ def ssh_credential_from_yaml(
1191
1195
 
1192
1196
 
1193
1197
  def parallel_data_transfer_to_nodes(
1194
- runners: List[command_runner.CommandRunner],
1195
- source: Optional[str],
1196
- target: str,
1197
- cmd: Optional[str],
1198
- run_rsync: bool,
1199
- *,
1200
- action_message: str,
1201
- # Advanced options.
1202
- log_path: str = os.devnull,
1203
- stream_logs: bool = False,
1204
- source_bashrc: bool = False,
1205
- ):
1198
+ runners: List[command_runner.CommandRunner],
1199
+ source: Optional[str],
1200
+ target: str,
1201
+ cmd: Optional[str],
1202
+ run_rsync: bool,
1203
+ *,
1204
+ action_message: str,
1205
+ # Advanced options.
1206
+ log_path: str = os.devnull,
1207
+ stream_logs: bool = False,
1208
+ source_bashrc: bool = False,
1209
+ num_threads: Optional[int] = None):
1206
1210
  """Runs a command on all nodes and optionally runs rsync from src->dst.
1207
1211
 
1208
1212
  Args:
@@ -1214,6 +1218,7 @@ def parallel_data_transfer_to_nodes(
1214
1218
  log_path: str; Path to the log file
1215
1219
  stream_logs: bool; Whether to stream logs to stdout
1216
1220
  source_bashrc: bool; Source bashrc before running the command.
1221
+ num_threads: Optional[int]; Number of threads to use.
1217
1222
  """
1218
1223
  style = colorama.Style
1219
1224
 
@@ -1254,7 +1259,7 @@ def parallel_data_transfer_to_nodes(
1254
1259
  message = (f' {style.DIM}{action_message} (to {num_nodes} node{plural})'
1255
1260
  f': {origin_source} -> {target}{style.RESET_ALL}')
1256
1261
  logger.info(message)
1257
- subprocess_utils.run_in_parallel(_sync_node, runners)
1262
+ subprocess_utils.run_in_parallel(_sync_node, runners, num_threads)
1258
1263
 
1259
1264
 
1260
1265
  def check_local_gpus() -> bool:
@@ -1535,7 +1535,7 @@ class RetryingVmProvisioner(object):
1535
1535
  to_provision,
1536
1536
  resources_utils.ClusterName(
1537
1537
  cluster_name, handle.cluster_name_on_cloud),
1538
- region, zones))
1538
+ region, zones, num_nodes))
1539
1539
  config_dict['provision_record'] = provision_record
1540
1540
  config_dict['resources_vars'] = resources_vars
1541
1541
  config_dict['handle'] = handle
@@ -3093,9 +3093,12 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
3093
3093
  f'{workdir} -> {SKY_REMOTE_WORKDIR}{style.RESET_ALL}')
3094
3094
  os.makedirs(os.path.expanduser(self.log_dir), exist_ok=True)
3095
3095
  os.system(f'touch {log_path}')
3096
+ num_threads = subprocess_utils.get_parallel_threads(
3097
+ str(handle.launched_resources.cloud))
3096
3098
  with rich_utils.safe_status(
3097
3099
  ux_utils.spinner_message('Syncing workdir', log_path)):
3098
- subprocess_utils.run_in_parallel(_sync_workdir_node, runners)
3100
+ subprocess_utils.run_in_parallel(_sync_workdir_node, runners,
3101
+ num_threads)
3099
3102
  logger.info(ux_utils.finishing_message('Workdir synced.', log_path))
3100
3103
 
3101
3104
  def _sync_file_mounts(
@@ -4423,6 +4426,8 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
4423
4426
  start = time.time()
4424
4427
  runners = handle.get_command_runners()
4425
4428
  log_path = os.path.join(self.log_dir, 'file_mounts.log')
4429
+ num_threads = subprocess_utils.get_max_workers_for_file_mounts(
4430
+ file_mounts, str(handle.launched_resources.cloud))
4426
4431
 
4427
4432
  # Check the files and warn
4428
4433
  for dst, src in file_mounts.items():
@@ -4484,6 +4489,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
4484
4489
  action_message='Syncing',
4485
4490
  log_path=log_path,
4486
4491
  stream_logs=False,
4492
+ num_threads=num_threads,
4487
4493
  )
4488
4494
  continue
4489
4495
 
@@ -4520,6 +4526,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
4520
4526
  # Need to source bashrc, as the cloud specific CLI or SDK may
4521
4527
  # require PATH in bashrc.
4522
4528
  source_bashrc=True,
4529
+ num_threads=num_threads,
4523
4530
  )
4524
4531
  # (2) Run the commands to create symlinks on all the nodes.
4525
4532
  symlink_command = ' && '.join(symlink_commands)
@@ -4538,7 +4545,8 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
4538
4545
  'Failed to create symlinks. The target destination '
4539
4546
  f'may already exist. Log: {log_path}')
4540
4547
 
4541
- subprocess_utils.run_in_parallel(_symlink_node, runners)
4548
+ subprocess_utils.run_in_parallel(_symlink_node, runners,
4549
+ num_threads)
4542
4550
  end = time.time()
4543
4551
  logger.debug(f'File mount sync took {end - start} seconds.')
4544
4552
  logger.info(ux_utils.finishing_message('Files synced.', log_path))
@@ -4567,6 +4575,8 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
4567
4575
  return
4568
4576
  start = time.time()
4569
4577
  runners = handle.get_command_runners()
4578
+ num_threads = subprocess_utils.get_parallel_threads(
4579
+ str(handle.launched_resources.cloud))
4570
4580
  log_path = os.path.join(self.log_dir, 'storage_mounts.log')
4571
4581
 
4572
4582
  plural = 's' if len(storage_mounts) > 1 else ''
@@ -4605,6 +4615,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
4605
4615
  # Need to source bashrc, as the cloud specific CLI or SDK
4606
4616
  # may require PATH in bashrc.
4607
4617
  source_bashrc=True,
4618
+ num_threads=num_threads,
4608
4619
  )
4609
4620
  except exceptions.CommandError as e:
4610
4621
  if e.returncode == exceptions.MOUNT_PATH_NON_EMPTY_CODE:
sky/clouds/aws.py CHANGED
@@ -401,6 +401,7 @@ class AWS(clouds.Cloud):
401
401
  cluster_name: resources_utils.ClusterName,
402
402
  region: 'clouds.Region',
403
403
  zones: Optional[List['clouds.Zone']],
404
+ num_nodes: int,
404
405
  dryrun: bool = False) -> Dict[str, Any]:
405
406
  del dryrun # unused
406
407
  assert zones is not None, (region, zones)
sky/clouds/azure.py CHANGED
@@ -302,6 +302,7 @@ class Azure(clouds.Cloud):
302
302
  cluster_name: resources_utils.ClusterName,
303
303
  region: 'clouds.Region',
304
304
  zones: Optional[List['clouds.Zone']],
305
+ num_nodes: int,
305
306
  dryrun: bool = False) -> Dict[str, Any]:
306
307
  assert zones is None, ('Azure does not support zones', zones)
307
308
 
sky/clouds/cloud.py CHANGED
@@ -283,6 +283,7 @@ class Cloud:
283
283
  cluster_name: resources_utils.ClusterName,
284
284
  region: 'Region',
285
285
  zones: Optional[List['Zone']],
286
+ num_nodes: int,
286
287
  dryrun: bool = False,
287
288
  ) -> Dict[str, Optional[str]]:
288
289
  """Converts planned sky.Resources to cloud-specific resource variables.
sky/clouds/cudo.py CHANGED
@@ -196,6 +196,7 @@ class Cudo(clouds.Cloud):
196
196
  cluster_name: resources_utils.ClusterName,
197
197
  region: 'clouds.Region',
198
198
  zones: Optional[List['clouds.Zone']],
199
+ num_nodes: int,
199
200
  dryrun: bool = False,
200
201
  ) -> Dict[str, Optional[str]]:
201
202
  del zones, cluster_name # unused
sky/clouds/fluidstack.py CHANGED
@@ -176,6 +176,7 @@ class Fluidstack(clouds.Cloud):
176
176
  cluster_name: resources_utils.ClusterName,
177
177
  region: clouds.Region,
178
178
  zones: Optional[List[clouds.Zone]],
179
+ num_nodes: int,
179
180
  dryrun: bool = False,
180
181
  ) -> Dict[str, Optional[str]]:
181
182
 
sky/clouds/gcp.py CHANGED
@@ -417,6 +417,7 @@ class GCP(clouds.Cloud):
417
417
  cluster_name: resources_utils.ClusterName,
418
418
  region: 'clouds.Region',
419
419
  zones: Optional[List['clouds.Zone']],
420
+ num_nodes: int,
420
421
  dryrun: bool = False) -> Dict[str, Optional[str]]:
421
422
  assert zones is not None, (region, zones)
422
423
 
sky/clouds/ibm.py CHANGED
@@ -170,6 +170,7 @@ class IBM(clouds.Cloud):
170
170
  cluster_name: resources_utils.ClusterName,
171
171
  region: 'clouds.Region',
172
172
  zones: Optional[List['clouds.Zone']],
173
+ num_nodes: int,
173
174
  dryrun: bool = False,
174
175
  ) -> Dict[str, Optional[str]]:
175
176
  """Converts planned sky.Resources to cloud-specific resource variables.
sky/clouds/kubernetes.py CHANGED
@@ -10,8 +10,10 @@ from sky import sky_logging
10
10
  from sky import skypilot_config
11
11
  from sky.adaptors import kubernetes
12
12
  from sky.clouds import service_catalog
13
+ from sky.provision import instance_setup
13
14
  from sky.provision.kubernetes import network_utils
14
15
  from sky.provision.kubernetes import utils as kubernetes_utils
16
+ from sky.skylet import constants
15
17
  from sky.utils import common_utils
16
18
  from sky.utils import resources_utils
17
19
  from sky.utils import schemas
@@ -311,12 +313,34 @@ class Kubernetes(clouds.Cloud):
311
313
  # we don't have a notion of disk size in Kubernetes.
312
314
  return 0
313
315
 
316
+ @staticmethod
317
+ def _calculate_provision_timeout(num_nodes: int) -> int:
318
+ """Calculate provision timeout based on number of nodes.
319
+
320
+ The timeout scales linearly with the number of nodes to account for
321
+ scheduling overhead, but is capped to avoid excessive waiting.
322
+
323
+ Args:
324
+ num_nodes: Number of nodes being provisioned
325
+
326
+ Returns:
327
+ Timeout in seconds
328
+ """
329
+ base_timeout = 10 # Base timeout for single node
330
+ per_node_timeout = 0.2 # Additional seconds per node
331
+ max_timeout = 60 # Cap at 1 minute
332
+
333
+ return int(
334
+ min(base_timeout + (per_node_timeout * (num_nodes - 1)),
335
+ max_timeout))
336
+
314
337
  def make_deploy_resources_variables(
315
338
  self,
316
339
  resources: 'resources_lib.Resources',
317
340
  cluster_name: resources_utils.ClusterName,
318
341
  region: Optional['clouds.Region'],
319
342
  zones: Optional[List['clouds.Zone']],
343
+ num_nodes: int,
320
344
  dryrun: bool = False) -> Dict[str, Optional[str]]:
321
345
  del cluster_name, zones, dryrun # Unused.
322
346
  if region is None:
@@ -413,12 +437,24 @@ class Kubernetes(clouds.Cloud):
413
437
  # Larger timeout may be required for autoscaling clusters, since
414
438
  # autoscaler may take some time to provision new nodes.
415
439
  # Note that this timeout includes time taken by the Kubernetes scheduler
416
- # itself, which can be upto 2-3 seconds.
417
- # For non-autoscaling clusters, we conservatively set this to 10s.
440
+ # itself, which can be upto 2-3 seconds, and up to 10-15 seconds when
441
+ # scheduling 100s of pods.
442
+ # We use a linear scaling formula to determine the timeout based on the
443
+ # number of nodes.
444
+
445
+ timeout = self._calculate_provision_timeout(num_nodes)
418
446
  timeout = skypilot_config.get_nested(
419
447
  ('kubernetes', 'provision_timeout'),
420
- 10,
448
+ timeout,
421
449
  override_configs=resources.cluster_config_overrides)
450
+ # We specify object-store-memory to be 500MB to avoid taking up too
451
+ # much memory on the head node. 'num-cpus' should be set to limit
452
+ # the CPU usage on the head pod, otherwise the ray cluster will use the
453
+ # CPU resources on the node instead within the pod.
454
+ custom_ray_options = {
455
+ 'object-store-memory': 500000000,
456
+ 'num-cpus': str(int(cpus)),
457
+ }
422
458
  deploy_vars = {
423
459
  'instance_type': resources.instance_type,
424
460
  'custom_resources': custom_resources,
@@ -445,6 +481,12 @@ class Kubernetes(clouds.Cloud):
445
481
  'k8s_topology_label_value': k8s_topology_label_value,
446
482
  'k8s_resource_key': k8s_resource_key,
447
483
  'image_id': image_id,
484
+ 'ray_installation_commands': constants.RAY_INSTALLATION_COMMANDS,
485
+ 'ray_head_start_command': instance_setup.ray_head_start_command(
486
+ custom_resources, custom_ray_options),
487
+ 'skypilot_ray_port': constants.SKY_REMOTE_RAY_PORT,
488
+ 'ray_worker_start_command': instance_setup.ray_worker_start_command(
489
+ custom_resources, custom_ray_options, no_restart=False),
448
490
  }
449
491
 
450
492
  # Add kubecontext if it is set. It may be None if SkyPilot is running
@@ -157,6 +157,7 @@ class Lambda(clouds.Cloud):
157
157
  cluster_name: resources_utils.ClusterName,
158
158
  region: 'clouds.Region',
159
159
  zones: Optional[List['clouds.Zone']],
160
+ num_nodes: int,
160
161
  dryrun: bool = False) -> Dict[str, Optional[str]]:
161
162
  del cluster_name, dryrun # Unused.
162
163
  assert zones is None, 'Lambda does not support zones.'
sky/clouds/oci.py CHANGED
@@ -208,6 +208,7 @@ class OCI(clouds.Cloud):
208
208
  cluster_name: resources_utils.ClusterName,
209
209
  region: Optional['clouds.Region'],
210
210
  zones: Optional[List['clouds.Zone']],
211
+ num_nodes: int,
211
212
  dryrun: bool = False) -> Dict[str, Optional[str]]:
212
213
  del cluster_name, dryrun # Unused.
213
214
  assert region is not None, resources
sky/clouds/paperspace.py CHANGED
@@ -175,6 +175,7 @@ class Paperspace(clouds.Cloud):
175
175
  cluster_name: resources_utils.ClusterName,
176
176
  region: 'clouds.Region',
177
177
  zones: Optional[List['clouds.Zone']],
178
+ num_nodes: int,
178
179
  dryrun: bool = False) -> Dict[str, Optional[str]]:
179
180
  del zones, dryrun, cluster_name
180
181
 
sky/clouds/runpod.py CHANGED
@@ -160,6 +160,7 @@ class RunPod(clouds.Cloud):
160
160
  cluster_name: resources_utils.ClusterName,
161
161
  region: 'clouds.Region',
162
162
  zones: Optional[List['clouds.Zone']],
163
+ num_nodes: int,
163
164
  dryrun: bool = False) -> Dict[str, Optional[str]]:
164
165
  del zones, dryrun, cluster_name # unused
165
166
 
sky/clouds/scp.py CHANGED
@@ -181,6 +181,7 @@ class SCP(clouds.Cloud):
181
181
  cluster_name: resources_utils.ClusterName,
182
182
  region: 'clouds.Region',
183
183
  zones: Optional[List['clouds.Zone']],
184
+ num_nodes: int,
184
185
  dryrun: bool = False) -> Dict[str, Optional[str]]:
185
186
  del cluster_name, dryrun # Unused.
186
187
  assert zones is None, 'SCP does not support zones.'
sky/clouds/vsphere.py CHANGED
@@ -173,6 +173,7 @@ class Vsphere(clouds.Cloud):
173
173
  cluster_name: resources_utils.ClusterName,
174
174
  region: 'clouds.Region',
175
175
  zones: Optional[List['clouds.Zone']],
176
+ num_nodes: int,
176
177
  dryrun: bool = False,
177
178
  ) -> Dict[str, Optional[str]]:
178
179
  # TODO get image id here.
@@ -4,7 +4,6 @@ import functools
4
4
  import hashlib
5
5
  import json
6
6
  import os
7
- import resource
8
7
  import time
9
8
  from typing import Any, Callable, Dict, List, Optional, Tuple
10
9
 
@@ -20,6 +19,7 @@ from sky.utils import accelerator_registry
20
19
  from sky.utils import command_runner
21
20
  from sky.utils import common_utils
22
21
  from sky.utils import subprocess_utils
22
+ from sky.utils import timeline
23
23
  from sky.utils import ux_utils
24
24
 
25
25
  logger = sky_logging.init_logger(__name__)
@@ -115,7 +115,8 @@ def _parallel_ssh_with_cache(func,
115
115
  if max_workers is None:
116
116
  # Not using the default value of `max_workers` in ThreadPoolExecutor,
117
117
  # as 32 is too large for some machines.
118
- max_workers = subprocess_utils.get_parallel_threads()
118
+ max_workers = subprocess_utils.get_parallel_threads(
119
+ cluster_info.provider_name)
119
120
  with futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
120
121
  results = []
121
122
  runners = provision.get_command_runners(cluster_info.provider_name,
@@ -170,6 +171,7 @@ def initialize_docker(cluster_name: str, docker_config: Dict[str, Any],
170
171
 
171
172
 
172
173
  @common.log_function_start_end
174
+ @timeline.event
173
175
  def setup_runtime_on_cluster(cluster_name: str, setup_commands: List[str],
174
176
  cluster_info: common.ClusterInfo,
175
177
  ssh_credentials: Dict[str, Any]) -> None:
@@ -245,20 +247,9 @@ def _ray_gpu_options(custom_resource: str) -> str:
245
247
  return f' --num-gpus={acc_count}'
246
248
 
247
249
 
248
- @common.log_function_start_end
249
- @_auto_retry()
250
- def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str],
251
- cluster_info: common.ClusterInfo,
252
- ssh_credentials: Dict[str, Any]) -> None:
253
- """Start Ray on the head node."""
254
- runners = provision.get_command_runners(cluster_info.provider_name,
255
- cluster_info, **ssh_credentials)
256
- head_runner = runners[0]
257
- assert cluster_info.head_instance_id is not None, (cluster_name,
258
- cluster_info)
259
-
260
- # Log the head node's output to the provision.log
261
- log_path_abs = str(provision_logging.get_log_path())
250
+ def ray_head_start_command(custom_resource: Optional[str],
251
+ custom_ray_options: Optional[Dict[str, Any]]) -> str:
252
+ """Returns the command to start Ray on the head node."""
262
253
  ray_options = (
263
254
  # --disable-usage-stats in `ray start` saves 10 seconds of idle wait.
264
255
  f'--disable-usage-stats '
@@ -270,11 +261,10 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str],
270
261
  if custom_resource:
271
262
  ray_options += f' --resources=\'{custom_resource}\''
272
263
  ray_options += _ray_gpu_options(custom_resource)
273
-
274
- if cluster_info.custom_ray_options:
275
- if 'use_external_ip' in cluster_info.custom_ray_options:
276
- cluster_info.custom_ray_options.pop('use_external_ip')
277
- for key, value in cluster_info.custom_ray_options.items():
264
+ if custom_ray_options:
265
+ if 'use_external_ip' in custom_ray_options:
266
+ custom_ray_options.pop('use_external_ip')
267
+ for key, value in custom_ray_options.items():
278
268
  ray_options += f' --{key}={value}'
279
269
 
280
270
  cmd = (
@@ -297,6 +287,62 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str],
297
287
  'RAY_worker_maximum_startup_concurrency=$(( 3 * $(nproc --all) )) '
298
288
  f'{constants.SKY_RAY_CMD} start --head {ray_options} || exit 1;' +
299
289
  _RAY_PRLIMIT + _DUMP_RAY_PORTS + RAY_HEAD_WAIT_INITIALIZED_COMMAND)
290
+ return cmd
291
+
292
+
293
+ def ray_worker_start_command(custom_resource: Optional[str],
294
+ custom_ray_options: Optional[Dict[str, Any]],
295
+ no_restart: bool) -> str:
296
+ """Returns the command to start Ray on the worker node."""
297
+ # We need to use the ray port in the env variable, because the head node
298
+ # determines the port to be used for the worker node.
299
+ ray_options = ('--address=${SKYPILOT_RAY_HEAD_IP}:${SKYPILOT_RAY_PORT} '
300
+ '--object-manager-port=8076')
301
+
302
+ if custom_resource:
303
+ ray_options += f' --resources=\'{custom_resource}\''
304
+ ray_options += _ray_gpu_options(custom_resource)
305
+
306
+ if custom_ray_options:
307
+ for key, value in custom_ray_options.items():
308
+ ray_options += f' --{key}={value}'
309
+
310
+ cmd = (
311
+ 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 '
312
+ f'{constants.SKY_RAY_CMD} start --disable-usage-stats {ray_options} || '
313
+ 'exit 1;' + _RAY_PRLIMIT)
314
+ if no_restart:
315
+ # We do not use ray status to check whether ray is running, because
316
+ # on worker node, if the user started their own ray cluster, ray status
317
+ # will return 0, i.e., we don't know skypilot's ray cluster is running.
318
+ # Instead, we check whether the raylet process is running on gcs address
319
+ # that is connected to the head with the correct port.
320
+ cmd = (
321
+ f'ps aux | grep "ray/raylet/raylet" | '
322
+ 'grep "gcs-address=${SKYPILOT_RAY_HEAD_IP}:${SKYPILOT_RAY_PORT}" '
323
+ f'|| {{ {cmd} }}')
324
+ else:
325
+ cmd = f'{constants.SKY_RAY_CMD} stop; ' + cmd
326
+ return cmd
327
+
328
+
329
+ @common.log_function_start_end
330
+ @_auto_retry()
331
+ @timeline.event
332
+ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str],
333
+ cluster_info: common.ClusterInfo,
334
+ ssh_credentials: Dict[str, Any]) -> None:
335
+ """Start Ray on the head node."""
336
+ runners = provision.get_command_runners(cluster_info.provider_name,
337
+ cluster_info, **ssh_credentials)
338
+ head_runner = runners[0]
339
+ assert cluster_info.head_instance_id is not None, (cluster_name,
340
+ cluster_info)
341
+
342
+ # Log the head node's output to the provision.log
343
+ log_path_abs = str(provision_logging.get_log_path())
344
+ cmd = ray_head_start_command(custom_resource,
345
+ cluster_info.custom_ray_options)
300
346
  logger.info(f'Running command on head node: {cmd}')
301
347
  # TODO(zhwu): add the output to log files.
302
348
  returncode, stdout, stderr = head_runner.run(
@@ -316,6 +362,7 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str],
316
362
 
317
363
  @common.log_function_start_end
318
364
  @_auto_retry()
365
+ @timeline.event
319
366
  def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool,
320
367
  custom_resource: Optional[str], ray_port: int,
321
368
  cluster_info: common.ClusterInfo,
@@ -350,42 +397,17 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool,
350
397
  head_ip = (head_instance.internal_ip
351
398
  if not use_external_ip else head_instance.external_ip)
352
399
 
353
- ray_options = (f'--address={head_ip}:{constants.SKY_REMOTE_RAY_PORT} '
354
- f'--object-manager-port=8076')
355
-
356
- if custom_resource:
357
- ray_options += f' --resources=\'{custom_resource}\''
358
- ray_options += _ray_gpu_options(custom_resource)
359
-
360
- if cluster_info.custom_ray_options:
361
- for key, value in cluster_info.custom_ray_options.items():
362
- ray_options += f' --{key}={value}'
400
+ ray_cmd = ray_worker_start_command(custom_resource,
401
+ cluster_info.custom_ray_options,
402
+ no_restart)
363
403
 
364
- # Unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY, see the comment in
365
- # `start_ray_on_head_node`.
366
- cmd = (
367
- 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 '
368
- f'{constants.SKY_RAY_CMD} start --disable-usage-stats {ray_options} || '
369
- 'exit 1;' + _RAY_PRLIMIT)
370
- if no_restart:
371
- # We do not use ray status to check whether ray is running, because
372
- # on worker node, if the user started their own ray cluster, ray status
373
- # will return 0, i.e., we don't know skypilot's ray cluster is running.
374
- # Instead, we check whether the raylet process is running on gcs address
375
- # that is connected to the head with the correct port.
376
- cmd = (f'RAY_PORT={ray_port}; ps aux | grep "ray/raylet/raylet" | '
377
- f'grep "gcs-address={head_ip}:${{RAY_PORT}}" || '
378
- f'{{ {cmd} }}')
379
- else:
380
- cmd = f'{constants.SKY_RAY_CMD} stop; ' + cmd
404
+ cmd = (f'export SKYPILOT_RAY_HEAD_IP="{head_ip}"; '
405
+ f'export SKYPILOT_RAY_PORT={ray_port}; ' + ray_cmd)
381
406
 
382
407
  logger.info(f'Running command on worker nodes: {cmd}')
383
408
 
384
409
  def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner,
385
410
  str]):
386
- # for cmd in config_from_yaml['worker_start_ray_commands']:
387
- # cmd = cmd.replace('$RAY_HEAD_IP', ip_list[0][0])
388
- # runner.run(cmd)
389
411
  runner, instance_id = runner_and_id
390
412
  log_dir = metadata_utils.get_instance_log_dir(cluster_name, instance_id)
391
413
  log_path_abs = str(log_dir / ('ray_cluster' + '.log'))
@@ -398,8 +420,10 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool,
398
420
  # by ray will have the correct PATH.
399
421
  source_bashrc=True)
400
422
 
423
+ num_threads = subprocess_utils.get_parallel_threads(
424
+ cluster_info.provider_name)
401
425
  results = subprocess_utils.run_in_parallel(
402
- _setup_ray_worker, list(zip(worker_runners, cache_ids)))
426
+ _setup_ray_worker, list(zip(worker_runners, cache_ids)), num_threads)
403
427
  for returncode, stdout, stderr in results:
404
428
  if returncode:
405
429
  with ux_utils.print_exception_no_traceback():
@@ -412,6 +436,7 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool,
412
436
 
413
437
  @common.log_function_start_end
414
438
  @_auto_retry()
439
+ @timeline.event
415
440
  def start_skylet_on_head_node(cluster_name: str,
416
441
  cluster_info: common.ClusterInfo,
417
442
  ssh_credentials: Dict[str, Any]) -> None:
@@ -473,28 +498,8 @@ def _internal_file_mounts(file_mounts: Dict,
473
498
  )
474
499
 
475
500
 
476
- def _max_workers_for_file_mounts(common_file_mounts: Dict[str, str]) -> int:
477
- fd_limit, _ = resource.getrlimit(resource.RLIMIT_NOFILE)
478
-
479
- fd_per_rsync = 5
480
- for src in common_file_mounts.values():
481
- if os.path.isdir(src):
482
- # Assume that each file/folder under src takes 5 file descriptors
483
- # on average.
484
- fd_per_rsync = max(fd_per_rsync, len(os.listdir(src)) * 5)
485
-
486
- # Reserve some file descriptors for the system and other processes
487
- fd_reserve = 100
488
-
489
- max_workers = (fd_limit - fd_reserve) // fd_per_rsync
490
- # At least 1 worker, and avoid too many workers overloading the system.
491
- max_workers = min(max(max_workers, 1),
492
- subprocess_utils.get_parallel_threads())
493
- logger.debug(f'Using {max_workers} workers for file mounts.')
494
- return max_workers
495
-
496
-
497
501
  @common.log_function_start_end
502
+ @timeline.event
498
503
  def internal_file_mounts(cluster_name: str, common_file_mounts: Dict[str, str],
499
504
  cluster_info: common.ClusterInfo,
500
505
  ssh_credentials: Dict[str, str]) -> None:
@@ -515,4 +520,5 @@ def internal_file_mounts(cluster_name: str, common_file_mounts: Dict[str, str],
515
520
  digest=None,
516
521
  cluster_info=cluster_info,
517
522
  ssh_credentials=ssh_credentials,
518
- max_workers=_max_workers_for_file_mounts(common_file_mounts))
523
+ max_workers=subprocess_utils.get_max_workers_for_file_mounts(
524
+ common_file_mounts, cluster_info.provider_name))