skypilot-nightly 1.0.0.dev20250819__py3-none-any.whl → 1.0.0.dev20250820__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (55) hide show
  1. sky/__init__.py +5 -3
  2. sky/backends/cloud_vm_ray_backend.py +6 -13
  3. sky/backends/wheel_utils.py +2 -1
  4. sky/client/cli/command.py +20 -16
  5. sky/core.py +1 -1
  6. sky/dashboard/out/404.html +1 -1
  7. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  8. sky/dashboard/out/clusters/[cluster].html +1 -1
  9. sky/dashboard/out/clusters.html +1 -1
  10. sky/dashboard/out/config.html +1 -1
  11. sky/dashboard/out/index.html +1 -1
  12. sky/dashboard/out/infra/[context].html +1 -1
  13. sky/dashboard/out/infra.html +1 -1
  14. sky/dashboard/out/jobs/[job].html +1 -1
  15. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  16. sky/dashboard/out/jobs.html +1 -1
  17. sky/dashboard/out/users.html +1 -1
  18. sky/dashboard/out/volumes.html +1 -1
  19. sky/dashboard/out/workspace/new.html +1 -1
  20. sky/dashboard/out/workspaces/[name].html +1 -1
  21. sky/dashboard/out/workspaces.html +1 -1
  22. sky/exceptions.py +6 -1
  23. sky/global_user_state.py +18 -11
  24. sky/jobs/server/core.py +1 -1
  25. sky/models.py +1 -0
  26. sky/provision/aws/config.py +11 -11
  27. sky/provision/aws/instance.py +30 -27
  28. sky/provision/do/utils.py +2 -2
  29. sky/provision/kubernetes/network_utils.py +3 -3
  30. sky/provision/kubernetes/utils.py +2 -2
  31. sky/provision/kubernetes/volume.py +2 -0
  32. sky/serve/replica_managers.py +7 -0
  33. sky/serve/server/impl.py +1 -1
  34. sky/server/requests/payloads.py +1 -0
  35. sky/server/requests/serializers/encoders.py +14 -2
  36. sky/server/server.py +33 -0
  37. sky/setup_files/dependencies.py +17 -11
  38. sky/utils/common.py +27 -7
  39. sky/utils/common_utils.py +13 -9
  40. sky/utils/directory_utils.py +12 -0
  41. sky/utils/env_options.py +3 -0
  42. sky/utils/kubernetes/gpu_labeler.py +3 -3
  43. sky/utils/schemas.py +1 -0
  44. sky/utils/serialize_utils.py +16 -0
  45. sky/volumes/client/sdk.py +10 -7
  46. sky/volumes/server/core.py +12 -3
  47. sky/volumes/volume.py +17 -3
  48. {skypilot_nightly-1.0.0.dev20250819.dist-info → skypilot_nightly-1.0.0.dev20250820.dist-info}/METADATA +21 -13
  49. {skypilot_nightly-1.0.0.dev20250819.dist-info → skypilot_nightly-1.0.0.dev20250820.dist-info}/RECORD +55 -53
  50. /sky/dashboard/out/_next/static/{tYn7R2be3cQPYJfTxxE09 → 8ZscIHnvBWz3AXkxsJL6H}/_buildManifest.js +0 -0
  51. /sky/dashboard/out/_next/static/{tYn7R2be3cQPYJfTxxE09 → 8ZscIHnvBWz3AXkxsJL6H}/_ssgManifest.js +0 -0
  52. {skypilot_nightly-1.0.0.dev20250819.dist-info → skypilot_nightly-1.0.0.dev20250820.dist-info}/WHEEL +0 -0
  53. {skypilot_nightly-1.0.0.dev20250819.dist-info → skypilot_nightly-1.0.0.dev20250820.dist-info}/entry_points.txt +0 -0
  54. {skypilot_nightly-1.0.0.dev20250819.dist-info → skypilot_nightly-1.0.0.dev20250820.dist-info}/licenses/LICENSE +0 -0
  55. {skypilot_nightly-1.0.0.dev20250819.dist-info → skypilot_nightly-1.0.0.dev20250820.dist-info}/top_level.txt +0 -0
sky/global_user_state.py CHANGED
@@ -55,6 +55,13 @@ _SQLALCHEMY_ENGINE_LOCK = threading.Lock()
55
55
  DEFAULT_CLUSTER_EVENT_RETENTION_HOURS = 24.0
56
56
  MIN_CLUSTER_EVENT_DAEMON_INTERVAL_SECONDS = 3600
57
57
 
