dstack 0.19.30__py3-none-any.whl → 0.19.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (39) hide show
  1. dstack/_internal/cli/commands/__init__.py +8 -0
  2. dstack/_internal/cli/commands/project.py +27 -20
  3. dstack/_internal/cli/commands/server.py +5 -0
  4. dstack/_internal/cli/main.py +1 -3
  5. dstack/_internal/core/backends/aws/compute.py +2 -0
  6. dstack/_internal/core/backends/azure/compute.py +2 -0
  7. dstack/_internal/core/backends/base/compute.py +32 -9
  8. dstack/_internal/core/backends/base/offers.py +1 -0
  9. dstack/_internal/core/backends/cloudrift/compute.py +2 -0
  10. dstack/_internal/core/backends/cudo/compute.py +2 -0
  11. dstack/_internal/core/backends/datacrunch/compute.py +2 -0
  12. dstack/_internal/core/backends/digitalocean_base/compute.py +2 -0
  13. dstack/_internal/core/backends/features.py +5 -0
  14. dstack/_internal/core/backends/gcp/compute.py +74 -34
  15. dstack/_internal/core/backends/gcp/configurator.py +1 -1
  16. dstack/_internal/core/backends/gcp/models.py +14 -1
  17. dstack/_internal/core/backends/gcp/resources.py +35 -12
  18. dstack/_internal/core/backends/hotaisle/compute.py +2 -0
  19. dstack/_internal/core/backends/kubernetes/compute.py +466 -213
  20. dstack/_internal/core/backends/kubernetes/models.py +13 -16
  21. dstack/_internal/core/backends/kubernetes/utils.py +145 -8
  22. dstack/_internal/core/backends/lambdalabs/compute.py +2 -0
  23. dstack/_internal/core/backends/local/compute.py +2 -0
  24. dstack/_internal/core/backends/nebius/compute.py +2 -0
  25. dstack/_internal/core/backends/oci/compute.py +2 -0
  26. dstack/_internal/core/backends/template/compute.py.jinja +2 -0
  27. dstack/_internal/core/backends/tensordock/compute.py +2 -0
  28. dstack/_internal/core/backends/vultr/compute.py +2 -0
  29. dstack/_internal/server/background/tasks/common.py +2 -0
  30. dstack/_internal/server/background/tasks/process_instances.py +2 -2
  31. dstack/_internal/server/services/offers.py +7 -1
  32. dstack/_internal/server/testing/common.py +2 -0
  33. dstack/_internal/server/utils/provisioning.py +3 -10
  34. dstack/version.py +1 -1
  35. {dstack-0.19.30.dist-info → dstack-0.19.31.dist-info}/METADATA +11 -9
  36. {dstack-0.19.30.dist-info → dstack-0.19.31.dist-info}/RECORD +39 -39
  37. {dstack-0.19.30.dist-info → dstack-0.19.31.dist-info}/WHEEL +0 -0
  38. {dstack-0.19.30.dist-info → dstack-0.19.31.dist-info}/entry_points.txt +0 -0
  39. {dstack-0.19.30.dist-info → dstack-0.19.31.dist-info}/licenses/LICENSE.md +0 -0
@@ -7,6 +7,7 @@ from typing import ClassVar, Optional
7
7
  from rich_argparse import RichHelpFormatter
8
8
 
9
9
  from dstack._internal.cli.services.completion import ProjectNameCompleter
10
+ from dstack._internal.cli.utils.common import configure_logging
10
11
  from dstack._internal.core.errors import CLIError
11
12
  from dstack.api import Client
12
13
 
@@ -52,9 +53,16 @@ class BaseCommand(ABC):
52
53
 
53
54
  @abstractmethod
54
55
  def _command(self, args: argparse.Namespace):
56
+ self._configure_logging()
55
57
  if not self.ACCEPT_EXTRA_ARGS and args.extra_args:
56
58
  raise CLIError(f"Unrecognized arguments: {shlex.join(args.extra_args)}")
57
59
 
60
+ def _configure_logging(self) -> None:
61
+ """
62
+ Override this method to configure command-specific logging
63
+ """
64
+ configure_logging()
65
+
58
66
 
59
67
  class APIBaseCommand(BaseCommand):
60
68
  api: Client
@@ -1,11 +1,12 @@
1
1
  import argparse
2
+ from typing import Any, Union
2
3
 
