dstack 0.19.27__py3-none-any.whl → 0.19.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (74) hide show
  1. dstack/_internal/cli/commands/__init__.py +11 -8
  2. dstack/_internal/cli/commands/apply.py +6 -3
  3. dstack/_internal/cli/commands/completion.py +3 -1
  4. dstack/_internal/cli/commands/config.py +1 -0
  5. dstack/_internal/cli/commands/init.py +2 -2
  6. dstack/_internal/cli/commands/offer.py +1 -1
  7. dstack/_internal/cli/commands/project.py +1 -0
  8. dstack/_internal/cli/commands/server.py +2 -2
  9. dstack/_internal/cli/main.py +1 -1
  10. dstack/_internal/cli/services/configurators/base.py +2 -4
  11. dstack/_internal/cli/services/configurators/fleet.py +4 -5
  12. dstack/_internal/cli/services/configurators/gateway.py +3 -5
  13. dstack/_internal/cli/services/configurators/run.py +51 -27
  14. dstack/_internal/cli/services/configurators/volume.py +3 -5
  15. dstack/_internal/core/backends/aws/compute.py +51 -36
  16. dstack/_internal/core/backends/azure/compute.py +10 -7
  17. dstack/_internal/core/backends/base/compute.py +96 -14
  18. dstack/_internal/core/backends/base/offers.py +34 -4
  19. dstack/_internal/core/backends/cloudrift/compute.py +5 -7
  20. dstack/_internal/core/backends/cudo/compute.py +4 -2
  21. dstack/_internal/core/backends/datacrunch/compute.py +13 -11
  22. dstack/_internal/core/backends/digitalocean_base/compute.py +4 -5
  23. dstack/_internal/core/backends/gcp/compute.py +12 -7
  24. dstack/_internal/core/backends/hotaisle/compute.py +4 -7
  25. dstack/_internal/core/backends/kubernetes/compute.py +6 -4
  26. dstack/_internal/core/backends/lambdalabs/compute.py +4 -5
  27. dstack/_internal/core/backends/local/compute.py +1 -3
  28. dstack/_internal/core/backends/nebius/compute.py +10 -7
  29. dstack/_internal/core/backends/oci/compute.py +10 -7
  30. dstack/_internal/core/backends/runpod/compute.py +15 -6
  31. dstack/_internal/core/backends/template/compute.py.jinja +3 -1
  32. dstack/_internal/core/backends/tensordock/compute.py +1 -3
  33. dstack/_internal/core/backends/tensordock/models.py +2 -0
  34. dstack/_internal/core/backends/vastai/compute.py +7 -3
  35. dstack/_internal/core/backends/vultr/compute.py +5 -5
  36. dstack/_internal/core/compatibility/runs.py +2 -0
  37. dstack/_internal/core/models/common.py +67 -43
  38. dstack/_internal/core/models/configurations.py +88 -62
  39. dstack/_internal/core/models/fleets.py +41 -24
  40. dstack/_internal/core/models/instances.py +5 -5
  41. dstack/_internal/core/models/profiles.py +66 -47
  42. dstack/_internal/core/models/projects.py +8 -0
  43. dstack/_internal/core/models/repos/remote.py +21 -16
  44. dstack/_internal/core/models/resources.py +69 -65
  45. dstack/_internal/core/models/runs.py +17 -9
  46. dstack/_internal/server/app.py +5 -0
  47. dstack/_internal/server/background/tasks/process_fleets.py +8 -0
  48. dstack/_internal/server/background/tasks/process_instances.py +3 -2
  49. dstack/_internal/server/background/tasks/process_submitted_jobs.py +97 -34
  50. dstack/_internal/server/models.py +6 -5
  51. dstack/_internal/server/schemas/gateways.py +10 -9
  52. dstack/_internal/server/services/backends/__init__.py +1 -1
  53. dstack/_internal/server/services/backends/handlers.py +2 -0
  54. dstack/_internal/server/services/docker.py +8 -7
  55. dstack/_internal/server/services/projects.py +63 -4
  56. dstack/_internal/server/services/runs.py +2 -0
  57. dstack/_internal/server/settings.py +46 -0
  58. dstack/_internal/server/statics/index.html +1 -1
  59. dstack/_internal/server/statics/main-56191fbfe77f49b251de.css +3 -0
  60. dstack/_internal/server/statics/{main-4eecc75fbe64067eb1bc.js → main-c51afa7f243e24d3e446.js} +61115 -49101
  61. dstack/_internal/server/statics/{main-4eecc75fbe64067eb1bc.js.map → main-c51afa7f243e24d3e446.js.map} +1 -1
  62. dstack/_internal/utils/env.py +85 -11
  63. dstack/version.py +1 -1
  64. {dstack-0.19.27.dist-info → dstack-0.19.29.dist-info}/METADATA +1 -1
  65. {dstack-0.19.27.dist-info → dstack-0.19.29.dist-info}/RECORD +68 -73
  66. dstack/_internal/core/backends/tensordock/__init__.py +0 -0
  67. dstack/_internal/core/backends/tensordock/api_client.py +0 -104
  68. dstack/_internal/core/backends/tensordock/backend.py +0 -16
  69. dstack/_internal/core/backends/tensordock/configurator.py +0 -74
  70. dstack/_internal/server/statics/main-56191c63d516fd0041c4.css +0 -3
  71. dstack/_internal/server/statics/static/media/github.1f7102513534c83a9d8d735d2b8c12a2.svg +0 -3
  72. {dstack-0.19.27.dist-info → dstack-0.19.29.dist-info}/WHEEL +0 -0
  73. {dstack-0.19.27.dist-info → dstack-0.19.29.dist-info}/entry_points.txt +0 -0
  74. {dstack-0.19.27.dist-info → dstack-0.19.29.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,20 +1,22 @@
1
1
  import argparse
2
2
  import os
3
+ import shlex
3
4
  from abc import ABC, abstractmethod
4
- from typing import List, Optional
5
+ from typing import ClassVar, Optional
5
6
 
6
7
  from rich_argparse import RichHelpFormatter
7
8
 
8
9
  from dstack._internal.cli.services.completion import ProjectNameCompleter
9
- from dstack._internal.cli.utils.common import configure_logging
10
+ from dstack._internal.core.errors import CLIError
10
11
  from dstack.api import Client
11
12
 
12
13
 
13
14
  class BaseCommand(ABC):
14
- NAME: str = "name the command"
15
- DESCRIPTION: str = "describe the command"
16
- DEFAULT_HELP: bool = True
17
- ALIASES: Optional[List[str]] = None
15
+ NAME: ClassVar[str] = "name the command"
16
+ DESCRIPTION: ClassVar[str] = "describe the command"
17
+ DEFAULT_HELP: ClassVar[bool] = True
18
+ ALIASES: ClassVar[Optional[list[str]]] = None
19
+ ACCEPT_EXTRA_ARGS: ClassVar[bool] = False
18
20
 
19
21
  def __init__(self, parser: argparse.ArgumentParser):
20
22
  self._parser = parser
@@ -50,7 +52,8 @@ class BaseCommand(ABC):
50
52
 
51
53
  @abstractmethod
52
54
  def _command(self, args: argparse.Namespace):
53
- pass
55
+ if not self.ACCEPT_EXTRA_ARGS and args.extra_args:
56
+ raise CLIError(f"Unrecognized arguments: {shlex.join(args.extra_args)}")
54
57
 
55
58
 
56
59
  class APIBaseCommand(BaseCommand):
@@ -65,5 +68,5 @@ class APIBaseCommand(BaseCommand):
65
68
  ).completer = ProjectNameCompleter() # type: ignore[attr-defined]
