dstack 0.19.30rc1__py3-none-any.whl → 0.19.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (54) hide show
  1. dstack/_internal/cli/commands/__init__.py +8 -0
  2. dstack/_internal/cli/commands/project.py +27 -20
  3. dstack/_internal/cli/commands/server.py +5 -0
  4. dstack/_internal/cli/services/configurators/fleet.py +20 -6
  5. dstack/_internal/cli/utils/gpu.py +2 -2
  6. dstack/_internal/core/backends/aws/compute.py +13 -5
  7. dstack/_internal/core/backends/aws/resources.py +11 -6
  8. dstack/_internal/core/backends/azure/compute.py +17 -6
  9. dstack/_internal/core/backends/base/compute.py +57 -9
  10. dstack/_internal/core/backends/base/offers.py +1 -0
  11. dstack/_internal/core/backends/cloudrift/compute.py +2 -0
  12. dstack/_internal/core/backends/cudo/compute.py +2 -0
  13. dstack/_internal/core/backends/datacrunch/compute.py +2 -0
  14. dstack/_internal/core/backends/digitalocean_base/compute.py +2 -0
  15. dstack/_internal/core/backends/features.py +5 -0
  16. dstack/_internal/core/backends/gcp/compute.py +87 -38
  17. dstack/_internal/core/backends/gcp/configurator.py +1 -1
  18. dstack/_internal/core/backends/gcp/models.py +14 -1
  19. dstack/_internal/core/backends/gcp/resources.py +35 -12
  20. dstack/_internal/core/backends/hotaisle/compute.py +22 -0
  21. dstack/_internal/core/backends/kubernetes/compute.py +531 -215
  22. dstack/_internal/core/backends/kubernetes/models.py +13 -16
  23. dstack/_internal/core/backends/kubernetes/utils.py +145 -8
  24. dstack/_internal/core/backends/lambdalabs/compute.py +2 -0
  25. dstack/_internal/core/backends/local/compute.py +2 -0
  26. dstack/_internal/core/backends/nebius/compute.py +17 -0
  27. dstack/_internal/core/backends/nebius/configurator.py +15 -0
  28. dstack/_internal/core/backends/nebius/models.py +57 -5
  29. dstack/_internal/core/backends/nebius/resources.py +45 -2
  30. dstack/_internal/core/backends/oci/compute.py +7 -1
  31. dstack/_internal/core/backends/oci/resources.py +8 -3
  32. dstack/_internal/core/backends/template/compute.py.jinja +2 -0
  33. dstack/_internal/core/backends/tensordock/compute.py +2 -0
  34. dstack/_internal/core/backends/vultr/compute.py +2 -0
  35. dstack/_internal/core/compatibility/runs.py +8 -0
  36. dstack/_internal/core/consts.py +2 -0
  37. dstack/_internal/core/models/profiles.py +11 -4
  38. dstack/_internal/core/services/repos.py +101 -11
  39. dstack/_internal/server/background/tasks/common.py +2 -0
  40. dstack/_internal/server/background/tasks/process_fleets.py +75 -17
  41. dstack/_internal/server/background/tasks/process_instances.py +3 -5
  42. dstack/_internal/server/background/tasks/process_running_jobs.py +1 -1
  43. dstack/_internal/server/background/tasks/process_runs.py +27 -23
  44. dstack/_internal/server/background/tasks/process_submitted_jobs.py +107 -54
  45. dstack/_internal/server/services/offers.py +7 -1
  46. dstack/_internal/server/testing/common.py +2 -0
  47. dstack/_internal/server/utils/provisioning.py +3 -10
  48. dstack/_internal/utils/ssh.py +22 -2
  49. dstack/version.py +2 -2
  50. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/METADATA +20 -18
  51. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/RECORD +54 -54
  52. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/WHEEL +0 -0
  53. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/entry_points.txt +0 -0
  54. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/licenses/LICENSE.md +0 -0
@@ -7,6 +7,7 @@ from typing import ClassVar, Optional
7
7
  from rich_argparse import RichHelpFormatter
8
8
 
9
9
  from dstack._internal.cli.services.completion import ProjectNameCompleter
10
+ from dstack._internal.cli.utils.common import configure_logging
10
11
  from dstack._internal.core.errors import CLIError
11
12
  from dstack.api import Client
12
13
 
@@ -52,9 +53,16 @@ class BaseCommand(ABC):
52
53
 
