skypilot-nightly 1.0.0.dev20250909__py3-none-any.whl → 1.0.0.dev20250912__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (97) hide show
  1. sky/__init__.py +2 -2
  2. sky/authentication.py +19 -4
  3. sky/backends/backend_utils.py +160 -23
  4. sky/backends/cloud_vm_ray_backend.py +226 -74
  5. sky/catalog/__init__.py +7 -0
  6. sky/catalog/aws_catalog.py +4 -0
  7. sky/catalog/common.py +18 -0
  8. sky/catalog/data_fetchers/fetch_aws.py +13 -1
  9. sky/client/cli/command.py +2 -71
  10. sky/client/sdk.py +20 -0
  11. sky/client/sdk_async.py +23 -18
  12. sky/clouds/aws.py +26 -6
  13. sky/clouds/cloud.py +8 -0
  14. sky/dashboard/out/404.html +1 -1
  15. sky/dashboard/out/_next/static/chunks/3294.ba6586f9755b0edb.js +6 -0
  16. sky/dashboard/out/_next/static/chunks/{webpack-d4fabc08788e14af.js → webpack-e8a0c4c3c6f408fb.js} +1 -1
  17. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  18. sky/dashboard/out/clusters/[cluster].html +1 -1
  19. sky/dashboard/out/clusters.html +1 -1
  20. sky/dashboard/out/config.html +1 -1
  21. sky/dashboard/out/index.html +1 -1
  22. sky/dashboard/out/infra/[context].html +1 -1
  23. sky/dashboard/out/infra.html +1 -1
  24. sky/dashboard/out/jobs/[job].html +1 -1
  25. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  26. sky/dashboard/out/jobs.html +1 -1
  27. sky/dashboard/out/users.html +1 -1
  28. sky/dashboard/out/volumes.html +1 -1
  29. sky/dashboard/out/workspace/new.html +1 -1
  30. sky/dashboard/out/workspaces/[name].html +1 -1
  31. sky/dashboard/out/workspaces.html +1 -1
  32. sky/data/storage.py +5 -1
  33. sky/execution.py +21 -14
  34. sky/global_user_state.py +34 -0
  35. sky/jobs/client/sdk_async.py +4 -2
  36. sky/jobs/constants.py +3 -0
  37. sky/jobs/controller.py +734 -310
  38. sky/jobs/recovery_strategy.py +251 -129
  39. sky/jobs/scheduler.py +247 -174
  40. sky/jobs/server/core.py +20 -4
  41. sky/jobs/server/utils.py +2 -2
  42. sky/jobs/state.py +709 -508
  43. sky/jobs/utils.py +90 -40
  44. sky/logs/agent.py +10 -2
  45. sky/provision/aws/config.py +4 -1
  46. sky/provision/gcp/config.py +6 -1
  47. sky/provision/kubernetes/config.py +7 -2
  48. sky/provision/kubernetes/instance.py +84 -41
  49. sky/provision/kubernetes/utils.py +17 -8
  50. sky/provision/provisioner.py +1 -0
  51. sky/provision/vast/instance.py +1 -1
  52. sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
  53. sky/serve/replica_managers.py +0 -7
  54. sky/serve/serve_utils.py +5 -0
  55. sky/serve/server/impl.py +1 -2
  56. sky/serve/service.py +0 -2
  57. sky/server/common.py +8 -3
  58. sky/server/config.py +55 -27
  59. sky/server/constants.py +1 -0
  60. sky/server/daemons.py +7 -11
  61. sky/server/metrics.py +41 -8
  62. sky/server/requests/executor.py +41 -4
  63. sky/server/requests/serializers/encoders.py +1 -1
  64. sky/server/server.py +9 -1
  65. sky/server/uvicorn.py +11 -5
  66. sky/setup_files/dependencies.py +4 -2
  67. sky/skylet/attempt_skylet.py +1 -0
  68. sky/skylet/constants.py +14 -7
  69. sky/skylet/events.py +2 -10
  70. sky/skylet/log_lib.py +11 -0
  71. sky/skylet/log_lib.pyi +9 -0
  72. sky/task.py +62 -0
  73. sky/templates/kubernetes-ray.yml.j2 +120 -3
  74. sky/utils/accelerator_registry.py +3 -1
  75. sky/utils/command_runner.py +35 -11
  76. sky/utils/command_runner.pyi +25 -3
  77. sky/utils/common_utils.py +11 -1
  78. sky/utils/context_utils.py +15 -2
  79. sky/utils/controller_utils.py +5 -0
  80. sky/utils/db/db_utils.py +31 -2
  81. sky/utils/db/migration_utils.py +1 -1
  82. sky/utils/git.py +559 -1
  83. sky/utils/resource_checker.py +8 -7
  84. sky/utils/rich_utils.py +3 -1
  85. sky/utils/subprocess_utils.py +9 -0
  86. sky/volumes/volume.py +2 -0
  87. sky/workspaces/core.py +57 -21
  88. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/METADATA +38 -36
  89. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/RECORD +95 -95
  90. sky/client/cli/git.py +0 -549
  91. sky/dashboard/out/_next/static/chunks/3294.c80326aec9bfed40.js +0 -6
  92. /sky/dashboard/out/_next/static/{eWytLgin5zvayQw3Xk46m → DAiq7V2xJnO1LSfmunZl6}/_buildManifest.js +0 -0
  93. /sky/dashboard/out/_next/static/{eWytLgin5zvayQw3Xk46m → DAiq7V2xJnO1LSfmunZl6}/_ssgManifest.js +0 -0
  94. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/WHEEL +0 -0
  95. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/entry_points.txt +0 -0
  96. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/licenses/LICENSE +0 -0
  97. {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250912.dist-info}/top_level.txt +0 -0