66
69
 
67
70
  def _command(self, args: argparse.Namespace):
68
- configure_logging()
71
+ super()._command(args)
69
72
  self.api = Client.from_config(project_name=args.project)
@@ -1,4 +1,5 @@
1
1
  import argparse
2
+ import shlex
2
3
 
3
4
  from argcomplete import FilesCompleter # type: ignore[attr-defined]
4
5
 
@@ -19,6 +20,7 @@ class ApplyCommand(APIBaseCommand):
19
20
  NAME = "apply"
20
21
  DESCRIPTION = "Apply a configuration"
21
22
  DEFAULT_HELP = False
23
+ ACCEPT_EXTRA_ARGS = True
22
24
 
23
25
  def _register(self):
24
26
  super()._register()
@@ -84,13 +86,14 @@ class ApplyCommand(APIBaseCommand):
84
86
  configurator_class = get_apply_configurator_class(configuration.type)
85
87
  configurator = configurator_class(api_client=self.api)
86
88
  configurator_parser = configurator.get_parser()
87
- known, unknown = configurator_parser.parse_known_args(args.unknown)
89
+ configurator_args, unknown_args = configurator_parser.parse_known_args(args.extra_args)
90
+ if unknown_args:
91
+ raise CLIError(f"Unrecognized arguments: {shlex.join(unknown_args)}")
88
92
  configurator.apply_configuration(
89
93
  conf=configuration,
90
94
  configuration_path=configuration_path,
91
95
  command_args=args,
92
- configurator_args=known,
93
- unknown_args=unknown,
96
+ configurator_args=configurator_args,
94
97
  )
