dstack 0.19.7__py3-none-any.whl → 0.19.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (60) hide show
  1. dstack/_internal/cli/services/args.py +2 -2
  2. dstack/_internal/cli/services/configurators/run.py +56 -13
  3. dstack/_internal/cli/utils/run.py +10 -5
  4. dstack/_internal/core/backends/aws/compute.py +13 -1
  5. dstack/_internal/core/backends/azure/compute.py +42 -13
  6. dstack/_internal/core/backends/azure/configurator.py +21 -0
  7. dstack/_internal/core/backends/azure/models.py +9 -0
  8. dstack/_internal/core/backends/base/compute.py +101 -27
  9. dstack/_internal/core/backends/base/offers.py +13 -3
  10. dstack/_internal/core/backends/cudo/compute.py +3 -1
  11. dstack/_internal/core/backends/datacrunch/compute.py +2 -0
  12. dstack/_internal/core/backends/gcp/auth.py +1 -1
  13. dstack/_internal/core/backends/gcp/compute.py +51 -35
  14. dstack/_internal/core/backends/lambdalabs/compute.py +20 -8
  15. dstack/_internal/core/backends/local/compute.py +2 -0
  16. dstack/_internal/core/backends/nebius/compute.py +95 -1
  17. dstack/_internal/core/backends/nebius/configurator.py +11 -0
  18. dstack/_internal/core/backends/nebius/fabrics.py +48 -0
  19. dstack/_internal/core/backends/nebius/models.py +9 -1
  20. dstack/_internal/core/backends/nebius/resources.py +29 -0
  21. dstack/_internal/core/backends/oci/compute.py +2 -0
  22. dstack/_internal/core/backends/remote/provisioning.py +27 -2
  23. dstack/_internal/core/backends/template/compute.py.jinja +2 -0
  24. dstack/_internal/core/backends/tensordock/compute.py +2 -0
  25. dstack/_internal/core/backends/vultr/compute.py +5 -1
  26. dstack/_internal/core/models/instances.py +2 -1
  27. dstack/_internal/core/models/resources.py +79 -4
  28. dstack/_internal/core/models/runs.py +26 -9
  29. dstack/_internal/core/models/volumes.py +1 -1
  30. dstack/_internal/server/background/tasks/process_fleets.py +4 -13
  31. dstack/_internal/server/background/tasks/process_instances.py +176 -55
  32. dstack/_internal/server/background/tasks/process_metrics.py +26 -9
  33. dstack/_internal/server/background/tasks/process_placement_groups.py +1 -1
  34. dstack/_internal/server/background/tasks/process_prometheus_metrics.py +5 -2
  35. dstack/_internal/server/background/tasks/process_running_jobs.py +56 -18
  36. dstack/_internal/server/migrations/versions/20166748b60c_add_jobmodel_disconnected_at.py +100 -0
  37. dstack/_internal/server/migrations/versions/6c1a9d6530ee_add_jobmodel_exit_status.py +26 -0
  38. dstack/_internal/server/models.py +6 -1
  39. dstack/_internal/server/schemas/runner.py +41 -8
  40. dstack/_internal/server/services/fleets.py +9 -26
  41. dstack/_internal/server/services/instances.py +0 -2
  42. dstack/_internal/server/services/jobs/__init__.py +1 -0
  43. dstack/_internal/server/services/offers.py +15 -0
  44. dstack/_internal/server/services/placement.py +27 -6
  45. dstack/_internal/server/services/resources.py +21 -0
  46. dstack/_internal/server/services/runner/client.py +7 -4
  47. dstack/_internal/server/services/runs.py +18 -8
  48. dstack/_internal/server/settings.py +20 -1
  49. dstack/_internal/server/testing/common.py +37 -26
  50. dstack/_internal/utils/common.py +13 -1
  51. dstack/_internal/utils/json_schema.py +6 -3
  52. dstack/api/__init__.py +1 -0
  53. dstack/api/server/_fleets.py +16 -0
  54. dstack/api/server/_runs.py +48 -3
  55. dstack/version.py +1 -1
  56. {dstack-0.19.7.dist-info → dstack-0.19.9.dist-info}/METADATA +38 -29
  57. {dstack-0.19.7.dist-info → dstack-0.19.9.dist-info}/RECORD +60 -56
  58. {dstack-0.19.7.dist-info → dstack-0.19.9.dist-info}/WHEEL +0 -0
  59. {dstack-0.19.7.dist-info → dstack-0.19.9.dist-info}/entry_points.txt +0 -0
  60. {dstack-0.19.7.dist-info → dstack-0.19.9.dist-info}/licenses/LICENSE.md +0 -0