sky/skylet/events.py CHANGED
@@ -11,7 +11,7 @@ import psutil
11
11
  from sky import clouds
12
12
  from sky import sky_logging
13
13
  from sky.backends import cloud_vm_ray_backend
14
- from sky.jobs import scheduler as managed_job_scheduler
14
+ from sky.jobs import scheduler
15
15
  from sky.jobs import state as managed_job_state
16
16
  from sky.jobs import utils as managed_job_utils
17
17
  from sky.serve import serve_utils
@@ -76,15 +76,7 @@ class ManagedJobEvent(SkyletEvent):
76
76
  def _run(self):
77
77
  logger.info('=== Updating managed job status ===')
78
78
  managed_job_utils.update_managed_jobs_statuses()
79
-
80
-
81
- class ManagedJobSchedulingEvent(SkyletEvent):
82
- """Skylet event for scheduling managed jobs."""
83
- EVENT_INTERVAL_SECONDS = 20
84
-
85
- def _run(self):
86
- logger.info('=== Scheduling next jobs ===')
87
- managed_job_scheduler.maybe_schedule_next_jobs()
79
+ scheduler.maybe_start_controllers()
88
80
 
89
81
 
90
82
  class ServiceUpdateEvent(SkyletEvent):
sky/skylet/log_lib.py CHANGED
@@ -354,6 +354,17 @@ def run_bash_command_with_log(bash_command: str,
354
354
  shell=True)
355
355
 
356
356
 
