skypilot-nightly 1.0.0.dev20250910__py3-none-any.whl → 1.0.0.dev20250912__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (68) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +125 -22
  3. sky/backends/cloud_vm_ray_backend.py +224 -72
  4. sky/catalog/__init__.py +7 -0
  5. sky/catalog/aws_catalog.py +4 -0
  6. sky/catalog/common.py +18 -0
  7. sky/catalog/data_fetchers/fetch_aws.py +13 -1
  8. sky/client/cli/command.py +2 -71
  9. sky/client/sdk_async.py +5 -2
  10. sky/clouds/aws.py +23 -5
  11. sky/clouds/cloud.py +8 -0
  12. sky/dashboard/out/404.html +1 -1
  13. sky/dashboard/out/_next/static/chunks/3294.ba6586f9755b0edb.js +6 -0
  14. sky/dashboard/out/_next/static/chunks/{webpack-1d7e11230da3ca89.js → webpack-e8a0c4c3c6f408fb.js} +1 -1
  15. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  16. sky/dashboard/out/clusters/[cluster].html +1 -1
  17. sky/dashboard/out/clusters.html +1 -1
  18. sky/dashboard/out/config.html +1 -1
  19. sky/dashboard/out/index.html +1 -1
  20. sky/dashboard/out/infra/[context].html +1 -1
  21. sky/dashboard/out/infra.html +1 -1
  22. sky/dashboard/out/jobs/[job].html +1 -1
  23. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  24. sky/dashboard/out/jobs.html +1 -1
  25. sky/dashboard/out/users.html +1 -1
  26. sky/dashboard/out/volumes.html +1 -1
  27. sky/dashboard/out/workspace/new.html +1 -1
  28. sky/dashboard/out/workspaces/[name].html +1 -1
  29. sky/dashboard/out/workspaces.html +1 -1
  30. sky/global_user_state.py +34 -0
  31. sky/jobs/client/sdk_async.py +4 -2
  32. sky/jobs/controller.py +4 -2
  33. sky/jobs/recovery_strategy.py +1 -1
  34. sky/jobs/state.py +26 -16
  35. sky/jobs/utils.py +6 -11
  36. sky/logs/agent.py +10 -2
  37. sky/provision/kubernetes/config.py +7 -2
  38. sky/provision/kubernetes/instance.py +84 -41
  39. sky/provision/vast/instance.py +1 -1
  40. sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
  41. sky/server/config.py +14 -5
  42. sky/server/metrics.py +41 -8
  43. sky/server/requests/executor.py +41 -4
  44. sky/server/server.py +1 -0
  45. sky/server/uvicorn.py +11 -5
  46. sky/skylet/constants.py +12 -7
  47. sky/skylet/log_lib.py +11 -0
  48. sky/skylet/log_lib.pyi +9 -0
  49. sky/task.py +62 -0
  50. sky/templates/kubernetes-ray.yml.j2 +120 -3
  51. sky/utils/accelerator_registry.py +3 -1
  52. sky/utils/command_runner.py +35 -11
  53. sky/utils/command_runner.pyi +22 -0
  54. sky/utils/context_utils.py +15 -2
  55. sky/utils/db/migration_utils.py +1 -1
  56. sky/utils/git.py +559 -1
  57. sky/utils/resource_checker.py +8 -7
  58. sky/workspaces/core.py +57 -21
  59. {skypilot_nightly-1.0.0.dev20250910.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/METADATA +33 -33
  60. {skypilot_nightly-1.0.0.dev20250910.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/RECORD +66 -66
  61. sky/client/cli/git.py +0 -549
  62. sky/dashboard/out/_next/static/chunks/3294.c80326aec9bfed40.js +0 -6
  63. /sky/dashboard/out/_next/static/{3SYxqNGnvvPS8h3gdD2T7 → DAiq7V2xJnO1LSfmunZl6}/_buildManifest.js +0 -0
  64. /sky/dashboard/out/_next/static/{3SYxqNGnvvPS8h3gdD2T7 → DAiq7V2xJnO1LSfmunZl6}/_ssgManifest.js +0 -0
  65. {skypilot_nightly-1.0.0.dev20250910.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/WHEEL +0 -0
  66. {skypilot_nightly-1.0.0.dev20250910.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/entry_points.txt +0 -0
  67. {skypilot_nightly-1.0.0.dev20250910.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/licenses/LICENSE +0 -0
  68. {skypilot_nightly-1.0.0.dev20250910.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/top_level.txt +0 -0
@@ -31,6 +31,7 @@ import time
31
31
  import typing
32
32
  from typing import Any, Callable, Generator, List, Optional, TextIO, Tuple
33
33
 
34
+ import psutil
34
35
  import setproctitle
35
36
 
36
37
  from sky import exceptions
@@ -130,8 +131,9 @@ queue_backend = server_config.QueueBackend.MULTIPROCESSING
130
131
  def executor_initializer(proc_group: str):
131
132
  setproctitle.setproctitle(f'SkyPilot:executor:{proc_group}:'
132
133
  f'{multiprocessing.current_process().pid}')
134
+ # Executor never stops, unless the whole process is killed.
133
135
  threading.Thread(target=metrics_lib.process_monitor,
134
- args=(f'worker:{proc_group}',),
136
+ args=(f'worker:{proc_group}', threading.Event()),
135
137
  daemon=True).start()
136
138
 
137
139
 
@@ -373,11 +375,13 @@ def _request_execution_wrapper(request_id: str,
373
375
  4. Handle the SIGTERM signal to abort the request gracefully.
374
376
  5. Maintain the lifecycle of the temp dir used by the request.
375
377
  """
378
+ pid = multiprocessing.current_process().pid
379
+ proc = psutil.Process(pid)
380
+ rss_begin = proc.memory_info().rss
376
381
  db_utils.set_max_connections(num_db_connections_per_worker)
377
382
  # Handle the SIGTERM signal to abort the request processing gracefully.
378
383
  signal.signal(signal.SIGTERM, _sigterm_handler)
379
384
 
380
- pid = multiprocessing.current_process().pid
381
385
  logger.info(f'Running request {request_id} with pid {pid}')
382
386
  with api_requests.update_request(request_id) as request_task:
383
387
  assert request_task is not None, request_id
@@ -443,8 +447,41 @@ def _request_execution_wrapper(request_id: str,
443
447
  _restore_output(original_stdout, original_stderr)
444
448
  logger.info(f'Request {request_id} finished')
445
449
  finally:
446
- with metrics_lib.time_it(name='release_memory', group='internal'):
447
- common_utils.release_memory()
450
+ try:
451
+ # Capture the peak RSS before GC.
452
+ peak_rss = max(proc.memory_info().rss,
453
+ metrics_lib.peak_rss_bytes)
454
+ with metrics_lib.time_it(name='release_memory',
455
+ group='internal'):
456
+ common_utils.release_memory()
457
+ _record_memory_metrics(request_name, proc, rss_begin, peak_rss)
458
+ except Exception as e: # pylint: disable=broad-except
459
+ logger.error(f'Failed to record memory metrics: '
460
+ f'{common_utils.format_exception(e)}')
461
+
462
+
463
+ _first_request = True
464
+
465
+
466
+ def _record_memory_metrics(request_name: str, proc: psutil.Process,
467
+ rss_begin: int, peak_rss: int) -> None:
468
+ """Record the memory metrics for a request."""
469
+ # Do not record full memory delta for the first request as it
470
+ # will loads the sky core modules and make the memory usage
471
+ # estimation inaccurate.
472
+ global _first_request
473
+ if _first_request:
474
+ _first_request = False
475
+ return
476
+ rss_end = proc.memory_info().rss
477
+
478
+ # Answer "how much RSS this request contributed?"
479
+ metrics_lib.SKY_APISERVER_REQUEST_RSS_INCR_BYTES.labels(
480
+ name=request_name).observe(max(rss_end - rss_begin, 0))
481
+ # Estimate the memory usage by the request by capturing the
482
+ # peak memory delta during the request execution.
483
+ metrics_lib.SKY_APISERVER_REQUEST_MEMORY_USAGE_BYTES.labels(
484
+ name=request_name).observe(max(peak_rss - rss_begin, 0))
448
485
 
449
486
 
450
487
  async def execute_request_coroutine(request: api_requests.Request):
sky/server/server.py CHANGED
@@ -1214,6 +1214,7 @@ async def logs(
1214
1214
  request_body=cluster_job_body,
1215
1215
  func=core.tail_logs,
1216
1216
  schedule_type=requests_lib.ScheduleType.SHORT,
1217
+ request_cluster_name=cluster_job_body.cluster_name,
1217
1218
  )
1218
1219
  task = asyncio.create_task(executor.execute_request_coroutine(request_task))
1219
1220
 
sky/server/uvicorn.py CHANGED
@@ -213,11 +213,17 @@ class Server(uvicorn.Server):
213
213
  # Same as set PYTHONASYNCIODEBUG=1, but with custom threshold.
214
214
  event_loop.set_debug(True)
215
215
  event_loop.slow_callback_duration = lag_threshold
216
- threading.Thread(target=metrics_lib.process_monitor,
217
- args=('server',),
218
- daemon=True).start()
219
- with self.capture_signals():
220
- asyncio.run(self.serve(*args, **kwargs))
216
+ stop_monitor = threading.Event()
217
+ monitor = threading.Thread(target=metrics_lib.process_monitor,
218
+ args=('server', stop_monitor),
219
+ daemon=True)
220
+ monitor.start()
221
+ try:
222
+ with self.capture_signals():
223
+ asyncio.run(self.serve(*args, **kwargs))
224
+ finally:
225
+ stop_monitor.set()
226
+ monitor.join()
221
227
 
222
228
 
223
229
  def run(config: uvicorn.Config, max_db_connections: Optional[int] = None):
sky/skylet/constants.py CHANGED
@@ -62,11 +62,14 @@ SKY_UV_INSTALL_CMD = (f'{SKY_UV_CMD} -V >/dev/null 2>&1 || '
62
62
  'curl -LsSf https://astral.sh/uv/install.sh '
63
63
  f'| UV_INSTALL_DIR={SKY_UV_INSTALL_DIR} sh')
64
64
  SKY_UV_PIP_CMD: str = (f'VIRTUAL_ENV={SKY_REMOTE_PYTHON_ENV} {SKY_UV_CMD} pip')
65
- # Deleting the SKY_REMOTE_PYTHON_ENV_NAME from the PATH to deactivate the
66
- # environment. `deactivate` command does not work when conda is used.
65
+ SKY_UV_RUN_CMD: str = (f'VIRTUAL_ENV={SKY_REMOTE_PYTHON_ENV} {SKY_UV_CMD} run')
66
+ # Deleting the SKY_REMOTE_PYTHON_ENV_NAME from the PATH and unsetting relevant
67
+ # VIRTUAL_ENV envvars to deactivate the environment. `deactivate` command does
68
+ # not work when conda is used.
67
69
  DEACTIVATE_SKY_REMOTE_PYTHON_ENV = (
68
70
  'export PATH='
69
- f'$(echo $PATH | sed "s|$(echo ~)/{SKY_REMOTE_PYTHON_ENV_NAME}/bin:||")')
71
+ f'$(echo $PATH | sed "s|$(echo ~)/{SKY_REMOTE_PYTHON_ENV_NAME}/bin:||") && '
72
+ 'unset VIRTUAL_ENV && unset VIRTUAL_ENV_PROMPT')
70
73
 
71
74
  # Prefix for SkyPilot environment variables
72
75
  SKYPILOT_ENV_VAR_PREFIX = 'SKYPILOT_'
@@ -98,7 +101,7 @@ SKYLET_VERSION = '18'
98
101
  SKYLET_LIB_VERSION = 4
99
102
  SKYLET_VERSION_FILE = '~/.sky/skylet_version'
100
103
  SKYLET_GRPC_PORT = 46590
101
- SKYLET_GRPC_TIMEOUT_SECONDS = 5
104
+ SKYLET_GRPC_TIMEOUT_SECONDS = 10
102
105
 
103
106
  # Docker default options
104
107
  DEFAULT_DOCKER_CONTAINER_NAME = 'sky_container'
@@ -229,7 +232,7 @@ RAY_INSTALLATION_COMMANDS = (
229
232
  'export PATH=$PATH:$HOME/.local/bin; '
230
233
  # Writes ray path to file if it does not exist or the file is empty.
231
234
  f'[ -s {SKY_RAY_PATH_FILE} ] || '
232
- f'{{ {ACTIVATE_SKY_REMOTE_PYTHON_ENV} && '
235
+ f'{{ {SKY_UV_RUN_CMD} '
233
236
  f'which ray > {SKY_RAY_PATH_FILE} || exit 1; }}; ')
234
237
 
235
238
  SKYPILOT_WHEEL_INSTALLATION_COMMANDS = (
@@ -374,7 +377,6 @@ OVERRIDEABLE_CONFIG_KEYS_IN_TASK: List[Tuple[str, ...]] = [
374
377
  ('ssh', 'pod_config'),
375
378
  ('kubernetes', 'custom_metadata'),
376
379
  ('kubernetes', 'pod_config'),
377
- ('kubernetes', 'context_configs'),
378
380
  ('kubernetes', 'provision_timeout'),
379
381
  ('kubernetes', 'dws'),
380
382
  ('kubernetes', 'kueue'),
@@ -449,7 +451,7 @@ SKYPILOT_DEFAULT_WORKSPACE = 'default'
449
451
  # BEGIN constants used for service catalog.
450
452
  HOSTED_CATALOG_DIR_URL = 'https://raw.githubusercontent.com/skypilot-org/skypilot-catalog/master/catalogs' # pylint: disable=line-too-long
451
453
  HOSTED_CATALOG_DIR_URL_S3_MIRROR = 'https://skypilot-catalog.s3.us-east-1.amazonaws.com/catalogs' # pylint: disable=line-too-long
452
- CATALOG_SCHEMA_VERSION = 'v7'
454
+ CATALOG_SCHEMA_VERSION = 'v8'
453
455
  CATALOG_DIR = '~/.sky/catalogs'
454
456
  ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci',
455
457
  'kubernetes', 'runpod', 'vast', 'vsphere', 'cudo', 'fluidstack',
@@ -510,3 +512,6 @@ SKY_LOCKS_DIR = os.path.expanduser('~/.sky/locks')
510
512
 
511
513
  ENV_VAR_LOOP_LAG_THRESHOLD_MS = (SKYPILOT_ENV_VAR_PREFIX +
512
514
  'DEBUG_LOOP_LAG_THRESHOLD_MS')
515
+
516
+ ARM64_ARCH = 'arm64'
517
+ X86_64_ARCH = 'x86_64'
sky/skylet/log_lib.py CHANGED
@@ -354,6 +354,17 @@ def run_bash_command_with_log(bash_command: str,
354
354
  shell=True)
355
355
 
356
356
 
357
+ def run_bash_command_with_log_and_return_pid(
358
+ bash_command: str,
359
+ log_path: str,
360
+ env_vars: Optional[Dict[str, str]] = None,
361
+ stream_logs: bool = False,
362
+ with_ray: bool = False):
363
+ return_code = run_bash_command_with_log(bash_command, log_path, env_vars,
364
+ stream_logs, with_ray)
365
+ return {'return_code': return_code, 'pid': os.getpid()}
366
+
367
+
357
368
  def _follow_job_logs(file,
358
369
  job_id: int,
359
370
  start_streaming: bool,
sky/skylet/log_lib.pyi CHANGED
@@ -129,6 +129,15 @@ def run_bash_command_with_log(bash_command: str,
129
129
  ...
130
130
 
131
131
 
132
+ def run_bash_command_with_log_and_return_pid(
133
+ bash_command: str,
134
+ log_path: str,
135
+ env_vars: Optional[Dict[str, str]] = ...,
136
+ stream_logs: bool = ...,
137
+ with_ray: bool = ...):
138
+ ...
139
+
140
+
132
141
  def tail_logs(job_id: int,
133
142
  log_dir: Optional[str],
134
143
  managed_job_id: Optional[int] = ...,
sky/task.py CHANGED
@@ -20,6 +20,7 @@ from sky.provision import docker_utils
20
20
  from sky.serve import service_spec
21
21
  from sky.skylet import constants
22
22
  from sky.utils import common_utils
23
+ from sky.utils import git
23
24
  from sky.utils import registry
24
25
  from sky.utils import schemas
25
26
  from sky.utils import ux_utils
@@ -1596,6 +1597,67 @@ class Task:
1596
1597
  d[k] = v
1597
1598
  return d
1598
1599
 
1600
+ def update_workdir(self, workdir: Optional[str], git_url: Optional[str],
1601
+ git_ref: Optional[str]) -> 'Task':
1602
+ """Updates the task workdir.
1603
+
1604
+ Args:
1605
+ workdir: The workdir to update.
1606
+ git_url: The git url to update.
1607
+ git_ref: The git ref to update.
1608
+ """
1609
+ if self.workdir is None or isinstance(self.workdir, str):
1610
+ if workdir is not None:
1611
+ self.workdir = workdir
1612
+ return self
1613
+ if git_url is not None:
1614
+ self.workdir = {}
1615
+ self.workdir['url'] = git_url
1616
+ if git_ref is not None:
1617
+ self.workdir['ref'] = git_ref
1618
+ return self
1619
+ return self
1620
+ if git_url is not None:
1621
+ self.workdir['url'] = git_url
1622
+ if git_ref is not None:
1623
+ self.workdir['ref'] = git_ref
1624
+ return self
1625
+
1626
+ def update_envs_and_secrets_from_workdir(self) -> 'Task':
1627
+ """Updates the task envs and secrets from the workdir."""
1628
+ if self.workdir is None:
1629
+ return self
1630
+ if not isinstance(self.workdir, dict):
1631
+ return self
1632
+ url = self.workdir['url']
1633
+ ref = self.workdir.get('ref', '')
1634
+ token = os.environ.get(git.GIT_TOKEN_ENV_VAR)
1635
+ ssh_key_path = os.environ.get(git.GIT_SSH_KEY_PATH_ENV_VAR)
1636
+ try:
1637
+ git_repo = git.GitRepo(url, ref, token, ssh_key_path)
1638
+ clone_info = git_repo.get_repo_clone_info()
1639
+ if clone_info is None:
1640
+ return self
1641
+ self.envs[git.GIT_URL_ENV_VAR] = clone_info.url
1642
+ if ref:
1643
+ ref_type = git_repo.get_ref_type()
1644
+ if ref_type == git.GitRefType.COMMIT:
1645
+ self.envs[git.GIT_COMMIT_HASH_ENV_VAR] = ref
1646
+ elif ref_type == git.GitRefType.BRANCH:
1647
+ self.envs[git.GIT_BRANCH_ENV_VAR] = ref
1648
+ elif ref_type == git.GitRefType.TAG:
1649
+ self.envs[git.GIT_TAG_ENV_VAR] = ref
1650
+ if clone_info.token is None and clone_info.ssh_key is None:
1651
+ return self
1652
+ if clone_info.token is not None:
1653
+ self.secrets[git.GIT_TOKEN_ENV_VAR] = clone_info.token
1654
+ if clone_info.ssh_key is not None:
1655
+ self.secrets[git.GIT_SSH_KEY_ENV_VAR] = clone_info.ssh_key
1656
+ except exceptions.GitError as e:
1657
+ with ux_utils.print_exception_no_traceback():
1658
+ raise ValueError(f'{str(e)}') from None
1659
+ return self
1660
+
1599
1661
  def to_yaml_config(self,
1600
1662
  use_user_specified_yaml: bool = False) -> Dict[str, Any]:
1601
1663
  """Returns a yaml-style dict representation of the task.
@@ -654,8 +654,125 @@ available_node_types:
654
654
  # after v0.11.0 release.
655
655
  touch /tmp/apt_ssh_setup_started
656
656
 
657
- DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get update > /tmp/apt-update.log 2>&1 || \
658
- echo "Warning: apt-get update failed. Continuing anyway..." >> /tmp/apt-update.log
657
+ # Helper: run apt-get update with retries
658
+ apt_update_with_retries() {
659
+ # do not fail the whole shell; we handle return codes
660
+ set +e
661
+ local log=/tmp/apt-update.log
662
+ local tries=3
663
+ local delay=1
664
+ local i
665
+ for i in $(seq 1 $tries); do
666
+ DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get update >> "$log" 2>&1 && { set -e; return 0; }
667
+ echo "apt-get update attempt $i/$tries failed; retrying in ${delay}s" >> "$log"
668
+ sleep $delay
669
+ delay=$((delay * 2))
670
+ done
671
+ set -e
672
+ return 1
673
+ }
674
+ apt_install_with_retries() {
675
+ local packages="$@"
676
+ [ -z "$packages" ] && return 0
677
+ set +e
678
+ local log=/tmp/apt-update.log
679
+ local tries=3
680
+ local delay=1
681
+ local i
682
+ for i in $(seq 1 $tries); do
683
+ DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get install -y -o Dpkg::Options::="--force-confdef" -o Dpkg::Options::="--force-confold" $packages && { set -e; return 0; }
684
+ echo "apt-get install failed for: $packages (attempt $i/$tries). Running -f install and retrying..." >> "$log"
685
+ DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get -f install -y >> "$log" 2>&1 || true
686
+ DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get clean >> "$log" 2>&1 || true
687
+ sleep $delay
688
+ delay=$((delay * 2))
689
+ done
690
+ set -e
691
+ return 1
692
+ }
693
+ apt_update_install_with_retries() {
694
+ apt_update_with_retries
695
+ apt_install_with_retries "$@"
696
+ }
697
+ backup_dir=/etc/apt/sources.list.backup_skypilot
698
+ backup_source() {
699
+ $(prefix_cmd) mkdir -p "$backup_dir"
700
+ if [ -f /etc/apt/sources.list ] && [ ! -f "$backup_dir/sources.list" ]; then
701
+ $(prefix_cmd) cp -a /etc/apt/sources.list "$backup_dir/sources.list" || true
702
+ fi
703
+ }
704
+ restore_source() {
705
+ if [ -f "$backup_dir/sources.list" ]; then
706
+ $(prefix_cmd) cp -a "$backup_dir/sources.list" /etc/apt/sources.list || true
707
+ fi
708
+ }
709
+ update_apt_sources() {
710
+ local host=$1
711
+ local apt_file=$2
712
+ $(prefix_cmd) sed -i -E "s|https?://[a-zA-Z0-9.-]+\.ubuntu\.com/ubuntu|http://$host/ubuntu|g" $apt_file
713
+ }
714
+ # Helper: install packages across mirrors with retries
715
+ apt_install_with_mirrors() {
716
+ local required=$1; shift
717
+ local packages="$@"
718
+ [ -z "$packages" ] && return 0
719
+ set +e
720
+ # Install packages with default sources first
721
+ local log=/tmp/apt-update.log
722
+ echo "$(date +%Y-%m-%d\ %H:%M:%S) Installing packages: $packages" >> "$log"
723
+ restore_source
724
+ apt_update_install_with_retries $packages >> "$log" 2>&1 && { set -e; return 0; }
725
+ echo "Install failed with default sources: $packages" >> "$log"
726
+ # Detect distro (ubuntu/debian)
727
+ local APT_OS="unknown"
728
+ if [ -f /etc/os-release ]; then
729
+ . /etc/os-release
730
+ case "$ID" in
731
+ debian) APT_OS="debian" ;;
732
+ ubuntu) APT_OS="ubuntu" ;;
733
+ *)
734
+ if [ -n "$ID_LIKE" ]; then
735
+ case " $ID $ID_LIKE " in
736
+ *ubuntu*) APT_OS="ubuntu" ;;
737
+ *debian*) APT_OS="debian" ;;
738
+ esac
739
+ fi
740
+ ;;
741
+ esac
742
+ fi
743
+ # Build mirror candidates
744
+ # deb.debian.org is a CDN endpoint, if one backend goes down,
745
+ # the CDN automatically fails over to another mirror,
746
+ # so we only retry for ubuntu here.
747
+ if [ "$APT_OS" = "ubuntu" ]; then
748
+ # Backup current sources once
749
+ backup_source
750
+ # Selected from https://launchpad.net/ubuntu/+archivemirrors
751
+ # and results from apt-select
752
+ local MIRROR_CANDIDATES="mirrors.wikimedia.org mirror.umd.edu"
753
+ for host in $MIRROR_CANDIDATES; do
754
+ echo "Trying APT mirror ($APT_OS): $host" >> "$log"
755
+ if [ -f /etc/apt/sources.list ]; then
756
+ update_apt_sources $host /etc/apt/sources.list
757
+ else
758
+ echo "Error: /etc/apt/sources.list not found" >> "$log"
759
+ break
760
+ fi
761
+ apt_update_install_with_retries $packages >> "$log" 2>&1 && { set -e; return 0; }
762
+ echo "Install failed with mirror ($APT_OS): $host" >> "$log"
763
+ # Restore to default sources
764
+ restore_source
765
+ done
766
+ fi
767
+ set -e
768
+ if [ "$required" = "1" ]; then
769
+ echo "Error: required package install failed across all mirrors: $packages" >> "$log"
770
+ return 1
771
+ else
772
+ echo "Optional package install failed across all mirrors: $packages; skipping." >> "$log"
773
+ return 0
774
+ fi
775
+ }
659
776
  # Install both fuse2 and fuse3 for compatibility for all possible fuse adapters in advance,
660
777
  # so that both fusemount and fusermount3 can be masked before enabling SSH access.
661
778
  PACKAGES="rsync curl wget netcat gcc patch pciutils fuse fuse3 openssh-server";
@@ -682,7 +799,7 @@ available_node_types:
682
799
  done;
683
800
  if [ ! -z "$INSTALL_FIRST" ]; then
684
801
  echo "Installing core packages: $INSTALL_FIRST";
685
- DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get install -y -o Dpkg::Options::="--force-confdef" -o Dpkg::Options::="--force-confold" $INSTALL_FIRST;
802
+ apt_install_with_mirrors 1 $INSTALL_FIRST || { echo "Error: core package installation failed." >> /tmp/apt-update.log; exit 1; }
686
803
  fi;
687
804
  # SSH and other packages are not necessary, so we disable set -e
688
805
  set +e
@@ -107,10 +107,12 @@ def canonicalize_accelerator_name(accelerator: str,
107
107
  if not names and cloud_str in ['Kubernetes', None]:
108
108
  with rich_utils.safe_status(
109
109
  ux_utils.spinner_message('Listing accelerators on Kubernetes')):
110
+ # Only search for Kubernetes to reduce the lookup cost.
111
+ # For other clouds, the catalog has been searched in previous steps.
110
112
  searched = catalog.list_accelerators(
111
113
  name_filter=accelerator,
112
114
  case_sensitive=False,
113
- clouds=cloud_str,
115
+ clouds='Kubernetes',
114
116
  )
115
117
  names = list(searched.keys())
116
118
  if accelerator in names:
@@ -469,15 +469,19 @@ class CommandRunner:
469
469
  """Close the cached connection to the remote machine."""
470
470
  pass
471
471
 
472
- def port_forward_command(self,
473
- port_forward: List[Tuple[int, int]],
474
- connect_timeout: int = 1) -> List[str]:
472
+ def port_forward_command(
473
+ self,
474
+ port_forward: List[Tuple[int, int]],
475
+ connect_timeout: int = 1,
476
+ ssh_mode: SshMode = SshMode.INTERACTIVE) -> List[str]:
475
477
  """Command for forwarding ports from localhost to the remote machine.
476
478
 
477
479
  Args:
478
480
  port_forward: A list of ports to forward from the localhost to the
479
481
  remote host.
480
482
  connect_timeout: The timeout for the connection.
483
+ ssh_mode: The mode to use for ssh.
484
+ See SSHMode for more details.
481
485
  """
482
486
  raise NotImplementedError
483
487
 
@@ -592,6 +596,7 @@ class SSHCommandRunner(CommandRunner):
592
596
  ssh_proxy_command: Optional[str] = None,
593
597
  docker_user: Optional[str] = None,
594
598
  disable_control_master: Optional[bool] = False,
599
+ port_forward_execute_remote_command: Optional[bool] = False,
595
600
  ):
596
601
  """Initialize SSHCommandRunner.
597
602
 
@@ -618,6 +623,10 @@ class SSHCommandRunner(CommandRunner):
618
623
  disable_control_master: bool; specifies either or not the ssh
619
624
  command will utilize ControlMaster. We currently disable
620
625
  it for k8s instance.
626
+ port_forward_execute_remote_command: bool; specifies whether to
627
+ add -N to the port forwarding command. This is useful if you
628
+ want to run a command on the remote machine to make sure the
629
+ SSH tunnel is established.
621
630
  """
622
631
  super().__init__(node)
623
632
  ip, port = node
@@ -646,22 +655,28 @@ class SSHCommandRunner(CommandRunner):
646
655
  self.ssh_user = ssh_user
647
656
  self.port = port
648
657
  self._docker_ssh_proxy_command = None
658
+ self.port_forward_execute_remote_command = (
659
+ port_forward_execute_remote_command)
649
660
 
650
- def port_forward_command(self,
651
- port_forward: List[Tuple[int, int]],
652
- connect_timeout: int = 1) -> List[str]:
661
+ def port_forward_command(
662
+ self,
663
+ port_forward: List[Tuple[int, int]],
664
+ connect_timeout: int = 1,
665
+ ssh_mode: SshMode = SshMode.INTERACTIVE) -> List[str]:
653
666
  """Command for forwarding ports from localhost to the remote machine.
654
667
 
655
668
  Args:
656
669
  port_forward: A list of ports to forward from the local port to the
657
670
  remote port.
658
671
  connect_timeout: The timeout for the ssh connection.
672
+ ssh_mode: The mode to use for ssh.
673
+ See SSHMode for more details.
659
674
 
660
675
  Returns:
661
676
  The command for forwarding ports from localhost to the remote
662
677
  machine.
663
678
  """
664
- return self.ssh_base_command(ssh_mode=SshMode.INTERACTIVE,
679
+ return self.ssh_base_command(ssh_mode=ssh_mode,
665
680
  port_forward=port_forward,
666
681
  connect_timeout=connect_timeout)
667
682
 
@@ -680,7 +695,11 @@ class SSHCommandRunner(CommandRunner):
680
695
  for local, remote in port_forward:
681
696
  logger.debug(
682
697
  f'Forwarding local port {local} to remote port {remote}.')
683
- ssh += ['-NL', f'{local}:localhost:{remote}']
698
+ if self.port_forward_execute_remote_command:
699
+ ssh += ['-L']
700
+ else:
701
+ ssh += ['-NL']
702
+ ssh += [f'{local}:localhost:{remote}']
684
703
  if self._docker_ssh_proxy_command is not None:
685
704
  docker_ssh_proxy_command = self._docker_ssh_proxy_command(ssh)
686
705
  else:
@@ -894,9 +913,11 @@ class KubernetesCommandRunner(CommandRunner):
894
913
  else:
895
914
  return f'pod/{self.pod_name}'
896
915
 
897
- def port_forward_command(self,
898
- port_forward: List[Tuple[int, int]],
899
- connect_timeout: int = 1) -> List[str]:
916
+ def port_forward_command(
917
+ self,
918
+ port_forward: List[Tuple[int, int]],
919
+ connect_timeout: int = 1,
920
+ ssh_mode: SshMode = SshMode.INTERACTIVE) -> List[str]:
900
921
  """Command for forwarding ports from localhost to the remote machine.
901
922
 
902
923
  Args:
@@ -904,7 +925,10 @@ class KubernetesCommandRunner(CommandRunner):
904
925
  remote port. Currently, only one port is supported, i.e. the
905
926
  list should have only one element.
906
927
  connect_timeout: The timeout for the ssh connection.
928
+ ssh_mode: The mode to use for ssh.
929
+ See SSHMode for more details.
907
930
  """
931
+ del ssh_mode # unused
908
932
  assert port_forward and len(port_forward) == 1, (
909
933
  'Only one port is supported for Kubernetes port-forward.')
910
934
  kubectl_args = [
@@ -106,6 +106,13 @@ class CommandRunner:
106
106
  max_retry: int = ...) -> None:
107
107
  ...
108
108
 
109
+ def port_forward_command(
110
+ self,
111
+ port_forward: List[Tuple[int, int]],
112
+ connect_timeout: int = 1,
113
+ ssh_mode: SshMode = SshMode.INTERACTIVE) -> List[str]:
114
+ ...
115
+
109
116
  @classmethod
110
117
  def make_runner_list(cls: typing.Type[CommandRunner],
111
118
  node_list: Iterable[Tuple[Any, ...]],
@@ -127,6 +134,7 @@ class SSHCommandRunner(CommandRunner):
127
134
  ssh_control_name: Optional[str]
128
135
  docker_user: str
129
136
  disable_control_master: Optional[bool]
137
+ port_forward_execute_remote_command: Optional[bool]
130
138
 
131
139
  def __init__(
132
140
  self,
@@ -200,6 +208,13 @@ class SSHCommandRunner(CommandRunner):
200
208
  max_retry: int = ...) -> None:
201
209
  ...
202
210
 
211
+ def port_forward_command(
212
+ self,
213
+ port_forward: List[Tuple[int, int]],
214
+ connect_timeout: int = 1,
215
+ ssh_mode: SshMode = SshMode.INTERACTIVE) -> List[str]:
216
+ ...
217
+
203
218
 
204
219
  class KubernetesCommandRunner(CommandRunner):
205
220
 
@@ -272,6 +287,13 @@ class KubernetesCommandRunner(CommandRunner):
272
287
  max_retry: int = ...) -> None:
273
288
  ...
274
289
 
290
+ def port_forward_command(
291
+ self,
292
+ port_forward: List[Tuple[int, int]],
293
+ connect_timeout: int = 1,
294
+ ssh_mode: SshMode = SshMode.INTERACTIVE) -> List[str]:
295
+ ...
296
+
275
297
 
276
298
  class LocalProcessCommandRunner(CommandRunner):
277
299
 
@@ -10,6 +10,8 @@ import sys
10
10
  import typing
11
11
  from typing import Any, Callable, IO, Optional, Tuple, TypeVar
12
12
 
13
+ from typing_extensions import ParamSpec
14
+
13
15
  from sky import sky_logging
14
16
  from sky.utils import context
15
17
  from sky.utils import subprocess_utils
@@ -173,9 +175,14 @@ def cancellation_guard(func: F) -> F:
173
175
  return typing.cast(F, wrapper)
174
176
 
175
177
 
178
+ P = ParamSpec('P')
179
+ T = TypeVar('T')
180
+
181
+
176
182
  # TODO(aylei): replace this with asyncio.to_thread once we drop support for
177
183
  # python 3.8
178
- def to_thread(func, /, *args, **kwargs):
184
+ def to_thread(func: Callable[P, T], /, *args: P.args,
185
+ **kwargs: P.kwargs) -> 'asyncio.Future[T]':
179
186
  """Asynchronously run function *func* in a separate thread.
180
187
 
181
188
  This is same as asyncio.to_thread added in python 3.9
@@ -183,5 +190,11 @@ def to_thread(func, /, *args, **kwargs):
183
190
  loop = asyncio.get_running_loop()
184
191
  # This is critical to pass the current coroutine context to the new thread
185
192
  pyctx = contextvars.copy_context()
186
- func_call = functools.partial(pyctx.run, func, *args, **kwargs)
193
+ func_call: Callable[..., T] = functools.partial(
194
+ # partial deletes arguments type and thus can't figure out the return
195
+ # type of pyctx.run
196
+ pyctx.run, # type: ignore
197
+ func,
198
+ *args,
199
+ **kwargs)
187
200
  return loop.run_in_executor(None, func_call)
@@ -17,7 +17,7 @@ logger = sky_logging.init_logger(__name__)
17
17
  DB_INIT_LOCK_TIMEOUT_SECONDS = 10
18
18
 
19
19
  GLOBAL_USER_STATE_DB_NAME = 'state_db'
20
- GLOBAL_USER_STATE_VERSION = '007'
20
+ GLOBAL_USER_STATE_VERSION = '008'
21
21
  GLOBAL_USER_STATE_LOCK_PATH = '~/.sky/locks/.state_db.lock'
22
22
 
23
23
  SPOT_JOBS_DB_NAME = 'spot_jobs_db'