dstack 0.19.30__py3-none-any.whl → 0.19.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dstack might be problematic. Click here for more details.
- dstack/_internal/cli/commands/__init__.py +8 -0
- dstack/_internal/cli/commands/project.py +27 -20
- dstack/_internal/cli/commands/server.py +5 -0
- dstack/_internal/cli/main.py +1 -3
- dstack/_internal/core/backends/aws/compute.py +2 -0
- dstack/_internal/core/backends/azure/compute.py +2 -0
- dstack/_internal/core/backends/base/compute.py +32 -9
- dstack/_internal/core/backends/base/offers.py +1 -0
- dstack/_internal/core/backends/cloudrift/compute.py +2 -0
- dstack/_internal/core/backends/cudo/compute.py +2 -0
- dstack/_internal/core/backends/datacrunch/compute.py +2 -0
- dstack/_internal/core/backends/digitalocean_base/compute.py +2 -0
- dstack/_internal/core/backends/features.py +5 -0
- dstack/_internal/core/backends/gcp/compute.py +74 -34
- dstack/_internal/core/backends/gcp/configurator.py +1 -1
- dstack/_internal/core/backends/gcp/models.py +14 -1
- dstack/_internal/core/backends/gcp/resources.py +35 -12
- dstack/_internal/core/backends/hotaisle/compute.py +2 -0
- dstack/_internal/core/backends/kubernetes/compute.py +466 -213
- dstack/_internal/core/backends/kubernetes/models.py +13 -16
- dstack/_internal/core/backends/kubernetes/utils.py +145 -8
- dstack/_internal/core/backends/lambdalabs/compute.py +2 -0
- dstack/_internal/core/backends/local/compute.py +2 -0
- dstack/_internal/core/backends/nebius/compute.py +2 -0
- dstack/_internal/core/backends/oci/compute.py +2 -0
- dstack/_internal/core/backends/template/compute.py.jinja +2 -0
- dstack/_internal/core/backends/tensordock/compute.py +2 -0
- dstack/_internal/core/backends/vultr/compute.py +2 -0
- dstack/_internal/server/background/tasks/common.py +2 -0
- dstack/_internal/server/background/tasks/process_instances.py +2 -2
- dstack/_internal/server/services/offers.py +7 -1
- dstack/_internal/server/testing/common.py +2 -0
- dstack/_internal/server/utils/provisioning.py +3 -10
- dstack/version.py +1 -1
- {dstack-0.19.30.dist-info → dstack-0.19.31.dist-info}/METADATA +11 -9
- {dstack-0.19.30.dist-info → dstack-0.19.31.dist-info}/RECORD +39 -39
- {dstack-0.19.30.dist-info → dstack-0.19.31.dist-info}/WHEEL +0 -0
- {dstack-0.19.30.dist-info → dstack-0.19.31.dist-info}/entry_points.txt +0 -0
- {dstack-0.19.30.dist-info → dstack-0.19.31.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -7,6 +7,7 @@ from typing import ClassVar, Optional
|
|
|
7
7
|
from rich_argparse import RichHelpFormatter
|
|
8
8
|
|
|
9
9
|
from dstack._internal.cli.services.completion import ProjectNameCompleter
|
|
10
|
+
from dstack._internal.cli.utils.common import configure_logging
|
|
10
11
|
from dstack._internal.core.errors import CLIError
|
|
11
12
|
from dstack.api import Client
|
|
12
13
|
|
|
@@ -52,9 +53,16 @@ class BaseCommand(ABC):
|
|
|
52
53
|
|
|
53
54
|
@abstractmethod
|
|
54
55
|
def _command(self, args: argparse.Namespace):
|
|
56
|
+
self._configure_logging()
|
|
55
57
|
if not self.ACCEPT_EXTRA_ARGS and args.extra_args:
|
|
56
58
|
raise CLIError(f"Unrecognized arguments: {shlex.join(args.extra_args)}")
|
|
57
59
|
|
|
60
|
+
def _configure_logging(self) -> None:
|
|
61
|
+
"""
|
|
62
|
+
Override this method to configure command-specific logging
|
|
63
|
+
"""
|
|
64
|
+
configure_logging()
|
|
65
|
+
|
|
58
66
|
|
|
59
67
|
class APIBaseCommand(BaseCommand):
|
|
60
68
|
api: Client
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import argparse
|
|
2
|
+
from typing import Any, Union
|
|
2
3
|
|
|
3
4
|
from requests import HTTPError
|
|
4
5
|
from rich.table import Table
|
|
5
6
|
|
|
6
7
|
import dstack.api.server
|
|
7
8
|
from dstack._internal.cli.commands import BaseCommand
|
|
8
|
-
from dstack._internal.cli.utils.common import confirm_ask, console
|
|
9
|
+
from dstack._internal.cli.utils.common import add_row_from_dict, confirm_ask, console
|
|
9
10
|
from dstack._internal.core.errors import ClientError, CLIError
|
|
10
11
|
from dstack._internal.core.services.configs import ConfigManager
|
|
11
12
|
from dstack._internal.utils.logging import get_logger
|
|
@@ -58,6 +59,10 @@ class ProjectCommand(BaseCommand):
|
|
|
58
59
|
# List subcommand
|
|
59
60
|
list_parser = subparsers.add_parser("list", help="List configured projects")
|
|
60
61
|
list_parser.set_defaults(subfunc=self._list)
|
|
62
|
+
for parser in [self._parser, list_parser]:
|
|
63
|
+
parser.add_argument(
|
|
64
|
+
"-v", "--verbose", action="store_true", help="Show more information"
|
|
65
|
+
)
|
|
61
66
|
|
|
62
67
|
# Set default subcommand
|
|
63
68
|
set_default_parser = subparsers.add_parser("set-default", help="Set default project")
|
|
@@ -122,30 +127,32 @@ class ProjectCommand(BaseCommand):
|
|
|
122
127
|
table = Table(box=None)
|
|
123
128
|
table.add_column("PROJECT", style="bold", no_wrap=True)
|
|
124
129
|
table.add_column("URL", style="grey58")
|
|
125
|
-
|
|
130
|
+
if args.verbose:
|
|
131
|
+
table.add_column("USER", style="grey58")
|
|
126
132
|
table.add_column("DEFAULT", justify="center")
|
|
127
133
|
|
|
128
134
|
for project_config in config_manager.list_project_configs():
|
|
129
135
|
project_name = project_config.name
|
|
130
136
|
is_default = project_name == default_project.name if default_project else False
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
username
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
137
|
+
row: dict[Union[str, int], Any] = {
|
|
138
|
+
"PROJECT": project_name,
|
|
139
|
+
"URL": project_config.url,
|
|
140
|
+
"DEFAULT": "✓" if is_default else "",
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
if args.verbose:
|
|
144
|
+
# Get username from API
|
|
145
|
+
try:
|
|
146
|
+
api_client = dstack.api.server.APIClient(
|
|
147
|
+
base_url=project_config.url, token=project_config.token
|
|
148
|
+
)
|
|
149
|
+
user_info = api_client.users.get_my_user()
|
|
150
|
+
username = user_info.username
|
|
151
|
+
except ClientError:
|
|
152
|
+
username = "(invalid token)"
|
|
153
|
+
row["USER"] = username
|
|
154
|
+
|
|
155
|
+
add_row_from_dict(table, row, style="bold" if is_default else None)
|
|
149
156
|
|
|
150
157
|
console.print(table)
|
|
151
158
|
|
dstack/_internal/cli/main.py
CHANGED
|
@@ -22,7 +22,7 @@ from dstack._internal.cli.commands.server import ServerCommand
|
|
|
22
22
|
from dstack._internal.cli.commands.stats import StatsCommand
|
|
23
23
|
from dstack._internal.cli.commands.stop import StopCommand
|
|
24
24
|
from dstack._internal.cli.commands.volume import VolumeCommand
|
|
25
|
-
from dstack._internal.cli.utils.common import _colors,
|
|
25
|
+
from dstack._internal.cli.utils.common import _colors, console
|
|
26
26
|
from dstack._internal.cli.utils.updates import check_for_updates
|
|
27
27
|
from dstack._internal.core.errors import ClientError, CLIError, ConfigurationError, SSHError
|
|
28
28
|
from dstack._internal.core.services.ssh.client import get_ssh_client_info
|
|
@@ -39,8 +39,6 @@ def main():
|
|
|
39
39
|
RichHelpFormatter.styles["argparse.groups"] = "bold grey74"
|
|
40
40
|
RichHelpFormatter.styles["argparse.text"] = "grey74"
|
|
41
41
|
|
|
42
|
-
configure_logging()
|
|
43
|
-
|
|
44
42
|
parser = argparse.ArgumentParser(
|
|
45
43
|
description=(
|
|
46
44
|
"Not sure where to start?"
|
|
@@ -24,6 +24,7 @@ from dstack._internal.core.backends.base.compute import (
|
|
|
24
24
|
ComputeWithMultinodeSupport,
|
|
25
25
|
ComputeWithPlacementGroupSupport,
|
|
26
26
|
ComputeWithPrivateGatewaySupport,
|
|
27
|
+
ComputeWithPrivilegedSupport,
|
|
27
28
|
ComputeWithReservationSupport,
|
|
28
29
|
ComputeWithVolumeSupport,
|
|
29
30
|
generate_unique_gateway_instance_name,
|
|
@@ -90,6 +91,7 @@ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs):
|
|
|
90
91
|
class AWSCompute(
|
|
91
92
|
ComputeWithAllOffersCached,
|
|
92
93
|
ComputeWithCreateInstanceSupport,
|
|
94
|
+
ComputeWithPrivilegedSupport,
|
|
93
95
|
ComputeWithMultinodeSupport,
|
|
94
96
|
ComputeWithReservationSupport,
|
|
95
97
|
ComputeWithPlacementGroupSupport,
|
|
@@ -43,6 +43,7 @@ from dstack._internal.core.backends.base.compute import (
|
|
|
43
43
|
ComputeWithCreateInstanceSupport,
|
|
44
44
|
ComputeWithGatewaySupport,
|
|
45
45
|
ComputeWithMultinodeSupport,
|
|
46
|
+
ComputeWithPrivilegedSupport,
|
|
46
47
|
generate_unique_gateway_instance_name,
|
|
47
48
|
generate_unique_instance_name,
|
|
48
49
|
get_gateway_user_data,
|
|
@@ -78,6 +79,7 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=Memory.parse("30GB"), max=Memory.pars
|
|
|
78
79
|
class AzureCompute(
|
|
79
80
|
ComputeWithAllOffersCached,
|
|
80
81
|
ComputeWithCreateInstanceSupport,
|
|
82
|
+
ComputeWithPrivilegedSupport,
|
|
81
83
|
ComputeWithMultinodeSupport,
|
|
82
84
|
ComputeWithGatewaySupport,
|
|
83
85
|
Compute,
|
|
@@ -5,14 +5,16 @@ import string
|
|
|
5
5
|
import threading
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
7
|
from collections.abc import Iterable
|
|
8
|
+
from enum import Enum
|
|
8
9
|
from functools import lru_cache
|
|
9
10
|
from pathlib import Path
|
|
10
|
-
from typing import Callable, Dict, List,
|
|
11
|
+
from typing import Callable, Dict, List, Optional
|
|
11
12
|
|
|
12
13
|
import git
|
|
13
14
|
import requests
|
|
14
15
|
import yaml
|
|
15
16
|
from cachetools import TTLCache, cachedmethod
|
|
17
|
+
from gpuhunt import CPUArchitecture
|
|
16
18
|
|
|
17
19
|
from dstack._internal import settings
|
|
18
20
|
from dstack._internal.core.backends.base.offers import filter_offers_by_requirements
|
|
@@ -65,7 +67,21 @@ NVIDIA_GPUS_REQUIRING_PROPRIETARY_KERNEL_MODULES = frozenset(
|
|
|
65
67
|
]
|
|
66
68
|
)
|
|
67
69
|
|
|
68
|
-
|
|
70
|
+
|
|
71
|
+
class GoArchType(str, Enum):
|
|
72
|
+
"""
|
|
73
|
+
A subset of GOARCH values
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
AMD64 = "amd64"
|
|
77
|
+
ARM64 = "arm64"
|
|
78
|
+
|
|
79
|
+
def to_cpu_architecture(self) -> CPUArchitecture:
|
|
80
|
+
if self == self.AMD64:
|
|
81
|
+
return CPUArchitecture.X86
|
|
82
|
+
if self == self.ARM64:
|
|
83
|
+
return CPUArchitecture.ARM
|
|
84
|
+
assert False, self
|
|
69
85
|
|
|
70
86
|
|
|
71
87
|
class Compute(ABC):
|
|
@@ -304,6 +320,15 @@ class ComputeWithCreateInstanceSupport(ABC):
|
|
|
304
320
|
]
|
|
305
321
|
|
|
306
322
|
|
|
323
|
+
class ComputeWithPrivilegedSupport:
|
|
324
|
+
"""
|
|
325
|
+
Must be subclassed to support runs with `privileged: true`.
|
|
326
|
+
All VM-based Computes (that is, Computes that use the shim) should subclass this mixin.
|
|
327
|
+
"""
|
|
328
|
+
|
|
329
|
+
pass
|
|
330
|
+
|
|
331
|
+
|
|
307
332
|
class ComputeWithMultinodeSupport:
|
|
308
333
|
"""
|
|
309
334
|
Must be subclassed to support multinode tasks and cluster fleets.
|
|
@@ -704,14 +729,14 @@ def normalize_arch(arch: Optional[str] = None) -> GoArchType:
|
|
|
704
729
|
If the arch is not specified, falls back to `amd64`.
|
|
705
730
|
"""
|
|
706
731
|
if not arch:
|
|
707
|
-
return
|
|
732
|
+
return GoArchType.AMD64
|
|
708
733
|
arch_lower = arch.lower()
|
|
709
734
|
if "32" in arch_lower or arch_lower in ["i386", "i686"]:
|
|
710
735
|
raise ValueError(f"32-bit architectures are not supported: {arch}")
|
|
711
736
|
if arch_lower.startswith("x86") or arch_lower.startswith("amd"):
|
|
712
|
-
return
|
|
737
|
+
return GoArchType.AMD64
|
|
713
738
|
if arch_lower.startswith("arm") or arch_lower.startswith("aarch"):
|
|
714
|
-
return
|
|
739
|
+
return GoArchType.ARM64
|
|
715
740
|
raise ValueError(f"Unsupported architecture: {arch}")
|
|
716
741
|
|
|
717
742
|
|
|
@@ -727,8 +752,7 @@ def get_dstack_runner_download_url(arch: Optional[str] = None) -> str:
|
|
|
727
752
|
"/{version}/binaries/dstack-runner-linux-{arch}"
|
|
728
753
|
)
|
|
729
754
|
version = get_dstack_runner_version()
|
|
730
|
-
|
|
731
|
-
return url_template.format(version=version, arch=arch)
|
|
755
|
+
return url_template.format(version=version, arch=normalize_arch(arch).value)
|
|
732
756
|
|
|
733
757
|
|
|
734
758
|
def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
|
|
@@ -743,8 +767,7 @@ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
|
|
|
743
767
|
"/{version}/binaries/dstack-shim-linux-{arch}"
|
|
744
768
|
)
|
|
745
769
|
version = get_dstack_runner_version()
|
|
746
|
-
|
|
747
|
-
return url_template.format(version=version, arch=arch)
|
|
770
|
+
return url_template.format(version=version, arch=normalize_arch(arch).value)
|
|
748
771
|
|
|
749
772
|
|
|
750
773
|
def get_setup_cloud_instance_commands(
|
|
@@ -4,6 +4,7 @@ from dstack._internal.core.backends.base.compute import (
|
|
|
4
4
|
Compute,
|
|
5
5
|
ComputeWithAllOffersCached,
|
|
6
6
|
ComputeWithCreateInstanceSupport,
|
|
7
|
+
ComputeWithPrivilegedSupport,
|
|
7
8
|
get_shim_commands,
|
|
8
9
|
)
|
|
9
10
|
from dstack._internal.core.backends.base.offers import get_catalog_offers
|
|
@@ -27,6 +28,7 @@ logger = get_logger(__name__)
|
|
|
27
28
|
class CloudRiftCompute(
|
|
28
29
|
ComputeWithAllOffersCached,
|
|
29
30
|
ComputeWithCreateInstanceSupport,
|
|
31
|
+
ComputeWithPrivilegedSupport,
|
|
30
32
|
Compute,
|
|
31
33
|
):
|
|
32
34
|
def __init__(self, config: CloudRiftConfig):
|
|
@@ -6,6 +6,7 @@ from dstack._internal.core.backends.base.backend import Compute
|
|
|
6
6
|
from dstack._internal.core.backends.base.compute import (
|
|
7
7
|
ComputeWithCreateInstanceSupport,
|
|
8
8
|
ComputeWithFilteredOffersCached,
|
|
9
|
+
ComputeWithPrivilegedSupport,
|
|
9
10
|
generate_unique_instance_name,
|
|
10
11
|
get_shim_commands,
|
|
11
12
|
)
|
|
@@ -32,6 +33,7 @@ MAX_RESOURCE_NAME_LEN = 30
|
|
|
32
33
|
class CudoCompute(
|
|
33
34
|
ComputeWithFilteredOffersCached,
|
|
34
35
|
ComputeWithCreateInstanceSupport,
|
|
36
|
+
ComputeWithPrivilegedSupport,
|
|
35
37
|
Compute,
|
|
36
38
|
):
|
|
37
39
|
def __init__(self, config: CudoConfig):
|
|
@@ -8,6 +8,7 @@ from dstack._internal.core.backends.base.backend import Compute
|
|
|
8
8
|
from dstack._internal.core.backends.base.compute import (
|
|
9
9
|
ComputeWithAllOffersCached,
|
|
10
10
|
ComputeWithCreateInstanceSupport,
|
|
11
|
+
ComputeWithPrivilegedSupport,
|
|
11
12
|
generate_unique_instance_name,
|
|
12
13
|
get_shim_commands,
|
|
13
14
|
)
|
|
@@ -39,6 +40,7 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=IMAGE_SIZE, max=None)
|
|
|
39
40
|
class DataCrunchCompute(
|
|
40
41
|
ComputeWithAllOffersCached,
|
|
41
42
|
ComputeWithCreateInstanceSupport,
|
|
43
|
+
ComputeWithPrivilegedSupport,
|
|
42
44
|
Compute,
|
|
43
45
|
):
|
|
44
46
|
def __init__(self, config: DataCrunchConfig):
|
|
@@ -7,6 +7,7 @@ from dstack._internal.core.backends.base.backend import Compute
|
|
|
7
7
|
from dstack._internal.core.backends.base.compute import (
|
|
8
8
|
ComputeWithAllOffersCached,
|
|
9
9
|
ComputeWithCreateInstanceSupport,
|
|
10
|
+
ComputeWithPrivilegedSupport,
|
|
10
11
|
generate_unique_instance_name,
|
|
11
12
|
get_user_data,
|
|
12
13
|
)
|
|
@@ -40,6 +41,7 @@ DOCKER_INSTALL_COMMANDS = [
|
|
|
40
41
|
class BaseDigitalOceanCompute(
|
|
41
42
|
ComputeWithAllOffersCached,
|
|
42
43
|
ComputeWithCreateInstanceSupport,
|
|
44
|
+
ComputeWithPrivilegedSupport,
|
|
43
45
|
Compute,
|
|
44
46
|
):
|
|
45
47
|
def __init__(self, config: BaseDigitalOceanConfig, api_url: str, type: BackendType):
|
|
@@ -4,6 +4,7 @@ from dstack._internal.core.backends.base.compute import (
|
|
|
4
4
|
ComputeWithMultinodeSupport,
|
|
5
5
|
ComputeWithPlacementGroupSupport,
|
|
6
6
|
ComputeWithPrivateGatewaySupport,
|
|
7
|
+
ComputeWithPrivilegedSupport,
|
|
7
8
|
ComputeWithReservationSupport,
|
|
8
9
|
ComputeWithVolumeSupport,
|
|
9
10
|
)
|
|
@@ -38,6 +39,10 @@ BACKENDS_WITH_CREATE_INSTANCE_SUPPORT = _get_backends_with_compute_feature(
|
|
|
38
39
|
configurator_classes=_configurator_classes,
|
|
39
40
|
compute_feature_class=ComputeWithCreateInstanceSupport,
|
|
40
41
|
)
|
|
42
|
+
BACKENDS_WITH_PRIVILEGED_SUPPORT = _get_backends_with_compute_feature(
|
|
43
|
+
configurator_classes=_configurator_classes,
|
|
44
|
+
compute_feature_class=ComputeWithPrivilegedSupport,
|
|
45
|
+
)
|
|
41
46
|
BACKENDS_WITH_MULTINODE_SUPPORT = [BackendType.REMOTE] + _get_backends_with_compute_feature(
|
|
42
47
|
configurator_classes=_configurator_classes,
|
|
43
48
|
compute_feature_class=ComputeWithMultinodeSupport,
|
|
@@ -23,6 +23,7 @@ from dstack._internal.core.backends.base.compute import (
|
|
|
23
23
|
ComputeWithMultinodeSupport,
|
|
24
24
|
ComputeWithPlacementGroupSupport,
|
|
25
25
|
ComputeWithPrivateGatewaySupport,
|
|
26
|
+
ComputeWithPrivilegedSupport,
|
|
26
27
|
ComputeWithVolumeSupport,
|
|
27
28
|
generate_unique_gateway_instance_name,
|
|
28
29
|
generate_unique_instance_name,
|
|
@@ -90,6 +91,7 @@ class GCPVolumeDiskBackendData(CoreModel):
|
|
|
90
91
|
class GCPCompute(
|
|
91
92
|
ComputeWithAllOffersCached,
|
|
92
93
|
ComputeWithCreateInstanceSupport,
|
|
94
|
+
ComputeWithPrivilegedSupport,
|
|
93
95
|
ComputeWithMultinodeSupport,
|
|
94
96
|
ComputeWithPlacementGroupSupport,
|
|
95
97
|
ComputeWithGatewaySupport,
|
|
@@ -111,8 +113,8 @@ class GCPCompute(
|
|
|
111
113
|
self.resource_policies_client = compute_v1.ResourcePoliciesClient(
|
|
112
114
|
credentials=self.credentials
|
|
113
115
|
)
|
|
114
|
-
self.
|
|
115
|
-
self.
|
|
116
|
+
self._usable_subnets_cache_lock = threading.Lock()
|
|
117
|
+
self._usable_subnets_cache = TTLCache(maxsize=1, ttl=120)
|
|
116
118
|
|
|
117
119
|
def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
|
|
118
120
|
regions = get_or_error(self.config.regions)
|
|
@@ -203,12 +205,12 @@ class GCPCompute(
|
|
|
203
205
|
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
|
|
204
206
|
# Choose any usable subnet in a VPC.
|
|
205
207
|
# Configuring a specific subnet per region is not supported yet.
|
|
206
|
-
subnetwork = _get_vpc_subnet(
|
|
207
|
-
|
|
208
|
-
config=self.config,
|
|
208
|
+
subnetwork = self._get_vpc_subnet(instance_offer.region)
|
|
209
|
+
extra_subnets = self._get_extra_subnets(
|
|
209
210
|
region=instance_offer.region,
|
|
211
|
+
instance_type_name=instance_offer.instance.name,
|
|
210
212
|
)
|
|
211
|
-
|
|
213
|
+
roce_subnets = self._get_roce_subnets(
|
|
212
214
|
region=instance_offer.region,
|
|
213
215
|
instance_type_name=instance_offer.instance.name,
|
|
214
216
|
)
|
|
@@ -330,6 +332,7 @@ class GCPCompute(
|
|
|
330
332
|
network=self.config.vpc_resource_name,
|
|
331
333
|
subnetwork=subnetwork,
|
|
332
334
|
extra_subnetworks=extra_subnets,
|
|
335
|
+
roce_subnetworks=roce_subnets,
|
|
333
336
|
allocate_public_ip=allocate_public_ip,
|
|
334
337
|
placement_policy=placement_policy,
|
|
335
338
|
)
|
|
@@ -339,6 +342,13 @@ class GCPCompute(
|
|
|
339
342
|
# If the request succeeds, we'll probably timeout and update_provisioning_data() will get hostname.
|
|
340
343
|
operation = self.instances_client.insert(request=request)
|
|
341
344
|
gcp_resources.wait_for_extended_operation(operation, timeout=30)
|
|
345
|
+
except google.api_core.exceptions.BadRequest as e:
|
|
346
|
+
if "Network profile only allows resource creation in location" in e.message:
|
|
347
|
+
# A hack to find the correct RoCE VPC zone by trial and error.
|
|
348
|
+
# Could be better to find it via the API.
|
|
349
|
+
logger.debug("Got GCP error when provisioning a VM: %s", e)
|
|
350
|
+
continue
|
|
351
|
+
raise
|
|
342
352
|
except (
|
|
343
353
|
google.api_core.exceptions.ServiceUnavailable,
|
|
344
354
|
google.api_core.exceptions.NotFound,
|
|
@@ -487,11 +497,7 @@ class GCPCompute(
|
|
|
487
497
|
)
|
|
488
498
|
# Choose any usable subnet in a VPC.
|
|
489
499
|
# Configuring a specific subnet per region is not supported yet.
|
|
490
|
-
subnetwork = _get_vpc_subnet(
|
|
491
|
-
subnetworks_client=self.subnetworks_client,
|
|
492
|
-
config=self.config,
|
|
493
|
-
region=configuration.region,
|
|
494
|
-
)
|
|
500
|
+
subnetwork = self._get_vpc_subnet(configuration.region)
|
|
495
501
|
|
|
496
502
|
labels = {
|
|
497
503
|
"owner": "dstack",
|
|
@@ -793,10 +799,6 @@ class GCPCompute(
|
|
|
793
799
|
instance_id,
|
|
794
800
|
)
|
|
795
801
|
|
|
796
|
-
@cachedmethod(
|
|
797
|
-
cache=lambda self: self._extra_subnets_cache,
|
|
798
|
-
lock=lambda self: self._extra_subnets_cache_lock,
|
|
799
|
-
)
|
|
800
802
|
def _get_extra_subnets(
|
|
801
803
|
self,
|
|
802
804
|
region: str,
|
|
@@ -808,15 +810,16 @@ class GCPCompute(
|
|
|
808
810
|
subnets_num = 8
|
|
809
811
|
elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
|
|
810
812
|
subnets_num = 4
|
|
813
|
+
elif instance_type_name == "a4-highgpu-8g":
|
|
814
|
+
subnets_num = 1 # 1 main + 1 extra + 8 RoCE
|
|
811
815
|
else:
|
|
812
816
|
return []
|
|
813
817
|
extra_subnets = []
|
|
814
818
|
for vpc_name in self.config.extra_vpcs[:subnets_num]:
|
|
815
819
|
subnet = gcp_resources.get_vpc_subnet_or_error(
|
|
816
|
-
subnetworks_client=self.subnetworks_client,
|
|
817
|
-
vpc_project_id=self.config.vpc_project_id or self.config.project_id,
|
|
818
820
|
vpc_name=vpc_name,
|
|
819
821
|
region=region,
|
|
822
|
+
usable_subnets=self._list_usable_subnets(),
|
|
820
823
|
)
|
|
821
824
|
vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
|
|
822
825
|
project_id=self.config.vpc_project_id or self.config.project_id,
|
|
@@ -825,6 +828,58 @@ class GCPCompute(
|
|
|
825
828
|
extra_subnets.append((vpc_resource_name, subnet))
|
|
826
829
|
return extra_subnets
|
|
827
830
|
|
|
831
|
+
def _get_roce_subnets(
|
|
832
|
+
self,
|
|
833
|
+
region: str,
|
|
834
|
+
instance_type_name: str,
|
|
835
|
+
) -> List[Tuple[str, str]]:
|
|
836
|
+
if not self.config.roce_vpcs:
|
|
837
|
+
return []
|
|
838
|
+
if instance_type_name == "a4-highgpu-8g":
|
|
839
|
+
nics_num = 8
|
|
840
|
+
else:
|
|
841
|
+
return []
|
|
842
|
+
roce_vpc = self.config.roce_vpcs[0] # roce_vpcs is validated to have at most 1 item
|
|
843
|
+
subnets = gcp_resources.get_vpc_subnets(
|
|
844
|
+
vpc_name=roce_vpc,
|
|
845
|
+
region=region,
|
|
846
|
+
usable_subnets=self._list_usable_subnets(),
|
|
847
|
+
)
|
|
848
|
+
if len(subnets) < nics_num:
|
|
849
|
+
raise ComputeError(
|
|
850
|
+
f"{instance_type_name} requires {nics_num} RoCE subnets,"
|
|
851
|
+
f" but only {len(subnets)} are available in VPC {roce_vpc}"
|
|
852
|
+
)
|
|
853
|
+
vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
|
|
854
|
+
project_id=self.config.vpc_project_id or self.config.project_id,
|
|
855
|
+
vpc_name=roce_vpc,
|
|
856
|
+
)
|
|
857
|
+
nic_subnets = []
|
|
858
|
+
for subnet in subnets[:nics_num]:
|
|
859
|
+
nic_subnets.append((vpc_resource_name, subnet))
|
|
860
|
+
return nic_subnets
|
|
861
|
+
|
|
862
|
+
@cachedmethod(
|
|
863
|
+
cache=lambda self: self._usable_subnets_cache,
|
|
864
|
+
lock=lambda self: self._usable_subnets_cache_lock,
|
|
865
|
+
)
|
|
866
|
+
def _list_usable_subnets(self) -> list[compute_v1.UsableSubnetwork]:
|
|
867
|
+
# To avoid hitting the `ListUsable requests per minute` system limit, we fetch all subnets
|
|
868
|
+
# at once and cache them
|
|
869
|
+
return gcp_resources.list_project_usable_subnets(
|
|
870
|
+
subnetworks_client=self.subnetworks_client,
|
|
871
|
+
project_id=self.config.vpc_project_id or self.config.project_id,
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
def _get_vpc_subnet(self, region: str) -> Optional[str]:
|
|
875
|
+
if self.config.vpc_name is None:
|
|
876
|
+
return None
|
|
877
|
+
return gcp_resources.get_vpc_subnet_or_error(
|
|
878
|
+
vpc_name=self.config.vpc_name,
|
|
879
|
+
region=region,
|
|
880
|
+
usable_subnets=self._list_usable_subnets(),
|
|
881
|
+
)
|
|
882
|
+
|
|
828
883
|
|
|
829
884
|
def _supported_instances_and_zones(
|
|
830
885
|
regions: List[str],
|
|
@@ -867,8 +922,8 @@ def _has_gpu_quota(quotas: Dict[str, float], resources: Resources) -> bool:
|
|
|
867
922
|
gpu = resources.gpus[0]
|
|
868
923
|
if _is_tpu(gpu.name):
|
|
869
924
|
return True
|
|
870
|
-
if gpu.name
|
|
871
|
-
# H100 and H100_MEGA quotas are not returned by `regions_client.list`
|
|
925
|
+
if gpu.name in ["B200", "H100"]:
|
|
926
|
+
# B200, H100 and H100_MEGA quotas are not returned by `regions_client.list`
|
|
872
927
|
return True
|
|
873
928
|
quota_name = f"NVIDIA_{gpu.name}_GPUS"
|
|
874
929
|
if gpu.name == "A100" and gpu.memory_mib == 80 * 1024:
|
|
@@ -889,21 +944,6 @@ def _unique_instance_name(instance: InstanceType) -> str:
|
|
|
889
944
|
return f"{name}-{gpu.name}-{gpu.memory_mib}"
|
|
890
945
|
|
|
891
946
|
|
|
892
|
-
def _get_vpc_subnet(
|
|
893
|
-
subnetworks_client: compute_v1.SubnetworksClient,
|
|
894
|
-
config: GCPConfig,
|
|
895
|
-
region: str,
|
|
896
|
-
) -> Optional[str]:
|
|
897
|
-
if config.vpc_name is None:
|
|
898
|
-
return None
|
|
899
|
-
return gcp_resources.get_vpc_subnet_or_error(
|
|
900
|
-
subnetworks_client=subnetworks_client,
|
|
901
|
-
vpc_project_id=config.vpc_project_id or config.project_id,
|
|
902
|
-
vpc_name=config.vpc_name,
|
|
903
|
-
region=region,
|
|
904
|
-
)
|
|
905
|
-
|
|
906
|
-
|
|
907
947
|
@dataclass
|
|
908
948
|
class GCPImage:
|
|
909
949
|
id: str
|
|
@@ -202,5 +202,5 @@ class GCPConfigurator(
|
|
|
202
202
|
)
|
|
203
203
|
except BackendError as e:
|
|
204
204
|
raise ServerClientError(e.args[0])
|
|
205
|
-
# Not checking config.
|
|
205
|
+
# Not checking config.extra_vpcs and config.roce_vpcs so that users are not required to configure subnets for all regions
|
|
206
206
|
# but only for regions they intend to use. Validation will be done on provisioning.
|
|
@@ -41,11 +41,24 @@ class GCPBackendConfig(CoreModel):
|
|
|
41
41
|
Optional[List[str]],
|
|
42
42
|
Field(
|
|
43
43
|
description=(
|
|
44
|
-
"The names of additional VPCs used for
|
|
44
|
+
"The names of additional VPCs used for multi-NIC instances, such as those that support GPUDirect."
|
|
45
|
+
" Specify eight VPCs to maximize bandwidth in clusters with eight-GPU instances."
|
|
45
46
|
" Each VPC must have a subnet and a firewall rule allowing internal traffic across all subnets"
|
|
46
47
|
)
|
|
47
48
|
),
|
|
48
49
|
] = None
|
|
50
|
+
roce_vpcs: Annotated[
|
|
51
|
+
Optional[List[str]],
|
|
52
|
+
Field(
|
|
53
|
+
description=(
|
|
54
|
+
"The names of additional VPCs with the RoCE network profile."
|
|
55
|
+
" Used for RDMA on GPU instances that support the MRDMA interface type."
|
|
56
|
+
" A VPC should have eight subnets to maximize the bandwidth in clusters"
|
|
57
|
+
" with eight-GPU instances."
|
|
58
|
+
),
|
|
59
|
+
max_items=1, # The currently supported instance types only need one VPC with eight subnets.
|
|
60
|
+
),
|
|
61
|
+
] = None
|
|
49
62
|
vpc_project_id: Annotated[
|
|
50
63
|
Optional[str],
|
|
51
64
|
Field(description="The shared VPC hosted project ID. Required for shared VPC only"),
|