@@ -19,8 +19,8 @@ def port_mapping(v: str) -> PortMapping:
19
19
  return PortMapping.parse(v)
20
20
 
21
21
 
22
- def cpu_spec(v: str) -> resources.Range[int]:
23
- return parse_obj_as(resources.Range[int], v)
22
+ def cpu_spec(v: str) -> dict:
23
+ return resources.CPUSpec.parse(v)
24
24
 
25
25
 
26
26
  def memory_spec(v: str) -> resources.Range[resources.Memory]:
@@ -6,9 +6,10 @@ from pathlib import Path
6
6
  from typing import Dict, List, Optional, Set, Tuple
7
7
 
8
8
  import gpuhunt
9
+ from pydantic import parse_obj_as
9
10
 
10
11
  import dstack._internal.core.models.resources as resources
11
- from dstack._internal.cli.services.args import disk_spec, gpu_spec, port_mapping
12
+ from dstack._internal.cli.services.args import cpu_spec, disk_spec, gpu_spec, port_mapping
12
13
  from dstack._internal.cli.services.configurators.base import (
13
14
  ApplyEnvVarsConfiguratorMixin,
14
15
  BaseApplyConfigurator,
@@ -39,6 +40,7 @@ from dstack._internal.core.models.configurations import (
39
40
  TaskConfiguration,
40
41
  )
41
42
  from dstack._internal.core.models.repos.base import Repo
43
+ from dstack._internal.core.models.resources import CPUSpec
42
44
  from dstack._internal.core.models.runs import JobSubmission, JobTerminationReason, RunStatus
43
45
  from dstack._internal.core.services.configs import ConfigManager
44
46
  from dstack._internal.core.services.diff import diff_models
@@ -72,6 +74,7 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
72
74
  ):
73
75
  self.apply_args(conf, configurator_args, unknown_args)
74
76
  self.validate_gpu_vendor_and_image(conf)
77
+ self.validate_cpu_arch_and_image(conf)
75
78
  if repo is None:
76
79
  repo = self.api.repos.load(Path.cwd())
77
80
  config_manager = ConfigManager()
@@ -95,6 +98,8 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
95
98
  print_run_plan(run_plan, max_offers=configurator_args.max_offers)
96
99
 
97
100
  confirm_message = "Submit a new run?"
101
+ if conf.name:
102
+ confirm_message = f"Submit the run [code]{conf.name}[/]?"
98
103
  stop_run_name = None
99
104
  if run_plan.current_resource is not None:
100
105
  changed_fields = []
@@ -127,11 +132,6 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
127
132
  f"Active run [code]{conf.name}[/] already exists and cannot be updated in-place."
128
133
  )
129
134
  confirm_message = "Stop and override the run?"
130
- else:
131
- console.print(f"Finished run [code]{conf.name}[/] already exists.")
132
- confirm_message = "Override the run?"
133
- elif conf.name:
134
- confirm_message = f"Submit the run [code]{conf.name}[/]?"
135
135
 
136
136
  if not command_args.yes and not confirm_ask(confirm_message):
137
137
  console.print("\nExiting...")
@@ -289,6 +289,14 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
289
289
  default=default_max_offers,
290
290
  )
291
291
  cls.register_env_args(configuration_group)
292
+ configuration_group.add_argument(
293
+ "--cpu",
294
+ type=cpu_spec,
295
+ help="Request CPU for the run. "
296
+ "The format is [code]ARCH[/]:[code]COUNT[/] (all parts are optional)",
297
+ dest="cpu_spec",
298
+ metavar="SPEC",
299
+ )
292
300
  configuration_group.add_argument(
293
301
  "--gpu",
294
302
  type=gpu_spec,
@@ -310,6 +318,8 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
310
318
  apply_profile_args(args, conf)
311
319
  if args.run_name:
312
320
  conf.name = args.run_name
321
+ if args.cpu_spec:
322
+ conf.resources.cpu = resources.CPUSpec.parse_obj(args.cpu_spec)
313
323
  if args.gpu_spec:
314
324
  conf.resources.gpu = resources.GPUSpec.parse_obj(args.gpu_spec)
315
325
  if args.disk_spec:
@@ -342,7 +352,7 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
342
352
 
343
353
  def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None:
344
354
  """
345
- Infers `resources.gpu.vendor` if not set, requires `image` if the vendor is AMD.
355
+ Infers and sets `resources.gpu.vendor` if not set, requires `image` if the vendor is AMD.
346
356
  """
347
357
  gpu_spec = conf.resources.gpu
348
358
  if gpu_spec is None:
@@ -400,6 +410,29 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
400
410
  "`image` is required if `resources.gpu.vendor` is `tenstorrent`"
401
411
  )
