dstack 0.19.6rc1__py3-none-any.whl → 0.19.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (69) hide show
  1. dstack/_internal/cli/services/args.py +2 -2
  2. dstack/_internal/cli/services/configurators/fleet.py +3 -2
  3. dstack/_internal/cli/services/configurators/run.py +50 -4
  4. dstack/_internal/cli/utils/fleet.py +3 -1
  5. dstack/_internal/cli/utils/run.py +25 -28
  6. dstack/_internal/core/backends/aws/compute.py +13 -1
  7. dstack/_internal/core/backends/azure/compute.py +42 -13
  8. dstack/_internal/core/backends/azure/configurator.py +21 -0
  9. dstack/_internal/core/backends/azure/models.py +9 -0
  10. dstack/_internal/core/backends/base/compute.py +101 -27
  11. dstack/_internal/core/backends/base/offers.py +13 -3
  12. dstack/_internal/core/backends/cudo/compute.py +2 -0
  13. dstack/_internal/core/backends/datacrunch/compute.py +2 -0
  14. dstack/_internal/core/backends/gcp/auth.py +1 -1
  15. dstack/_internal/core/backends/gcp/compute.py +51 -35
  16. dstack/_internal/core/backends/gcp/resources.py +6 -1
  17. dstack/_internal/core/backends/lambdalabs/compute.py +20 -8
  18. dstack/_internal/core/backends/local/compute.py +2 -0
  19. dstack/_internal/core/backends/nebius/compute.py +95 -1
  20. dstack/_internal/core/backends/nebius/configurator.py +11 -0
  21. dstack/_internal/core/backends/nebius/fabrics.py +47 -0
  22. dstack/_internal/core/backends/nebius/models.py +8 -0
  23. dstack/_internal/core/backends/nebius/resources.py +29 -0
  24. dstack/_internal/core/backends/oci/compute.py +2 -0
  25. dstack/_internal/core/backends/remote/provisioning.py +27 -2
  26. dstack/_internal/core/backends/template/compute.py.jinja +2 -0
  27. dstack/_internal/core/backends/tensordock/compute.py +2 -0
  28. dstack/_internal/core/backends/vastai/compute.py +2 -1
  29. dstack/_internal/core/backends/vultr/compute.py +5 -1
  30. dstack/_internal/core/errors.py +4 -0
  31. dstack/_internal/core/models/fleets.py +2 -0
  32. dstack/_internal/core/models/instances.py +4 -3
  33. dstack/_internal/core/models/resources.py +80 -3
  34. dstack/_internal/core/models/runs.py +10 -3
  35. dstack/_internal/core/models/volumes.py +1 -1
  36. dstack/_internal/server/background/tasks/process_fleets.py +4 -13
  37. dstack/_internal/server/background/tasks/process_instances.py +176 -55
  38. dstack/_internal/server/background/tasks/process_placement_groups.py +1 -1
  39. dstack/_internal/server/background/tasks/process_prometheus_metrics.py +5 -2
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +1 -1
  41. dstack/_internal/server/models.py +1 -0
  42. dstack/_internal/server/routers/gateways.py +2 -1
  43. dstack/_internal/server/services/config.py +7 -2
  44. dstack/_internal/server/services/fleets.py +24 -26
  45. dstack/_internal/server/services/gateways/__init__.py +17 -2
  46. dstack/_internal/server/services/instances.py +0 -2
  47. dstack/_internal/server/services/offers.py +15 -0
  48. dstack/_internal/server/services/placement.py +27 -6
  49. dstack/_internal/server/services/plugins.py +77 -0
  50. dstack/_internal/server/services/resources.py +21 -0
  51. dstack/_internal/server/services/runs.py +41 -17
  52. dstack/_internal/server/services/volumes.py +10 -1
  53. dstack/_internal/server/testing/common.py +35 -26
  54. dstack/_internal/utils/common.py +22 -9
  55. dstack/_internal/utils/json_schema.py +6 -3
  56. dstack/api/__init__.py +1 -0
  57. dstack/api/server/__init__.py +8 -1
  58. dstack/api/server/_fleets.py +16 -0
  59. dstack/api/server/_runs.py +44 -3
  60. dstack/plugins/__init__.py +8 -0
  61. dstack/plugins/_base.py +72 -0
  62. dstack/plugins/_models.py +8 -0
  63. dstack/plugins/_utils.py +19 -0
  64. dstack/version.py +1 -1
  65. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/METADATA +14 -2
  66. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/RECORD +69 -62
  67. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/WHEEL +0 -0
  68. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/entry_points.txt +0 -0
  69. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/licenses/LICENSE.md +0 -0
