skypilot-nightly 1.0.0.dev20250927__py3-none-any.whl → 1.0.0.dev20251002__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (54) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +18 -10
  3. sky/backends/cloud_vm_ray_backend.py +2 -2
  4. sky/check.py +0 -29
  5. sky/client/cli/command.py +48 -28
  6. sky/client/cli/table_utils.py +279 -1
  7. sky/client/sdk.py +7 -18
  8. sky/core.py +15 -16
  9. sky/dashboard/out/404.html +1 -1
  10. sky/dashboard/out/_next/static/{UDSEoDB67vwFMZyCJ4HWU → 16g0-hgEgk6Db72hpE8MY}/_buildManifest.js +1 -1
  11. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/{[pool]-07349868f7905d37.js → [pool]-509b2977a6373bf6.js} +1 -1
  12. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  13. sky/dashboard/out/clusters/[cluster].html +1 -1
  14. sky/dashboard/out/clusters.html +1 -1
  15. sky/dashboard/out/config.html +1 -1
  16. sky/dashboard/out/index.html +1 -1
  17. sky/dashboard/out/infra/[context].html +1 -1
  18. sky/dashboard/out/infra.html +1 -1
  19. sky/dashboard/out/jobs/[job].html +1 -1
  20. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  21. sky/dashboard/out/jobs.html +1 -1
  22. sky/dashboard/out/users.html +1 -1
  23. sky/dashboard/out/volumes.html +1 -1
  24. sky/dashboard/out/workspace/new.html +1 -1
  25. sky/dashboard/out/workspaces/[name].html +1 -1
  26. sky/dashboard/out/workspaces.html +1 -1
  27. sky/data/storage.py +11 -0
  28. sky/data/storage_utils.py +1 -45
  29. sky/jobs/client/sdk.py +3 -2
  30. sky/jobs/controller.py +15 -0
  31. sky/jobs/server/core.py +24 -2
  32. sky/jobs/server/server.py +1 -1
  33. sky/jobs/utils.py +2 -1
  34. sky/provision/kubernetes/instance.py +1 -1
  35. sky/provision/kubernetes/utils.py +50 -28
  36. sky/schemas/api/responses.py +76 -0
  37. sky/server/common.py +2 -1
  38. sky/server/requests/serializers/decoders.py +16 -4
  39. sky/server/requests/serializers/encoders.py +12 -5
  40. sky/task.py +4 -0
  41. sky/utils/cluster_utils.py +23 -5
  42. sky/utils/command_runner.py +21 -5
  43. sky/utils/command_runner.pyi +11 -0
  44. sky/utils/volume.py +5 -0
  45. sky/volumes/client/sdk.py +3 -2
  46. sky/volumes/server/core.py +3 -2
  47. {skypilot_nightly-1.0.0.dev20250927.dist-info → skypilot_nightly-1.0.0.dev20251002.dist-info}/METADATA +33 -33
  48. {skypilot_nightly-1.0.0.dev20250927.dist-info → skypilot_nightly-1.0.0.dev20251002.dist-info}/RECORD +53 -54
  49. sky/volumes/utils.py +0 -224
  50. /sky/dashboard/out/_next/static/{UDSEoDB67vwFMZyCJ4HWU → 16g0-hgEgk6Db72hpE8MY}/_ssgManifest.js +0 -0
  51. {skypilot_nightly-1.0.0.dev20250927.dist-info → skypilot_nightly-1.0.0.dev20251002.dist-info}/WHEEL +0 -0
  52. {skypilot_nightly-1.0.0.dev20250927.dist-info → skypilot_nightly-1.0.0.dev20251002.dist-info}/entry_points.txt +0 -0
  53. {skypilot_nightly-1.0.0.dev20250927.dist-info → skypilot_nightly-1.0.0.dev20251002.dist-info}/licenses/LICENSE +0 -0
  54. {skypilot_nightly-1.0.0.dev20250927.dist-info → skypilot_nightly-1.0.0.dev20251002.dist-info}/top_level.txt +0 -0