402
412
 
413
+ def validate_cpu_arch_and_image(self, conf: BaseRunConfiguration) -> None:
414
+ """
415
+ Infers `resources.cpu.arch` if not set, requires `image` if the architecture is ARM.
416
+ """
417
+ # TODO: Remove in 0.20. Use conf.resources.cpu directly
418
+ cpu_spec = parse_obj_as(CPUSpec, conf.resources.cpu)
419
+ arch = cpu_spec.arch
420
+ if arch is None:
421
+ gpu_spec = conf.resources.gpu
422
+ if (
423
+ gpu_spec is not None
424
+ and gpu_spec.vendor in [None, gpuhunt.AcceleratorVendor.NVIDIA]
425
+ and gpu_spec.name
426
+ and any(map(gpuhunt.is_nvidia_superchip, gpu_spec.name))
427
+ ):
428
+ arch = gpuhunt.CPUArchitecture.ARM
429
+ else:
430
+ arch = gpuhunt.CPUArchitecture.X86
431
+ # NOTE: We don't set the inferred resources.cpu.arch for compatibility with older servers.
432
+ # Servers with ARM support set the arch using the same logic.
433
+ if arch == gpuhunt.CPUArchitecture.ARM and conf.image is None:
434
+ raise ConfigurationError("`image` is required if `resources.cpu.arch` is `arm`")
435
+
403
436
 
404
437
  class RunWithPortsConfigurator(BaseRunConfigurator):
405
438
  @classmethod
@@ -524,7 +557,9 @@ def print_finished_message(run: Run):
524
557
  console.print("[code]Done[/]")
525
558
  return
526
559
 
527
- termination_reason, termination_reason_message = _get_run_termination_reason(run)
560
+ termination_reason, termination_reason_message, exit_status = (
561
+ _get_run_termination_reason_and_exit_status(run)
562
+ )
528
563
  message = "Run failed due to unknown reason. Check CLI, server, and run logs."
529
564
  if run.status == RunStatus.TERMINATED:
530
565
  message = "Run terminated due to unknown reason. Check CLI, server, and run logs."
@@ -536,13 +571,15 @@ def print_finished_message(run: Run):
536
571
  "Check CLI and server logs for more details."
537
572
  )
538
573
  elif termination_reason is not None:
574
+ exit_status_details = f"Exit status: {exit_status}.\n" if exit_status else ""
539
575
  error_details = (
540
576
  f"Error: {termination_reason_message}\n" if termination_reason_message else ""
541
577
  )
542
578
  message = (
543
579
  f"Run failed with error code {termination_reason.name}.\n"
580
+ f"{exit_status_details}"
544
581
  f"{error_details}"
545
- "Check CLI, server, and run logs for more details."
582
+ f"Check [bold]dstack logs -d {run.name}[/bold] for more details."
546
583
  )
547
584
  console.print(f"[error]{message}[/]")
548
585
 
@@ -553,14 +590,20 @@ def get_run_exit_code(run: Run) -> int:
553
590
  return 1
554
591
 
555
592
 
556
- def _get_run_termination_reason(run: Run) -> Tuple[Optional[JobTerminationReason], Optional[str]]:
593
+ def _get_run_termination_reason_and_exit_status(
594
+ run: Run,
595
+ ) -> Tuple[Optional[JobTerminationReason], Optional[str], Optional[int]]:
557
596
  if len(run._run.jobs) == 0:
558
- return None, None
597
+ return None, None, None
559
598
  job = run._run.jobs[0]
560
599
  if len(job.job_submissions) == 0:
561
- return None, None
600
+ return None, None, None
562
601
  job_submission = job.job_submissions[0]
563
- return job_submission.termination_reason, job_submission.termination_reason_message
602
+ return (
603
+ job_submission.termination_reason,
604
+ job_submission.termination_reason_message,
605
+ job_submission.exit_status,
606
+ )
564
607
 
565
608
 
566
609
  def _run_resubmitted(run: Run, current_job_submission: Optional[JobSubmission]) -> bool:
