dstack 0.19.17__py3-none-any.whl → 0.19.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (86) hide show
  1. dstack/_internal/cli/services/configurators/fleet.py +111 -1
  2. dstack/_internal/cli/services/profile.py +1 -1
  3. dstack/_internal/core/backends/aws/compute.py +237 -18
  4. dstack/_internal/core/backends/base/compute.py +20 -2
  5. dstack/_internal/core/backends/cudo/compute.py +23 -9
  6. dstack/_internal/core/backends/gcp/compute.py +13 -7
  7. dstack/_internal/core/backends/lambdalabs/compute.py +2 -1
  8. dstack/_internal/core/compatibility/fleets.py +12 -11
  9. dstack/_internal/core/compatibility/gateways.py +9 -8
  10. dstack/_internal/core/compatibility/logs.py +4 -3
  11. dstack/_internal/core/compatibility/runs.py +29 -21
  12. dstack/_internal/core/compatibility/volumes.py +11 -8
  13. dstack/_internal/core/errors.py +4 -0
  14. dstack/_internal/core/models/common.py +45 -2
  15. dstack/_internal/core/models/configurations.py +9 -1
  16. dstack/_internal/core/models/fleets.py +2 -1
  17. dstack/_internal/core/models/profiles.py +8 -5
  18. dstack/_internal/core/models/resources.py +15 -8
  19. dstack/_internal/core/models/runs.py +41 -138
  20. dstack/_internal/core/models/volumes.py +14 -0
  21. dstack/_internal/core/services/diff.py +56 -3
  22. dstack/_internal/core/services/ssh/attach.py +2 -0
  23. dstack/_internal/server/app.py +37 -9
  24. dstack/_internal/server/background/__init__.py +66 -40
  25. dstack/_internal/server/background/tasks/process_fleets.py +19 -3
  26. dstack/_internal/server/background/tasks/process_gateways.py +47 -29
  27. dstack/_internal/server/background/tasks/process_idle_volumes.py +139 -0
  28. dstack/_internal/server/background/tasks/process_instances.py +13 -2
  29. dstack/_internal/server/background/tasks/process_placement_groups.py +4 -2
  30. dstack/_internal/server/background/tasks/process_running_jobs.py +14 -3
  31. dstack/_internal/server/background/tasks/process_runs.py +8 -4
  32. dstack/_internal/server/background/tasks/process_submitted_jobs.py +38 -7
  33. dstack/_internal/server/background/tasks/process_terminating_jobs.py +5 -3
  34. dstack/_internal/server/background/tasks/process_volumes.py +2 -2
  35. dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py +6 -6
  36. dstack/_internal/server/migrations/versions/d5863798bf41_add_volumemodel_last_job_processed_at.py +40 -0
  37. dstack/_internal/server/models.py +1 -0
  38. dstack/_internal/server/routers/backends.py +23 -16
  39. dstack/_internal/server/routers/files.py +7 -6
  40. dstack/_internal/server/routers/fleets.py +47 -36
  41. dstack/_internal/server/routers/gateways.py +27 -18
  42. dstack/_internal/server/routers/instances.py +18 -13
  43. dstack/_internal/server/routers/logs.py +7 -3
  44. dstack/_internal/server/routers/metrics.py +14 -8
  45. dstack/_internal/server/routers/projects.py +33 -22
  46. dstack/_internal/server/routers/repos.py +7 -6
  47. dstack/_internal/server/routers/runs.py +49 -28
  48. dstack/_internal/server/routers/secrets.py +20 -15
  49. dstack/_internal/server/routers/server.py +7 -4
  50. dstack/_internal/server/routers/users.py +22 -19
  51. dstack/_internal/server/routers/volumes.py +34 -25
  52. dstack/_internal/server/schemas/logs.py +2 -2
  53. dstack/_internal/server/schemas/runs.py +17 -5
  54. dstack/_internal/server/services/fleets.py +358 -75
  55. dstack/_internal/server/services/gateways/__init__.py +17 -6
  56. dstack/_internal/server/services/gateways/client.py +5 -3
  57. dstack/_internal/server/services/instances.py +8 -0
  58. dstack/_internal/server/services/jobs/__init__.py +45 -0
  59. dstack/_internal/server/services/jobs/configurators/base.py +12 -1
  60. dstack/_internal/server/services/locking.py +104 -13
  61. dstack/_internal/server/services/logging.py +4 -2
  62. dstack/_internal/server/services/logs/__init__.py +15 -2
  63. dstack/_internal/server/services/logs/aws.py +2 -4
  64. dstack/_internal/server/services/logs/filelog.py +33 -27
  65. dstack/_internal/server/services/logs/gcp.py +3 -5
  66. dstack/_internal/server/services/proxy/repo.py +4 -1
  67. dstack/_internal/server/services/runs.py +139 -72
  68. dstack/_internal/server/services/services/__init__.py +2 -1
  69. dstack/_internal/server/services/users.py +3 -1
  70. dstack/_internal/server/services/volumes.py +15 -2
  71. dstack/_internal/server/settings.py +25 -6
  72. dstack/_internal/server/statics/index.html +1 -1
  73. dstack/_internal/server/statics/{main-d151637af20f70b2e796.js → main-64f8273740c4b52c18f5.js} +71 -67
  74. dstack/_internal/server/statics/{main-d151637af20f70b2e796.js.map → main-64f8273740c4b52c18f5.js.map} +1 -1
  75. dstack/_internal/server/statics/{main-d48635d8fe670d53961c.css → main-d58fc0460cb0eae7cb5c.css} +1 -1
  76. dstack/_internal/server/testing/common.py +48 -8
  77. dstack/_internal/server/utils/routers.py +31 -8
  78. dstack/_internal/utils/json_utils.py +54 -0
  79. dstack/api/_public/runs.py +13 -2
  80. dstack/api/server/_runs.py +12 -2
  81. dstack/version.py +1 -1
  82. {dstack-0.19.17.dist-info → dstack-0.19.19.dist-info}/METADATA +17 -14
  83. {dstack-0.19.17.dist-info → dstack-0.19.19.dist-info}/RECORD +86 -83
  84. {dstack-0.19.17.dist-info → dstack-0.19.19.dist-info}/WHEEL +0 -0
  85. {dstack-0.19.17.dist-info → dstack-0.19.19.dist-info}/entry_points.txt +0 -0
  86. {dstack-0.19.17.dist-info → dstack-0.19.19.dist-info}/licenses/LICENSE.md +0 -0
