skypilot-nightly 1.0.0.dev20250219__py3-none-any.whl → 1.0.0.dev20250221__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +4 -2
- sky/adaptors/nebius.py +85 -0
- sky/backends/backend_utils.py +8 -0
- sky/backends/cloud_vm_ray_backend.py +10 -2
- sky/client/sdk.py +8 -3
- sky/clouds/__init__.py +2 -0
- sky/clouds/nebius.py +294 -0
- sky/clouds/service_catalog/constants.py +1 -1
- sky/clouds/service_catalog/nebius_catalog.py +116 -0
- sky/jobs/controller.py +17 -0
- sky/jobs/server/core.py +31 -3
- sky/provision/__init__.py +1 -0
- sky/provision/kubernetes/instance.py +5 -1
- sky/provision/kubernetes/utils.py +8 -7
- sky/provision/nebius/__init__.py +11 -0
- sky/provision/nebius/config.py +11 -0
- sky/provision/nebius/instance.py +285 -0
- sky/provision/nebius/utils.py +310 -0
- sky/server/common.py +5 -7
- sky/server/requests/executor.py +94 -87
- sky/server/server.py +10 -5
- sky/server/stream_utils.py +8 -11
- sky/setup_files/dependencies.py +9 -1
- sky/skylet/constants.py +3 -6
- sky/task.py +6 -0
- sky/templates/jobs-controller.yaml.j2 +3 -0
- sky/templates/nebius-ray.yml.j2 +79 -0
- sky/utils/common_utils.py +38 -0
- sky/utils/controller_utils.py +66 -2
- {skypilot_nightly-1.0.0.dev20250219.dist-info → skypilot_nightly-1.0.0.dev20250221.dist-info}/METADATA +8 -4
- {skypilot_nightly-1.0.0.dev20250219.dist-info → skypilot_nightly-1.0.0.dev20250221.dist-info}/RECORD +35 -27
- {skypilot_nightly-1.0.0.dev20250219.dist-info → skypilot_nightly-1.0.0.dev20250221.dist-info}/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250219.dist-info → skypilot_nightly-1.0.0.dev20250221.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20250219.dist-info → skypilot_nightly-1.0.0.dev20250221.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250219.dist-info → skypilot_nightly-1.0.0.dev20250221.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,310 @@
|
|
1
|
+
"""Nebius library wrapper for SkyPilot."""
|
2
|
+
import time
|
3
|
+
from typing import Any, Dict
|
4
|
+
import uuid
|
5
|
+
|
6
|
+
from sky import sky_logging
|
7
|
+
from sky.adaptors import nebius
|
8
|
+
from sky.utils import common_utils
|
9
|
+
|
10
|
+
logger = sky_logging.init_logger(__name__)
|
11
|
+
|
12
|
+
POLL_INTERVAL = 5
|
13
|
+
|
14
|
+
|
15
|
+
def retry(func):
|
16
|
+
"""Decorator to retry a function."""
|
17
|
+
|
18
|
+
def wrapper(*args, **kwargs):
|
19
|
+
"""Wrapper for retrying a function."""
|
20
|
+
cnt = 0
|
21
|
+
while True:
|
22
|
+
try:
|
23
|
+
return func(*args, **kwargs)
|
24
|
+
except nebius.nebius.error.QueryError as e:
|
25
|
+
if cnt >= 3:
|
26
|
+
raise
|
27
|
+
logger.warning('Retrying for exception: '
|
28
|
+
f'{common_utils.format_exception(e)}.')
|
29
|
+
time.sleep(POLL_INTERVAL)
|
30
|
+
|
31
|
+
return wrapper
|
32
|
+
|
33
|
+
|
34
|
+
def get_project_by_region(region: str) -> str:
|
35
|
+
service = nebius.iam().ProjectServiceClient(nebius.sdk())
|
36
|
+
projects = service.list(nebius.iam().ListProjectsRequest(
|
37
|
+
parent_id=nebius.get_tenant_id())).wait()
|
38
|
+
# To find a project in a specific region, we rely on the project ID to
|
39
|
+
# deduce the region, since there is currently no method to retrieve region
|
40
|
+
# information directly from the project. Additionally, there is only one
|
41
|
+
# project per region, and projects cannot be created at this time.
|
42
|
+
# The region is determined from the project ID using a region-specific
|
43
|
+
# identifier embedded in it.
|
44
|
+
# Project id looks like project-e00xxxxxxxxxxxxxx where
|
45
|
+
# e00 - id of region 'eu-north1'
|
46
|
+
# e01 - id of region 'eu-west1'
|
47
|
+
# TODO(SalikovAlex): fix when info about region will be in projects list
|
48
|
+
# Currently, Nebius cloud supports 2 regions. We manually enumerate
|
49
|
+
# them here. Reference: https://docs.nebius.com/overview/regions
|
50
|
+
for project in projects.items:
|
51
|
+
if region == 'eu-north1' and project.metadata.id[8:11] == 'e00':
|
52
|
+
return project.metadata.id
|
53
|
+
if region == 'eu-west1' and project.metadata.id[8:11] == 'e01':
|
54
|
+
return project.metadata.id
|
55
|
+
raise Exception(f'No project found for region "{region}".')
|
56
|
+
|
57
|
+
|
58
|
+
def get_or_create_gpu_cluster(name: str, region: str) -> str:
|
59
|
+
"""Creates a GPU cluster.
|
60
|
+
When creating a GPU cluster, select an InfiniBand fabric for it:
|
61
|
+
|
62
|
+
fabric-2, fabric-3 or fabric-4 for projects in the eu-north1 region.
|
63
|
+
fabric-5 for projects in the eu-west1 region.
|
64
|
+
|
65
|
+
https://docs.nebius.com/compute/clusters/gpu
|
66
|
+
"""
|
67
|
+
project_id = get_project_by_region(region)
|
68
|
+
service = nebius.compute().GpuClusterServiceClient(nebius.sdk())
|
69
|
+
try:
|
70
|
+
cluster = service.get_by_name(nebius.nebius_common().GetByNameRequest(
|
71
|
+
parent_id=project_id,
|
72
|
+
name=name,
|
73
|
+
)).wait()
|
74
|
+
cluster_id = cluster.metadata.id
|
75
|
+
except nebius.request_error() as no_cluster_found_error:
|
76
|
+
if region == 'eu-north1':
|
77
|
+
fabric = 'fabric-4'
|
78
|
+
elif region == 'eu-west1':
|
79
|
+
fabric = 'fabric-5'
|
80
|
+
else:
|
81
|
+
raise RuntimeError(
|
82
|
+
f'Unsupported region {region}.') from no_cluster_found_error
|
83
|
+
cluster = service.create(nebius.compute().CreateGpuClusterRequest(
|
84
|
+
metadata=nebius.nebius_common().ResourceMetadata(
|
85
|
+
parent_id=project_id,
|
86
|
+
name=name,
|
87
|
+
),
|
88
|
+
spec=nebius.compute().GpuClusterSpec(
|
89
|
+
infiniband_fabric=fabric))).wait()
|
90
|
+
cluster_id = cluster.resource_id
|
91
|
+
return cluster_id
|
92
|
+
|
93
|
+
|
94
|
+
def delete_cluster(name: str, region: str) -> None:
|
95
|
+
"""Delete a GPU cluster."""
|
96
|
+
project_id = get_project_by_region(region)
|
97
|
+
service = nebius.compute().GpuClusterServiceClient(nebius.sdk())
|
98
|
+
try:
|
99
|
+
cluster = service.get_by_name(nebius.nebius_common().GetByNameRequest(
|
100
|
+
parent_id=project_id,
|
101
|
+
name=name,
|
102
|
+
)).wait()
|
103
|
+
cluster_id = cluster.metadata.id
|
104
|
+
logger.debug(f'Found GPU Cluster : {cluster_id}.')
|
105
|
+
service.delete(
|
106
|
+
nebius.compute().DeleteGpuClusterRequest(id=cluster_id)).wait()
|
107
|
+
logger.debug(f'Deleted GPU Cluster : {cluster_id}.')
|
108
|
+
except nebius.request_error():
|
109
|
+
logger.debug('GPU Cluster does not exist.')
|
110
|
+
|
111
|
+
|
112
|
+
def list_instances(project_id: str) -> Dict[str, Dict[str, Any]]:
|
113
|
+
"""Lists instances associated with API key."""
|
114
|
+
service = nebius.compute().InstanceServiceClient(nebius.sdk())
|
115
|
+
result = service.list(
|
116
|
+
nebius.compute().ListInstancesRequest(parent_id=project_id)).wait()
|
117
|
+
|
118
|
+
instances = result
|
119
|
+
|
120
|
+
instance_dict: Dict[str, Dict[str, Any]] = {}
|
121
|
+
for instance in instances.items:
|
122
|
+
info = {}
|
123
|
+
info['status'] = instance.status.state.name
|
124
|
+
info['name'] = instance.metadata.name
|
125
|
+
if instance.status.network_interfaces:
|
126
|
+
info['external_ip'] = instance.status.network_interfaces[
|
127
|
+
0].public_ip_address.address.split('/')[0]
|
128
|
+
info['internal_ip'] = instance.status.network_interfaces[
|
129
|
+
0].ip_address.address.split('/')[0]
|
130
|
+
instance_dict[instance.metadata.id] = info
|
131
|
+
|
132
|
+
return instance_dict
|
133
|
+
|
134
|
+
|
135
|
+
def stop(instance_id: str) -> None:
|
136
|
+
service = nebius.compute().InstanceServiceClient(nebius.sdk())
|
137
|
+
service.stop(nebius.compute().StopInstanceRequest(id=instance_id)).wait()
|
138
|
+
retry_count = 0
|
139
|
+
while retry_count < nebius.MAX_RETRIES_TO_INSTANCE_STOP:
|
140
|
+
service = nebius.compute().InstanceServiceClient(nebius.sdk())
|
141
|
+
instance = service.get(nebius.compute().GetInstanceRequest(
|
142
|
+
id=instance_id,)).wait()
|
143
|
+
if instance.status.state.name == 'STOPPED':
|
144
|
+
break
|
145
|
+
time.sleep(POLL_INTERVAL)
|
146
|
+
logger.debug(f'Waiting for instance {instance_id} stopping.')
|
147
|
+
retry_count += 1
|
148
|
+
|
149
|
+
if retry_count == nebius.MAX_RETRIES_TO_INSTANCE_STOP:
|
150
|
+
raise TimeoutError(
|
151
|
+
f'Exceeded maximum retries '
|
152
|
+
f'({nebius.MAX_RETRIES_TO_INSTANCE_STOP * POLL_INTERVAL}'
|
153
|
+
f' seconds) while waiting for instance {instance_id}'
|
154
|
+
f' to be stopped.')
|
155
|
+
|
156
|
+
|
157
|
+
def start(instance_id: str) -> None:
|
158
|
+
service = nebius.compute().InstanceServiceClient(nebius.sdk())
|
159
|
+
service.start(nebius.compute().StartInstanceRequest(id=instance_id)).wait()
|
160
|
+
retry_count = 0
|
161
|
+
while retry_count < nebius.MAX_RETRIES_TO_INSTANCE_START:
|
162
|
+
service = nebius.compute().InstanceServiceClient(nebius.sdk())
|
163
|
+
instance = service.get(nebius.compute().GetInstanceRequest(
|
164
|
+
id=instance_id,)).wait()
|
165
|
+
if instance.status.state.name == 'RUNNING':
|
166
|
+
break
|
167
|
+
time.sleep(POLL_INTERVAL)
|
168
|
+
logger.debug(f'Waiting for instance {instance_id} starting.')
|
169
|
+
retry_count += 1
|
170
|
+
|
171
|
+
if retry_count == nebius.MAX_RETRIES_TO_INSTANCE_START:
|
172
|
+
raise TimeoutError(
|
173
|
+
f'Exceeded maximum retries '
|
174
|
+
f'({nebius.MAX_RETRIES_TO_INSTANCE_START * POLL_INTERVAL}'
|
175
|
+
f' seconds) while waiting for instance {instance_id}'
|
176
|
+
f' to be ready.')
|
177
|
+
|
178
|
+
|
179
|
+
def launch(cluster_name_on_cloud: str, node_type: str, platform: str,
|
180
|
+
preset: str, region: str, image_family: str, disk_size: int,
|
181
|
+
user_data: str) -> str:
|
182
|
+
# Each node must have a unique name to avoid conflicts between
|
183
|
+
# multiple worker VMs. To ensure uniqueness,a UUID is appended
|
184
|
+
# to the node name.
|
185
|
+
instance_name = (f'{cluster_name_on_cloud}-'
|
186
|
+
f'{uuid.uuid4().hex[:4]}-{node_type}')
|
187
|
+
logger.debug(f'Launching instance: {instance_name}')
|
188
|
+
|
189
|
+
disk_name = 'disk-' + instance_name
|
190
|
+
cluster_id = None
|
191
|
+
# 8 GPU virtual machines can be grouped into a GPU cluster.
|
192
|
+
# The GPU clusters are built with InfiniBand secure high-speed networking.
|
193
|
+
# https://docs.nebius.com/compute/clusters/gpu
|
194
|
+
if platform in ('gpu-h100-sxm', 'gpu-h200-sxm'):
|
195
|
+
if preset == '8gpu-128vcpu-1600gb':
|
196
|
+
cluster_id = get_or_create_gpu_cluster(cluster_name_on_cloud,
|
197
|
+
region)
|
198
|
+
|
199
|
+
project_id = get_project_by_region(region)
|
200
|
+
service = nebius.compute().DiskServiceClient(nebius.sdk())
|
201
|
+
disk = service.create(nebius.compute().CreateDiskRequest(
|
202
|
+
metadata=nebius.nebius_common().ResourceMetadata(
|
203
|
+
parent_id=project_id,
|
204
|
+
name=disk_name,
|
205
|
+
),
|
206
|
+
spec=nebius.compute().DiskSpec(
|
207
|
+
source_image_family=nebius.compute().SourceImageFamily(
|
208
|
+
image_family=image_family),
|
209
|
+
size_gibibytes=disk_size,
|
210
|
+
type=nebius.compute().DiskSpec.DiskType.NETWORK_SSD,
|
211
|
+
))).wait()
|
212
|
+
disk_id = disk.resource_id
|
213
|
+
retry_count = 0
|
214
|
+
while retry_count < nebius.MAX_RETRIES_TO_DISK_CREATE:
|
215
|
+
disk = service.get_by_name(nebius.nebius_common().GetByNameRequest(
|
216
|
+
parent_id=project_id,
|
217
|
+
name=disk_name,
|
218
|
+
)).wait()
|
219
|
+
if disk.status.state.name == 'READY':
|
220
|
+
break
|
221
|
+
logger.debug(f'Waiting for disk {disk_name} to be ready.')
|
222
|
+
time.sleep(POLL_INTERVAL)
|
223
|
+
retry_count += 1
|
224
|
+
|
225
|
+
if retry_count == nebius.MAX_RETRIES_TO_DISK_CREATE:
|
226
|
+
raise TimeoutError(
|
227
|
+
f'Exceeded maximum retries '
|
228
|
+
f'({nebius.MAX_RETRIES_TO_DISK_CREATE * POLL_INTERVAL}'
|
229
|
+
f' seconds) while waiting for disk {disk_name}'
|
230
|
+
f' to be ready.')
|
231
|
+
|
232
|
+
service = nebius.vpc().SubnetServiceClient(nebius.sdk())
|
233
|
+
sub_net = service.list(nebius.vpc().ListSubnetsRequest(
|
234
|
+
parent_id=project_id,)).wait()
|
235
|
+
|
236
|
+
service = nebius.compute().InstanceServiceClient(nebius.sdk())
|
237
|
+
service.create(nebius.compute().CreateInstanceRequest(
|
238
|
+
metadata=nebius.nebius_common().ResourceMetadata(
|
239
|
+
parent_id=project_id,
|
240
|
+
name=instance_name,
|
241
|
+
),
|
242
|
+
spec=nebius.compute().InstanceSpec(
|
243
|
+
gpu_cluster=nebius.compute().InstanceGpuClusterSpec(id=cluster_id,)
|
244
|
+
if cluster_id is not None else None,
|
245
|
+
boot_disk=nebius.compute().AttachedDiskSpec(
|
246
|
+
attach_mode=nebius.compute(
|
247
|
+
).AttachedDiskSpec.AttachMode.READ_WRITE,
|
248
|
+
existing_disk=nebius.compute().ExistingDisk(id=disk_id)),
|
249
|
+
cloud_init_user_data=user_data,
|
250
|
+
resources=nebius.compute().ResourcesSpec(platform=platform,
|
251
|
+
preset=preset),
|
252
|
+
network_interfaces=[
|
253
|
+
nebius.compute().NetworkInterfaceSpec(
|
254
|
+
subnet_id=sub_net.items[0].metadata.id,
|
255
|
+
ip_address=nebius.compute().IPAddress(),
|
256
|
+
name='network-interface-0',
|
257
|
+
public_ip_address=nebius.compute().PublicIPAddress())
|
258
|
+
]))).wait()
|
259
|
+
instance_id = ''
|
260
|
+
retry_count = 0
|
261
|
+
while retry_count < nebius.MAX_RETRIES_TO_INSTANCE_READY:
|
262
|
+
service = nebius.compute().InstanceServiceClient(nebius.sdk())
|
263
|
+
instance = service.get_by_name(nebius.nebius_common().GetByNameRequest(
|
264
|
+
parent_id=project_id,
|
265
|
+
name=instance_name,
|
266
|
+
)).wait()
|
267
|
+
if instance.status.state.name == 'STARTING':
|
268
|
+
instance_id = instance.metadata.id
|
269
|
+
break
|
270
|
+
time.sleep(POLL_INTERVAL)
|
271
|
+
logger.debug(f'Waiting for instance {instance_name} start running.')
|
272
|
+
retry_count += 1
|
273
|
+
|
274
|
+
if retry_count == nebius.MAX_RETRIES_TO_INSTANCE_READY:
|
275
|
+
raise TimeoutError(
|
276
|
+
f'Exceeded maximum retries '
|
277
|
+
f'({nebius.MAX_RETRIES_TO_INSTANCE_READY * POLL_INTERVAL}'
|
278
|
+
f' seconds) while waiting for instance {instance_name}'
|
279
|
+
f' to be ready.')
|
280
|
+
return instance_id
|
281
|
+
|
282
|
+
|
283
|
+
def remove(instance_id: str) -> None:
|
284
|
+
"""Terminates the given instance."""
|
285
|
+
service = nebius.compute().InstanceServiceClient(nebius.sdk())
|
286
|
+
result = service.get(
|
287
|
+
nebius.compute().GetInstanceRequest(id=instance_id)).wait()
|
288
|
+
disk_id = result.spec.boot_disk.existing_disk.id
|
289
|
+
service.delete(
|
290
|
+
nebius.compute().DeleteInstanceRequest(id=instance_id)).wait()
|
291
|
+
retry_count = 0
|
292
|
+
# The instance begins deleting and attempts to delete the disk.
|
293
|
+
# Must wait until the disk is unlocked and becomes deletable.
|
294
|
+
while retry_count < nebius.MAX_RETRIES_TO_DISK_DELETE:
|
295
|
+
try:
|
296
|
+
service = nebius.compute().DiskServiceClient(nebius.sdk())
|
297
|
+
service.delete(
|
298
|
+
nebius.compute().DeleteDiskRequest(id=disk_id)).wait()
|
299
|
+
break
|
300
|
+
except nebius.request_error():
|
301
|
+
logger.debug('Waiting for disk deletion.')
|
302
|
+
time.sleep(POLL_INTERVAL)
|
303
|
+
retry_count += 1
|
304
|
+
|
305
|
+
if retry_count == nebius.MAX_RETRIES_TO_DISK_DELETE:
|
306
|
+
raise TimeoutError(
|
307
|
+
f'Exceeded maximum retries '
|
308
|
+
f'({nebius.MAX_RETRIES_TO_DISK_DELETE * POLL_INTERVAL}'
|
309
|
+
f' seconds) while waiting for disk {disk_id}'
|
310
|
+
f' to be deleted.')
|
sky/server/common.py
CHANGED
@@ -15,7 +15,6 @@ import uuid
|
|
15
15
|
|
16
16
|
import colorama
|
17
17
|
import filelock
|
18
|
-
import psutil
|
19
18
|
import pydantic
|
20
19
|
import requests
|
21
20
|
|
@@ -146,13 +145,14 @@ def get_api_server_status(endpoint: Optional[str] = None) -> ApiServerInfo:
|
|
146
145
|
return ApiServerInfo(status=ApiServerStatus.UNHEALTHY, api_version=None)
|
147
146
|
|
148
147
|
|
149
|
-
def
|
148
|
+
def start_api_server_in_background(deploy: bool = False,
|
149
|
+
host: str = '127.0.0.1'):
|
150
150
|
if not is_api_server_local():
|
151
151
|
raise RuntimeError(
|
152
152
|
f'Cannot start API server: {get_server_url()} is not a local URL')
|
153
153
|
|
154
154
|
# Check available memory before starting the server.
|
155
|
-
avail_mem_size_gb: float =
|
155
|
+
avail_mem_size_gb: float = common_utils.get_mem_size_gb()
|
156
156
|
if avail_mem_size_gb <= server_constants.MIN_AVAIL_MEM_GB:
|
157
157
|
logger.warning(
|
158
158
|
f'{colorama.Fore.YELLOW}Your SkyPilot API server machine only has '
|
@@ -163,8 +163,6 @@ def start_uvicorn_in_background(deploy: bool = False, host: str = '127.0.0.1'):
|
|
163
163
|
log_path = os.path.expanduser(constants.API_SERVER_LOGS)
|
164
164
|
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
165
165
|
|
166
|
-
# The command to run uvicorn. Adjust the app:app to your application's
|
167
|
-
# location.
|
168
166
|
api_server_cmd = API_SERVER_CMD
|
169
167
|
if deploy:
|
170
168
|
api_server_cmd += ' --deploy'
|
@@ -172,7 +170,7 @@ def start_uvicorn_in_background(deploy: bool = False, host: str = '127.0.0.1'):
|
|
172
170
|
api_server_cmd += f' --host {host}'
|
173
171
|
cmd = f'{sys.executable} {api_server_cmd} > {log_path} 2>&1'
|
174
172
|
|
175
|
-
# Start the
|
173
|
+
# Start the API server process in the background and don't wait for it.
|
176
174
|
# If this is called from a CLI invocation, we need start_new_session=True so
|
177
175
|
# that SIGINT on the CLI will not also kill the API server.
|
178
176
|
subprocess.Popen(cmd, shell=True, start_new_session=True)
|
@@ -232,7 +230,7 @@ def _start_api_server(deploy: bool = False, host: str = '127.0.0.1'):
|
|
232
230
|
f'SkyPilot API server at {server_url}. '
|
233
231
|
'Starting a local server.'
|
234
232
|
f'{colorama.Style.RESET_ALL}')
|
235
|
-
|
233
|
+
start_api_server_in_background(deploy=deploy, host=host)
|
236
234
|
logger.info(ux_utils.finishing_message('SkyPilot API server started.'))
|
237
235
|
|
238
236
|
|
sky/server/requests/executor.py
CHANGED
@@ -32,7 +32,6 @@ import traceback
|
|
32
32
|
import typing
|
33
33
|
from typing import Any, Callable, Generator, List, Optional, TextIO, Tuple
|
34
34
|
|
35
|
-
import psutil
|
36
35
|
import setproctitle
|
37
36
|
|
38
37
|
from sky import global_user_state
|
@@ -70,18 +69,36 @@ logger = sky_logging.init_logger(__name__)
|
|
70
69
|
# platforms, including macOS.
|
71
70
|
multiprocessing.set_start_method('spawn', force=True)
|
72
71
|
|
73
|
-
# Constants based on profiling the peak memory usage
|
74
|
-
#
|
75
|
-
#
|
76
|
-
|
77
|
-
|
78
|
-
#
|
79
|
-
|
80
|
-
|
81
|
-
#
|
72
|
+
# Constants based on profiling the peak memory usage while serving various
|
73
|
+
# sky commands. These estimation are highly related to usage patterns
|
74
|
+
# (clouds enabled, type of requests, etc. see `tests/load_tests` for details.),
|
75
|
+
# the profiling covers major clouds and common usage patterns. For user has
|
76
|
+
# deviated usage pattern, they can override the default estimation by
|
77
|
+
# environment variables.
|
78
|
+
# NOTE(dev): update these constants for each release according to the load
|
79
|
+
# test results.
|
80
|
+
# TODO(aylei): maintaining these constants is error-prone, we may need to
|
81
|
+
# automatically tune parallelism at runtime according to system usage stats
|
82
|
+
# in the future.
|
83
|
+
_LONG_WORKER_MEM_GB = 0.4
|
84
|
+
_SHORT_WORKER_MEM_GB = 0.25
|
85
|
+
# To control the number of long workers.
|
86
|
+
_CPU_MULTIPLIER_FOR_LONG_WORKERS = 2
|
87
|
+
# Limit the number of long workers of local API server, since local server is
|
88
|
+
# typically:
|
89
|
+
# 1. launched automatically in an environment with high resource contention
|
90
|
+
# (e.g. Laptop)
|
91
|
+
# 2. used by a single user
|
92
|
+
_MAX_LONG_WORKERS_LOCAL = 4
|
93
|
+
# Percentage of memory for long requests
|
82
94
|
# from the memory reserved for SkyPilot.
|
83
|
-
# This is to reserve some memory for
|
95
|
+
# This is to reserve some memory for short requests.
|
84
96
|
_MAX_MEM_PERCENT_FOR_BLOCKING = 0.6
|
97
|
+
# Minimal number of long workers to ensure responsiveness.
|
98
|
+
_MIN_LONG_WORKERS = 1
|
99
|
+
# Minimal number of short workers, there is a daemon task running on short
|
100
|
+
# workers so at least 2 workers are needed to ensure responsiveness.
|
101
|
+
_MIN_SHORT_WORKERS = 2
|
85
102
|
|
86
103
|
|
87
104
|
class QueueBackend(enum.Enum):
|
@@ -301,34 +318,32 @@ def schedule_request(request_id: str,
|
|
301
318
|
_get_queue(schedule_type).put(input_tuple)
|
302
319
|
|
303
320
|
|
321
|
+
def executor_initializer(proc_group: str):
|
322
|
+
setproctitle.setproctitle(f'SkyPilot:executor:{proc_group}:'
|
323
|
+
f'{multiprocessing.current_process().pid}')
|
324
|
+
|
325
|
+
|
304
326
|
def request_worker(worker: RequestWorker, max_parallel_size: int) -> None:
|
305
327
|
"""Worker for the requests.
|
306
328
|
|
307
329
|
Args:
|
308
330
|
max_parallel_size: Maximum number of parallel jobs this worker can run.
|
309
331
|
"""
|
310
|
-
|
311
|
-
|
312
|
-
setproctitle.setproctitle(
|
313
|
-
f'SkyPilot:worker:{worker.schedule_type.value}-{worker.id}')
|
332
|
+
proc_group = f'{worker.schedule_type.value}-{worker.id}'
|
333
|
+
setproctitle.setproctitle(f'SkyPilot:worker:{proc_group}')
|
314
334
|
queue = _get_queue(worker.schedule_type)
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
# We use executor instead of individual multiprocessing.Process to avoid
|
319
|
-
# the overhead of forking a new process for each request, which can be about
|
320
|
-
# 1s delay.
|
321
|
-
with concurrent.futures.ProcessPoolExecutor(
|
322
|
-
max_workers=max_parallel_size) as executor:
|
323
|
-
while True:
|
335
|
+
|
336
|
+
def process_request(executor: concurrent.futures.ProcessPoolExecutor):
|
337
|
+
try:
|
324
338
|
request_element = queue.get()
|
325
339
|
if request_element is None:
|
326
340
|
time.sleep(0.1)
|
327
|
-
|
341
|
+
return
|
328
342
|
request_id, ignore_return_value = request_element
|
329
343
|
request = api_requests.get_request(request_id)
|
344
|
+
assert request is not None, f'Request with ID {request_id} is None'
|
330
345
|
if request.status == api_requests.RequestStatus.CANCELLED:
|
331
|
-
|
346
|
+
return
|
332
347
|
logger.info(f'[{worker}] Submitting request: {request_id}')
|
333
348
|
# Start additional process to run the request, so that it can be
|
334
349
|
# cancelled when requested by a user.
|
@@ -347,60 +362,49 @@ def request_worker(worker: RequestWorker, max_parallel_size: int) -> None:
|
|
347
362
|
logger.info(f'[{worker}] Finished request: {request_id}')
|
348
363
|
else:
|
349
364
|
logger.info(f'[{worker}] Submitted request: {request_id}')
|
365
|
+
except KeyboardInterrupt:
|
366
|
+
# Interrupt the worker process will stop request execution, but
|
367
|
+
# the SIGTERM request should be respected anyway since it might
|
368
|
+
# be explicitly sent by user.
|
369
|
+
# TODO(aylei): crash the API server or recreate the worker process
|
370
|
+
# to avoid broken state.
|
371
|
+
logger.error(f'[{worker}] Worker process interrupted')
|
372
|
+
raise
|
373
|
+
except (Exception, SystemExit) as e: # pylint: disable=broad-except
|
374
|
+
# Catch any other exceptions to avoid crashing the worker process.
|
375
|
+
logger.error(
|
376
|
+
f'[{worker}] Error processing request {request_id}: '
|
377
|
+
f'{common_utils.format_exception(e, use_bracket=True)}')
|
350
378
|
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
with ux_utils.print_exception_no_traceback():
|
364
|
-
raise ValueError(
|
365
|
-
f'Failed to parse the number of CPUs from {cpu_count}'
|
366
|
-
) from e
|
367
|
-
return psutil.cpu_count()
|
368
|
-
|
369
|
-
|
370
|
-
def _get_mem_size_gb() -> float:
|
371
|
-
"""Get the memory size in GB.
|
372
|
-
|
373
|
-
If the API server is deployed as a pod in k8s cluster, we assume the
|
374
|
-
memory size is provided by the downward API.
|
375
|
-
"""
|
376
|
-
mem_size = os.getenv('SKYPILOT_POD_MEMORY_GB_LIMIT')
|
377
|
-
if mem_size is not None:
|
378
|
-
try:
|
379
|
-
return float(mem_size)
|
380
|
-
except ValueError as e:
|
381
|
-
with ux_utils.print_exception_no_traceback():
|
382
|
-
raise ValueError(
|
383
|
-
f'Failed to parse the memory size from {mem_size}') from e
|
384
|
-
return psutil.virtual_memory().total / (1024**3)
|
379
|
+
# Use concurrent.futures.ProcessPoolExecutor instead of multiprocessing.Pool
|
380
|
+
# because the former is more efficient with the support of lazy creation of
|
381
|
+
# worker processes.
|
382
|
+
# We use executor instead of individual multiprocessing.Process to avoid
|
383
|
+
# the overhead of forking a new process for each request, which can be about
|
384
|
+
# 1s delay.
|
385
|
+
with concurrent.futures.ProcessPoolExecutor(
|
386
|
+
max_workers=max_parallel_size,
|
387
|
+
initializer=executor_initializer,
|
388
|
+
initargs=(proc_group,)) as executor:
|
389
|
+
while True:
|
390
|
+
process_request(executor)
|
385
391
|
|
386
392
|
|
387
393
|
def start(deploy: bool) -> List[multiprocessing.Process]:
|
388
394
|
"""Start the request workers."""
|
389
395
|
# Determine the job capacity of the workers based on the system resources.
|
390
|
-
cpu_count =
|
391
|
-
mem_size_gb =
|
396
|
+
cpu_count = common_utils.get_cpu_count()
|
397
|
+
mem_size_gb = common_utils.get_mem_size_gb()
|
392
398
|
mem_size_gb = max(0, mem_size_gb - server_constants.MIN_AVAIL_MEM_GB)
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
max_parallel_for_non_blocking = _max_parallel_size_for_non_blocking(
|
399
|
-
mem_size_gb, parallel_for_blocking)
|
399
|
+
max_parallel_for_long = _max_long_worker_parallism(cpu_count,
|
400
|
+
mem_size_gb,
|
401
|
+
local=not deploy)
|
402
|
+
max_parallel_for_short = _max_short_worker_parallism(
|
403
|
+
mem_size_gb, max_parallel_for_long)
|
400
404
|
logger.info(
|
401
|
-
f'SkyPilot API server will start {
|
402
|
-
f'
|
403
|
-
f'{
|
405
|
+
f'SkyPilot API server will start {max_parallel_for_long} workers for '
|
406
|
+
f'long requests and will allow at max '
|
407
|
+
f'{max_parallel_for_short} short requests in parallel.')
|
404
408
|
|
405
409
|
# Setup the queues.
|
406
410
|
if queue_backend == QueueBackend.MULTIPROCESSING:
|
@@ -424,7 +428,7 @@ def start(deploy: bool) -> List[multiprocessing.Process]:
|
|
424
428
|
logger.info('Request queues created')
|
425
429
|
|
426
430
|
worker_procs = []
|
427
|
-
for worker_id in range(
|
431
|
+
for worker_id in range(max_parallel_for_long):
|
428
432
|
worker = RequestWorker(id=worker_id,
|
429
433
|
schedule_type=api_requests.ScheduleType.LONG)
|
430
434
|
worker_proc = multiprocessing.Process(target=request_worker,
|
@@ -432,31 +436,34 @@ def start(deploy: bool) -> List[multiprocessing.Process]:
|
|
432
436
|
worker_proc.start()
|
433
437
|
worker_procs.append(worker_proc)
|
434
438
|
|
435
|
-
# Start a
|
439
|
+
# Start a worker for short requests.
|
436
440
|
worker = RequestWorker(id=1, schedule_type=api_requests.ScheduleType.SHORT)
|
437
441
|
worker_proc = multiprocessing.Process(target=request_worker,
|
438
|
-
args=(worker,
|
439
|
-
max_parallel_for_non_blocking))
|
442
|
+
args=(worker, max_parallel_for_short))
|
440
443
|
worker_proc.start()
|
441
444
|
worker_procs.append(worker_proc)
|
442
445
|
return worker_procs
|
443
446
|
|
444
447
|
|
445
448
|
@annotations.lru_cache(scope='global', maxsize=1)
|
446
|
-
def
|
447
|
-
|
448
|
-
|
449
|
+
def _max_long_worker_parallism(cpu_count: int,
|
450
|
+
mem_size_gb: float,
|
451
|
+
local=False) -> int:
|
452
|
+
"""Max parallelism for long workers."""
|
453
|
+
cpu_based_max_parallel = cpu_count * _CPU_MULTIPLIER_FOR_LONG_WORKERS
|
449
454
|
mem_based_max_parallel = int(mem_size_gb * _MAX_MEM_PERCENT_FOR_BLOCKING /
|
450
|
-
|
451
|
-
n = max(
|
455
|
+
_LONG_WORKER_MEM_GB)
|
456
|
+
n = max(_MIN_LONG_WORKERS,
|
457
|
+
min(cpu_based_max_parallel, mem_based_max_parallel))
|
458
|
+
if local:
|
459
|
+
return min(n, _MAX_LONG_WORKERS_LOCAL)
|
452
460
|
return n
|
453
461
|
|
454
462
|
|
455
463
|
@annotations.lru_cache(scope='global', maxsize=1)
|
456
|
-
def
|
457
|
-
|
458
|
-
"""Max parallelism for
|
459
|
-
available_mem = mem_size_gb - (
|
460
|
-
|
461
|
-
n = max(1, int(available_mem / _PER_NON_BLOCKING_REQUEST_MEM_GB))
|
464
|
+
def _max_short_worker_parallism(mem_size_gb: float,
|
465
|
+
long_worker_parallism: int) -> int:
|
466
|
+
"""Max parallelism for short workers."""
|
467
|
+
available_mem = mem_size_gb - (long_worker_parallism * _LONG_WORKER_MEM_GB)
|
468
|
+
n = max(_MIN_SHORT_WORKERS, int(available_mem / _SHORT_WORKER_MEM_GB))
|
462
469
|
return n
|
sky/server/server.py
CHANGED
@@ -57,7 +57,9 @@ P = ParamSpec('P')
|
|
57
57
|
|
58
58
|
def _add_timestamp_prefix_for_server_logs() -> None:
|
59
59
|
server_logger = sky_logging.init_logger('sky.server')
|
60
|
-
#
|
60
|
+
# Clear existing handlers first to prevent duplicates
|
61
|
+
server_logger.handlers.clear()
|
62
|
+
# Disable propagation to avoid the root logger of SkyPilot being affected
|
61
63
|
server_logger.propagate = False
|
62
64
|
# Add date prefix to the log message printed by loggers under
|
63
65
|
# server.
|
@@ -460,6 +462,7 @@ async def launch(launch_body: payloads.LaunchBody,
|
|
460
462
|
request: fastapi.Request) -> None:
|
461
463
|
"""Launches a cluster or task."""
|
462
464
|
request_id = request.state.request_id
|
465
|
+
logger.info(f'Launching request: {request_id}')
|
463
466
|
executor.schedule_request(
|
464
467
|
request_id,
|
465
468
|
request_name='launch',
|
@@ -627,6 +630,9 @@ async def logs(
|
|
627
630
|
request_name='logs',
|
628
631
|
request_body=cluster_job_body,
|
629
632
|
func=core.tail_logs,
|
633
|
+
# TODO(aylei): We have tail logs scheduled as SHORT request, because it
|
634
|
+
# should be responsive. However, it can be long running if the user's
|
635
|
+
# job keeps running, and we should avoid it taking the SHORT worker.
|
630
636
|
schedule_type=requests_lib.ScheduleType.SHORT,
|
631
637
|
request_cluster_name=cluster_job_body.cluster_name,
|
632
638
|
)
|
@@ -794,10 +800,9 @@ async def api_get(request_id: str) -> requests_lib.RequestPayload:
|
|
794
800
|
detail=dataclasses.asdict(
|
795
801
|
request_task.encode()))
|
796
802
|
return request_task.encode()
|
797
|
-
#
|
798
|
-
#
|
799
|
-
|
800
|
-
await asyncio.sleep(0)
|
803
|
+
# yield control to allow other coroutines to run, sleep shortly
|
804
|
+
# to avoid storming the DB and CPU in the meantime
|
805
|
+
await asyncio.sleep(0.1)
|
801
806
|
|
802
807
|
|
803
808
|
@app.get('/api/stream')
|