95
98
  except KeyboardInterrupt:
96
99
  console.print("\nOperation interrupted by user. Exiting...")
@@ -1,3 +1,5 @@
1
+ import argparse
2
+
1
3
  import argcomplete
2
4
 
3
5
  from dstack._internal.cli.commands import BaseCommand
@@ -15,6 +17,6 @@ class CompletionCommand(BaseCommand):
15
17
  choices=["bash", "zsh"],
16
18
  )
17
19
 
18
- def _command(self, args):
20
+ def _command(self, args: argparse.Namespace):
19
21
  super()._command(args)
20
22
  print(argcomplete.shellcode(["dstack"], shell=args.shell)) # type: ignore[attr-defined]
@@ -40,6 +40,7 @@ class ConfigCommand(BaseCommand):
40
40
  )
41
41
 
42
42
  def _command(self, args: argparse.Namespace):
43
+ super()._command(args)
43
44
  config_manager = ConfigManager()
44
45
  if args.remove:
45
46
  config_manager.delete_project(args.project)
@@ -9,7 +9,7 @@ from dstack._internal.cli.services.repos import (
9
9
  is_git_repo_url,
10
10
  register_init_repo_args,
11
11
  )
12
- from dstack._internal.cli.utils.common import configure_logging, confirm_ask, console, warn
12
+ from dstack._internal.cli.utils.common import confirm_ask, console, warn
13
13
  from dstack._internal.core.errors import ConfigurationError
14
14
  from dstack._internal.core.models.repos.remote import RemoteRepo
15
15
  from dstack._internal.core.services.configs import ConfigManager
@@ -52,7 +52,7 @@ class InitCommand(BaseCommand):
52
52
  )
53
53
 
54
54
  def _command(self, args: argparse.Namespace):
55
- configure_logging()
55
+ super()._command(args)
56
56
 
57
57
  repo_path: Optional[Path] = None
58
58
  repo_url: Optional[str] = None
@@ -99,7 +99,7 @@ class OfferCommand(APIBaseCommand):
99
99
  conf = TaskConfiguration(commands=[":"])
100
100
 
101
101
  configurator = OfferConfigurator(api_client=self.api)
102
- configurator.apply_args(conf, args, [])
102
+ configurator.apply_args(conf, args)
103
103
  profile = load_profile(Path.cwd(), profile_name=args.profile)
104
104
 
105
105
  run_spec = RunSpec(
@@ -67,6 +67,7 @@ class ProjectCommand(BaseCommand):
67
67
  set_default_parser.set_defaults(subfunc=self._set_default)
68
68
 
69
69
  def _command(self, args: argparse.Namespace):
70
+ super()._command(args)
70
71
  if not hasattr(args, "subfunc"):
71
72
  args.subfunc = self._list
72
73
  args.subfunc(args)
@@ -1,5 +1,5 @@
1
+ import argparse
1
2
  import os
2
- from argparse import Namespace
3
3
 
4
4
  from dstack._internal import settings
5
5
  from dstack._internal.cli.commands import BaseCommand
@@ -53,7 +53,7 @@ class ServerCommand(BaseCommand):
53
53
  )
54
54
  self._parser.add_argument("--token", type=str, help="The admin user token")
55
55
 
56
- def _command(self, args: Namespace):
56
+ def _command(self, args: argparse.Namespace):
57
57
  super()._command(args)
58
58
 
59
59
  if not UVICORN_INSTALLED:
@@ -83,7 +83,7 @@ def main():
83
83
  argcomplete.autocomplete(parser, always_complete_options=False)
84
84
 
85
85
  args, unknown_args = parser.parse_known_args()
86
- args.unknown = unknown_args
86
+ args.extra_args = unknown_args
87
87
 
88
88
  try:
89
89
  check_for_updates()
@@ -1,7 +1,7 @@
1
1
  import argparse
2
2
  import os
3
3
  from abc import ABC, abstractmethod
4
- from typing import Generic, List, TypeVar, Union, cast
4
+ from typing import ClassVar, Generic, List, TypeVar, Union, cast
5
5
 