58
+ _UNIQUE_CONSTRAINT_FAILED_ERROR_MSGS = [
59
+ # sqlite
60
+ 'UNIQUE constraint failed',
61
+ # postgres
62
+ 'duplicate key value violates unique constraint',
63
+ ]
64
+
58
65
  Base = declarative.declarative_base()
59
66
 
60
67
  config_table = sqlalchemy.Table(
@@ -735,17 +742,17 @@ def add_cluster_event(cluster_name: str,
735
742
  ))
736
743
  session.commit()
737
744
  except sqlalchemy.exc.IntegrityError as e:
738
- if 'UNIQUE constraint failed' in str(e):
739
- # This can happen if the cluster event is added twice.
740
- # We can ignore this error unless the caller requests
741
- # to expose the error.
742
- if expose_duplicate_error:
743
- raise db_utils.UniqueConstraintViolationError(
744
- value=reason, message=str(e))
745
- else:
746
- pass
747
- else:
748
- raise e
745
+ for msg in _UNIQUE_CONSTRAINT_FAILED_ERROR_MSGS:
746
+ if msg in str(e):
747
+ # This can happen if the cluster event is added twice.
748
+ # We can ignore this error unless the caller requests
749
+ # to expose the error.
750
+ if expose_duplicate_error:
751
+ raise db_utils.UniqueConstraintViolationError(
752
+ value=reason, message=str(e))
753
+ else:
754
+ return
755
+ raise e
749
756
 
750
757
 
