dstack 0.19.6rc1__py3-none-any.whl → 0.19.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dstack might be problematic. Click here for more details.
- dstack/_internal/cli/services/args.py +2 -2
- dstack/_internal/cli/services/configurators/fleet.py +3 -2
- dstack/_internal/cli/services/configurators/run.py +50 -4
- dstack/_internal/cli/utils/fleet.py +3 -1
- dstack/_internal/cli/utils/run.py +25 -28
- dstack/_internal/core/backends/aws/compute.py +13 -1
- dstack/_internal/core/backends/azure/compute.py +42 -13
- dstack/_internal/core/backends/azure/configurator.py +21 -0
- dstack/_internal/core/backends/azure/models.py +9 -0
- dstack/_internal/core/backends/base/compute.py +101 -27
- dstack/_internal/core/backends/base/offers.py +13 -3
- dstack/_internal/core/backends/cudo/compute.py +2 -0
- dstack/_internal/core/backends/datacrunch/compute.py +2 -0
- dstack/_internal/core/backends/gcp/auth.py +1 -1
- dstack/_internal/core/backends/gcp/compute.py +51 -35
- dstack/_internal/core/backends/gcp/resources.py +6 -1
- dstack/_internal/core/backends/lambdalabs/compute.py +20 -8
- dstack/_internal/core/backends/local/compute.py +2 -0
- dstack/_internal/core/backends/nebius/compute.py +95 -1
- dstack/_internal/core/backends/nebius/configurator.py +11 -0
- dstack/_internal/core/backends/nebius/fabrics.py +47 -0
- dstack/_internal/core/backends/nebius/models.py +8 -0
- dstack/_internal/core/backends/nebius/resources.py +29 -0
- dstack/_internal/core/backends/oci/compute.py +2 -0
- dstack/_internal/core/backends/remote/provisioning.py +27 -2
- dstack/_internal/core/backends/template/compute.py.jinja +2 -0
- dstack/_internal/core/backends/tensordock/compute.py +2 -0
- dstack/_internal/core/backends/vastai/compute.py +2 -1
- dstack/_internal/core/backends/vultr/compute.py +5 -1
- dstack/_internal/core/errors.py +4 -0
- dstack/_internal/core/models/fleets.py +2 -0
- dstack/_internal/core/models/instances.py +4 -3
- dstack/_internal/core/models/resources.py +80 -3
- dstack/_internal/core/models/runs.py +10 -3
- dstack/_internal/core/models/volumes.py +1 -1
- dstack/_internal/server/background/tasks/process_fleets.py +4 -13
- dstack/_internal/server/background/tasks/process_instances.py +176 -55
- dstack/_internal/server/background/tasks/process_placement_groups.py +1 -1
- dstack/_internal/server/background/tasks/process_prometheus_metrics.py +5 -2
- dstack/_internal/server/background/tasks/process_submitted_jobs.py +1 -1
- dstack/_internal/server/models.py +1 -0
- dstack/_internal/server/routers/gateways.py +2 -1
- dstack/_internal/server/services/config.py +7 -2
- dstack/_internal/server/services/fleets.py +24 -26
- dstack/_internal/server/services/gateways/__init__.py +17 -2
- dstack/_internal/server/services/instances.py +0 -2
- dstack/_internal/server/services/offers.py +15 -0
- dstack/_internal/server/services/placement.py +27 -6
- dstack/_internal/server/services/plugins.py +77 -0
- dstack/_internal/server/services/resources.py +21 -0
- dstack/_internal/server/services/runs.py +41 -17
- dstack/_internal/server/services/volumes.py +10 -1
- dstack/_internal/server/testing/common.py +35 -26
- dstack/_internal/utils/common.py +22 -9
- dstack/_internal/utils/json_schema.py +6 -3
- dstack/api/__init__.py +1 -0
- dstack/api/server/__init__.py +8 -1
- dstack/api/server/_fleets.py +16 -0
- dstack/api/server/_runs.py +44 -3
- dstack/plugins/__init__.py +8 -0
- dstack/plugins/_base.py +72 -0
- dstack/plugins/_models.py +8 -0
- dstack/plugins/_utils.py +19 -0
- dstack/version.py +1 -1
- {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/METADATA +14 -2
- {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/RECORD +69 -62
- {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/WHEEL +0 -0
- {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/entry_points.txt +0 -0
- {dstack-0.19.6rc1.dist-info → dstack-0.19.8.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -19,8 +19,8 @@ def port_mapping(v: str) -> PortMapping:
|
|
|
19
19
|
return PortMapping.parse(v)
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def cpu_spec(v: str) ->
|
|
23
|
-
return
|
|
22
|
+
def cpu_spec(v: str) -> dict:
|
|
23
|
+
return resources.CPUSpec.parse(v)
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def memory_spec(v: str) -> resources.Range[resources.Memory]:
|
|
@@ -20,6 +20,7 @@ from dstack._internal.cli.utils.rich import MultiItemStatus
|
|
|
20
20
|
from dstack._internal.core.errors import (
|
|
21
21
|
CLIError,
|
|
22
22
|
ConfigurationError,
|
|
23
|
+
MethodNotAllowedError,
|
|
23
24
|
ResourceNotExistsError,
|
|
24
25
|
ServerClientError,
|
|
25
26
|
URLNotFoundError,
|
|
@@ -321,7 +322,7 @@ def _print_plan_header(plan: FleetPlan):
|
|
|
321
322
|
offer.instance.name,
|
|
322
323
|
resources.pretty_format(),
|
|
323
324
|
"yes" if resources.spot else "no",
|
|
324
|
-
f"${offer.price:
|
|
325
|
+
f"${offer.price:3f}".rstrip("0").rstrip("."),
|
|
325
326
|
availability,
|
|
326
327
|
style=None if index == 1 else "secondary",
|
|
327
328
|
)
|
|
@@ -367,7 +368,7 @@ def _apply_plan(api: Client, plan: FleetPlan) -> Fleet:
|
|
|
367
368
|
project_name=api.project,
|
|
368
369
|
plan=plan,
|
|
369
370
|
)
|
|
370
|
-
except URLNotFoundError:
|
|
371
|
+
except (URLNotFoundError, MethodNotAllowedError):
|
|
371
372
|
# TODO: Remove in 0.20
|
|
372
373
|
return api.client.fleets.create(
|
|
373
374
|
project_name=api.project,
|
|
@@ -6,9 +6,10 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Dict, List, Optional, Set, Tuple
|
|
7
7
|
|
|
8
8
|
import gpuhunt
|
|
9
|
+
from pydantic import parse_obj_as
|
|
9
10
|
|
|
10
11
|
import dstack._internal.core.models.resources as resources
|
|
11
|
-
from dstack._internal.cli.services.args import disk_spec, gpu_spec, port_mapping
|
|
12
|
+
from dstack._internal.cli.services.args import cpu_spec, disk_spec, gpu_spec, port_mapping
|
|
12
13
|
from dstack._internal.cli.services.configurators.base import (
|
|
13
14
|
ApplyEnvVarsConfiguratorMixin,
|
|
14
15
|
BaseApplyConfigurator,
|
|
@@ -39,6 +40,7 @@ from dstack._internal.core.models.configurations import (
|
|
|
39
40
|
TaskConfiguration,
|
|
40
41
|
)
|
|
41
42
|
from dstack._internal.core.models.repos.base import Repo
|
|
43
|
+
from dstack._internal.core.models.resources import CPUSpec
|
|
42
44
|
from dstack._internal.core.models.runs import JobSubmission, JobTerminationReason, RunStatus
|
|
43
45
|
from dstack._internal.core.services.configs import ConfigManager
|
|
44
46
|
from dstack._internal.core.services.diff import diff_models
|
|
@@ -52,7 +54,7 @@ from dstack.api.utils import load_profile
|
|
|
52
54
|
_KNOWN_AMD_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_AMD_GPUS}
|
|
53
55
|
_KNOWN_NVIDIA_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_NVIDIA_GPUS}
|
|
54
56
|
_KNOWN_TPU_VERSIONS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_TPUS}
|
|
55
|
-
|
|
57
|
+
_KNOWN_TENSTORRENT_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_TENSTORRENT_ACCELERATORS}
|
|
56
58
|
_BIND_ADDRESS_ARG = "bind_address"
|
|
57
59
|
|
|
58
60
|
logger = get_logger(__name__)
|
|
@@ -72,6 +74,7 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
|
|
|
72
74
|
):
|
|
73
75
|
self.apply_args(conf, configurator_args, unknown_args)
|
|
74
76
|
self.validate_gpu_vendor_and_image(conf)
|
|
77
|
+
self.validate_cpu_arch_and_image(conf)
|
|
75
78
|
if repo is None:
|
|
76
79
|
repo = self.api.repos.load(Path.cwd())
|
|
77
80
|
config_manager = ConfigManager()
|
|
@@ -289,6 +292,14 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
|
|
|
289
292
|
default=default_max_offers,
|
|
290
293
|
)
|
|
291
294
|
cls.register_env_args(configuration_group)
|
|
295
|
+
configuration_group.add_argument(
|
|
296
|
+
"--cpu",
|
|
297
|
+
type=cpu_spec,
|
|
298
|
+
help="Request CPU for the run. "
|
|
299
|
+
"The format is [code]ARCH[/]:[code]COUNT[/] (all parts are optional)",
|
|
300
|
+
dest="cpu_spec",
|
|
301
|
+
metavar="SPEC",
|
|
302
|
+
)
|
|
292
303
|
configuration_group.add_argument(
|
|
293
304
|
"--gpu",
|
|
294
305
|
type=gpu_spec,
|
|
@@ -310,6 +321,8 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
|
|
|
310
321
|
apply_profile_args(args, conf)
|
|
311
322
|
if args.run_name:
|
|
312
323
|
conf.name = args.run_name
|
|
324
|
+
if args.cpu_spec:
|
|
325
|
+
conf.resources.cpu = resources.CPUSpec.parse_obj(args.cpu_spec)
|
|
313
326
|
if args.gpu_spec:
|
|
314
327
|
conf.resources.gpu = resources.GPUSpec.parse_obj(args.gpu_spec)
|
|
315
328
|
if args.disk_spec:
|
|
@@ -342,7 +355,7 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
|
|
|
342
355
|
|
|
343
356
|
def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None:
|
|
344
357
|
"""
|
|
345
|
-
Infers `resources.gpu.vendor` if not set, requires `image` if the vendor is AMD.
|
|
358
|
+
Infers and sets `resources.gpu.vendor` if not set, requires `image` if the vendor is AMD.
|
|
346
359
|
"""
|
|
347
360
|
gpu_spec = conf.resources.gpu
|
|
348
361
|
if gpu_spec is None:
|
|
@@ -350,6 +363,7 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
|
|
|
350
363
|
if gpu_spec.count.max == 0:
|
|
351
364
|
return
|
|
352
365
|
has_amd_gpu: bool
|
|
366
|
+
has_tt_gpu: bool
|
|
353
367
|
vendor = gpu_spec.vendor
|
|
354
368
|
if vendor is None:
|
|
355
369
|
names = gpu_spec.name
|
|
@@ -362,6 +376,8 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
|
|
|
362
376
|
vendors.add(gpuhunt.AcceleratorVendor.NVIDIA)
|
|
363
377
|
elif name in _KNOWN_AMD_GPUS:
|
|
364
378
|
vendors.add(gpuhunt.AcceleratorVendor.AMD)
|
|
379
|
+
elif name in _KNOWN_TENSTORRENT_GPUS:
|
|
380
|
+
vendors.add(gpuhunt.AcceleratorVendor.TENSTORRENT)
|
|
365
381
|
else:
|
|
366
382
|
maybe_tpu_version, _, maybe_tpu_cores = name.partition("-")
|
|
367
383
|
if maybe_tpu_version in _KNOWN_TPU_VERSIONS and maybe_tpu_cores.isdigit():
|
|
@@ -380,15 +396,45 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
|
|
|
380
396
|
# to execute a run on an instance with an AMD accelerator with a default
|
|
381
397
|
# CUDA image, not a big deal.
|
|
382
398
|
has_amd_gpu = gpuhunt.AcceleratorVendor.AMD in vendors
|
|
399
|
+
has_tt_gpu = gpuhunt.AcceleratorVendor.TENSTORRENT in vendors
|
|
383
400
|
else:
|
|
384
401
|
# If neither gpu.vendor nor gpu.name is set, assume Nvidia.
|
|
385
402
|
vendor = gpuhunt.AcceleratorVendor.NVIDIA
|
|
386
403
|
has_amd_gpu = False
|
|
404
|
+
has_tt_gpu = False
|
|
387
405
|
gpu_spec.vendor = vendor
|
|
388
406
|
else:
|
|
389
407
|
has_amd_gpu = vendor == gpuhunt.AcceleratorVendor.AMD
|
|
408
|
+
has_tt_gpu = vendor == gpuhunt.AcceleratorVendor.TENSTORRENT
|
|
390
409
|
if has_amd_gpu and conf.image is None:
|
|
391
|
-
raise ConfigurationError("`image` is required if `resources.gpu.vendor` is
|
|
410
|
+
raise ConfigurationError("`image` is required if `resources.gpu.vendor` is `amd`")
|
|
411
|
+
if has_tt_gpu and conf.image is None:
|
|
412
|
+
raise ConfigurationError(
|
|
413
|
+
"`image` is required if `resources.gpu.vendor` is `tenstorrent`"
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
def validate_cpu_arch_and_image(self, conf: BaseRunConfiguration) -> None:
|
|
417
|
+
"""
|
|
418
|
+
Infers `resources.cpu.arch` if not set, requires `image` if the architecture is ARM.
|
|
419
|
+
"""
|
|
420
|
+
# TODO: Remove in 0.20. Use conf.resources.cpu directly
|
|
421
|
+
cpu_spec = parse_obj_as(CPUSpec, conf.resources.cpu)
|
|
422
|
+
arch = cpu_spec.arch
|
|
423
|
+
if arch is None:
|
|
424
|
+
gpu_spec = conf.resources.gpu
|
|
425
|
+
if (
|
|
426
|
+
gpu_spec is not None
|
|
427
|
+
and gpu_spec.vendor in [None, gpuhunt.AcceleratorVendor.NVIDIA]
|
|
428
|
+
and gpu_spec.name
|
|
429
|
+
and any(map(gpuhunt.is_nvidia_superchip, gpu_spec.name))
|
|
430
|
+
):
|
|
431
|
+
arch = gpuhunt.CPUArchitecture.ARM
|
|
432
|
+
else:
|
|
433
|
+
arch = gpuhunt.CPUArchitecture.X86
|
|
434
|
+
# NOTE: We don't set the inferred resources.cpu.arch for compatibility with older servers.
|
|
435
|
+
# Servers with ARM support set the arch using the same logic.
|
|
436
|
+
if arch == gpuhunt.CPUArchitecture.ARM and conf.image is None:
|
|
437
|
+
raise ConfigurationError("`image` is required if `resources.cpu.arch` is `arm`")
|
|
392
438
|
|
|
393
439
|
|
|
394
440
|
class RunWithPortsConfigurator(BaseRunConfigurator):
|
|
@@ -79,7 +79,9 @@ def get_fleets_table(
|
|
|
79
79
|
"BACKEND": backend,
|
|
80
80
|
"REGION": region,
|
|
81
81
|
"RESOURCES": resources,
|
|
82
|
-
"PRICE": f"${instance.price:.
|
|
82
|
+
"PRICE": f"${instance.price:.4f}".rstrip("0").rstrip(".")
|
|
83
|
+
if instance.price is not None
|
|
84
|
+
else "",
|
|
83
85
|
"STATUS": status,
|
|
84
86
|
"CREATED": format_date(instance.created),
|
|
85
87
|
"ERROR": error,
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import shutil
|
|
1
2
|
from typing import Any, Dict, List, Optional, Union
|
|
2
3
|
|
|
3
4
|
from rich.markup import escape
|
|
@@ -36,7 +37,7 @@ def print_run_plan(
|
|
|
36
37
|
|
|
37
38
|
req = job_plan.job_spec.requirements
|
|
38
39
|
pretty_req = req.pretty_format(resources_only=True)
|
|
39
|
-
max_price = f"${req.max_price:
|
|
40
|
+
max_price = f"${req.max_price:3f}".rstrip("0").rstrip(".") if req.max_price else "-"
|
|
40
41
|
max_duration = (
|
|
41
42
|
format_pretty_duration(job_plan.job_spec.max_duration)
|
|
42
43
|
if job_plan.job_spec.max_duration
|
|
@@ -94,14 +95,12 @@ def print_run_plan(
|
|
|
94
95
|
props.add_row(th("Inactivity duration"), inactivity_duration)
|
|
95
96
|
props.add_row(th("Reservation"), run_spec.configuration.reservation or "-")
|
|
96
97
|
|
|
97
|
-
offers = Table(box=None)
|
|
98
|
+
offers = Table(box=None, expand=shutil.get_terminal_size(fallback=(120, 40)).columns <= 110)
|
|
98
99
|
offers.add_column("#")
|
|
99
|
-
offers.add_column("BACKEND")
|
|
100
|
-
offers.add_column("
|
|
101
|
-
offers.add_column("INSTANCE TYPE")
|
|
102
|
-
offers.add_column("
|
|
103
|
-
offers.add_column("SPOT")
|
|
104
|
-
offers.add_column("PRICE")
|
|
100
|
+
offers.add_column("BACKEND", style="grey58", ratio=2)
|
|
101
|
+
offers.add_column("RESOURCES", ratio=4)
|
|
102
|
+
offers.add_column("INSTANCE TYPE", style="grey58", no_wrap=True, ratio=2)
|
|
103
|
+
offers.add_column("PRICE", style="grey58", ratio=1)
|
|
105
104
|
offers.add_column()
|
|
106
105
|
|
|
107
106
|
job_plan.offers = job_plan.offers[:max_offers] if max_offers else job_plan.offers
|
|
@@ -122,14 +121,12 @@ def print_run_plan(
|
|
|
122
121
|
instance += f" ({offer.blocks}/{offer.total_blocks})"
|
|
123
122
|
offers.add_row(
|
|
124
123
|
f"{i}",
|
|
125
|
-
offer.backend.replace("remote", "ssh"),
|
|
126
|
-
|
|
124
|
+
offer.backend.replace("remote", "ssh") + " (" + offer.region + ")",
|
|
125
|
+
r.pretty_format(include_spot=True),
|
|
127
126
|
instance,
|
|
128
|
-
|
|
129
|
-
"yes" if r.spot else "no",
|
|
130
|
-
f"${offer.price:g}",
|
|
127
|
+
f"${offer.price:.4f}".rstrip("0").rstrip("."),
|
|
131
128
|
availability,
|
|
132
|
-
style=None if i == 1 else "secondary",
|
|
129
|
+
style=None if i == 1 or not include_run_properties else "secondary",
|
|
133
130
|
)
|
|
134
131
|
if job_plan.total_offers > len(job_plan.offers):
|
|
135
132
|
offers.add_row("", "...", style="secondary")
|
|
@@ -141,7 +138,8 @@ def print_run_plan(
|
|
|
141
138
|
if job_plan.total_offers > len(job_plan.offers):
|
|
142
139
|
console.print(
|
|
143
140
|
f"[secondary] Shown {len(job_plan.offers)} of {job_plan.total_offers} offers, "
|
|
144
|
-
f"${job_plan.max_price:
|
|
141
|
+
f"${job_plan.max_price:3f}".rstrip("0").rstrip(".")
|
|
142
|
+
+ "max[/]"
|
|
145
143
|
)
|
|
146
144
|
console.print()
|
|
147
145
|
else:
|
|
@@ -151,19 +149,18 @@ def print_run_plan(
|
|
|
151
149
|
def get_runs_table(
|
|
152
150
|
runs: List[Run], verbose: bool = False, format_date: DateFormatter = pretty_date
|
|
153
151
|
) -> Table:
|
|
154
|
-
table = Table(box=None)
|
|
155
|
-
table.add_column("NAME", style="bold", no_wrap=True)
|
|
156
|
-
table.add_column("BACKEND", style="grey58")
|
|
152
|
+
table = Table(box=None, expand=shutil.get_terminal_size(fallback=(120, 40)).columns <= 110)
|
|
153
|
+
table.add_column("NAME", style="bold", no_wrap=True, ratio=2)
|
|
154
|
+
table.add_column("BACKEND", style="grey58", ratio=2)
|
|
155
|
+
table.add_column("RESOURCES", ratio=3 if not verbose else 2)
|
|
157
156
|
if verbose:
|
|
158
|
-
table.add_column("INSTANCE", no_wrap=True)
|
|
159
|
-
|
|
157
|
+
table.add_column("INSTANCE", no_wrap=True, ratio=1)
|
|
158
|
+
table.add_column("RESERVATION", no_wrap=True, ratio=1)
|
|
159
|
+
table.add_column("PRICE", style="grey58", ratio=1)
|
|
160
|
+
table.add_column("STATUS", no_wrap=True, ratio=1)
|
|
161
|
+
table.add_column("SUBMITTED", style="grey58", no_wrap=True, ratio=1)
|
|
160
162
|
if verbose:
|
|
161
|
-
table.add_column("
|
|
162
|
-
table.add_column("PRICE", no_wrap=True)
|
|
163
|
-
table.add_column("STATUS", no_wrap=True)
|
|
164
|
-
table.add_column("SUBMITTED", style="grey58", no_wrap=True)
|
|
165
|
-
if verbose:
|
|
166
|
-
table.add_column("ERROR", no_wrap=True)
|
|
163
|
+
table.add_column("ERROR", no_wrap=True, ratio=2)
|
|
167
164
|
|
|
168
165
|
for run in runs:
|
|
169
166
|
run_error = _get_run_error(run)
|
|
@@ -202,10 +199,10 @@ def get_runs_table(
|
|
|
202
199
|
job_row.update(
|
|
203
200
|
{
|
|
204
201
|
"BACKEND": f"{jpd.backend.value.replace('remote', 'ssh')} ({jpd.region})",
|
|
205
|
-
"INSTANCE": instance,
|
|
206
202
|
"RESOURCES": resources.pretty_format(include_spot=True),
|
|
203
|
+
"INSTANCE": instance,
|
|
207
204
|
"RESERVATION": jpd.reservation,
|
|
208
|
-
"PRICE": f"${jpd.price:.
|
|
205
|
+
"PRICE": f"${jpd.price:.4f}".rstrip("0").rstrip("."),
|
|
209
206
|
}
|
|
210
207
|
)
|
|
211
208
|
if len(run.jobs) == 1:
|
|
@@ -159,6 +159,7 @@ class AWSCompute(
|
|
|
159
159
|
self,
|
|
160
160
|
instance_offer: InstanceOfferWithAvailability,
|
|
161
161
|
instance_config: InstanceConfiguration,
|
|
162
|
+
placement_group: Optional[PlacementGroup],
|
|
162
163
|
) -> JobProvisioningData:
|
|
163
164
|
project_name = instance_config.project_name
|
|
164
165
|
ec2_resource = self.session.resource("ec2", region_name=instance_offer.region)
|
|
@@ -248,7 +249,7 @@ class AWSCompute(
|
|
|
248
249
|
spot=instance_offer.instance.resources.spot,
|
|
249
250
|
subnet_id=subnet_id,
|
|
250
251
|
allocate_public_ip=allocate_public_ip,
|
|
251
|
-
placement_group_name=
|
|
252
|
+
placement_group_name=placement_group.name if placement_group else None,
|
|
252
253
|
enable_efa=enable_efa,
|
|
253
254
|
max_efa_interfaces=max_efa_interfaces,
|
|
254
255
|
reservation_id=instance_config.reservation,
|
|
@@ -291,6 +292,7 @@ class AWSCompute(
|
|
|
291
292
|
def create_placement_group(
|
|
292
293
|
self,
|
|
293
294
|
placement_group: PlacementGroup,
|
|
295
|
+
master_instance_offer: InstanceOffer,
|
|
294
296
|
) -> PlacementGroupProvisioningData:
|
|
295
297
|
ec2_client = self.session.client("ec2", region_name=placement_group.configuration.region)
|
|
296
298
|
logger.debug("Creating placement group %s...", placement_group.name)
|
|
@@ -323,6 +325,16 @@ class AWSCompute(
|
|
|
323
325
|
raise e
|
|
324
326
|
logger.debug("Deleted placement group %s", placement_group.name)
|
|
325
327
|
|
|
328
|
+
def is_suitable_placement_group(
|
|
329
|
+
self,
|
|
330
|
+
placement_group: PlacementGroup,
|
|
331
|
+
instance_offer: InstanceOffer,
|
|
332
|
+
) -> bool:
|
|
333
|
+
return (
|
|
334
|
+
placement_group.configuration.backend == BackendType.AWS
|
|
335
|
+
and placement_group.configuration.region == instance_offer.region
|
|
336
|
+
)
|
|
337
|
+
|
|
326
338
|
def create_gateway(
|
|
327
339
|
self,
|
|
328
340
|
configuration: GatewayComputeConfiguration,
|
|
@@ -62,6 +62,7 @@ from dstack._internal.core.models.instances import (
|
|
|
62
62
|
InstanceOfferWithAvailability,
|
|
63
63
|
InstanceType,
|
|
64
64
|
)
|
|
65
|
+
from dstack._internal.core.models.placement import PlacementGroup
|
|
65
66
|
from dstack._internal.core.models.resources import Memory, Range
|
|
66
67
|
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
|
|
67
68
|
from dstack._internal.utils.logging import get_logger
|
|
@@ -109,6 +110,7 @@ class AzureCompute(
|
|
|
109
110
|
self,
|
|
110
111
|
instance_offer: InstanceOfferWithAvailability,
|
|
111
112
|
instance_config: InstanceConfiguration,
|
|
113
|
+
placement_group: Optional[PlacementGroup],
|
|
112
114
|
) -> JobProvisioningData:
|
|
113
115
|
instance_name = generate_unique_instance_name(
|
|
114
116
|
instance_config, max_length=azure_resources.MAX_RESOURCE_NAME_LEN
|
|
@@ -136,6 +138,10 @@ class AzureCompute(
|
|
|
136
138
|
location=location,
|
|
137
139
|
)
|
|
138
140
|
|
|
141
|
+
managed_identity_resource_group, managed_identity_name = parse_vm_managed_identity(
|
|
142
|
+
self.config.vm_managed_identity
|
|
143
|
+
)
|
|
144
|
+
|
|
139
145
|
base_tags = {
|
|
140
146
|
"owner": "dstack",
|
|
141
147
|
"dstack_project": instance_config.project_name,
|
|
@@ -159,7 +165,8 @@ class AzureCompute(
|
|
|
159
165
|
network_security_group=network_security_group,
|
|
160
166
|
network=network,
|
|
161
167
|
subnet=subnet,
|
|
162
|
-
|
|
168
|
+
managed_identity_name=managed_identity_name,
|
|
169
|
+
managed_identity_resource_group=managed_identity_resource_group,
|
|
163
170
|
image_reference=_get_image_ref(
|
|
164
171
|
compute_client=self._compute_client,
|
|
165
172
|
location=location,
|
|
@@ -255,7 +262,8 @@ class AzureCompute(
|
|
|
255
262
|
network_security_group=network_security_group,
|
|
256
263
|
network=network,
|
|
257
264
|
subnet=subnet,
|
|
258
|
-
|
|
265
|
+
managed_identity_name=None,
|
|
266
|
+
managed_identity_resource_group=None,
|
|
259
267
|
image_reference=_get_gateway_image_ref(),
|
|
260
268
|
vm_size="Standard_B1ms",
|
|
261
269
|
instance_name=instance_name,
|
|
@@ -338,6 +346,21 @@ def get_resource_group_network_subnet_or_error(
|
|
|
338
346
|
return resource_group, network_name, subnet_name
|
|
339
347
|
|
|
340
348
|
|
|
349
|
+
def parse_vm_managed_identity(
|
|
350
|
+
vm_managed_identity: Optional[str],
|
|
351
|
+
) -> Tuple[Optional[str], Optional[str]]:
|
|
352
|
+
if vm_managed_identity is None:
|
|
353
|
+
return None, None
|
|
354
|
+
try:
|
|
355
|
+
resource_group, managed_identity = vm_managed_identity.split("/")
|
|
356
|
+
return resource_group, managed_identity
|
|
357
|
+
except Exception:
|
|
358
|
+
raise ComputeError(
|
|
359
|
+
"`vm_managed_identity` specified in incorrect format."
|
|
360
|
+
" Supported format: 'managedIdentityResourceGroup/managedIdentityName'"
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
|
|
341
364
|
def _parse_config_vpc_id(vpc_id: str) -> Tuple[str, str]:
|
|
342
365
|
resource_group, network_name = vpc_id.split("/")
|
|
343
366
|
return resource_group, network_name
|
|
@@ -466,7 +489,8 @@ def _launch_instance(
|
|
|
466
489
|
network_security_group: str,
|
|
467
490
|
network: str,
|
|
468
491
|
subnet: str,
|
|
469
|
-
|
|
492
|
+
managed_identity_name: Optional[str],
|
|
493
|
+
managed_identity_resource_group: Optional[str],
|
|
470
494
|
image_reference: ImageReference,
|
|
471
495
|
vm_size: str,
|
|
472
496
|
instance_name: str,
|
|
@@ -488,6 +512,20 @@ def _launch_instance(
|
|
|
488
512
|
public_ip_address_configuration = VirtualMachinePublicIPAddressConfiguration(
|
|
489
513
|
name="public_ip_config",
|
|
490
514
|
)
|
|
515
|
+
managed_identity = None
|
|
516
|
+
if managed_identity_name is not None:
|
|
517
|
+
if managed_identity_resource_group is None:
|
|
518
|
+
managed_identity_resource_group = resource_group
|
|
519
|
+
managed_identity = VirtualMachineIdentity(
|
|
520
|
+
type=ResourceIdentityType.USER_ASSIGNED,
|
|
521
|
+
user_assigned_identities={
|
|
522
|
+
azure_utils.get_managed_identity_id(
|
|
523
|
+
subscription_id,
|
|
524
|
+
managed_identity_resource_group,
|
|
525
|
+
managed_identity_name,
|
|
526
|
+
): UserAssignedIdentitiesValue(),
|
|
527
|
+
},
|
|
528
|
+
)
|
|
491
529
|
try:
|
|
492
530
|
poller = compute_client.virtual_machines.begin_create_or_update(
|
|
493
531
|
resource_group,
|
|
@@ -552,16 +590,7 @@ def _launch_instance(
|
|
|
552
590
|
),
|
|
553
591
|
priority="Spot" if spot else "Regular",
|
|
554
592
|
eviction_policy="Delete" if spot else None,
|
|
555
|
-
identity=
|
|
556
|
-
if managed_identity is None
|
|
557
|
-
else VirtualMachineIdentity(
|
|
558
|
-
type=ResourceIdentityType.USER_ASSIGNED,
|
|
559
|
-
user_assigned_identities={
|
|
560
|
-
azure_utils.get_managed_identity_id(
|
|
561
|
-
subscription_id, resource_group, managed_identity
|
|
562
|
-
): UserAssignedIdentitiesValue()
|
|
563
|
-
},
|
|
564
|
-
),
|
|
593
|
+
identity=managed_identity,
|
|
565
594
|
user_data=base64.b64encode(user_data.encode()).decode(),
|
|
566
595
|
tags=tags,
|
|
567
596
|
),
|
|
@@ -4,6 +4,7 @@ from typing import List, Optional, Tuple
|
|
|
4
4
|
|
|
5
5
|
import azure.core.exceptions
|
|
6
6
|
from azure.core.credentials import TokenCredential
|
|
7
|
+
from azure.mgmt import msi as msi_mgmt
|
|
7
8
|
from azure.mgmt import network as network_mgmt
|
|
8
9
|
from azure.mgmt import resource as resource_mgmt
|
|
9
10
|
from azure.mgmt import subscription as subscription_mgmt
|
|
@@ -97,6 +98,7 @@ class AzureConfigurator(Configurator):
|
|
|
97
98
|
self._check_config_locations(config)
|
|
98
99
|
self._check_config_tags(config)
|
|
99
100
|
self._check_config_resource_group(config=config, credential=credential)
|
|
101
|
+
self._check_config_vm_managed_identity(config=config, credential=credential)
|
|
100
102
|
self._check_config_vpc(config=config, credential=credential)
|
|
101
103
|
|
|
102
104
|
def create_backend(
|
|
@@ -260,6 +262,25 @@ class AzureConfigurator(Configurator):
|
|
|
260
262
|
except BackendError as e:
|
|
261
263
|
raise ServerClientError(e.args[0])
|
|
262
264
|
|
|
265
|
+
def _check_config_vm_managed_identity(
|
|
266
|
+
self, config: AzureBackendConfigWithCreds, credential: auth.AzureCredential
|
|
267
|
+
):
|
|
268
|
+
try:
|
|
269
|
+
resource_group, identity_name = compute.parse_vm_managed_identity(
|
|
270
|
+
config.vm_managed_identity
|
|
271
|
+
)
|
|
272
|
+
except BackendError as e:
|
|
273
|
+
raise ServerClientError(e.args[0])
|
|
274
|
+
if resource_group is None or identity_name is None:
|
|
275
|
+
return
|
|
276
|
+
msi_client = msi_mgmt.ManagedServiceIdentityClient(credential, config.subscription_id)
|
|
277
|
+
try:
|
|
278
|
+
msi_client.user_assigned_identities.get(resource_group, identity_name)
|
|
279
|
+
except azure.core.exceptions.ResourceNotFoundError:
|
|
280
|
+
raise ServerClientError(
|
|
281
|
+
f"Managed identity {identity_name} not found in resource group {resource_group}"
|
|
282
|
+
)
|
|
283
|
+
|
|
263
284
|
def _set_client_creds_tenant_id(
|
|
264
285
|
self,
|
|
265
286
|
creds: AzureClientCreds,
|
|
@@ -62,6 +62,15 @@ class AzureBackendConfig(CoreModel):
|
|
|
62
62
|
)
|
|
63
63
|
),
|
|
64
64
|
] = None
|
|
65
|
+
vm_managed_identity: Annotated[
|
|
66
|
+
Optional[str],
|
|
67
|
+
Field(
|
|
68
|
+
description=(
|
|
69
|
+
"The managed identity to associate with provisioned VMs."
|
|
70
|
+
" Must have a format `managedIdentityResourceGroup/managedIdentityName`"
|
|
71
|
+
)
|
|
72
|
+
),
|
|
73
|
+
] = None
|
|
65
74
|
tags: Annotated[
|
|
66
75
|
Optional[Dict[str, str]],
|
|
67
76
|
Field(description="The tags that will be assigned to resources created by `dstack`"),
|