skypilot-nightly 1.0.0.dev20251210__py3-none-any.whl → 1.0.0.dev20260112__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +4 -2
- sky/adaptors/slurm.py +159 -72
- sky/backends/backend_utils.py +52 -10
- sky/backends/cloud_vm_ray_backend.py +192 -32
- sky/backends/task_codegen.py +40 -2
- sky/catalog/data_fetchers/fetch_gcp.py +9 -1
- sky/catalog/data_fetchers/fetch_nebius.py +1 -1
- sky/catalog/data_fetchers/fetch_vast.py +4 -2
- sky/catalog/seeweb_catalog.py +30 -15
- sky/catalog/shadeform_catalog.py +5 -2
- sky/catalog/slurm_catalog.py +0 -7
- sky/catalog/vast_catalog.py +30 -6
- sky/check.py +11 -8
- sky/client/cli/command.py +106 -54
- sky/client/interactive_utils.py +190 -0
- sky/client/sdk.py +8 -0
- sky/client/sdk_async.py +9 -0
- sky/clouds/aws.py +60 -2
- sky/clouds/azure.py +2 -0
- sky/clouds/kubernetes.py +2 -0
- sky/clouds/runpod.py +38 -7
- sky/clouds/slurm.py +44 -12
- sky/clouds/ssh.py +1 -1
- sky/clouds/vast.py +30 -17
- sky/core.py +69 -1
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/3nu-b8raeKRNABZ2d4GAG/_buildManifest.js +1 -0
- sky/dashboard/out/_next/static/chunks/1871-0565f8975a7dcd10.js +6 -0
- sky/dashboard/out/_next/static/chunks/2109-55a1546d793574a7.js +11 -0
- sky/dashboard/out/_next/static/chunks/2521-099b07cd9e4745bf.js +26 -0
- sky/dashboard/out/_next/static/chunks/2755.a636e04a928a700e.js +31 -0
- sky/dashboard/out/_next/static/chunks/3495.05eab4862217c1a5.js +6 -0
- sky/dashboard/out/_next/static/chunks/3785.cfc5dcc9434fd98c.js +1 -0
- sky/dashboard/out/_next/static/chunks/3981.645d01bf9c8cad0c.js +21 -0
- sky/dashboard/out/_next/static/chunks/4083-0115d67c1fb57d6c.js +21 -0
- sky/dashboard/out/_next/static/chunks/{8640.5b9475a2d18c5416.js → 429.a58e9ba9742309ed.js} +2 -2
- sky/dashboard/out/_next/static/chunks/4555.8e221537181b5dc1.js +6 -0
- sky/dashboard/out/_next/static/chunks/4725.937865b81fdaaebb.js +6 -0
- sky/dashboard/out/_next/static/chunks/6082-edabd8f6092300ce.js +25 -0
- sky/dashboard/out/_next/static/chunks/6989-49cb7dca83a7a62d.js +1 -0
- sky/dashboard/out/_next/static/chunks/6990-630bd2a2257275f8.js +1 -0
- sky/dashboard/out/_next/static/chunks/7248-a99800d4db8edabd.js +1 -0
- sky/dashboard/out/_next/static/chunks/754-cfc5d4ad1b843d29.js +18 -0
- sky/dashboard/out/_next/static/chunks/8050-dd8aa107b17dce00.js +16 -0
- sky/dashboard/out/_next/static/chunks/8056-d4ae1e0cb81e7368.js +1 -0
- sky/dashboard/out/_next/static/chunks/8555.011023e296c127b3.js +6 -0
- sky/dashboard/out/_next/static/chunks/8821-93c25df904a8362b.js +1 -0
- sky/dashboard/out/_next/static/chunks/8969-0662594b69432ade.js +1 -0
- sky/dashboard/out/_next/static/chunks/9025.f15c91c97d124a5f.js +6 -0
- sky/dashboard/out/_next/static/chunks/{9353-8369df1cf105221c.js → 9353-7ad6bd01858556f1.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/_app-5a86569acad99764.js +34 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-8297476714acb4ac.js +6 -0
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-337c3ba1085f1210.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/{clusters-9e5d47818b9bdadd.js → clusters-57632ff3684a8b5c.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/infra/[context]-5fd3a453c079c2ea.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/infra-9f85c02c9c6cae9e.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-90f16972cbecf354.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-2dd42fc37aad427a.js +16 -0
- sky/dashboard/out/_next/static/chunks/pages/jobs-ed806aeace26b972.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/users-bec34706b36f3524.js +1 -0
- sky/dashboard/out/_next/static/chunks/pages/{volumes-ef19d49c6d0e8500.js → volumes-a83ba9b38dff7ea9.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-96e0f298308da7e2.js → [name]-c781e9c3e52ef9fc.js} +1 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces-91e0942f47310aae.js +1 -0
- sky/dashboard/out/_next/static/chunks/webpack-cfe59cf684ee13b9.js +1 -0
- sky/dashboard/out/_next/static/css/b0dbca28f027cc19.css +3 -0
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/plugins/[...slug].html +1 -1
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/volumes.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/data/data_utils.py +26 -12
- sky/data/mounting_utils.py +29 -4
- sky/global_user_state.py +108 -16
- sky/jobs/client/sdk.py +8 -3
- sky/jobs/controller.py +191 -31
- sky/jobs/recovery_strategy.py +109 -11
- sky/jobs/server/core.py +81 -4
- sky/jobs/server/server.py +14 -0
- sky/jobs/state.py +417 -19
- sky/jobs/utils.py +73 -80
- sky/models.py +9 -0
- sky/optimizer.py +2 -1
- sky/provision/__init__.py +11 -9
- sky/provision/kubernetes/utils.py +122 -15
- sky/provision/kubernetes/volume.py +52 -17
- sky/provision/provisioner.py +2 -1
- sky/provision/runpod/instance.py +3 -1
- sky/provision/runpod/utils.py +13 -1
- sky/provision/runpod/volume.py +25 -9
- sky/provision/slurm/instance.py +75 -29
- sky/provision/slurm/utils.py +213 -107
- sky/provision/vast/utils.py +1 -0
- sky/resources.py +135 -13
- sky/schemas/api/responses.py +4 -0
- sky/schemas/db/global_user_state/010_save_ssh_key.py +1 -1
- sky/schemas/db/spot_jobs/008_add_full_resources.py +34 -0
- sky/schemas/db/spot_jobs/009_job_events.py +32 -0
- sky/schemas/db/spot_jobs/010_job_events_timestamp_with_timezone.py +43 -0
- sky/schemas/db/spot_jobs/011_add_links.py +34 -0
- sky/schemas/generated/jobsv1_pb2.py +9 -5
- sky/schemas/generated/jobsv1_pb2.pyi +12 -0
- sky/schemas/generated/jobsv1_pb2_grpc.py +44 -0
- sky/schemas/generated/managed_jobsv1_pb2.py +32 -28
- sky/schemas/generated/managed_jobsv1_pb2.pyi +11 -2
- sky/serve/serve_utils.py +232 -40
- sky/server/common.py +17 -0
- sky/server/constants.py +1 -1
- sky/server/metrics.py +6 -3
- sky/server/plugins.py +16 -0
- sky/server/requests/payloads.py +18 -0
- sky/server/requests/request_names.py +2 -0
- sky/server/requests/requests.py +28 -10
- sky/server/requests/serializers/encoders.py +5 -0
- sky/server/requests/serializers/return_value_serializers.py +14 -4
- sky/server/server.py +434 -107
- sky/server/uvicorn.py +5 -0
- sky/setup_files/MANIFEST.in +1 -0
- sky/setup_files/dependencies.py +21 -10
- sky/sky_logging.py +2 -1
- sky/skylet/constants.py +22 -5
- sky/skylet/executor/slurm.py +4 -6
- sky/skylet/job_lib.py +89 -4
- sky/skylet/services.py +18 -3
- sky/ssh_node_pools/deploy/tunnel/cleanup-tunnel.sh +62 -0
- sky/ssh_node_pools/deploy/tunnel/ssh-tunnel.sh +379 -0
- sky/templates/kubernetes-ray.yml.j2 +4 -6
- sky/templates/slurm-ray.yml.j2 +32 -2
- sky/templates/websocket_proxy.py +18 -41
- sky/users/permission.py +61 -51
- sky/utils/auth_utils.py +42 -0
- sky/utils/cli_utils/status_utils.py +19 -5
- sky/utils/cluster_utils.py +10 -3
- sky/utils/command_runner.py +256 -94
- sky/utils/command_runner.pyi +16 -0
- sky/utils/common_utils.py +30 -29
- sky/utils/context.py +32 -0
- sky/utils/db/db_utils.py +36 -6
- sky/utils/db/migration_utils.py +41 -21
- sky/utils/infra_utils.py +5 -1
- sky/utils/instance_links.py +139 -0
- sky/utils/interactive_utils.py +49 -0
- sky/utils/kubernetes/generate_kubeconfig.sh +42 -33
- sky/utils/kubernetes/rsync_helper.sh +5 -1
- sky/utils/plugin_extensions/__init__.py +14 -0
- sky/utils/plugin_extensions/external_failure_source.py +176 -0
- sky/utils/resources_utils.py +10 -8
- sky/utils/rich_utils.py +9 -11
- sky/utils/schemas.py +63 -20
- sky/utils/status_lib.py +7 -0
- sky/utils/subprocess_utils.py +17 -0
- sky/volumes/client/sdk.py +6 -3
- sky/volumes/server/core.py +65 -27
- sky_templates/ray/start_cluster +8 -4
- {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/METADATA +53 -57
- {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/RECORD +172 -162
- sky/dashboard/out/_next/static/KYAhEFa3FTfq4JyKVgo-s/_buildManifest.js +0 -1
- sky/dashboard/out/_next/static/chunks/1141-9c810f01ff4f398a.js +0 -11
- sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +0 -6
- sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +0 -1
- sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +0 -1
- sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +0 -15
- sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +0 -26
- sky/dashboard/out/_next/static/chunks/3294.ddda8c6c6f9f24dc.js +0 -1
- sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +0 -1
- sky/dashboard/out/_next/static/chunks/3800-b589397dc09c5b4e.js +0 -1
- sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +0 -1
- sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +0 -15
- sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +0 -13
- sky/dashboard/out/_next/static/chunks/6856-da20c5fd999f319c.js +0 -1
- sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +0 -1
- sky/dashboard/out/_next/static/chunks/6990-09cbf02d3cd518c3.js +0 -1
- sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +0 -30
- sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +0 -41
- sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +0 -1
- sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +0 -1
- sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +0 -6
- sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +0 -31
- sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +0 -30
- sky/dashboard/out/_next/static/chunks/pages/_app-68b647e26f9d2793.js +0 -34
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-33f525539665fdfd.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-a7565f586ef86467.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/infra/[context]-12c559ec4d81fdbd.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/infra-d187cd0413d72475.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-895847b6cf200b04.js +0 -16
- sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-8d0f4655400b4eb9.js +0 -21
- sky/dashboard/out/_next/static/chunks/pages/jobs-e5a98f17f8513a96.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/users-2f7646eb77785a2c.js +0 -1
- sky/dashboard/out/_next/static/chunks/pages/workspaces-cb4da3abe08ebf19.js +0 -1
- sky/dashboard/out/_next/static/chunks/webpack-fba3de387ff6bb08.js +0 -1
- sky/dashboard/out/_next/static/css/c5a4cfd2600fc715.css +0 -3
- /sky/dashboard/out/_next/static/{KYAhEFa3FTfq4JyKVgo-s → 3nu-b8raeKRNABZ2d4GAG}/_ssgManifest.js +0 -0
- /sky/dashboard/out/_next/static/chunks/pages/plugins/{[...slug]-4f46050ca065d8f8.js → [...slug]-449a9f5a3bb20fb3.js} +0 -0
- {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20251210.dist-info → skypilot_nightly-1.0.0.dev20260112.dist-info}/top_level.txt +0 -0
sky/provision/slurm/utils.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
"""Slurm utilities for SkyPilot."""
|
|
2
|
+
import json
|
|
2
3
|
import math
|
|
3
4
|
import os
|
|
4
5
|
import re
|
|
6
|
+
import shlex
|
|
7
|
+
import time
|
|
5
8
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
9
|
|
|
7
10
|
from paramiko.config import SSHConfig
|
|
@@ -9,15 +12,39 @@ from paramiko.config import SSHConfig
|
|
|
9
12
|
from sky import exceptions
|
|
10
13
|
from sky import sky_logging
|
|
11
14
|
from sky.adaptors import slurm
|
|
15
|
+
from sky.skylet import constants
|
|
12
16
|
from sky.utils import annotations
|
|
13
17
|
from sky.utils import common_utils
|
|
18
|
+
from sky.utils.db import kv_cache
|
|
14
19
|
|
|
15
20
|
logger = sky_logging.init_logger(__name__)
|
|
16
21
|
|
|
17
|
-
# TODO(jwj): Choose commonly used default values.
|
|
18
22
|
DEFAULT_SLURM_PATH = '~/.slurm/config'
|
|
19
|
-
|
|
20
|
-
|
|
23
|
+
SLURM_MARKER_FILE = '.sky_slurm_cluster'
|
|
24
|
+
|
|
25
|
+
# Regex pattern for parsing GPU GRES strings.
|
|
26
|
+
# Format: 'gpu[:acc_type]:acc_count(optional_extra_info)'
|
|
27
|
+
# Examples: 'gpu:8', 'gpu:H100:8', 'gpu:nvidia_h100_80gb_hbm3:8(S:0-1)'
|
|
28
|
+
_GRES_GPU_PATTERN = re.compile(r'\bgpu:(?:(?P<type>[^:(]+):)?(?P<count>\d+)',
|
|
29
|
+
re.IGNORECASE)
|
|
30
|
+
|
|
31
|
+
_SLURM_NODES_INFO_CACHE_TTL = 30 * 60
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_gpu_type_and_count(gres_str: str) -> Tuple[Optional[str], int]:
|
|
35
|
+
"""Parses GPU type and count from a GRES string.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
A tuple of (GPU type, GPU count). If no GPU is found, returns (None, 0).
|
|
39
|
+
"""
|
|
40
|
+
match = _GRES_GPU_PATTERN.search(gres_str)
|
|
41
|
+
if not match:
|
|
42
|
+
return None, 0
|
|
43
|
+
return match.group('type'), int(match.group('count'))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# SSH host key filename for sshd.
|
|
47
|
+
SLURM_SSHD_HOST_KEY_FILENAME = 'skypilot_host_key'
|
|
21
48
|
|
|
22
49
|
|
|
23
50
|
def get_slurm_ssh_config() -> SSHConfig:
|
|
@@ -27,6 +54,42 @@ def get_slurm_ssh_config() -> SSHConfig:
|
|
|
27
54
|
return slurm_config
|
|
28
55
|
|
|
29
56
|
|
|
57
|
+
@annotations.lru_cache(scope='request')
|
|
58
|
+
def _get_slurm_nodes_info(cluster: str) -> List[slurm.NodeInfo]:
|
|
59
|
+
cache_key = f'slurm:nodes_info:{cluster}'
|
|
60
|
+
cached = kv_cache.get_cache_entry(cache_key)
|
|
61
|
+
if cached is not None:
|
|
62
|
+
logger.debug(f'Slurm nodes info found in cache ({cache_key})')
|
|
63
|
+
return [slurm.NodeInfo(**item) for item in json.loads(cached)]
|
|
64
|
+
|
|
65
|
+
ssh_config = get_slurm_ssh_config()
|
|
66
|
+
ssh_config_dict = ssh_config.lookup(cluster)
|
|
67
|
+
client = slurm.SlurmClient(
|
|
68
|
+
ssh_config_dict['hostname'],
|
|
69
|
+
int(ssh_config_dict.get('port', 22)),
|
|
70
|
+
ssh_config_dict['user'],
|
|
71
|
+
ssh_config_dict['identityfile'][0],
|
|
72
|
+
ssh_proxy_command=ssh_config_dict.get('proxycommand', None),
|
|
73
|
+
ssh_proxy_jump=ssh_config_dict.get('proxyjump', None),
|
|
74
|
+
)
|
|
75
|
+
nodes_info = client.info_nodes()
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
# Nodes in a cluster are unlikely to change frequently, so cache
|
|
79
|
+
# the result for a short period of time.
|
|
80
|
+
kv_cache.add_or_update_cache_entry(
|
|
81
|
+
cache_key, json.dumps([n._asdict() for n in nodes_info]),
|
|
82
|
+
time.time() + _SLURM_NODES_INFO_CACHE_TTL)
|
|
83
|
+
except Exception as e: # pylint: disable=broad-except
|
|
84
|
+
# Catch the error and continue.
|
|
85
|
+
# Failure to cache the result is not critical to the
|
|
86
|
+
# success of this function.
|
|
87
|
+
logger.debug(f'Failed to cache slurm nodes info for {cluster}: '
|
|
88
|
+
f'{common_utils.format_exception(e)}')
|
|
89
|
+
|
|
90
|
+
return nodes_info
|
|
91
|
+
|
|
92
|
+
|
|
30
93
|
class SlurmInstanceType:
|
|
31
94
|
"""Class to represent the "Instance Type" in a Slurm cluster.
|
|
32
95
|
|
|
@@ -170,35 +233,23 @@ def instance_id(job_id: str, node: str) -> str:
|
|
|
170
233
|
return f'job{job_id}-{node}'
|
|
171
234
|
|
|
172
235
|
|
|
173
|
-
def get_cluster_name_from_config(provider_config: Dict[str, Any]) -> str:
|
|
174
|
-
"""Return the cluster name from the provider config.
|
|
175
|
-
|
|
176
|
-
The concept of cluster can be mapped to a cloud region.
|
|
177
|
-
"""
|
|
178
|
-
return provider_config.get('cluster', DEFAULT_CLUSTER_NAME)
|
|
179
|
-
|
|
180
|
-
|
|
181
236
|
def get_partition_from_config(provider_config: Dict[str, Any]) -> str:
|
|
182
237
|
"""Return the partition from the provider config.
|
|
183
238
|
|
|
184
239
|
The concept of partition can be mapped to a cloud zone.
|
|
185
240
|
"""
|
|
186
|
-
|
|
241
|
+
partition = provider_config.get('partition')
|
|
242
|
+
if partition is None:
|
|
243
|
+
raise ValueError('Partition not specified in provider config.')
|
|
244
|
+
return partition
|
|
187
245
|
|
|
188
246
|
|
|
189
247
|
@annotations.lru_cache(scope='request')
|
|
190
|
-
def get_cluster_default_partition(cluster_name: str) -> str:
|
|
248
|
+
def get_cluster_default_partition(cluster_name: str) -> Optional[str]:
|
|
191
249
|
"""Get the default partition for a Slurm cluster.
|
|
192
250
|
|
|
193
251
|
Queries the Slurm cluster for the partition marked with an asterisk (*)
|
|
194
|
-
in sinfo output.
|
|
195
|
-
no default partition is found.
|
|
196
|
-
|
|
197
|
-
Args:
|
|
198
|
-
cluster_name: Name of the Slurm cluster.
|
|
199
|
-
|
|
200
|
-
Returns:
|
|
201
|
-
The default partition name for the cluster.
|
|
252
|
+
in sinfo output. If no default partition is marked, returns None.
|
|
202
253
|
"""
|
|
203
254
|
try:
|
|
204
255
|
ssh_config = get_slurm_ssh_config()
|
|
@@ -214,16 +265,10 @@ def get_cluster_default_partition(cluster_name: str) -> str:
|
|
|
214
265
|
ssh_config_dict['user'],
|
|
215
266
|
ssh_config_dict['identityfile'][0],
|
|
216
267
|
ssh_proxy_command=ssh_config_dict.get('proxycommand', None),
|
|
268
|
+
ssh_proxy_jump=ssh_config_dict.get('proxyjump', None),
|
|
217
269
|
)
|
|
218
270
|
|
|
219
|
-
|
|
220
|
-
if default_partition is None:
|
|
221
|
-
# TODO(kevin): Have a way to specify default partition in
|
|
222
|
-
# ~/.sky/config.yaml if needed, in case a Slurm cluster
|
|
223
|
-
# really does not have a default partition.
|
|
224
|
-
raise ValueError('No default partition found for cluster '
|
|
225
|
-
f'{cluster_name}.')
|
|
226
|
-
return default_partition
|
|
271
|
+
return client.get_default_partition()
|
|
227
272
|
|
|
228
273
|
|
|
229
274
|
def get_all_slurm_cluster_names() -> List[str]:
|
|
@@ -296,7 +341,7 @@ def check_instance_fits(
|
|
|
296
341
|
"""
|
|
297
342
|
# Get Slurm node list in the given cluster (region).
|
|
298
343
|
try:
|
|
299
|
-
|
|
344
|
+
nodes = _get_slurm_nodes_info(cluster)
|
|
300
345
|
except FileNotFoundError:
|
|
301
346
|
return (False, f'Could not query Slurm cluster {cluster} '
|
|
302
347
|
f'because the Slurm configuration file '
|
|
@@ -305,20 +350,13 @@ def check_instance_fits(
|
|
|
305
350
|
return (False, f'Could not query Slurm cluster {cluster} '
|
|
306
351
|
f'because Slurm SSH configuration at {DEFAULT_SLURM_PATH} '
|
|
307
352
|
f'could not be loaded: {common_utils.format_exception(e)}.')
|
|
308
|
-
ssh_config_dict = ssh_config.lookup(cluster)
|
|
309
|
-
|
|
310
|
-
client = slurm.SlurmClient(
|
|
311
|
-
ssh_config_dict['hostname'],
|
|
312
|
-
int(ssh_config_dict.get('port', 22)),
|
|
313
|
-
ssh_config_dict['user'],
|
|
314
|
-
ssh_config_dict['identityfile'][0],
|
|
315
|
-
ssh_proxy_command=ssh_config_dict.get('proxycommand', None),
|
|
316
|
-
)
|
|
317
353
|
|
|
318
|
-
nodes = client.info_nodes()
|
|
319
354
|
default_partition = get_cluster_default_partition(cluster)
|
|
320
355
|
|
|
321
356
|
def is_default_partition(node_partition: str) -> bool:
|
|
357
|
+
if default_partition is None:
|
|
358
|
+
return False
|
|
359
|
+
|
|
322
360
|
# info_nodes does not strip the '*' from the default partition name.
|
|
323
361
|
# But non-default partition names can also end with '*',
|
|
324
362
|
# so we need to check whether the partition name without the '*'
|
|
@@ -352,27 +390,18 @@ def check_instance_fits(
|
|
|
352
390
|
assert acc_count is not None, (acc_type, acc_count)
|
|
353
391
|
|
|
354
392
|
gpu_nodes = []
|
|
355
|
-
# GRES string format: 'gpu:acc_type:acc_count(optional_extra_info)'
|
|
356
|
-
# Examples:
|
|
357
|
-
# - gpu:nvidia_h100_80gb_hbm3:8(S:0-1)
|
|
358
|
-
# - gpu:a10g:8
|
|
359
|
-
# - gpu:l4:1
|
|
360
|
-
gres_pattern = re.compile(r'^gpu:([^:]+):(\d+)')
|
|
361
393
|
for node_info in nodes:
|
|
362
|
-
gres_str = node_info.gres
|
|
363
394
|
# Extract the GPU type and count from the GRES string
|
|
364
|
-
|
|
365
|
-
|
|
395
|
+
node_acc_type, node_acc_count = get_gpu_type_and_count(
|
|
396
|
+
node_info.gres)
|
|
397
|
+
if node_acc_type is None:
|
|
366
398
|
continue
|
|
367
399
|
|
|
368
|
-
node_acc_type = match.group(1).lower()
|
|
369
|
-
node_acc_count = int(match.group(2))
|
|
370
|
-
|
|
371
400
|
# TODO(jwj): Handle status check.
|
|
372
401
|
|
|
373
402
|
# Check if the node has the requested GPU type and at least the
|
|
374
403
|
# requested count
|
|
375
|
-
if (node_acc_type == acc_type.lower() and
|
|
404
|
+
if (node_acc_type.lower() == acc_type.lower() and
|
|
376
405
|
node_acc_count >= acc_count):
|
|
377
406
|
gpu_nodes.append(node_info)
|
|
378
407
|
if len(gpu_nodes) == 0:
|
|
@@ -394,6 +423,51 @@ def check_instance_fits(
|
|
|
394
423
|
return fits, reason
|
|
395
424
|
|
|
396
425
|
|
|
426
|
+
# GRES names are highly unlikely to change within a cluster.
|
|
427
|
+
# TODO(kevin): Cache using sky/utils/db/kv_cache.py too.
|
|
428
|
+
@annotations.lru_cache(scope='global', maxsize=10)
|
|
429
|
+
def get_gres_gpu_type(cluster: str, requested_gpu_type: str) -> str:
|
|
430
|
+
"""Get the actual GPU type as it appears in the cluster's GRES.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
cluster: Name of the Slurm cluster.
|
|
434
|
+
requested_gpu_type: The GPU type requested by the user.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
The actual GPU type as it appears in the cluster's GRES string.
|
|
438
|
+
Falls back to the requested type if not found.
|
|
439
|
+
"""
|
|
440
|
+
try:
|
|
441
|
+
ssh_config = get_slurm_ssh_config()
|
|
442
|
+
ssh_config_dict = ssh_config.lookup(cluster)
|
|
443
|
+
client = slurm.SlurmClient(
|
|
444
|
+
ssh_config_dict['hostname'],
|
|
445
|
+
int(ssh_config_dict.get('port', 22)),
|
|
446
|
+
ssh_config_dict['user'],
|
|
447
|
+
ssh_config_dict['identityfile'][0],
|
|
448
|
+
ssh_proxy_command=ssh_config_dict.get('proxycommand', None),
|
|
449
|
+
ssh_proxy_jump=ssh_config_dict.get('proxyjump', None),
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
nodes = client.info_nodes()
|
|
453
|
+
|
|
454
|
+
for node_info in nodes:
|
|
455
|
+
node_gpu_type, _ = get_gpu_type_and_count(node_info.gres)
|
|
456
|
+
if node_gpu_type is None:
|
|
457
|
+
continue
|
|
458
|
+
if node_gpu_type.lower() == requested_gpu_type.lower():
|
|
459
|
+
return node_gpu_type
|
|
460
|
+
except Exception as e: # pylint: disable=broad-except
|
|
461
|
+
logger.warning(
|
|
462
|
+
'Failed to determine the exact GPU GRES type from the Slurm '
|
|
463
|
+
f'cluster {cluster!r}. Falling back to '
|
|
464
|
+
f'{requested_gpu_type.lower()!r}. This may cause issues if the '
|
|
465
|
+
f'casing is incorrect. Error: {common_utils.format_exception(e)}')
|
|
466
|
+
|
|
467
|
+
# GRES names are more commonly in lowercase from what we've seen so far.
|
|
468
|
+
return requested_gpu_type.lower()
|
|
469
|
+
|
|
470
|
+
|
|
397
471
|
def _get_slurm_node_info_list(
|
|
398
472
|
slurm_cluster_name: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
399
473
|
"""Gathers detailed information about each node in the Slurm cluster.
|
|
@@ -423,6 +497,7 @@ def _get_slurm_node_info_list(
|
|
|
423
497
|
slurm_config_dict['user'],
|
|
424
498
|
slurm_config_dict['identityfile'][0],
|
|
425
499
|
ssh_proxy_command=slurm_config_dict.get('proxycommand', None),
|
|
500
|
+
ssh_proxy_jump=slurm_config_dict.get('proxyjump', None),
|
|
426
501
|
)
|
|
427
502
|
node_infos = slurm_client.info_nodes()
|
|
428
503
|
|
|
@@ -434,8 +509,8 @@ def _get_slurm_node_info_list(
|
|
|
434
509
|
|
|
435
510
|
# 2. Process each node, aggregating partitions per node
|
|
436
511
|
slurm_nodes_info: Dict[str, Dict[str, Any]] = {}
|
|
437
|
-
gres_gpu_pattern = re.compile(r'((gpu)(?::([^:]+))?:(\d+))')
|
|
438
512
|
|
|
513
|
+
nodes_to_jobs_gres = slurm_client.get_all_jobs_gres()
|
|
439
514
|
for node_info in node_infos:
|
|
440
515
|
node_name = node_info.node
|
|
441
516
|
state = node_info.state
|
|
@@ -447,43 +522,27 @@ def _get_slurm_node_info_list(
|
|
|
447
522
|
continue
|
|
448
523
|
|
|
449
524
|
# Extract GPU info from GRES
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
gpu_type_from_sinfo = gres_match.group(3).upper()
|
|
459
|
-
# If total_gpus > 0 but no type, default to 'GPU'
|
|
460
|
-
elif total_gpus > 0:
|
|
461
|
-
gpu_type_from_sinfo = 'GPU'
|
|
462
|
-
except ValueError:
|
|
463
|
-
logger.warning(
|
|
464
|
-
f'Could not parse GPU count from GRES for {node_name}.')
|
|
465
|
-
|
|
466
|
-
# Get allocated GPUs via squeue
|
|
525
|
+
node_gpu_type, total_gpus = get_gpu_type_and_count(gres_str)
|
|
526
|
+
if total_gpus > 0:
|
|
527
|
+
if node_gpu_type is not None:
|
|
528
|
+
node_gpu_type = node_gpu_type.upper()
|
|
529
|
+
else:
|
|
530
|
+
node_gpu_type = 'GPU'
|
|
531
|
+
|
|
532
|
+
# Get allocated GPUs
|
|
467
533
|
allocated_gpus = 0
|
|
468
534
|
# TODO(zhwu): move to enum
|
|
469
535
|
if state in ('alloc', 'mix', 'drain', 'drng', 'drained', 'resv',
|
|
470
536
|
'comp'):
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
if state == 'alloc':
|
|
481
|
-
# We can infer allocated GPUs only if the node is
|
|
482
|
-
# in 'alloc' state.
|
|
483
|
-
allocated_gpus = total_gpus
|
|
484
|
-
else:
|
|
485
|
-
# Otherwise, just raise the error.
|
|
486
|
-
raise e
|
|
537
|
+
jobs_gres = nodes_to_jobs_gres.get(node_name, [])
|
|
538
|
+
if jobs_gres:
|
|
539
|
+
for job_line in jobs_gres:
|
|
540
|
+
_, job_gpu_count = get_gpu_type_and_count(job_line)
|
|
541
|
+
allocated_gpus += job_gpu_count
|
|
542
|
+
elif state == 'alloc':
|
|
543
|
+
# If no GRES info found but node is fully allocated,
|
|
544
|
+
# assume all GPUs are in use.
|
|
545
|
+
allocated_gpus = total_gpus
|
|
487
546
|
elif state == 'idle':
|
|
488
547
|
allocated_gpus = 0
|
|
489
548
|
|
|
@@ -493,27 +552,16 @@ def _get_slurm_node_info_list(
|
|
|
493
552
|
'maint') else 0
|
|
494
553
|
free_gpus = max(0, free_gpus)
|
|
495
554
|
|
|
496
|
-
# Get CPU/Mem info via scontrol
|
|
497
|
-
vcpu_total = 0
|
|
498
|
-
mem_gb = 0.0
|
|
499
|
-
try:
|
|
500
|
-
node_details = slurm_client.node_details(node_name)
|
|
501
|
-
vcpu_total = int(node_details.get('CPUTot', '0'))
|
|
502
|
-
mem_gb = float(node_details.get('RealMemory', '0')) / 1024.0
|
|
503
|
-
except Exception as e: # pylint: disable=broad-except
|
|
504
|
-
logger.warning(
|
|
505
|
-
f'Failed to get CPU/memory info for {node_name}: {e}')
|
|
506
|
-
|
|
507
555
|
slurm_nodes_info[node_name] = {
|
|
508
556
|
'node_name': node_name,
|
|
509
557
|
'slurm_cluster_name': slurm_cluster_name,
|
|
510
558
|
'partitions': [partition],
|
|
511
559
|
'node_state': state,
|
|
512
|
-
'gpu_type':
|
|
560
|
+
'gpu_type': node_gpu_type,
|
|
513
561
|
'total_gpus': total_gpus,
|
|
514
562
|
'free_gpus': free_gpus,
|
|
515
|
-
'vcpu_count':
|
|
516
|
-
'memory_gb': round(
|
|
563
|
+
'vcpu_count': node_info.cpus,
|
|
564
|
+
'memory_gb': round(node_info.memory_gb, 2),
|
|
517
565
|
}
|
|
518
566
|
|
|
519
567
|
for node_info in slurm_nodes_info.values():
|
|
@@ -539,10 +587,15 @@ def slurm_node_info(
|
|
|
539
587
|
return node_list
|
|
540
588
|
|
|
541
589
|
|
|
542
|
-
def
|
|
543
|
-
|
|
590
|
+
def is_inside_slurm_cluster() -> bool:
|
|
591
|
+
# Check for the marker file in the current home directory. When run by
|
|
592
|
+
# the skylet on a compute node, the HOME environment variable is set to
|
|
593
|
+
# the cluster's sky home directory by the SlurmCommandRunner.
|
|
594
|
+
marker_file = os.path.join(os.path.expanduser('~'), SLURM_MARKER_FILE)
|
|
595
|
+
return os.path.exists(marker_file)
|
|
544
596
|
|
|
545
597
|
|
|
598
|
+
@annotations.lru_cache(scope='request')
|
|
546
599
|
def get_partitions(cluster_name: str) -> List[str]:
|
|
547
600
|
"""Get unique partition names available in a Slurm cluster.
|
|
548
601
|
|
|
@@ -565,6 +618,7 @@ def get_partitions(cluster_name: str) -> List[str]:
|
|
|
565
618
|
slurm_config_dict['user'],
|
|
566
619
|
slurm_config_dict['identityfile'][0],
|
|
567
620
|
ssh_proxy_command=slurm_config_dict.get('proxycommand', None),
|
|
621
|
+
ssh_proxy_jump=slurm_config_dict.get('proxyjump', None),
|
|
568
622
|
)
|
|
569
623
|
|
|
570
624
|
partitions_info = client.get_partitions_info()
|
|
@@ -577,7 +631,59 @@ def get_partitions(cluster_name: str) -> List[str]:
|
|
|
577
631
|
other_partitions.append(partition.name)
|
|
578
632
|
return default_partitions + sorted(other_partitions)
|
|
579
633
|
except Exception as e: # pylint: disable=broad-except
|
|
580
|
-
|
|
581
|
-
f'Failed to get partitions for cluster
|
|
582
|
-
|
|
583
|
-
|
|
634
|
+
raise ValueError(
|
|
635
|
+
f'Failed to get partitions for cluster '
|
|
636
|
+
f'{cluster_name}: {common_utils.format_exception(e)}') from e
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def srun_sshd_command(
|
|
640
|
+
job_id: str,
|
|
641
|
+
target_node: str,
|
|
642
|
+
unix_user: str,
|
|
643
|
+
) -> str:
|
|
644
|
+
"""Build srun command for launching sshd -i inside a Slurm job.
|
|
645
|
+
|
|
646
|
+
This is used by the API server to proxy SSH connections to Slurm jobs
|
|
647
|
+
via sshd running in inetd mode within srun.
|
|
648
|
+
|
|
649
|
+
Args:
|
|
650
|
+
job_id: The Slurm job ID
|
|
651
|
+
target_node: The target compute node hostname
|
|
652
|
+
unix_user: The Unix user for the job
|
|
653
|
+
|
|
654
|
+
Returns:
|
|
655
|
+
List of command arguments to be extended to ssh base command
|
|
656
|
+
"""
|
|
657
|
+
# We use ~username to ensure we use the real home of the user ssh'ing in,
|
|
658
|
+
# because we override the home directory in SlurmCommandRunner.run.
|
|
659
|
+
user_home_ssh_dir = f'~{unix_user}/.ssh'
|
|
660
|
+
return shlex.join([
|
|
661
|
+
'srun',
|
|
662
|
+
'--quiet',
|
|
663
|
+
'--unbuffered',
|
|
664
|
+
'--overlap',
|
|
665
|
+
'--jobid',
|
|
666
|
+
job_id,
|
|
667
|
+
'-w',
|
|
668
|
+
target_node,
|
|
669
|
+
'/usr/sbin/sshd',
|
|
670
|
+
'-i', # Uses stdin/stdout
|
|
671
|
+
'-e', # Writes errors to stderr
|
|
672
|
+
'-f', # Use /dev/null to avoid reading system sshd_config
|
|
673
|
+
'/dev/null',
|
|
674
|
+
'-h',
|
|
675
|
+
f'{user_home_ssh_dir}/{SLURM_SSHD_HOST_KEY_FILENAME}',
|
|
676
|
+
'-o',
|
|
677
|
+
f'AuthorizedKeysFile={user_home_ssh_dir}/authorized_keys',
|
|
678
|
+
'-o',
|
|
679
|
+
'PasswordAuthentication=no',
|
|
680
|
+
'-o',
|
|
681
|
+
'PubkeyAuthentication=yes',
|
|
682
|
+
# If UsePAM is enabled, we will not be able to run sshd(8)
|
|
683
|
+
# as a non-root user.
|
|
684
|
+
# See https://man7.org/linux/man-pages/man5/sshd_config.5.html
|
|
685
|
+
'-o',
|
|
686
|
+
'UsePAM=no',
|
|
687
|
+
'-o',
|
|
688
|
+
f'AcceptEnv={constants.SKY_CLUSTER_NAME_ENV_VAR_KEY}',
|
|
689
|
+
])
|
sky/provision/vast/utils.py
CHANGED
|
@@ -98,6 +98,7 @@ def launch(name: str, instance_type: str, region: str, disk_size: int,
|
|
|
98
98
|
]
|
|
99
99
|
if secure_only:
|
|
100
100
|
query.append('datacenter=true')
|
|
101
|
+
query.append('hosting_type>=1')
|
|
101
102
|
query_str = ' '.join(query)
|
|
102
103
|
|
|
103
104
|
instance_list = vast.vast().search_offers(query=query_str)
|
sky/resources.py
CHANGED
|
@@ -219,6 +219,9 @@ class Resources:
|
|
|
219
219
|
- strategy: the recovery strategy to use.
|
|
220
220
|
- max_restarts_on_errors: the max number of restarts on user code
|
|
221
221
|
errors.
|
|
222
|
+
- recover_on_exit_codes: a list of exit codes that should trigger
|
|
223
|
+
job recovery. If any task exits with a code in this list, the job
|
|
224
|
+
will be recovered regardless of max_restarts_on_errors limit.
|
|
222
225
|
|
|
223
226
|
region: the region to use. Deprecated. Use `infra` instead.
|
|
224
227
|
zone: the zone to use. Deprecated. Use `infra` instead.
|
|
@@ -569,7 +572,8 @@ class Resources:
|
|
|
569
572
|
if self.cloud is not None and self._instance_type is not None:
|
|
570
573
|
vcpus, _ = self.cloud.get_vcpus_mem_from_instance_type(
|
|
571
574
|
self._instance_type)
|
|
572
|
-
|
|
575
|
+
if vcpus is not None:
|
|
576
|
+
return str(vcpus)
|
|
573
577
|
return None
|
|
574
578
|
|
|
575
579
|
@property
|
|
@@ -1645,6 +1649,7 @@ class Resources:
|
|
|
1645
1649
|
other: Union[List['Resources'], 'Resources'],
|
|
1646
1650
|
requested_num_nodes: int = 1,
|
|
1647
1651
|
check_ports: bool = False,
|
|
1652
|
+
check_cloud: bool = True,
|
|
1648
1653
|
) -> bool:
|
|
1649
1654
|
"""Returns whether this resources is less demanding than the other.
|
|
1650
1655
|
|
|
@@ -1654,24 +1659,29 @@ class Resources:
|
|
|
1654
1659
|
requested_num_nodes: Number of nodes that the current task
|
|
1655
1660
|
requests from the cluster.
|
|
1656
1661
|
check_ports: Whether to check the ports field.
|
|
1662
|
+
check_cloud: Whether we check the cloud/region/zone fields. Useful
|
|
1663
|
+
for resources that don't have cloud specified, like some launched
|
|
1664
|
+
resources.
|
|
1657
1665
|
"""
|
|
1658
1666
|
if isinstance(other, list):
|
|
1659
1667
|
resources_list = [self.less_demanding_than(o) for o in other]
|
|
1660
1668
|
return requested_num_nodes <= sum(resources_list)
|
|
1661
1669
|
|
|
1662
|
-
|
|
1670
|
+
if check_cloud:
|
|
1671
|
+
assert other.cloud is not None, 'Other cloud must be specified'
|
|
1663
1672
|
|
|
1664
|
-
|
|
1665
|
-
|
|
1666
|
-
|
|
1673
|
+
if self.cloud is not None and not self.cloud.is_same_cloud(
|
|
1674
|
+
other.cloud):
|
|
1675
|
+
return False
|
|
1676
|
+
# self.cloud <= other.cloud
|
|
1667
1677
|
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
|
|
1678
|
+
if self.region is not None and self.region != other.region:
|
|
1679
|
+
return False
|
|
1680
|
+
# self.region <= other.region
|
|
1671
1681
|
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
|
|
1682
|
+
if self.zone is not None and self.zone != other.zone:
|
|
1683
|
+
return False
|
|
1684
|
+
# self.zone <= other.zone
|
|
1675
1685
|
|
|
1676
1686
|
if self.image_id is not None:
|
|
1677
1687
|
if other.image_id is None:
|
|
@@ -1743,8 +1753,10 @@ class Resources:
|
|
|
1743
1753
|
# On Kubernetes, we can't launch a task that requires FUSE on a pod
|
|
1744
1754
|
# that wasn't initialized with FUSE support at the start.
|
|
1745
1755
|
# Other clouds don't have this limitation.
|
|
1746
|
-
if
|
|
1747
|
-
|
|
1756
|
+
if check_cloud:
|
|
1757
|
+
assert other.cloud is not None
|
|
1758
|
+
if other.cloud.is_same_cloud(clouds.Kubernetes()):
|
|
1759
|
+
return False
|
|
1748
1760
|
|
|
1749
1761
|
# self <= other
|
|
1750
1762
|
return True
|
|
@@ -1792,6 +1804,101 @@ class Resources:
|
|
|
1792
1804
|
self._docker_login_config is None,
|
|
1793
1805
|
])
|
|
1794
1806
|
|
|
1807
|
+
def __add__(self, other: Optional['Resources']) -> Optional['Resources']:
|
|
1808
|
+
"""Add two Resources objects together.
|
|
1809
|
+
|
|
1810
|
+
Args:
|
|
1811
|
+
other: Another Resources object to add (may be None)
|
|
1812
|
+
|
|
1813
|
+
Returns:
|
|
1814
|
+
New Resources object with summed resources, or None if other is None
|
|
1815
|
+
"""
|
|
1816
|
+
if other is None:
|
|
1817
|
+
return self
|
|
1818
|
+
|
|
1819
|
+
# Sum CPUs
|
|
1820
|
+
self_cpus = _parse_value(self.cpus)
|
|
1821
|
+
other_cpus = _parse_value(other.cpus)
|
|
1822
|
+
total_cpus = None
|
|
1823
|
+
if self_cpus is not None or other_cpus is not None:
|
|
1824
|
+
total_cpus = (self_cpus or 0) + (other_cpus or 0)
|
|
1825
|
+
|
|
1826
|
+
# Sum memory
|
|
1827
|
+
self_memory = _parse_value(self.memory)
|
|
1828
|
+
other_memory = _parse_value(other.memory)
|
|
1829
|
+
total_memory = None
|
|
1830
|
+
if self_memory is not None or other_memory is not None:
|
|
1831
|
+
total_memory = (self_memory or 0) + (other_memory or 0)
|
|
1832
|
+
|
|
1833
|
+
# Sum accelerators
|
|
1834
|
+
total_accelerators = {}
|
|
1835
|
+
if self.accelerators:
|
|
1836
|
+
for acc_type, count in self.accelerators.items():
|
|
1837
|
+
total_accelerators[acc_type] = float(count)
|
|
1838
|
+
if other.accelerators:
|
|
1839
|
+
for acc_type, count in other.accelerators.items():
|
|
1840
|
+
if acc_type not in total_accelerators:
|
|
1841
|
+
total_accelerators[acc_type] = 0
|
|
1842
|
+
total_accelerators[acc_type] += float(count)
|
|
1843
|
+
|
|
1844
|
+
return Resources(
|
|
1845
|
+
cpus=str(total_cpus) if total_cpus is not None else None,
|
|
1846
|
+
memory=str(total_memory) if total_memory is not None else None,
|
|
1847
|
+
accelerators=total_accelerators if total_accelerators else None)
|
|
1848
|
+
|
|
1849
|
+
def __sub__(self, other: Optional['Resources']) -> 'Resources':
|
|
1850
|
+
"""Subtract another Resources object from this one.
|
|
1851
|
+
|
|
1852
|
+
Args:
|
|
1853
|
+
other: Resources to subtract (may be None)
|
|
1854
|
+
|
|
1855
|
+
Returns:
|
|
1856
|
+
New Resources object with subtracted resources. If the result for a
|
|
1857
|
+
resource is negative, it will be set to 0.
|
|
1858
|
+
"""
|
|
1859
|
+
if other is None:
|
|
1860
|
+
return self
|
|
1861
|
+
|
|
1862
|
+
# Subtract CPUs
|
|
1863
|
+
self_cpus = _parse_value(self.cpus)
|
|
1864
|
+
other_cpus = _parse_value(other.cpus)
|
|
1865
|
+
free_cpus = None
|
|
1866
|
+
if self_cpus is not None:
|
|
1867
|
+
if other_cpus is not None:
|
|
1868
|
+
free_cpus = max(0, self_cpus - other_cpus)
|
|
1869
|
+
else:
|
|
1870
|
+
free_cpus = self_cpus
|
|
1871
|
+
|
|
1872
|
+
# Subtract memory
|
|
1873
|
+
self_memory = _parse_value(self.memory)
|
|
1874
|
+
other_memory = _parse_value(other.memory)
|
|
1875
|
+
free_memory = None
|
|
1876
|
+
if self_memory is not None:
|
|
1877
|
+
if other_memory is not None:
|
|
1878
|
+
free_memory = max(0, self_memory - other_memory)
|
|
1879
|
+
else:
|
|
1880
|
+
free_memory = self_memory
|
|
1881
|
+
|
|
1882
|
+
# Subtract accelerators
|
|
1883
|
+
free_accelerators = {}
|
|
1884
|
+
if self.accelerators:
|
|
1885
|
+
for acc_type, total_count in self.accelerators.items():
|
|
1886
|
+
used_count = (other.accelerators.get(acc_type, 0)
|
|
1887
|
+
if other.accelerators else 0)
|
|
1888
|
+
free_count = max(0, float(total_count) - float(used_count))
|
|
1889
|
+
if free_count > 0:
|
|
1890
|
+
free_accelerators[acc_type] = free_count
|
|
1891
|
+
|
|
1892
|
+
# If all resources are exhausted, return None
|
|
1893
|
+
# Check if we have any free resources
|
|
1894
|
+
free_cpus = None if free_cpus == 0 else free_cpus
|
|
1895
|
+
free_memory = None if free_memory == 0 else free_memory
|
|
1896
|
+
free_accelerators = None if not free_accelerators else free_accelerators
|
|
1897
|
+
|
|
1898
|
+
return Resources(cpus=free_cpus,
|
|
1899
|
+
memory=free_memory,
|
|
1900
|
+
accelerators=free_accelerators)
|
|
1901
|
+
|
|
1795
1902
|
def copy(self, **override) -> 'Resources':
|
|
1796
1903
|
"""Returns a copy of the given Resources."""
|
|
1797
1904
|
use_spot = self.use_spot if self._use_spot_specified else None
|
|
@@ -2456,3 +2563,18 @@ def _maybe_add_docker_prefix_to_image_id(
|
|
|
2456
2563
|
for k, v in image_id_dict.items():
|
|
2457
2564
|
if not v.startswith('docker:'):
|
|
2458
2565
|
image_id_dict[k] = f'docker:{v}'
|
|
2566
|
+
|
|
2567
|
+
|
|
2568
|
+
def _parse_value(val):
|
|
2569
|
+
if val is None:
|
|
2570
|
+
return None
|
|
2571
|
+
if isinstance(val, (int, float)):
|
|
2572
|
+
return float(val)
|
|
2573
|
+
if isinstance(val, str):
|
|
2574
|
+
# Remove '+' suffix if present
|
|
2575
|
+
val = val.rstrip('+')
|
|
2576
|
+
try:
|
|
2577
|
+
return float(val)
|
|
2578
|
+
except ValueError:
|
|
2579
|
+
return None
|
|
2580
|
+
return None
|