6
6
  from dstack._internal.cli.services.args import env_var
7
7
  from dstack._internal.core.errors import ConfigurationError
@@ -18,7 +18,7 @@ ApplyConfigurationT = TypeVar("ApplyConfigurationT", bound=AnyApplyConfiguration
18
18
 
19
19
 
20
20
  class BaseApplyConfigurator(ABC, Generic[ApplyConfigurationT]):
21
- TYPE: ApplyConfigurationType
21
+ TYPE: ClassVar[ApplyConfigurationType]
22
22
 
23
23
  def __init__(self, api_client: Client):
24
24
  self.api = api_client
@@ -30,7 +30,6 @@ class BaseApplyConfigurator(ABC, Generic[ApplyConfigurationT]):
30
30
  configuration_path: str,
31
31
  command_args: argparse.Namespace,
32
32
  configurator_args: argparse.Namespace,
33
- unknown_args: List[str],
34
33
  ):
35
34
  """
36
35
  Implements `dstack apply` for a given configuration type.
@@ -40,7 +39,6 @@ class BaseApplyConfigurator(ABC, Generic[ApplyConfigurationT]):
40
39
  configuration_path: The path to the configuration file.
41
40
  command_args: The args parsed by `dstack apply`.
42
41
  configurator_args: The known args parsed by `cls.get_parser()`.
43
- unknown_args: The unknown args after parsing by `cls.get_parser()`.
44
42
  """
45
43
  pass
46
44
 
@@ -1,7 +1,7 @@
1
1
  import argparse
2
2
  import time
3
3
  from pathlib import Path
4
- from typing import List, Optional
4
+ from typing import Optional
5
5
 
6
6
  from rich.table import Table
7
7
 
@@ -46,7 +46,7 @@ logger = get_logger(__name__)
46
46
 
47
47
 
48
48
  class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator[FleetConfiguration]):
49
- TYPE: ApplyConfigurationType = ApplyConfigurationType.FLEET
49
+ TYPE = ApplyConfigurationType.FLEET
50
50
 