@@ -25,6 +25,7 @@ from dstack._internal.core.errors import (
25
25
  ServerClientError,
26
26
  URLNotFoundError,
27
27
  )
28
+ from dstack._internal.core.models.common import ApplyAction
28
29
  from dstack._internal.core.models.configurations import ApplyConfigurationType
29
30
  from dstack._internal.core.models.fleets import (
30
31
  Fleet,
@@ -35,6 +36,7 @@ from dstack._internal.core.models.fleets import (
35
36
  )
36
37
  from dstack._internal.core.models.instances import InstanceAvailability, InstanceStatus, SSHKey
37
38
  from dstack._internal.core.models.repos.base import Repo
39
+ from dstack._internal.core.services.diff import diff_models
38
40
  from dstack._internal.utils.common import local_time
39
41
  from dstack._internal.utils.logging import get_logger
40
42
  from dstack._internal.utils.ssh import convert_ssh_key_to_pem, generate_public_key, pkey_from_str
@@ -71,7 +73,14 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
71
73
  spec=spec,
72
74
  )
73
75
  _print_plan_header(plan)
76
+ if plan.action is not None:
77
+ self._apply_plan(plan, command_args)
78
+ else:
79
+ # Old servers don't support spec update
80
+ self._apply_plan_on_old_server(plan, command_args)
74
81
 
82
+ def _apply_plan(self, plan: FleetPlan, command_args: argparse.Namespace):
83
+ delete_fleet_name: Optional[str] = None
75
84
  action_message = ""
76
85
  confirm_message = ""
77
86
  if plan.current_resource is None:
@@ -82,7 +91,108 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
82
91
  confirm_message += "Create the fleet?"
83
92
  else:
84
93
  action_message += f"Found fleet [code]{plan.spec.configuration.name}[/]."