3
4
  from requests import HTTPError
4
5
  from rich.table import Table
5
6
 
6
7
  import dstack.api.server
7
8
  from dstack._internal.cli.commands import BaseCommand
8
- from dstack._internal.cli.utils.common import confirm_ask, console
9
+ from dstack._internal.cli.utils.common import add_row_from_dict, confirm_ask, console
9
10
  from dstack._internal.core.errors import ClientError, CLIError
10
11
  from dstack._internal.core.services.configs import ConfigManager
11
12
  from dstack._internal.utils.logging import get_logger
@@ -58,6 +59,10 @@ class ProjectCommand(BaseCommand):
58
59
  # List subcommand
59
60
  list_parser = subparsers.add_parser("list", help="List configured projects")
60
61
  list_parser.set_defaults(subfunc=self._list)
62
+ for parser in [self._parser, list_parser]:
63
+ parser.add_argument(
64
+ "-v", "--verbose", action="store_true", help="Show more information"
65
+ )
61
66
 
62
67
  # Set default subcommand
63
68
  set_default_parser = subparsers.add_parser("set-default", help="Set default project")
@@ -122,30 +127,32 @@ class ProjectCommand(BaseCommand):
122
127
  table = Table(box=None)
123
128
  table.add_column("PROJECT", style="bold", no_wrap=True)
124
129
  table.add_column("URL", style="grey58")
125
- table.add_column("USER", style="grey58")
130
+ if args.verbose:
131
+ table.add_column("USER", style="grey58")
126
132
  table.add_column("DEFAULT", justify="center")
127
133
 
128
134
  for project_config in config_manager.list_project_configs():
129
135
  project_name = project_config.name
130
136
  is_default = project_name == default_project.name if default_project else False
131
-
132
- # Get username from API
133
- try:
134
- api_client = dstack.api.server.APIClient(
135
- base_url=project_config.url, token=project_config.token
136
- )
137
- user_info = api_client.users.get_my_user()
138
- username = user_info.username
139
- except ClientError:
140
- username = "(invalid token)"
141
-
142
- table.add_row(
143
- project_name,
144
- project_config.url,
145
- username,
146
- "✓" if is_default else "",
147
- style="bold" if is_default else None,
148
- )
137
+ row: dict[Union[str, int], Any] = {
138
+ "PROJECT": project_name,
139
+ "URL": project_config.url,
140
+ "DEFAULT": "✓" if is_default else "",
141
+ }
142
+
143
+ if args.verbose:
144
+ # Get username from API
145
+ try:
146
+ api_client = dstack.api.server.APIClient(
147
+ base_url=project_config.url, token=project_config.token
148
+ )
149
+ user_info = api_client.users.get_my_user()
150
+ username = user_info.username
151
+ except ClientError:
152
+ username = "(invalid token)"
153
+ row["USER"] = username
154
+
155
+ add_row_from_dict(table, row, style="bold" if is_default else None)
149
156
 
150
157
  console.print(table)
151
158
 
@@ -82,3 +82,8 @@ class ServerCommand(BaseCommand):
82
82
  log_level=uvicorn_log_level,
83
83
  workers=1,
84
84
  )
85
+
86
+ def _configure_logging(self) -> None:
87
+ # Server logging is configured in the FastAPI lifespan function.
88
+ # No need to configure CLI logging.
89
+ pass
@@ -22,7 +22,7 @@ from dstack._internal.cli.commands.server import ServerCommand
22
22
  from dstack._internal.cli.commands.stats import StatsCommand
23
23
  from dstack._internal.cli.commands.stop import StopCommand
24
24
  from dstack._internal.cli.commands.volume import VolumeCommand
25
- from dstack._internal.cli.utils.common import _colors, configure_logging, console
25
+ from dstack._internal.cli.utils.common import _colors, console
26
26
  from dstack._internal.cli.utils.updates import check_for_updates
27
27
  from dstack._internal.core.errors import ClientError, CLIError, ConfigurationError, SSHError
28
28
  from dstack._internal.core.services.ssh.client import get_ssh_client_info
@@ -39,8 +39,6 @@ def main():
39
39
  RichHelpFormatter.styles["argparse.groups"] = "bold grey74"
40
40
  RichHelpFormatter.styles["argparse.text"] = "grey74"
41
41
 