@@ -1,4 +1,4 @@
1
- import os
1
+ import shutil
2
2
  from typing import Any, Dict, List, Optional, Union
3
3
 
4
4
  from rich.markup import escape
@@ -95,7 +95,7 @@ def print_run_plan(
95
95
  props.add_row(th("Inactivity duration"), inactivity_duration)
96
96
  props.add_row(th("Reservation"), run_spec.configuration.reservation or "-")
97
97
 
98
- offers = Table(box=None, expand=os.get_terminal_size()[0] <= 110)
98
+ offers = Table(box=None, expand=shutil.get_terminal_size(fallback=(120, 40)).columns <= 110)
99
99
  offers.add_column("#")
100
100
  offers.add_column("BACKEND", style="grey58", ratio=2)
101
101
  offers.add_column("RESOURCES", ratio=4)
@@ -149,7 +149,7 @@ def print_run_plan(
149
149
  def get_runs_table(
150
150
  runs: List[Run], verbose: bool = False, format_date: DateFormatter = pretty_date
151
151
  ) -> Table:
152
- table = Table(box=None, expand=os.get_terminal_size()[0] <= 110)
152
+ table = Table(box=None, expand=shutil.get_terminal_size(fallback=(120, 40)).columns <= 110)
153
153
  table.add_column("NAME", style="bold", no_wrap=True, ratio=2)
154
154
  table.add_column("BACKEND", style="grey58", ratio=2)
155
155
  table.add_column("RESOURCES", ratio=3 if not verbose else 2)
@@ -218,6 +218,11 @@ def _get_run_error(run: Run) -> str:
218
218
 
219
219
 
220
220
  def _get_job_error(job: Job) -> str:
221
- if job.job_submissions[-1].termination_reason is None:
221
+ job_submission = job.job_submissions[-1]
222
+ termination_reason = job_submission.termination_reason
223
+ exit_status = job_submission.exit_status
224
+ if termination_reason is None:
222
225
  return ""
223
- return job.job_submissions[-1].termination_reason.name
226
+ if exit_status:
227
+ return f"{termination_reason.name} {exit_status}"
228
+ return termination_reason.name
@@ -159,6 +159,7 @@ class AWSCompute(
159
159
  self,
160
160
  instance_offer: InstanceOfferWithAvailability,
161
161
  instance_config: InstanceConfiguration,
162
+ placement_group: Optional[PlacementGroup],
162
163
  ) -> JobProvisioningData:
163
164
  project_name = instance_config.project_name
164
165
  ec2_resource = self.session.resource("ec2", region_name=instance_offer.region)
@@ -248,7 +249,7 @@ class AWSCompute(
248
249
  spot=instance_offer.instance.resources.spot,
249
250
  subnet_id=subnet_id,
250
251
  allocate_public_ip=allocate_public_ip,
251
- placement_group_name=instance_config.placement_group_name,
252
+ placement_group_name=placement_group.name if placement_group else None,
252
253
  enable_efa=enable_efa,
253
254
  max_efa_interfaces=max_efa_interfaces,
254
255
  reservation_id=instance_config.reservation,
@@ -291,6 +292,7 @@ class AWSCompute(
291
292
  def create_placement_group(
292
293
  self,
293
294
  placement_group: PlacementGroup,
295
+ master_instance_offer: InstanceOffer,
294
296
  ) -> PlacementGroupProvisioningData:
295
297
  ec2_client = self.session.client("ec2", region_name=placement_group.configuration.region)
296
298
  logger.debug("Creating placement group %s...", placement_group.name)
@@ -323,6 +325,16 @@ class AWSCompute(
323
325
  raise e
324
326
  logger.debug("Deleted placement group %s", placement_group.name)
325
327
 
328
+ def is_suitable_placement_group(
329
+ self,
330
+ placement_group: PlacementGroup,
331
+ instance_offer: InstanceOffer,
332
+ ) -> bool:
333
+ return (
334
+ placement_group.configuration.backend == BackendType.AWS
335
+ and placement_group.configuration.region == instance_offer.region
336
+ )
337
+
326
338
  def create_gateway(
327
339
  self,
328
340
  configuration: GatewayComputeConfiguration,
@@ -62,6 +62,7 @@ from dstack._internal.core.models.instances import (
62
62
  InstanceOfferWithAvailability,
63
63
  InstanceType,
64
64
  )
65
+ from dstack._internal.core.models.placement import PlacementGroup
65
66
  from dstack._internal.core.models.resources import Memory, Range
66
67
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
67
68
  from dstack._internal.utils.logging import get_logger
@@ -109,6 +110,7 @@ class AzureCompute(
109
110
  self,
110
111
  instance_offer: InstanceOfferWithAvailability,
111
112
  instance_config: InstanceConfiguration,
113
+ placement_group: Optional[PlacementGroup],
112
114
  ) -> JobProvisioningData:
113
115
  instance_name = generate_unique_instance_name(
114
116
  instance_config, max_length=azure_resources.MAX_RESOURCE_NAME_LEN
@@ -136,6 +138,10 @@ class AzureCompute(
136
138
  location=location,
137
139
  )
138
140
 
141
+ managed_identity_resource_group, managed_identity_name = parse_vm_managed_identity(
142
+ self.config.vm_managed_identity
143
+ )
144
+
139
145
  base_tags = {
140
146
  "owner": "dstack",
141
147
  "dstack_project": instance_config.project_name,
@@ -159,7 +165,8 @@ class AzureCompute(
159
165
  network_security_group=network_security_group,
160
166
  network=network,
161
167
  subnet=subnet,
162
- managed_identity=None,
168
+ managed_identity_name=managed_identity_name,
169
+ managed_identity_resource_group=managed_identity_resource_group,
163
170
  image_reference=_get_image_ref(
164
171
  compute_client=self._compute_client,
165
172
  location=location,
@@ -255,7 +262,8 @@ class AzureCompute(
255
262
  network_security_group=network_security_group,
256
263
  network=network,
257
264
  subnet=subnet,
258
- managed_identity=None,
265
+ managed_identity_name=None,
266
+ managed_identity_resource_group=None,
259
267
  image_reference=_get_gateway_image_ref(),
260
268
  vm_size="Standard_B1ms",
261
269
  instance_name=instance_name,
@@ -338,6 +346,21 @@ def get_resource_group_network_subnet_or_error(
338
346
  return resource_group, network_name, subnet_name
339
347
 
340
348
 
349
+ def parse_vm_managed_identity(
350
+ vm_managed_identity: Optional[str],
351
+ ) -> Tuple[Optional[str], Optional[str]]:
352
+ if vm_managed_identity is None:
353
+ return None, None
354
+ try:
355
+ resource_group, managed_identity = vm_managed_identity.split("/")
356
+ return resource_group, managed_identity
357
+ except Exception:
358
+ raise ComputeError(
359
+ "`vm_managed_identity` specified in incorrect format."
360
+ " Supported format: 'managedIdentityResourceGroup/managedIdentityName'"
361
+ )
362
+
363
+
341
364
  def _parse_config_vpc_id(vpc_id: str) -> Tuple[str, str]:
342
365
  resource_group, network_name = vpc_id.split("/")
343
366
  return resource_group, network_name
@@ -466,7 +489,8 @@ def _launch_instance(
466
489
  network_security_group: str,
467
490
  network: str,
468
491
  subnet: str,
469
- managed_identity: Optional[str],
492
+ managed_identity_name: Optional[str],
493
+ managed_identity_resource_group: Optional[str],
470
494
  image_reference: ImageReference,
471
495
  vm_size: str,
472
496
  instance_name: str,
@@ -488,6 +512,20 @@ def _launch_instance(
488
512
  public_ip_address_configuration = VirtualMachinePublicIPAddressConfiguration(
489
513
  name="public_ip_config",
490
514
  )
515
+ managed_identity = None
516
+ if managed_identity_name is not None:
517
+ if managed_identity_resource_group is None:
518
+ managed_identity_resource_group = resource_group
519
+ managed_identity = VirtualMachineIdentity(
520
+ type=ResourceIdentityType.USER_ASSIGNED,
521
+ user_assigned_identities={
522
+ azure_utils.get_managed_identity_id(
523
+ subscription_id,
524
+ managed_identity_resource_group,
525
+ managed_identity_name,
526
+ ): UserAssignedIdentitiesValue(),
527
+ },
528
+ )
491
529
  try:
492
530
  poller = compute_client.virtual_machines.begin_create_or_update(
493
531
  resource_group,
@@ -552,16 +590,7 @@ def _launch_instance(
552
590
  ),
553
591
  priority="Spot" if spot else "Regular",
554
592
  eviction_policy="Delete" if spot else None,
555
- identity=None
556
- if managed_identity is None
557
- else VirtualMachineIdentity(
558
- type=ResourceIdentityType.USER_ASSIGNED,
559
- user_assigned_identities={
560
- azure_utils.get_managed_identity_id(
561
- subscription_id, resource_group, managed_identity
562
- ): UserAssignedIdentitiesValue()
563
- },
564
- ),
593
+ identity=managed_identity,
565
594
  user_data=base64.b64encode(user_data.encode()).decode(),
566
595
  tags=tags,
567
596
  ),
@@ -4,6 +4,7 @@ from typing import List, Optional, Tuple
4
4
 
5
5
  import azure.core.exceptions
6
6
  from azure.core.credentials import TokenCredential
7
+ from azure.mgmt import msi as msi_mgmt
7
8
  from azure.mgmt import network as network_mgmt
8
9
  from azure.mgmt import resource as resource_mgmt
9
10
  from azure.mgmt import subscription as subscription_mgmt
@@ -97,6 +98,7 @@ class AzureConfigurator(Configurator):
97
98
  self._check_config_locations(config)
98
99
  self._check_config_tags(config)
99
100
  self._check_config_resource_group(config=config, credential=credential)
101
+ self._check_config_vm_managed_identity(config=config, credential=credential)
100
102
  self._check_config_vpc(config=config, credential=credential)
101
103
 
102
104
  def create_backend(
@@ -260,6 +262,25 @@ class AzureConfigurator(Configurator):
260
262
  except BackendError as e:
261
263
  raise ServerClientError(e.args[0])
262
264
 
265
+ def _check_config_vm_managed_identity(
266
+ self, config: AzureBackendConfigWithCreds, credential: auth.AzureCredential
267
+ ):
268
+ try:
269
+ resource_group, identity_name = compute.parse_vm_managed_identity(
270
+ config.vm_managed_identity
271
+ )
272
+ except BackendError as e:
273
+ raise ServerClientError(e.args[0])
274
+ if resource_group is None or identity_name is None:
275
+ return
276
+ msi_client = msi_mgmt.ManagedServiceIdentityClient(credential, config.subscription_id)
277
+ try:
278
+ msi_client.user_assigned_identities.get(resource_group, identity_name)
279
+ except azure.core.exceptions.ResourceNotFoundError:
280
+ raise ServerClientError(
281
+ f"Managed identity {identity_name} not found in resource group {resource_group}"
282
+ )
283
+
263
284
  def _set_client_creds_tenant_id(
264
285
  self,
265
286
  creds: AzureClientCreds,
@@ -62,6 +62,15 @@ class AzureBackendConfig(CoreModel):
62
62
  )
63
63
  ),
64
64
  ] = None
65
+ vm_managed_identity: Annotated[
66
+ Optional[str],
67
+ Field(
68
+ description=(
69
+ "The managed identity to associate with provisioned VMs."
70
+ " Must have a format `managedIdentityResourceGroup/managedIdentityName`"
71
+ )
72
+ ),
73
+ ] = None
65
74
  tags: Annotated[
66
75
  Optional[Dict[str, str]],
67
76
  Field(description="The tags that will be assigned to resources created by `dstack`"),
@@ -6,7 +6,7 @@ import threading
6
6
  from abc import ABC, abstractmethod
7
7
  from functools import lru_cache
8
8
  from pathlib import Path
9
- from typing import Dict, List, Optional
9
+ from typing import Dict, List, Literal, Optional
10
10
 
11
11
  import git
12
12
  import requests
@@ -25,6 +25,7 @@ from dstack._internal.core.models.gateways import (
25
25
  )
26
26
  from dstack._internal.core.models.instances import (
27
27
  InstanceConfiguration,
28
+ InstanceOffer,
28
29
  InstanceOfferWithAvailability,
29
30
  SSHKey,
30
31
  )
@@ -44,6 +45,8 @@ logger = get_logger(__name__)
44
45
  DSTACK_SHIM_BINARY_NAME = "dstack-shim"
45
46
  DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
46
47
 
48
+ GoArchType = Literal["amd64", "arm64"]
49
+
47
50
 
48
51
  class Compute(ABC):
49
52
  """
@@ -144,6 +147,7 @@ class ComputeWithCreateInstanceSupport(ABC):
144
147
  self,
145
148
  instance_offer: InstanceOfferWithAvailability,
146
149
  instance_config: InstanceConfiguration,
150
+ placement_group: Optional[PlacementGroup],
147
151
  ) -> JobProvisioningData:
148
152
  """
149
153
  Launches a new instance. It should return `JobProvisioningData` ASAP.
@@ -176,7 +180,7 @@ class ComputeWithCreateInstanceSupport(ABC):
176
180
  )
177
181
  instance_offer = instance_offer.copy()
178
182
  self._restrict_instance_offer_az_to_volumes_az(instance_offer, volumes)
179
- return self.create_instance(instance_offer, instance_config)
183
+ return self.create_instance(instance_offer, instance_config, placement_group=None)
180
184
 
181
185
  def _restrict_instance_offer_az_to_volumes_az(
182
186
  self,
@@ -225,9 +229,15 @@ class ComputeWithPlacementGroupSupport(ABC):
225
229
  def create_placement_group(
226
230
  self,
227
231
  placement_group: PlacementGroup,
232
+ master_instance_offer: InstanceOffer,
228
233
  ) -> PlacementGroupProvisioningData:
229
234
  """
230
235
  Creates a placement group.
236
+
237
+ Args:
238
+ placement_group: details about the placement group to be created
239
+ master_instance_offer: the first instance dstack will attempt to add
240
+ to the placement group
231
241
  """
232
242
  pass
233
243
 
@@ -242,10 +252,27 @@ class ComputeWithPlacementGroupSupport(ABC):
242
252
  """
243
253
  pass
244
254
 
255
+ @abstractmethod
256
+ def is_suitable_placement_group(
257
+ self,
258
+ placement_group: PlacementGroup,
259
+ instance_offer: InstanceOffer,
260
+ ) -> bool:
261
+ """
262
+ Checks if the instance offer can be provisioned in the placement group.
263
+
264
+ Should return immediately, without performing API calls.
265
+
266
+ Can be called with an offer originating from a different backend, because some backends
267
+ (BackendType.DSTACK) produce offers on behalf of other backends. Should return `False`
268
+ in that case.
269
+ """
270
+ pass
271
+
245
272
 
246
273
  class ComputeWithGatewaySupport(ABC):
247
274
  """
248
- Must be subclassed and imlemented to support gateways.
275
+ Must be subclassed and implemented to support gateways.
249
276
  """
250
277
 
251
278
  @abstractmethod
@@ -418,6 +445,21 @@ def generate_unique_volume_name(
418
445
  )
419
446
 
420
447
 
448
+ def generate_unique_placement_group_name(
449
+ project_name: str,
450
+ fleet_name: str,
451
+ max_length: int = _DEFAULT_MAX_RESOURCE_NAME_LEN,
452
+ ) -> str:
453
+ """
454
+ Generates a unique placement group name valid across all backends.
455
+ """
456
+ return generate_unique_backend_name(
457
+ resource_name=fleet_name,
458
+ project_name=project_name,
459
+ max_length=max_length,
460
+ )
461
+
462
+
421
463
  def generate_unique_backend_name(
422
464
  resource_name: str,
423
465
  project_name: Optional[str],
@@ -483,13 +525,14 @@ def get_shim_env(
483
525
  base_path: Optional[PathLike] = None,
484
526
  bin_path: Optional[PathLike] = None,
485
527
  backend_shim_env: Optional[Dict[str, str]] = None,
528
+ arch: Optional[str] = None,
486
529
  ) -> Dict[str, str]:
487
530
  log_level = "6" # Trace
488
531
  envs = {
489
532
  "DSTACK_SHIM_HOME": get_dstack_working_dir(base_path),
490
533
  "DSTACK_SHIM_HTTP_PORT": str(DSTACK_SHIM_HTTP_PORT),
491
534
  "DSTACK_SHIM_LOG_LEVEL": log_level,
492
- "DSTACK_RUNNER_DOWNLOAD_URL": get_dstack_runner_download_url(),
535
+ "DSTACK_RUNNER_DOWNLOAD_URL": get_dstack_runner_download_url(arch),
493
536
  "DSTACK_RUNNER_BINARY_PATH": get_dstack_runner_binary_path(bin_path),
494
537
  "DSTACK_RUNNER_HTTP_PORT": str(DSTACK_RUNNER_HTTP_PORT),
495
538
  "DSTACK_RUNNER_SSH_PORT": str(DSTACK_RUNNER_SSH_PORT),
@@ -509,16 +552,19 @@ def get_shim_commands(
509
552
  base_path: Optional[PathLike] = None,
510
553
  bin_path: Optional[PathLike] = None,
511
554
  backend_shim_env: Optional[Dict[str, str]] = None,
555
+ arch: Optional[str] = None,
512
556
  ) -> List[str]:
513
557
  commands = get_shim_pre_start_commands(
514
558
  base_path=base_path,
515
559
  bin_path=bin_path,
560
+ arch=arch,
516
561
  )
517
562
  shim_env = get_shim_env(
518
563
  authorized_keys=authorized_keys,
519
564
  base_path=base_path,
520
565
  bin_path=bin_path,
521
566
  backend_shim_env=backend_shim_env,
567
+ arch=arch,
522
568
  )
523
569
  for k, v in shim_env.items():
524
570
  commands += [f'export "{k}={v}"']
@@ -539,35 +585,63 @@ def get_dstack_runner_version() -> str:
539
585
  return version or "latest"
540
586
 
541
587
 
542
- def get_dstack_runner_download_url() -> str:
543
- if url := os.environ.get("DSTACK_RUNNER_DOWNLOAD_URL"):
544
- return url
545
- build = get_dstack_runner_version()
546
- if settings.DSTACK_VERSION is not None:
547
- bucket = "dstack-runner-downloads"
548
- else:
549
- bucket = "dstack-runner-downloads-stgn"
550
- return (
551
- f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64"
552
- )
553
-
554
-
555
- def get_dstack_shim_download_url() -> str:
556
- if url := os.environ.get("DSTACK_SHIM_DOWNLOAD_URL"):
557
- return url
558
- build = get_dstack_runner_version()
559
- if settings.DSTACK_VERSION is not None:
560
- bucket = "dstack-runner-downloads"
561
- else:
562
- bucket = "dstack-runner-downloads-stgn"
563
- return f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64"
588
+ def normalize_arch(arch: Optional[str] = None) -> GoArchType:
589
+ """
590
+ Converts the given free-form architecture string to the Go GOARCH format.
591
+ Only 64-bit x86 and ARM are supported. If the word size is not specified (e.g., `x86`, `arm`),
592
+ 64-bit is implied.
593
+ If the arch is not specified, falls back to `amd64`.
594
+ """
595
+ if not arch:
596
+ return "amd64"
597
+ arch_lower = arch.lower()
598
+ if "32" in arch_lower or arch_lower in ["i386", "i686"]:
599
+ raise ValueError(f"32-bit architectures are not supported: {arch}")
600
+ if arch_lower.startswith("x86") or arch_lower.startswith("amd"):
601
+ return "amd64"
602
+ if arch_lower.startswith("arm") or arch_lower.startswith("aarch"):
603
+ return "arm64"
604
+ raise ValueError(f"Unsupported architecture: {arch}")
605
+
606
+
607
+ def get_dstack_runner_download_url(arch: Optional[str] = None) -> str:
608
+ url_template = os.environ.get("DSTACK_RUNNER_DOWNLOAD_URL")
609
+ if not url_template:
610
+ if settings.DSTACK_VERSION is not None:
611
+ bucket = "dstack-runner-downloads"
612
+ else:
613
+ bucket = "dstack-runner-downloads-stgn"
614
+ url_template = (
615
+ f"https://{bucket}.s3.eu-west-1.amazonaws.com"
616
+ "/{version}/binaries/dstack-runner-linux-{arch}"
617
+ )
618
+ version = get_dstack_runner_version()
619
+ arch = normalize_arch(arch)
620
+ return url_template.format(version=version, arch=arch)
621
+
622
+
623
+ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
624
+ url_template = os.environ.get("DSTACK_SHIM_DOWNLOAD_URL")
625
+ if not url_template:
626
+ if settings.DSTACK_VERSION is not None:
627
+ bucket = "dstack-runner-downloads"
628
+ else:
629
+ bucket = "dstack-runner-downloads-stgn"
630
+ url_template = (
631
+ f"https://{bucket}.s3.eu-west-1.amazonaws.com"
632
+ "/{version}/binaries/dstack-shim-linux-{arch}"
633
+ )
634
+ version = get_dstack_runner_version()
635
+ arch = normalize_arch(arch)
636
+ return url_template.format(version=version, arch=arch)
564
637
 
565
638
 
566
639
  def get_shim_pre_start_commands(
567
640
  base_path: Optional[PathLike] = None,
568
641
  bin_path: Optional[PathLike] = None,
642
+ arch: Optional[str] = None,
569
643
  ) -> List[str]:
570
- url = get_dstack_shim_download_url()
644
+ url = get_dstack_shim_download_url(arch)
571
645
  dstack_shim_binary_path = get_dstack_shim_binary_path(bin_path)
572
646
  dstack_working_dir = get_dstack_working_dir(base_path)
573
647
  return [