dstack 0.19.6rc1__py3-none-any.whl → 0.19.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (69) hide show
  1. dstack/_internal/cli/services/args.py +2 -2
  2. dstack/_internal/cli/services/configurators/fleet.py +3 -2
  3. dstack/_internal/cli/services/configurators/run.py +50 -4
  4. dstack/_internal/cli/utils/fleet.py +3 -1
  5. dstack/_internal/cli/utils/run.py +25 -28
  6. dstack/_internal/core/backends/aws/compute.py +13 -1
  7. dstack/_internal/core/backends/azure/compute.py +42 -13
  8. dstack/_internal/core/backends/azure/configurator.py +21 -0
  9. dstack/_internal/core/backends/azure/models.py +9 -0
  10. dstack/_internal/core/backends/base/compute.py +101 -27
  11. dstack/_internal/core/backends/base/offers.py +13 -3
  12. dstack/_internal/core/backends/cudo/compute.py +2 -0
  13. dstack/_internal/core/backends/datacrunch/compute.py +2 -0
  14. dstack/_internal/core/backends/gcp/auth.py +1 -1
  15. dstack/_internal/core/backends/gcp/compute.py +51 -35
  16. dstack/_internal/core/backends/gcp/resources.py +6 -1
  17. dstack/_internal/core/backends/lambdalabs/compute.py +20 -8
  18. dstack/_internal/core/backends/local/compute.py +2 -0
  19. dstack/_internal/core/backends/nebius/compute.py +95 -1
  20. dstack/_internal/core/backends/nebius/configurator.py +11 -0
  21. dstack/_internal/core/backends/nebius/fabrics.py +47 -0
  22. dstack/_internal/core/backends/nebius/models.py +8 -0
  23. dstack/_internal/core/backends/nebius/resources.py +29 -0
  24. dstack/_internal/core/backends/oci/compute.py +2 -0
  25. dstack/_internal/core/backends/remote/provisioning.py +27 -2
  26. dstack/_internal/core/backends/template/compute.py.jinja +2 -0
  27. dstack/_internal/core/backends/tensordock/compute.py +2 -0
  28. dstack/_internal/core/backends/vastai/compute.py +2 -1
  29. dstack/_internal/core/backends/vultr/compute.py +5 -1
  30. dstack/_internal/core/errors.py +4 -0
  31. dstack/_internal/core/models/fleets.py +2 -0
  32. dstack/_internal/core/models/instances.py +4 -3
  33. dstack/_internal/core/models/resources.py +80 -3
  34. dstack/_internal/core/models/runs.py +10 -3
  35. dstack/_internal/core/models/volumes.py +1 -1
  36. dstack/_internal/server/background/tasks/process_fleets.py +4 -13
  37. dstack/_internal/server/background/tasks/process_instances.py +176 -55
  38. dstack/_internal/server/background/tasks/process_placement_groups.py +1 -1
  39. dstack/_internal/server/background/tasks/process_prometheus_metrics.py +5 -2
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +1 -1
  41. dstack/_internal/server/models.py +1 -0
  42. dstack/_internal/server/routers/gateways.py +2 -1
  43. dstack/_internal/server/services/config.py +7 -2
  44. dstack/_internal/server/services/fleets.py +24 -26
  45. dstack/_internal/server/services/gateways/__init__.py +17 -2
  46. dstack/_internal/server/services/instances.py +0 -2
  47. dstack/_internal/server/services/offers.py +15 -0
  48. dstack/_internal/server/services/placement.py +27 -6
  49. dstack/_internal/server/services/plugins.py +77 -0
  50. dstack/_internal/server/services/resources.py +21 -0
  51. dstack/_internal/server/services/runs.py +41 -17
  52. dstack/_internal/server/services/volumes.py +10 -1
  53. dstack/_internal/server/testing/common.py +35 -26
  54. dstack/_internal/utils/common.py +22 -9
  55. dstack/_internal/utils/json_schema.py +6 -3
  56. dstack/api/__init__.py +1 -0
  57. dstack/api/server/__init__.py +8 -1
  58. dstack/api/server/_fleets.py +16 -0
  59. dstack/api/server/_runs.py +44 -3
  60. dstack/plugins/__init__.py +8 -0
  61. dstack/plugins/_base.py +72 -0
  62. dstack/plugins/_models.py +8 -0
  63. dstack/plugins/_utils.py +19 -0
  64. dstack/version.py +1 -1
  65. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/METADATA +14 -2
  66. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/RECORD +69 -62
  67. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/WHEEL +0 -0
  68. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/entry_points.txt +0 -0
  69. {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/licenses/LICENSE.md +0 -0
@@ -19,8 +19,8 @@ def port_mapping(v: str) -> PortMapping:
19
19
  return PortMapping.parse(v)
20
20
 
21
21
 
22
- def cpu_spec(v: str) -> resources.Range[int]:
23
- return parse_obj_as(resources.Range[int], v)
22
+ def cpu_spec(v: str) -> dict:
23
+ return resources.CPUSpec.parse(v)
24
24
 
25
25
 
26
26
  def memory_spec(v: str) -> resources.Range[resources.Memory]:
@@ -20,6 +20,7 @@ from dstack._internal.cli.utils.rich import MultiItemStatus
20
20
  from dstack._internal.core.errors import (
21
21
  CLIError,
22
22
  ConfigurationError,
23
+ MethodNotAllowedError,
23
24
  ResourceNotExistsError,
24
25
  ServerClientError,
25
26
  URLNotFoundError,
@@ -321,7 +322,7 @@ def _print_plan_header(plan: FleetPlan):
321
322
  offer.instance.name,
322
323
  resources.pretty_format(),
323
324
  "yes" if resources.spot else "no",
324
- f"${offer.price:g}",
325
+ f"${offer.price:3f}".rstrip("0").rstrip("."),
325
326
  availability,
326
327
  style=None if index == 1 else "secondary",
327
328
  )
@@ -367,7 +368,7 @@ def _apply_plan(api: Client, plan: FleetPlan) -> Fleet:
367
368
  project_name=api.project,
368
369
  plan=plan,
369
370
  )
