dstack 0.19.7__py3-none-any.whl → 0.19.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (52) hide show
  1. dstack/_internal/cli/services/args.py +2 -2
  2. dstack/_internal/cli/services/configurators/run.py +38 -2
  3. dstack/_internal/cli/utils/run.py +3 -3
  4. dstack/_internal/core/backends/aws/compute.py +13 -1
  5. dstack/_internal/core/backends/azure/compute.py +42 -13
  6. dstack/_internal/core/backends/azure/configurator.py +21 -0
  7. dstack/_internal/core/backends/azure/models.py +9 -0
  8. dstack/_internal/core/backends/base/compute.py +101 -27
  9. dstack/_internal/core/backends/base/offers.py +13 -3
  10. dstack/_internal/core/backends/cudo/compute.py +2 -0
  11. dstack/_internal/core/backends/datacrunch/compute.py +2 -0
  12. dstack/_internal/core/backends/gcp/auth.py +1 -1
  13. dstack/_internal/core/backends/gcp/compute.py +51 -35
  14. dstack/_internal/core/backends/lambdalabs/compute.py +20 -8
  15. dstack/_internal/core/backends/local/compute.py +2 -0
  16. dstack/_internal/core/backends/nebius/compute.py +95 -1
  17. dstack/_internal/core/backends/nebius/configurator.py +11 -0
  18. dstack/_internal/core/backends/nebius/fabrics.py +47 -0
  19. dstack/_internal/core/backends/nebius/models.py +8 -0
  20. dstack/_internal/core/backends/nebius/resources.py +29 -0
  21. dstack/_internal/core/backends/oci/compute.py +2 -0
  22. dstack/_internal/core/backends/remote/provisioning.py +27 -2
  23. dstack/_internal/core/backends/template/compute.py.jinja +2 -0
  24. dstack/_internal/core/backends/tensordock/compute.py +2 -0
  25. dstack/_internal/core/backends/vultr/compute.py +5 -1
  26. dstack/_internal/core/models/instances.py +2 -1
  27. dstack/_internal/core/models/resources.py +78 -3
  28. dstack/_internal/core/models/runs.py +7 -2
  29. dstack/_internal/core/models/volumes.py +1 -1
  30. dstack/_internal/server/background/tasks/process_fleets.py +4 -13
  31. dstack/_internal/server/background/tasks/process_instances.py +176 -55
  32. dstack/_internal/server/background/tasks/process_placement_groups.py +1 -1
  33. dstack/_internal/server/background/tasks/process_prometheus_metrics.py +5 -2
  34. dstack/_internal/server/models.py +1 -0
  35. dstack/_internal/server/services/fleets.py +9 -26
  36. dstack/_internal/server/services/instances.py +0 -2
  37. dstack/_internal/server/services/offers.py +15 -0
  38. dstack/_internal/server/services/placement.py +27 -6
  39. dstack/_internal/server/services/resources.py +21 -0
  40. dstack/_internal/server/services/runs.py +16 -6
  41. dstack/_internal/server/testing/common.py +35 -26
  42. dstack/_internal/utils/common.py +13 -1
  43. dstack/_internal/utils/json_schema.py +6 -3
  44. dstack/api/__init__.py +1 -0
  45. dstack/api/server/_fleets.py +16 -0
  46. dstack/api/server/_runs.py +44 -3
  47. dstack/version.py +1 -1
  48. {dstack-0.19.7.dist-info → dstack-0.19.8.dist-info}/METADATA +3 -1
  49. {dstack-0.19.7.dist-info → dstack-0.19.8.dist-info}/RECORD +52 -50
  50. {dstack-0.19.7.dist-info → dstack-0.19.8.dist-info}/WHEEL +0 -0
  51. {dstack-0.19.7.dist-info → dstack-0.19.8.dist-info}/entry_points.txt +0 -0
  52. {dstack-0.19.7.dist-info → dstack-0.19.8.dist-info}/licenses/LICENSE.md +0 -0
@@ -19,8 +19,8 @@ def port_mapping(v: str) -> PortMapping:
19
19
  return PortMapping.parse(v)
20
20
 
21
21
 
22
- def cpu_spec(v: str) -> resources.Range[int]:
23
- return parse_obj_as(resources.Range[int], v)
22
+ def cpu_spec(v: str) -> dict:
23
+ return resources.CPUSpec.parse(v)
24
24
 
25
25
 
26
26
  def memory_spec(v: str) -> resources.Range[resources.Memory]:
@@ -6,9 +6,10 @@ from pathlib import Path
6
6
  from typing import Dict, List, Optional, Set, Tuple