751
758
  def get_last_cluster_event(cluster_hash: str,
sky/jobs/server/core.py CHANGED
@@ -188,11 +188,11 @@ def launch(
188
188
 
189
189
  dag_uuid = str(uuid.uuid4().hex[:4])
190
190
  dag = dag_utils.convert_entrypoint_to_dag(entrypoint)
191
- dag.resolve_and_validate_volumes()
192
191
  # Always apply the policy again here, even though it might have been applied
193
192
  # in the CLI. This is to ensure that we apply the policy to the final DAG
194
193
  # and get the mutated config.
195
194
  dag, mutated_user_config = admin_policy_utils.apply(dag)
195
+ dag.resolve_and_validate_volumes()
196
196
  if not dag.is_chain():
197
197
  with ux_utils.print_exception_no_traceback():
198
198
  raise ValueError('Only single-task or chain DAG is '
sky/models.py CHANGED
@@ -108,3 +108,4 @@ class VolumeConfig(pydantic.BaseModel):
108
108
  name_on_cloud: str
109
109
  size: Optional[str]
110
110
  config: Dict[str, Any] = {}
111
+ labels: Optional[Dict[str, str]] = None
@@ -498,8 +498,8 @@ def _vpc_id_from_security_group_ids(ec2: 'mypy_boto3_ec2.ServiceResource',
498
498
  return vpc_ids[0]
499
499
 
500
500
 
501
- def _get_vpc_id_by_name(ec2: 'mypy_boto3_ec2.ServiceResource', vpc_name: str,
502
- region: str) -> str:
501
+ def get_vpc_id_by_name(ec2: 'mypy_boto3_ec2.ServiceResource', vpc_name: str,
502
+ region: str) -> str:
503
503
  """Returns the VPC ID of the unique VPC with a given name.
504
504
 
505
505
  Exits with code 1 if:
@@ -532,7 +532,7 @@ def _get_subnet_and_vpc_id(ec2: 'mypy_boto3_ec2.ServiceResource',
532
532
  use_internal_ips: bool,
533
533
  vpc_name: Optional[str]) -> Tuple[Any, str]:
534
534
  if vpc_name is not None:
535
- vpc_id_of_sg = _get_vpc_id_by_name(ec2, vpc_name, region)
535
+ vpc_id_of_sg = get_vpc_id_by_name(ec2, vpc_name, region)
536
536
  elif security_group_ids:
537
537
  vpc_id_of_sg = _vpc_id_from_security_group_ids(ec2, security_group_ids)
538
538
  else:
@@ -614,8 +614,8 @@ def _get_or_create_vpc_security_group(ec2: 'mypy_boto3_ec2.ServiceResource',
614
614
  due to AWS service issues.
615
615
  """
616
616
  # Figure out which security groups with this name exist for each VPC...
617
- security_group = _get_security_group_from_vpc_id(ec2, vpc_id,
618
- expected_sg_name)
617
+ security_group = get_security_group_from_vpc_id(ec2, vpc_id,
618
+ expected_sg_name)
619
619
  if security_group is not None:
620
620
  return security_group
621
621
 
@@ -631,7 +631,7 @@ def _get_or_create_vpc_security_group(ec2: 'mypy_boto3_ec2.ServiceResource',
631
631
  # The security group already exists, but we didn't see it
632
632
  # because of eventual consistency.
633
633
  logger.warning(f'{expected_sg_name} already exists when creating.')
634
- security_group = _get_security_group_from_vpc_id(
634
+ security_group = get_security_group_from_vpc_id(
635
635
  ec2, vpc_id, expected_sg_name)
636
636
  assert (security_group is not None and
637
637
  security_group.group_name == expected_sg_name), (
@@ -646,8 +646,8 @@ def _get_or_create_vpc_security_group(ec2: 'mypy_boto3_ec2.ServiceResource',
646
646
  logger.warning(message)
647
647
  raise exceptions.NoClusterLaunchedError(message) from e
648
648
 
649
- security_group = _get_security_group_from_vpc_id(ec2, vpc_id,
650
- expected_sg_name)
649
+ security_group = get_security_group_from_vpc_id(ec2, vpc_id,
650
+ expected_sg_name)
651
651
  assert security_group is not None, 'Failed to create security group'
652
652
  logger.info(f'Created new security group {colorama.Style.BRIGHT}'
653
653
  f'{security_group.group_name}{colorama.Style.RESET_ALL} '
@@ -655,9 +655,9 @@ def _get_or_create_vpc_security_group(ec2: 'mypy_boto3_ec2.ServiceResource',
655
655
  return security_group
656
656
 
657
657
 
658
- def _get_security_group_from_vpc_id(ec2: 'mypy_boto3_ec2.ServiceResource',
659
- vpc_id: str,
660
- group_name: str) -> Optional[Any]:
658
+ def get_security_group_from_vpc_id(ec2: 'mypy_boto3_ec2.ServiceResource',
659
+ vpc_id: str,
660
+ group_name: str) -> Optional[Any]:
661
661
  """Get security group by VPC ID and group name."""
662
662
  existing_groups = list(
663
663
  ec2.security_groups.filter(Filters=[{
@@ -18,6 +18,7 @@ from sky.clouds import aws as aws_cloud
18
18
  from sky.clouds.utils import aws_utils
19
19
  from sky.provision import common
20
20
  from sky.provision import constants
21
+ from sky.provision.aws import config as aws_config
21
22
  from sky.provision.aws import utils
22
23
  from sky.utils import common_utils
23
24
  from sky.utils import resources_utils
@@ -685,7 +686,9 @@ def terminate_instances(
685
686
  filters,
686
687
  included_instances=None,
687
688
  excluded_instances=None)
688
- default_sg = _get_sg_from_name(ec2, aws_cloud.DEFAULT_SECURITY_GROUP_NAME)
689
+ default_sg = aws_config.get_security_group_from_vpc_id(
690
+ ec2, _get_vpc_id(provider_config),
691
+ aws_cloud.DEFAULT_SECURITY_GROUP_NAME)
689
692
  if sg_name == aws_cloud.DEFAULT_SECURITY_GROUP_NAME:
690
693
  # Case 1: The default SG is used, we don't need to ensure instance are
691
694
  # terminated.
@@ -727,30 +730,6 @@ def terminate_instances(
727
730
  # of most cloud implementations (including AWS).
728
731
 
729
732
 
730
- def _get_sg_from_name(
731
- ec2: Any,
732
- sg_name: str,
733
- ) -> Any:
734
- # GroupNames will only filter SGs in the default VPC, so we need to use
735
- # Filters here. Ref:
736
- # https://boto3.amazonaws.com/v1/documentation/api/1.26.112/reference/services/ec2/service-resource/security_groups.html # pylint: disable=line-too-long
737
- sgs = ec2.security_groups.filter(Filters=[{
738
- 'Name': 'group-name',
739
- 'Values': [sg_name]
740
- }])
741
- num_sg = len(list(sgs))
742
- if num_sg == 0:
743
- logger.warning(f'Expected security group {sg_name} not found. ')
744
- return None
745
- if num_sg > 1:
746
- # TODO(tian): Better handle this case. Maybe we can check when creating
747
- # the SG and throw an error if there is already an existing SG with the
748
- # same name.
749
- logger.warning(f'Found {num_sg} security groups with name {sg_name}. ')
750
- return None
751
- return list(sgs)[0]
752
-
753
-
754
733
  def _maybe_move_to_new_sg(
755
734
  instance: Any,
756
735
  expected_sg: Any,
@@ -803,7 +782,9 @@ def open_ports(
803
782
  with ux_utils.print_exception_no_traceback():
804
783
  raise ValueError('Instance with cluster name '
805
784
  f'{cluster_name_on_cloud} not found.')
806
- sg = _get_sg_from_name(ec2, sg_name)
785
+ sg = aws_config.get_security_group_from_vpc_id(ec2,
786
+ _get_vpc_id(provider_config),
787
+ sg_name)
807
788
  if sg is None:
808
789
  with ux_utils.print_exception_no_traceback():
809
790
  raise ValueError('Cannot find new security group '
@@ -899,7 +880,9 @@ def cleanup_ports(
899
880
  # We only want to delete the SG that is dedicated to this cluster (i.e.,
900
881
  # this cluster have opened some ports).
901
882
  return
902
- sg = _get_sg_from_name(ec2, sg_name)
883
+ sg = aws_config.get_security_group_from_vpc_id(ec2,
884
+ _get_vpc_id(provider_config),
885
+ sg_name)
903
886
  if sg is None:
904
887
  logger.warning(
905
888
  'Find security group failed. Skip cleanup security group.')
@@ -1010,3 +993,23 @@ def get_cluster_info(
1010
993
  provider_name='aws',
1011
994
  provider_config=provider_config,
1012
995
  )
996
+
997
+
998
+ def _get_vpc_id(provider_config: Dict[str, Any]) -> str:
999
+ region = provider_config['region']
1000
+ ec2 = _default_ec2_resource(provider_config['region'])
1001
+ if 'vpc_name' in provider_config:
1002
+ return aws_config.get_vpc_id_by_name(ec2, provider_config['vpc_name'],
1003
+ region)
1004
+ else:
1005
+ # Retrieve the default VPC name from the region.
1006
+ response = ec2.meta.client.describe_vpcs(Filters=[{
1007
+ 'Name': 'isDefault',
1008
+ 'Values': ['true']
1009
+ }])
1010
+ if len(response['Vpcs']) == 0:
1011
+ raise ValueError(f'No default VPC found in region {region}')
1012
+ elif len(response['Vpcs']) > 1:
1013
+ raise ValueError(f'Multiple default VPCs found in region {region}')
1014
+ else:
1015
+ return response['Vpcs'][0]['VpcId']
sky/provision/do/utils.py CHANGED
@@ -30,7 +30,7 @@ POSSIBLE_CREDENTIALS_PATHS = [
30
30
  INITIAL_BACKOFF_SECONDS = 10
31
31
  MAX_BACKOFF_FACTOR = 10
32
32
  MAX_ATTEMPTS = 6
33
- SSH_KEY_NAME_ON_DO = f'sky-key-{common_utils.get_user_hash()}'
33
+ SSH_KEY_NAME_ON_DO_PREFIX = 'sky-key-'
34
34
 
35
35
  _client = None
36
36
  _ssh_key_id = None
@@ -125,7 +125,7 @@ def ssh_key_id(public_key: str):
125
125
 
126
126
  request = {
127
127
  'public_key': public_key,
128
- 'name': SSH_KEY_NAME_ON_DO,
128
+ 'name': SSH_KEY_NAME_ON_DO_PREFIX + common_utils.get_user_hash(),
129
129
  }
130
130
  _ssh_key_id = client().ssh_keys.create(body=request)['ssh_key']
131
131
  return _ssh_key_id
@@ -4,13 +4,13 @@ import time
4
4
  import typing
5
5
  from typing import Dict, List, Optional, Tuple, Union
6
6
 
7
- import sky
8
7
  from sky import exceptions
9
8
  from sky import sky_logging
10
9
  from sky import skypilot_config
11
10
  from sky.adaptors import common as adaptors_common
12
11
  from sky.adaptors import kubernetes
13
12
  from sky.provision.kubernetes import utils as kubernetes_utils
13
+ from sky.utils import directory_utils
14
14
  from sky.utils import kubernetes_enums
15
15
  from sky.utils import ux_utils
16
16
 
@@ -80,7 +80,7 @@ def get_networking_mode(
80
80
  def fill_loadbalancer_template(namespace: str, context: Optional[str],
81
81
  service_name: str, ports: List[int],
82
82
  selector_key: str, selector_value: str) -> Dict:
83
- template_path = os.path.join(sky.__root_dir__, 'templates',
83
+ template_path = os.path.join(directory_utils.get_sky_dir(), 'templates',
84
84
  _LOADBALANCER_TEMPLATE_NAME)
85
85
  if not os.path.exists(template_path):
86
86
  raise FileNotFoundError(
@@ -116,7 +116,7 @@ def fill_ingress_template(namespace: str, context: Optional[str],
116
116
  service_details: List[Tuple[str, int,
117
117
  str]], ingress_name: str,
118
118
  selector_key: str, selector_value: str) -> Dict:
119
- template_path = os.path.join(sky.__root_dir__, 'templates',
119
+ template_path = os.path.join(directory_utils.get_sky_dir(), 'templates',
120
120
  _INGRESS_TEMPLATE_NAME)
121
121
  if not os.path.exists(template_path):
122
122
  raise FileNotFoundError(
@@ -14,7 +14,6 @@ import typing
14
14
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
15
15
  from urllib.parse import urlparse
16
16
 
17
- import sky
18
17
  from sky import clouds
19
18
  from sky import exceptions
20
19
  from sky import global_user_state
@@ -31,6 +30,7 @@ from sky.skylet import constants
31
30
  from sky.utils import annotations
32
31
  from sky.utils import common_utils
33
32
  from sky.utils import config_utils
33
+ from sky.utils import directory_utils
34
34
  from sky.utils import env_options
35
35
  from sky.utils import kubernetes_enums
36
36
  from sky.utils import schemas
@@ -2444,7 +2444,7 @@ def clean_zombie_ssh_jump_pod(namespace: str, context: Optional[str],
2444
2444
 
2445
2445
  def fill_ssh_jump_template(ssh_key_secret: str, ssh_jump_image: str,
2446
2446
  ssh_jump_name: str, service_type: str) -> Dict:
2447
- template_path = os.path.join(sky.__root_dir__, 'templates',
2447
+ template_path = os.path.join(directory_utils.get_sky_dir(), 'templates',
2448
2448
  'kubernetes-ssh-jump.yml.j2')
2449
2449
  if not os.path.exists(template_path):
2450
2450
  raise FileNotFoundError(
@@ -203,6 +203,8 @@ def _get_pvc_spec(namespace: str,
203
203
  },
204
204
  }
205
205
  }
206
+ if config.labels:
207
+ pvc_spec['metadata']['labels'].update(config.labels)
206
208
  storage_class = config.config.get('storage_class_name')
207
209
  if storage_class is not None:
208
210
  pvc_spec['spec']['storageClassName'] = storage_class
@@ -48,6 +48,13 @@ _PROCESS_POOL_REFRESH_INTERVAL = 20
48
48
  _RETRY_INIT_GAP_SECONDS = 60
49
49
  _DEFAULT_DRAIN_SECONDS = 120
50
50
 
51
+ # TODO(tian): Backward compatibility. Remove this after 3 minor release, i.e.
52
+ # 0.13.0. We move the ProcessStatus to common_utils.ProcessStatus in #6666, but
53
+ # old ReplicaInfo in database will still tries to unpickle using ProcessStatus
54
+ # in replica_managers. We set this alias to avoid breaking changes. See #6729
55
+ # for more details.
56
+ ProcessStatus = common_utils.ProcessStatus
57
+
51
58
 
52
59
  # TODO(tian): Combine this with
53
60
  # sky/spot/recovery_strategy.py::StrategyExecutor::launch
sky/serve/server/impl.py CHANGED
@@ -129,11 +129,11 @@ def up(
129
129
  f'{constants.CLUSTER_NAME_VALID_REGEX}')
130
130
 
131
131
  dag = dag_utils.convert_entrypoint_to_dag(task)
132
- dag.resolve_and_validate_volumes()
133
132
  # Always apply the policy again here, even though it might have been applied
134
133
  # in the CLI. This is to ensure that we apply the policy to the final DAG
135
134
  # and get the mutated config.
136
135
  dag, mutated_user_config = admin_policy_utils.apply(dag)
136
+ dag.resolve_and_validate_volumes()
137
137
  dag.pre_mount_volumes()
138
138
  task = dag.tasks[0]
139
139
  assert task.service is not None
@@ -453,6 +453,7 @@ class VolumeApplyBody(RequestBody):
453
453
  zone: Optional[str] = None
454
454
  size: Optional[str] = None
455
455
  config: Optional[Dict[str, Any]] = None
456
+ labels: Optional[Dict[str, str]] = None
456
457
 
457
458
 
458
459
  class VolumeDeleteBody(RequestBody):
@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple
10
10
 
11
11
  from sky.schemas.api import responses
12
12
  from sky.server import constants as server_constants
13
+ from sky.utils import serialize_utils
13
14
 
14
15
  if typing.TYPE_CHECKING:
15
16
  from sky import backends
@@ -22,6 +23,9 @@ handlers: Dict[str, Any] = {}
22
23
 
23
24
  def pickle_and_encode(obj: Any) -> str:
24
25
  try:
26
+ # Apply backwards compatibility processing at the lowest level
27
+ # to catch any handles that might have bypassed the encoders
28
+ obj = serialize_utils.prepare_handle_for_backwards_compatibility(obj)
25
29
  return base64.b64encode(pickle.dumps(obj)).decode('utf-8')
26
30
  except TypeError as e:
27
31
  raise ValueError(f'Failed to pickle object: {obj}') from e
@@ -58,7 +62,9 @@ def encode_status(
58
62
  for cluster in clusters:
59
63
  response_cluster = cluster.model_dump()
60
64
  response_cluster['status'] = cluster['status'].value
61
- response_cluster['handle'] = pickle_and_encode(cluster['handle'])
65
+ handle = serialize_utils.prepare_handle_for_backwards_compatibility(
66
+ cluster['handle'])
67
+ response_cluster['handle'] = pickle_and_encode(handle)
62
68
  response_cluster['storage_mounts_metadata'] = pickle_and_encode(
63
69
  response_cluster['storage_mounts_metadata'])
64
70
  response.append(response_cluster)
@@ -70,6 +76,7 @@ def encode_launch(
70
76
  job_id_handle: Tuple[Optional[int], Optional['backends.ResourceHandle']]
71
77
  ) -> Dict[str, Any]:
72
78
  job_id, handle = job_id_handle
79
+ handle = serialize_utils.prepare_handle_for_backwards_compatibility(handle)
73
80
  return {
74
81
  'job_id': job_id,
75
82
  'handle': pickle_and_encode(handle),
@@ -78,6 +85,9 @@ def encode_launch(
78
85
 
79
86
  @register_encoder('start')
80
87
  def encode_start(resource_handle: 'backends.CloudVmRayResourceHandle') -> str:
88
+ resource_handle = (
89
+ serialize_utils.prepare_handle_for_backwards_compatibility(
90
+ resource_handle))
81
91
  return pickle_and_encode(resource_handle)
82
92
 
83
93
 
@@ -143,7 +153,9 @@ def _encode_serve_status(
143
153
  service_status['status'] = service_status['status'].value
144
154
  for replica_info in service_status.get('replica_info', []):
145
155
  replica_info['status'] = replica_info['status'].value
146
- replica_info['handle'] = pickle_and_encode(replica_info['handle'])
156
+ handle = serialize_utils.prepare_handle_for_backwards_compatibility(
157
+ replica_info['handle'])
158
+ replica_info['handle'] = pickle_and_encode(handle)
147
159
  return service_statuses
148
160
 
149
161
 
sky/server/server.py CHANGED
@@ -83,6 +83,8 @@ else:
83
83
 
84
84
  P = ParamSpec('P')
85
85
 
86
+ _SERVER_USER_HASH_KEY = 'server_user_hash'
87
+
86
88
 
87
89
  def _add_timestamp_prefix_for_server_logs() -> None:
88
90
  server_logger = sky_logging.init_logger('sky.server')
@@ -1821,6 +1823,35 @@ async def root():
1821
1823
  return fastapi.responses.RedirectResponse(url='/dashboard/')
1822
1824
 
1823
1825
 
1826
+ def _init_or_restore_server_user_hash():
1827
+ """Restores the server user hash from the global user state db.
1828
+
1829
+ The API server must have a stable user hash across restarts and potential
1830
+ multiple replicas. Thus we persist the user hash in db and restore it on
1831
+ startup. When upgrading from old version, the user hash will be read from
1832
+ the local file (if any) to keep the user hash consistent.
1833
+ """
1834
+
1835
+ def apply_user_hash(user_hash: str) -> None:
1836
+ # For local API server, the user hash in db and local file should be
1837
+ # same so there is no harm to override here.
1838
+ common_utils.set_user_hash_locally(user_hash)
1839
+ # Refresh the server user hash for current process after restore or
1840
+ # initialize the user hash in db, child processes will get the correct
1841
+ # server id from the local cache file.
1842
+ common_lib.refresh_server_id()
1843
+
1844
+ user_hash = global_user_state.get_system_config(_SERVER_USER_HASH_KEY)
1845
+ if user_hash is not None:
1846
+ apply_user_hash(user_hash)
1847
+ return
1848
+
1849
+ # Initial deployment, generate a user hash and save it to the db.
1850
+ user_hash = common_utils.get_user_hash()
1851
+ global_user_state.set_system_config(_SERVER_USER_HASH_KEY, user_hash)
1852
+ apply_user_hash(user_hash)
1853
+
1854
+
1824
1855
  if __name__ == '__main__':
1825
1856
  import uvicorn
1826
1857
 
@@ -1830,6 +1861,8 @@ if __name__ == '__main__':
1830
1861
  global_user_state.initialize_and_get_db()
1831
1862
  # Initialize request db
1832
1863
  requests_lib.reset_db_and_logs()
1864
+ # Restore the server user hash
1865
+ _init_or_restore_server_user_hash()
1833
1866
 
1834
1867
  parser = argparse.ArgumentParser()
1835
1868
  parser.add_argument('--host', default='127.0.0.1')
@@ -72,12 +72,27 @@ install_requires = [
72
72
  'aiohttp',
73
73
  ]
74
74
 
75
+ # See requirements-dev.txt for the version of grpc and protobuf
76
+ # used to generate the code during development.
77
+
78
+ # The grpc version at runtime has to be newer than the version
79
+ # used to generate the code.
80
+ GRPC = 'grpcio>=1.63.0'
81
+ # >= 5.26.1 because the runtime version can't be older than the version
82
+ # used to generate the code.
83
+ # < 7.0.0 because code generated for a major version V will be supported by
84
+ # protobuf runtimes of version V and V+1.
85
+ # https://protobuf.dev/support/cross-version-runtime-guarantee
86
+ PROTOBUF = 'protobuf>=5.26.1, < 7.0.0'
87
+
75
88
  server_dependencies = [
76
89
  'casbin',
77
90
  'sqlalchemy_adapter',
78
91
  'passlib',
79
92
  'pyjwt',
80
93
  'aiohttp',
94
+ GRPC,
95
+ PROTOBUF,
81
96
  ]
82
97
 
83
98
  local_ray = [
@@ -88,18 +103,9 @@ local_ray = [
88
103
  'ray[default] >= 2.2.0, != 2.6.0',
89
104
  ]
90
105
 
91
- # See requirements-dev.txt for the version of grpc and protobuf
92
- # used to generate the code during development.
93
106
  remote = [
94
- # The grpc version at runtime has to be newer than the version
95
- # used to generate the code.
96
- 'grpcio>=1.63.0',
97
- # >= 5.26.1 because the runtime version can't be older than the version
98
- # used to generate the code.
99
- # < 7.0.0 because code generated for a major version V will be supported by
100
- # protobuf runtimes of version V and V+1.
101
- # https://protobuf.dev/support/cross-version-runtime-guarantee
102
- 'protobuf >= 5.26.1, < 7.0.0',
107
+ GRPC,
108
+ PROTOBUF,
103
109
  ]
104
110
 
105
111
  # NOTE: Change the templates/jobs-controller.yaml.j2 file if any of the
sky/utils/common.py CHANGED
@@ -11,18 +11,38 @@ from sky.utils import common_utils
11
11
 
12
12
  SKY_SERVE_CONTROLLER_PREFIX: str = 'sky-serve-controller-'
13
13
  JOB_CONTROLLER_PREFIX: str = 'sky-jobs-controller-'
14
+
14
15
  # We use the user hash (machine-specific) for the controller name. It will be
15
16
  # the same across the whole lifecycle of the server, including:
16
- # 1. all requests, because this global variable is set once during server
17
- # starts.
18
- # 2. SkyPilot API server restarts, as long as the `~/.sky` folder is persisted
19
- # and the env var set during starting the server is the same.
17
+ # 1. all requests, because all the server processes share the same user hash
18
+ # cache file.
19
+ # 2. SkyPilot API server restarts, because the API server will restore the
20
+ # user hash from the global user state db on startup.
21
+ # 3. Potential multiple server replicas, because multiple server replicas of
22
+ # a same deployment will share the same global user state db.
20
23
  # This behavior is the same for the local API server (where SERVER_ID is the
21
24
  # same as the normal user hash). This ensures backwards-compatibility with jobs
22
25
  # controllers from before #4660.
23
- SERVER_ID = common_utils.get_user_hash()
24
- SKY_SERVE_CONTROLLER_NAME: str = f'{SKY_SERVE_CONTROLLER_PREFIX}{SERVER_ID}'
25
- JOB_CONTROLLER_NAME: str = f'{JOB_CONTROLLER_PREFIX}{SERVER_ID}'
26
+ SERVER_ID: str
27
+ SKY_SERVE_CONTROLLER_NAME: str
28
+ JOB_CONTROLLER_NAME: str
29
+
30
+
31
+ def refresh_server_id() -> None:
32
+ """Refresh the server id.
33
+
34
+ This function is used to ensure the server id is read from the authorative
35
+ source.
36
+ """
37
+ global SERVER_ID
38
+ global SKY_SERVE_CONTROLLER_NAME
39
+ global JOB_CONTROLLER_NAME
40
+ SERVER_ID = common_utils.get_user_hash()
41
+ SKY_SERVE_CONTROLLER_NAME = f'{SKY_SERVE_CONTROLLER_PREFIX}{SERVER_ID}'
42
+ JOB_CONTROLLER_NAME = f'{JOB_CONTROLLER_PREFIX}{SERVER_ID}'
43
+
44
+
45
+ refresh_server_id()
26
46
 
27
47
 
28
48
  @contextlib.contextmanager
sky/utils/common_utils.py CHANGED
@@ -28,7 +28,6 @@ from sky.adaptors import common as adaptors_common
28
28
  from sky.skylet import constants
29
29
  from sky.usage import constants as usage_constants
30
30
  from sky.utils import annotations
31
- from sky.utils import common_utils
32
31
  from sky.utils import ux_utils
33
32
  from sky.utils import validator
34
33
 
@@ -41,7 +40,7 @@ else:
41
40
  psutil = adaptors_common.LazyImport('psutil')
42
41
  yaml = adaptors_common.LazyImport('yaml')
43
42
 
44
- _USER_HASH_FILE = os.path.expanduser('~/.sky/user_hash')
43
+ USER_HASH_FILE = os.path.expanduser('~/.sky/user_hash')
45
44
  USER_HASH_LENGTH = 8
46
45
 
47
46
  # We are using base36 to reduce the length of the hash. 2 chars -> 36^2 = 1296
@@ -131,21 +130,26 @@ def get_user_hash() -> str:
131
130
  assert user_hash is not None
132
131
  return user_hash
133
132
 
134
- if os.path.exists(_USER_HASH_FILE):
133
+ if os.path.exists(USER_HASH_FILE):
135
134
  # Read from cached user hash file.
136
- with open(_USER_HASH_FILE, 'r', encoding='utf-8') as f:
135
+ with open(USER_HASH_FILE, 'r', encoding='utf-8') as f:
137
136
  # Remove invalid characters.
138
137
  user_hash = f.read().strip()
139
138
  if is_valid_user_hash(user_hash):
140
139
  return user_hash
141
140
 
142
141
  user_hash = generate_user_hash()
143
- os.makedirs(os.path.dirname(_USER_HASH_FILE), exist_ok=True)
144
- with open(_USER_HASH_FILE, 'w', encoding='utf-8') as f:
145
- f.write(user_hash)
142
+ set_user_hash_locally(user_hash)
146
143
  return user_hash
147
144
 
148
145
 
146
+ def set_user_hash_locally(user_hash: str) -> None:
147
+ """Sets the user hash to local file."""
148
+ os.makedirs(os.path.dirname(USER_HASH_FILE), exist_ok=True)
149
+ with open(USER_HASH_FILE, 'w', encoding='utf-8') as f:
150
+ f.write(user_hash)
151
+
152
+
149
153
  def base36_encode(hex_str: str) -> str:
150
154
  """Converts a hex string to a base36 string."""
151
155
  int_value = int(hex_str, 16)
@@ -343,7 +347,7 @@ def get_current_user() -> 'models.User':
343
347
 
344
348
  def get_current_user_name() -> str:
345
349
  """Returns the current user name."""
346
- name = common_utils.get_current_user().name
350
+ name = get_current_user().name
347
351
  assert name is not None
348
352
  return name
349
353
 
@@ -886,7 +890,7 @@ def get_cleaned_username(username: str = '') -> str:
886
890
  Returns:
887
891
  A cleaned username.
888
892
  """
889
- username = username or common_utils.get_current_user_name()
893
+ username = username or get_current_user_name()
890
894
  username = username.lower()
891
895
  username = re.sub(r'[^a-z0-9-_]', '', username)
892
896
  username = re.sub(r'^[0-9-]+', '', username)
@@ -0,0 +1,12 @@
1
+ """Directory utilities."""
2
+
3
+ import os
4
+
5
+ # This file is in '<project_root>/sky/utils/directory_utils.py'
6
+ # So we need to go up 2 levels to get to the '<project_root>/sky' directory
7
+ SKY_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8
+
9
+
10
+ def get_sky_dir():
11
+ """Get the sky root directory."""
12
+ return SKY_DIR
sky/utils/env_options.py CHANGED
@@ -24,6 +24,9 @@ class Options(enum.Enum):
24
24
  # running in a Buildkite container environment, which requires special
25
25
  # handling for networking between containers.
26
26
  RUNNING_IN_BUILDKITE = ('BUILDKITE', False)
27
+ # Internal: This is used for testing to enable grpc for communication
28
+ # between the API server and the Skylet.
29
+ ENABLE_GRPC = ('SKYPILOT_ENABLE_GRPC', False)
27
30
 
28
31
  def __init__(self, env_var: str, default: bool) -> None:
29
32
  super().__init__()