85
- if plan.current_resource.spec.configuration == plan.spec.configuration:
94
+ if plan.action == ApplyAction.CREATE:
95
+ delete_fleet_name = plan.current_resource.name
96
+ action_message += (
97
+ " Configuration changes detected. Cannot update the fleet in-place"
98
+ )
99
+ confirm_message += "Re-create the fleet?"
100
+ elif plan.current_resource.spec == plan.effective_spec:
101
+ if command_args.yes and not command_args.force:
102
+ # --force is required only with --yes,
103
+ # otherwise we may ask for force apply interactively.
104
+ console.print(
105
+ "No configuration changes detected. Use --force to apply anyway."
106
+ )
107
+ return
108
+ delete_fleet_name = plan.current_resource.name
109
+ action_message += " No configuration changes detected."
110
+ confirm_message += "Re-create the fleet?"
111
+ else:
112
+ action_message += " Configuration changes detected."
113
+ confirm_message += "Update the fleet in-place?"
114
+
115
+ console.print(action_message)
116
+ if not command_args.yes and not confirm_ask(confirm_message):
117
+ console.print("\nExiting...")
118
+ return
119
+
120
+ if delete_fleet_name is not None:
121
+ with console.status("Deleting existing fleet..."):
122
+ self.api.client.fleets.delete(
123
+ project_name=self.api.project, names=[delete_fleet_name]
124
+ )
125
+ # Fleet deletion is async. Wait for fleet to be deleted.
126
+ while True:
127
+ try:
128
+ self.api.client.fleets.get(
129
+ project_name=self.api.project, name=delete_fleet_name
130
+ )
131
+ except ResourceNotExistsError:
132
+ break
133
+ else:
134
+ time.sleep(1)
135
+
136
+ try:
137
+ with console.status("Applying plan..."):
138
+ fleet = self.api.client.fleets.apply_plan(project_name=self.api.project, plan=plan)
139
+ except ServerClientError as e:
140
+ raise CLIError(e.msg)
141
+ if command_args.detach:
142
+ console.print("Fleet configuration submitted. Exiting...")
143
+ return
144
+ try:
145
+ with MultiItemStatus(
146
+ f"Provisioning [code]{fleet.name}[/]...", console=console
147
+ ) as live:
148
+ while not _finished_provisioning(fleet):
149
+ table = get_fleets_table([fleet])
150
+ live.update(table)
151
+ time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
152
+ fleet = self.api.client.fleets.get(self.api.project, fleet.name)
153
+ except KeyboardInterrupt:
154
+ if confirm_ask("Delete the fleet before exiting?"):
155
+ with console.status("Deleting fleet..."):
156
+ self.api.client.fleets.delete(
157
+ project_name=self.api.project, names=[fleet.name]
158
+ )
159
+ else:
160
+ console.print("Exiting... Fleet provisioning will continue in the background.")
161
+ return
162
+ console.print(
163
+ get_fleets_table(
164
+ [fleet],
165
+ verbose=_failed_provisioning(fleet),
166
+ format_date=local_time,
167
+ )
168
+ )
169
+ if _failed_provisioning(fleet):
170
+ console.print("\n[error]Some instances failed. Check the table above for errors.[/]")
171
+ exit(1)
172
+
173
+ def _apply_plan_on_old_server(self, plan: FleetPlan, command_args: argparse.Namespace):
174
+ action_message = ""
175
+ confirm_message = ""
176
+ if plan.current_resource is None:
177
+ if plan.spec.configuration.name is not None:
178
+ action_message += (
179
+ f"Fleet [code]{plan.spec.configuration.name}[/] does not exist yet."
180
+ )
181
+ confirm_message += "Create the fleet?"
182
+ else:
183
+ action_message += f"Found fleet [code]{plan.spec.configuration.name}[/]."
184
+ diff = diff_models(
185
+ old=plan.current_resource.spec.configuration,
186
+ new=plan.spec.configuration,
187
+ reset={
188
+ "ssh_config": {
189
+ "ssh_key": True,
190
+ "proxy_jump": {"ssh_key"},
191
+ "hosts": {"__all__": {"ssh_key": True, "proxy_jump": {"ssh_key"}}},
192
+ }
193
+ },
194
+ )
195
+ if not diff:
86
196
  if command_args.yes and not command_args.force:
87
197
  # --force is required only with --yes,
88
198
  # otherwise we may ask for force apply interactively.