370
- except URLNotFoundError:
371
+ except (URLNotFoundError, MethodNotAllowedError):
371
372
  # TODO: Remove in 0.20
372
373
  return api.client.fleets.create(
373
374
  project_name=api.project,
@@ -6,9 +6,10 @@ from pathlib import Path
6
6
  from typing import Dict, List, Optional, Set, Tuple
7
7
 
8
8
  import gpuhunt
9
+ from pydantic import parse_obj_as
9
10
 
10
11
  import dstack._internal.core.models.resources as resources
11
- from dstack._internal.cli.services.args import disk_spec, gpu_spec, port_mapping
12
+ from dstack._internal.cli.services.args import cpu_spec, disk_spec, gpu_spec, port_mapping
12
13
  from dstack._internal.cli.services.configurators.base import (
13
14
  ApplyEnvVarsConfiguratorMixin,
14
15
  BaseApplyConfigurator,
@@ -39,6 +40,7 @@ from dstack._internal.core.models.configurations import (
39
40
  TaskConfiguration,
40
41
  )
41
42
  from dstack._internal.core.models.repos.base import Repo
43
+ from dstack._internal.core.models.resources import CPUSpec
42
44
  from dstack._internal.core.models.runs import JobSubmission, JobTerminationReason, RunStatus
43
45
  from dstack._internal.core.services.configs import ConfigManager
44
46
  from dstack._internal.core.services.diff import diff_models
@@ -52,7 +54,7 @@ from dstack.api.utils import load_profile
52
54
  _KNOWN_AMD_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_AMD_GPUS}
53
55
  _KNOWN_NVIDIA_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_NVIDIA_GPUS}
54
56
  _KNOWN_TPU_VERSIONS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_TPUS}
55
-
57
+ _KNOWN_TENSTORRENT_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_TENSTORRENT_ACCELERATORS}
56
58
  _BIND_ADDRESS_ARG = "bind_address"
57
59
 
58
60
  logger = get_logger(__name__)
@@ -72,6 +74,7 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
72
74
  ):
73
75
  self.apply_args(conf, configurator_args, unknown_args)
74
76
  self.validate_gpu_vendor_and_image(conf)
77
+ self.validate_cpu_arch_and_image(conf)
75
78
  if repo is None:
76
79
  repo = self.api.repos.load(Path.cwd())
77
80
  config_manager = ConfigManager()
@@ -289,6 +292,14 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
289
292
  default=default_max_offers,