357
+ def run_bash_command_with_log_and_return_pid(
358
+ bash_command: str,
359
+ log_path: str,
360
+ env_vars: Optional[Dict[str, str]] = None,
361
+ stream_logs: bool = False,
362
+ with_ray: bool = False):
363
+ return_code = run_bash_command_with_log(bash_command, log_path, env_vars,
364
+ stream_logs, with_ray)
365
+ return {'return_code': return_code, 'pid': os.getpid()}
366
+
367
+
357
368
  def _follow_job_logs(file,
358
369
  job_id: int,
359
370
  start_streaming: bool,
sky/skylet/log_lib.pyi CHANGED
@@ -129,6 +129,15 @@ def run_bash_command_with_log(bash_command: str,
129
129
  ...
130
130
 
131
131
 
132
+ def run_bash_command_with_log_and_return_pid(
133
+ bash_command: str,
134
+ log_path: str,
135
+ env_vars: Optional[Dict[str, str]] = ...,
136
+ stream_logs: bool = ...,
137
+ with_ray: bool = ...):
138
+ ...
139
+
140
+
132
141
  def tail_logs(job_id: int,
133
142
  log_dir: Optional[str],
134
143
  managed_job_id: Optional[int] = ...,
sky/task.py CHANGED
@@ -20,6 +20,7 @@ from sky.provision import docker_utils
20
20
  from sky.serve import service_spec
21
21
  from sky.skylet import constants
22
22
  from sky.utils import common_utils
23
+ from sky.utils import git
23
24
  from sky.utils import registry
24
25
  from sky.utils import schemas
25
26
  from sky.utils import ux_utils
@@ -1596,6 +1597,67 @@ class Task:
1596
1597
  d[k] = v
1597
1598
  return d
1598
1599
 
1600
+ def update_workdir(self, workdir: Optional[str], git_url: Optional[str],
1601
+ git_ref: Optional[str]) -> 'Task':
1602
+ """Updates the task workdir.
1603
+
1604
+ Args:
1605
+ workdir: The workdir to update.
1606
+ git_url: The git url to update.
1607
+ git_ref: The git ref to update.
1608
+ """
1609
+ if self.workdir is None or isinstance(self.workdir, str):
1610
+ if workdir is not None:
1611
+ self.workdir = workdir
1612
+ return self
1613
+ if git_url is not None:
1614
+ self.workdir = {}
1615
+ self.workdir['url'] = git_url
1616
+ if git_ref is not None:
1617
+ self.workdir['ref'] = git_ref
1618
+ return self
1619
+ return self
1620
+ if git_url is not None:
1621
+ self.workdir['url'] = git_url
1622
+ if git_ref is not None:
1623
+ self.workdir['ref'] = git_ref
1624
+ return self
1625
+
1626
+ def update_envs_and_secrets_from_workdir(self) -> 'Task':
1627
+ """Updates the task envs and secrets from the workdir."""
1628
+ if self.workdir is None:
1629
+ return self
1630
+ if not isinstance(self.workdir, dict):
1631
+ return self
1632
+ url = self.workdir['url']
1633
+ ref = self.workdir.get('ref', '')
1634
+ token = os.environ.get(git.GIT_TOKEN_ENV_VAR)
1635
+ ssh_key_path = os.environ.get(git.GIT_SSH_KEY_PATH_ENV_VAR)
1636
+ try:
1637
+ git_repo = git.GitRepo(url, ref, token, ssh_key_path)
1638
+ clone_info = git_repo.get_repo_clone_info()
1639
+ if clone_info is None:
1640
+ return self
1641
+ self.envs[git.GIT_URL_ENV_VAR] = clone_info.url
1642
+ if ref:
1643
+ ref_type = git_repo.get_ref_type()
1644
+ if ref_type == git.GitRefType.COMMIT:
1645
+ self.envs[git.GIT_COMMIT_HASH_ENV_VAR] = ref
1646
+ elif ref_type == git.GitRefType.BRANCH:
1647
+ self.envs[git.GIT_BRANCH_ENV_VAR] = ref
1648
+ elif ref_type == git.GitRefType.TAG:
1649
+ self.envs[git.GIT_TAG_ENV_VAR] = ref
1650
+ if clone_info.token is None and clone_info.ssh_key is None:
1651
+ return self
1652
+ if clone_info.token is not None:
1653
+ self.secrets[git.GIT_TOKEN_ENV_VAR] = clone_info.token
1654
+ if clone_info.ssh_key is not None:
1655
+ self.secrets[git.GIT_SSH_KEY_ENV_VAR] = clone_info.ssh_key
1656
+ except exceptions.GitError as e:
1657
+ with ux_utils.print_exception_no_traceback():
1658
+ raise ValueError(f'{str(e)}') from None
1659
+ return self
1660
+
1599
1661
  def to_yaml_config(self,
1600
1662
  use_user_specified_yaml: bool = False) -> Dict[str, Any]:
1601
1663
  """Returns a yaml-style dict representation of the task.
@@ -654,8 +654,125 @@ available_node_types:
654
654
  # after v0.11.0 release.
655
655
  touch /tmp/apt_ssh_setup_started
656
656
 
657
- DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get update > /tmp/apt-update.log 2>&1 || \
658
- echo "Warning: apt-get update failed. Continuing anyway..." >> /tmp/apt-update.log
657
+ # Helper: run apt-get update with retries
658
+ apt_update_with_retries() {
659
+ # do not fail the whole shell; we handle return codes
660
+ set +e
661
+ local log=/tmp/apt-update.log
662
+ local tries=3
663
+ local delay=1
664
+ local i
665
+ for i in $(seq 1 $tries); do
666
+ DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get update >> "$log" 2>&1 && { set -e; return 0; }
667
+ echo "apt-get update attempt $i/$tries failed; retrying in ${delay}s" >> "$log"
668
+ sleep $delay
669
+ delay=$((delay * 2))
670
+ done
671
+ set -e
672
+ return 1
673
+ }
674
+ apt_install_with_retries() {
675
+ local packages="$@"
676
+ [ -z "$packages" ] && return 0
677
+ set +e
678
+ local log=/tmp/apt-update.log
679
+ local tries=3
680
+ local delay=1
681
+ local i
682
+ for i in $(seq 1 $tries); do
683
+ DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get install -y -o Dpkg::Options::="--force-confdef" -o Dpkg::Options::="--force-confold" $packages && { set -e; return 0; }
684
+ echo "apt-get install failed for: $packages (attempt $i/$tries). Running -f install and retrying..." >> "$log"
685
+ DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get -f install -y >> "$log" 2>&1 || true
686
+ DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get clean >> "$log" 2>&1 || true
687
+ sleep $delay
688
+ delay=$((delay * 2))
689
+ done
690
+ set -e
691
+ return 1
692
+ }
693
+ apt_update_install_with_retries() {
694
+ apt_update_with_retries
695
+ apt_install_with_retries "$@"
696
+ }
697
+ backup_dir=/etc/apt/sources.list.backup_skypilot
698
+ backup_source() {
699
+ $(prefix_cmd) mkdir -p "$backup_dir"
700
+ if [ -f /etc/apt/sources.list ] && [ ! -f "$backup_dir/sources.list" ]; then
701
+ $(prefix_cmd) cp -a /etc/apt/sources.list "$backup_dir/sources.list" || true
702
+ fi
703
+ }
704
+ restore_source() {
705
+ if [ -f "$backup_dir/sources.list" ]; then
706
+ $(prefix_cmd) cp -a "$backup_dir/sources.list" /etc/apt/sources.list || true
707
+ fi
708
+ }
709
+ update_apt_sources() {
710
+ local host=$1
711
+ local apt_file=$2
712
+ $(prefix_cmd) sed -i -E "s|https?://[a-zA-Z0-9.-]+\.ubuntu\.com/ubuntu|http://$host/ubuntu|g" $apt_file
713
+ }
714
+ # Helper: install packages across mirrors with retries
715
+ apt_install_with_mirrors() {
716
+ local required=$1; shift
717
+ local packages="$@"
718
+ [ -z "$packages" ] && return 0
719
+ set +e
720
+ # Install packages with default sources first
721
+ local log=/tmp/apt-update.log
722
+ echo "$(date +%Y-%m-%d\ %H:%M:%S) Installing packages: $packages" >> "$log"
723
+ restore_source
724
+ apt_update_install_with_retries $packages >> "$log" 2>&1 && { set -e; return 0; }
725
+ echo "Install failed with default sources: $packages" >> "$log"
726
+ # Detect distro (ubuntu/debian)
727
+ local APT_OS="unknown"
728
+ if [ -f /etc/os-release ]; then
729
+ . /etc/os-release
730
+ case "$ID" in
731
+ debian) APT_OS="debian" ;;
732
+ ubuntu) APT_OS="ubuntu" ;;
733
+ *)
734
+ if [ -n "$ID_LIKE" ]; then
735
+ case " $ID $ID_LIKE " in
736
+ *ubuntu*) APT_OS="ubuntu" ;;
737
+ *debian*) APT_OS="debian" ;;
738
+ esac
739
+ fi
740
+ ;;
741
+ esac
742
+ fi
743
+ # Build mirror candidates
744
+ # deb.debian.org is a CDN endpoint, if one backend goes down,
745
+ # the CDN automatically fails over to another mirror,
746
+ # so we only retry for ubuntu here.
747
+ if [ "$APT_OS" = "ubuntu" ]; then
748
+ # Backup current sources once
749
+ backup_source
750
+ # Selected from https://launchpad.net/ubuntu/+archivemirrors
751
+ # and results from apt-select
752
+ local MIRROR_CANDIDATES="mirrors.wikimedia.org mirror.umd.edu"
753
+ for host in $MIRROR_CANDIDATES; do
754
+ echo "Trying APT mirror ($APT_OS): $host" >> "$log"
755
+ if [ -f /etc/apt/sources.list ]; then
756
+ update_apt_sources $host /etc/apt/sources.list
757
+ else
758
+ echo "Error: /etc/apt/sources.list not found" >> "$log"
759
+ break
760
+ fi
761
+ apt_update_install_with_retries $packages >> "$log" 2>&1 && { set -e; return 0; }
762
+ echo "Install failed with mirror ($APT_OS): $host" >> "$log"
763
+ # Restore to default sources
764
+ restore_source
765
+ done
766
+ fi
767
+ set -e
768
+ if [ "$required" = "1" ]; then
769
+ echo "Error: required package install failed across all mirrors: $packages" >> "$log"
770
+ return 1
771
+ else
772
+ echo "Optional package install failed across all mirrors: $packages; skipping." >> "$log"
773
+ return 0
774
+ fi
775
+ }
659
776
  # Install both fuse2 and fuse3 for compatibility for all possible fuse adapters in advance,
660
777
  # so that both fusemount and fusermount3 can be masked before enabling SSH access.
661
778
  PACKAGES="rsync curl wget netcat gcc patch pciutils fuse fuse3 openssh-server";
@@ -682,7 +799,7 @@ available_node_types:
682
799
  done;
683
800
  if [ ! -z "$INSTALL_FIRST" ]; then
684
801
  echo "Installing core packages: $INSTALL_FIRST";
685
- DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get install -y -o Dpkg::Options::="--force-confdef" -o Dpkg::Options::="--force-confold" $INSTALL_FIRST;
802
+ apt_install_with_mirrors 1 $INSTALL_FIRST || { echo "Error: core package installation failed." >> /tmp/apt-update.log; exit 1; }
686
803
  fi;
687
804
  # SSH and other packages are not necessary, so we disable set -e
688
805
  set +e
@@ -107,10 +107,12 @@ def canonicalize_accelerator_name(accelerator: str,
107
107
  if not names and cloud_str in ['Kubernetes', None]:
108
108
  with rich_utils.safe_status(
109
109
  ux_utils.spinner_message('Listing accelerators on Kubernetes')):
110
+ # Only search for Kubernetes to reduce the lookup cost.
111
+ # For other clouds, the catalog has been searched in previous steps.
110
112
  searched = catalog.list_accelerators(
111
113
  name_filter=accelerator,
112
114
  case_sensitive=False,
113
- clouds=cloud_str,
115
+ clouds='Kubernetes',
114
116
  )
115
117
  names = list(searched.keys())
116
118
  if accelerator in names:
@@ -469,15 +469,19 @@ class CommandRunner:
469
469
  """Close the cached connection to the remote machine."""
470
470
  pass
471
471
 
472
- def port_forward_command(self,
473
- port_forward: List[Tuple[int, int]],
474
- connect_timeout: int = 1) -> List[str]:
472
+ def port_forward_command(
473
+ self,
474
+ port_forward: List[Tuple[int, int]],
475
+ connect_timeout: int = 1,
476
+ ssh_mode: SshMode = SshMode.INTERACTIVE) -> List[str]:
475
477
  """Command for forwarding ports from localhost to the remote machine.
476
478
 
477
479
  Args:
478
480
  port_forward: A list of ports to forward from the localhost to the
479
481
  remote host.
480
482
  connect_timeout: The timeout for the connection.
483
+ ssh_mode: The mode to use for ssh.
484
+ See SSHMode for more details.
481
485
  """
482
486
  raise NotImplementedError
483
487
 
@@ -592,6 +596,7 @@ class SSHCommandRunner(CommandRunner):
592
596
  ssh_proxy_command: Optional[str] = None,
593
597
  docker_user: Optional[str] = None,
594
598
  disable_control_master: Optional[bool] = False,
599
+ port_forward_execute_remote_command: Optional[bool] = False,
595
600
  ):
596
601
  """Initialize SSHCommandRunner.
597
602
 
@@ -618,6 +623,10 @@ class SSHCommandRunner(CommandRunner):
618
623
  disable_control_master: bool; specifies either or not the ssh
619
624
  command will utilize ControlMaster. We currently disable
620
625
  it for k8s instance.
626
+ port_forward_execute_remote_command: bool; specifies whether to
627
+ add -N to the port forwarding command. This is useful if you
628
+ want to run a command on the remote machine to make sure the
629
+ SSH tunnel is established.
621
630
  """
622
631
  super().__init__(node)
623
632
  ip, port = node
@@ -646,22 +655,28 @@ class SSHCommandRunner(CommandRunner):
646
655
  self.ssh_user = ssh_user
647
656
  self.port = port
648
657
  self._docker_ssh_proxy_command = None
658
+ self.port_forward_execute_remote_command = (
659
+ port_forward_execute_remote_command)
649
660
 
650
- def port_forward_command(self,
651
- port_forward: List[Tuple[int, int]],
652
- connect_timeout: int = 1) -> List[str]:
661
+ def port_forward_command(
662
+ self,
663
+ port_forward: List[Tuple[int, int]],
664
+ connect_timeout: int = 1,
665
+ ssh_mode: SshMode = SshMode.INTERACTIVE) -> List[str]:
653
666
  """Command for forwarding ports from localhost to the remote machine.
654
667
 
655
668
  Args:
656
669
  port_forward: A list of ports to forward from the local port to the
657
670
  remote port.
658
671
  connect_timeout: The timeout for the ssh connection.
672
+ ssh_mode: The mode to use for ssh.
673
+ See SSHMode for more details.
659
674
 
660
675
  Returns:
661
676
  The command for forwarding ports from localhost to the remote
662
677
  machine.
663
678
  """
664
- return self.ssh_base_command(ssh_mode=SshMode.INTERACTIVE,
679
+ return self.ssh_base_command(ssh_mode=ssh_mode,
665
680
  port_forward=port_forward,
666
681
  connect_timeout=connect_timeout)
667
682
 
@@ -680,7 +695,11 @@ class SSHCommandRunner(CommandRunner):
680
695
  for local, remote in port_forward:
681
696
  logger.debug(
682
697
  f'Forwarding local port {local} to remote port {remote}.')
683
- ssh += ['-NL', f'{local}:localhost:{remote}']
698
+ if self.port_forward_execute_remote_command:
699
+ ssh += ['-L']
700
+ else:
701
+ ssh += ['-NL']
702
+ ssh += [f'{local}:localhost:{remote}']
684
703
  if self._docker_ssh_proxy_command is not None:
685
704
  docker_ssh_proxy_command = self._docker_ssh_proxy_command(ssh)
686
705
  else:
@@ -894,9 +913,11 @@ class KubernetesCommandRunner(CommandRunner):
894
913
  else:
895
914
  return f'pod/{self.pod_name}'
896
915
 
897
- def port_forward_command(self,
898
- port_forward: List[Tuple[int, int]],
899
- connect_timeout: int = 1) -> List[str]:
916
+ def port_forward_command(
917
+ self,
918
+ port_forward: List[Tuple[int, int]],
919
+ connect_timeout: int = 1,
920
+ ssh_mode: SshMode = SshMode.INTERACTIVE) -> List[str]:
900
921
  """Command for forwarding ports from localhost to the remote machine.
901
922
 
902
923
  Args:
@@ -904,7 +925,10 @@ class KubernetesCommandRunner(CommandRunner):
904
925
  remote port. Currently, only one port is supported, i.e. the
905
926
  list should have only one element.
906
927
  connect_timeout: The timeout for the ssh connection.
928
+ ssh_mode: The mode to use for ssh.
929
+ See SSHMode for more details.
907
930
  """
931
+ del ssh_mode # unused
908
932
  assert port_forward and len(port_forward) == 1, (
909
933
  'Only one port is supported for Kubernetes port-forward.')
910
934
  kubectl_args = [
@@ -36,9 +36,9 @@ def ssh_options_list(
36
36
 
37
37
 
38
38
  class SshMode(enum.Enum):
39
- NON_INTERACTIVE: int
40
- INTERACTIVE: int
41
- LOGIN: int
39
+ NON_INTERACTIVE = ...
40
+ INTERACTIVE = ...
41
+ LOGIN = ...
42
42
 
43
43
 
44
44
  class CommandRunner:
@@ -106,6 +106,13 @@ class CommandRunner:
106
106
  max_retry: int = ...) -> None:
107
107
  ...
108
108
 
109
+ def port_forward_command(
110
+ self,
111
+ port_forward: List[Tuple[int, int]],
112
+ connect_timeout: int = 1,
113
+ ssh_mode: SshMode = SshMode.INTERACTIVE) -> List[str]:
114
+ ...
115
+
109
116
  @classmethod
110
117
  def make_runner_list(cls: typing.Type[CommandRunner],
111
118
  node_list: Iterable[Tuple[Any, ...]],
@@ -127,6 +134,7 @@ class SSHCommandRunner(CommandRunner):
127
134
  ssh_control_name: Optional[str]
128
135
  docker_user: str
129
136
  disable_control_master: Optional[bool]
137
+ port_forward_execute_remote_command: Optional[bool]
130
138
 
131
139
  def __init__(
132
140
  self,
@@ -200,6 +208,13 @@ class SSHCommandRunner(CommandRunner):
200
208
  max_retry: int = ...) -> None:
201
209
  ...
202
210
 
211
+ def port_forward_command(
212
+ self,
213
+ port_forward: List[Tuple[int, int]],
214
+ connect_timeout: int = 1,
215
+ ssh_mode: SshMode = SshMode.INTERACTIVE) -> List[str]:
216
+ ...
217
+
203
218
 
204
219
  class KubernetesCommandRunner(CommandRunner):
205
220
 
@@ -272,6 +287,13 @@ class KubernetesCommandRunner(CommandRunner):
272
287
  max_retry: int = ...) -> None:
273
288
  ...
274
289
 
290
+ def port_forward_command(
291
+ self,
292
+ port_forward: List[Tuple[int, int]],
293
+ connect_timeout: int = 1,
294
+ ssh_mode: SshMode = SshMode.INTERACTIVE) -> List[str]:
295
+ ...
296
+
275
297
 
276
298
  class LocalProcessCommandRunner(CommandRunner):
277
299
 
sky/utils/common_utils.py CHANGED
@@ -996,7 +996,17 @@ def get_mem_size_gb() -> float:
996
996
  except ValueError as e:
997
997
  with ux_utils.print_exception_no_traceback():
998
998
  raise ValueError(
999
- f'Failed to parse the memory size from {mem_size}') from e
999
+ f'Failed to parse the memory size from {mem_size} (GB)'
1000
+ ) from e
1001
+ mem_size = os.getenv('SKYPILOT_POD_MEMORY_BYTES_LIMIT')
1002
+ if mem_size is not None:
1003
+ try:
1004
+ return float(mem_size) / (1024**3)
1005
+ except ValueError as e:
1006
+ with ux_utils.print_exception_no_traceback():
1007
+ raise ValueError(
1008
+ f'Failed to parse the memory size from {mem_size} (bytes)'
1009
+ ) from e
1000
1010
  return _mem_size_gb()
1001
1011
 
1002
1012
 
@@ -10,6 +10,8 @@ import sys
10
10
  import typing
11
11
  from typing import Any, Callable, IO, Optional, Tuple, TypeVar
12
12
 
13
+ from typing_extensions import ParamSpec
14
+
13
15
  from sky import sky_logging
14
16
  from sky.utils import context
15
17
  from sky.utils import subprocess_utils
@@ -173,9 +175,14 @@ def cancellation_guard(func: F) -> F:
173
175
  return typing.cast(F, wrapper)
174
176
 
175
177
 
178
+ P = ParamSpec('P')
179
+ T = TypeVar('T')
180
+
181
+
176
182
  # TODO(aylei): replace this with asyncio.to_thread once we drop support for
177
183
  # python 3.8
178
- def to_thread(func, /, *args, **kwargs):
184
+ def to_thread(func: Callable[P, T], /, *args: P.args,
185
+ **kwargs: P.kwargs) -> 'asyncio.Future[T]':
179
186
  """Asynchronously run function *func* in a separate thread.
180
187
 
181
188
  This is same as asyncio.to_thread added in python 3.9
@@ -183,5 +190,11 @@ def to_thread(func, /, *args, **kwargs):
183
190
  loop = asyncio.get_running_loop()
184
191
  # This is critical to pass the current coroutine context to the new thread
185
192
  pyctx = contextvars.copy_context()
186
- func_call = functools.partial(pyctx.run, func, *args, **kwargs)
193
+ func_call: Callable[..., T] = functools.partial(
194
+ # partial deletes arguments type and thus can't figure out the return
195
+ # type of pyctx.run
196
+ pyctx.run, # type: ignore
197
+ func,
198
+ *args,
199
+ **kwargs)
187
200
  return loop.run_in_executor(None, func_call)
@@ -228,6 +228,11 @@ def get_controller_for_pool(pool: bool) -> Controllers:
228
228
  def high_availability_specified(cluster_name: Optional[str]) -> bool:
229
229
  """Check if the controller high availability is specified in user config.
230
230
  """
231
+ # pylint: disable=import-outside-toplevel
232
+ from sky.jobs import utils as managed_job_utils
233
+ if managed_job_utils.is_consolidation_mode():
234
+ return True
235
+
231
236
  controller = Controllers.from_name(cluster_name)
232
237
  if controller is None:
233
238
  return False
sky/utils/db/db_utils.py CHANGED
@@ -7,12 +7,13 @@ import pathlib
7
7
  import sqlite3
8
8
  import threading
9
9
  import typing
10
- from typing import Any, Callable, Dict, Iterable, Optional
10
+ from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
11
11
 
12
12
  import aiosqlite
13
13
  import aiosqlite.context
14
14
  import sqlalchemy
15
15
  from sqlalchemy import exc as sqlalchemy_exc
16
+ from sqlalchemy.ext import asyncio as sqlalchemy_async
16
17
 
17
18
  from sky import sky_logging
18
19
  from sky.skylet import constants
@@ -375,11 +376,34 @@ def get_max_connections():
375
376
  return _max_connections
376
377
 
377
378
 
378
- def get_engine(db_name: str):
379
+ @typing.overload
380
+ def get_engine(
381
+ db_name: str,
382
+ async_engine: Literal[False] = False) -> sqlalchemy.engine.Engine:
383
+ ...
384
+
385
+
386
+ @typing.overload
387
+ def get_engine(db_name: str,
388
+ async_engine: Literal[True]) -> sqlalchemy_async.AsyncEngine:
389
+ ...
390
+
391
+
392
+ def get_engine(
393
+ db_name: str,
394
+ async_engine: bool = False
395
+ ) -> Union[sqlalchemy.engine.Engine, sqlalchemy_async.AsyncEngine]:
379
396
  conn_string = None
380
397
  if os.environ.get(constants.ENV_VAR_IS_SKYPILOT_SERVER) is not None:
381
398
  conn_string = os.environ.get(constants.ENV_VAR_DB_CONNECTION_URI)
382
399
  if conn_string:
400
+ if async_engine:
401
+ conn_string = conn_string.replace('postgresql://',
402
+ 'postgresql+asyncpg://')
403
+ # This is an AsyncEngine, instead of a (normal, synchronous) Engine,
404
+ # so we should not put it in the cache. Instead, just return.
405
+ return sqlalchemy_async.create_async_engine(
406
+ conn_string, poolclass=sqlalchemy.NullPool)
383
407
  with _db_creation_lock:
384
408
  if conn_string not in _postgres_engine_cache:
385
409
  if _max_connections == 0:
@@ -401,6 +425,11 @@ def get_engine(db_name: str):
401
425
  else:
402
426
  db_path = os.path.expanduser(f'~/.sky/{db_name}.db')
403
427
  pathlib.Path(db_path).parents[0].mkdir(parents=True, exist_ok=True)
428
+ if async_engine:
429
+ # This is an AsyncEngine, instead of a (normal, synchronous) Engine,
430
+ # so we should not put it in the cache. Instead, just return.
431
+ return sqlalchemy_async.create_async_engine(
432
+ 'sqlite+aiosqlite:///' + db_path, connect_args={'timeout': 30})
404
433
  if db_path not in _sqlite_engine_cache:
405
434
  _sqlite_engine_cache[db_path] = sqlalchemy.create_engine(
406
435
  'sqlite:///' + db_path)
@@ -17,7 +17,7 @@ logger = sky_logging.init_logger(__name__)
17
17
  DB_INIT_LOCK_TIMEOUT_SECONDS = 10
18
18
 
19
19
  GLOBAL_USER_STATE_DB_NAME = 'state_db'
20
- GLOBAL_USER_STATE_VERSION = '007'
20
+ GLOBAL_USER_STATE_VERSION = '008'
21
21
  GLOBAL_USER_STATE_LOCK_PATH = '~/.sky/locks/.state_db.lock'
22
22
 
23
23
  SPOT_JOBS_DB_NAME = 'spot_jobs_db'