7
7
 
8
8
  import gpuhunt
9
+ from pydantic import parse_obj_as
9
10
 
10
11
  import dstack._internal.core.models.resources as resources
11
- from dstack._internal.cli.services.args import disk_spec, gpu_spec, port_mapping
12
+ from dstack._internal.cli.services.args import cpu_spec, disk_spec, gpu_spec, port_mapping
12
13
  from dstack._internal.cli.services.configurators.base import (
13
14
  ApplyEnvVarsConfiguratorMixin,
14
15
  BaseApplyConfigurator,
@@ -39,6 +40,7 @@ from dstack._internal.core.models.configurations import (
39
40
  TaskConfiguration,
40
41
  )
41
42
  from dstack._internal.core.models.repos.base import Repo
43
+ from dstack._internal.core.models.resources import CPUSpec
42
44
  from dstack._internal.core.models.runs import JobSubmission, JobTerminationReason, RunStatus
43
45
  from dstack._internal.core.services.configs import ConfigManager
44
46
  from dstack._internal.core.services.diff import diff_models
@@ -72,6 +74,7 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
72
74
  ):
73
75
  self.apply_args(conf, configurator_args, unknown_args)
74
76
  self.validate_gpu_vendor_and_image(conf)
77
+ self.validate_cpu_arch_and_image(conf)
75
78
  if repo is None:
76
79
  repo = self.api.repos.load(Path.cwd())
77
80
  config_manager = ConfigManager()
@@ -289,6 +292,14 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
289
292
  default=default_max_offers,
290
293
  )
291
294
  cls.register_env_args(configuration_group)
295
+ configuration_group.add_argument(
296
+ "--cpu",
297
+ type=cpu_spec,
298
+ help="Request CPU for the run. "
299
+ "The format is [code]ARCH[/]:[code]COUNT[/] (all parts are optional)",
300
+ dest="cpu_spec",
301
+ metavar="SPEC",
302
+ )
292
303
  configuration_group.add_argument(
293
304
  "--gpu",
294
305
  type=gpu_spec,
@@ -310,6 +321,8 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
310
321
  apply_profile_args(args, conf)
311
322
  if args.run_name:
312
323
  conf.name = args.run_name
324
+ if args.cpu_spec:
325
+ conf.resources.cpu = resources.CPUSpec.parse_obj(args.cpu_spec)
313
326
  if args.gpu_spec:
314
327
  conf.resources.gpu = resources.GPUSpec.parse_obj(args.gpu_spec)
315
328
  if args.disk_spec:
@@ -342,7 +355,7 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
342
355
 
343
356
  def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None:
344
357
  """
345
- Infers `resources.gpu.vendor` if not set, requires `image` if the vendor is AMD.
358
+ Infers and sets `resources.gpu.vendor` if not set, requires `image` if the vendor is AMD.
346
359
  """
347
360
  gpu_spec = conf.resources.gpu
348
361
  if gpu_spec is None:
@@ -400,6 +413,29 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
400
413
  "`image` is required if `resources.gpu.vendor` is `tenstorrent`"
401
414
  )
402
415
 
416
+ def validate_cpu_arch_and_image(self, conf: BaseRunConfiguration) -> None:
417
+ """
418
+ Infers `resources.cpu.arch` if not set, requires `image` if the architecture is ARM.
419
+ """
420
+ # TODO: Remove in 0.20. Use conf.resources.cpu directly
421
+ cpu_spec = parse_obj_as(CPUSpec, conf.resources.cpu)
422
+ arch = cpu_spec.arch
423
+ if arch is None:
424
+ gpu_spec = conf.resources.gpu
425
+ if (
426
+ gpu_spec is not None
427
+ and gpu_spec.vendor in [None, gpuhunt.AcceleratorVendor.NVIDIA]
428
+ and gpu_spec.name
429
+ and any(map(gpuhunt.is_nvidia_superchip, gpu_spec.name))
430
+ ):
431
+ arch = gpuhunt.CPUArchitecture.ARM
432
+ else:
433
+ arch = gpuhunt.CPUArchitecture.X86
434
+ # NOTE: We don't set the inferred resources.cpu.arch for compatibility with older servers.
435
+ # Servers with ARM support set the arch using the same logic.
436
+ if arch == gpuhunt.CPUArchitecture.ARM and conf.image is None:
437
+ raise ConfigurationError("`image` is required if `resources.cpu.arch` is `arm`")
438
+
403
439
 
404
440
  class RunWithPortsConfigurator(BaseRunConfigurator):