@@ -159,7 +159,7 @@ def apply_profile_args(
159
159
  if args.idle_duration is not None:
160
160
  profile_settings.idle_duration = args.idle_duration
161
161
  elif args.dont_destroy:
162
- profile_settings.idle_duration = False
162
+ profile_settings.idle_duration = "off"
163
163
  if args.creation_policy_reuse:
164
164
  profile_settings.creation_policy = CreationPolicy.REUSE
165
165
 
@@ -1,14 +1,21 @@
1
+ import threading
1
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
2
3
  from typing import Any, Dict, List, Optional, Tuple
3
4
 
4
5
  import boto3
5
6
  import botocore.client
6
7
  import botocore.exceptions
8
+ from cachetools import Cache, TTLCache, cachedmethod
9
+ from cachetools.keys import hashkey
7
10
  from pydantic import ValidationError
8
11
 
9
12
  import dstack._internal.core.backends.aws.resources as aws_resources
10
13
  from dstack._internal import settings
11
- from dstack._internal.core.backends.aws.models import AWSAccessKeyCreds, AWSConfig
14
+ from dstack._internal.core.backends.aws.models import (
15
+ AWSAccessKeyCreds,
16
+ AWSConfig,
17
+ AWSOSImageConfig,
18
+ )
12
19
  from dstack._internal.core.backends.base.compute import (
13
20
  Compute,
14
21
  ComputeWithCreateInstanceSupport,
@@ -26,7 +33,12 @@ from dstack._internal.core.backends.base.compute import (
26
33
  merge_tags,
27
34
  )
28
35
  from dstack._internal.core.backends.base.offers import get_catalog_offers
29
- from dstack._internal.core.errors import ComputeError, NoCapacityError, PlacementGroupInUseError
36
+ from dstack._internal.core.errors import (
37
+ ComputeError,
38
+ NoCapacityError,
39
+ PlacementGroupInUseError,
40
+ PlacementGroupNotSupportedError,
41
+ )
30
42
  from dstack._internal.core.models.backends.base import BackendType
31
43
  from dstack._internal.core.models.common import CoreModel
32
44
  from dstack._internal.core.models.gateways import (
@@ -39,7 +51,11 @@ from dstack._internal.core.models.instances import (
39
51
  InstanceOffer,
40
52
  InstanceOfferWithAvailability,
41
53
  )
42
- from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
54
+ from dstack._internal.core.models.placement import (
55
+ PlacementGroup,
56
+ PlacementGroupProvisioningData,
57
+ PlacementStrategy,
58
+ )
43
59
  from dstack._internal.core.models.resources import Memory, Range
44
60
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
45
61
  from dstack._internal.core.models.volumes import (
@@ -66,6 +82,10 @@ class AWSVolumeBackendData(CoreModel):
66
82
  iops: int
67
83
 
68
84
 
85
+ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs):
86
+ return hashkey(*args, **kwargs)
87
+
88
+
69
89
  class AWSCompute(
70
90
  ComputeWithCreateInstanceSupport,
71
91
  ComputeWithMultinodeSupport,
@@ -86,6 +106,24 @@ class AWSCompute(
86
106
  )
87
107
  else: # default creds
88
108
  self.session = boto3.Session()
109
+ # Caches to avoid redundant API calls when provisioning many instances
110
+ # get_offers is already cached but we still cache its sub-functions
111
+ # with more aggressive/longer caches.
112
+ self._get_regions_to_quotas_cache_lock = threading.Lock()
113
+ self._get_regions_to_quotas_execution_lock = threading.Lock()
114
+ self._get_regions_to_quotas_cache = TTLCache(maxsize=10, ttl=300)
115
+ self._get_regions_to_zones_cache_lock = threading.Lock()
116
+ self._get_regions_to_zones_cache = Cache(maxsize=10)
117
+ self._get_vpc_id_subnet_id_or_error_cache_lock = threading.Lock()
118
+ self._get_vpc_id_subnet_id_or_error_cache = TTLCache(maxsize=100, ttl=600)
119
+ self._get_maximum_efa_interfaces_cache_lock = threading.Lock()
120
+ self._get_maximum_efa_interfaces_cache = Cache(maxsize=100)
121
+ self._get_subnets_availability_zones_cache_lock = threading.Lock()
122
+ self._get_subnets_availability_zones_cache = Cache(maxsize=100)
123
+ self._create_security_group_cache_lock = threading.Lock()
124
+ self._create_security_group_cache = TTLCache(maxsize=100, ttl=600)
125
+ self._get_image_id_and_username_cache_lock = threading.Lock()
126
+ self._get_image_id_and_username_cache = TTLCache(maxsize=100, ttl=600)
89
127
 
90
128
  def get_offers(
91
129
  self, requirements: Optional[Requirements] = None
@@ -126,8 +164,11 @@ class AWSCompute(
126
164
  extra_filter=filter,
127
165
  )
128
166
  regions = list(set(i.region for i in offers))
129
- regions_to_quotas = _get_regions_to_quotas(self.session, regions)
130
- regions_to_zones = _get_regions_to_zones(self.session, regions)
167
+ with self._get_regions_to_quotas_execution_lock:
168
+ # Cache lock does not prevent concurrent execution.
169
+ # We use a separate lock to avoid requesting quotas in parallel and hitting rate limits.
170
+ regions_to_quotas = self._get_regions_to_quotas(self.session, regions)
171
+ regions_to_zones = self._get_regions_to_zones(self.session, regions)
131
172
 
132
173
  availability_offers = []
133
174
  for offer in offers:
@@ -186,21 +227,24 @@ class AWSCompute(
186
227
  tags = aws_resources.filter_invalid_tags(tags)
187
228
 
188
229
  disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
189
- max_efa_interfaces = _get_maximum_efa_interfaces(
190
- ec2_client=ec2_client, instance_type=instance_offer.instance.name
230
+ max_efa_interfaces = self._get_maximum_efa_interfaces(
231
+ ec2_client=ec2_client,
232
+ region=instance_offer.region,
233
+ instance_type=instance_offer.instance.name,
191
234
  )
192
235
  enable_efa = max_efa_interfaces > 0
193
236
  is_capacity_block = False
194
237
  try:
195
- vpc_id, subnet_ids = get_vpc_id_subnet_id_or_error(
238
+ vpc_id, subnet_ids = self._get_vpc_id_subnet_id_or_error(
196
239
  ec2_client=ec2_client,
197
240
  config=self.config,
198
241
  region=instance_offer.region,
199
242
  allocate_public_ip=allocate_public_ip,
200
243
  availability_zones=zones,
201
244
  )
202
- subnet_id_to_az_map = aws_resources.get_subnets_availability_zones(
245
+ subnet_id_to_az_map = self._get_subnets_availability_zones(
203
246
  ec2_client=ec2_client,
247
+ region=instance_offer.region,
204
248
  subnet_ids=subnet_ids,
205
249
  )
206
250
  if instance_config.reservation:
@@ -229,12 +273,19 @@ class AWSCompute(
229
273
  tried_zones.add(az)
230
274
  try:
231
275
  logger.debug("Trying provisioning %s in %s", instance_offer.instance.name, az)
232
- image_id, username = aws_resources.get_image_id_and_username(
276
+ image_id, username = self._get_image_id_and_username(
233
277
  ec2_client=ec2_client,
278
+ region=instance_offer.region,
234
279
  cuda=len(instance_offer.instance.resources.gpus) > 0,
235
280
  instance_type=instance_offer.instance.name,
236
281
  image_config=self.config.os_images,
237
282
  )
283
+ security_group_id = self._create_security_group(
284
+ ec2_client=ec2_client,
285
+ region=instance_offer.region,
286
+ project_id=project_name,
287
+ vpc_id=vpc_id,
288
+ )
238
289
  response = ec2_resource.create_instances(
239
290
  **aws_resources.create_instances_struct(
240
291
  disk_size=disk_size,
@@ -243,11 +294,7 @@ class AWSCompute(
243
294
  iam_instance_profile=self.config.iam_instance_profile,
244
295
  user_data=get_user_data(authorized_keys=instance_config.get_public_keys()),
245
296
  tags=aws_resources.make_tags(tags),
246
- security_group_id=aws_resources.create_security_group(
247
- ec2_client=ec2_client,
248
- project_id=project_name,
249
- vpc_id=vpc_id,
250
- ),
297
+ security_group_id=security_group_id,
251
298
  spot=instance_offer.instance.resources.spot,
252
299
  subnet_id=subnet_id,
253
300
  allocate_public_ip=allocate_public_ip,
@@ -296,6 +343,8 @@ class AWSCompute(
296
343
  placement_group: PlacementGroup,
297
344
  master_instance_offer: InstanceOffer,
298
345
  ) -> PlacementGroupProvisioningData:
346
+ if not _offer_supports_placement_group(master_instance_offer, placement_group):
347
+ raise PlacementGroupNotSupportedError()
299
348
  ec2_client = self.session.client("ec2", region_name=placement_group.configuration.region)
300
349
  logger.debug("Creating placement group %s...", placement_group.name)
301
350
  ec2_client.create_placement_group(
@@ -332,6 +381,8 @@ class AWSCompute(
332
381
  placement_group: PlacementGroup,
333
382
  instance_offer: InstanceOffer,
334
383
  ) -> bool:
384
+ if not _offer_supports_placement_group(instance_offer, placement_group):
385
+ return False
335
386
  return (
336
387
  placement_group.configuration.backend == BackendType.AWS
337
388
  and placement_group.configuration.region == instance_offer.region
@@ -361,7 +412,7 @@ class AWSCompute(
361
412
  tags = aws_resources.filter_invalid_tags(tags)
362
413
  tags = aws_resources.make_tags(tags)
363
414
 
364
- vpc_id, subnets_ids = get_vpc_id_subnet_id_or_error(
415
+ vpc_id, subnets_ids = self._get_vpc_id_subnet_id_or_error(
365
416
  ec2_client=ec2_client,
366
417
  config=self.config,
367
418
  region=configuration.region,
@@ -696,6 +747,165 @@ class AWSCompute(
696
747
  return True
697
748
  return True
698
749
 
750
+ def _get_regions_to_quotas_key(
751
+ self,
752
+ session: boto3.Session,
753
+ regions: List[str],
754
+ ) -> tuple:
755
+ return hashkey(tuple(regions))
756
+
757
+ @cachedmethod(
758
+ cache=lambda self: self._get_regions_to_quotas_cache,
759
+ key=_get_regions_to_quotas_key,
760
+ lock=lambda self: self._get_regions_to_quotas_cache_lock,
761
+ )
762
+ def _get_regions_to_quotas(
763
+ self,
764
+ session: boto3.Session,
765
+ regions: List[str],
766
+ ) -> Dict[str, Dict[str, int]]:
767
+ return _get_regions_to_quotas(session=session, regions=regions)
768
+
769
+ def _get_regions_to_zones_key(
770
+ self,
771
+ session: boto3.Session,
772
+ regions: List[str],
773
+ ) -> tuple:
774
+ return hashkey(tuple(regions))
775
+
776
+ @cachedmethod(
777
+ cache=lambda self: self._get_regions_to_zones_cache,
778
+ key=_get_regions_to_zones_key,
779
+ lock=lambda self: self._get_regions_to_zones_cache_lock,
780
+ )
781
+ def _get_regions_to_zones(
782
+ self,
783
+ session: boto3.Session,
784
+ regions: List[str],
785
+ ) -> Dict[str, List[str]]:
786
+ return _get_regions_to_zones(session=session, regions=regions)
787
+
788
+ def _get_vpc_id_subnet_id_or_error_cache_key(
789
+ self,
790
+ ec2_client: botocore.client.BaseClient,
791
+ config: AWSConfig,
792
+ region: str,
793
+ allocate_public_ip: bool,
794
+ availability_zones: Optional[List[str]] = None,
795
+ ) -> tuple:
796
+ return hashkey(
797
+ region, allocate_public_ip, tuple(availability_zones) if availability_zones else None
798
+ )
799
+
800
+ @cachedmethod(
801
+ cache=lambda self: self._get_vpc_id_subnet_id_or_error_cache,
802
+ key=_get_vpc_id_subnet_id_or_error_cache_key,
803
+ lock=lambda self: self._get_vpc_id_subnet_id_or_error_cache_lock,
804
+ )
805
+ def _get_vpc_id_subnet_id_or_error(
806
+ self,
807
+ ec2_client: botocore.client.BaseClient,
808
+ config: AWSConfig,
809
+ region: str,
810
+ allocate_public_ip: bool,
811
+ availability_zones: Optional[List[str]] = None,
812
+ ) -> Tuple[str, List[str]]:
813
+ return get_vpc_id_subnet_id_or_error(
814
+ ec2_client=ec2_client,
815
+ config=config,
816
+ region=region,
817
+ allocate_public_ip=allocate_public_ip,
818
+ availability_zones=availability_zones,
819
+ )
820
+
821
+ @cachedmethod(
822
+ cache=lambda self: self._get_maximum_efa_interfaces_cache,
823
+ key=_ec2client_cache_methodkey,
824
+ lock=lambda self: self._get_maximum_efa_interfaces_cache_lock,
825
+ )
826
+ def _get_maximum_efa_interfaces(
827
+ self,
828
+ ec2_client: botocore.client.BaseClient,
829
+ region: str,
830
+ instance_type: str,
831
+ ) -> int:
832
+ return _get_maximum_efa_interfaces(
833
+ ec2_client=ec2_client,
834
+ instance_type=instance_type,
835
+ )
836
+
837
+ def _get_subnets_availability_zones_key(
838
+ self,
839
+ ec2_client: botocore.client.BaseClient,
840
+ region: str,
841
+ subnet_ids: List[str],
842
+ ) -> tuple:
843
+ return hashkey(region, tuple(subnet_ids))
844
+
845
+ @cachedmethod(
846
+ cache=lambda self: self._get_subnets_availability_zones_cache,
847
+ key=_get_subnets_availability_zones_key,
848
+ lock=lambda self: self._get_subnets_availability_zones_cache_lock,
849
+ )
850
+ def _get_subnets_availability_zones(
851
+ self,
852
+ ec2_client: botocore.client.BaseClient,
853
+ region: str,
854
+ subnet_ids: List[str],
855
+ ) -> Dict[str, str]:
856
+ return aws_resources.get_subnets_availability_zones(
857
+ ec2_client=ec2_client,
858
+ subnet_ids=subnet_ids,
859
+ )
860
+
861
+ @cachedmethod(
862
+ cache=lambda self: self._create_security_group_cache,
863
+ key=_ec2client_cache_methodkey,
864
+ lock=lambda self: self._create_security_group_cache_lock,
865
+ )
866
+ def _create_security_group(
867
+ self,
868
+ ec2_client: botocore.client.BaseClient,
869
+ region: str,
870
+ project_id: str,
871
+ vpc_id: Optional[str],
872
+ ) -> str:
873
+ return aws_resources.create_security_group(
874
+ ec2_client=ec2_client,
875
+ project_id=project_id,
876
+ vpc_id=vpc_id,
877
+ )
878
+
879
+ def _get_image_id_and_username_cache_key(
880
+ self,
881
+ ec2_client: botocore.client.BaseClient,
882
+ region: str,
883
+ cuda: bool,
884
+ instance_type: str,
885
+ image_config: Optional[AWSOSImageConfig] = None,
886
+ ) -> tuple:
887
+ return hashkey(region, cuda, instance_type, image_config.json() if image_config else None)
888
+
889
+ @cachedmethod(
890
+ cache=lambda self: self._get_image_id_and_username_cache,
891
+ key=_get_image_id_and_username_cache_key,
892
+ lock=lambda self: self._get_image_id_and_username_cache_lock,
893
+ )
894
+ def _get_image_id_and_username(
895
+ self,
896
+ ec2_client: botocore.client.BaseClient,
897
+ region: str,
898
+ cuda: bool,
899
+ instance_type: str,
900
+ image_config: Optional[AWSOSImageConfig] = None,
901
+ ) -> tuple[str, str]:
902
+ return aws_resources.get_image_id_and_username(
903
+ ec2_client=ec2_client,
904
+ cuda=cuda,
905
+ instance_type=instance_type,
906
+ image_config=image_config,
907
+ )
908
+
699
909
 
700
910
  def get_vpc_id_subnet_id_or_error(
701
911
  ec2_client: botocore.client.BaseClient,
@@ -798,7 +1008,7 @@ def _get_regions_to_quotas(
798
1008
  return region_quotas
799
1009
 
800
1010
  regions_to_quotas = {}
801
- with ThreadPoolExecutor(max_workers=8) as executor:
1011
+ with ThreadPoolExecutor(max_workers=12) as executor:
802
1012
  future_to_region = {}
803
1013
  for region in regions:
804
1014
  future = executor.submit(
@@ -823,7 +1033,7 @@ def _has_quota(quotas: Dict[str, int], instance_name: str) -> Optional[bool]:
823
1033
 
824
1034
  def _get_regions_to_zones(session: boto3.Session, regions: List[str]) -> Dict[str, List[str]]:
825
1035
  regions_to_zones = {}
826
- with ThreadPoolExecutor(max_workers=8) as executor:
1036
+ with ThreadPoolExecutor(max_workers=12) as executor:
827
1037
  future_to_region = {}
828
1038
  for region in regions:
829
1039
  future = executor.submit(
@@ -862,6 +1072,15 @@ def _supported_instances(offer: InstanceOffer) -> bool:
862
1072
  return False
863
1073
 
864
1074
 
1075
+ def _offer_supports_placement_group(offer: InstanceOffer, placement_group: PlacementGroup) -> bool:
1076
+ if placement_group.configuration.placement_strategy != PlacementStrategy.CLUSTER:
1077
+ return True
1078
+ for family in ["t3.", "t2."]:
1079
+ if offer.instance.name.startswith(family):
1080
+ return False
1081
+ return True
1082
+
1083
+
865
1084
  def _get_maximum_efa_interfaces(ec2_client: botocore.client.BaseClient, instance_type: str) -> int:
866
1085
  try:
867
1086
  response = ec2_client.describe_instance_types(
@@ -57,7 +57,7 @@ class Compute(ABC):
57
57
 
58
58
  def __init__(self):
59
59
  self._offers_cache_lock = threading.Lock()
60
- self._offers_cache = TTLCache(maxsize=5, ttl=30)
60
+ self._offers_cache = TTLCache(maxsize=10, ttl=180)
61
61
 
62
62
  @abstractmethod
63
63
  def get_offers(
@@ -559,7 +559,8 @@ def get_shim_commands(
559
559
  backend_shim_env: Optional[Dict[str, str]] = None,
560
560
  arch: Optional[str] = None,
561
561
  ) -> List[str]:
562
- commands = get_shim_pre_start_commands(
562
+ commands = get_setup_cloud_instance_commands()
563
+ commands += get_shim_pre_start_commands(
563
564
  base_path=base_path,
564
565
  bin_path=bin_path,
565
566
  arch=arch,
@@ -641,6 +642,23 @@ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
641
642
  return url_template.format(version=version, arch=arch)
642
643
 
643
644
 
645
+ def get_setup_cloud_instance_commands() -> list[str]:
646
+ return [
647
+ # Workaround for https://github.com/NVIDIA/nvidia-container-toolkit/issues/48
648
+ # Attempts to patch /etc/docker/daemon.json while keeping any custom settings it may have.
649
+ (
650
+ "/bin/sh -c '" # wrap in /bin/sh to avoid interfering with other cloud init commands
651
+ " grep -q nvidia /etc/docker/daemon.json"
652
+ " && ! grep -q native.cgroupdriver /etc/docker/daemon.json"
653
+ " && jq '\\''.\"exec-opts\" = ((.\"exec-opts\" // []) + [\"native.cgroupdriver=cgroupfs\"])'\\'' /etc/docker/daemon.json > /tmp/daemon.json"
654
+ " && sudo mv /tmp/daemon.json /etc/docker/daemon.json"
655
+ " && sudo service docker restart"
656
+ " || true"
657
+ "'"
658
+ ),
659
+ ]
660
+
661
+
644
662
  def get_shim_pre_start_commands(
645
663
  base_path: Optional[PathLike] = None,
646
664
  bin_path: Optional[PathLike] = None,
@@ -65,12 +65,13 @@ class CudoCompute(
65
65
  public_keys = instance_config.get_public_keys()
66
66
  memory_size = round(instance_offer.instance.resources.memory_mib / 1024)
67
67
  disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
68
- commands = get_shim_commands(authorized_keys=public_keys)
69
68
  gpus_no = len(instance_offer.instance.resources.gpus)
70
- shim_commands = " ".join([" && ".join(commands)])
71
- startup_script = (
72
- shim_commands if gpus_no > 0 else f"{install_docker_script()} && {shim_commands}"
73
- )
69
+ if gpus_no > 0:
70
+ # we'll need jq for patching /etc/docker/daemon.json, see get_shim_commands()
71
+ commands = install_jq_commands()
72
+ else:
73
+ commands = install_docker_commands()
74
+ commands += get_shim_commands(authorized_keys=public_keys)
74
75
 
75
76
  try:
76
77
  resp_data = self.api_client.create_virtual_machine(
@@ -85,7 +86,7 @@ class CudoCompute(
85
86
  memory_gib=memory_size,
86
87
  vcpus=instance_offer.instance.resources.cpus,
87
88
  vm_id=vm_id,
88
- start_script=startup_script,
89
+ start_script=" && ".join(commands),
89
90
  password=None,
90
91
  customSshKeys=public_keys,
91
92
  )
@@ -151,6 +152,19 @@ def _get_image_id(cuda: bool) -> str:
151
152
  return image_name
152
153
 
153
154
 
154
- def install_docker_script():
155
- commands = 'export DEBIAN_FRONTEND="noninteractive" && mkdir -p /etc/apt/keyrings && curl --max-time 60 -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /etc/apt/keyrings/docker.gpg && echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null && apt-get update && apt-get --assume-yes install docker-ce docker-ce-cli containerd.io docker-compose-plugin'
156
- return commands
155
+ def install_jq_commands():
156
+ return [
157
+ "export DEBIAN_FRONTEND=noninteractive",
158
+ "apt-get --assume-yes install jq",
159
+ ]
160
+
161
+
162
+ def install_docker_commands():
163
+ return [
164
+ "export DEBIAN_FRONTEND=noninteractive",
165
+ "mkdir -p /etc/apt/keyrings",
166
+ "curl --max-time 60 -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /etc/apt/keyrings/docker.gpg",
167
+ 'echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null',
168
+ "apt-get update",
169
+ "apt-get --assume-yes install docker-ce docker-ce-cli containerd.io docker-compose-plugin",
170
+ ]