290
293
  )
291
294
  cls.register_env_args(configuration_group)
295
+ configuration_group.add_argument(
296
+ "--cpu",
297
+ type=cpu_spec,
298
+ help="Request CPU for the run. "
299
+ "The format is [code]ARCH[/]:[code]COUNT[/] (all parts are optional)",
300
+ dest="cpu_spec",
301
+ metavar="SPEC",
302
+ )
292
303
  configuration_group.add_argument(
293
304
  "--gpu",
294
305
  type=gpu_spec,
@@ -310,6 +321,8 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
310
321
  apply_profile_args(args, conf)
311
322
  if args.run_name:
312
323
  conf.name = args.run_name
324
+ if args.cpu_spec:
325
+ conf.resources.cpu = resources.CPUSpec.parse_obj(args.cpu_spec)
313
326
  if args.gpu_spec:
314
327
  conf.resources.gpu = resources.GPUSpec.parse_obj(args.gpu_spec)
315
328
  if args.disk_spec:
@@ -342,7 +355,7 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
342
355
 
343
356
  def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None:
344
357
  """
345
- Infers `resources.gpu.vendor` if not set, requires `image` if the vendor is AMD.
358
+ Infers and sets `resources.gpu.vendor` if not set, requires `image` if the vendor is AMD.
346
359
  """
347
360
  gpu_spec = conf.resources.gpu
348
361
  if gpu_spec is None:
@@ -350,6 +363,7 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
350
363
  if gpu_spec.count.max == 0:
351
364
  return
352
365
  has_amd_gpu: bool
366
+ has_tt_gpu: bool
353
367
  vendor = gpu_spec.vendor
354
368
  if vendor is None:
355
369
  names = gpu_spec.name
@@ -362,6 +376,8 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
362
376
  vendors.add(gpuhunt.AcceleratorVendor.NVIDIA)
363
377
  elif name in _KNOWN_AMD_GPUS:
364
378
  vendors.add(gpuhunt.AcceleratorVendor.AMD)
379
+ elif name in _KNOWN_TENSTORRENT_GPUS:
380
+ vendors.add(gpuhunt.AcceleratorVendor.TENSTORRENT)
365
381
  else:
366
382
  maybe_tpu_version, _, maybe_tpu_cores = name.partition("-")
367
383
  if maybe_tpu_version in _KNOWN_TPU_VERSIONS and maybe_tpu_cores.isdigit():
@@ -380,15 +396,45 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
380
396
  # to execute a run on an instance with an AMD accelerator with a default
381
397
  # CUDA image, not a big deal.
382
398
  has_amd_gpu = gpuhunt.AcceleratorVendor.AMD in vendors
399
+ has_tt_gpu = gpuhunt.AcceleratorVendor.TENSTORRENT in vendors
383
400
  else:
384
401
  # If neither gpu.vendor nor gpu.name is set, assume Nvidia.
385
402
  vendor = gpuhunt.AcceleratorVendor.NVIDIA
386
403
  has_amd_gpu = False
404
+ has_tt_gpu = False
387
405
  gpu_spec.vendor = vendor
388
406
  else:
389
407
  has_amd_gpu = vendor == gpuhunt.AcceleratorVendor.AMD
408
+ has_tt_gpu = vendor == gpuhunt.AcceleratorVendor.TENSTORRENT
390
409
  if has_amd_gpu and conf.image is None:
391
- raise ConfigurationError("`image` is required if `resources.gpu.vendor` is AMD.")
410
+ raise ConfigurationError("`image` is required if `resources.gpu.vendor` is `amd`")
411
+ if has_tt_gpu and conf.image is None:
412
+ raise ConfigurationError(
413
+ "`image` is required if `resources.gpu.vendor` is `tenstorrent`"
414
+ )
415
+
416
+ def validate_cpu_arch_and_image(self, conf: BaseRunConfiguration) -> None:
417
+ """
418
+ Infers `resources.cpu.arch` if not set, requires `image` if the architecture is ARM.
419
+ """
420
+ # TODO: Remove in 0.20. Use conf.resources.cpu directly
421
+ cpu_spec = parse_obj_as(CPUSpec, conf.resources.cpu)
422
+ arch = cpu_spec.arch
423
+ if arch is None:
424
+ gpu_spec = conf.resources.gpu
425
+ if (
426
+ gpu_spec is not None
427
+ and gpu_spec.vendor in [None, gpuhunt.AcceleratorVendor.NVIDIA]
428
+ and gpu_spec.name
429
+ and any(map(gpuhunt.is_nvidia_superchip, gpu_spec.name))
430
+ ):
431
+ arch = gpuhunt.CPUArchitecture.ARM
432
+ else:
433
+ arch = gpuhunt.CPUArchitecture.X86
434
+ # NOTE: We don't set the inferred resources.cpu.arch for compatibility with older servers.
435
+ # Servers with ARM support set the arch using the same logic.
436
+ if arch == gpuhunt.CPUArchitecture.ARM and conf.image is None:
437
+ raise ConfigurationError("`image` is required if `resources.cpu.arch` is `arm`")
392
438
 
393
439
 
394
440
  class RunWithPortsConfigurator(BaseRunConfigurator):
@@ -79,7 +79,9 @@ def get_fleets_table(
79
79
  "BACKEND": backend,
80
80
  "REGION": region,
81
81
  "RESOURCES": resources,
82
- "PRICE": f"${instance.price:.4}" if instance.price is not None else "",
82
+ "PRICE": f"${instance.price:.4f}".rstrip("0").rstrip(".")
83
+ if instance.price is not None
84
+ else "",
83
85
  "STATUS": status,
84
86
  "CREATED": format_date(instance.created),
85
87
  "ERROR": error,
@@ -1,3 +1,4 @@
1
+ import shutil
1
2
  from typing import Any, Dict, List, Optional, Union
2
3
 
3
4
  from rich.markup import escape
@@ -36,7 +37,7 @@ def print_run_plan(
36
37
 
37
38
  req = job_plan.job_spec.requirements
38
39
  pretty_req = req.pretty_format(resources_only=True)
39
- max_price = f"${req.max_price:g}" if req.max_price else "-"
40
+ max_price = f"${req.max_price:3f}".rstrip("0").rstrip(".") if req.max_price else "-"
40
41
  max_duration = (
41
42
  format_pretty_duration(job_plan.job_spec.max_duration)
42
43
  if job_plan.job_spec.max_duration
@@ -94,14 +95,12 @@ def print_run_plan(
94
95
  props.add_row(th("Inactivity duration"), inactivity_duration)
95
96
  props.add_row(th("Reservation"), run_spec.configuration.reservation or "-")
96
97
 
97
- offers = Table(box=None)
98
+ offers = Table(box=None, expand=shutil.get_terminal_size(fallback=(120, 40)).columns <= 110)
98
99
  offers.add_column("#")
99
- offers.add_column("BACKEND")
100
- offers.add_column("REGION")
101
- offers.add_column("INSTANCE TYPE")
102
- offers.add_column("RESOURCES")
103
- offers.add_column("SPOT")
104
- offers.add_column("PRICE")
100
+ offers.add_column("BACKEND", style="grey58", ratio=2)
101
+ offers.add_column("RESOURCES", ratio=4)
102
+ offers.add_column("INSTANCE TYPE", style="grey58", no_wrap=True, ratio=2)
103
+ offers.add_column("PRICE", style="grey58", ratio=1)
105
104
  offers.add_column()
106
105
 
107
106
  job_plan.offers = job_plan.offers[:max_offers] if max_offers else job_plan.offers
@@ -122,14 +121,12 @@ def print_run_plan(
122
121
  instance += f" ({offer.blocks}/{offer.total_blocks})"
123
122
  offers.add_row(
124
123
  f"{i}",
125
- offer.backend.replace("remote", "ssh"),
126
- offer.region,
124
+ offer.backend.replace("remote", "ssh") + " (" + offer.region + ")",
125
+ r.pretty_format(include_spot=True),
127
126
  instance,
128
- r.pretty_format(),
129
- "yes" if r.spot else "no",
130
- f"${offer.price:g}",
127
+ f"${offer.price:.4f}".rstrip("0").rstrip("."),
131
128
  availability,
132
- style=None if i == 1 else "secondary",
129
+ style=None if i == 1 or not include_run_properties else "secondary",
133
130
  )
134
131
  if job_plan.total_offers > len(job_plan.offers):
135
132
  offers.add_row("", "...", style="secondary")
@@ -141,7 +138,8 @@ def print_run_plan(
141
138
  if job_plan.total_offers > len(job_plan.offers):
142
139
  console.print(
143
140
  f"[secondary] Shown {len(job_plan.offers)} of {job_plan.total_offers} offers, "
144
- f"${job_plan.max_price:g} max[/]"
141
+ f"${job_plan.max_price:3f}".rstrip("0").rstrip(".")
142
+ + "max[/]"
145
143
  )
146
144
  console.print()
147
145
  else:
@@ -151,19 +149,18 @@ def print_run_plan(
151
149
  def get_runs_table(
152
150
  runs: List[Run], verbose: bool = False, format_date: DateFormatter = pretty_date
153
151
  ) -> Table:
154
- table = Table(box=None)
155
- table.add_column("NAME", style="bold", no_wrap=True)
156
- table.add_column("BACKEND", style="grey58")
152
+ table = Table(box=None, expand=shutil.get_terminal_size(fallback=(120, 40)).columns <= 110)
153
+ table.add_column("NAME", style="bold", no_wrap=True, ratio=2)
154
+ table.add_column("BACKEND", style="grey58", ratio=2)
155
+ table.add_column("RESOURCES", ratio=3 if not verbose else 2)
157
156
  if verbose:
158
- table.add_column("INSTANCE", no_wrap=True)
159
- table.add_column("RESOURCES")
157
+ table.add_column("INSTANCE", no_wrap=True, ratio=1)
158
+ table.add_column("RESERVATION", no_wrap=True, ratio=1)
159
+ table.add_column("PRICE", style="grey58", ratio=1)
160
+ table.add_column("STATUS", no_wrap=True, ratio=1)
161
+ table.add_column("SUBMITTED", style="grey58", no_wrap=True, ratio=1)
160
162
  if verbose:
161
- table.add_column("RESERVATION", no_wrap=True)
162
- table.add_column("PRICE", no_wrap=True)
163
- table.add_column("STATUS", no_wrap=True)
164
- table.add_column("SUBMITTED", style="grey58", no_wrap=True)
165
- if verbose:
166
- table.add_column("ERROR", no_wrap=True)
163
+ table.add_column("ERROR", no_wrap=True, ratio=2)
167
164
 
168
165
  for run in runs:
169
166
  run_error = _get_run_error(run)
@@ -202,10 +199,10 @@ def get_runs_table(
202
199
  job_row.update(
203
200
  {
204
201
  "BACKEND": f"{jpd.backend.value.replace('remote', 'ssh')} ({jpd.region})",
205
- "INSTANCE": instance,
206
202
  "RESOURCES": resources.pretty_format(include_spot=True),
203
+ "INSTANCE": instance,
207
204
  "RESERVATION": jpd.reservation,
208
- "PRICE": f"${jpd.price:.4}",
205
+ "PRICE": f"${jpd.price:.4f}".rstrip("0").rstrip("."),
209
206
  }
210
207
  )
211
208
  if len(run.jobs) == 1:
@@ -159,6 +159,7 @@ class AWSCompute(
159
159
  self,
160
160
  instance_offer: InstanceOfferWithAvailability,
161
161
  instance_config: InstanceConfiguration,
162
+ placement_group: Optional[PlacementGroup],
162
163
  ) -> JobProvisioningData:
163
164
  project_name = instance_config.project_name
164
165
  ec2_resource = self.session.resource("ec2", region_name=instance_offer.region)
@@ -248,7 +249,7 @@ class AWSCompute(
248
249
  spot=instance_offer.instance.resources.spot,
249
250
  subnet_id=subnet_id,
250
251
  allocate_public_ip=allocate_public_ip,
251
- placement_group_name=instance_config.placement_group_name,
252
+ placement_group_name=placement_group.name if placement_group else None,
252
253
  enable_efa=enable_efa,
253
254
  max_efa_interfaces=max_efa_interfaces,
254
255
  reservation_id=instance_config.reservation,
@@ -291,6 +292,7 @@ class AWSCompute(
291
292
  def create_placement_group(
292
293
  self,
293
294
  placement_group: PlacementGroup,
295
+ master_instance_offer: InstanceOffer,
294
296
  ) -> PlacementGroupProvisioningData:
295
297
  ec2_client = self.session.client("ec2", region_name=placement_group.configuration.region)
296
298
  logger.debug("Creating placement group %s...", placement_group.name)
@@ -323,6 +325,16 @@ class AWSCompute(
323
325
  raise e
324
326
  logger.debug("Deleted placement group %s", placement_group.name)
325
327
 
328
+ def is_suitable_placement_group(
329
+ self,
330
+ placement_group: PlacementGroup,
331
+ instance_offer: InstanceOffer,
332
+ ) -> bool:
333
+ return (
334
+ placement_group.configuration.backend == BackendType.AWS
335
+ and placement_group.configuration.region == instance_offer.region
336
+ )
337
+
326
338
  def create_gateway(
327
339
  self,
328
340
  configuration: GatewayComputeConfiguration,
@@ -62,6 +62,7 @@ from dstack._internal.core.models.instances import (
62
62
  InstanceOfferWithAvailability,
63
63
  InstanceType,
64
64
  )
65
+ from dstack._internal.core.models.placement import PlacementGroup
65
66
  from dstack._internal.core.models.resources import Memory, Range
66
67
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
67
68
  from dstack._internal.utils.logging import get_logger
@@ -109,6 +110,7 @@ class AzureCompute(
109
110
  self,
110
111
  instance_offer: InstanceOfferWithAvailability,
111
112
  instance_config: InstanceConfiguration,
113
+ placement_group: Optional[PlacementGroup],
112
114
  ) -> JobProvisioningData:
113
115
  instance_name = generate_unique_instance_name(
114
116
  instance_config, max_length=azure_resources.MAX_RESOURCE_NAME_LEN
@@ -136,6 +138,10 @@ class AzureCompute(
136
138
  location=location,
137
139
  )
138
140
 
141
+ managed_identity_resource_group, managed_identity_name = parse_vm_managed_identity(
142
+ self.config.vm_managed_identity
143
+ )
144
+
139
145
  base_tags = {
140
146
  "owner": "dstack",
141
147
  "dstack_project": instance_config.project_name,
@@ -159,7 +165,8 @@ class AzureCompute(
159
165
  network_security_group=network_security_group,
160
166
  network=network,
161
167
  subnet=subnet,
162
- managed_identity=None,
168
+ managed_identity_name=managed_identity_name,
169
+ managed_identity_resource_group=managed_identity_resource_group,
163
170
  image_reference=_get_image_ref(
164
171
  compute_client=self._compute_client,
165
172
  location=location,
@@ -255,7 +262,8 @@ class AzureCompute(
255
262
  network_security_group=network_security_group,
256
263
  network=network,
257
264
  subnet=subnet,
258
- managed_identity=None,
265
+ managed_identity_name=None,
266
+ managed_identity_resource_group=None,
259
267
  image_reference=_get_gateway_image_ref(),
260
268
  vm_size="Standard_B1ms",
261
269
  instance_name=instance_name,
@@ -338,6 +346,21 @@ def get_resource_group_network_subnet_or_error(
338
346
  return resource_group, network_name, subnet_name
339
347
 
340
348
 
349
+ def parse_vm_managed_identity(
350
+ vm_managed_identity: Optional[str],
351
+ ) -> Tuple[Optional[str], Optional[str]]:
352
+ if vm_managed_identity is None:
353
+ return None, None
354
+ try:
355
+ resource_group, managed_identity = vm_managed_identity.split("/")
356
+ return resource_group, managed_identity
357
+ except Exception:
358
+ raise ComputeError(
359
+ "`vm_managed_identity` specified in incorrect format."
360
+ " Supported format: 'managedIdentityResourceGroup/managedIdentityName'"
361
+ )
362
+
363
+
341
364
  def _parse_config_vpc_id(vpc_id: str) -> Tuple[str, str]:
342
365
  resource_group, network_name = vpc_id.split("/")
343
366
  return resource_group, network_name
@@ -466,7 +489,8 @@ def _launch_instance(
466
489
  network_security_group: str,
467
490
  network: str,
468
491
  subnet: str,
469
- managed_identity: Optional[str],
492
+ managed_identity_name: Optional[str],
493
+ managed_identity_resource_group: Optional[str],
470
494
  image_reference: ImageReference,
471
495
  vm_size: str,
472
496
  instance_name: str,
@@ -488,6 +512,20 @@ def _launch_instance(
488
512
  public_ip_address_configuration = VirtualMachinePublicIPAddressConfiguration(
489
513
  name="public_ip_config",
490
514
  )
515
+ managed_identity = None
516
+ if managed_identity_name is not None:
517
+ if managed_identity_resource_group is None:
518
+ managed_identity_resource_group = resource_group
519
+ managed_identity = VirtualMachineIdentity(
520
+ type=ResourceIdentityType.USER_ASSIGNED,
521
+ user_assigned_identities={
522
+ azure_utils.get_managed_identity_id(
523
+ subscription_id,
524
+ managed_identity_resource_group,
525
+ managed_identity_name,
526
+ ): UserAssignedIdentitiesValue(),
527
+ },
528
+ )
491
529
  try:
492
530
  poller = compute_client.virtual_machines.begin_create_or_update(
493
531
  resource_group,
@@ -552,16 +590,7 @@ def _launch_instance(
552
590
  ),
553
591
  priority="Spot" if spot else "Regular",
554
592
  eviction_policy="Delete" if spot else None,
555
- identity=None
556
- if managed_identity is None
557
- else VirtualMachineIdentity(
558
- type=ResourceIdentityType.USER_ASSIGNED,
559
- user_assigned_identities={
560
- azure_utils.get_managed_identity_id(
561
- subscription_id, resource_group, managed_identity
562
- ): UserAssignedIdentitiesValue()
563
- },
564
- ),
593
+ identity=managed_identity,
565
594
  user_data=base64.b64encode(user_data.encode()).decode(),
566
595
  tags=tags,
567
596
  ),
@@ -4,6 +4,7 @@ from typing import List, Optional, Tuple
4
4
 
5
5
  import azure.core.exceptions
6
6
  from azure.core.credentials import TokenCredential
7
+ from azure.mgmt import msi as msi_mgmt
7
8
  from azure.mgmt import network as network_mgmt
8
9
  from azure.mgmt import resource as resource_mgmt
9
10
  from azure.mgmt import subscription as subscription_mgmt
@@ -97,6 +98,7 @@ class AzureConfigurator(Configurator):
97
98
  self._check_config_locations(config)
98
99
  self._check_config_tags(config)
99
100
  self._check_config_resource_group(config=config, credential=credential)
101
+ self._check_config_vm_managed_identity(config=config, credential=credential)
100
102
  self._check_config_vpc(config=config, credential=credential)
101
103
 
102
104
  def create_backend(
@@ -260,6 +262,25 @@ class AzureConfigurator(Configurator):
260
262
  except BackendError as e:
261
263
  raise ServerClientError(e.args[0])
262
264
 
265
+ def _check_config_vm_managed_identity(
266
+ self, config: AzureBackendConfigWithCreds, credential: auth.AzureCredential
267
+ ):
268
+ try:
269
+ resource_group, identity_name = compute.parse_vm_managed_identity(
270
+ config.vm_managed_identity
271
+ )
272
+ except BackendError as e:
273
+ raise ServerClientError(e.args[0])
274
+ if resource_group is None or identity_name is None:
275
+ return
276
+ msi_client = msi_mgmt.ManagedServiceIdentityClient(credential, config.subscription_id)
277
+ try:
278
+ msi_client.user_assigned_identities.get(resource_group, identity_name)
279
+ except azure.core.exceptions.ResourceNotFoundError:
280
+ raise ServerClientError(
281
+ f"Managed identity {identity_name} not found in resource group {resource_group}"
282
+ )
283
+
263
284
  def _set_client_creds_tenant_id(
264
285
  self,
265
286
  creds: AzureClientCreds,
@@ -62,6 +62,15 @@ class AzureBackendConfig(CoreModel):
62
62
  )
63
63
  ),
64
64
  ] = None
65
+ vm_managed_identity: Annotated[
66
+ Optional[str],
67
+ Field(
68
+ description=(
69
+ "The managed identity to associate with provisioned VMs."
70
+ " Must have a format `managedIdentityResourceGroup/managedIdentityName`"
71
+ )
72
+ ),
73
+ ] = None
65
74
  tags: Annotated[
66
75
  Optional[Dict[str, str]],
67
76
  Field(description="The tags that will be assigned to resources created by `dstack`"),