@@ -6,7 +6,7 @@ import threading
6
6
  from abc import ABC, abstractmethod
7
7
  from functools import lru_cache
8
8
  from pathlib import Path
9
- from typing import Dict, List, Optional
9
+ from typing import Dict, List, Literal, Optional
10
10
 
11
11
  import git
12
12
  import requests
@@ -25,6 +25,7 @@ from dstack._internal.core.models.gateways import (
25
25
  )
26
26
  from dstack._internal.core.models.instances import (
27
27
  InstanceConfiguration,
28
+ InstanceOffer,
28
29
  InstanceOfferWithAvailability,
29
30
  SSHKey,
30
31
  )
@@ -44,6 +45,8 @@ logger = get_logger(__name__)
44
45
  DSTACK_SHIM_BINARY_NAME = "dstack-shim"
45
46
  DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
46
47
 
48
+ GoArchType = Literal["amd64", "arm64"]
49
+
47
50
 
48
51
  class Compute(ABC):
49
52
  """
@@ -144,6 +147,7 @@ class ComputeWithCreateInstanceSupport(ABC):
144
147
  self,
145
148
  instance_offer: InstanceOfferWithAvailability,
146
149
  instance_config: InstanceConfiguration,
150
+ placement_group: Optional[PlacementGroup],
147
151
  ) -> JobProvisioningData:
148
152
  """
149
153
  Launches a new instance. It should return `JobProvisioningData` ASAP.
@@ -176,7 +180,7 @@ class ComputeWithCreateInstanceSupport(ABC):
176
180
  )
177
181
  instance_offer = instance_offer.copy()
178
182
  self._restrict_instance_offer_az_to_volumes_az(instance_offer, volumes)
179
- return self.create_instance(instance_offer, instance_config)
183
+ return self.create_instance(instance_offer, instance_config, placement_group=None)
180
184
 
