dstack 0.19.17__py3-none-any.whl → 0.19.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (43) hide show
  1. dstack/_internal/cli/services/configurators/fleet.py +13 -1
  2. dstack/_internal/core/backends/aws/compute.py +237 -18
  3. dstack/_internal/core/backends/base/compute.py +20 -2
  4. dstack/_internal/core/backends/cudo/compute.py +23 -9
  5. dstack/_internal/core/backends/gcp/compute.py +13 -7
  6. dstack/_internal/core/backends/lambdalabs/compute.py +2 -1
  7. dstack/_internal/core/compatibility/fleets.py +12 -11
  8. dstack/_internal/core/compatibility/gateways.py +9 -8
  9. dstack/_internal/core/compatibility/logs.py +4 -3
  10. dstack/_internal/core/compatibility/runs.py +17 -20
  11. dstack/_internal/core/compatibility/volumes.py +9 -8
  12. dstack/_internal/core/errors.py +4 -0
  13. dstack/_internal/core/models/common.py +7 -0
  14. dstack/_internal/core/services/diff.py +36 -3
  15. dstack/_internal/server/app.py +20 -0
  16. dstack/_internal/server/background/__init__.py +61 -37
  17. dstack/_internal/server/background/tasks/process_fleets.py +19 -3
  18. dstack/_internal/server/background/tasks/process_gateways.py +1 -1
  19. dstack/_internal/server/background/tasks/process_instances.py +13 -2
  20. dstack/_internal/server/background/tasks/process_placement_groups.py +4 -2
  21. dstack/_internal/server/background/tasks/process_running_jobs.py +14 -3
  22. dstack/_internal/server/background/tasks/process_runs.py +8 -4
  23. dstack/_internal/server/background/tasks/process_submitted_jobs.py +36 -7
  24. dstack/_internal/server/background/tasks/process_terminating_jobs.py +5 -3
  25. dstack/_internal/server/background/tasks/process_volumes.py +2 -2
  26. dstack/_internal/server/services/fleets.py +5 -4
  27. dstack/_internal/server/services/gateways/__init__.py +4 -2
  28. dstack/_internal/server/services/jobs/configurators/base.py +5 -1
  29. dstack/_internal/server/services/locking.py +101 -12
  30. dstack/_internal/server/services/runs.py +24 -40
  31. dstack/_internal/server/services/volumes.py +2 -2
  32. dstack/_internal/server/settings.py +18 -4
  33. dstack/_internal/server/statics/index.html +1 -1
  34. dstack/_internal/server/statics/{main-d151637af20f70b2e796.js → main-d1ac2e8c38ed5f08a114.js} +68 -64
  35. dstack/_internal/server/statics/{main-d151637af20f70b2e796.js.map → main-d1ac2e8c38ed5f08a114.js.map} +1 -1
  36. dstack/_internal/server/statics/{main-d48635d8fe670d53961c.css → main-d58fc0460cb0eae7cb5c.css} +1 -1
  37. dstack/_internal/server/testing/common.py +7 -3
  38. dstack/version.py +1 -1
  39. {dstack-0.19.17.dist-info → dstack-0.19.18.dist-info}/METADATA +11 -10
  40. {dstack-0.19.17.dist-info → dstack-0.19.18.dist-info}/RECORD +43 -43
  41. {dstack-0.19.17.dist-info → dstack-0.19.18.dist-info}/WHEEL +0 -0
  42. {dstack-0.19.17.dist-info → dstack-0.19.18.dist-info}/entry_points.txt +0 -0
  43. {dstack-0.19.17.dist-info → dstack-0.19.18.dist-info}/licenses/LICENSE.md +0 -0
@@ -35,6 +35,7 @@ from dstack._internal.core.models.fleets import (
35
35
  )
36
36
  from dstack._internal.core.models.instances import InstanceAvailability, InstanceStatus, SSHKey
37
37
  from dstack._internal.core.models.repos.base import Repo
38
+ from dstack._internal.core.services.diff import diff_models
38
39
  from dstack._internal.utils.common import local_time
39
40
  from dstack._internal.utils.logging import get_logger
40
41
  from dstack._internal.utils.ssh import convert_ssh_key_to_pem, generate_public_key, pkey_from_str
@@ -82,7 +83,18 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
82
83
  confirm_message += "Create the fleet?"
83
84
  else:
84
85
  action_message += f"Found fleet [code]{plan.spec.configuration.name}[/]."