51
51
  def apply_configuration(
52
52
  self,
@@ -54,9 +54,8 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator[Fle
54
54
  configuration_path: str,
55
55
  command_args: argparse.Namespace,
56
56
  configurator_args: argparse.Namespace,
57
- unknown_args: List[str],
58
57
  ):
59
- self.apply_args(conf, configurator_args, unknown_args)
58
+ self.apply_args(conf, configurator_args)
60
59
  profile = load_profile(Path.cwd(), None)
61
60
  spec = FleetSpec(
62
61
  configuration=conf,
@@ -309,7 +308,7 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator[Fle
309
308
  )
310
309
  cls.register_env_args(configuration_group)
311
310
 
312
- def apply_args(self, conf: FleetConfiguration, args: argparse.Namespace, unknown: List[str]):
311
+ def apply_args(self, conf: FleetConfiguration, args: argparse.Namespace):
313
312
  if args.name:
314
313
  conf.name = args.name
315
314
  self.apply_env_vars(conf.env, args)
@@ -1,6 +1,5 @@
1
1
  import argparse
2
2
  import time
3
- from typing import List
4
3
 
5
4
  from rich.table import Table
6
5
 
@@ -27,7 +26,7 @@ from dstack.api._public import Client
27
26
 
28
27
 
29
28
  class GatewayConfigurator(BaseApplyConfigurator[GatewayConfiguration]):
30
- TYPE: ApplyConfigurationType = ApplyConfigurationType.GATEWAY
29
+ TYPE = ApplyConfigurationType.GATEWAY
31
30
 
32
31
  def apply_configuration(
33
32
  self,
@@ -35,9 +34,8 @@ class GatewayConfigurator(BaseApplyConfigurator[GatewayConfiguration]):
35
34
  configuration_path: str,
36
35
  command_args: argparse.Namespace,
37
36
  configurator_args: argparse.Namespace,
38
- unknown_args: List[str],
39
37
  ):
40
- self.apply_args(conf, configurator_args, unknown_args)
38
+ self.apply_args(conf, configurator_args)
41
39
  spec = GatewaySpec(
42
40
  configuration=conf,
43
41
  configuration_path=configuration_path,
@@ -179,7 +177,7 @@ class GatewayConfigurator(BaseApplyConfigurator[GatewayConfiguration]):
179
177
  help="The gateway name",
180
178
  )
181
179
 
182
- def apply_args(self, conf: GatewayConfiguration, args: argparse.Namespace, unknown: List[str]):
180
+ def apply_args(self, conf: GatewayConfiguration, args: argparse.Namespace):
183
181
  if args.name:
184
182
  conf.name = args.name
185
183
 
@@ -1,4 +1,5 @@
1
1
  import argparse
2
+ import shlex
2
3
  import subprocess
3
4
  import sys
4
5
  import time
@@ -35,6 +36,7 @@ from dstack._internal.core.models.configurations import (
35
36
  LEGACY_REPO_DIR,
36
37
  AnyRunConfiguration,
37
38
  ApplyConfigurationType,
39
+ ConfigurationWithCommandsParams,
38
40
  ConfigurationWithPortsParams,
39
41
  DevEnvironmentConfiguration,
40
42
  PortMapping,
@@ -80,20 +82,17 @@ class BaseRunConfigurator(
80
82
  ApplyEnvVarsConfiguratorMixin,
81
83
  BaseApplyConfigurator[RunConfigurationT],
82
84
  ):
83
- TYPE: ApplyConfigurationType
84
-
85
85
  def apply_configuration(
86
86
  self,
87
87
  conf: RunConfigurationT,
88
88
  configuration_path: str,
89
89
  command_args: argparse.Namespace,
90
90
  configurator_args: argparse.Namespace,
91
- unknown_args: List[str],
92
91
  ):
93
92
  if configurator_args.repo and configurator_args.no_repo:
94
93
  raise CLIError("Either --repo or --no-repo can be specified")
95
94
 
96
- self.apply_args(conf, configurator_args, unknown_args)
95
+ self.apply_args(conf, configurator_args)
97
96
  self.validate_gpu_vendor_and_image(conf)
98
97
  self.validate_cpu_arch_and_image(conf)
99
98
 
@@ -395,7 +394,7 @@ class BaseRunConfigurator(
395
394
  )
396
395
  register_init_repo_args(repo_group)
397
396
 
398
- def apply_args(self, conf: RunConfigurationT, args: argparse.Namespace, unknown: List[str]):
397
+ def apply_args(self, conf: RunConfigurationT, args: argparse.Namespace):
399
398
  apply_profile_args(args, conf)
400
399
  if args.run_name:
401
400
  conf.name = args.run_name
@@ -408,16 +407,6 @@ class BaseRunConfigurator(
408
407
 
409
408
  self.apply_env_vars(conf.env, args)
410
409
  self.interpolate_env(conf)
411
- self.interpolate_run_args(conf.setup, unknown)
412
-
413
- def interpolate_run_args(self, value: List[str], unknown):
414
- run_args = " ".join(unknown)
415
- interpolator = VariablesInterpolator({"run": {"args": run_args}}, skip=["secrets"])
416
- try:
417
- for i in range(len(value)):
418
- value[i] = interpolator.interpolate_or_error(value[i])
419
- except InterpolatorError as e:
420
- raise ConfigurationError(e.args[0])
421
410
 
422
411
  def interpolate_env(self, conf: RunConfigurationT):
423
412
  env_dict = conf.env.as_dict()
@@ -701,18 +690,50 @@ class RunWithPortsConfiguratorMixin:
701
690
  conf.ports = list(_merge_ports(conf.ports, args.ports).values())
702
691
 
703
692
 
704
- class TaskConfigurator(RunWithPortsConfiguratorMixin, BaseRunConfigurator):
693
+ class RunWithCommandsConfiguratorMixin:
694
+ @classmethod
695
+ def register_commands_args(cls, parser: argparse.ArgumentParser):
696
+ parser.add_argument(
697
+ "run_args",
698
+ help=(
699
+ "Run arguments. Available in the configuration [code]commands[/code] as"
700
+ " [code]${{ run.args }}[/code]."
701
+ " Use [code]--[/code] to separate run options from [code]dstack[/code] options"
702
+ ),
703
+ nargs="*",
704
+ metavar="RUN_ARGS",
705
+ )
706
+
707
+ def apply_commands_args(
708
+ self,
709
+ conf: ConfigurationWithCommandsParams,
710
+ args: argparse.Namespace,
711
+ ):
712
+ commands = conf.commands
713
+ run_args = shlex.join(args.run_args)
714
+ interpolator = VariablesInterpolator({"run": {"args": run_args}}, skip=["secrets"])
715
+ try:
716
+ for i, command in enumerate(commands):
717
+ commands[i] = interpolator.interpolate_or_error(command)
718
+ except InterpolatorError as e:
719
+ raise ConfigurationError(e.args[0])
720
+
721
+
722
+ class TaskConfigurator(
723
+ RunWithPortsConfiguratorMixin, RunWithCommandsConfiguratorMixin, BaseRunConfigurator
724
+ ):
705
725
  TYPE = ApplyConfigurationType.TASK
706
726
 
707
727
  @classmethod
708
728
  def register_args(cls, parser: argparse.ArgumentParser):
709
729
  super().register_args(parser)
710
730
  cls.register_ports_args(parser)
731
+ cls.register_commands_args(parser)
711
732
 
712
- def apply_args(self, conf: TaskConfiguration, args: argparse.Namespace, unknown: List[str]):
713
- super().apply_args(conf, args, unknown)
733
+ def apply_args(self, conf: TaskConfiguration, args: argparse.Namespace):
734
+ super().apply_args(conf, args)
714
735
  self.apply_ports_args(conf, args)
715
- self.interpolate_run_args(conf.commands, unknown)
736
+ self.apply_commands_args(conf, args)
716
737
 
717
738
 
718
739
  class DevEnvironmentConfigurator(RunWithPortsConfiguratorMixin, BaseRunConfigurator):
@@ -723,10 +744,8 @@ class DevEnvironmentConfigurator(RunWithPortsConfiguratorMixin, BaseRunConfigura
723
744
  super().register_args(parser)
724
745
  cls.register_ports_args(parser)
725
746
 
726
- def apply_args(
727
- self, conf: DevEnvironmentConfiguration, args: argparse.Namespace, unknown: List[str]
728
- ):
729
- super().apply_args(conf, args, unknown)
747
+ def apply_args(self, conf: DevEnvironmentConfiguration, args: argparse.Namespace):
748
+ super().apply_args(conf, args)
730
749
  self.apply_ports_args(conf, args)
731
750
  if conf.ide == "vscode" and conf.version is None:
732
751
  conf.version = _detect_vscode_version()
@@ -746,12 +765,17 @@ class DevEnvironmentConfigurator(RunWithPortsConfiguratorMixin, BaseRunConfigura
746
765
  )
747
766
 
748
767
 
749
- class ServiceConfigurator(BaseRunConfigurator):
768
+ class ServiceConfigurator(RunWithCommandsConfiguratorMixin, BaseRunConfigurator):
750
769
  TYPE = ApplyConfigurationType.SERVICE
751
770
 
752
- def apply_args(self, conf: ServiceConfiguration, args: argparse.Namespace, unknown: List[str]):
753
- super().apply_args(conf, args, unknown)
754
- self.interpolate_run_args(conf.commands, unknown)
771
+ @classmethod
772
+ def register_args(cls, parser: argparse.ArgumentParser):
773
+ super().register_args(parser)
774
+ cls.register_commands_args(parser)
775
+
776
+ def apply_args(self, conf: TaskConfiguration, args: argparse.Namespace):
777
+ super().apply_args(conf, args)
778
+ self.apply_commands_args(conf, args)
755
779
 
756
780
 
757
781
  def _merge_ports(conf: List[PortMapping], args: List[PortMapping]) -> Dict[int, PortMapping]:
@@ -1,6 +1,5 @@
1
1
  import argparse
2
2
  import time
3
- from typing import List
4
3
 
5
4
  from rich.table import Table
6
5
 
@@ -26,7 +25,7 @@ from dstack.api._public import Client
26
25
 
27
26
 
28
27
  class VolumeConfigurator(BaseApplyConfigurator[VolumeConfiguration]):
29
- TYPE: ApplyConfigurationType = ApplyConfigurationType.VOLUME
28
+ TYPE = ApplyConfigurationType.VOLUME
30
29
 
31
30
  def apply_configuration(
32
31
  self,
@@ -34,9 +33,8 @@ class VolumeConfigurator(BaseApplyConfigurator[VolumeConfiguration]):
34
33
  configuration_path: str,
35
34
  command_args: argparse.Namespace,
36
35
  configurator_args: argparse.Namespace,
37
- unknown_args: List[str],
38
36
  ):
39
- self.apply_args(conf, configurator_args, unknown_args)
37
+ self.apply_args(conf, configurator_args)
40
38
  spec = VolumeSpec(
41
39
  configuration=conf,
42
40
  configuration_path=configuration_path,
@@ -167,7 +165,7 @@ class VolumeConfigurator(BaseApplyConfigurator[VolumeConfiguration]):
167
165
  help="The volume name",
168
166
  )
169
167
 
170
- def apply_args(self, conf: VolumeConfiguration, args: argparse.Namespace, unknown: List[str]):
168
+ def apply_args(self, conf: VolumeConfiguration, args: argparse.Namespace):
171
169
  if args.name:
172
170
  conf.name = args.name
173
171
 
@@ -1,6 +1,6 @@
1
1
  import threading
2
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
3
- from typing import Any, Dict, List, Optional, Tuple
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple
4
4
 
5
5
  import boto3
6
6
  import botocore.client
@@ -18,6 +18,7 @@ from dstack._internal.core.backends.aws.models import (
18
18
  )
19
19
  from dstack._internal.core.backends.base.compute import (
20
20
  Compute,
21
+ ComputeWithAllOffersCached,
21
22
  ComputeWithCreateInstanceSupport,
22
23
  ComputeWithGatewaySupport,
23
24
  ComputeWithMultinodeSupport,
@@ -32,7 +33,7 @@ from dstack._internal.core.backends.base.compute import (
32
33
  get_user_data,
33
34
  merge_tags,
34
35
  )
35
- from dstack._internal.core.backends.base.offers import get_catalog_offers
36
+ from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier
36
37
  from dstack._internal.core.errors import (
37
38
  ComputeError,
38
39
  NoCapacityError,
@@ -87,6 +88,7 @@ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs):
87
88
 
88
89
 
89
90
  class AWSCompute(
91
+ ComputeWithAllOffersCached,
90
92
  ComputeWithCreateInstanceSupport,
91
93
  ComputeWithMultinodeSupport,
92
94
  ComputeWithReservationSupport,
@@ -109,6 +111,8 @@ class AWSCompute(
109
111
  # Caches to avoid redundant API calls when provisioning many instances
110
112
  # get_offers is already cached but we still cache its sub-functions
111
113
  # with more aggressive/longer caches.
114
+ self._offers_post_filter_cache_lock = threading.Lock()
115
+ self._offers_post_filter_cache = TTLCache(maxsize=10, ttl=180)
112
116
  self._get_regions_to_quotas_cache_lock = threading.Lock()
113
117
  self._get_regions_to_quotas_execution_lock = threading.Lock()
114
118
  self._get_regions_to_quotas_cache = TTLCache(maxsize=10, ttl=300)
@@ -125,43 +129,11 @@ class AWSCompute(
125
129
  self._get_image_id_and_username_cache_lock = threading.Lock()
126
130
  self._get_image_id_and_username_cache = TTLCache(maxsize=100, ttl=600)
127
131
 
128
- def get_offers(
129
- self, requirements: Optional[Requirements] = None
130
- ) -> List[InstanceOfferWithAvailability]:
131
- filter = _supported_instances
132
- if requirements and requirements.reservation:
133
- region_to_reservation = {}
134
- for region in self.config.regions:
135
- reservation = aws_resources.get_reservation(
136
- ec2_client=self.session.client("ec2", region_name=region),
137
- reservation_id=requirements.reservation,
138
- instance_count=1,
139
- )
140
- if reservation is not None:
141
- region_to_reservation[region] = reservation
142
-
143
- def _supported_instances_with_reservation(offer: InstanceOffer) -> bool:
144
- # Filter: only instance types supported by dstack
145
- if not _supported_instances(offer):
146
- return False
147
- # Filter: Spot instances can't be used with reservations
148
- if offer.instance.resources.spot:
149
- return False
150
- region = offer.region
151
- reservation = region_to_reservation.get(region)
152
- # Filter: only instance types matching the capacity reservation
153
- if not bool(reservation and offer.instance.name == reservation["InstanceType"]):
154
- return False
155
- return True
156
-
157
- filter = _supported_instances_with_reservation
158
-
132
+ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
159
133
  offers = get_catalog_offers(
160
134
  backend=BackendType.AWS,
161
135
  locations=self.config.regions,
162
- requirements=requirements,
163
- configurable_disk_size=CONFIGURABLE_DISK_SIZE,
164
- extra_filter=filter,
136
+ extra_filter=_supported_instances,
165
137
  )
166
138
  regions = list(set(i.region for i in offers))
167
139
  with self._get_regions_to_quotas_execution_lock:
@@ -185,6 +157,49 @@ class AWSCompute(
185
157
  )
186
158
  return availability_offers
187
159
 
160
+ def get_offers_modifier(
161
+ self, requirements: Requirements
162
+ ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]:
163
+ return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)
164
+
165
+ def _get_offers_cached_key(self, requirements: Requirements) -> int:
166
+ # Requirements is not hashable, so we use a hack to get arguments hash
167
+ return hash(requirements.json())
168
+
169
+ @cachedmethod(
170
+ cache=lambda self: self._offers_post_filter_cache,
171
+ key=_get_offers_cached_key,
172
+ lock=lambda self: self._offers_post_filter_cache_lock,
173
+ )
174
+ def get_offers_post_filter(
175
+ self, requirements: Requirements
176
+ ) -> Optional[Callable[[InstanceOfferWithAvailability], bool]]:
177
+ if requirements.reservation:
178
+ region_to_reservation = {}
179
+ for region in get_or_error(self.config.regions):
180
+ reservation = aws_resources.get_reservation(
181
+ ec2_client=self.session.client("ec2", region_name=region),
182
+ reservation_id=requirements.reservation,
183
+ instance_count=1,
184
+ )
185
+ if reservation is not None:
186
+ region_to_reservation[region] = reservation
187
+
188
+ def reservation_filter(offer: InstanceOfferWithAvailability) -> bool:
189
+ # Filter: Spot instances can't be used with reservations
190
+ if offer.instance.resources.spot:
191
+ return False
192
+ region = offer.region
193
+ reservation = region_to_reservation.get(region)
194
+ # Filter: only instance types matching the capacity reservation
195
+ if not bool(reservation and offer.instance.name == reservation["InstanceType"]):
196
+ return False
197
+ return True
198
+
199
+ return reservation_filter
200
+
201
+ return None
202
+
188
203
  def terminate_instance(
189
204
  self, instance_id: str, region: str, backend_data: Optional[str] = None
190
205
  ) -> None:
@@ -2,7 +2,7 @@ import base64
2
2
  import enum
3
3
  import re
4
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
- from typing import Dict, List, Optional, Tuple
5
+ from typing import Callable, Dict, List, Optional, Tuple
6
6
 
7
7
  from azure.core.credentials import TokenCredential
8
8
  from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
@@ -39,6 +39,7 @@ from dstack._internal.core.backends.azure import utils as azure_utils
39
39
  from dstack._internal.core.backends.azure.models import AzureConfig
40
40
  from dstack._internal.core.backends.base.compute import (
41
41
  Compute,
42
+ ComputeWithAllOffersCached,
42
43
  ComputeWithCreateInstanceSupport,
43
44
  ComputeWithGatewaySupport,
44
45
  ComputeWithMultinodeSupport,
@@ -48,7 +49,7 @@ from dstack._internal.core.backends.base.compute import (
48
49
  get_user_data,
49
50
  merge_tags,
50
51
  )
51
- from dstack._internal.core.backends.base.offers import get_catalog_offers
52
+ from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier
52
53
  from dstack._internal.core.errors import ComputeError, NoCapacityError
53
54
  from dstack._internal.core.models.backends.base import BackendType
54
55
  from dstack._internal.core.models.gateways import (
@@ -73,6 +74,7 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=Memory.parse("30GB"), max=Memory.pars
73
74
 
74
75
 
75
76
  class AzureCompute(
77
+ ComputeWithAllOffersCached,
76
78
  ComputeWithCreateInstanceSupport,
77
79
  ComputeWithMultinodeSupport,
78
80
  ComputeWithGatewaySupport,
@@ -89,14 +91,10 @@ class AzureCompute(
89
91
  credential=credential, subscription_id=config.subscription_id
90
92
  )
91
93
 
92
- def get_offers(
93
- self, requirements: Optional[Requirements] = None
94
- ) -> List[InstanceOfferWithAvailability]:
94
+ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
95
95
  offers = get_catalog_offers(
96
96
  backend=BackendType.AZURE,
97
97
  locations=self.config.regions,
98
- requirements=requirements,
99
- configurable_disk_size=CONFIGURABLE_DISK_SIZE,
100
98
  extra_filter=_supported_instances,
101
99
  )
102
100
  offers_with_availability = _get_offers_with_availability(
@@ -106,6 +104,11 @@ class AzureCompute(
106
104
  )
107
105
  return offers_with_availability
108
106
 
107
+ def get_offers_modifier(
108
+ self, requirements: Requirements
109
+ ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]:
110
+ return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)
111
+
109
112
  def create_instance(
110
113
  self,
111
114
  instance_offer: InstanceOfferWithAvailability,