53
54
  @abstractmethod
54
55
  def _command(self, args: argparse.Namespace):
56
+ self._configure_logging()
55
57
  if not self.ACCEPT_EXTRA_ARGS and args.extra_args:
56
58
  raise CLIError(f"Unrecognized arguments: {shlex.join(args.extra_args)}")
57
59
 
60
+ def _configure_logging(self) -> None:
61
+ """
62
+ Override this method to configure command-specific logging
63
+ """
64
+ configure_logging()
65
+
58
66
 
59
67
  class APIBaseCommand(BaseCommand):
60
68
  api: Client
@@ -1,11 +1,12 @@
1
1
  import argparse
2
+ from typing import Any, Union
2
3
 
3
4
  from requests import HTTPError
4
5
  from rich.table import Table
5
6
 
6
7
  import dstack.api.server
7
8
  from dstack._internal.cli.commands import BaseCommand
8
- from dstack._internal.cli.utils.common import confirm_ask, console
9
+ from dstack._internal.cli.utils.common import add_row_from_dict, confirm_ask, console
9
10
  from dstack._internal.core.errors import ClientError, CLIError
10
11
  from dstack._internal.core.services.configs import ConfigManager
11
12
  from dstack._internal.utils.logging import get_logger
@@ -58,6 +59,10 @@ class ProjectCommand(BaseCommand):
58
59
  # List subcommand
59
60
  list_parser = subparsers.add_parser("list", help="List configured projects")
60
61
  list_parser.set_defaults(subfunc=self._list)
62
+ for parser in [self._parser, list_parser]:
63
+ parser.add_argument(
64
+ "-v", "--verbose", action="store_true", help="Show more information"
65
+ )
61
66
 
62
67
  # Set default subcommand
63
68
  set_default_parser = subparsers.add_parser("set-default", help="Set default project")
@@ -122,30 +127,32 @@ class ProjectCommand(BaseCommand):
122
127
  table = Table(box=None)
123
128
  table.add_column("PROJECT", style="bold", no_wrap=True)
124
129
  table.add_column("URL", style="grey58")
125
- table.add_column("USER", style="grey58")
130
+ if args.verbose:
131
+ table.add_column("USER", style="grey58")
126
132
  table.add_column("DEFAULT", justify="center")
127
133
 
128
134
  for project_config in config_manager.list_project_configs():
129
135
  project_name = project_config.name
130
136
  is_default = project_name == default_project.name if default_project else False
131
-
132
- # Get username from API
133
- try:
134
- api_client = dstack.api.server.APIClient(
135
- base_url=project_config.url, token=project_config.token
136
- )
137
- user_info = api_client.users.get_my_user()
138
- username = user_info.username
139
- except ClientError:
140
- username = "(invalid token)"
141
-
142
- table.add_row(
143
- project_name,
144
- project_config.url,
145
- username,
146
- "✓" if is_default else "",
147
- style="bold" if is_default else None,
148
- )
137
+ row: dict[Union[str, int], Any] = {
138
+ "PROJECT": project_name,
139
+ "URL": project_config.url,
140
+ "DEFAULT": "✓" if is_default else "",
141
+ }
142
+
143
+ if args.verbose:
144
+ # Get username from API
145
+ try:
146
+ api_client = dstack.api.server.APIClient(
147
+ base_url=project_config.url, token=project_config.token
148
+ )
149
+ user_info = api_client.users.get_my_user()
150
+ username = user_info.username
151
+ except ClientError:
152
+ username = "(invalid token)"
153
+ row["USER"] = username
154
+
155
+ add_row_from_dict(table, row, style="bold" if is_default else None)
149
156
 
150
157
  console.print(table)
151
158
 
@@ -82,3 +82,8 @@ class ServerCommand(BaseCommand):
82
82
  log_level=uvicorn_log_level,
83
83
  workers=1,
84
84
  )