sky/jobs/server/core.py CHANGED
@@ -28,6 +28,7 @@ from sky.jobs import constants as managed_job_constants
28
28
  from sky.jobs import state as managed_job_state
29
29
  from sky.jobs import utils as managed_job_utils
30
30
  from sky.provision import common as provision_common
31
+ from sky.schemas.api import responses
31
32
  from sky.serve import serve_state
32
33
  from sky.serve import serve_utils
33
34
  from sky.serve.server import impl
@@ -296,8 +297,7 @@ def launch(
296
297
  # TODO: do something with returned status?
297
298
  _, _ = backend_utils.refresh_cluster_status_handle(
298
299
  cluster_name=cluster_name,
299
- force_refresh_statuses=set(status_lib.ClusterStatus),
300
- acquire_per_cluster_status_lock=False)
300
+ force_refresh_statuses=set(status_lib.ClusterStatus))
301
301
  except (exceptions.ClusterOwnerIdentityMismatchError,
302
302
  exceptions.CloudUserIdentityError,
303
303
  exceptions.ClusterStatusFetchingError) as e:
@@ -644,6 +644,28 @@ def queue(refresh: bool,
644
644
 
645
645
 
646
646
  @usage_lib.entrypoint
647
+ def queue_v2_api(
648
+ refresh: bool,
649
+ skip_finished: bool = False,
650
+ all_users: bool = False,
651
+ job_ids: Optional[List[int]] = None,
652
+ user_match: Optional[str] = None,
653
+ workspace_match: Optional[str] = None,
654
+ name_match: Optional[str] = None,
655
+ pool_match: Optional[str] = None,
656
+ page: Optional[int] = None,
657
+ limit: Optional[int] = None,
658
+ statuses: Optional[List[str]] = None,
659
+ ) -> Tuple[List[responses.ManagedJobRecord], int, Dict[str, int], int]:
660
+ """Gets statuses of managed jobs and parse the
661
+ jobs to responses.ManagedJobRecord."""
662
+ jobs, total, status_counts, total_no_filter = queue_v2(
663
+ refresh, skip_finished, all_users, job_ids, user_match, workspace_match,
664
+ name_match, pool_match, page, limit, statuses)
665
+ return [responses.ManagedJobRecord(**job) for job in jobs
666
+ ], total, status_counts, total_no_filter
667
+
668
+
647
669
  def queue_v2(
648
670
  refresh: bool,
649
671
  skip_finished: bool = False,
sky/jobs/server/server.py CHANGED
@@ -68,7 +68,7 @@ async def queue_v2(request: fastapi.Request,
68
68
  request_id=request.state.request_id,
69
69
  request_name='jobs.queue_v2',
70
70
  request_body=jobs_queue_body_v2,
71
- func=core.queue_v2,
71
+ func=core.queue_v2_api,
72
72
  schedule_type=(api_requests.ScheduleType.LONG
73
73
  if jobs_queue_body_v2.refresh else
74
74
  api_requests.ScheduleType.SHORT),
sky/jobs/utils.py CHANGED
@@ -33,6 +33,7 @@ from sky.backends import cloud_vm_ray_backend
33
33
  from sky.jobs import constants as managed_job_constants
34
34
  from sky.jobs import scheduler
35
35
  from sky.jobs import state as managed_job_state
36
+ from sky.schemas.api import responses
36
37
  from sky.skylet import constants
37
38
  from sky.skylet import job_lib
38
39
  from sky.skylet import log_lib
@@ -1517,7 +1518,7 @@ def load_managed_job_queue(
1517
1518
 
1518
1519
 
1519
1520
  def _get_job_status_from_tasks(
1520
- job_tasks: List[Dict[str, Any]]
1521
+ job_tasks: Union[List[responses.ManagedJobRecord], List[Dict[str, Any]]]
1521
1522
  ) -> Tuple[managed_job_state.ManagedJobStatus, int]:
1522
1523
  """Get the current task status and the current task id for a job."""
1523
1524
  managed_task_status = managed_job_state.ManagedJobStatus.SUCCEEDED
@@ -847,7 +847,7 @@ def _create_namespaced_pod_with_retries(namespace: str, pod_spec: dict,
847
847
  def _wait_for_deployment_pod(context,
848
848
  namespace,
849
849
  deployment,
850
- timeout=60) -> List:
850
+ timeout=300) -> List:
851
851
  label_selector = ','.join([
852
852
  f'{key}={value}'
853
853
  for key, value in deployment.spec.selector.match_labels.items()
@@ -1,4 +1,5 @@
1
1
  """Kubernetes utilities for SkyPilot."""
2
+ import collections
2
3
  import copy
3
4
  import dataclasses
4
5
  import datetime
@@ -3117,14 +3118,6 @@ def get_kubernetes_node_info(
3117
3118
  information.
3118
3119
  """
3119
3120
  nodes = get_kubernetes_nodes(context=context)
3120
- # Get the pods to get the real-time resource usage
3121
- try:
3122
- pods = get_all_pods_in_kubernetes_cluster(context=context)
3123
- except kubernetes.api_exception() as e:
3124
- if e.status == 403:
3125
- pods = None
3126
- else:
3127
- raise
3128
3121
 
3129
3122
  lf, _ = detect_gpu_label_formatter(context)
3130
3123
  if not lf:
@@ -3132,6 +3125,46 @@ def get_kubernetes_node_info(
3132
3125
  else:
3133
3126
  label_keys = lf.get_label_keys()
3134
3127
 
3128
+ # Check if all nodes have no accelerators to avoid fetching pods
3129
+ any_node_has_accelerators = False
3130
+ for node in nodes:
3131
+ accelerator_count = get_node_accelerator_count(context,
3132
+ node.status.allocatable)
3133
+ if accelerator_count > 0:
3134
+ any_node_has_accelerators = True
3135
+ break
3136
+
3137
+ # Get the pods to get the real-time resource usage
3138
+ pods = None
3139
+ allocated_qty_by_node: Dict[str, int] = collections.defaultdict(int)
3140
+ if any_node_has_accelerators:
3141
+ try:
3142
+ pods = get_all_pods_in_kubernetes_cluster(context=context)
3143
+ # Pre-compute allocated accelerator count per node
3144
+ for pod in pods:
3145
+ if pod.status.phase in ['Running', 'Pending']:
3146
+ # Skip pods that should not count against GPU count
3147
+ if should_exclude_pod_from_gpu_allocation(pod):
3148
+ logger.debug(f'Excluding low priority pod '
3149
+ f'{pod.metadata.name} from GPU allocation '
3150
+ f'calculations')
3151
+ continue
3152
+ # Iterate over all the containers in the pod and sum the
3153
+ # GPU requests
3154
+ pod_allocated_qty = 0
3155
+ for container in pod.spec.containers:
3156
+ if container.resources.requests:
3157
+ pod_allocated_qty += get_node_accelerator_count(
3158
+ context, container.resources.requests)
3159
+ if pod_allocated_qty > 0:
3160
+ allocated_qty_by_node[
3161
+ pod.spec.node_name] += pod_allocated_qty
3162
+ except kubernetes.api_exception() as e:
3163
+ if e.status == 403:
3164
+ pass
3165
+ else:
3166
+ raise
3167
+
3135
3168
  node_info_dict: Dict[str, models.KubernetesNodeInfo] = {}
3136
3169
  has_multi_host_tpu = False
3137
3170
 
@@ -3161,32 +3194,21 @@ def get_kubernetes_node_info(
3161
3194
  node_ip = address.address
3162
3195
  break
3163
3196
 
3164
- allocated_qty = 0
3165
3197
  accelerator_count = get_node_accelerator_count(context,
3166
3198
  node.status.allocatable)
3199
+ if accelerator_count == 0:
3200
+ node_info_dict[node.metadata.name] = models.KubernetesNodeInfo(
3201
+ name=node.metadata.name,
3202
+ accelerator_type=accelerator_name,
3203
+ total={'accelerator_count': 0},
3204
+ free={'accelerators_available': 0},
3205
+ ip_address=node_ip)
3206
+ continue
3167
3207
 
3168
3208
  if pods is None:
3169
3209
  accelerators_available = -1
3170
-
3171
3210
  else:
3172
- for pod in pods:
3173
- # Get all the pods running on the node
3174
- if (pod.spec.node_name == node.metadata.name and
3175
- pod.status.phase in ['Running', 'Pending']):
3176
- # Skip pods that should not count against GPU count
3177
- if should_exclude_pod_from_gpu_allocation(pod):
3178
- logger.debug(
3179
- f'Excluding low priority pod '
3180
- f'{pod.metadata.name} from GPU allocation '
3181
- f'calculations on node {node.metadata.name}')
3182
- continue
3183
- # Iterate over all the containers in the pod and sum the
3184
- # GPU requests
3185
- for container in pod.spec.containers:
3186
- if container.resources.requests:
3187
- allocated_qty += get_node_accelerator_count(
3188
- context, container.resources.requests)
3189
-
3211
+ allocated_qty = allocated_qty_by_node[node.metadata.name]
3190
3212
  accelerators_available = accelerator_count - allocated_qty
3191
3213
 
3192
3214
  # Exclude multi-host TPUs from being processed.
@@ -5,7 +5,9 @@ from typing import Any, Dict, List, Optional
5
5
 
6
6
  import pydantic
7
7
 
8
+ from sky import data
8
9
  from sky import models
10
+ from sky.jobs import state as job_state
9
11
  from sky.server import common
10
12
  from sky.skylet import job_lib
11
13
  from sky.utils import status_lib
@@ -143,3 +145,77 @@ class UploadStatus(enum.Enum):
143
145
  """Status of the upload."""
144
146
  UPLOADING = 'uploading'
145
147
  COMPLETED = 'completed'
148
+
149
+
150
+ class StorageRecord(ResponseBaseModel):
151
+ """Response for the storage list endpoint."""
152
+ name: str
153
+ launched_at: int
154
+ store: List[data.StoreType]
155
+ last_use: str
156
+ status: status_lib.StorageStatus
157
+
158
+
159
+ # TODO (syang) figure out which fields are always present
160
+ # and therefore can be non-optional.
161
+ class ManagedJobRecord(ResponseBaseModel):
162
+ """A single managed job record."""
163
+ job_id: Optional[int] = None
164
+ task_id: Optional[int] = None
165
+ job_name: Optional[str] = None
166
+ task_name: Optional[str] = None
167
+ job_duration: Optional[float] = None
168
+ workspace: Optional[str] = None
169
+ status: Optional[job_state.ManagedJobStatus] = None
170
+ schedule_state: Optional[str] = None
171
+ resources: Optional[str] = None
172
+ cluster_resources: Optional[str] = None
173
+ cluster_resources_full: Optional[str] = None
174
+ cloud: Optional[str] = None
175
+ region: Optional[str] = None
176
+ zone: Optional[str] = None
177
+ infra: Optional[str] = None
178
+ recovery_count: Optional[int] = None
179
+ details: Optional[str] = None
180
+ failure_reason: Optional[str] = None
181
+ user_name: Optional[str] = None
182
+ user_hash: Optional[str] = None
183
+ submitted_at: Optional[float] = None
184
+ start_at: Optional[float] = None
185
+ end_at: Optional[float] = None
186
+ user_yaml: Optional[str] = None
187
+ entrypoint: Optional[str] = None
188
+ metadata: Optional[Dict[str, Any]] = None
189
+ controller_pid: Optional[int] = None
190
+ dag_yaml_path: Optional[str] = None
191
+ env_file_path: Optional[str] = None
192
+ last_recovered_at: Optional[float] = None
193
+ run_timestamp: Optional[str] = None
194
+ priority: Optional[int] = None
195
+ original_user_yaml_path: Optional[str] = None
196
+ pool: Optional[str] = None
197
+ pool_hash: Optional[str] = None
198
+ current_cluster_name: Optional[str] = None
199
+ job_id_on_pool_cluster: Optional[int] = None
200
+ accelerators: Optional[Dict[str, int]] = None
201
+
202
+
203
+ class VolumeRecord(ResponseBaseModel):
204
+ """A single volume record."""
205
+ name: str
206
+ type: str
207
+ launched_at: int
208
+ cloud: str
209
+ region: str
210
+ zone: Optional[str] = None
211
+ size: str
212
+ config: Dict[str, Any]
213
+ name_on_cloud: str
214
+ user_hash: str
215
+ user_name: str
216
+ workspace: str
217
+ last_attached_at: Optional[int] = None
218
+ last_use: Optional[str] = None
219
+ status: Optional[str] = None
220
+ usedby_pods: List[str]
221
+ usedby_clusters: List[str]
sky/server/common.py CHANGED
@@ -780,6 +780,7 @@ def check_server_healthy_or_start_fn(deploy: bool = False,
780
780
  os.path.expanduser(constants.API_SERVER_CREATION_LOCK_PATH)):
781
781
  # Check again if server is already running. Other processes may
782
782
  # have started the server while we were waiting for the lock.
783
+ get_api_server_status.cache_clear() # type: ignore[attr-defined]
783
784
  api_server_info = get_api_server_status(endpoint)
784
785
  if api_server_info.status == ApiServerStatus.UNHEALTHY:
785
786
  _start_api_server(deploy, host, foreground, metrics,
@@ -841,7 +842,7 @@ def process_mounts_in_task_on_api_server(task: str, env_vars: Dict[str, str],
841
842
  for task_config in task_configs:
842
843
  if task_config is None:
843
844
  continue
844
- file_mounts_mapping = task_config.get('file_mounts_mapping', {})
845
+ file_mounts_mapping = task_config.pop('file_mounts_mapping', {})
845
846
  if not file_mounts_mapping:
846
847
  # We did not mount any files to new paths on the remote server
847
848
  # so no need to resolve filepaths.
@@ -72,7 +72,7 @@ def decode_status_kubernetes(
72
72
  List[Dict[str, Any]], Optional[str]]
73
73
  ) -> Tuple[List[kubernetes_utils.KubernetesSkyPilotClusterInfoPayload],
74
74
  List[kubernetes_utils.KubernetesSkyPilotClusterInfoPayload],
75
- List[Dict[str, Any]], Optional[str]]:
75
+ List[responses.ManagedJobRecord], Optional[str]]:
76
76
  (encoded_all_clusters, encoded_unmanaged_clusters, all_jobs,
77
77
  context) = return_value
78
78
  all_clusters = []
@@ -85,6 +85,7 @@ def decode_status_kubernetes(
85
85
  cluster['status'] = status_lib.ClusterStatus(cluster['status'])
86
86
  unmanaged_clusters.append(
87
87
  kubernetes_utils.KubernetesSkyPilotClusterInfoPayload(**cluster))
88
+ all_jobs = [responses.ManagedJobRecord(**job) for job in all_jobs]
88
89
  return all_clusters, unmanaged_clusters, all_jobs, context
89
90
 
90
91
 
@@ -115,7 +116,7 @@ def decode_jobs_queue(return_value: List[dict],) -> List[Dict[str, Any]]:
115
116
 
116
117
 
117
118
  @register_decoders('jobs.queue_v2')
118
- def decode_jobs_queue_v2(return_value) -> List[Dict[str, Any]]:
119
+ def decode_jobs_queue_v2(return_value) -> List[responses.ManagedJobRecord]:
119
120
  """Decode jobs queue response.
120
121
 
121
122
  Supports legacy list, or a dict {jobs, total}.
@@ -129,6 +130,7 @@ def decode_jobs_queue_v2(return_value) -> List[Dict[str, Any]]:
129
130
  jobs = return_value
130
131
  for job in jobs:
131
132
  job['status'] = managed_jobs.ManagedJobStatus(job['status'])
133
+ jobs = [responses.ManagedJobRecord(**job) for job in jobs]
132
134
  return jobs
133
135
 
134
136
 
@@ -181,14 +183,24 @@ def decode_list_accelerators(
181
183
 
182
184
  @register_decoders('storage_ls')
183
185
  def decode_storage_ls(
184
- return_value: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
186
+ return_value: List[Dict[str, Any]]) -> List[responses.StorageRecord]:
185
187
  for storage_info in return_value:
186
188
  storage_info['status'] = status_lib.StorageStatus(
187
189
  storage_info['status'])
188
190
  storage_info['store'] = [
189
191
  storage.StoreType(store) for store in storage_info['store']
190
192
  ]
191
- return return_value
193
+ return [
194
+ responses.StorageRecord(**storage_info) for storage_info in return_value
195
+ ]
196
+
197
+
198
+ @register_decoders('volume_list')
199
+ def decode_volume_list(
200
+ return_value: List[Dict[str, Any]]) -> List[responses.VolumeRecord]:
201
+ return [
202
+ responses.VolumeRecord(**volume_info) for volume_info in return_value
203
+ ]
192
204
 
193
205
 
194
206
  @register_decoders('job_status')
@@ -107,7 +107,7 @@ def encode_status_kubernetes(
107
107
  return_value: Tuple[
108
108
  List['kubernetes_utils.KubernetesSkyPilotClusterInfoPayload'],
109
109
  List['kubernetes_utils.KubernetesSkyPilotClusterInfoPayload'],
110
- List[Dict[str, Any]], Optional[str]]
110
+ List[responses.ManagedJobRecord], Optional[str]]
111
111
  ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]],
112
112
  Optional[str]]:
113
113
  all_clusters, unmanaged_clusters, all_jobs, context = return_value
@@ -121,6 +121,7 @@ def encode_status_kubernetes(
121
121
  encoded_cluster = dataclasses.asdict(cluster)
122
122
  encoded_cluster['status'] = encoded_cluster['status'].value
123
123
  encoded_unmanaged_clusters.append(encoded_cluster)
124
+ all_jobs = [job.model_dump() for job in all_jobs]
124
125
  return encoded_all_clusters, encoded_unmanaged_clusters, all_jobs, context
125
126
 
126
127
 
@@ -150,9 +151,9 @@ def encode_jobs_queue_v2(
150
151
  for job in jobs:
151
152
  job['status'] = job['status'].value
152
153
  if total is None:
153
- return jobs
154
+ return [job.model_dump() for job in jobs]
154
155
  return {
155
- 'jobs': jobs,
156
+ 'jobs': [job.model_dump() for job in jobs],
156
157
  'total': total,
157
158
  'total_no_filter': total_no_filter,
158
159
  'status_counts': status_counts
@@ -203,11 +204,17 @@ def encode_enabled_clouds(clouds: List['clouds.Cloud']) -> List[str]:
203
204
 
204
205
  @register_encoder('storage_ls')
205
206
  def encode_storage_ls(
206
- return_value: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
207
+ return_value: List[responses.StorageRecord]) -> List[Dict[str, Any]]:
207
208
  for storage_info in return_value:
208
209
  storage_info['status'] = storage_info['status'].value
209
210
  storage_info['store'] = [store.value for store in storage_info['store']]
210
- return return_value
211
+ return [storage_info.model_dump() for storage_info in return_value]
212
+
213
+
214
+ @register_encoder('volume_list')
215
+ def encode_volume_list(
216
+ return_value: List[responses.VolumeRecord]) -> List[Dict[str, Any]]:
217
+ return [volume_info.model_dump() for volume_info in return_value]
211
218
 
212
219
 
213
220
  @register_encoder('job_status')
sky/task.py CHANGED
@@ -649,6 +649,10 @@ class Task:
649
649
  config['workdir'] = _fill_in_env_vars(config['workdir'],
650
650
  env_and_secrets)
651
651
 
652
+ if config.get('volumes') is not None:
653
+ config['volumes'] = _fill_in_env_vars(config['volumes'],
654
+ env_and_secrets)
655
+
652
656
  task = Task(
653
657
  config.pop('name', None),
654
658
  run=config.pop('run', None),
@@ -193,11 +193,29 @@ class SSHConfigHelper(object):
193
193
  proxy_command = auth_config.get('ssh_proxy_command', None)
194
194
 
195
195
  docker_proxy_command_generator = None
196
+ proxy_command_for_nodes = proxy_command
196
197
  if docker_user is not None:
197
- docker_proxy_command_generator = lambda ip, port: ' '.join(
198
- ['ssh'] + command_runner.ssh_options_list(
199
- key_path, ssh_control_name=None, port=port) +
200
- ['-W', '%h:%p', f'{auth_config["ssh_user"]}@{ip}'])
198
+
199
+ def _docker_proxy_cmd(ip: str, port: int) -> str:
200
+ inner_proxy = proxy_command
201
+ inner_port = port or 22
202
+ if inner_proxy is not None:
203
+ inner_proxy = inner_proxy.replace('%h', ip)
204
+ inner_proxy = inner_proxy.replace('%p', str(inner_port))
205
+ return ' '.join(['ssh'] + command_runner.ssh_options_list(
206
+ key_path,
207
+ ssh_control_name=None,
208
+ ssh_proxy_command=inner_proxy,
209
+ port=inner_port,
210
+ # ProxyCommand (ssh -W) is a forwarding tunnel, not an
211
+ # interactive session. ControlMaster would cache these
212
+ # processes, causing them to hang and block subsequent
213
+ # connections. Each ProxyCommand should be ephemeral.
214
+ disable_control_master=True
215
+ ) + ['-W', '%h:%p', f'{auth_config["ssh_user"]}@{ip}'])
216
+
217
+ docker_proxy_command_generator = _docker_proxy_cmd
218
+ proxy_command_for_nodes = None
201
219
 
202
220
  codegen = ''
203
221
  # Add the nodes to the codegen
@@ -212,7 +230,7 @@ class SSHConfigHelper(object):
212
230
  # TODO(romilb): Update port number when k8s supports multinode
213
231
  codegen += cls._get_generated_config(
214
232
  sky_autogen_comment, node_name, ip, username,
215
- key_path_for_config, proxy_command, port,
233
+ key_path_for_config, proxy_command_for_nodes, port,
216
234
  docker_proxy_command) + '\n'
217
235
 
218
236
  cluster_config_path = os.path.expanduser(
@@ -652,15 +652,31 @@ class SSHCommandRunner(CommandRunner):
652
652
  if docker_user is not None:
653
653
  assert port is None or port == 22, (
654
654
  f'port must be None or 22 for docker_user, got {port}.')
655
- # Already checked in resources
656
- assert ssh_proxy_command is None, (
657
- 'ssh_proxy_command is not supported when using docker.')
655
+ # When connecting via docker, the outer SSH hop points to the
656
+ # container's sshd (localhost). Preserve the user proxy for the
657
+ # inner hop that reaches the host VM, and clear the outer proxy to
658
+ # avoid forwarding localhost through the jump host.
659
+ inner_proxy_command = ssh_proxy_command
660
+ inner_proxy_port = port or 22
661
+ self._ssh_proxy_command = None
658
662
  self.ip = 'localhost'
659
663
  self.ssh_user = docker_user
660
664
  self.port = constants.DEFAULT_DOCKER_PORT
665
+ if inner_proxy_command is not None:
666
+ # Replace %h/%p placeholders with actual host values, since the
667
+ # final destination from the perspective of the user proxy is
668
+ # the host VM (ip, inner_proxy_port).
669
+ inner_proxy_command = inner_proxy_command.replace('%h', ip)
670
+ inner_proxy_command = inner_proxy_command.replace(
671
+ '%p', str(inner_proxy_port))
661
672
  self._docker_ssh_proxy_command = lambda ssh: ' '.join(
662
- ssh + ssh_options_list(ssh_private_key, None
663
- ) + ['-W', '%h:%p', f'{ssh_user}@{ip}'])
673
+ ssh + ssh_options_list(ssh_private_key,
674
+ None,
675
+ ssh_proxy_command=inner_proxy_command,
676
+ port=inner_proxy_port,
677
+ disable_control_master=self.
678
+ disable_control_master) +
679
+ ['-W', '%h:%p', f'{ssh_user}@{ip}'])
664
680
  else:
665
681
  self.ip = ip
666
682
  self.ssh_user = ssh_user
@@ -142,8 +142,10 @@ class SSHCommandRunner(CommandRunner):
142
142
  ssh_user: str,
143
143
  ssh_private_key: str,
144
144
  ssh_control_name: Optional[str] = ...,
145
+ ssh_proxy_command: Optional[str] = ...,
145
146
  docker_user: Optional[str] = ...,
146
147
  disable_control_master: Optional[bool] = ...,
148
+ port_forward_execute_remote_command: Optional[bool] = ...,
147
149
  ) -> None:
148
150
  ...
149
151
 
@@ -198,6 +200,15 @@ class SSHCommandRunner(CommandRunner):
198
200
  **kwargs) -> Union[Tuple[int, str, str], int]:
199
201
  ...
200
202
 
203
+ def ssh_base_command(
204
+ self,
205
+ *,
206
+ ssh_mode: SshMode,
207
+ port_forward: Optional[List[Tuple[int, int]]],
208
+ connect_timeout: Optional[int],
209
+ ) -> List[str]:
210
+ ...
211
+
201
212
  def rsync(self,
202
213
  source: str,
203
214
  target: str,
sky/utils/volume.py CHANGED
@@ -26,6 +26,11 @@ class VolumeType(enum.Enum):
26
26
  PVC = 'k8s-pvc'
27
27
  RUNPOD_NETWORK_VOLUME = 'runpod-network-volume'
28
28
 
29
+ @classmethod
30
+ def supported_types(cls) -> list:
31
+ """Return list of supported volume type values."""
32
+ return [vt.value for vt in cls]
33
+
29
34
 
30
35
  class VolumeMount:
31
36
  """Volume mount specification."""
sky/volumes/client/sdk.py CHANGED
@@ -1,11 +1,12 @@
1
1
  """SDK functions for managed jobs."""
2
2
  import json
3
3
  import typing
4
- from typing import Any, Dict, List
4
+ from typing import List
5
5
 
6
6
  from sky import exceptions
7
7
  from sky import sky_logging
8
8
  from sky.adaptors import common as adaptors_common
9
+ from sky.schemas.api import responses
9
10
  from sky.server import common as server_common
10
11
  from sky.server import versions
11
12
  from sky.server.requests import payloads
@@ -116,7 +117,7 @@ def validate(volume: volume_lib.Volume) -> None:
116
117
  @usage_lib.entrypoint
117
118
  @server_common.check_server_healthy_or_start
118
119
  @annotations.client_api
119
- def ls() -> server_common.RequestId[List[Dict[str, Any]]]:
120
+ def ls() -> server_common.RequestId[List[responses.VolumeRecord]]:
120
121
  """Lists all volumes.
121
122
 
122
123
  Returns:
@@ -11,6 +11,7 @@ from sky import global_user_state
11
11
  from sky import models
12
12
  from sky import provision
13
13
  from sky import sky_logging
14
+ from sky.schemas.api import responses
14
15
  from sky.utils import common_utils
15
16
  from sky.utils import registry
16
17
  from sky.utils import rich_utils
@@ -56,7 +57,7 @@ def volume_refresh():
56
57
  volume_name, status=status_lib.VolumeStatus.IN_USE)
57
58
 
58
59
 
59
- def volume_list() -> List[Dict[str, Any]]:
60
+ def volume_list() -> List[responses.VolumeRecord]:
60
61
  """Gets the volumes.
61
62
 
62
63
  Returns:
@@ -143,7 +144,7 @@ def volume_list() -> List[Dict[str, Any]]:
143
144
  record['name_on_cloud'] = config.name_on_cloud
144
145
  record['usedby_pods'] = usedby_pods
145
146
  record['usedby_clusters'] = usedby_clusters
146
- records.append(record)
147
+ records.append(responses.VolumeRecord(**record))
147
148
  return records
148
149
 
149
150