skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +64 -32
- sky/adaptors/aws.py +23 -6
- sky/adaptors/azure.py +432 -15
- sky/adaptors/cloudflare.py +5 -5
- sky/adaptors/common.py +19 -9
- sky/adaptors/do.py +20 -0
- sky/adaptors/gcp.py +3 -2
- sky/adaptors/kubernetes.py +122 -88
- sky/adaptors/nebius.py +100 -0
- sky/adaptors/oci.py +39 -1
- sky/adaptors/vast.py +29 -0
- sky/admin_policy.py +101 -0
- sky/authentication.py +117 -98
- sky/backends/backend.py +52 -20
- sky/backends/backend_utils.py +669 -557
- sky/backends/cloud_vm_ray_backend.py +1099 -808
- sky/backends/local_docker_backend.py +14 -8
- sky/backends/wheel_utils.py +38 -20
- sky/benchmark/benchmark_utils.py +22 -23
- sky/check.py +76 -27
- sky/cli.py +1586 -1139
- sky/client/__init__.py +1 -0
- sky/client/cli.py +5683 -0
- sky/client/common.py +345 -0
- sky/client/sdk.py +1765 -0
- sky/cloud_stores.py +283 -19
- sky/clouds/__init__.py +7 -2
- sky/clouds/aws.py +303 -112
- sky/clouds/azure.py +185 -179
- sky/clouds/cloud.py +115 -37
- sky/clouds/cudo.py +29 -22
- sky/clouds/do.py +313 -0
- sky/clouds/fluidstack.py +44 -54
- sky/clouds/gcp.py +206 -65
- sky/clouds/ibm.py +26 -21
- sky/clouds/kubernetes.py +345 -91
- sky/clouds/lambda_cloud.py +40 -29
- sky/clouds/nebius.py +297 -0
- sky/clouds/oci.py +129 -90
- sky/clouds/paperspace.py +22 -18
- sky/clouds/runpod.py +53 -34
- sky/clouds/scp.py +28 -24
- sky/clouds/service_catalog/__init__.py +19 -13
- sky/clouds/service_catalog/aws_catalog.py +29 -12
- sky/clouds/service_catalog/azure_catalog.py +33 -6
- sky/clouds/service_catalog/common.py +95 -75
- sky/clouds/service_catalog/constants.py +3 -3
- sky/clouds/service_catalog/cudo_catalog.py +13 -3
- sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
- sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
- sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
- sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
- sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
- sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
- sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
- sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
- sky/clouds/service_catalog/do_catalog.py +111 -0
- sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
- sky/clouds/service_catalog/gcp_catalog.py +16 -2
- sky/clouds/service_catalog/ibm_catalog.py +2 -2
- sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
- sky/clouds/service_catalog/lambda_catalog.py +8 -3
- sky/clouds/service_catalog/nebius_catalog.py +116 -0
- sky/clouds/service_catalog/oci_catalog.py +31 -4
- sky/clouds/service_catalog/paperspace_catalog.py +2 -2
- sky/clouds/service_catalog/runpod_catalog.py +2 -2
- sky/clouds/service_catalog/scp_catalog.py +2 -2
- sky/clouds/service_catalog/vast_catalog.py +104 -0
- sky/clouds/service_catalog/vsphere_catalog.py +2 -2
- sky/clouds/utils/aws_utils.py +65 -0
- sky/clouds/utils/azure_utils.py +91 -0
- sky/clouds/utils/gcp_utils.py +5 -9
- sky/clouds/utils/oci_utils.py +47 -5
- sky/clouds/utils/scp_utils.py +4 -3
- sky/clouds/vast.py +280 -0
- sky/clouds/vsphere.py +22 -18
- sky/core.py +361 -107
- sky/dag.py +41 -28
- sky/data/data_transfer.py +37 -0
- sky/data/data_utils.py +211 -32
- sky/data/mounting_utils.py +182 -30
- sky/data/storage.py +2118 -270
- sky/data/storage_utils.py +126 -5
- sky/exceptions.py +179 -8
- sky/execution.py +158 -85
- sky/global_user_state.py +150 -34
- sky/jobs/__init__.py +12 -10
- sky/jobs/client/__init__.py +0 -0
- sky/jobs/client/sdk.py +302 -0
- sky/jobs/constants.py +49 -11
- sky/jobs/controller.py +161 -99
- sky/jobs/dashboard/dashboard.py +171 -25
- sky/jobs/dashboard/templates/index.html +572 -60
- sky/jobs/recovery_strategy.py +157 -156
- sky/jobs/scheduler.py +307 -0
- sky/jobs/server/__init__.py +1 -0
- sky/jobs/server/core.py +598 -0
- sky/jobs/server/dashboard_utils.py +69 -0
- sky/jobs/server/server.py +190 -0
- sky/jobs/state.py +627 -122
- sky/jobs/utils.py +615 -206
- sky/models.py +27 -0
- sky/optimizer.py +142 -83
- sky/provision/__init__.py +20 -5
- sky/provision/aws/config.py +124 -42
- sky/provision/aws/instance.py +130 -53
- sky/provision/azure/__init__.py +7 -0
- sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
- sky/provision/azure/config.py +220 -0
- sky/provision/azure/instance.py +1012 -37
- sky/provision/common.py +31 -3
- sky/provision/constants.py +25 -0
- sky/provision/cudo/__init__.py +2 -1
- sky/provision/cudo/cudo_utils.py +112 -0
- sky/provision/cudo/cudo_wrapper.py +37 -16
- sky/provision/cudo/instance.py +28 -12
- sky/provision/do/__init__.py +11 -0
- sky/provision/do/config.py +14 -0
- sky/provision/do/constants.py +10 -0
- sky/provision/do/instance.py +287 -0
- sky/provision/do/utils.py +301 -0
- sky/provision/docker_utils.py +82 -46
- sky/provision/fluidstack/fluidstack_utils.py +57 -125
- sky/provision/fluidstack/instance.py +15 -43
- sky/provision/gcp/config.py +19 -9
- sky/provision/gcp/constants.py +7 -1
- sky/provision/gcp/instance.py +55 -34
- sky/provision/gcp/instance_utils.py +339 -80
- sky/provision/gcp/mig_utils.py +210 -0
- sky/provision/instance_setup.py +172 -133
- sky/provision/kubernetes/__init__.py +1 -0
- sky/provision/kubernetes/config.py +104 -90
- sky/provision/kubernetes/constants.py +8 -0
- sky/provision/kubernetes/instance.py +680 -325
- sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
- sky/provision/kubernetes/network.py +54 -20
- sky/provision/kubernetes/network_utils.py +70 -21
- sky/provision/kubernetes/utils.py +1370 -251
- sky/provision/lambda_cloud/__init__.py +11 -0
- sky/provision/lambda_cloud/config.py +10 -0
- sky/provision/lambda_cloud/instance.py +265 -0
- sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
- sky/provision/logging.py +1 -1
- sky/provision/nebius/__init__.py +11 -0
- sky/provision/nebius/config.py +11 -0
- sky/provision/nebius/instance.py +285 -0
- sky/provision/nebius/utils.py +318 -0
- sky/provision/oci/__init__.py +15 -0
- sky/provision/oci/config.py +51 -0
- sky/provision/oci/instance.py +436 -0
- sky/provision/oci/query_utils.py +681 -0
- sky/provision/paperspace/constants.py +6 -0
- sky/provision/paperspace/instance.py +4 -3
- sky/provision/paperspace/utils.py +2 -0
- sky/provision/provisioner.py +207 -130
- sky/provision/runpod/__init__.py +1 -0
- sky/provision/runpod/api/__init__.py +3 -0
- sky/provision/runpod/api/commands.py +119 -0
- sky/provision/runpod/api/pods.py +142 -0
- sky/provision/runpod/instance.py +64 -8
- sky/provision/runpod/utils.py +239 -23
- sky/provision/vast/__init__.py +10 -0
- sky/provision/vast/config.py +11 -0
- sky/provision/vast/instance.py +247 -0
- sky/provision/vast/utils.py +162 -0
- sky/provision/vsphere/common/vim_utils.py +1 -1
- sky/provision/vsphere/instance.py +8 -18
- sky/provision/vsphere/vsphere_utils.py +1 -1
- sky/resources.py +247 -102
- sky/serve/__init__.py +9 -9
- sky/serve/autoscalers.py +361 -299
- sky/serve/client/__init__.py +0 -0
- sky/serve/client/sdk.py +366 -0
- sky/serve/constants.py +12 -3
- sky/serve/controller.py +106 -36
- sky/serve/load_balancer.py +63 -12
- sky/serve/load_balancing_policies.py +84 -2
- sky/serve/replica_managers.py +42 -34
- sky/serve/serve_state.py +62 -32
- sky/serve/serve_utils.py +271 -160
- sky/serve/server/__init__.py +0 -0
- sky/serve/{core.py → server/core.py} +271 -90
- sky/serve/server/server.py +112 -0
- sky/serve/service.py +52 -16
- sky/serve/service_spec.py +95 -32
- sky/server/__init__.py +1 -0
- sky/server/common.py +430 -0
- sky/server/constants.py +21 -0
- sky/server/html/log.html +174 -0
- sky/server/requests/__init__.py +0 -0
- sky/server/requests/executor.py +472 -0
- sky/server/requests/payloads.py +487 -0
- sky/server/requests/queues/__init__.py +0 -0
- sky/server/requests/queues/mp_queue.py +76 -0
- sky/server/requests/requests.py +567 -0
- sky/server/requests/serializers/__init__.py +0 -0
- sky/server/requests/serializers/decoders.py +192 -0
- sky/server/requests/serializers/encoders.py +166 -0
- sky/server/server.py +1106 -0
- sky/server/stream_utils.py +141 -0
- sky/setup_files/MANIFEST.in +2 -5
- sky/setup_files/dependencies.py +159 -0
- sky/setup_files/setup.py +14 -125
- sky/sky_logging.py +59 -14
- sky/skylet/autostop_lib.py +2 -2
- sky/skylet/constants.py +183 -50
- sky/skylet/events.py +22 -10
- sky/skylet/job_lib.py +403 -258
- sky/skylet/log_lib.py +111 -71
- sky/skylet/log_lib.pyi +6 -0
- sky/skylet/providers/command_runner.py +6 -8
- sky/skylet/providers/ibm/node_provider.py +2 -2
- sky/skylet/providers/scp/config.py +11 -3
- sky/skylet/providers/scp/node_provider.py +8 -8
- sky/skylet/skylet.py +3 -1
- sky/skylet/subprocess_daemon.py +69 -17
- sky/skypilot_config.py +119 -57
- sky/task.py +205 -64
- sky/templates/aws-ray.yml.j2 +37 -7
- sky/templates/azure-ray.yml.j2 +27 -82
- sky/templates/cudo-ray.yml.j2 +7 -3
- sky/templates/do-ray.yml.j2 +98 -0
- sky/templates/fluidstack-ray.yml.j2 +7 -4
- sky/templates/gcp-ray.yml.j2 +26 -6
- sky/templates/ibm-ray.yml.j2 +3 -2
- sky/templates/jobs-controller.yaml.j2 +46 -11
- sky/templates/kubernetes-ingress.yml.j2 +7 -0
- sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
- sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
- sky/templates/kubernetes-ray.yml.j2 +292 -25
- sky/templates/lambda-ray.yml.j2 +30 -40
- sky/templates/nebius-ray.yml.j2 +79 -0
- sky/templates/oci-ray.yml.j2 +18 -57
- sky/templates/paperspace-ray.yml.j2 +10 -6
- sky/templates/runpod-ray.yml.j2 +26 -4
- sky/templates/scp-ray.yml.j2 +3 -2
- sky/templates/sky-serve-controller.yaml.j2 +12 -1
- sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
- sky/templates/vast-ray.yml.j2 +70 -0
- sky/templates/vsphere-ray.yml.j2 +8 -3
- sky/templates/websocket_proxy.py +64 -0
- sky/usage/constants.py +10 -1
- sky/usage/usage_lib.py +130 -37
- sky/utils/accelerator_registry.py +35 -51
- sky/utils/admin_policy_utils.py +147 -0
- sky/utils/annotations.py +51 -0
- sky/utils/cli_utils/status_utils.py +81 -23
- sky/utils/cluster_utils.py +356 -0
- sky/utils/command_runner.py +452 -89
- sky/utils/command_runner.pyi +77 -3
- sky/utils/common.py +54 -0
- sky/utils/common_utils.py +319 -108
- sky/utils/config_utils.py +204 -0
- sky/utils/control_master_utils.py +48 -0
- sky/utils/controller_utils.py +548 -266
- sky/utils/dag_utils.py +93 -32
- sky/utils/db_utils.py +18 -4
- sky/utils/env_options.py +29 -7
- sky/utils/kubernetes/create_cluster.sh +8 -60
- sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
- sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
- sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
- sky/utils/kubernetes/gpu_labeler.py +4 -4
- sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
- sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
- sky/utils/kubernetes/rsync_helper.sh +24 -0
- sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
- sky/utils/log_utils.py +240 -33
- sky/utils/message_utils.py +81 -0
- sky/utils/registry.py +127 -0
- sky/utils/resources_utils.py +94 -22
- sky/utils/rich_utils.py +247 -18
- sky/utils/schemas.py +284 -64
- sky/{status_lib.py → utils/status_lib.py} +12 -7
- sky/utils/subprocess_utils.py +212 -46
- sky/utils/timeline.py +12 -7
- sky/utils/ux_utils.py +168 -15
- skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
- skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
- {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
- sky/clouds/cloud_registry.py +0 -31
- sky/jobs/core.py +0 -330
- sky/skylet/providers/azure/__init__.py +0 -2
- sky/skylet/providers/azure/azure-vm-template.json +0 -301
- sky/skylet/providers/azure/config.py +0 -170
- sky/skylet/providers/azure/node_provider.py +0 -466
- sky/skylet/providers/lambda_cloud/__init__.py +0 -2
- sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
- sky/skylet/providers/oci/__init__.py +0 -2
- sky/skylet/providers/oci/node_provider.py +0 -488
- sky/skylet/providers/oci/query_helper.py +0 -383
- sky/skylet/providers/oci/utils.py +0 -21
- sky/utils/cluster_yaml_utils.py +0 -24
- sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
- skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
- skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
- {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
sky/client/common.py
ADDED
@@ -0,0 +1,345 @@
|
|
1
|
+
"""Common utilities for the client."""
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
import dataclasses
|
5
|
+
import json
|
6
|
+
import logging
|
7
|
+
import math
|
8
|
+
import os
|
9
|
+
import pathlib
|
10
|
+
import tempfile
|
11
|
+
import time
|
12
|
+
import typing
|
13
|
+
from typing import Dict, Generator, Iterable
|
14
|
+
import uuid
|
15
|
+
import zipfile
|
16
|
+
|
17
|
+
import httpx
|
18
|
+
import requests
|
19
|
+
|
20
|
+
from sky import sky_logging
|
21
|
+
from sky.data import data_utils
|
22
|
+
from sky.data import storage_utils
|
23
|
+
from sky.server import common as server_common
|
24
|
+
from sky.server.requests import payloads
|
25
|
+
from sky.skylet import constants
|
26
|
+
from sky.utils import common_utils
|
27
|
+
from sky.utils import rich_utils
|
28
|
+
from sky.utils import subprocess_utils
|
29
|
+
from sky.utils import ux_utils
|
30
|
+
|
31
|
+
if typing.TYPE_CHECKING:
|
32
|
+
import sky
|
33
|
+
import sky.dag as dag_lib
|
34
|
+
|
35
|
+
logger = sky_logging.init_logger(__name__)
|
36
|
+
|
37
|
+
# The chunk size for downloading the logs from the API server.
|
38
|
+
_DOWNLOAD_CHUNK_BYTES = 8192
|
39
|
+
# The chunk size for the zip file to be uploaded to the API server. We split
|
40
|
+
# the zip file into chunks to avoid network issues for large request body that
|
41
|
+
# can be caused by NGINX's client_max_body_size.
|
42
|
+
_UPLOAD_CHUNK_BYTES = 512 * 1024 * 1024
|
43
|
+
|
44
|
+
FILE_UPLOAD_LOGS_DIR = os.path.join(constants.SKY_LOGS_DIRECTORY,
|
45
|
+
'file_uploads')
|
46
|
+
|
47
|
+
# Connection timeout when sending requests to the API server.
|
48
|
+
API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS = 5
|
49
|
+
|
50
|
+
|
51
|
+
def download_logs_from_api_server(
|
52
|
+
paths_on_api_server: Iterable[str]) -> Dict[str, str]:
|
53
|
+
"""Downloads the logs from the API server.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
paths_on_api_server: The paths on the API server to download.
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
A dictionary mapping the remote path on API server to the local path.
|
60
|
+
"""
|
61
|
+
remote2local_path_dict = {
|
62
|
+
remote_path: remote_path.replace(
|
63
|
+
# TODO(zhwu): handling the replacement locally is not stable, and
|
64
|
+
# may cause issues when we change the pattern of the remote path.
|
65
|
+
# This should be moved to remote API server. A proper way might be
|
66
|
+
# set the returned path to be started with a special prefix, instead
|
67
|
+
# of using the `api_server_user_logs_dir_prefix()`.
|
68
|
+
str(server_common.api_server_user_logs_dir_prefix()),
|
69
|
+
constants.SKY_LOGS_DIRECTORY) for remote_path in paths_on_api_server
|
70
|
+
}
|
71
|
+
body = payloads.DownloadBody(folder_paths=list(paths_on_api_server),)
|
72
|
+
response = requests.post(f'{server_common.get_server_url()}/download',
|
73
|
+
json=json.loads(body.model_dump_json()),
|
74
|
+
stream=True)
|
75
|
+
if response.status_code == 200:
|
76
|
+
remote_home_path = response.headers.get('X-Home-Path')
|
77
|
+
assert remote_home_path is not None, response.headers
|
78
|
+
with tempfile.NamedTemporaryFile(prefix='skypilot-logs-download-',
|
79
|
+
delete=True) as temp_file:
|
80
|
+
# Download the zip file from the API server to the local machine.
|
81
|
+
for chunk in response.iter_content(
|
82
|
+
chunk_size=_DOWNLOAD_CHUNK_BYTES):
|
83
|
+
temp_file.write(chunk)
|
84
|
+
temp_file.flush()
|
85
|
+
|
86
|
+
# Unzip the downloaded file and save the logs to the correct local
|
87
|
+
# directory.
|
88
|
+
with zipfile.ZipFile(temp_file, 'r') as zipf:
|
89
|
+
for member in zipf.namelist():
|
90
|
+
# Determine the new path
|
91
|
+
zipped_filename = os.path.basename(member)
|
92
|
+
zipped_dir = os.path.dirname('/' + member)
|
93
|
+
local_dir = zipped_dir.replace(remote_home_path, '~')
|
94
|
+
for remote_path, local_path in remote2local_path_dict.items(
|
95
|
+
):
|
96
|
+
if local_dir.startswith(remote_path):
|
97
|
+
local_dir = local_dir.replace(
|
98
|
+
remote_path, local_path)
|
99
|
+
break
|
100
|
+
else:
|
101
|
+
raise ValueError(f'Invalid folder path: {zipped_dir}')
|
102
|
+
new_path = pathlib.Path(
|
103
|
+
local_dir).expanduser().resolve() / zipped_filename
|
104
|
+
new_path.parent.mkdir(parents=True, exist_ok=True)
|
105
|
+
if member.endswith('/'):
|
106
|
+
# If it is a directory, we need to create it.
|
107
|
+
new_path.mkdir(parents=True, exist_ok=True)
|
108
|
+
else:
|
109
|
+
with zipf.open(member) as member_file:
|
110
|
+
new_path.write_bytes(member_file.read())
|
111
|
+
|
112
|
+
return remote2local_path_dict
|
113
|
+
else:
|
114
|
+
raise Exception(
|
115
|
+
f'Failed to download logs: {response.status_code} {response.text}')
|
116
|
+
|
117
|
+
|
118
|
+
# === Upload files to API server ===
|
119
|
+
|
120
|
+
|
121
|
+
class FileChunkIterator:
|
122
|
+
"""A file-like object that reads from a file in chunks."""
|
123
|
+
|
124
|
+
def __init__(self, file_obj, chunk_size: int, chunk_index: int):
|
125
|
+
self.file_obj = file_obj
|
126
|
+
self.chunk_size = chunk_size
|
127
|
+
self.chunk_index = chunk_index
|
128
|
+
self.bytes_read = 0
|
129
|
+
|
130
|
+
def __iter__(self):
|
131
|
+
# Seek to the correct position for this chunk
|
132
|
+
self.file_obj.seek(self.chunk_index * self.chunk_size)
|
133
|
+
while self.bytes_read < self.chunk_size:
|
134
|
+
# Read a smaller buffer size to keep memory usage low
|
135
|
+
buffer_size = min(64 * 1024,
|
136
|
+
self.chunk_size - self.bytes_read) # 64KB buffer
|
137
|
+
data = self.file_obj.read(buffer_size)
|
138
|
+
if not data:
|
139
|
+
break
|
140
|
+
self.bytes_read += len(data)
|
141
|
+
yield data
|
142
|
+
|
143
|
+
|
144
|
+
@dataclasses.dataclass
|
145
|
+
class UploadChunkParams:
|
146
|
+
client: httpx.Client
|
147
|
+
upload_id: str
|
148
|
+
chunk_index: int
|
149
|
+
total_chunks: int
|
150
|
+
file_path: str
|
151
|
+
upload_logger: logging.Logger
|
152
|
+
log_file: str
|
153
|
+
|
154
|
+
|
155
|
+
def _upload_chunk_with_retry(params: UploadChunkParams) -> None:
|
156
|
+
"""Uploads a chunk of a zip file to the API server."""
|
157
|
+
upload_logger = params.upload_logger
|
158
|
+
upload_logger.info(
|
159
|
+
f'Uploading chunk: {params.chunk_index + 1} / {params.total_chunks}')
|
160
|
+
|
161
|
+
server_url = server_common.get_server_url()
|
162
|
+
max_attempts = 3
|
163
|
+
with open(params.file_path, 'rb') as f:
|
164
|
+
for attempt in range(max_attempts):
|
165
|
+
response = params.client.post(
|
166
|
+
f'{server_url}/upload',
|
167
|
+
params={
|
168
|
+
'user_hash': common_utils.get_user_hash(),
|
169
|
+
'upload_id': params.upload_id,
|
170
|
+
'chunk_index': str(params.chunk_index),
|
171
|
+
'total_chunks': str(params.total_chunks),
|
172
|
+
},
|
173
|
+
content=FileChunkIterator(f, _UPLOAD_CHUNK_BYTES,
|
174
|
+
params.chunk_index),
|
175
|
+
headers={'Content-Type': 'application/octet-stream'})
|
176
|
+
if response.status_code == 200:
|
177
|
+
data = response.json()
|
178
|
+
status = data.get('status')
|
179
|
+
msg = ('Uploaded chunk: '
|
180
|
+
f'{params.chunk_index + 1} / {params.total_chunks}')
|
181
|
+
if status == 'uploading':
|
182
|
+
missing_chunks = data.get('missing_chunks')
|
183
|
+
if missing_chunks:
|
184
|
+
msg += f' - Waiting for chunks: {missing_chunks}'
|
185
|
+
upload_logger.info(msg)
|
186
|
+
return
|
187
|
+
elif attempt < max_attempts - 1:
|
188
|
+
upload_logger.error(
|
189
|
+
f'Failed to upload chunk: '
|
190
|
+
f'{params.chunk_index + 1} / {params.total_chunks}: '
|
191
|
+
f'{response.content.decode("utf-8")}')
|
192
|
+
upload_logger.info(
|
193
|
+
f'Retrying... ({attempt + 1} / {max_attempts})')
|
194
|
+
time.sleep(1)
|
195
|
+
else:
|
196
|
+
error_msg = (
|
197
|
+
f'Failed to upload chunk: {params.chunk_index + 1} / '
|
198
|
+
f'{params.total_chunks}: {response.json().get("detail")}')
|
199
|
+
upload_logger.error(error_msg)
|
200
|
+
with ux_utils.print_exception_no_traceback():
|
201
|
+
raise RuntimeError(
|
202
|
+
ux_utils.error_message(error_msg + '\n',
|
203
|
+
params.log_file,
|
204
|
+
is_local=True))
|
205
|
+
|
206
|
+
|
207
|
+
@contextlib.contextmanager
|
208
|
+
def _setup_upload_logger(
|
209
|
+
log_file: str) -> Generator[logging.Logger, None, None]:
|
210
|
+
try:
|
211
|
+
upload_logger = logging.getLogger('sky.upload')
|
212
|
+
upload_logger.propagate = False
|
213
|
+
handler = logging.FileHandler(os.path.expanduser(log_file),
|
214
|
+
encoding='utf-8')
|
215
|
+
handler.setFormatter(sky_logging.FORMATTER)
|
216
|
+
upload_logger.addHandler(handler)
|
217
|
+
upload_logger.setLevel(logging.DEBUG)
|
218
|
+
yield upload_logger
|
219
|
+
finally:
|
220
|
+
upload_logger.removeHandler(handler)
|
221
|
+
handler.close()
|
222
|
+
|
223
|
+
|
224
|
+
def upload_mounts_to_api_server(dag: 'sky.Dag',
|
225
|
+
workdir_only: bool = False) -> 'dag_lib.Dag':
|
226
|
+
"""Upload user files to remote API server.
|
227
|
+
|
228
|
+
This function needs to be called after sdk.validate(),
|
229
|
+
as the file paths need to be expanded to keep file_mounts_mapping
|
230
|
+
aligned with the actual task uploaded to SkyPilot API server.
|
231
|
+
|
232
|
+
We don't use FastAPI's built-in multipart upload, as nginx's
|
233
|
+
client_max_body_size can block the request due to large request body, i.e.,
|
234
|
+
even though the multipart upload streams the file to the server, there is
|
235
|
+
only one HTTP request, and a large request body will be blocked by nginx.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
dag: The dag where the file mounts are defined.
|
239
|
+
workdir_only: Whether to only upload the workdir, which is used for
|
240
|
+
`exec`, as it does not need other files/folders in file_mounts.
|
241
|
+
|
242
|
+
Returns:
|
243
|
+
The dag with the file_mounts_mapping updated, which maps the original
|
244
|
+
file paths to the full path, so that on API server, the file paths can
|
245
|
+
be retrieved by adding prefix to the full path.
|
246
|
+
"""
|
247
|
+
|
248
|
+
if server_common.is_api_server_local():
|
249
|
+
return dag
|
250
|
+
|
251
|
+
def _full_path(src: str) -> str:
|
252
|
+
return os.path.abspath(os.path.expanduser(src))
|
253
|
+
|
254
|
+
upload_list = []
|
255
|
+
for task_ in dag.tasks:
|
256
|
+
task_.file_mounts_mapping = {}
|
257
|
+
if task_.workdir:
|
258
|
+
workdir = task_.workdir
|
259
|
+
assert os.path.isabs(workdir)
|
260
|
+
upload_list.append(workdir)
|
261
|
+
task_.file_mounts_mapping[workdir] = workdir
|
262
|
+
if workdir_only:
|
263
|
+
continue
|
264
|
+
if task_.file_mounts is not None:
|
265
|
+
for src in task_.file_mounts.values():
|
266
|
+
if not data_utils.is_cloud_store_url(src):
|
267
|
+
assert os.path.isabs(src)
|
268
|
+
upload_list.append(src)
|
269
|
+
task_.file_mounts_mapping[src] = src
|
270
|
+
if src == constants.LOCAL_SKYPILOT_CONFIG_PATH_PLACEHOLDER:
|
271
|
+
# The placeholder for the local skypilot config path is in
|
272
|
+
# file mounts for controllers. It will be replaced with the
|
273
|
+
# real path for config file on API server.
|
274
|
+
pass
|
275
|
+
if task_.storage_mounts is not None:
|
276
|
+
for storage in task_.storage_mounts.values():
|
277
|
+
storage_source = storage.source
|
278
|
+
is_cloud_store_url = (
|
279
|
+
isinstance(storage_source, str) and
|
280
|
+
data_utils.is_cloud_store_url(storage_source))
|
281
|
+
if (storage_source is not None and not is_cloud_store_url):
|
282
|
+
if isinstance(storage_source, str):
|
283
|
+
storage_source = [storage_source]
|
284
|
+
for src in storage_source:
|
285
|
+
upload_list.append(_full_path(src))
|
286
|
+
task_.file_mounts_mapping[src] = _full_path(src)
|
287
|
+
if (task_.service is not None and
|
288
|
+
task_.service.tls_credential is not None):
|
289
|
+
upload_list.append(task_.service.tls_credential.keyfile)
|
290
|
+
upload_list.append(task_.service.tls_credential.certfile)
|
291
|
+
task_.file_mounts_mapping[
|
292
|
+
task_.service.tls_credential.
|
293
|
+
keyfile] = task_.service.tls_credential.keyfile
|
294
|
+
task_.file_mounts_mapping[
|
295
|
+
task_.service.tls_credential.
|
296
|
+
certfile] = task_.service.tls_credential.certfile
|
297
|
+
|
298
|
+
if upload_list:
|
299
|
+
os.makedirs(os.path.expanduser(FILE_UPLOAD_LOGS_DIR), exist_ok=True)
|
300
|
+
upload_id = sky_logging.get_run_timestamp()
|
301
|
+
upload_id = f'{upload_id}-{uuid.uuid4().hex[:8]}'
|
302
|
+
log_file = os.path.join(FILE_UPLOAD_LOGS_DIR, f'{upload_id}.log')
|
303
|
+
|
304
|
+
logger.info(ux_utils.starting_message('Uploading files to API server'))
|
305
|
+
with rich_utils.client_status(
|
306
|
+
ux_utils.spinner_message(
|
307
|
+
'Uploading files to API server (1/2 - Zipping)',
|
308
|
+
log_file,
|
309
|
+
is_local=True)) as status, _setup_upload_logger(
|
310
|
+
log_file) as upload_logger:
|
311
|
+
with tempfile.NamedTemporaryFile(suffix='.zip',
|
312
|
+
delete=False) as temp_zip_file:
|
313
|
+
upload_logger.info(
|
314
|
+
f'Zipping files to be uploaded: {upload_list}')
|
315
|
+
storage_utils.zip_files_and_folders(upload_list,
|
316
|
+
temp_zip_file.name)
|
317
|
+
upload_logger.info(f'Zipped files to: {temp_zip_file.name}')
|
318
|
+
|
319
|
+
zip_file_size = os.path.getsize(temp_zip_file.name)
|
320
|
+
# Per chunk size 512 MB
|
321
|
+
total_chunks = int(math.ceil(zip_file_size / _UPLOAD_CHUNK_BYTES))
|
322
|
+
timeout = httpx.Timeout(None, read=180.0)
|
323
|
+
status.update(
|
324
|
+
ux_utils.spinner_message(
|
325
|
+
'Uploading files to API server (2/2 - Uploading)',
|
326
|
+
log_file,
|
327
|
+
is_local=True))
|
328
|
+
|
329
|
+
with httpx.Client(timeout=timeout) as client:
|
330
|
+
chunk_params = [
|
331
|
+
UploadChunkParams(client, upload_id, chunk_index,
|
332
|
+
total_chunks, temp_zip_file.name,
|
333
|
+
upload_logger, log_file)
|
334
|
+
for chunk_index in range(total_chunks)
|
335
|
+
]
|
336
|
+
subprocess_utils.run_in_parallel(_upload_chunk_with_retry,
|
337
|
+
chunk_params)
|
338
|
+
os.unlink(temp_zip_file.name)
|
339
|
+
upload_logger.info(f'Uploaded files: {upload_list}')
|
340
|
+
logger.info(
|
341
|
+
ux_utils.finishing_message('Files uploaded',
|
342
|
+
log_file,
|
343
|
+
is_local=True))
|
344
|
+
|
345
|
+
return dag
|