405
441
  @classmethod
@@ -1,4 +1,4 @@
1
- import os
1
+ import shutil
2
2
  from typing import Any, Dict, List, Optional, Union
3
3
 
4
4
  from rich.markup import escape
@@ -95,7 +95,7 @@ def print_run_plan(
95
95
  props.add_row(th("Inactivity duration"), inactivity_duration)
96
96
  props.add_row(th("Reservation"), run_spec.configuration.reservation or "-")
97
97
 
98
- offers = Table(box=None, expand=os.get_terminal_size()[0] <= 110)
98
+ offers = Table(box=None, expand=shutil.get_terminal_size(fallback=(120, 40)).columns <= 110)
99
99
  offers.add_column("#")
100
100
  offers.add_column("BACKEND", style="grey58", ratio=2)
101
101
  offers.add_column("RESOURCES", ratio=4)
@@ -149,7 +149,7 @@ def print_run_plan(
149
149
  def get_runs_table(
150
150
  runs: List[Run], verbose: bool = False, format_date: DateFormatter = pretty_date
151
151
  ) -> Table:
152
- table = Table(box=None, expand=os.get_terminal_size()[0] <= 110)
152
+ table = Table(box=None, expand=shutil.get_terminal_size(fallback=(120, 40)).columns <= 110)
153
153
  table.add_column("NAME", style="bold", no_wrap=True, ratio=2)
154
154
  table.add_column("BACKEND", style="grey58", ratio=2)
155
155
  table.add_column("RESOURCES", ratio=3 if not verbose else 2)
@@ -159,6 +159,7 @@ class AWSCompute(
159
159
  self,
160
160
  instance_offer: InstanceOfferWithAvailability,
161
161
  instance_config: InstanceConfiguration,
162
+ placement_group: Optional[PlacementGroup],
162
163
  ) -> JobProvisioningData:
163
164
  project_name = instance_config.project_name
164
165
  ec2_resource = self.session.resource("ec2", region_name=instance_offer.region)
@@ -248,7 +249,7 @@ class AWSCompute(
248
249
  spot=instance_offer.instance.resources.spot,
249
250
  subnet_id=subnet_id,
250
251
  allocate_public_ip=allocate_public_ip,
251
- placement_group_name=instance_config.placement_group_name,
252
+ placement_group_name=placement_group.name if placement_group else None,
252
253
  enable_efa=enable_efa,
253
254
  max_efa_interfaces=max_efa_interfaces,
254
255
  reservation_id=instance_config.reservation,
@@ -291,6 +292,7 @@ class AWSCompute(
291
292
  def create_placement_group(
292
293
  self,
293
294
  placement_group: PlacementGroup,
295
+ master_instance_offer: InstanceOffer,
294
296
  ) -> PlacementGroupProvisioningData:
295
297
  ec2_client = self.session.client("ec2", region_name=placement_group.configuration.region)
296
298
  logger.debug("Creating placement group %s...", placement_group.name)
@@ -323,6 +325,16 @@ class AWSCompute(
323
325
  raise e
324
326
  logger.debug("Deleted placement group %s", placement_group.name)
325
327
 
328
+ def is_suitable_placement_group(
329
+ self,
330
+ placement_group: PlacementGroup,
331
+ instance_offer: InstanceOffer,
332
+ ) -> bool:
333
+ return (
334
+ placement_group.configuration.backend == BackendType.AWS
335
+ and placement_group.configuration.region == instance_offer.region
336
+ )
337
+
326
338
  def create_gateway(
327
339
  self,
328
340
  configuration: GatewayComputeConfiguration,
@@ -62,6 +62,7 @@ from dstack._internal.core.models.instances import (
62
62
  InstanceOfferWithAvailability,
63
63
  InstanceType,
64
64
  )
65
+ from dstack._internal.core.models.placement import PlacementGroup
65
66
  from dstack._internal.core.models.resources import Memory, Range
66
67
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
67
68
  from dstack._internal.utils.logging import get_logger
@@ -109,6 +110,7 @@ class AzureCompute(
109
110
  self,
110
111
  instance_offer: InstanceOfferWithAvailability,
111
112
  instance_config: InstanceConfiguration,
113
+ placement_group: Optional[PlacementGroup],
112
114
  ) -> JobProvisioningData:
113
115
  instance_name = generate_unique_instance_name(
114
116
  instance_config, max_length=azure_resources.MAX_RESOURCE_NAME_LEN
@@ -136,6 +138,10 @@ class AzureCompute(
136
138
  location=location,
137
139
  )
138
140
 
141
+ managed_identity_resource_group, managed_identity_name = parse_vm_managed_identity(
142
+ self.config.vm_managed_identity
143
+ )
144
+
139
145
  base_tags = {
140
146
  "owner": "dstack",
141
147
  "dstack_project": instance_config.project_name,
@@ -159,7 +165,8 @@ class AzureCompute(
159
165
  network_security_group=network_security_group,
160
166
  network=network,
161
167
  subnet=subnet,
162
- managed_identity=None,
168
+ managed_identity_name=managed_identity_name,
169
+ managed_identity_resource_group=managed_identity_resource_group,
163
170
  image_reference=_get_image_ref(
164
171
  compute_client=self._compute_client,
165
172
  location=location,
@@ -255,7 +262,8 @@ class AzureCompute(
255
262
  network_security_group=network_security_group,
256
263
  network=network,
257
264
  subnet=subnet,
258
- managed_identity=None,
265
+ managed_identity_name=None,
266
+ managed_identity_resource_group=None,
259
267
  image_reference=_get_gateway_image_ref(),
260
268
  vm_size="Standard_B1ms",
261
269
  instance_name=instance_name,
@@ -338,6 +346,21 @@ def get_resource_group_network_subnet_or_error(
338
346
  return resource_group, network_name, subnet_name
339
347
 
340
348
 
349
+ def parse_vm_managed_identity(
350
+ vm_managed_identity: Optional[str],
351
+ ) -> Tuple[Optional[str], Optional[str]]:
352
+ if vm_managed_identity is None:
353
+ return None, None
354
+ try:
355
+ resource_group, managed_identity = vm_managed_identity.split("/")
356
+ return resource_group, managed_identity
357
+ except Exception:
358
+ raise ComputeError(
359
+ "`vm_managed_identity` specified in incorrect format."
360
+ " Supported format: 'managedIdentityResourceGroup/managedIdentityName'"
361
+ )
362
+
363
+
341
364
  def _parse_config_vpc_id(vpc_id: str) -> Tuple[str, str]:
342
365
  resource_group, network_name = vpc_id.split("/")
343
366
  return resource_group, network_name
@@ -466,7 +489,8 @@ def _launch_instance(
466
489
  network_security_group: str,
467
490
  network: str,
468
491
  subnet: str,
469
- managed_identity: Optional[str],
492
+ managed_identity_name: Optional[str],
493
+ managed_identity_resource_group: Optional[str],
470
494
  image_reference: ImageReference,
471
495
  vm_size: str,
472
496
  instance_name: str,
@@ -488,6 +512,20 @@ def _launch_instance(
488
512
  public_ip_address_configuration = VirtualMachinePublicIPAddressConfiguration(
489
513
  name="public_ip_config",
490
514
  )
515
+ managed_identity = None
516
+ if managed_identity_name is not None:
517
+ if managed_identity_resource_group is None:
518
+ managed_identity_resource_group = resource_group
519
+ managed_identity = VirtualMachineIdentity(
520
+ type=ResourceIdentityType.USER_ASSIGNED,
521
+ user_assigned_identities={
522
+ azure_utils.get_managed_identity_id(
523
+ subscription_id,
524
+ managed_identity_resource_group,
525
+ managed_identity_name,
526
+ ): UserAssignedIdentitiesValue(),
527
+ },
528
+ )
491
529
  try:
492
530
  poller = compute_client.virtual_machines.begin_create_or_update(
493
531
  resource_group,
@@ -552,16 +590,7 @@ def _launch_instance(
552
590
  ),
553
591
  priority="Spot" if spot else "Regular",
554
592
  eviction_policy="Delete" if spot else None,
555
- identity=None
556
- if managed_identity is None
557
- else VirtualMachineIdentity(
558
- type=ResourceIdentityType.USER_ASSIGNED,
559
- user_assigned_identities={
560
- azure_utils.get_managed_identity_id(
561
- subscription_id, resource_group, managed_identity
562
- ): UserAssignedIdentitiesValue()
563
- },
564
- ),
593
+ identity=managed_identity,
565
594
  user_data=base64.b64encode(user_data.encode()).decode(),
566
595
  tags=tags,
567
596
  ),
@@ -4,6 +4,7 @@ from typing import List, Optional, Tuple
4
4
 
5
5
  import azure.core.exceptions
6
6
  from azure.core.credentials import TokenCredential
7
+ from azure.mgmt import msi as msi_mgmt
7
8
  from azure.mgmt import network as network_mgmt
8
9
  from azure.mgmt import resource as resource_mgmt
9
10
  from azure.mgmt import subscription as subscription_mgmt
@@ -97,6 +98,7 @@ class AzureConfigurator(Configurator):
97
98
  self._check_config_locations(config)
98
99
  self._check_config_tags(config)
99
100
  self._check_config_resource_group(config=config, credential=credential)
101
+ self._check_config_vm_managed_identity(config=config, credential=credential)
100
102
  self._check_config_vpc(config=config, credential=credential)
101
103
 
102
104
  def create_backend(
@@ -260,6 +262,25 @@ class AzureConfigurator(Configurator):
260
262
  except BackendError as e:
261
263
  raise ServerClientError(e.args[0])
262
264
 
265
+ def _check_config_vm_managed_identity(
266
+ self, config: AzureBackendConfigWithCreds, credential: auth.AzureCredential
267
+ ):
268
+ try:
269
+ resource_group, identity_name = compute.parse_vm_managed_identity(
270
+ config.vm_managed_identity
271
+ )
272
+ except BackendError as e:
273
+ raise ServerClientError(e.args[0])
274
+ if resource_group is None or identity_name is None:
275
+ return
276
+ msi_client = msi_mgmt.ManagedServiceIdentityClient(credential, config.subscription_id)
277
+ try:
278
+ msi_client.user_assigned_identities.get(resource_group, identity_name)
279
+ except azure.core.exceptions.ResourceNotFoundError:
280
+ raise ServerClientError(
281
+ f"Managed identity {identity_name} not found in resource group {resource_group}"
282
+ )
283
+
263
284
  def _set_client_creds_tenant_id(
264
285
  self,
265
286
  creds: AzureClientCreds,
@@ -62,6 +62,15 @@ class AzureBackendConfig(CoreModel):
62
62
  )
63
63
  ),
64
64
  ] = None
65
+ vm_managed_identity: Annotated[
66
+ Optional[str],
67
+ Field(
68
+ description=(
69
+ "The managed identity to associate with provisioned VMs."
70
+ " Must have a format `managedIdentityResourceGroup/managedIdentityName`"
71
+ )
72
+ ),
73
+ ] = None
65
74
  tags: Annotated[
66
75
  Optional[Dict[str, str]],
67
76
  Field(description="The tags that will be assigned to resources created by `dstack`"),
@@ -6,7 +6,7 @@ import threading
6
6
  from abc import ABC, abstractmethod
7
7
  from functools import lru_cache
8
8
  from pathlib import Path
9
- from typing import Dict, List, Optional
9
+ from typing import Dict, List, Literal, Optional
10
10
 
11
11
  import git
12
12
  import requests
@@ -25,6 +25,7 @@ from dstack._internal.core.models.gateways import (
25
25
  )
26
26
  from dstack._internal.core.models.instances import (
27
27
  InstanceConfiguration,
28
+ InstanceOffer,
28
29
  InstanceOfferWithAvailability,
29
30
  SSHKey,
30
31
  )
@@ -44,6 +45,8 @@ logger = get_logger(__name__)
44
45
  DSTACK_SHIM_BINARY_NAME = "dstack-shim"
45
46
  DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
46
47
 
48
+ GoArchType = Literal["amd64", "arm64"]
49
+
47
50
 
48
51
  class Compute(ABC):
49
52
  """
@@ -144,6 +147,7 @@ class ComputeWithCreateInstanceSupport(ABC):
144
147
  self,
145
148
  instance_offer: InstanceOfferWithAvailability,
146
149
  instance_config: InstanceConfiguration,
150
+ placement_group: Optional[PlacementGroup],
147
151
  ) -> JobProvisioningData:
148
152
  """
149
153
  Launches a new instance. It should return `JobProvisioningData` ASAP.
@@ -176,7 +180,7 @@ class ComputeWithCreateInstanceSupport(ABC):
176
180
  )
177
181
  instance_offer = instance_offer.copy()
178
182
  self._restrict_instance_offer_az_to_volumes_az(instance_offer, volumes)
179
- return self.create_instance(instance_offer, instance_config)
183
+ return self.create_instance(instance_offer, instance_config, placement_group=None)
180
184
 
181
185
  def _restrict_instance_offer_az_to_volumes_az(
182
186
  self,
@@ -225,9 +229,15 @@ class ComputeWithPlacementGroupSupport(ABC):
225
229
  def create_placement_group(
226
230
  self,
227
231
  placement_group: PlacementGroup,
232
+ master_instance_offer: InstanceOffer,
228
233
  ) -> PlacementGroupProvisioningData:
229
234
  """
230
235
  Creates a placement group.
236
+
237
+ Args:
238
+ placement_group: details about the placement group to be created
239
+ master_instance_offer: the first instance dstack will attempt to add
240
+ to the placement group
231
241
  """
232
242
  pass
233
243
 
@@ -242,10 +252,27 @@ class ComputeWithPlacementGroupSupport(ABC):
242
252
  """
243
253
  pass
244
254
 
255
+ @abstractmethod
256
+ def is_suitable_placement_group(
257
+ self,
258
+ placement_group: PlacementGroup,
259
+ instance_offer: InstanceOffer,
260
+ ) -> bool:
261
+ """
262
+ Checks if the instance offer can be provisioned in the placement group.
263
+
264
+ Should return immediately, without performing API calls.
265
+
266
+ Can be called with an offer originating from a different backend, because some backends
267
+ (BackendType.DSTACK) produce offers on behalf of other backends. Should return `False`
268
+ in that case.
269
+ """
270
+ pass
271
+
245
272
 
246
273
  class ComputeWithGatewaySupport(ABC):
247
274
  """
248
- Must be subclassed and imlemented to support gateways.
275
+ Must be subclassed and implemented to support gateways.
249
276
  """
250
277
 
251
278
  @abstractmethod
@@ -418,6 +445,21 @@ def generate_unique_volume_name(
418
445
  )
419
446
 
420
447
 
448
+ def generate_unique_placement_group_name(
449
+ project_name: str,
450
+ fleet_name: str,
451
+ max_length: int = _DEFAULT_MAX_RESOURCE_NAME_LEN,
452
+ ) -> str:
453
+ """
454
+ Generates a unique placement group name valid across all backends.
455
+ """
456
+ return generate_unique_backend_name(
457
+ resource_name=fleet_name,
458
+ project_name=project_name,
459
+ max_length=max_length,
460
+ )
461
+
462
+
421
463
  def generate_unique_backend_name(
422
464
  resource_name: str,
423
465
  project_name: Optional[str],
@@ -483,13 +525,14 @@ def get_shim_env(
483
525
  base_path: Optional[PathLike] = None,
484
526
  bin_path: Optional[PathLike] = None,
485
527
  backend_shim_env: Optional[Dict[str, str]] = None,
528
+ arch: Optional[str] = None,
486
529
  ) -> Dict[str, str]:
487
530
  log_level = "6" # Trace
488
531
  envs = {
489
532
  "DSTACK_SHIM_HOME": get_dstack_working_dir(base_path),
490
533
  "DSTACK_SHIM_HTTP_PORT": str(DSTACK_SHIM_HTTP_PORT),
491
534
  "DSTACK_SHIM_LOG_LEVEL": log_level,
492
- "DSTACK_RUNNER_DOWNLOAD_URL": get_dstack_runner_download_url(),
535
+ "DSTACK_RUNNER_DOWNLOAD_URL": get_dstack_runner_download_url(arch),
493
536
  "DSTACK_RUNNER_BINARY_PATH": get_dstack_runner_binary_path(bin_path),
494
537
  "DSTACK_RUNNER_HTTP_PORT": str(DSTACK_RUNNER_HTTP_PORT),
495
538
  "DSTACK_RUNNER_SSH_PORT": str(DSTACK_RUNNER_SSH_PORT),
@@ -509,16 +552,19 @@ def get_shim_commands(
509
552
  base_path: Optional[PathLike] = None,
510
553
  bin_path: Optional[PathLike] = None,
511
554
  backend_shim_env: Optional[Dict[str, str]] = None,
555
+ arch: Optional[str] = None,
512
556
  ) -> List[str]:
513
557
  commands = get_shim_pre_start_commands(
514
558
  base_path=base_path,
515
559
  bin_path=bin_path,
560
+ arch=arch,
516
561
  )
517
562
  shim_env = get_shim_env(
518
563
  authorized_keys=authorized_keys,
519
564
  base_path=base_path,
520
565
  bin_path=bin_path,
521
566
  backend_shim_env=backend_shim_env,
567
+ arch=arch,
522
568
  )
523
569
  for k, v in shim_env.items():
524
570
  commands += [f'export "{k}={v}"']
@@ -539,35 +585,63 @@ def get_dstack_runner_version() -> str:
539
585
  return version or "latest"
540
586
 
541
587
 
542
- def get_dstack_runner_download_url() -> str:
543
- if url := os.environ.get("DSTACK_RUNNER_DOWNLOAD_URL"):
544
- return url
545
- build = get_dstack_runner_version()
546
- if settings.DSTACK_VERSION is not None:
547
- bucket = "dstack-runner-downloads"
548
- else:
549
- bucket = "dstack-runner-downloads-stgn"
550
- return (
551
- f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64"
552
- )
553
-
554
-
555
- def get_dstack_shim_download_url() -> str:
556
- if url := os.environ.get("DSTACK_SHIM_DOWNLOAD_URL"):
557
- return url
558
- build = get_dstack_runner_version()
559
- if settings.DSTACK_VERSION is not None:
560
- bucket = "dstack-runner-downloads"
561
- else:
562
- bucket = "dstack-runner-downloads-stgn"
563
- return f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64"
588
+ def normalize_arch(arch: Optional[str] = None) -> GoArchType:
589
+ """
590
+ Converts the given free-form architecture string to the Go GOARCH format.
591
+ Only 64-bit x86 and ARM are supported. If the word size is not specified (e.g., `x86`, `arm`),
592
+ 64-bit is implied.
593
+ If the arch is not specified, falls back to `amd64`.
594
+ """
595
+ if not arch:
596
+ return "amd64"
597
+ arch_lower = arch.lower()
598
+ if "32" in arch_lower or arch_lower in ["i386", "i686"]:
599
+ raise ValueError(f"32-bit architectures are not supported: {arch}")
600
+ if arch_lower.startswith("x86") or arch_lower.startswith("amd"):
601
+ return "amd64"
602
+ if arch_lower.startswith("arm") or arch_lower.startswith("aarch"):
603
+ return "arm64"
604
+ raise ValueError(f"Unsupported architecture: {arch}")
605
+
606
+
607
+ def get_dstack_runner_download_url(arch: Optional[str] = None) -> str:
608
+ url_template = os.environ.get("DSTACK_RUNNER_DOWNLOAD_URL")
609
+ if not url_template:
610
+ if settings.DSTACK_VERSION is not None:
611
+ bucket = "dstack-runner-downloads"
612
+ else:
613
+ bucket = "dstack-runner-downloads-stgn"
614
+ url_template = (
615
+ f"https://{bucket}.s3.eu-west-1.amazonaws.com"
616
+ "/{version}/binaries/dstack-runner-linux-{arch}"
617
+ )
618
+ version = get_dstack_runner_version()
619
+ arch = normalize_arch(arch)
620
+ return url_template.format(version=version, arch=arch)
621
+
622
+
623
+ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
624
+ url_template = os.environ.get("DSTACK_SHIM_DOWNLOAD_URL")
625
+ if not url_template:
626
+ if settings.DSTACK_VERSION is not None:
627
+ bucket = "dstack-runner-downloads"
628
+ else:
629
+ bucket = "dstack-runner-downloads-stgn"
630
+ url_template = (
631
+ f"https://{bucket}.s3.eu-west-1.amazonaws.com"
632
+ "/{version}/binaries/dstack-shim-linux-{arch}"
633
+ )
634
+ version = get_dstack_runner_version()
635
+ arch = normalize_arch(arch)
636
+ return url_template.format(version=version, arch=arch)
564
637
 
565
638
 
566
639
  def get_shim_pre_start_commands(
567
640
  base_path: Optional[PathLike] = None,
568
641
  bin_path: Optional[PathLike] = None,
642
+ arch: Optional[str] = None,
569
643
  ) -> List[str]:
570
- url = get_dstack_shim_download_url()
644
+ url = get_dstack_shim_download_url(arch)
571
645
  dstack_shim_binary_path = get_dstack_shim_binary_path(bin_path)
572
646
  dstack_working_dir = get_dstack_working_dir(base_path)
573
647
  return [
@@ -2,6 +2,7 @@ from dataclasses import asdict
2
2
  from typing import Callable, List, Optional
3
3
 
4
4
  import gpuhunt
5
+ from pydantic import parse_obj_as
5
6
 
6
7
  from dstack._internal.core.models.backends.base import BackendType
7
8
  from dstack._internal.core.models.instances import (
@@ -11,13 +12,14 @@ from dstack._internal.core.models.instances import (
11
12
  InstanceType,
12
13
  Resources,
13
14
  )
14
- from dstack._internal.core.models.resources import DEFAULT_DISK, Memory, Range
15
+ from dstack._internal.core.models.resources import DEFAULT_DISK, CPUSpec, Memory, Range
15
16
  from dstack._internal.core.models.runs import Requirements
16
17
 
17
18
  # Offers not supported by all dstack versions are hidden behind one or more flags.
18
19
  # This list enables the flags that are currently supported.
19
20
  SUPPORTED_GPUHUNT_FLAGS = [
20
21
  "oci-spot",
22
+ "lambda-arm",
21
23
  ]
22
24
 
23
25
 
@@ -71,6 +73,7 @@ def catalog_item_to_offer(
71
73
  if disk_size_mib is None:
72
74
  return None
73
75
  resources = Resources(
76
+ cpu_arch=item.cpu_arch,
74
77
  cpus=item.cpu,
75
78
  memory_mib=round(item.memory * 1024),
76
79
  gpus=gpus,
@@ -90,6 +93,9 @@ def catalog_item_to_offer(
90
93
 
91
94
 
92
95
  def offer_to_catalog_item(offer: InstanceOffer) -> gpuhunt.CatalogItem:
96
+ cpu_arch = offer.instance.resources.cpu_arch
97
+ if cpu_arch is None:
98
+ cpu_arch = gpuhunt.CPUArchitecture.X86
93
99
  gpu_count = len(offer.instance.resources.gpus)
94
100
  gpu_vendor = None
95
101
  gpu_name = None
@@ -104,6 +110,7 @@ def offer_to_catalog_item(offer: InstanceOffer) -> gpuhunt.CatalogItem:
104
110
  instance_name=offer.instance.name,
105
111
  location=offer.region,
106
112
  price=offer.price,
113
+ cpu_arch=cpu_arch,
107
114
  cpu=offer.instance.resources.cpus,
108
115
  memory=offer.instance.resources.memory_mib / 1024,
109
116
  gpu_count=gpu_count,
@@ -125,8 +132,11 @@ def requirements_to_query_filter(req: Optional[Requirements]) -> gpuhunt.QueryFi
125
132
 
126
133
  res = req.resources
127
134
  if res.cpu:
128
- q.min_cpu = res.cpu.min
129
- q.max_cpu = res.cpu.max
135
+ # TODO: Remove in 0.20. Use res.cpu directly
136
+ cpu = parse_obj_as(CPUSpec, res.cpu)
137
+ q.cpu_arch = cpu.arch
138
+ q.min_cpu = cpu.count.min
139
+ q.max_cpu = cpu.count.max
130
140
  if res.memory:
131
141
  q.min_memory = res.memory.min
132
142
  q.max_memory = res.memory.max
@@ -18,6 +18,7 @@ from dstack._internal.core.models.instances import (
18
18
  InstanceConfiguration,
19
19
  InstanceOfferWithAvailability,
20
20
  )
21
+ from dstack._internal.core.models.placement import PlacementGroup
21
22
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
22
23
  from dstack._internal.utils.logging import get_logger
23
24
 
@@ -58,6 +59,7 @@ class CudoCompute(
58
59
  self,
59
60
  instance_offer: InstanceOfferWithAvailability,
60
61
  instance_config: InstanceConfiguration,
62
+ placement_group: Optional[PlacementGroup],
61
63
  ) -> JobProvisioningData:
62
64
  vm_id = generate_unique_instance_name(instance_config, max_length=MAX_RESOURCE_NAME_LEN)
63
65
  public_keys = instance_config.get_public_keys()
@@ -20,6 +20,7 @@ from dstack._internal.core.models.instances import (
20
20
  InstanceOffer,
21
21
  InstanceOfferWithAvailability,
22
22
  )
23
+ from dstack._internal.core.models.placement import PlacementGroup
23
24
  from dstack._internal.core.models.resources import Memory, Range
24
25
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
25
26
  from dstack._internal.utils.logging import get_logger
@@ -85,6 +86,7 @@ class DataCrunchCompute(
85
86
  self,
86
87
  instance_offer: InstanceOfferWithAvailability,
87
88
  instance_config: InstanceConfiguration,
89
+ placement_group: Optional[PlacementGroup],
88
90
  ) -> JobProvisioningData:
89
91
  instance_name = generate_unique_instance_name(
90
92
  instance_config, max_length=MAX_INSTANCE_NAME_LEN
@@ -19,7 +19,7 @@ def authenticate(creds: AnyGCPCreds, project_id: Optional[str] = None) -> Tuple[
19
19
  credentials, credentials_project_id = get_credentials(creds)
20
20
  if project_id is None:
21
21
  # If project_id is not specified explicitly, try using credentials' project_id.
22
- # Explicit project_id takes precedence bacause credentials' project_id may be irrelevant.
22
+ # Explicit project_id takes precedence because credentials' project_id may be irrelevant.
23
23
  # For example, with Workload Identity Federation for GKE, it's cluster project_id.
24
24
  project_id = credentials_project_id
25
25
  if project_id is None: