dstack 0.19.28__py3-none-any.whl → 0.19.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (51) hide show
  1. dstack/_internal/cli/main.py +3 -1
  2. dstack/_internal/cli/services/configurators/fleet.py +20 -6
  3. dstack/_internal/cli/utils/gpu.py +2 -2
  4. dstack/_internal/core/backends/aws/compute.py +62 -41
  5. dstack/_internal/core/backends/aws/resources.py +11 -6
  6. dstack/_internal/core/backends/azure/compute.py +25 -13
  7. dstack/_internal/core/backends/base/compute.py +121 -14
  8. dstack/_internal/core/backends/base/offers.py +34 -4
  9. dstack/_internal/core/backends/cloudrift/compute.py +5 -7
  10. dstack/_internal/core/backends/cudo/compute.py +4 -2
  11. dstack/_internal/core/backends/datacrunch/compute.py +13 -11
  12. dstack/_internal/core/backends/digitalocean_base/compute.py +4 -5
  13. dstack/_internal/core/backends/gcp/compute.py +25 -11
  14. dstack/_internal/core/backends/hotaisle/compute.py +4 -7
  15. dstack/_internal/core/backends/kubernetes/compute.py +6 -4
  16. dstack/_internal/core/backends/lambdalabs/compute.py +4 -5
  17. dstack/_internal/core/backends/local/compute.py +1 -3
  18. dstack/_internal/core/backends/nebius/compute.py +10 -7
  19. dstack/_internal/core/backends/oci/compute.py +15 -8
  20. dstack/_internal/core/backends/oci/resources.py +8 -3
  21. dstack/_internal/core/backends/runpod/compute.py +15 -6
  22. dstack/_internal/core/backends/template/compute.py.jinja +3 -1
  23. dstack/_internal/core/backends/tensordock/compute.py +1 -3
  24. dstack/_internal/core/backends/tensordock/models.py +2 -0
  25. dstack/_internal/core/backends/vastai/compute.py +7 -3
  26. dstack/_internal/core/backends/vultr/compute.py +5 -5
  27. dstack/_internal/core/consts.py +2 -0
  28. dstack/_internal/core/models/projects.py +8 -0
  29. dstack/_internal/core/services/repos.py +101 -10
  30. dstack/_internal/server/background/tasks/process_instances.py +3 -2
  31. dstack/_internal/server/background/tasks/process_running_jobs.py +1 -1
  32. dstack/_internal/server/background/tasks/process_submitted_jobs.py +100 -47
  33. dstack/_internal/server/services/backends/__init__.py +1 -1
  34. dstack/_internal/server/services/projects.py +11 -3
  35. dstack/_internal/server/services/runs.py +2 -0
  36. dstack/_internal/server/statics/index.html +1 -1
  37. dstack/_internal/server/statics/main-56191fbfe77f49b251de.css +3 -0
  38. dstack/_internal/server/statics/{main-a2a16772fbf11a14d191.js → main-c51afa7f243e24d3e446.js} +61081 -49037
  39. dstack/_internal/server/statics/{main-a2a16772fbf11a14d191.js.map → main-c51afa7f243e24d3e446.js.map} +1 -1
  40. dstack/_internal/utils/ssh.py +22 -2
  41. dstack/version.py +2 -2
  42. {dstack-0.19.28.dist-info → dstack-0.19.30.dist-info}/METADATA +8 -6
  43. {dstack-0.19.28.dist-info → dstack-0.19.30.dist-info}/RECORD +46 -50
  44. dstack/_internal/core/backends/tensordock/__init__.py +0 -0
  45. dstack/_internal/core/backends/tensordock/api_client.py +0 -104
  46. dstack/_internal/core/backends/tensordock/backend.py +0 -16
  47. dstack/_internal/core/backends/tensordock/configurator.py +0 -74
  48. dstack/_internal/server/statics/main-5e0d56245c4bd241ec27.css +0 -3
  49. {dstack-0.19.28.dist-info → dstack-0.19.30.dist-info}/WHEEL +0 -0
  50. {dstack-0.19.28.dist-info → dstack-0.19.30.dist-info}/entry_points.txt +0 -0
  51. {dstack-0.19.28.dist-info → dstack-0.19.30.dist-info}/licenses/LICENSE.md +0 -0