85
+
86
+ def _configure_logging(self) -> None:
87
+ # Server logging is configured in the FastAPI lifespan function.
88
+ # No need to configure CLI logging.
89
+ pass
@@ -159,12 +159,19 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator[Fle
159
159
  console.print(
160
160
  get_fleets_table(
161
161
  [fleet],
162
- verbose=_failed_provisioning(fleet),
162
+ verbose=_fleet_has_failed_instances(fleet),
163
163
  format_date=local_time,
164
164
  )
165
165
  )
166
- if _failed_provisioning(fleet):
167
- console.print("\n[error]Some instances failed. Check the table above for errors.[/]")
166
+ if _fleet_has_failed_instances(fleet):
167
+ if _fleet_retrying(fleet):
168
+ console.print(
169
+ "\n[error]Some instances failed. Provisioning will be retried in the background.[/]"
170
+ )
171
+ else:
172
+ console.print(
173
+ "\n[error]Some instances failed. Check the table above for errors.[/]"
174
+ )
168
175
  exit(1)
169
176
 
170
177
  def _apply_plan_on_old_server(self, plan: FleetPlan, command_args: argparse.Namespace):
@@ -253,11 +260,11 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator[Fle
253
260
  console.print(
254
261
  get_fleets_table(
255
262
  [fleet],
256
- verbose=_failed_provisioning(fleet),
263
+ verbose=_fleet_has_failed_instances(fleet),
257
264
  format_date=local_time,
258
265
  )
259
266
  )
260
- if _failed_provisioning(fleet):
267
+ if _fleet_has_failed_instances(fleet):
261
268
  console.print("\n[error]Some instances failed. Check the table above for errors.[/]")
262
269
  exit(1)
263
270
 
@@ -462,13 +469,20 @@ def _finished_provisioning(fleet: Fleet) -> bool:
462
469
  return True
463
470
 
464
471
 
465
- def _failed_provisioning(fleet: Fleet) -> bool:
472
+ def _fleet_has_failed_instances(fleet: Fleet) -> bool:
466
473
  for instance in fleet.instances:
467
474
  if instance.status == InstanceStatus.TERMINATED:
468
475
  return True
469
476
  return False
470
477
 
471
478
 
479
+ def _fleet_retrying(fleet: Fleet) -> bool:
480
+ if fleet.spec.configuration.nodes is None:
481
+ return False
482
+ active_instances = [i for i in fleet.instances if i.status.is_active()]
483
+ return len(active_instances) < fleet.spec.configuration.nodes.min
484
+
485
+
472
486
  def _apply_plan(api: Client, plan: FleetPlan) -> Fleet:
473
487
  try:
474
488
  return api.client.fleets.apply_plan(
@@ -9,7 +9,7 @@ from dstack._internal.core.models.runs import Requirements, RunSpec, get_policy_
9
9
  from dstack._internal.server.schemas.gpus import GpuGroup
10
10
 
11
11
 
12
- def print_gpu_json(gpu_response, run_spec, group_by_cli, api_project):
12
+ def print_gpu_json(gpus, run_spec, group_by_cli, api_project):
13
13
  """Print GPU information in JSON format."""
14
14
  req = Requirements(
15
15
  resources=run_spec.configuration.resources,
@@ -36,7 +36,7 @@ def print_gpu_json(gpu_response, run_spec, group_by_cli, api_project):
36
36
  "gpus": [],
37
37
  }
38
38
 
39
- for gpu_group in gpu_response.gpus:
39
+ for gpu_group in gpus:
40
40
  gpu_data = {
41
41
  "name": gpu_group.name,
42
42
  "memory_mib": gpu_group.memory_mib,
@@ -24,6 +24,7 @@ from dstack._internal.core.backends.base.compute import (
24
24
  ComputeWithMultinodeSupport,
25
25
  ComputeWithPlacementGroupSupport,
26
26
  ComputeWithPrivateGatewaySupport,
27
+ ComputeWithPrivilegedSupport,
27
28
  ComputeWithReservationSupport,
28
29
  ComputeWithVolumeSupport,
29
30
  generate_unique_gateway_instance_name,
@@ -90,6 +91,7 @@ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs):
90
91
  class AWSCompute(
91
92
  ComputeWithAllOffersCached,
92
93
  ComputeWithCreateInstanceSupport,
94
+ ComputeWithPrivilegedSupport,
93
95
  ComputeWithMultinodeSupport,
94
96
  ComputeWithReservationSupport,
95
97
  ComputeWithPlacementGroupSupport,
@@ -291,7 +293,11 @@ class AWSCompute(
291
293
  image_id, username = self._get_image_id_and_username(
292
294
  ec2_client=ec2_client,
293
295
  region=instance_offer.region,
294
- cuda=len(instance_offer.instance.resources.gpus) > 0,
296
+ gpu_name=(
297
+ instance_offer.instance.resources.gpus[0].name
298
+ if len(instance_offer.instance.resources.gpus) > 0
299
+ else None
300
+ ),
295
301
  instance_type=instance_offer.instance.name,
296
302
  image_config=self.config.os_images,
297
303
  )
@@ -897,11 +903,13 @@ class AWSCompute(
897
903
  self,
898
904
  ec2_client: botocore.client.BaseClient,
899
905
  region: str,
900
- cuda: bool,
906
+ gpu_name: Optional[str],
901
907
  instance_type: str,
902
908
  image_config: Optional[AWSOSImageConfig] = None,
903
909
  ) -> tuple:
904
- return hashkey(region, cuda, instance_type, image_config.json() if image_config else None)
910
+ return hashkey(
911
+ region, gpu_name, instance_type, image_config.json() if image_config else None
912
+ )
905
913
 
906
914
  @cachedmethod(
907
915
  cache=lambda self: self._get_image_id_and_username_cache,
@@ -912,13 +920,13 @@ class AWSCompute(
912
920
  self,
913
921
  ec2_client: botocore.client.BaseClient,
914
922
  region: str,
915
- cuda: bool,
923
+ gpu_name: Optional[str],
916
924
  instance_type: str,
917
925
  image_config: Optional[AWSOSImageConfig] = None,
918
926
  ) -> tuple[str, str]:
919
927
  return aws_resources.get_image_id_and_username(
920
928
  ec2_client=ec2_client,
921
- cuda=cuda,
929
+ gpu_name=gpu_name,
922
930
  instance_type=instance_type,
923
931
  image_config=image_config,
924
932
  )
@@ -6,6 +6,8 @@ import botocore.exceptions
6
6
 
7
7
  import dstack.version as version
8
8
  from dstack._internal.core.backends.aws.models import AWSOSImageConfig
9
+ from dstack._internal.core.backends.base.compute import requires_nvidia_proprietary_kernel_modules
10
+ from dstack._internal.core.consts import DSTACK_OS_IMAGE_WITH_PROPRIETARY_NVIDIA_KERNEL_MODULES
9
11
  from dstack._internal.core.errors import BackendError, ComputeError, ComputeResourceNotFoundError
10
12
  from dstack._internal.utils.logging import get_logger
11
13
 
@@ -17,14 +19,14 @@ DLAMI_OWNER_ACCOUNT_ID = "898082745236"
17
19
 
18
20
  def get_image_id_and_username(
19
21
  ec2_client: botocore.client.BaseClient,
20
- cuda: bool,
22
+ gpu_name: Optional[str],
21
23
  instance_type: str,
22
24
  image_config: Optional[AWSOSImageConfig] = None,
23
25
  ) -> tuple[str, str]:
24
26
  if image_config is not None:
25
- image = image_config.nvidia if cuda else image_config.cpu
27
+ image = image_config.nvidia if gpu_name else image_config.cpu
26
28
  if image is None:
27
- logger.warning("%s image not configured", "nvidia" if cuda else "cpu")
29
+ logger.warning("%s image not configured", "nvidia" if gpu_name else "cpu")
28
30
  raise ComputeResourceNotFoundError()
29
31
  image_name = image.name
30
32
  image_owner = image.owner
@@ -35,9 +37,12 @@ def get_image_id_and_username(
35
37
  image_owner = DLAMI_OWNER_ACCOUNT_ID
36
38
  username = "ubuntu"
37
39
  else:
38
- image_name = (
39
- f"dstack-{version.base_image}" if not cuda else f"dstack-cuda-{version.base_image}"
40
- )
40
+ if gpu_name is None:
41
+ image_name = f"dstack-{version.base_image}"
42
+ elif not requires_nvidia_proprietary_kernel_modules(gpu_name):
43
+ image_name = f"dstack-cuda-{version.base_image}"
44
+ else:
45
+ image_name = f"dstack-cuda-{DSTACK_OS_IMAGE_WITH_PROPRIETARY_NVIDIA_KERNEL_MODULES}"
41
46
  image_owner = DSTACK_ACCOUNT_ID
42
47
  username = "ubuntu"
43
48
  response = ec2_client.describe_images(
@@ -43,13 +43,16 @@ from dstack._internal.core.backends.base.compute import (
43
43
  ComputeWithCreateInstanceSupport,
44
44
  ComputeWithGatewaySupport,
45
45
  ComputeWithMultinodeSupport,
46
+ ComputeWithPrivilegedSupport,
46
47
  generate_unique_gateway_instance_name,
47
48
  generate_unique_instance_name,
48
49
  get_gateway_user_data,
49
50
  get_user_data,
50
51
  merge_tags,
52
+ requires_nvidia_proprietary_kernel_modules,
51
53
  )
52
54
  from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier
55
+ from dstack._internal.core.consts import DSTACK_OS_IMAGE_WITH_PROPRIETARY_NVIDIA_KERNEL_MODULES
53
56
  from dstack._internal.core.errors import ComputeError, NoCapacityError
54
57
  from dstack._internal.core.models.backends.base import BackendType
55
58
  from dstack._internal.core.models.gateways import (
@@ -76,6 +79,7 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=Memory.parse("30GB"), max=Memory.pars
76
79
  class AzureCompute(
77
80
  ComputeWithAllOffersCached,
78
81
  ComputeWithCreateInstanceSupport,
82
+ ComputeWithPrivilegedSupport,
79
83
  ComputeWithMultinodeSupport,
80
84
  ComputeWithGatewaySupport,
81
85
  Compute,
@@ -372,6 +376,7 @@ def _parse_config_vpc_id(vpc_id: str) -> Tuple[str, str]:
372
376
  class VMImageVariant(enum.Enum):
373
377
  GRID = enum.auto()
374
378
  CUDA = enum.auto()
379
+ CUDA_WITH_PROPRIETARY_KERNEL_MODULES = enum.auto()
375
380
  STANDARD = enum.auto()
376
381
 
377
382
  @classmethod
@@ -379,18 +384,24 @@ class VMImageVariant(enum.Enum):
379
384
  if "_A10_v5" in instance.name:
380
385
  return cls.GRID
381
386
  elif len(instance.resources.gpus) > 0:
382
- return cls.CUDA
387
+ if not requires_nvidia_proprietary_kernel_modules(instance.resources.gpus[0].name):
388
+ return cls.CUDA
389
+ else:
390
+ return cls.CUDA_WITH_PROPRIETARY_KERNEL_MODULES
383
391
  else:
384
392
  return cls.STANDARD
385
393
 
386
394
  def get_image_name(self) -> str:
387
- name = "dstack-"
388
395
  if self is self.GRID:
389
- name += "grid-"
396
+ return f"dstack-grid-{version.base_image}"
390
397
  elif self is self.CUDA:
391
- name += "cuda-"
392
- name += version.base_image
393
- return name
398
+ return f"dstack-cuda-{version.base_image}"
399
+ elif self is self.CUDA_WITH_PROPRIETARY_KERNEL_MODULES:
400
+ return f"dstack-cuda-{DSTACK_OS_IMAGE_WITH_PROPRIETARY_NVIDIA_KERNEL_MODULES}"
401
+ elif self is self.STANDARD:
402
+ return f"dstack-{version.base_image}"
403
+ else:
404
+ raise ValueError(f"Unexpected image variant {self!r}")
394
405
 
395
406
 
396
407
  _SUPPORTED_VM_SERIES_PATTERNS = [
@@ -5,14 +5,16 @@ import string
5
5
  import threading
6
6
  from abc import ABC, abstractmethod
7
7
  from collections.abc import Iterable
8
+ from enum import Enum
8
9
  from functools import lru_cache
9
10
  from pathlib import Path
10
- from typing import Callable, Dict, List, Literal, Optional
11
+ from typing import Callable, Dict, List, Optional
11
12
 
12
13
  import git
13
14
  import requests
14
15
  import yaml
15
16
  from cachetools import TTLCache, cachedmethod
17
+ from gpuhunt import CPUArchitecture
16
18
 
17
19
  from dstack._internal import settings
18
20
  from dstack._internal.core.backends.base.offers import filter_offers_by_requirements
@@ -48,8 +50,38 @@ logger = get_logger(__name__)
48
50
  DSTACK_SHIM_BINARY_NAME = "dstack-shim"
49
51
  DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
50
52
  DEFAULT_PRIVATE_SUBNETS = ("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16")
53
+ NVIDIA_GPUS_REQUIRING_PROPRIETARY_KERNEL_MODULES = frozenset(
54
+ # All NVIDIA architectures prior to Turing do not support Open Kernel Modules and require
55
+ # proprietary modules. This list is incomplete, update when necessary.
56
+ [
57
+ "v100",
58
+ "p100",
59
+ "p40",
60
+ "p4",
61
+ "m60",
62
+ "m40",
63
+ "m4",
64
+ "k80",
65
+ "k40",
66
+ "k20",
67
+ ]
68
+ )
69
+
51
70
 
52
- GoArchType = Literal["amd64", "arm64"]
71
+ class GoArchType(str, Enum):
72
+ """
73
+ A subset of GOARCH values
74
+ """
75
+
76
+ AMD64 = "amd64"
77
+ ARM64 = "arm64"
78
+
79
+ def to_cpu_architecture(self) -> CPUArchitecture:
80
+ if self == self.AMD64:
81
+ return CPUArchitecture.X86
82
+ if self == self.ARM64:
83
+ return CPUArchitecture.ARM
84
+ assert False, self
53
85
 
54
86
 
55
87
  class Compute(ABC):
@@ -288,6 +320,15 @@ class ComputeWithCreateInstanceSupport(ABC):
288
320
  ]
289
321
 
290
322
 
323
+ class ComputeWithPrivilegedSupport:
324
+ """
325
+ Must be subclassed to support runs with `privileged: true`.
326
+ All VM-based Computes (that is, Computes that use the shim) should subclass this mixin.
327
+ """
328
+
329
+ pass
330
+
331
+
291
332
  class ComputeWithMultinodeSupport:
292
333
  """
293
334
  Must be subclassed to support multinode tasks and cluster fleets.
@@ -688,14 +729,14 @@ def normalize_arch(arch: Optional[str] = None) -> GoArchType:
688
729
  If the arch is not specified, falls back to `amd64`.
689
730
  """
690
731
  if not arch:
691
- return "amd64"
732
+ return GoArchType.AMD64
692
733
  arch_lower = arch.lower()
693
734
  if "32" in arch_lower or arch_lower in ["i386", "i686"]:
694
735
  raise ValueError(f"32-bit architectures are not supported: {arch}")
695
736
  if arch_lower.startswith("x86") or arch_lower.startswith("amd"):
696
- return "amd64"
737
+ return GoArchType.AMD64
697
738
  if arch_lower.startswith("arm") or arch_lower.startswith("aarch"):
698
- return "arm64"
739
+ return GoArchType.ARM64
699
740
  raise ValueError(f"Unsupported architecture: {arch}")
700
741
 
701
742
 
@@ -711,8 +752,7 @@ def get_dstack_runner_download_url(arch: Optional[str] = None) -> str:
711
752
  "/{version}/binaries/dstack-runner-linux-{arch}"
712
753
  )
713
754
  version = get_dstack_runner_version()
714
- arch = normalize_arch(arch)
715
- return url_template.format(version=version, arch=arch)
755
+ return url_template.format(version=version, arch=normalize_arch(arch).value)
716
756
 
717
757
 
718
758
  def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
@@ -727,8 +767,7 @@ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
727
767
  "/{version}/binaries/dstack-shim-linux-{arch}"
728
768
  )
729
769
  version = get_dstack_runner_version()
730
- arch = normalize_arch(arch)
731
- return url_template.format(version=version, arch=arch)
770
+ return url_template.format(version=version, arch=normalize_arch(arch).value)
732
771
 
733
772
 
734
773
  def get_setup_cloud_instance_commands(
@@ -969,3 +1008,12 @@ def merge_tags(
969
1008
  for k, v in resource_tags.items():
970
1009
  res.setdefault(k, v)
971
1010
  return res
1011
+
1012
+
1013
+ def requires_nvidia_proprietary_kernel_modules(gpu_name: str) -> bool:
1014
+ """
1015
+ Returns:
1016
+ Whether this NVIDIA GPU requires NVIDIA proprietary kernel modules
1017
+ instead of open kernel modules.
1018
+ """
1019
+ return gpu_name.lower() in NVIDIA_GPUS_REQUIRING_PROPRIETARY_KERNEL_MODULES
@@ -22,6 +22,7 @@ from dstack._internal.utils.common import get_or_error
22
22
  SUPPORTED_GPUHUNT_FLAGS = [
23
23
  "oci-spot",
24
24
  "lambda-arm",
25
+ "gcp-a4",
25
26
  ]
26
27
 
27
28
 
@@ -4,6 +4,7 @@ from dstack._internal.core.backends.base.compute import (
4
4
  Compute,
5
5
  ComputeWithAllOffersCached,
6
6
  ComputeWithCreateInstanceSupport,
7
+ ComputeWithPrivilegedSupport,
7
8
  get_shim_commands,
8
9
  )
9
10
  from dstack._internal.core.backends.base.offers import get_catalog_offers
@@ -27,6 +28,7 @@ logger = get_logger(__name__)
27
28
  class CloudRiftCompute(
28
29
  ComputeWithAllOffersCached,
29
30
  ComputeWithCreateInstanceSupport,
31
+ ComputeWithPrivilegedSupport,
30
32
  Compute,
31
33
  ):
32
34
  def __init__(self, config: CloudRiftConfig):
@@ -6,6 +6,7 @@ from dstack._internal.core.backends.base.backend import Compute
6
6
  from dstack._internal.core.backends.base.compute import (
7
7
  ComputeWithCreateInstanceSupport,
8
8
  ComputeWithFilteredOffersCached,
9
+ ComputeWithPrivilegedSupport,
9
10
  generate_unique_instance_name,
10
11
  get_shim_commands,
11
12
  )
@@ -32,6 +33,7 @@ MAX_RESOURCE_NAME_LEN = 30
32
33
  class CudoCompute(
33
34
  ComputeWithFilteredOffersCached,
34
35
  ComputeWithCreateInstanceSupport,
36
+ ComputeWithPrivilegedSupport,
35
37
  Compute,
36
38
  ):
37
39
  def __init__(self, config: CudoConfig):
@@ -8,6 +8,7 @@ from dstack._internal.core.backends.base.backend import Compute
8
8
  from dstack._internal.core.backends.base.compute import (
9
9
  ComputeWithAllOffersCached,
10
10
  ComputeWithCreateInstanceSupport,
11
+ ComputeWithPrivilegedSupport,
11
12
  generate_unique_instance_name,
12
13
  get_shim_commands,
13
14
  )
@@ -39,6 +40,7 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=IMAGE_SIZE, max=None)
39
40
  class DataCrunchCompute(
40
41
  ComputeWithAllOffersCached,
41
42
  ComputeWithCreateInstanceSupport,
43
+ ComputeWithPrivilegedSupport,
42
44
  Compute,
43
45
  ):
44
46
  def __init__(self, config: DataCrunchConfig):
@@ -7,6 +7,7 @@ from dstack._internal.core.backends.base.backend import Compute
7
7
  from dstack._internal.core.backends.base.compute import (
8
8
  ComputeWithAllOffersCached,
9
9
  ComputeWithCreateInstanceSupport,
10
+ ComputeWithPrivilegedSupport,
10
11
  generate_unique_instance_name,
11
12
  get_user_data,
12
13
  )
@@ -40,6 +41,7 @@ DOCKER_INSTALL_COMMANDS = [
40
41
  class BaseDigitalOceanCompute(
41
42
  ComputeWithAllOffersCached,
42
43
  ComputeWithCreateInstanceSupport,
44
+ ComputeWithPrivilegedSupport,
43
45
  Compute,
44
46
  ):
45
47
  def __init__(self, config: BaseDigitalOceanConfig, api_url: str, type: BackendType):
@@ -4,6 +4,7 @@ from dstack._internal.core.backends.base.compute import (
4
4
  ComputeWithMultinodeSupport,
5
5
  ComputeWithPlacementGroupSupport,
6
6
  ComputeWithPrivateGatewaySupport,
7
+ ComputeWithPrivilegedSupport,
7
8
  ComputeWithReservationSupport,
8
9
  ComputeWithVolumeSupport,
9
10
  )
@@ -38,6 +39,10 @@ BACKENDS_WITH_CREATE_INSTANCE_SUPPORT = _get_backends_with_compute_feature(
38
39
  configurator_classes=_configurator_classes,
39
40
  compute_feature_class=ComputeWithCreateInstanceSupport,
40
41
  )
42
+ BACKENDS_WITH_PRIVILEGED_SUPPORT = _get_backends_with_compute_feature(
43
+ configurator_classes=_configurator_classes,
44
+ compute_feature_class=ComputeWithPrivilegedSupport,
45
+ )
41
46
  BACKENDS_WITH_MULTINODE_SUPPORT = [BackendType.REMOTE] + _get_backends_with_compute_feature(
42
47
  configurator_classes=_configurator_classes,
43
48
  compute_feature_class=ComputeWithMultinodeSupport,