42
- configure_logging()
43
-
44
42
  parser = argparse.ArgumentParser(
45
43
  description=(
46
44
  "Not sure where to start?"
@@ -24,6 +24,7 @@ from dstack._internal.core.backends.base.compute import (
24
24
  ComputeWithMultinodeSupport,
25
25
  ComputeWithPlacementGroupSupport,
26
26
  ComputeWithPrivateGatewaySupport,
27
+ ComputeWithPrivilegedSupport,
27
28
  ComputeWithReservationSupport,
28
29
  ComputeWithVolumeSupport,
29
30
  generate_unique_gateway_instance_name,
@@ -90,6 +91,7 @@ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs):
90
91
  class AWSCompute(
91
92
  ComputeWithAllOffersCached,
92
93
  ComputeWithCreateInstanceSupport,
94
+ ComputeWithPrivilegedSupport,
93
95
  ComputeWithMultinodeSupport,
94
96
  ComputeWithReservationSupport,
95
97
  ComputeWithPlacementGroupSupport,
@@ -43,6 +43,7 @@ from dstack._internal.core.backends.base.compute import (
43
43
  ComputeWithCreateInstanceSupport,
44
44
  ComputeWithGatewaySupport,
45
45
  ComputeWithMultinodeSupport,
46
+ ComputeWithPrivilegedSupport,
46
47
  generate_unique_gateway_instance_name,
47
48
  generate_unique_instance_name,
48
49
  get_gateway_user_data,
@@ -78,6 +79,7 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=Memory.parse("30GB"), max=Memory.pars
78
79
  class AzureCompute(
79
80
  ComputeWithAllOffersCached,
80
81
  ComputeWithCreateInstanceSupport,
82
+ ComputeWithPrivilegedSupport,
81
83
  ComputeWithMultinodeSupport,
82
84
  ComputeWithGatewaySupport,
83
85
  Compute,
@@ -5,14 +5,16 @@ import string
5
5
  import threading
6
6
  from abc import ABC, abstractmethod
7
7
  from collections.abc import Iterable
8
+ from enum import Enum
8
9
  from functools import lru_cache
9
10
  from pathlib import Path
10
- from typing import Callable, Dict, List, Literal, Optional
11
+ from typing import Callable, Dict, List, Optional
11
12
 
12
13
  import git
13
14
  import requests
14
15
  import yaml
15
16
  from cachetools import TTLCache, cachedmethod
17
+ from gpuhunt import CPUArchitecture
16
18
 
17
19
  from dstack._internal import settings
18
20
  from dstack._internal.core.backends.base.offers import filter_offers_by_requirements
@@ -65,7 +67,21 @@ NVIDIA_GPUS_REQUIRING_PROPRIETARY_KERNEL_MODULES = frozenset(
65
67
  ]
66
68
  )
67
69
 
68
- GoArchType = Literal["amd64", "arm64"]
70
+
71
+ class GoArchType(str, Enum):
72
+ """
73
+ A subset of GOARCH values
74
+ """
75
+
76
+ AMD64 = "amd64"
77
+ ARM64 = "arm64"
78
+
79
+ def to_cpu_architecture(self) -> CPUArchitecture:
80
+ if self == self.AMD64:
81
+ return CPUArchitecture.X86
82
+ if self == self.ARM64:
83
+ return CPUArchitecture.ARM
84
+ assert False, self
69
85
 
70
86
 
71
87
  class Compute(ABC):
@@ -304,6 +320,15 @@ class ComputeWithCreateInstanceSupport(ABC):
304
320
  ]
305
321
 
306
322
 
323
+ class ComputeWithPrivilegedSupport:
324
+ """
325
+ Must be subclassed to support runs with `privileged: true`.
326
+ All VM-based Computes (that is, Computes that use the shim) should subclass this mixin.
327
+ """
328
+
329
+ pass
330
+
331
+
307
332
  class ComputeWithMultinodeSupport:
308
333
  """
309
334
  Must be subclassed to support multinode tasks and cluster fleets.
@@ -704,14 +729,14 @@ def normalize_arch(arch: Optional[str] = None) -> GoArchType:
704
729
  If the arch is not specified, falls back to `amd64`.
705
730
  """
706
731
  if not arch:
707
- return "amd64"
732
+ return GoArchType.AMD64
708
733
  arch_lower = arch.lower()
709
734
  if "32" in arch_lower or arch_lower in ["i386", "i686"]:
710
735
  raise ValueError(f"32-bit architectures are not supported: {arch}")
711
736
  if arch_lower.startswith("x86") or arch_lower.startswith("amd"):
712
- return "amd64"
737
+ return GoArchType.AMD64
713
738
  if arch_lower.startswith("arm") or arch_lower.startswith("aarch"):
714
- return "arm64"
739
+ return GoArchType.ARM64
715
740
  raise ValueError(f"Unsupported architecture: {arch}")
716
741
 
717
742
 
@@ -727,8 +752,7 @@ def get_dstack_runner_download_url(arch: Optional[str] = None) -> str:
727
752
  "/{version}/binaries/dstack-runner-linux-{arch}"
728
753
  )
729
754
  version = get_dstack_runner_version()
730
- arch = normalize_arch(arch)
731
- return url_template.format(version=version, arch=arch)
755
+ return url_template.format(version=version, arch=normalize_arch(arch).value)
732
756
 
733
757
 
734
758
  def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
@@ -743,8 +767,7 @@ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
743
767
  "/{version}/binaries/dstack-shim-linux-{arch}"
744
768
  )
745
769
  version = get_dstack_runner_version()
746
- arch = normalize_arch(arch)
747
- return url_template.format(version=version, arch=arch)
770
+ return url_template.format(version=version, arch=normalize_arch(arch).value)
748
771
 
749
772
 
750
773
  def get_setup_cloud_instance_commands(
@@ -22,6 +22,7 @@ from dstack._internal.utils.common import get_or_error
22
22
  SUPPORTED_GPUHUNT_FLAGS = [
23
23
  "oci-spot",
24
24
  "lambda-arm",
25
+ "gcp-a4",
25
26
  ]
26
27
 
27
28
 
@@ -4,6 +4,7 @@ from dstack._internal.core.backends.base.compute import (
4
4
  Compute,
5
5
  ComputeWithAllOffersCached,
6
6
  ComputeWithCreateInstanceSupport,
7
+ ComputeWithPrivilegedSupport,
7
8
  get_shim_commands,
8
9
  )
9
10
  from dstack._internal.core.backends.base.offers import get_catalog_offers
@@ -27,6 +28,7 @@ logger = get_logger(__name__)
27
28
  class CloudRiftCompute(
28
29
  ComputeWithAllOffersCached,
29
30
  ComputeWithCreateInstanceSupport,
31
+ ComputeWithPrivilegedSupport,
30
32
  Compute,
31
33
  ):
32
34
  def __init__(self, config: CloudRiftConfig):
@@ -6,6 +6,7 @@ from dstack._internal.core.backends.base.backend import Compute
6
6
  from dstack._internal.core.backends.base.compute import (
7
7
  ComputeWithCreateInstanceSupport,
8
8
  ComputeWithFilteredOffersCached,
9
+ ComputeWithPrivilegedSupport,
9
10
  generate_unique_instance_name,
10
11
  get_shim_commands,
11
12
  )
@@ -32,6 +33,7 @@ MAX_RESOURCE_NAME_LEN = 30
32
33
  class CudoCompute(
33
34
  ComputeWithFilteredOffersCached,
34
35
  ComputeWithCreateInstanceSupport,
36
+ ComputeWithPrivilegedSupport,
35
37
  Compute,
36
38
  ):
37
39
  def __init__(self, config: CudoConfig):
@@ -8,6 +8,7 @@ from dstack._internal.core.backends.base.backend import Compute
8
8
  from dstack._internal.core.backends.base.compute import (
9
9
  ComputeWithAllOffersCached,
10
10
  ComputeWithCreateInstanceSupport,
11
+ ComputeWithPrivilegedSupport,
11
12
  generate_unique_instance_name,
12
13
  get_shim_commands,
13
14
  )
@@ -39,6 +40,7 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=IMAGE_SIZE, max=None)
39
40
  class DataCrunchCompute(
40
41
  ComputeWithAllOffersCached,
41
42
  ComputeWithCreateInstanceSupport,
43
+ ComputeWithPrivilegedSupport,
42
44
  Compute,
43
45
  ):
44
46
  def __init__(self, config: DataCrunchConfig):
@@ -7,6 +7,7 @@ from dstack._internal.core.backends.base.backend import Compute
7
7
  from dstack._internal.core.backends.base.compute import (
8
8
  ComputeWithAllOffersCached,
9
9
  ComputeWithCreateInstanceSupport,
10
+ ComputeWithPrivilegedSupport,
10
11
  generate_unique_instance_name,
11
12
  get_user_data,
12
13
  )
@@ -40,6 +41,7 @@ DOCKER_INSTALL_COMMANDS = [
40
41
  class BaseDigitalOceanCompute(
41
42
  ComputeWithAllOffersCached,
42
43
  ComputeWithCreateInstanceSupport,
44
+ ComputeWithPrivilegedSupport,
43
45
  Compute,
44
46
  ):
45
47
  def __init__(self, config: BaseDigitalOceanConfig, api_url: str, type: BackendType):
@@ -4,6 +4,7 @@ from dstack._internal.core.backends.base.compute import (
4
4
  ComputeWithMultinodeSupport,
5
5
  ComputeWithPlacementGroupSupport,
6
6
  ComputeWithPrivateGatewaySupport,
7
+ ComputeWithPrivilegedSupport,
7
8
  ComputeWithReservationSupport,
8
9
  ComputeWithVolumeSupport,
9
10
  )
@@ -38,6 +39,10 @@ BACKENDS_WITH_CREATE_INSTANCE_SUPPORT = _get_backends_with_compute_feature(
38
39
  configurator_classes=_configurator_classes,
39
40
  compute_feature_class=ComputeWithCreateInstanceSupport,
40
41
  )
42
+ BACKENDS_WITH_PRIVILEGED_SUPPORT = _get_backends_with_compute_feature(
43
+ configurator_classes=_configurator_classes,
44
+ compute_feature_class=ComputeWithPrivilegedSupport,
45
+ )
41
46
  BACKENDS_WITH_MULTINODE_SUPPORT = [BackendType.REMOTE] + _get_backends_with_compute_feature(
42
47
  configurator_classes=_configurator_classes,
43
48
  compute_feature_class=ComputeWithMultinodeSupport,
@@ -23,6 +23,7 @@ from dstack._internal.core.backends.base.compute import (
23
23
  ComputeWithMultinodeSupport,
24
24
  ComputeWithPlacementGroupSupport,
25
25
  ComputeWithPrivateGatewaySupport,
26
+ ComputeWithPrivilegedSupport,
26
27
  ComputeWithVolumeSupport,
27
28
  generate_unique_gateway_instance_name,
28
29
  generate_unique_instance_name,
@@ -90,6 +91,7 @@ class GCPVolumeDiskBackendData(CoreModel):
90
91
  class GCPCompute(
91
92
  ComputeWithAllOffersCached,
92
93
  ComputeWithCreateInstanceSupport,
94
+ ComputeWithPrivilegedSupport,
93
95
  ComputeWithMultinodeSupport,
94
96
  ComputeWithPlacementGroupSupport,
95
97
  ComputeWithGatewaySupport,
@@ -111,8 +113,8 @@ class GCPCompute(
111
113
  self.resource_policies_client = compute_v1.ResourcePoliciesClient(
112
114
  credentials=self.credentials
113
115
  )
114
- self._extra_subnets_cache_lock = threading.Lock()
115
- self._extra_subnets_cache = TTLCache(maxsize=30, ttl=60)
116
+ self._usable_subnets_cache_lock = threading.Lock()
117
+ self._usable_subnets_cache = TTLCache(maxsize=1, ttl=120)
116
118
 
117
119
  def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
118
120
  regions = get_or_error(self.config.regions)
@@ -203,12 +205,12 @@ class GCPCompute(
203
205
  disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
204
206
  # Choose any usable subnet in a VPC.
205
207
  # Configuring a specific subnet per region is not supported yet.
206
- subnetwork = _get_vpc_subnet(
207
- subnetworks_client=self.subnetworks_client,
208
- config=self.config,
208
+ subnetwork = self._get_vpc_subnet(instance_offer.region)
209
+ extra_subnets = self._get_extra_subnets(
209
210
  region=instance_offer.region,
211
+ instance_type_name=instance_offer.instance.name,
210
212
  )
211
- extra_subnets = self._get_extra_subnets(
213
+ roce_subnets = self._get_roce_subnets(
212
214
  region=instance_offer.region,
213
215
  instance_type_name=instance_offer.instance.name,
214
216
  )
@@ -330,6 +332,7 @@ class GCPCompute(
330
332
  network=self.config.vpc_resource_name,
331
333
  subnetwork=subnetwork,
332
334
  extra_subnetworks=extra_subnets,
335
+ roce_subnetworks=roce_subnets,
333
336
  allocate_public_ip=allocate_public_ip,
334
337
  placement_policy=placement_policy,
335
338
  )
@@ -339,6 +342,13 @@ class GCPCompute(
339
342
  # If the request succeeds, we'll probably timeout and update_provisioning_data() will get hostname.
340
343
  operation = self.instances_client.insert(request=request)
341
344
  gcp_resources.wait_for_extended_operation(operation, timeout=30)
345
+ except google.api_core.exceptions.BadRequest as e:
346
+ if "Network profile only allows resource creation in location" in e.message:
347
+ # A hack to find the correct RoCE VPC zone by trial and error.
348
+ # Could be better to find it via the API.
349
+ logger.debug("Got GCP error when provisioning a VM: %s", e)
350
+ continue
351
+ raise
342
352
  except (
343
353
  google.api_core.exceptions.ServiceUnavailable,
344
354
  google.api_core.exceptions.NotFound,
@@ -487,11 +497,7 @@ class GCPCompute(
487
497
  )
488
498
  # Choose any usable subnet in a VPC.
489
499
  # Configuring a specific subnet per region is not supported yet.
490
- subnetwork = _get_vpc_subnet(
491
- subnetworks_client=self.subnetworks_client,
492
- config=self.config,
493
- region=configuration.region,
494
- )
500
+ subnetwork = self._get_vpc_subnet(configuration.region)
495
501
 
496
502
  labels = {
497
503
  "owner": "dstack",
@@ -793,10 +799,6 @@ class GCPCompute(
793
799
  instance_id,
794
800
  )
795
801
 
796
- @cachedmethod(
797
- cache=lambda self: self._extra_subnets_cache,
798
- lock=lambda self: self._extra_subnets_cache_lock,
799
- )
800
802
  def _get_extra_subnets(
801
803
  self,
802
804
  region: str,
@@ -808,15 +810,16 @@ class GCPCompute(
808
810
  subnets_num = 8
809
811
  elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
810
812
  subnets_num = 4
813
+ elif instance_type_name == "a4-highgpu-8g":
814
+ subnets_num = 1 # 1 main + 1 extra + 8 RoCE
811
815
  else:
812
816
  return []
813
817
  extra_subnets = []
814
818
  for vpc_name in self.config.extra_vpcs[:subnets_num]:
815
819
  subnet = gcp_resources.get_vpc_subnet_or_error(
816
- subnetworks_client=self.subnetworks_client,
817
- vpc_project_id=self.config.vpc_project_id or self.config.project_id,
818
820
  vpc_name=vpc_name,
819
821
  region=region,
822
+ usable_subnets=self._list_usable_subnets(),
820
823
  )
821
824
  vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
822
825
  project_id=self.config.vpc_project_id or self.config.project_id,
@@ -825,6 +828,58 @@ class GCPCompute(
825
828
  extra_subnets.append((vpc_resource_name, subnet))
826
829
  return extra_subnets
827
830
 
831
+ def _get_roce_subnets(
832
+ self,
833
+ region: str,
834
+ instance_type_name: str,
835
+ ) -> List[Tuple[str, str]]:
836
+ if not self.config.roce_vpcs:
837
+ return []
838
+ if instance_type_name == "a4-highgpu-8g":
839
+ nics_num = 8
840
+ else:
841
+ return []
842
+ roce_vpc = self.config.roce_vpcs[0] # roce_vpcs is validated to have at most 1 item
843
+ subnets = gcp_resources.get_vpc_subnets(
844
+ vpc_name=roce_vpc,
845
+ region=region,
846
+ usable_subnets=self._list_usable_subnets(),
847
+ )
848
+ if len(subnets) < nics_num:
849
+ raise ComputeError(
850
+ f"{instance_type_name} requires {nics_num} RoCE subnets,"
851
+ f" but only {len(subnets)} are available in VPC {roce_vpc}"
852
+ )
853
+ vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
854
+ project_id=self.config.vpc_project_id or self.config.project_id,
855
+ vpc_name=roce_vpc,
856
+ )
857
+ nic_subnets = []
858
+ for subnet in subnets[:nics_num]:
859
+ nic_subnets.append((vpc_resource_name, subnet))
860
+ return nic_subnets
861
+
862
+ @cachedmethod(
863
+ cache=lambda self: self._usable_subnets_cache,
864
+ lock=lambda self: self._usable_subnets_cache_lock,
865
+ )
866
+ def _list_usable_subnets(self) -> list[compute_v1.UsableSubnetwork]:
867
+ # To avoid hitting the `ListUsable requests per minute` system limit, we fetch all subnets
868
+ # at once and cache them
869
+ return gcp_resources.list_project_usable_subnets(
870
+ subnetworks_client=self.subnetworks_client,
871
+ project_id=self.config.vpc_project_id or self.config.project_id,
872
+ )
873
+
874
+ def _get_vpc_subnet(self, region: str) -> Optional[str]:
875
+ if self.config.vpc_name is None:
876
+ return None
877
+ return gcp_resources.get_vpc_subnet_or_error(
878
+ vpc_name=self.config.vpc_name,
879
+ region=region,
880
+ usable_subnets=self._list_usable_subnets(),
881
+ )
882
+
828
883
 
829
884
  def _supported_instances_and_zones(
830
885
  regions: List[str],
@@ -867,8 +922,8 @@ def _has_gpu_quota(quotas: Dict[str, float], resources: Resources) -> bool:
867
922
  gpu = resources.gpus[0]
868
923
  if _is_tpu(gpu.name):
869
924
  return True
870
- if gpu.name == "H100":
871
- # H100 and H100_MEGA quotas are not returned by `regions_client.list`
925
+ if gpu.name in ["B200", "H100"]:
926
+ # B200, H100 and H100_MEGA quotas are not returned by `regions_client.list`
872
927
  return True
873
928
  quota_name = f"NVIDIA_{gpu.name}_GPUS"
874
929
  if gpu.name == "A100" and gpu.memory_mib == 80 * 1024:
@@ -889,21 +944,6 @@ def _unique_instance_name(instance: InstanceType) -> str:
889
944
  return f"{name}-{gpu.name}-{gpu.memory_mib}"
890
945
 
891
946
 
892
- def _get_vpc_subnet(
893
- subnetworks_client: compute_v1.SubnetworksClient,
894
- config: GCPConfig,
895
- region: str,
896
- ) -> Optional[str]:
897
- if config.vpc_name is None:
898
- return None
899
- return gcp_resources.get_vpc_subnet_or_error(
900
- subnetworks_client=subnetworks_client,
901
- vpc_project_id=config.vpc_project_id or config.project_id,
902
- vpc_name=config.vpc_name,
903
- region=region,
904
- )
905
-
906
-
907
947
  @dataclass
908
948
  class GCPImage:
909
949
  id: str
@@ -202,5 +202,5 @@ class GCPConfigurator(
202
202
  )
203
203
  except BackendError as e:
204
204
  raise ServerClientError(e.args[0])
205
- # Not checking config.extra_vpc so that users are not required to configure subnets for all regions
205
+ # Not checking config.extra_vpcs and config.roce_vpcs so that users are not required to configure subnets for all regions
206
206
  # but only for regions they intend to use. Validation will be done on provisioning.
@@ -41,11 +41,24 @@ class GCPBackendConfig(CoreModel):
41
41
  Optional[List[str]],
42
42
  Field(
43
43
  description=(
44
- "The names of additional VPCs used for GPUDirect. Specify eight VPCs to maximize bandwidth."
44
+ "The names of additional VPCs used for multi-NIC instances, such as those that support GPUDirect."
45
+ " Specify eight VPCs to maximize bandwidth in clusters with eight-GPU instances."
45
46
  " Each VPC must have a subnet and a firewall rule allowing internal traffic across all subnets"
46
47
  )
47
48
  ),
48
49
  ] = None
50
+ roce_vpcs: Annotated[
51
+ Optional[List[str]],
52
+ Field(
53
+ description=(
54
+ "The names of additional VPCs with the RoCE network profile."
55
+ " Used for RDMA on GPU instances that support the MRDMA interface type."
56
+ " A VPC should have eight subnets to maximize the bandwidth in clusters"
57
+ " with eight-GPU instances."
58
+ ),
59
+ max_items=1, # The currently supported instance types only need one VPC with eight subnets.
60
+ ),
61
+ ] = None
49
62
  vpc_project_id: Annotated[
50
63
  Optional[str],
51
64
  Field(description="The shared VPC hosted project ID. Required for shared VPC only"),