@@ -22,7 +22,7 @@ from dstack._internal.cli.commands.server import ServerCommand
22
22
  from dstack._internal.cli.commands.stats import StatsCommand
23
23
  from dstack._internal.cli.commands.stop import StopCommand
24
24
  from dstack._internal.cli.commands.volume import VolumeCommand
25
- from dstack._internal.cli.utils.common import _colors, console
25
+ from dstack._internal.cli.utils.common import _colors, configure_logging, console
26
26
  from dstack._internal.cli.utils.updates import check_for_updates
27
27
  from dstack._internal.core.errors import ClientError, CLIError, ConfigurationError, SSHError
28
28
  from dstack._internal.core.services.ssh.client import get_ssh_client_info
@@ -39,6 +39,8 @@ def main():
39
39
  RichHelpFormatter.styles["argparse.groups"] = "bold grey74"
40
40
  RichHelpFormatter.styles["argparse.text"] = "grey74"
41
41
 
42
+ configure_logging()
43
+
42
44
  parser = argparse.ArgumentParser(
43
45
  description=(
44
46
  "Not sure where to start?"
@@ -159,12 +159,19 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator[Fle
159
159
  console.print(
160
160
  get_fleets_table(
161
161
  [fleet],
162
- verbose=_failed_provisioning(fleet),
162
+ verbose=_fleet_has_failed_instances(fleet),
163
163
  format_date=local_time,
164
164
  )
165
165
  )
166
- if _failed_provisioning(fleet):
167
- console.print("\n[error]Some instances failed. Check the table above for errors.[/]")
166
+ if _fleet_has_failed_instances(fleet):
167
+ if _fleet_retrying(fleet):
168
+ console.print(
169
+ "\n[error]Some instances failed. Provisioning will be retried in the background.[/]"
170
+ )
171
+ else:
172
+ console.print(
173
+ "\n[error]Some instances failed. Check the table above for errors.[/]"
174
+ )
168
175
  exit(1)
169
176
 
170
177
  def _apply_plan_on_old_server(self, plan: FleetPlan, command_args: argparse.Namespace):
@@ -253,11 +260,11 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator[Fle
253
260
  console.print(
254
261
  get_fleets_table(
255
262
  [fleet],
256
- verbose=_failed_provisioning(fleet),
263
+ verbose=_fleet_has_failed_instances(fleet),
257
264
  format_date=local_time,
258
265
  )
259
266
  )
260
- if _failed_provisioning(fleet):
267
+ if _fleet_has_failed_instances(fleet):
261
268
  console.print("\n[error]Some instances failed. Check the table above for errors.[/]")
262
269
  exit(1)
263
270
 
@@ -462,13 +469,20 @@ def _finished_provisioning(fleet: Fleet) -> bool:
462
469
  return True
463
470
 
464
471
 
465
- def _failed_provisioning(fleet: Fleet) -> bool:
472
+ def _fleet_has_failed_instances(fleet: Fleet) -> bool:
466
473
  for instance in fleet.instances:
467
474
  if instance.status == InstanceStatus.TERMINATED:
468
475
  return True
469
476
  return False
470
477
 
471
478
 
479
+ def _fleet_retrying(fleet: Fleet) -> bool:
480
+ if fleet.spec.configuration.nodes is None:
481
+ return False
482
+ active_instances = [i for i in fleet.instances if i.status.is_active()]
483
+ return len(active_instances) < fleet.spec.configuration.nodes.min
484
+
485
+
472
486
  def _apply_plan(api: Client, plan: FleetPlan) -> Fleet:
473
487
  try:
474
488
  return api.client.fleets.apply_plan(
@@ -9,7 +9,7 @@ from dstack._internal.core.models.runs import Requirements, RunSpec, get_policy_
9
9
  from dstack._internal.server.schemas.gpus import GpuGroup
10
10
 
11
11
 
12
- def print_gpu_json(gpu_response, run_spec, group_by_cli, api_project):
12
+ def print_gpu_json(gpus, run_spec, group_by_cli, api_project):
13
13
  """Print GPU information in JSON format."""
14
14
  req = Requirements(
15
15
  resources=run_spec.configuration.resources,
@@ -36,7 +36,7 @@ def print_gpu_json(gpu_response, run_spec, group_by_cli, api_project):
36
36
  "gpus": [],
37
37
  }
38
38
 
39
- for gpu_group in gpu_response.gpus:
39
+ for gpu_group in gpus:
40
40
  gpu_data = {
41
41
  "name": gpu_group.name,
42
42
  "memory_mib": gpu_group.memory_mib,
@@ -1,6 +1,6 @@
1
1
  import threading
2
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
3
- from typing import Any, Dict, List, Optional, Tuple
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple
4
4
 
5
5
  import boto3
6
6
  import botocore.client
@@ -18,6 +18,7 @@ from dstack._internal.core.backends.aws.models import (
18
18
  )
19
19
  from dstack._internal.core.backends.base.compute import (
20
20
  Compute,
21
+ ComputeWithAllOffersCached,
21
22
  ComputeWithCreateInstanceSupport,
22
23
  ComputeWithGatewaySupport,
23
24
  ComputeWithMultinodeSupport,
@@ -32,7 +33,7 @@ from dstack._internal.core.backends.base.compute import (
32
33
  get_user_data,
33
34
  merge_tags,
34
35
  )
35
- from dstack._internal.core.backends.base.offers import get_catalog_offers
36
+ from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier
36
37
  from dstack._internal.core.errors import (
37
38
  ComputeError,
38
39
  NoCapacityError,
@@ -87,6 +88,7 @@ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs):
87
88
 
88
89
 
89
90
  class AWSCompute(
91
+ ComputeWithAllOffersCached,
90
92
  ComputeWithCreateInstanceSupport,
91
93
  ComputeWithMultinodeSupport,
92
94
  ComputeWithReservationSupport,
@@ -109,6 +111,8 @@ class AWSCompute(
109
111
  # Caches to avoid redundant API calls when provisioning many instances
110
112
  # get_offers is already cached but we still cache its sub-functions
111
113
  # with more aggressive/longer caches.
114
+ self._offers_post_filter_cache_lock = threading.Lock()
115
+ self._offers_post_filter_cache = TTLCache(maxsize=10, ttl=180)
112
116
  self._get_regions_to_quotas_cache_lock = threading.Lock()
113
117
  self._get_regions_to_quotas_execution_lock = threading.Lock()
114
118
  self._get_regions_to_quotas_cache = TTLCache(maxsize=10, ttl=300)
@@ -125,43 +129,11 @@ class AWSCompute(
125
129
  self._get_image_id_and_username_cache_lock = threading.Lock()
126
130
  self._get_image_id_and_username_cache = TTLCache(maxsize=100, ttl=600)
127
131
 
128
- def get_offers(
129
- self, requirements: Optional[Requirements] = None
130
- ) -> List[InstanceOfferWithAvailability]:
131
- filter = _supported_instances
132
- if requirements and requirements.reservation:
133
- region_to_reservation = {}
134
- for region in self.config.regions:
135
- reservation = aws_resources.get_reservation(
136
- ec2_client=self.session.client("ec2", region_name=region),
137
- reservation_id=requirements.reservation,
138
- instance_count=1,
139
- )
140
- if reservation is not None:
141
- region_to_reservation[region] = reservation
142
-
143
- def _supported_instances_with_reservation(offer: InstanceOffer) -> bool:
144
- # Filter: only instance types supported by dstack
145
- if not _supported_instances(offer):
146
- return False
147
- # Filter: Spot instances can't be used with reservations
148
- if offer.instance.resources.spot:
149
- return False
150
- region = offer.region
151
- reservation = region_to_reservation.get(region)
152
- # Filter: only instance types matching the capacity reservation
153
- if not bool(reservation and offer.instance.name == reservation["InstanceType"]):
154
- return False
155
- return True
156
-
157
- filter = _supported_instances_with_reservation
158
-
132
+ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
159
133
  offers = get_catalog_offers(
160
134
  backend=BackendType.AWS,
161
135
  locations=self.config.regions,
162
- requirements=requirements,
163
- configurable_disk_size=CONFIGURABLE_DISK_SIZE,
164
- extra_filter=filter,
136
+ extra_filter=_supported_instances,
165
137
  )
166
138
  regions = list(set(i.region for i in offers))
167
139
  with self._get_regions_to_quotas_execution_lock:
@@ -185,6 +157,49 @@ class AWSCompute(
185
157
  )
186
158
  return availability_offers
187
159
 
160
+ def get_offers_modifier(
161
+ self, requirements: Requirements
162
+ ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]:
163
+ return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)
164
+
165
+ def _get_offers_cached_key(self, requirements: Requirements) -> int:
166
+ # Requirements is not hashable, so we use a hack to get arguments hash
167
+ return hash(requirements.json())
168
+
169
+ @cachedmethod(
170
+ cache=lambda self: self._offers_post_filter_cache,
171
+ key=_get_offers_cached_key,
172
+ lock=lambda self: self._offers_post_filter_cache_lock,
173
+ )
174
+ def get_offers_post_filter(
175
+ self, requirements: Requirements
176
+ ) -> Optional[Callable[[InstanceOfferWithAvailability], bool]]:
177
+ if requirements.reservation:
178
+ region_to_reservation = {}
179
+ for region in get_or_error(self.config.regions):
180
+ reservation = aws_resources.get_reservation(
181
+ ec2_client=self.session.client("ec2", region_name=region),
182
+ reservation_id=requirements.reservation,
183
+ instance_count=1,
184
+ )
185
+ if reservation is not None:
186
+ region_to_reservation[region] = reservation
187
+
188
+ def reservation_filter(offer: InstanceOfferWithAvailability) -> bool:
189
+ # Filter: Spot instances can't be used with reservations
190
+ if offer.instance.resources.spot:
191
+ return False
192
+ region = offer.region
193
+ reservation = region_to_reservation.get(region)
194
+ # Filter: only instance types matching the capacity reservation
195
+ if not bool(reservation and offer.instance.name == reservation["InstanceType"]):
196
+ return False
197
+ return True
198
+
199
+ return reservation_filter
200
+
201
+ return None
202
+
188
203
  def terminate_instance(
189
204
  self, instance_id: str, region: str, backend_data: Optional[str] = None
190
205
  ) -> None:
@@ -276,7 +291,11 @@ class AWSCompute(
276
291
  image_id, username = self._get_image_id_and_username(
277
292
  ec2_client=ec2_client,
278
293
  region=instance_offer.region,
279
- cuda=len(instance_offer.instance.resources.gpus) > 0,
294
+ gpu_name=(
295
+ instance_offer.instance.resources.gpus[0].name
296
+ if len(instance_offer.instance.resources.gpus) > 0
297
+ else None
298
+ ),
280
299
  instance_type=instance_offer.instance.name,
281
300
  image_config=self.config.os_images,
282
301
  )
@@ -882,11 +901,13 @@ class AWSCompute(
882
901
  self,
883
902
  ec2_client: botocore.client.BaseClient,
884
903
  region: str,
885
- cuda: bool,
904
+ gpu_name: Optional[str],
886
905
  instance_type: str,
887
906
  image_config: Optional[AWSOSImageConfig] = None,
888
907
  ) -> tuple:
889
- return hashkey(region, cuda, instance_type, image_config.json() if image_config else None)
908
+ return hashkey(
909
+ region, gpu_name, instance_type, image_config.json() if image_config else None
910
+ )
890
911
 
891
912
  @cachedmethod(
892
913
  cache=lambda self: self._get_image_id_and_username_cache,
@@ -897,13 +918,13 @@ class AWSCompute(
897
918
  self,
898
919
  ec2_client: botocore.client.BaseClient,
899
920
  region: str,
900
- cuda: bool,
921
+ gpu_name: Optional[str],
901
922
  instance_type: str,
902
923
  image_config: Optional[AWSOSImageConfig] = None,
903
924
  ) -> tuple[str, str]:
904
925
  return aws_resources.get_image_id_and_username(
905
926
  ec2_client=ec2_client,
906
- cuda=cuda,
927
+ gpu_name=gpu_name,
907
928
  instance_type=instance_type,
908
929
  image_config=image_config,
909
930
  )
@@ -6,6 +6,8 @@ import botocore.exceptions
6
6
 
7
7
  import dstack.version as version
8
8
  from dstack._internal.core.backends.aws.models import AWSOSImageConfig
9
+ from dstack._internal.core.backends.base.compute import requires_nvidia_proprietary_kernel_modules
10
+ from dstack._internal.core.consts import DSTACK_OS_IMAGE_WITH_PROPRIETARY_NVIDIA_KERNEL_MODULES
9
11
  from dstack._internal.core.errors import BackendError, ComputeError, ComputeResourceNotFoundError
10
12
  from dstack._internal.utils.logging import get_logger
11
13
 
@@ -17,14 +19,14 @@ DLAMI_OWNER_ACCOUNT_ID = "898082745236"
17
19
 
18
20
  def get_image_id_and_username(
19
21
  ec2_client: botocore.client.BaseClient,
20
- cuda: bool,
22
+ gpu_name: Optional[str],
21
23
  instance_type: str,
22
24
  image_config: Optional[AWSOSImageConfig] = None,
23
25
  ) -> tuple[str, str]:
24
26
  if image_config is not None:
25
- image = image_config.nvidia if cuda else image_config.cpu
27
+ image = image_config.nvidia if gpu_name else image_config.cpu
26
28
  if image is None:
27
- logger.warning("%s image not configured", "nvidia" if cuda else "cpu")
29
+ logger.warning("%s image not configured", "nvidia" if gpu_name else "cpu")
28
30
  raise ComputeResourceNotFoundError()
29
31
  image_name = image.name
30
32
  image_owner = image.owner
@@ -35,9 +37,12 @@ def get_image_id_and_username(
35
37
  image_owner = DLAMI_OWNER_ACCOUNT_ID
36
38
  username = "ubuntu"
37
39
  else:
38
- image_name = (
39
- f"dstack-{version.base_image}" if not cuda else f"dstack-cuda-{version.base_image}"
40
- )
40
+ if gpu_name is None:
41
+ image_name = f"dstack-{version.base_image}"
42
+ elif not requires_nvidia_proprietary_kernel_modules(gpu_name):
43
+ image_name = f"dstack-cuda-{version.base_image}"
44
+ else:
45
+ image_name = f"dstack-cuda-{DSTACK_OS_IMAGE_WITH_PROPRIETARY_NVIDIA_KERNEL_MODULES}"
41
46
  image_owner = DSTACK_ACCOUNT_ID
42
47
  username = "ubuntu"
43
48
  response = ec2_client.describe_images(
@@ -2,7 +2,7 @@ import base64
2
2
  import enum
3
3
  import re
4
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
- from typing import Dict, List, Optional, Tuple
5
+ from typing import Callable, Dict, List, Optional, Tuple
6
6
 
7
7
  from azure.core.credentials import TokenCredential
8
8
  from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
@@ -39,6 +39,7 @@ from dstack._internal.core.backends.azure import utils as azure_utils
39
39
  from dstack._internal.core.backends.azure.models import AzureConfig
40
40
  from dstack._internal.core.backends.base.compute import (
41
41
  Compute,
42
+ ComputeWithAllOffersCached,
42
43
  ComputeWithCreateInstanceSupport,
43
44
  ComputeWithGatewaySupport,
44
45
  ComputeWithMultinodeSupport,
@@ -47,8 +48,10 @@ from dstack._internal.core.backends.base.compute import (
47
48
  get_gateway_user_data,
48
49
  get_user_data,
49
50
  merge_tags,
51
+ requires_nvidia_proprietary_kernel_modules,
50
52
  )
51
- from dstack._internal.core.backends.base.offers import get_catalog_offers
53
+ from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier
54
+ from dstack._internal.core.consts import DSTACK_OS_IMAGE_WITH_PROPRIETARY_NVIDIA_KERNEL_MODULES
52
55
  from dstack._internal.core.errors import ComputeError, NoCapacityError
53
56
  from dstack._internal.core.models.backends.base import BackendType
54
57
  from dstack._internal.core.models.gateways import (
@@ -73,6 +76,7 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=Memory.parse("30GB"), max=Memory.pars
73
76
 
74
77
 
75
78
  class AzureCompute(
79
+ ComputeWithAllOffersCached,
76
80
  ComputeWithCreateInstanceSupport,
77
81
  ComputeWithMultinodeSupport,
78
82
  ComputeWithGatewaySupport,
@@ -89,14 +93,10 @@ class AzureCompute(
89
93
  credential=credential, subscription_id=config.subscription_id
90
94
  )
91
95
 
92
- def get_offers(
93
- self, requirements: Optional[Requirements] = None
94
- ) -> List[InstanceOfferWithAvailability]:
96
+ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
95
97
  offers = get_catalog_offers(
96
98
  backend=BackendType.AZURE,
97
99
  locations=self.config.regions,
98
- requirements=requirements,
99
- configurable_disk_size=CONFIGURABLE_DISK_SIZE,
100
100
  extra_filter=_supported_instances,
101
101
  )
102
102
  offers_with_availability = _get_offers_with_availability(
@@ -106,6 +106,11 @@ class AzureCompute(
106
106
  )
107
107
  return offers_with_availability
108
108
 
109
+ def get_offers_modifier(
110
+ self, requirements: Requirements
111
+ ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]:
112
+ return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)
113
+
109
114
  def create_instance(
110
115
  self,
111
116
  instance_offer: InstanceOfferWithAvailability,
@@ -369,6 +374,7 @@ def _parse_config_vpc_id(vpc_id: str) -> Tuple[str, str]:
369
374
  class VMImageVariant(enum.Enum):
370
375
  GRID = enum.auto()
371
376
  CUDA = enum.auto()
377
+ CUDA_WITH_PROPRIETARY_KERNEL_MODULES = enum.auto()
372
378
  STANDARD = enum.auto()
373
379
 
374
380
  @classmethod
@@ -376,18 +382,24 @@ class VMImageVariant(enum.Enum):
376
382
  if "_A10_v5" in instance.name:
377
383
  return cls.GRID
378
384
  elif len(instance.resources.gpus) > 0:
379
- return cls.CUDA
385
+ if not requires_nvidia_proprietary_kernel_modules(instance.resources.gpus[0].name):
386
+ return cls.CUDA
387
+ else:
388
+ return cls.CUDA_WITH_PROPRIETARY_KERNEL_MODULES
380
389
  else:
381
390
  return cls.STANDARD
382
391
 
383
392
  def get_image_name(self) -> str:
384
- name = "dstack-"
385
393
  if self is self.GRID:
386
- name += "grid-"
394
+ return f"dstack-grid-{version.base_image}"
387
395
  elif self is self.CUDA:
388
- name += "cuda-"
389
- name += version.base_image
390
- return name
396
+ return f"dstack-cuda-{version.base_image}"
397
+ elif self is self.CUDA_WITH_PROPRIETARY_KERNEL_MODULES:
398
+ return f"dstack-cuda-{DSTACK_OS_IMAGE_WITH_PROPRIETARY_NVIDIA_KERNEL_MODULES}"
399
+ elif self is self.STANDARD:
400
+ return f"dstack-{version.base_image}"
401
+ else:
402
+ raise ValueError(f"Unexpected image variant {self!r}")
391
403
 
392
404
 
393
405
  _SUPPORTED_VM_SERIES_PATTERNS = [
@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
7
7
  from collections.abc import Iterable
8
8
  from functools import lru_cache
9
9
  from pathlib import Path
10
- from typing import Dict, List, Literal, Optional
10
+ from typing import Callable, Dict, List, Literal, Optional
11
11
 
12
12
  import git
13
13
  import requests
@@ -15,6 +15,7 @@ import yaml
15
15
  from cachetools import TTLCache, cachedmethod
16
16
 
17
17
  from dstack._internal import settings
18
+ from dstack._internal.core.backends.base.offers import filter_offers_by_requirements
18
19
  from dstack._internal.core.consts import (
19
20
  DSTACK_RUNNER_HTTP_PORT,
20
21
  DSTACK_RUNNER_SSH_PORT,
@@ -47,6 +48,22 @@ logger = get_logger(__name__)
47
48
  DSTACK_SHIM_BINARY_NAME = "dstack-shim"
48
49
  DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
49
50
  DEFAULT_PRIVATE_SUBNETS = ("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16")
51
+ NVIDIA_GPUS_REQUIRING_PROPRIETARY_KERNEL_MODULES = frozenset(
52
+ # All NVIDIA architectures prior to Turing do not support Open Kernel Modules and require
53
+ # proprietary modules. This list is incomplete, update when necessary.
54
+ [
55
+ "v100",
56
+ "p100",
57
+ "p40",
58
+ "p4",
59
+ "m60",
60
+ "m40",
61
+ "m4",
62
+ "k80",
63
+ "k40",
64
+ "k20",
65
+ ]
66
+ )
50
67
 
51
68
  GoArchType = Literal["amd64", "arm64"]
52
69
 
@@ -57,14 +74,8 @@ class Compute(ABC):
57
74
  If a compute supports additional features, it must also subclass `ComputeWith*` classes.
58
75
  """
59
76
 
60
- def __init__(self):
61
- self._offers_cache_lock = threading.Lock()
62
- self._offers_cache = TTLCache(maxsize=10, ttl=180)
63
-
64
77
  @abstractmethod
65
- def get_offers(
66
- self, requirements: Optional[Requirements] = None
67
- ) -> List[InstanceOfferWithAvailability]:
78
+ def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
68
79
  """
69
80
  Returns offers with availability matching `requirements`.
70
81
  If the provider is added to gpuhunt, typically gets offers using `base.offers.get_catalog_offers()`
@@ -121,10 +132,97 @@ class Compute(ABC):
121
132
  """
122
133
  pass
123
134
 
124
- def _get_offers_cached_key(self, requirements: Optional[Requirements] = None) -> int:
135
+
136
+ class ComputeWithAllOffersCached(ABC):
137
+ """
138
+ Provides common `get_offers()` implementation for backends
139
+ whose offers do not depend on requirements.
140
+ It caches all offers with availability and post-filters by requirements.
141
+ """
142
+
143
+ def __init__(self) -> None:
144
+ super().__init__()
145
+ self._offers_cache_lock = threading.Lock()
146
+ self._offers_cache = TTLCache(maxsize=1, ttl=180)
147
+
148
+ @abstractmethod
149
+ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
150
+ """
151
+ Returns all backend offers with availability.
152
+ """
153
+ pass
154
+
155
+ def get_offers_modifier(
156
+ self, requirements: Requirements
157
+ ) -> Optional[
158
+ Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]
159
+ ]:
160
+ """
161
+ Returns a modifier function that modifies offers before they are filtered by requirements.
162
+ Can return `None` to exclude the offer.
163
+ E.g. can be used to set appropriate disk size based on requirements.
164
+ """
165
+ return None
166
+
167
+ def get_offers_post_filter(
168
+ self, requirements: Requirements
169
+ ) -> Optional[Callable[[InstanceOfferWithAvailability], bool]]:
170
+ """
171
+ Returns a filter function to apply to offers based on requirements.
172
+ This allows backends to implement custom post-filtering logic for specific requirements.
173
+ """
174
+ return None
175
+
176
+ def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
177
+ offers = self._get_all_offers_with_availability_cached()
178
+ modifier = self.get_offers_modifier(requirements)
179
+ if modifier is not None:
180
+ modified_offers = []
181
+ for o in offers:
182
+ modified_offer = modifier(o)
183
+ if modified_offer is not None:
184
+ modified_offers.append(modified_offer)
185
+ offers = modified_offers
186
+ offers = filter_offers_by_requirements(offers, requirements)
187
+ post_filter = self.get_offers_post_filter(requirements)
188
+ if post_filter is not None:
189
+ offers = [o for o in offers if post_filter(o)]
190
+ return offers
191
+
192
+ @cachedmethod(
193
+ cache=lambda self: self._offers_cache,
194
+ lock=lambda self: self._offers_cache_lock,
195
+ )
196
+ def _get_all_offers_with_availability_cached(self) -> List[InstanceOfferWithAvailability]:
197
+ return self.get_all_offers_with_availability()
198
+
199
+
200
+ class ComputeWithFilteredOffersCached(ABC):
201
+ """
202
+ Provides common `get_offers()` implementation for backends
203
+ whose offers depend on requirements.
204
+ It caches offers using requirements as key.
205
+ """
206
+
207
+ def __init__(self) -> None:
208
+ super().__init__()
209
+ self._offers_cache_lock = threading.Lock()
210
+ self._offers_cache = TTLCache(maxsize=10, ttl=180)
211
+
212
+ @abstractmethod
213
+ def get_offers_by_requirements(
214
+ self, requirements: Requirements
215
+ ) -> List[InstanceOfferWithAvailability]:
216
+ """
217
+ Returns backend offers with availability matching requirements.
218
+ """
219
+ pass
220
+
221
+ def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
222
+ return self._get_offers_cached(requirements)
223
+
224
+ def _get_offers_cached_key(self, requirements: Requirements) -> int:
125
225
  # Requirements is not hashable, so we use a hack to get arguments hash
126
- if requirements is None:
127
- return hash(None)
128
226
  return hash(requirements.json())
129
227
 
130
228
  @cachedmethod(
@@ -132,10 +230,10 @@ class Compute(ABC):
132
230
  key=_get_offers_cached_key,
133
231
  lock=lambda self: self._offers_cache_lock,
134
232
  )
135
- def get_offers_cached(
136
- self, requirements: Optional[Requirements] = None
233
+ def _get_offers_cached(
234
+ self, requirements: Requirements
137
235
  ) -> List[InstanceOfferWithAvailability]:
138
- return self.get_offers(requirements)
236
+ return self.get_offers_by_requirements(requirements)
139
237
 
140
238
 
141
239
  class ComputeWithCreateInstanceSupport(ABC):
@@ -887,3 +985,12 @@ def merge_tags(
887
985
  for k, v in resource_tags.items():
888
986
  res.setdefault(k, v)
889
987
  return res
988
+
989
+
990
+ def requires_nvidia_proprietary_kernel_modules(gpu_name: str) -> bool:
991
+ """
992
+ Returns:
993
+ Whether this NVIDIA GPU requires NVIDIA proprietary kernel modules
994
+ instead of open kernel modules.
995
+ """
996
+ return gpu_name.lower() in NVIDIA_GPUS_REQUIRING_PROPRIETARY_KERNEL_MODULES
@@ -1,5 +1,5 @@
1
1
  from dataclasses import asdict
2
- from typing import Callable, List, Optional
2
+ from typing import Callable, List, Optional, TypeVar
3
3
 
4
4
  import gpuhunt
5
5
  from pydantic import parse_obj_as
@@ -9,11 +9,13 @@ from dstack._internal.core.models.instances import (
9
9
  Disk,
10
10
  Gpu,
11
11
  InstanceOffer,
12
+ InstanceOfferWithAvailability,
12
13
  InstanceType,
13
14
  Resources,
14
15
  )
15
16
  from dstack._internal.core.models.resources import DEFAULT_DISK, CPUSpec, Memory, Range
16
17
  from dstack._internal.core.models.runs import Requirements
18
+ from dstack._internal.utils.common import get_or_error
17
19
 
18
20
  # Offers not supported by all dstack versions are hidden behind one or more flags.
19
21
  # This list enables the flags that are currently supported.
@@ -163,9 +165,13 @@ def requirements_to_query_filter(req: Optional[Requirements]) -> gpuhunt.QueryFi
163
165
  return q
164
166
 
165
167
 
166
- def match_requirements(
167
- offers: List[InstanceOffer], requirements: Optional[Requirements]
168
- ) -> List[InstanceOffer]:
168
+ InstanceOfferT = TypeVar("InstanceOfferT", InstanceOffer, InstanceOfferWithAvailability)
169
+
170
+
171
+ def filter_offers_by_requirements(
172
+ offers: List[InstanceOfferT],
173
+ requirements: Optional[Requirements],
174
+ ) -> List[InstanceOfferT]:
169
175
  query_filter = requirements_to_query_filter(requirements)
170
176
  filtered_offers = []
171
177
  for offer in offers:
@@ -190,3 +196,27 @@ def choose_disk_size_mib(
190
196
  disk_size_gib = disk_size_range.min
191
197
 
192
198
  return round(disk_size_gib * 1024)
199
+
200
+
201
+ def get_offers_disk_modifier(
202
+ configurable_disk_size: Range[Memory], requirements: Requirements
203
+ ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]:
204
+ """
205
+ Returns a func that modifies offers disk by setting min value that satisfies both
206
+ `configurable_disk_size` and `requirements`.
207
+ """
208
+
209
+ def modifier(offer: InstanceOfferWithAvailability) -> Optional[InstanceOfferWithAvailability]:
210
+ requirements_disk_range = DEFAULT_DISK.size
211
+ if requirements.resources.disk is not None:
212
+ requirements_disk_range = requirements.resources.disk.size
213
+ disk_size_range = requirements_disk_range.intersect(configurable_disk_size)
214
+ if disk_size_range is None:
215
+ return None
216
+ offer_copy = offer.copy(deep=True)
217
+ offer_copy.instance.resources.disk = Disk(
218
+ size_mib=get_or_error(disk_size_range.min) * 1024
219
+ )
220
+ return offer_copy
221
+
222
+ return modifier