85
- if plan.current_resource.spec.configuration == plan.spec.configuration:
86
+ diff = diff_models(
87
+ old=plan.current_resource.spec.configuration,
88
+ new=plan.spec.configuration,
89
+ ignore={
90
+ "ssh_config": {
91
+ "ssh_key": True,
92
+ "proxy_jump": {"ssh_key"},
93
+ "hosts": {"__all__": {"ssh_key": True, "proxy_jump": {"ssh_key"}}},
94
+ }
95
+ },
96
+ )
97
+ if not diff:
86
98
  if command_args.yes and not command_args.force:
87
99
  # --force is required only with --yes,
88
100
  # otherwise we may ask for force apply interactively.
@@ -1,14 +1,21 @@
1
+ import threading
1
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
2
3
  from typing import Any, Dict, List, Optional, Tuple
3
4
 
4
5
  import boto3
5
6
  import botocore.client
6
7
  import botocore.exceptions
8
+ from cachetools import Cache, TTLCache, cachedmethod
9
+ from cachetools.keys import hashkey
7
10
  from pydantic import ValidationError
8
11
 
9
12
  import dstack._internal.core.backends.aws.resources as aws_resources
10
13
  from dstack._internal import settings
11
- from dstack._internal.core.backends.aws.models import AWSAccessKeyCreds, AWSConfig
14
+ from dstack._internal.core.backends.aws.models import (
15
+ AWSAccessKeyCreds,
16
+ AWSConfig,
17
+ AWSOSImageConfig,
18
+ )
12
19
  from dstack._internal.core.backends.base.compute import (
13
20
  Compute,
14
21
  ComputeWithCreateInstanceSupport,
@@ -26,7 +33,12 @@ from dstack._internal.core.backends.base.compute import (
26
33
  merge_tags,
27
34
  )
28
35
  from dstack._internal.core.backends.base.offers import get_catalog_offers
29
- from dstack._internal.core.errors import ComputeError, NoCapacityError, PlacementGroupInUseError
36
+ from dstack._internal.core.errors import (
37
+ ComputeError,
38
+ NoCapacityError,
39
+ PlacementGroupInUseError,
40
+ PlacementGroupNotSupportedError,
41
+ )
30
42
  from dstack._internal.core.models.backends.base import BackendType
31
43
  from dstack._internal.core.models.common import CoreModel
32
44
  from dstack._internal.core.models.gateways import (
@@ -39,7 +51,11 @@ from dstack._internal.core.models.instances import (
39
51
  InstanceOffer,
40
52
  InstanceOfferWithAvailability,
41
53
  )
42
- from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
54
+ from dstack._internal.core.models.placement import (
55
+ PlacementGroup,
56
+ PlacementGroupProvisioningData,
57
+ PlacementStrategy,
58
+ )
43
59
  from dstack._internal.core.models.resources import Memory, Range
44
60
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
45
61
  from dstack._internal.core.models.volumes import (
@@ -66,6 +82,10 @@ class AWSVolumeBackendData(CoreModel):
66
82
  iops: int
67
83
 
68
84
 
85
+ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs):
86
+ return hashkey(*args, **kwargs)
87
+
88
+
69
89
  class AWSCompute(
70
90
  ComputeWithCreateInstanceSupport,
71
91
  ComputeWithMultinodeSupport,
@@ -86,6 +106,24 @@ class AWSCompute(
86
106
  )
87
107
  else: # default creds
88
108
  self.session = boto3.Session()
109
+ # Caches to avoid redundant API calls when provisioning many instances
110
+ # get_offers is already cached but we still cache its sub-functions
111
+ # with more aggressive/longer caches.
112
+ self._get_regions_to_quotas_cache_lock = threading.Lock()
113
+ self._get_regions_to_quotas_execution_lock = threading.Lock()
114
+ self._get_regions_to_quotas_cache = TTLCache(maxsize=10, ttl=300)
115
+ self._get_regions_to_zones_cache_lock = threading.Lock()
116
+ self._get_regions_to_zones_cache = Cache(maxsize=10)
117
+ self._get_vpc_id_subnet_id_or_error_cache_lock = threading.Lock()
118
+ self._get_vpc_id_subnet_id_or_error_cache = TTLCache(maxsize=100, ttl=600)
119
+ self._get_maximum_efa_interfaces_cache_lock = threading.Lock()
120
+ self._get_maximum_efa_interfaces_cache = Cache(maxsize=100)
121
+ self._get_subnets_availability_zones_cache_lock = threading.Lock()
122
+ self._get_subnets_availability_zones_cache = Cache(maxsize=100)
123
+ self._create_security_group_cache_lock = threading.Lock()
124
+ self._create_security_group_cache = TTLCache(maxsize=100, ttl=600)
125
+ self._get_image_id_and_username_cache_lock = threading.Lock()
126
+ self._get_image_id_and_username_cache = TTLCache(maxsize=100, ttl=600)
89
127
 
90
128
  def get_offers(
91
129
  self, requirements: Optional[Requirements] = None
@@ -126,8 +164,11 @@ class AWSCompute(
126
164
  extra_filter=filter,
127
165
  )
128
166
  regions = list(set(i.region for i in offers))
129
- regions_to_quotas = _get_regions_to_quotas(self.session, regions)
130
- regions_to_zones = _get_regions_to_zones(self.session, regions)
167
+ with self._get_regions_to_quotas_execution_lock:
168
+ # Cache lock does not prevent concurrent execution.
169
+ # We use a separate lock to avoid requesting quotas in parallel and hitting rate limits.
170
+ regions_to_quotas = self._get_regions_to_quotas(self.session, regions)
171
+ regions_to_zones = self._get_regions_to_zones(self.session, regions)
131
172
 
132
173
  availability_offers = []
133
174
  for offer in offers:
@@ -186,21 +227,24 @@ class AWSCompute(
186
227
  tags = aws_resources.filter_invalid_tags(tags)
187
228
 
188
229
  disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
189
- max_efa_interfaces = _get_maximum_efa_interfaces(
190
- ec2_client=ec2_client, instance_type=instance_offer.instance.name
230
+ max_efa_interfaces = self._get_maximum_efa_interfaces(
231
+ ec2_client=ec2_client,
232
+ region=instance_offer.region,
233
+ instance_type=instance_offer.instance.name,
191
234
  )
192
235
  enable_efa = max_efa_interfaces > 0
193
236
  is_capacity_block = False
194
237
  try:
195
- vpc_id, subnet_ids = get_vpc_id_subnet_id_or_error(
238
+ vpc_id, subnet_ids = self._get_vpc_id_subnet_id_or_error(
196
239
  ec2_client=ec2_client,
197
240
  config=self.config,
198
241
  region=instance_offer.region,
199
242
  allocate_public_ip=allocate_public_ip,
200
243
  availability_zones=zones,
201
244
  )
202
- subnet_id_to_az_map = aws_resources.get_subnets_availability_zones(
245
+ subnet_id_to_az_map = self._get_subnets_availability_zones(
203
246
  ec2_client=ec2_client,
247
+ region=instance_offer.region,
204
248
  subnet_ids=subnet_ids,
205
249
  )
206
250
  if instance_config.reservation:
@@ -229,12 +273,19 @@ class AWSCompute(
229
273
  tried_zones.add(az)
230
274
  try:
231
275
  logger.debug("Trying provisioning %s in %s", instance_offer.instance.name, az)
232
- image_id, username = aws_resources.get_image_id_and_username(
276
+ image_id, username = self._get_image_id_and_username(
233
277
  ec2_client=ec2_client,
278
+ region=instance_offer.region,
234
279
  cuda=len(instance_offer.instance.resources.gpus) > 0,
235
280
  instance_type=instance_offer.instance.name,
236
281
  image_config=self.config.os_images,
237
282
  )
283
+ security_group_id = self._create_security_group(
284
+ ec2_client=ec2_client,
285
+ region=instance_offer.region,
286
+ project_id=project_name,
287
+ vpc_id=vpc_id,
288
+ )
238
289
  response = ec2_resource.create_instances(
239
290
  **aws_resources.create_instances_struct(
240
291
  disk_size=disk_size,
@@ -243,11 +294,7 @@ class AWSCompute(
243
294
  iam_instance_profile=self.config.iam_instance_profile,
244
295
  user_data=get_user_data(authorized_keys=instance_config.get_public_keys()),
245
296
  tags=aws_resources.make_tags(tags),
246
- security_group_id=aws_resources.create_security_group(
247
- ec2_client=ec2_client,
248
- project_id=project_name,
249
- vpc_id=vpc_id,
250
- ),
297
+ security_group_id=security_group_id,
251
298
  spot=instance_offer.instance.resources.spot,
252
299
  subnet_id=subnet_id,
253
300
  allocate_public_ip=allocate_public_ip,
@@ -296,6 +343,8 @@ class AWSCompute(
296
343
  placement_group: PlacementGroup,
297
344
  master_instance_offer: InstanceOffer,
298
345
  ) -> PlacementGroupProvisioningData:
346
+ if not _offer_supports_placement_group(master_instance_offer, placement_group):
347
+ raise PlacementGroupNotSupportedError()
299
348
  ec2_client = self.session.client("ec2", region_name=placement_group.configuration.region)
300
349
  logger.debug("Creating placement group %s...", placement_group.name)
301
350
  ec2_client.create_placement_group(
@@ -332,6 +381,8 @@ class AWSCompute(
332
381
  placement_group: PlacementGroup,
333
382
  instance_offer: InstanceOffer,
334
383
  ) -> bool:
384
+ if not _offer_supports_placement_group(instance_offer, placement_group):
385
+ return False
335
386
  return (
336
387
  placement_group.configuration.backend == BackendType.AWS
337
388
  and placement_group.configuration.region == instance_offer.region
@@ -361,7 +412,7 @@ class AWSCompute(
361
412
  tags = aws_resources.filter_invalid_tags(tags)
362
413
  tags = aws_resources.make_tags(tags)
363
414
 
364
- vpc_id, subnets_ids = get_vpc_id_subnet_id_or_error(
415
+ vpc_id, subnets_ids = self._get_vpc_id_subnet_id_or_error(
365
416
  ec2_client=ec2_client,
366
417
  config=self.config,
367
418
  region=configuration.region,
@@ -696,6 +747,165 @@ class AWSCompute(
696
747
  return True
697
748
  return True
698
749
 
750
+ def _get_regions_to_quotas_key(
751
+ self,
752
+ session: boto3.Session,
753
+ regions: List[str],
754
+ ) -> tuple:
755
+ return hashkey(tuple(regions))
756
+
757
+ @cachedmethod(
758
+ cache=lambda self: self._get_regions_to_quotas_cache,
759
+ key=_get_regions_to_quotas_key,
760
+ lock=lambda self: self._get_regions_to_quotas_cache_lock,
761
+ )
762
+ def _get_regions_to_quotas(
763
+ self,
764
+ session: boto3.Session,
765
+ regions: List[str],
766
+ ) -> Dict[str, Dict[str, int]]:
767
+ return _get_regions_to_quotas(session=session, regions=regions)
768
+
769
+ def _get_regions_to_zones_key(
770
+ self,
771
+ session: boto3.Session,
772
+ regions: List[str],
773
+ ) -> tuple:
774
+ return hashkey(tuple(regions))
775
+
776
+ @cachedmethod(
777
+ cache=lambda self: self._get_regions_to_zones_cache,
778
+ key=_get_regions_to_zones_key,
779
+ lock=lambda self: self._get_regions_to_zones_cache_lock,
780
+ )
781
+ def _get_regions_to_zones(
782
+ self,
783
+ session: boto3.Session,
784
+ regions: List[str],
785
+ ) -> Dict[str, List[str]]:
786
+ return _get_regions_to_zones(session=session, regions=regions)
787
+
788
+ def _get_vpc_id_subnet_id_or_error_cache_key(
789
+ self,
790
+ ec2_client: botocore.client.BaseClient,
791
+ config: AWSConfig,
792
+ region: str,
793
+ allocate_public_ip: bool,
794
+ availability_zones: Optional[List[str]] = None,
795
+ ) -> tuple:
796
+ return hashkey(
797
+ region, allocate_public_ip, tuple(availability_zones) if availability_zones else None
798
+ )
799
+
800
+ @cachedmethod(
801
+ cache=lambda self: self._get_vpc_id_subnet_id_or_error_cache,
802
+ key=_get_vpc_id_subnet_id_or_error_cache_key,
803
+ lock=lambda self: self._get_vpc_id_subnet_id_or_error_cache_lock,
804
+ )
805
+ def _get_vpc_id_subnet_id_or_error(
806
+ self,
807
+ ec2_client: botocore.client.BaseClient,
808
+ config: AWSConfig,
809
+ region: str,
810
+ allocate_public_ip: bool,
811
+ availability_zones: Optional[List[str]] = None,
812
+ ) -> Tuple[str, List[str]]:
813
+ return get_vpc_id_subnet_id_or_error(
814
+ ec2_client=ec2_client,
815
+ config=config,
816
+ region=region,
817
+ allocate_public_ip=allocate_public_ip,
818
+ availability_zones=availability_zones,
819
+ )
820
+
821
+ @cachedmethod(
822
+ cache=lambda self: self._get_maximum_efa_interfaces_cache,
823
+ key=_ec2client_cache_methodkey,
824
+ lock=lambda self: self._get_maximum_efa_interfaces_cache_lock,
825
+ )
826
+ def _get_maximum_efa_interfaces(
827
+ self,
828
+ ec2_client: botocore.client.BaseClient,
829
+ region: str,
830
+ instance_type: str,
831
+ ) -> int:
832
+ return _get_maximum_efa_interfaces(
833
+ ec2_client=ec2_client,
834
+ instance_type=instance_type,
835
+ )
836
+
837
+ def _get_subnets_availability_zones_key(
838
+ self,
839
+ ec2_client: botocore.client.BaseClient,
840
+ region: str,
841
+ subnet_ids: List[str],
842
+ ) -> tuple:
843
+ return hashkey(region, tuple(subnet_ids))
844
+
845
+ @cachedmethod(
846
+ cache=lambda self: self._get_subnets_availability_zones_cache,
847
+ key=_get_subnets_availability_zones_key,
848
+ lock=lambda self: self._get_subnets_availability_zones_cache_lock,
849
+ )
850
+ def _get_subnets_availability_zones(
851
+ self,
852
+ ec2_client: botocore.client.BaseClient,
853
+ region: str,
854
+ subnet_ids: List[str],
855
+ ) -> Dict[str, str]:
856
+ return aws_resources.get_subnets_availability_zones(
857
+ ec2_client=ec2_client,
858
+ subnet_ids=subnet_ids,
859
+ )
860
+
861
+ @cachedmethod(
862
+ cache=lambda self: self._create_security_group_cache,
863
+ key=_ec2client_cache_methodkey,
864
+ lock=lambda self: self._create_security_group_cache_lock,
865
+ )
866
+ def _create_security_group(
867
+ self,
868
+ ec2_client: botocore.client.BaseClient,
869
+ region: str,
870
+ project_id: str,
871
+ vpc_id: Optional[str],
872
+ ) -> str:
873
+ return aws_resources.create_security_group(
874
+ ec2_client=ec2_client,
875
+ project_id=project_id,
876
+ vpc_id=vpc_id,
877
+ )
878
+
879
+ def _get_image_id_and_username_cache_key(
880
+ self,
881
+ ec2_client: botocore.client.BaseClient,
882
+ region: str,
883
+ cuda: bool,
884
+ instance_type: str,
885
+ image_config: Optional[AWSOSImageConfig] = None,
886
+ ) -> tuple:
887
+ return hashkey(region, cuda, instance_type, image_config.json() if image_config else None)
888
+
889
+ @cachedmethod(
890
+ cache=lambda self: self._get_image_id_and_username_cache,
891
+ key=_get_image_id_and_username_cache_key,
892
+ lock=lambda self: self._get_image_id_and_username_cache_lock,
893
+ )
894
+ def _get_image_id_and_username(
895
+ self,
896
+ ec2_client: botocore.client.BaseClient,
897
+ region: str,
898
+ cuda: bool,
899
+ instance_type: str,
900
+ image_config: Optional[AWSOSImageConfig] = None,
901
+ ) -> tuple[str, str]:
902
+ return aws_resources.get_image_id_and_username(
903
+ ec2_client=ec2_client,
904
+ cuda=cuda,
905
+ instance_type=instance_type,
906
+ image_config=image_config,
907
+ )
908
+
699
909
 
700
910
  def get_vpc_id_subnet_id_or_error(
701
911
  ec2_client: botocore.client.BaseClient,
@@ -798,7 +1008,7 @@ def _get_regions_to_quotas(
798
1008
  return region_quotas
799
1009
 
800
1010
  regions_to_quotas = {}
801
- with ThreadPoolExecutor(max_workers=8) as executor:
1011
+ with ThreadPoolExecutor(max_workers=12) as executor:
802
1012
  future_to_region = {}
803
1013
  for region in regions:
804
1014
  future = executor.submit(
@@ -823,7 +1033,7 @@ def _has_quota(quotas: Dict[str, int], instance_name: str) -> Optional[bool]:
823
1033
 
824
1034
  def _get_regions_to_zones(session: boto3.Session, regions: List[str]) -> Dict[str, List[str]]:
825
1035
  regions_to_zones = {}
826
- with ThreadPoolExecutor(max_workers=8) as executor:
1036
+ with ThreadPoolExecutor(max_workers=12) as executor:
827
1037
  future_to_region = {}
828
1038
  for region in regions:
829
1039
  future = executor.submit(
@@ -862,6 +1072,15 @@ def _supported_instances(offer: InstanceOffer) -> bool:
862
1072
  return False
863
1073
 
864
1074
 
1075
+ def _offer_supports_placement_group(offer: InstanceOffer, placement_group: PlacementGroup) -> bool:
1076
+ if placement_group.configuration.placement_strategy != PlacementStrategy.CLUSTER:
1077
+ return True
1078
+ for family in ["t3.", "t2."]:
1079
+ if offer.instance.name.startswith(family):
1080
+ return False
1081
+ return True
1082
+
1083
+
865
1084
  def _get_maximum_efa_interfaces(ec2_client: botocore.client.BaseClient, instance_type: str) -> int:
866
1085
  try:
867
1086
  response = ec2_client.describe_instance_types(
@@ -57,7 +57,7 @@ class Compute(ABC):
57
57
 
58
58
  def __init__(self):
59
59
  self._offers_cache_lock = threading.Lock()
60
- self._offers_cache = TTLCache(maxsize=5, ttl=30)
60
+ self._offers_cache = TTLCache(maxsize=10, ttl=180)
61
61
 
62
62
  @abstractmethod
63
63
  def get_offers(
@@ -559,7 +559,8 @@ def get_shim_commands(
559
559
  backend_shim_env: Optional[Dict[str, str]] = None,
560
560
  arch: Optional[str] = None,
561
561
  ) -> List[str]:
562
- commands = get_shim_pre_start_commands(
562
+ commands = get_setup_cloud_instance_commands()
563
+ commands += get_shim_pre_start_commands(
563
564
  base_path=base_path,
564
565
  bin_path=bin_path,
565
566
  arch=arch,
@@ -641,6 +642,23 @@ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
641
642
  return url_template.format(version=version, arch=arch)
642
643
 
643
644
 
645
+ def get_setup_cloud_instance_commands() -> list[str]:
646
+ return [
647
+ # Workaround for https://github.com/NVIDIA/nvidia-container-toolkit/issues/48
648
+ # Attempts to patch /etc/docker/daemon.json while keeping any custom settings it may have.
649
+ (
650
+ "/bin/sh -c '" # wrap in /bin/sh to avoid interfering with other cloud init commands
651
+ " grep -q nvidia /etc/docker/daemon.json"
652
+ " && ! grep -q native.cgroupdriver /etc/docker/daemon.json"
653
+ " && jq '\\''.\"exec-opts\" = ((.\"exec-opts\" // []) + [\"native.cgroupdriver=cgroupfs\"])'\\'' /etc/docker/daemon.json > /tmp/daemon.json"
654
+ " && sudo mv /tmp/daemon.json /etc/docker/daemon.json"
655
+ " && sudo service docker restart"
656
+ " || true"
657
+ "'"
658
+ ),
659
+ ]
660
+
661
+
644
662
  def get_shim_pre_start_commands(
645
663
  base_path: Optional[PathLike] = None,
646
664
  bin_path: Optional[PathLike] = None,
@@ -65,12 +65,13 @@ class CudoCompute(
65
65
  public_keys = instance_config.get_public_keys()
66
66
  memory_size = round(instance_offer.instance.resources.memory_mib / 1024)
67
67
  disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
68
- commands = get_shim_commands(authorized_keys=public_keys)
69
68
  gpus_no = len(instance_offer.instance.resources.gpus)
70
- shim_commands = " ".join([" && ".join(commands)])
71
- startup_script = (
72
- shim_commands if gpus_no > 0 else f"{install_docker_script()} && {shim_commands}"
73
- )
69
+ if gpus_no > 0:
70
+ # we'll need jq for patching /etc/docker/daemon.json, see get_shim_commands()
71
+ commands = install_jq_commands()
72
+ else:
73
+ commands = install_docker_commands()
74
+ commands += get_shim_commands(authorized_keys=public_keys)
74
75
 
75
76
  try:
76
77
  resp_data = self.api_client.create_virtual_machine(
@@ -85,7 +86,7 @@ class CudoCompute(
85
86
  memory_gib=memory_size,
86
87
  vcpus=instance_offer.instance.resources.cpus,
87
88
  vm_id=vm_id,
88
- start_script=startup_script,
89
+ start_script=" && ".join(commands),
89
90
  password=None,
90
91
  customSshKeys=public_keys,
91
92
  )
@@ -151,6 +152,19 @@ def _get_image_id(cuda: bool) -> str:
151
152
  return image_name
152
153
 
153
154
 
154
- def install_docker_script():
155
- commands = 'export DEBIAN_FRONTEND="noninteractive" && mkdir -p /etc/apt/keyrings && curl --max-time 60 -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /etc/apt/keyrings/docker.gpg && echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null && apt-get update && apt-get --assume-yes install docker-ce docker-ce-cli containerd.io docker-compose-plugin'
156
- return commands
155
+ def install_jq_commands():
156
+ return [
157
+ "export DEBIAN_FRONTEND=noninteractive",
158
+ "apt-get --assume-yes install jq",
159
+ ]
160
+
161
+
162
+ def install_docker_commands():
163
+ return [
164
+ "export DEBIAN_FRONTEND=noninteractive",
165
+ "mkdir -p /etc/apt/keyrings",
166
+ "curl --max-time 60 -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /etc/apt/keyrings/docker.gpg",
167
+ 'echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null',
168
+ "apt-get update",
169
+ "apt-get --assume-yes install docker-ce docker-ce-cli containerd.io docker-compose-plugin",
170
+ ]
@@ -8,6 +8,7 @@ import google.api_core.exceptions
8
8
  import google.cloud.compute_v1 as compute_v1
9
9
  from cachetools import TTLCache, cachedmethod
10
10
  from google.cloud import tpu_v2
11
+ from google.cloud.compute_v1.types.compute import Instance
11
12
  from gpuhunt import KNOWN_TPUS
12
13
 
13
14
  import dstack._internal.core.backends.gcp.auth as auth
@@ -19,6 +20,7 @@ from dstack._internal.core.backends.base.compute import (
19
20
  ComputeWithGatewaySupport,
20
21
  ComputeWithMultinodeSupport,
21
22
  ComputeWithPlacementGroupSupport,
23
+ ComputeWithPrivateGatewaySupport,
22
24
  ComputeWithVolumeSupport,
23
25
  generate_unique_gateway_instance_name,
24
26
  generate_unique_instance_name,
@@ -83,6 +85,7 @@ class GCPCompute(
83
85
  ComputeWithMultinodeSupport,
84
86
  ComputeWithPlacementGroupSupport,
85
87
  ComputeWithGatewaySupport,
88
+ ComputeWithPrivateGatewaySupport,
86
89
  ComputeWithVolumeSupport,
87
90
  Compute,
88
91
  ):
@@ -395,11 +398,7 @@ class GCPCompute(
395
398
  if instance.status in ["PROVISIONING", "STAGING"]:
396
399
  return
397
400
  if instance.status == "RUNNING":
398
- if allocate_public_ip:
399
- hostname = instance.network_interfaces[0].access_configs[0].nat_i_p
400
- else:
401
- hostname = instance.network_interfaces[0].network_i_p
402
- provisioning_data.hostname = hostname
401
+ provisioning_data.hostname = _get_instance_ip(instance, allocate_public_ip)
403
402
  provisioning_data.internal_ip = instance.network_interfaces[0].network_i_p
404
403
  return
405
404
  raise ProvisioningError(
@@ -500,7 +499,7 @@ class GCPCompute(
500
499
  request.instance_resource = gcp_resources.create_instance_struct(
501
500
  disk_size=10,
502
501
  image_id=_get_gateway_image_id(),
503
- machine_type="e2-small",
502
+ machine_type="e2-medium",
504
503
  accelerators=[],
505
504
  spot=False,
506
505
  user_data=get_gateway_user_data(configuration.ssh_key_pub),
@@ -512,6 +511,7 @@ class GCPCompute(
512
511
  service_account=self.config.vm_service_account,
513
512
  network=self.config.vpc_resource_name,
514
513
  subnetwork=subnetwork,
514
+ allocate_public_ip=configuration.public_ip,
515
515
  )
516
516
  operation = self.instances_client.insert(request=request)
517
517
  gcp_resources.wait_for_extended_operation(operation, "instance creation")
@@ -522,7 +522,7 @@ class GCPCompute(
522
522
  instance_id=instance_name,
523
523
  region=configuration.region, # used for instance termination
524
524
  availability_zone=zone,
525
- ip_address=instance.network_interfaces[0].access_configs[0].nat_i_p,
525
+ ip_address=_get_instance_ip(instance, configuration.public_ip),
526
526
  backend_data=json.dumps({"zone": zone}),
527
527
  )
528
528
 
@@ -1024,3 +1024,9 @@ def _is_tpu_provisioning_data(provisioning_data: JobProvisioningData) -> bool:
1024
1024
  backend_data_dict = json.loads(provisioning_data.backend_data)
1025
1025
  is_tpu = backend_data_dict.get("is_tpu", False)
1026
1026
  return is_tpu
1027
+
1028
+
1029
+ def _get_instance_ip(instance: Instance, public_ip: bool) -> str:
1030
+ if public_ip:
1031
+ return instance.network_interfaces[0].access_configs[0].nat_i_p
1032
+ return instance.network_interfaces[0].network_i_p
@@ -1,4 +1,5 @@
1
1
  import hashlib
2
+ import shlex
2
3
  import subprocess
3
4
  import tempfile
4
5
  from threading import Thread
@@ -98,7 +99,7 @@ class LambdaCompute(
98
99
  arch=provisioning_data.instance_type.resources.cpu_arch,
99
100
  )
100
101
  # shim is assumed to be run under root
101
- launch_command = "sudo sh -c '" + "&& ".join(commands) + "'"
102
+ launch_command = "sudo sh -c " + shlex.quote(" && ".join(commands))
102
103
  thread = Thread(
103
104
  target=_start_runner,
104
105
  kwargs={
@@ -1,19 +1,20 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Optional
2
2
 
3
+ from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType
3
4
  from dstack._internal.core.models.fleets import ApplyFleetPlanInput, FleetSpec
4
5
  from dstack._internal.core.models.instances import Instance
5
6
 
6
7
 
7
- def get_get_plan_excludes(fleet_spec: FleetSpec) -> Dict:
8
- get_plan_excludes = {}
8
+ def get_get_plan_excludes(fleet_spec: FleetSpec) -> IncludeExcludeDictType:
9
+ get_plan_excludes: IncludeExcludeDictType = {}
9
10
  spec_excludes = get_fleet_spec_excludes(fleet_spec)
10
11
  if spec_excludes:
11
12
  get_plan_excludes["spec"] = spec_excludes
12
13
  return get_plan_excludes
13
14
 
14
15
 
15
- def get_apply_plan_excludes(plan_input: ApplyFleetPlanInput) -> Dict:
16
- apply_plan_excludes = {}
16
+ def get_apply_plan_excludes(plan_input: ApplyFleetPlanInput) -> IncludeExcludeDictType:
17
+ apply_plan_excludes: IncludeExcludeDictType = {}
17
18
  spec_excludes = get_fleet_spec_excludes(plan_input.spec)
18
19
  if spec_excludes:
19
20
  apply_plan_excludes["spec"] = spec_excludes
@@ -28,23 +29,23 @@ def get_apply_plan_excludes(plan_input: ApplyFleetPlanInput) -> Dict:
28
29
  return {"plan": apply_plan_excludes}
29
30
 
30
31
 
31
- def get_create_fleet_excludes(fleet_spec: FleetSpec) -> Dict:
32
- create_fleet_excludes = {}
32
+ def get_create_fleet_excludes(fleet_spec: FleetSpec) -> IncludeExcludeDictType:
33
+ create_fleet_excludes: IncludeExcludeDictType = {}
33
34
  spec_excludes = get_fleet_spec_excludes(fleet_spec)
34
35
  if spec_excludes:
35
36
  create_fleet_excludes["spec"] = spec_excludes
36
37
  return create_fleet_excludes
37
38
 
38
39
 
39
- def get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[Dict]:
40
+ def get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[IncludeExcludeDictType]:
40
41
  """
41
42
  Returns `fleet_spec` exclude mapping to exclude certain fields from the request.
42
43
  Use this method to exclude new fields when they are not set to keep
43
44
  clients backward-compatibility with older servers.
44
45
  """
45
- spec_excludes: Dict[str, Any] = {}
46
- configuration_excludes: Dict[str, Any] = {}
47
- profile_excludes: set[str] = set()
46
+ spec_excludes: IncludeExcludeDictType = {}
47
+ configuration_excludes: IncludeExcludeDictType = {}
48
+ profile_excludes: IncludeExcludeSetType = set()
48
49
  profile = fleet_spec.profile
49
50
  if profile.fleets is None:
50
51
  profile_excludes.add("fleets")