181
185
  def _restrict_instance_offer_az_to_volumes_az(
182
186
  self,
@@ -225,9 +229,15 @@ class ComputeWithPlacementGroupSupport(ABC):
225
229
  def create_placement_group(
226
230
  self,
227
231
  placement_group: PlacementGroup,
232
+ master_instance_offer: InstanceOffer,
228
233
  ) -> PlacementGroupProvisioningData:
229
234
  """
230
235
  Creates a placement group.
236
+
237
+ Args:
238
+ placement_group: details about the placement group to be created
239
+ master_instance_offer: the first instance dstack will attempt to add
240
+ to the placement group
231
241
  """
232
242
  pass
233
243
 
@@ -242,10 +252,27 @@ class ComputeWithPlacementGroupSupport(ABC):
242
252
  """
243
253
  pass
244
254
 
255
+ @abstractmethod
256
+ def is_suitable_placement_group(
257
+ self,
258
+ placement_group: PlacementGroup,
259
+ instance_offer: InstanceOffer,
260
+ ) -> bool:
261
+ """
262
+ Checks if the instance offer can be provisioned in the placement group.
263
+
264
+ Should return immediately, without performing API calls.
265
+
266
+ Can be called with an offer originating from a different backend, because some backends
267
+ (BackendType.DSTACK) produce offers on behalf of other backends. Should return `False`
268
+ in that case.
269
+ """
270
+ pass
271
+
245
272
 
246
273
  class ComputeWithGatewaySupport(ABC):
247
274
  """
248
- Must be subclassed and imlemented to support gateways.
275
+ Must be subclassed and implemented to support gateways.
249
276
  """
250
277
 
251
278
  @abstractmethod
@@ -418,6 +445,21 @@ def generate_unique_volume_name(
418
445
  )
419
446
 
420
447
 
448
+ def generate_unique_placement_group_name(
449
+ project_name: str,
450
+ fleet_name: str,
451
+ max_length: int = _DEFAULT_MAX_RESOURCE_NAME_LEN,
452
+ ) -> str:
453
+ """
454
+ Generates a unique placement group name valid across all backends.
455
+ """
456
+ return generate_unique_backend_name(
457
+ resource_name=fleet_name,
458
+ project_name=project_name,
459
+ max_length=max_length,
460
+ )
461
+
462
+
421
463
  def generate_unique_backend_name(
422
464
  resource_name: str,
423
465
  project_name: Optional[str],
@@ -483,13 +525,14 @@ def get_shim_env(
483
525
  base_path: Optional[PathLike] = None,
484
526
  bin_path: Optional[PathLike] = None,
485
527
  backend_shim_env: Optional[Dict[str, str]] = None,
528
+ arch: Optional[str] = None,
486
529
  ) -> Dict[str, str]:
487
530
  log_level = "6" # Trace
488
531
  envs = {
489
532
  "DSTACK_SHIM_HOME": get_dstack_working_dir(base_path),
490
533
  "DSTACK_SHIM_HTTP_PORT": str(DSTACK_SHIM_HTTP_PORT),
491
534
  "DSTACK_SHIM_LOG_LEVEL": log_level,
492
- "DSTACK_RUNNER_DOWNLOAD_URL": get_dstack_runner_download_url(),
535
+ "DSTACK_RUNNER_DOWNLOAD_URL": get_dstack_runner_download_url(arch),
493
536
  "DSTACK_RUNNER_BINARY_PATH": get_dstack_runner_binary_path(bin_path),
494
537
  "DSTACK_RUNNER_HTTP_PORT": str(DSTACK_RUNNER_HTTP_PORT),
495
538
  "DSTACK_RUNNER_SSH_PORT": str(DSTACK_RUNNER_SSH_PORT),
@@ -509,16 +552,19 @@ def get_shim_commands(
509
552
  base_path: Optional[PathLike] = None,
510
553
  bin_path: Optional[PathLike] = None,
511
554
  backend_shim_env: Optional[Dict[str, str]] = None,
555
+ arch: Optional[str] = None,
512
556
  ) -> List[str]:
513
557
  commands = get_shim_pre_start_commands(
514
558
  base_path=base_path,
515
559
  bin_path=bin_path,
560
+ arch=arch,
516
561
  )
517
562
  shim_env = get_shim_env(
518
563
  authorized_keys=authorized_keys,
519
564
  base_path=base_path,
520
565
  bin_path=bin_path,
521
566
  backend_shim_env=backend_shim_env,
567
+ arch=arch,
522
568
  )
523
569
  for k, v in shim_env.items():
524
570
  commands += [f'export "{k}={v}"']
@@ -539,35 +585,63 @@ def get_dstack_runner_version() -> str:
539
585
  return version or "latest"
540
586
 
541
587
 
542
- def get_dstack_runner_download_url() -> str:
543
- if url := os.environ.get("DSTACK_RUNNER_DOWNLOAD_URL"):
544
- return url
545
- build = get_dstack_runner_version()
546
- if settings.DSTACK_VERSION is not None:
547
- bucket = "dstack-runner-downloads"
548
- else:
549
- bucket = "dstack-runner-downloads-stgn"
550
- return (
551
- f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64"
552
- )
553
-
554
-
555
- def get_dstack_shim_download_url() -> str:
556
- if url := os.environ.get("DSTACK_SHIM_DOWNLOAD_URL"):
557
- return url
558
- build = get_dstack_runner_version()
559
- if settings.DSTACK_VERSION is not None:
560
- bucket = "dstack-runner-downloads"
561
- else:
562
- bucket = "dstack-runner-downloads-stgn"
563
- return f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64"
588
+ def normalize_arch(arch: Optional[str] = None) -> GoArchType:
589
+ """
590
+ Converts the given free-form architecture string to the Go GOARCH format.
591
+ Only 64-bit x86 and ARM are supported. If the word size is not specified (e.g., `x86`, `arm`),
592
+ 64-bit is implied.
593
+ If the arch is not specified, falls back to `amd64`.
594
+ """
595
+ if not arch:
596
+ return "amd64"
597
+ arch_lower = arch.lower()
598
+ if "32" in arch_lower or arch_lower in ["i386", "i686"]:
599
+ raise ValueError(f"32-bit architectures are not supported: {arch}")
600
+ if arch_lower.startswith("x86") or arch_lower.startswith("amd"):
601
+ return "amd64"
602
+ if arch_lower.startswith("arm") or arch_lower.startswith("aarch"):
603
+ return "arm64"
604
+ raise ValueError(f"Unsupported architecture: {arch}")
605
+
606
+
607
+ def get_dstack_runner_download_url(arch: Optional[str] = None) -> str:
608
+ url_template = os.environ.get("DSTACK_RUNNER_DOWNLOAD_URL")
609
+ if not url_template:
610
+ if settings.DSTACK_VERSION is not None:
611
+ bucket = "dstack-runner-downloads"
612
+ else:
613
+ bucket = "dstack-runner-downloads-stgn"
614
+ url_template = (
615
+ f"https://{bucket}.s3.eu-west-1.amazonaws.com"
616
+ "/{version}/binaries/dstack-runner-linux-{arch}"
617
+ )
618
+ version = get_dstack_runner_version()
619
+ arch = normalize_arch(arch)
620
+ return url_template.format(version=version, arch=arch)
621
+
622
+
623
+ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
624
+ url_template = os.environ.get("DSTACK_SHIM_DOWNLOAD_URL")
625
+ if not url_template:
626
+ if settings.DSTACK_VERSION is not None:
627
+ bucket = "dstack-runner-downloads"
628
+ else:
629
+ bucket = "dstack-runner-downloads-stgn"
630
+ url_template = (
631
+ f"https://{bucket}.s3.eu-west-1.amazonaws.com"
632
+ "/{version}/binaries/dstack-shim-linux-{arch}"
633
+ )
634
+ version = get_dstack_runner_version()
635
+ arch = normalize_arch(arch)
636
+ return url_template.format(version=version, arch=arch)
564
637
 
565
638
 
566
639
  def get_shim_pre_start_commands(
567
640
  base_path: Optional[PathLike] = None,
568
641
  bin_path: Optional[PathLike] = None,
642
+ arch: Optional[str] = None,
569
643
  ) -> List[str]:
570
- url = get_dstack_shim_download_url()
644
+ url = get_dstack_shim_download_url(arch)
571
645
  dstack_shim_binary_path = get_dstack_shim_binary_path(bin_path)
572
646
  dstack_working_dir = get_dstack_working_dir(base_path)
573
647
  return [
@@ -2,6 +2,7 @@ from dataclasses import asdict
2
2
  from typing import Callable, List, Optional
3
3
 
4
4
  import gpuhunt
5
+ from pydantic import parse_obj_as
5
6
 
6
7
  from dstack._internal.core.models.backends.base import BackendType
7
8
  from dstack._internal.core.models.instances import (
@@ -11,13 +12,14 @@ from dstack._internal.core.models.instances import (
11
12
  InstanceType,
12
13
  Resources,
13
14
  )
14
- from dstack._internal.core.models.resources import DEFAULT_DISK, Memory, Range
15
+ from dstack._internal.core.models.resources import DEFAULT_DISK, CPUSpec, Memory, Range
15
16
  from dstack._internal.core.models.runs import Requirements
16
17
 
17
18
  # Offers not supported by all dstack versions are hidden behind one or more flags.
18
19
  # This list enables the flags that are currently supported.
19
20
  SUPPORTED_GPUHUNT_FLAGS = [
20
21
  "oci-spot",
22
+ "lambda-arm",
21
23
  ]
22
24
 
23
25
 
@@ -71,6 +73,7 @@ def catalog_item_to_offer(
71
73
  if disk_size_mib is None:
72
74
  return None
73
75
  resources = Resources(
76
+ cpu_arch=item.cpu_arch,
74
77
  cpus=item.cpu,
75
78
  memory_mib=round(item.memory * 1024),
76
79
  gpus=gpus,
@@ -90,6 +93,9 @@ def catalog_item_to_offer(
90
93
 
91
94
 
92
95
  def offer_to_catalog_item(offer: InstanceOffer) -> gpuhunt.CatalogItem:
96
+ cpu_arch = offer.instance.resources.cpu_arch
97
+ if cpu_arch is None:
98
+ cpu_arch = gpuhunt.CPUArchitecture.X86
93
99
  gpu_count = len(offer.instance.resources.gpus)
94
100
  gpu_vendor = None
95
101
  gpu_name = None
@@ -104,6 +110,7 @@ def offer_to_catalog_item(offer: InstanceOffer) -> gpuhunt.CatalogItem:
104
110
  instance_name=offer.instance.name,
105
111
  location=offer.region,
106
112
  price=offer.price,
113
+ cpu_arch=cpu_arch,
107
114
  cpu=offer.instance.resources.cpus,
108
115
  memory=offer.instance.resources.memory_mib / 1024,
109
116
  gpu_count=gpu_count,
@@ -125,8 +132,11 @@ def requirements_to_query_filter(req: Optional[Requirements]) -> gpuhunt.QueryFi
125
132
 
126
133
  res = req.resources
127
134
  if res.cpu:
128
- q.min_cpu = res.cpu.min
129
- q.max_cpu = res.cpu.max
135
+ # TODO: Remove in 0.20. Use res.cpu directly
136
+ cpu = parse_obj_as(CPUSpec, res.cpu)
137
+ q.cpu_arch = cpu.arch
138
+ q.min_cpu = cpu.count.min
139
+ q.max_cpu = cpu.count.max
130
140
  if res.memory:
131
141
  q.min_memory = res.memory.min
132
142
  q.max_memory = res.memory.max
@@ -18,6 +18,7 @@ from dstack._internal.core.models.instances import (
18
18
  InstanceConfiguration,
19
19
  InstanceOfferWithAvailability,
20
20
  )
21
+ from dstack._internal.core.models.placement import PlacementGroup
21
22
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
22
23
  from dstack._internal.utils.logging import get_logger
23
24
 
@@ -58,6 +59,7 @@ class CudoCompute(
58
59
  self,
59
60
  instance_offer: InstanceOfferWithAvailability,
60
61
  instance_config: InstanceConfiguration,
62
+ placement_group: Optional[PlacementGroup],
61
63
  ) -> JobProvisioningData:
62
64
  vm_id = generate_unique_instance_name(instance_config, max_length=MAX_RESOURCE_NAME_LEN)
63
65
  public_keys = instance_config.get_public_keys()
@@ -20,6 +20,7 @@ from dstack._internal.core.models.instances import (
20
20
  InstanceOffer,
21
21
  InstanceOfferWithAvailability,
22
22
  )
23
+ from dstack._internal.core.models.placement import PlacementGroup
23
24
  from dstack._internal.core.models.resources import Memory, Range
24
25
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
25
26
  from dstack._internal.utils.logging import get_logger
@@ -85,6 +86,7 @@ class DataCrunchCompute(
85
86
  self,
86
87
  instance_offer: InstanceOfferWithAvailability,
87
88
  instance_config: InstanceConfiguration,
89
+ placement_group: Optional[PlacementGroup],
88
90
  ) -> JobProvisioningData:
89
91
  instance_name = generate_unique_instance_name(
90
92
  instance_config, max_length=MAX_INSTANCE_NAME_LEN
@@ -19,7 +19,7 @@ def authenticate(creds: AnyGCPCreds, project_id: Optional[str] = None) -> Tuple[
19
19
  credentials, credentials_project_id = get_credentials(creds)
20
20
  if project_id is None:
21
21
  # If project_id is not specified explicitly, try using credentials' project_id.
22
- # Explicit project_id takes precedence bacause credentials' project_id may be irrelevant.
22
+ # Explicit project_id takes precedence because credentials' project_id may be irrelevant.
23
23
  # For example, with Workload Identity Federation for GKE, it's cluster project_id.
24
24
  project_id = credentials_project_id
25
25
  if project_id is None:
@@ -1,10 +1,12 @@
1
1
  import concurrent.futures
2
2
  import json
3
+ import threading
3
4
  from collections import defaultdict
4
5
  from typing import Callable, Dict, List, Literal, Optional, Tuple
5
6
 
6
7
  import google.api_core.exceptions
7
8
  import google.cloud.compute_v1 as compute_v1
9
+ from cachetools import TTLCache, cachedmethod
8
10
  from google.cloud import tpu_v2
9
11
  from gpuhunt import KNOWN_TPUS
10
12
 
@@ -98,6 +100,8 @@ class GCPCompute(
98
100
  self.resource_policies_client = compute_v1.ResourcePoliciesClient(
99
101
  credentials=self.credentials
100
102
  )
103
+ self._extra_subnets_cache_lock = threading.Lock()
104
+ self._extra_subnets_cache = TTLCache(maxsize=30, ttl=60)
101
105
 
102
106
  def get_offers(
103
107
  self, requirements: Optional[Requirements] = None
@@ -166,6 +170,7 @@ class GCPCompute(
166
170
  self,
167
171
  instance_offer: InstanceOfferWithAvailability,
168
172
  instance_config: InstanceConfiguration,
173
+ placement_group: Optional[PlacementGroup],
169
174
  ) -> JobProvisioningData:
170
175
  instance_name = generate_unique_instance_name(
171
176
  instance_config, max_length=gcp_resources.MAX_RESOURCE_NAME_LEN
@@ -192,18 +197,16 @@ class GCPCompute(
192
197
  config=self.config,
193
198
  region=instance_offer.region,
194
199
  )
195
- extra_subnets = _get_extra_subnets(
196
- subnetworks_client=self.subnetworks_client,
197
- config=self.config,
200
+ extra_subnets = self._get_extra_subnets(
198
201
  region=instance_offer.region,
199
202
  instance_type_name=instance_offer.instance.name,
200
203
  )
201
204
  placement_policy = None
202
- if instance_config.placement_group_name is not None:
205
+ if placement_group is not None:
203
206
  placement_policy = gcp_resources.get_placement_policy_resource_name(
204
207
  project_id=self.config.project_id,
205
208
  region=instance_offer.region,
206
- placement_policy=instance_config.placement_group_name,
209
+ placement_policy=placement_group.name,
207
210
  )
208
211
  labels = {
209
212
  "owner": "dstack",
@@ -406,6 +409,7 @@ class GCPCompute(
406
409
  def create_placement_group(
407
410
  self,
408
411
  placement_group: PlacementGroup,
412
+ master_instance_offer: InstanceOffer,
409
413
  ) -> PlacementGroupProvisioningData:
410
414
  policy = compute_v1.ResourcePolicy(
411
415
  name=placement_group.name,
@@ -440,6 +444,16 @@ class GCPCompute(
440
444
  raise PlacementGroupInUseError()
441
445
  raise
442
446
 
447
+ def is_suitable_placement_group(
448
+ self,
449
+ placement_group: PlacementGroup,
450
+ instance_offer: InstanceOffer,
451
+ ) -> bool:
452
+ return (
453
+ placement_group.configuration.backend == BackendType.GCP
454
+ and placement_group.configuration.region == instance_offer.region
455
+ )
456
+
443
457
  def create_gateway(
444
458
  self,
445
459
  configuration: GatewayComputeConfiguration,
@@ -757,6 +771,38 @@ class GCPCompute(
757
771
  instance_id,
758
772
  )
759
773
 
774
+ @cachedmethod(
775
+ cache=lambda self: self._extra_subnets_cache,
776
+ lock=lambda self: self._extra_subnets_cache_lock,
777
+ )
778
+ def _get_extra_subnets(
779
+ self,
780
+ region: str,
781
+ instance_type_name: str,
782
+ ) -> List[Tuple[str, str]]:
783
+ if self.config.extra_vpcs is None:
784
+ return []
785
+ if instance_type_name == "a3-megagpu-8g":
786
+ subnets_num = 8
787
+ elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
788
+ subnets_num = 4
789
+ else:
790
+ return []
791
+ extra_subnets = []
792
+ for vpc_name in self.config.extra_vpcs[:subnets_num]:
793
+ subnet = gcp_resources.get_vpc_subnet_or_error(
794
+ subnetworks_client=self.subnetworks_client,
795
+ vpc_project_id=self.config.vpc_project_id or self.config.project_id,
796
+ vpc_name=vpc_name,
797
+ region=region,
798
+ )
799
+ vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
800
+ project_id=self.config.vpc_project_id or self.config.project_id,
801
+ vpc_name=vpc_name,
802
+ )
803
+ extra_subnets.append((vpc_resource_name, subnet))
804
+ return extra_subnets
805
+
760
806
 
761
807
  def _supported_instances_and_zones(
762
808
  regions: List[str],
@@ -831,36 +877,6 @@ def _get_vpc_subnet(
831
877
  )
832
878
 
833
879
 
834
- def _get_extra_subnets(
835
- subnetworks_client: compute_v1.SubnetworksClient,
836
- config: GCPConfig,
837
- region: str,
838
- instance_type_name: str,
839
- ) -> List[Tuple[str, str]]:
840
- if config.extra_vpcs is None:
841
- return []
842
- if instance_type_name == "a3-megagpu-8g":
843
- subnets_num = 8
844
- elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
845
- subnets_num = 4
846
- else:
847
- return []
848
- extra_subnets = []
849
- for vpc_name in config.extra_vpcs[:subnets_num]:
850
- subnet = gcp_resources.get_vpc_subnet_or_error(
851
- subnetworks_client=subnetworks_client,
852
- vpc_project_id=config.vpc_project_id or config.project_id,
853
- vpc_name=vpc_name,
854
- region=region,
855
- )
856
- vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
857
- project_id=config.vpc_project_id or config.project_id,
858
- vpc_name=vpc_name,
859
- )
860
- extra_subnets.append((vpc_resource_name, subnet))
861
- return extra_subnets
862
-
863
-
864
880
  def _get_image_id(instance_type_name: str, cuda: bool) -> str:
865
881
  if instance_type_name == "a3-megagpu-8g":
866
882
  image_name = "dstack-a3mega-5"
@@ -205,12 +205,17 @@ def _get_network_interfaces(
205
205
  else:
206
206
  network_interface.access_configs = []
207
207
 
208
+ if extra_subnetworks:
209
+ # Multiple interfaces are set only for GPU VM that require gVNIC for best performance
210
+ network_interface.nic_type = compute_v1.NetworkInterface.NicType.GVNIC.name
211
+
208
212
  network_interfaces = [network_interface]
209
213
  for network, subnetwork in extra_subnetworks or []:
210
214
  network_interfaces.append(
211
215
  compute_v1.NetworkInterface(
212
216
  network=network,
213
217
  subnetwork=subnetwork,
218
+ nic_type=compute_v1.NetworkInterface.NicType.GVNIC.name,
214
219
  )
215
220
  )
216
221
  return network_interfaces
@@ -437,7 +442,7 @@ def wait_for_operation(operation: Operation, verbose_name: str = "operation", ti
437
442
  raise
438
443
  except Exception as e:
439
444
  # Write only debug logs here.
440
- # The unexpected errors will be propagated and logged appropriatly by the caller.
445
+ # The unexpected errors will be propagated and logged appropriately by the caller.
441
446
  logger.debug("Error during %s: %s", verbose_name, e)
442
447
  raise operation.exception() or e
443
448
  return result
@@ -20,6 +20,7 @@ from dstack._internal.core.models.instances import (
20
20
  InstanceOffer,
21
21
  InstanceOfferWithAvailability,
22
22
  )
23
+ from dstack._internal.core.models.placement import PlacementGroup
23
24
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
24
25
 
25
26
  MAX_INSTANCE_NAME_LEN = 60
@@ -46,7 +47,10 @@ class LambdaCompute(
46
47
  return offers_with_availability
47
48
 
48
49
  def create_instance(
49
- self, instance_offer: InstanceOfferWithAvailability, instance_config: InstanceConfiguration
50
+ self,
51
+ instance_offer: InstanceOfferWithAvailability,
52
+ instance_config: InstanceConfiguration,
53
+ placement_group: Optional[PlacementGroup],
50
54
  ) -> JobProvisioningData:
51
55
  instance_name = generate_unique_instance_name(
52
56
  instance_config, max_length=MAX_INSTANCE_NAME_LEN
@@ -89,7 +93,10 @@ class LambdaCompute(
89
93
  instance_info = _get_instance_info(self.api_client, provisioning_data.instance_id)
90
94
  if instance_info is not None and instance_info["status"] != "booting":
91
95
  provisioning_data.hostname = instance_info["ip"]
92
- commands = get_shim_commands(authorized_keys=[project_ssh_public_key])
96
+ commands = get_shim_commands(
97
+ authorized_keys=[project_ssh_public_key],
98
+ arch=provisioning_data.instance_type.resources.cpu_arch,
99
+ )
93
100
  # shim is assumed to be run under root
94
101
  launch_command = "sudo sh -c '" + "&& ".join(commands) + "'"
95
102
  thread = Thread(
@@ -179,13 +186,18 @@ def _setup_instance(
179
186
  ssh_private_key: str,
180
187
  ):
181
188
  setup_commands = (
182
- "mkdir /home/ubuntu/.dstack && "
183
- "sudo apt-get update && "
184
- "sudo apt-get install -y --no-install-recommends nvidia-container-toolkit && "
185
- "sudo nvidia-ctk runtime configure --runtime=docker && "
186
- "sudo pkill -SIGHUP dockerd"
189
+ "mkdir /home/ubuntu/.dstack",
190
+ "sudo apt-get update",
191
+ "sudo apt-get install -y --no-install-recommends nvidia-container-toolkit",
192
+ "sudo install -d -m 0755 /etc/docker",
193
+ # Workaround for https://github.com/NVIDIA/nvidia-container-toolkit/issues/48
194
+ """echo '{"exec-opts":["native.cgroupdriver=cgroupfs"]}' | sudo tee /etc/docker/daemon.json""",
195
+ "sudo nvidia-ctk runtime configure --runtime=docker",
196
+ "sudo systemctl restart docker.service", # `systemctl reload` (`kill -HUP`) won't work
197
+ )
198
+ _run_ssh_command(
199
+ hostname=hostname, ssh_private_key=ssh_private_key, command=" && ".join(setup_commands)
187
200
  )
188
- _run_ssh_command(hostname=hostname, ssh_private_key=ssh_private_key, command=setup_commands)
189
201
 
190
202
 
191
203
  def _launch_runner(
@@ -15,6 +15,7 @@ from dstack._internal.core.models.instances import (
15
15
  InstanceType,
16
16
  Resources,
17
17
  )
18
+ from dstack._internal.core.models.placement import PlacementGroup
18
19
  from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
19
20
  from dstack._internal.core.models.volumes import Volume, VolumeProvisioningData
20
21
  from dstack._internal.utils.logging import get_logger
@@ -53,6 +54,7 @@ class LocalCompute(
53
54
  self,
54
55
  instance_offer: InstanceOfferWithAvailability,
55
56
  instance_config: InstanceConfiguration,
57
+ placement_group: Optional[PlacementGroup],
56
58
  ) -> JobProvisioningData:
57
59
  return JobProvisioningData(
58
60
  backend=instance_offer.backend,