dstack 0.19.9__py3-none-any.whl → 0.19.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (53) hide show
  1. dstack/_internal/cli/commands/config.py +1 -1
  2. dstack/_internal/cli/commands/metrics.py +25 -10
  3. dstack/_internal/cli/commands/offer.py +2 -0
  4. dstack/_internal/cli/commands/project.py +161 -0
  5. dstack/_internal/cli/commands/ps.py +9 -2
  6. dstack/_internal/cli/main.py +2 -0
  7. dstack/_internal/cli/services/configurators/run.py +1 -1
  8. dstack/_internal/cli/utils/updates.py +13 -1
  9. dstack/_internal/core/backends/aws/compute.py +21 -9
  10. dstack/_internal/core/backends/azure/compute.py +8 -3
  11. dstack/_internal/core/backends/base/compute.py +9 -4
  12. dstack/_internal/core/backends/gcp/compute.py +43 -20
  13. dstack/_internal/core/backends/gcp/resources.py +18 -2
  14. dstack/_internal/core/backends/local/compute.py +4 -2
  15. dstack/_internal/core/models/configurations.py +21 -4
  16. dstack/_internal/core/models/runs.py +2 -1
  17. dstack/_internal/proxy/gateway/resources/nginx/00-log-format.conf +11 -1
  18. dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 +12 -6
  19. dstack/_internal/proxy/gateway/services/stats.py +17 -3
  20. dstack/_internal/server/background/tasks/process_metrics.py +23 -21
  21. dstack/_internal/server/background/tasks/process_submitted_jobs.py +24 -15
  22. dstack/_internal/server/migrations/versions/bca2fdf130bf_add_runmodel_priority.py +34 -0
  23. dstack/_internal/server/models.py +1 -0
  24. dstack/_internal/server/routers/repos.py +13 -4
  25. dstack/_internal/server/services/fleets.py +2 -2
  26. dstack/_internal/server/services/gateways/__init__.py +1 -1
  27. dstack/_internal/server/services/instances.py +6 -2
  28. dstack/_internal/server/services/jobs/__init__.py +4 -4
  29. dstack/_internal/server/services/jobs/configurators/base.py +18 -4
  30. dstack/_internal/server/services/jobs/configurators/extensions/cursor.py +3 -1
  31. dstack/_internal/server/services/jobs/configurators/extensions/vscode.py +3 -1
  32. dstack/_internal/server/services/plugins.py +64 -32
  33. dstack/_internal/server/services/runs.py +33 -20
  34. dstack/_internal/server/services/volumes.py +1 -1
  35. dstack/_internal/server/settings.py +1 -0
  36. dstack/_internal/server/statics/index.html +1 -1
  37. dstack/_internal/server/statics/{main-b4f65323f5df007e1664.js → main-5b9786c955b42bf93581.js} +8 -8
  38. dstack/_internal/server/statics/{main-b4f65323f5df007e1664.js.map → main-5b9786c955b42bf93581.js.map} +1 -1
  39. dstack/_internal/server/testing/common.py +2 -0
  40. dstack/_internal/server/utils/routers.py +3 -6
  41. dstack/_internal/settings.py +4 -0
  42. dstack/api/_public/runs.py +6 -3
  43. dstack/api/server/_runs.py +2 -0
  44. dstack/plugins/builtin/__init__.py +0 -0
  45. dstack/plugins/builtin/rest_plugin/__init__.py +18 -0
  46. dstack/plugins/builtin/rest_plugin/_models.py +48 -0
  47. dstack/plugins/builtin/rest_plugin/_plugin.py +127 -0
  48. dstack/version.py +2 -2
  49. {dstack-0.19.9.dist-info → dstack-0.19.11.dist-info}/METADATA +10 -6
  50. {dstack-0.19.9.dist-info → dstack-0.19.11.dist-info}/RECORD +53 -47
  51. {dstack-0.19.9.dist-info → dstack-0.19.11.dist-info}/WHEEL +0 -0
  52. {dstack-0.19.9.dist-info → dstack-0.19.11.dist-info}/entry_points.txt +0 -0
  53. {dstack-0.19.9.dist-info → dstack-0.19.11.dist-info}/licenses/LICENSE.md +0 -0
@@ -14,7 +14,7 @@ logger = get_logger(__name__)
14
14
 
15
15
  class ConfigCommand(BaseCommand):
16
16
  NAME = "config"
17
- DESCRIPTION = "Configure CLI"
17
+ DESCRIPTION = "Configure CLI (deprecated; use `dstack project`)"
18
18
 
19
19
  def _register(self):
20
20
  super()._register()
@@ -39,8 +39,6 @@ class MetricsCommand(APIBaseCommand):
39
39
  run = self.api.runs.get(run_name=args.run_name)
40
40
  if run is None:
41
41
  raise CLIError(f"Run {args.run_name} not found")
42
- if run.status.is_finished():
43
- raise CLIError(f"Run {args.run_name} is finished")
44
42
  metrics = _get_run_jobs_metrics(api=self.api, run=run)
45
43
 
46
44
  if not args.watch:
@@ -55,8 +53,6 @@ class MetricsCommand(APIBaseCommand):
55
53
  run = self.api.runs.get(run_name=args.run_name)
56
54
  if run is None:
57
55
  raise CLIError(f"Run {args.run_name} not found")
58
- if run.status.is_finished():
59
- raise CLIError(f"Run {args.run_name} is finished")
60
56
  metrics = _get_run_jobs_metrics(api=self.api, run=run)
61
57
  except KeyboardInterrupt:
62
58
  pass
@@ -78,11 +74,12 @@ def _get_run_jobs_metrics(api: Client, run: Run) -> List[JobMetrics]:
78
74
  def _get_metrics_table(run: Run, metrics: List[JobMetrics]) -> Table:
79
75
  table = Table(box=None)
80
76
  table.add_column("NAME", style="bold", no_wrap=True)
77
+ table.add_column("STATUS")
81
78
  table.add_column("CPU")
82
79
  table.add_column("MEMORY")
83
80
  table.add_column("GPU")
84
81
 
85
- run_row: Dict[Union[str, int], Any] = {"NAME": run.name}
82
+ run_row: Dict[Union[str, int], Any] = {"NAME": run.name, "STATUS": run.status.value}
86
83
  if len(run._run.jobs) != 1:
87
84
  add_row_from_dict(table, run_row)
88
85
 
@@ -101,9 +98,9 @@ def _get_metrics_table(run: Run, metrics: List[JobMetrics]) -> Table:
101
98
  cpu_usage = f"{cpu_usage:.0f}%"
102
99
  memory_usage = _get_metric_value(job_metrics, "memory_working_set_bytes")
103
100
  if memory_usage is not None:
104
- memory_usage = f"{round(memory_usage / 1024 / 1024)}MB"
101
+ memory_usage = _format_memory(memory_usage, 2)
105
102
  if resources is not None:
106
- memory_usage += f"/{resources.memory_mib}MB"
103
+ memory_usage += f"/{_format_memory(resources.memory_mib * 1024 * 1024, 2)}"
107
104
  gpu_metrics = ""
108
105
  gpus_detected_num = _get_metric_value(job_metrics, "gpus_detected_num")
109
106
  if gpus_detected_num is not None:
@@ -113,13 +110,16 @@ def _get_metrics_table(run: Run, metrics: List[JobMetrics]) -> Table:
113
110
  if gpu_memory_usage is not None:
114
111
  if i != 0:
115
112
  gpu_metrics += "\n"
116
- gpu_metrics += f"#{i} {round(gpu_memory_usage / 1024 / 1024)}MB"
113
+ gpu_metrics += f"gpu={i} mem={_format_memory(gpu_memory_usage, 2)}"
117
114
  if resources is not None:
118
- gpu_metrics += f"/{resources.gpus[i].memory_mib}MB"
119
- gpu_metrics += f" {gpu_util_percent}% Util"
115
+ gpu_metrics += (
116
+ f"/{_format_memory(resources.gpus[i].memory_mib * 1024 * 1024, 2)}"
117
+ )
118
+ gpu_metrics += f" util={gpu_util_percent}%"
120
119
 
121
120
  job_row: Dict[Union[str, int], Any] = {
122
121
  "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}",
122
+ "STATUS": job.job_submissions[-1].status.value,
123
123
  "CPU": cpu_usage or "-",
124
124
  "MEMORY": memory_usage or "-",
125
125
  "GPU": gpu_metrics or "-",
@@ -136,3 +136,18 @@ def _get_metric_value(job_metrics: JobMetrics, name: str) -> Optional[Any]:
136
136
  if metric.name == name:
137
137
  return metric.values[-1]
138
138
  return None
139
+
140
+
141
+ def _format_memory(memory_bytes: int, decimal_places: int) -> str:
142
+ """See test_format_memory in tests/_internal/cli/commands/test_metrics.py for examples."""
143
+ memory_mb = memory_bytes / 1024 / 1024
144
+ if memory_mb >= 1024:
145
+ value = memory_mb / 1024
146
+ unit = "GB"
147
+ else:
148
+ value = memory_mb
149
+ unit = "MB"
150
+
151
+ if decimal_places == 0:
152
+ return f"{round(value)}{unit}"
153
+ return f"{value:.{decimal_places}f}".rstrip("0").rstrip(".") + unit
@@ -84,6 +84,8 @@ class OfferCommand(APIBaseCommand):
84
84
  job_plan = run_plan.job_plans[0]
85
85
 
86
86
  if args.format == "json":
87
+ # FIXME: Should use effective_run_spec from run_plan,
88
+ # since the spec can be changed by the server and plugins
87
89
  output = {
88
90
  "project": run_plan.project_name,
89
91
  "user": run_plan.user,
@@ -0,0 +1,161 @@
1
+ import argparse
2
+
3
+ from requests import HTTPError
4
+ from rich.table import Table
5
+
6
+ import dstack.api.server
7
+ from dstack._internal.cli.commands import BaseCommand
8
+ from dstack._internal.cli.utils.common import confirm_ask, console
9
+ from dstack._internal.core.errors import ClientError, CLIError
10
+ from dstack._internal.core.services.configs import ConfigManager
11
+ from dstack._internal.utils.logging import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ class ProjectCommand(BaseCommand):
17
+ NAME = "project"
18
+ DESCRIPTION = "Manage projects configs"
19
+
20
+ def _register(self):
21
+ super()._register()
22
+ subparsers = self._parser.add_subparsers(dest="subcommand", help="Command to execute")
23
+
24
+ # Add subcommand
25
+ add_parser = subparsers.add_parser("add", help="Add or update a project config")
26
+ add_parser.add_argument(
27
+ "--name", type=str, help="The name of the project to configure", required=True
28
+ )
29
+ add_parser.add_argument("--url", type=str, help="Server url", required=True)
30
+ add_parser.add_argument("--token", type=str, help="User token", required=True)
31
+ add_parser.add_argument(
32
+ "-y",
33
+ "--yes",
34
+ help="Don't ask for confirmation (e.g. update the config)",
35
+ action="store_true",
36
+ )
37
+ add_parser.add_argument(
38
+ "-n",
39
+ "--no",
40
+ help="Don't ask for confirmation (e.g. do not update the config)",
41
+ action="store_true",
42
+ )
43
+ add_parser.set_defaults(subfunc=self._add)
44
+
45
+ # Delete subcommand
46
+ delete_parser = subparsers.add_parser("delete", help="Delete a project config")
47
+ delete_parser.add_argument(
48
+ "--name", type=str, help="The name of the project to delete", required=True
49
+ )
50
+ delete_parser.add_argument(
51
+ "-y",
52
+ "--yes",
53
+ help="Don't ask for confirmation",
54
+ action="store_true",
55
+ )
56
+ delete_parser.set_defaults(subfunc=self._delete)
57
+
58
+ # List subcommand
59
+ list_parser = subparsers.add_parser("list", help="List configured projects")
60
+ list_parser.set_defaults(subfunc=self._list)
61
+
62
+ # Set default subcommand
63
+ set_default_parser = subparsers.add_parser("set-default", help="Set default project")
64
+ set_default_parser.add_argument(
65
+ "name", type=str, help="The name of the project to set as default"
66
+ )
67
+ set_default_parser.set_defaults(subfunc=self._set_default)
68
+
69
+ def _command(self, args: argparse.Namespace):
70
+ if not hasattr(args, "subfunc"):
71
+ args.subfunc = self._list
72
+ args.subfunc(args)
73
+
74
+ def _add(self, args: argparse.Namespace):
75
+ config_manager = ConfigManager()
76
+ api_client = dstack.api.server.APIClient(base_url=args.url, token=args.token)
77
+ try:
78
+ api_client.projects.get(args.name)
79
+ except HTTPError as e:
80
+ if e.response.status_code == 403:
81
+ raise CLIError("Forbidden. Ensure the token is valid.")
82
+ elif e.response.status_code == 404:
83
+ raise CLIError(f"Project '{args.name}' not found.")
84
+ else:
85
+ raise e
86
+ default_project = config_manager.get_project_config()
87
+ if (
88
+ default_project is None
89
+ or default_project.name != args.name
90
+ or default_project.url != args.url
91
+ or default_project.token != args.token
92
+ ):
93
+ set_it_as_default = (
94
+ (
95
+ args.yes
96
+ or not default_project
97
+ or confirm_ask(f"Set '{args.name}' as your default project?")
98
+ )
99
+ if not args.no
100
+ else False
101
+ )
102
+ config_manager.configure_project(
103
+ name=args.name, url=args.url, token=args.token, default=set_it_as_default
104
+ )
105
+ config_manager.save()
106
+ logger.info(
107
+ f"Configuration updated at {config_manager.config_filepath}", {"show_path": False}
108
+ )
109
+
110
+ def _delete(self, args: argparse.Namespace):
111
+ config_manager = ConfigManager()
112
+ if args.yes or confirm_ask(f"Are you sure you want to delete project '{args.name}'?"):
113
+ config_manager.delete_project(args.name)
114
+ config_manager.save()
115
+ console.print("[grey58]OK[/]")
116
+
117
+ def _list(self, args: argparse.Namespace):
118
+ config_manager = ConfigManager()
119
+ default_project = config_manager.get_project_config()
120
+
121
+ table = Table(box=None)
122
+ table.add_column("PROJECT", style="bold", no_wrap=True)
123
+ table.add_column("URL", style="grey58")
124
+ table.add_column("USER", style="grey58")
125
+ table.add_column("DEFAULT", justify="center")
126
+
127
+ for project_name in config_manager.list_projects():
128
+ project_config = config_manager.get_project_config(project_name)
129
+ is_default = project_name == default_project.name if default_project else False
130
+
131
+ # Get username from API
132
+ try:
133
+ api_client = dstack.api.server.APIClient(
134
+ base_url=project_config.url, token=project_config.token
135
+ )
136
+ user_info = api_client.users.get_my_user()
137
+ username = user_info.username
138
+ except ClientError:
139
+ username = "(invalid token)"
140
+
141
+ table.add_row(
142
+ project_name,
143
+ project_config.url,
144
+ username,
145
+ "✓" if is_default else "",
146
+ style="bold" if is_default else None,
147
+ )
148
+
149
+ console.print(table)
150
+
151
+ def _set_default(self, args: argparse.Namespace):
152
+ config_manager = ConfigManager()
153
+ project_config = config_manager.get_project_config(args.name)
154
+ if project_config is None:
155
+ raise CLIError(f"Project '{args.name}' not found")
156
+
157
+ config_manager.configure_project(
158
+ name=args.name, url=project_config.url, token=project_config.token, default=True
159
+ )
160
+ config_manager.save()
161
+ console.print("[grey58]OK[/]")
@@ -36,10 +36,17 @@ class PsCommand(APIBaseCommand):
36
36
  help="Watch statuses of runs in realtime",
37
37
  action="store_true",
38
38
  )
39
+ self._parser.add_argument(
40
+ "-n",
41
+ "--last",
42
+ help="Show only the last N runs. Implies --all",
43
+ type=int,
44
+ default=None,
45
+ )
39
46
 
40
47
  def _command(self, args: argparse.Namespace):
41
48
  super()._command(args)
42
- runs = self.api.runs.list(all=args.all)
49
+ runs = self.api.runs.list(all=args.all, limit=args.last)
43
50
  if not args.watch:
44
51
  console.print(run_utils.get_runs_table(runs, verbose=args.verbose))
45
52
  return
@@ -49,6 +56,6 @@ class PsCommand(APIBaseCommand):
49
56
  while True:
50
57
  live.update(run_utils.get_runs_table(runs, verbose=args.verbose))
51
58
  time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
52
- runs = self.api.runs.list(all=args.all)
59
+ runs = self.api.runs.list(all=args.all, limit=args.last)
53
60
  except KeyboardInterrupt:
54
61
  pass
@@ -15,6 +15,7 @@ from dstack._internal.cli.commands.init import InitCommand
15
15
  from dstack._internal.cli.commands.logs import LogsCommand
16
16
  from dstack._internal.cli.commands.metrics import MetricsCommand
17
17
  from dstack._internal.cli.commands.offer import OfferCommand
18
+ from dstack._internal.cli.commands.project import ProjectCommand
18
19
  from dstack._internal.cli.commands.ps import PsCommand
19
20
  from dstack._internal.cli.commands.server import ServerCommand
20
21
  from dstack._internal.cli.commands.stats import StatsCommand
@@ -69,6 +70,7 @@ def main():
69
70
  OfferCommand.register(subparsers)
70
71
  LogsCommand.register(subparsers)
71
72
  MetricsCommand.register(subparsers)
73
+ ProjectCommand.register(subparsers)
72
74
  PsCommand.register(subparsers)
73
75
  ServerCommand.register(subparsers)
74
76
  StatsCommand.register(subparsers)
@@ -105,7 +105,7 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
105
105
  changed_fields = []
106
106
  if run_plan.action == ApplyAction.UPDATE:
107
107
  diff = diff_models(
108
- run_plan.run_spec.configuration,
108
+ run_plan.get_effective_run_spec().configuration,
109
109
  run_plan.current_resource.run_spec.configuration,
110
110
  )
111
111
  changed_fields = list(diff.keys())
@@ -57,10 +57,22 @@ def _is_last_check_time_outdated() -> bool:
57
57
  )
58
58
 
59
59
 
60
+ def is_update_available(current_version: str, latest_version: str) -> bool:
61
+ """
62
+ Return True if latest_version is newer than current_version.
63
+ Pre-releases are only considered if the current version is also a pre-release.
64
+ """
65
+ _current_version = pkg_version.parse(str(current_version))
66
+ _latest_version = pkg_version.parse(str(latest_version))
67
+ return _current_version < _latest_version and (
68
+ not _latest_version.is_prerelease or _current_version.is_prerelease
69
+ )
70
+
71
+
60
72
  def _check_version():
61
73
  latest_version = get_latest_version()
62
74
  if latest_version is not None:
63
- if pkg_version.parse(str(version.__version__)) < pkg_version.parse(latest_version):
75
+ if is_update_available(version.__version__, latest_version):
64
76
  console.print(f"A new version of dstack is available: [code]{latest_version}[/]\n")
65
77
 
66
78
 
@@ -611,9 +611,12 @@ class AWSCompute(
611
611
  raise e
612
612
  logger.debug("Deleted EBS volume %s", volume.configuration.name)
613
613
 
614
- def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData:
614
+ def attach_volume(
615
+ self, volume: Volume, provisioning_data: JobProvisioningData
616
+ ) -> VolumeAttachmentData:
615
617
  ec2_client = self.session.client("ec2", region_name=volume.configuration.region)
616
618
 
619
+ instance_id = provisioning_data.instance_id
617
620
  device_names = aws_resources.list_available_device_names(
618
621
  ec2_client=ec2_client, instance_id=instance_id
619
622
  )
@@ -646,9 +649,12 @@ class AWSCompute(
646
649
  logger.debug("Attached EBS volume %s to instance %s", volume.volume_id, instance_id)
647
650
  return VolumeAttachmentData(device_name=device_name)
648
651
 
649
- def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
652
+ def detach_volume(
653
+ self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
654
+ ):
650
655
  ec2_client = self.session.client("ec2", region_name=volume.configuration.region)
651
656
 
657
+ instance_id = provisioning_data.instance_id
652
658
  logger.debug("Detaching EBS volume %s from instance %s", volume.volume_id, instance_id)
653
659
  attachment_data = get_or_error(volume.get_attachment_data_for_instance(instance_id))
654
660
  try:
@@ -667,9 +673,10 @@ class AWSCompute(
667
673
  raise e
668
674
  logger.debug("Detached EBS volume %s from instance %s", volume.volume_id, instance_id)
669
675
 
670
- def is_volume_detached(self, volume: Volume, instance_id: str) -> bool:
676
+ def is_volume_detached(self, volume: Volume, provisioning_data: JobProvisioningData) -> bool:
671
677
  ec2_client = self.session.client("ec2", region_name=volume.configuration.region)
672
678
 
679
+ instance_id = provisioning_data.instance_id
673
680
  logger.debug("Getting EBS volume %s status", volume.volume_id)
674
681
  response = ec2_client.describe_volumes(VolumeIds=[volume.volume_id])
675
682
  volumes_infos = response.get("Volumes")
@@ -819,18 +826,23 @@ def _get_regions_to_zones(session: boto3.Session, regions: List[str]) -> Dict[st
819
826
 
820
827
  def _supported_instances(offer: InstanceOffer) -> bool:
821
828
  for family in [
829
+ "m7i.",
830
+ "c7i.",
831
+ "r7i.",
832
+ "t3.",
822
833
  "t2.small",
823
834
  "c5.",
824
835
  "m5.",
825
- "g4dn.",
826
- "g5.",
836
+ "p5.",
837
+ "p5e.",
838
+ "p4d.",
839
+ "p4de.",
840
+ "p3.",
827
841
  "g6.",
828
842
  "g6e.",
829
843
  "gr6.",
830
- "p3.",
831
- "p4d.",
832
- "p4de.",
833
- "p5.",
844
+ "g5.",
845
+ "g4dn.",
834
846
  ]:
835
847
  if offer.instance.name.startswith(family):
836
848
  return True
@@ -391,9 +391,9 @@ class VMImageVariant(enum.Enum):
391
391
 
392
392
 
393
393
  _SUPPORTED_VM_SERIES_PATTERNS = [
394
- r"D(\d+)s_v3", # Dsv3-series
395
- r"E(\d+)i?s_v4", # Esv4-series
396
- r"E(\d+)-(\d+)s_v4", # Esv4-series (constrained vCPU)
394
+ r"D(\d+)s_v6", # Dsv6-series (general purpose)
395
+ r"E(\d+)i?s_v6", # Esv6-series (memory optimized)
396
+ r"F(\d+)s_v2", # Fsv2-series (compute optimized)
397
397
  r"NC(\d+)s_v3", # NCv3-series [V100 16GB]
398
398
  r"NC(\d+)as_T4_v3", # NCasT4_v3-series [T4]
399
399
  r"ND(\d+)rs_v2", # NDv2-series [8xV100 32GB]
@@ -401,6 +401,11 @@ _SUPPORTED_VM_SERIES_PATTERNS = [
401
401
  r"NC(\d+)ads_A100_v4", # NC A100 v4-series [A100 80GB]
402
402
  r"ND(\d+)asr_v4", # ND A100 v4-series [8xA100 40GB]
403
403
  r"ND(\d+)amsr_A100_v4", # NDm A100 v4-series [8xA100 80GB]
404
+ # Deprecated series
405
+ # TODO: Remove after several releases
406
+ r"D(\d+)s_v3", # Dsv3-series (general purpose)
407
+ r"E(\d+)i?s_v4", # Esv4-series (memory optimized)
408
+ r"E(\d+)-(\d+)s_v4", # Esv4-series (constrained vCPU)
404
409
  ]
405
410
  _SUPPORTED_VM_SERIES_PATTERN = (
406
411
  "^Standard_(" + "|".join(f"({s})" for s in _SUPPORTED_VM_SERIES_PATTERNS) + ")$"
@@ -19,6 +19,7 @@ from dstack._internal.core.consts import (
19
19
  DSTACK_RUNNER_SSH_PORT,
20
20
  DSTACK_SHIM_HTTP_PORT,
21
21
  )
22
+ from dstack._internal.core.models.configurations import DEFAULT_REPO_DIR
22
23
  from dstack._internal.core.models.gateways import (
23
24
  GatewayComputeConfiguration,
24
25
  GatewayProvisioningData,
@@ -335,7 +336,9 @@ class ComputeWithVolumeSupport(ABC):
335
336
  """
336
337
  raise NotImplementedError()
337
338
 
338
- def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData:
339
+ def attach_volume(
340
+ self, volume: Volume, provisioning_data: JobProvisioningData
341
+ ) -> VolumeAttachmentData:
339
342
  """
340
343
  Attaches a volume to the instance.
341
344
  If the volume is not found, it should raise `ComputeError()`.
@@ -344,7 +347,9 @@ class ComputeWithVolumeSupport(ABC):
344
347
  """
345
348
  raise NotImplementedError()
346
349
 
347
- def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
350
+ def detach_volume(
351
+ self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
352
+ ):
348
353
  """
349
354
  Detaches a volume from the instance.
350
355
  Implement only if compute may return `VolumeProvisioningData.detachable`.
@@ -352,7 +357,7 @@ class ComputeWithVolumeSupport(ABC):
352
357
  """
353
358
  raise NotImplementedError()
354
359
 
355
- def is_volume_detached(self, volume: Volume, instance_id: str) -> bool:
360
+ def is_volume_detached(self, volume: Volume, provisioning_data: JobProvisioningData) -> bool:
356
361
  """
357
362
  Checks if a volume was detached from the instance.
358
363
  If `detach_volume()` may fail to detach volume,
@@ -754,7 +759,7 @@ def get_docker_commands(
754
759
  f" --ssh-port {DSTACK_RUNNER_SSH_PORT}"
755
760
  " --temp-dir /tmp/runner"
756
761
  " --home-dir /root"
757
- " --working-dir /workflow"
762
+ f" --working-dir {DEFAULT_REPO_DIR}"
758
763
  ),
759
764
  ]
760
765
 
@@ -649,13 +649,24 @@ class GCPCompute(
649
649
  pass
650
650
  logger.debug("Deleted persistent disk for volume %s", volume.name)
651
651
 
652
- def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData:
652
+ def attach_volume(
653
+ self, volume: Volume, provisioning_data: JobProvisioningData
654
+ ) -> VolumeAttachmentData:
655
+ instance_id = provisioning_data.instance_id
653
656
  logger.debug(
654
657
  "Attaching persistent disk for volume %s to instance %s",
655
658
  volume.volume_id,
656
659
  instance_id,
657
660
  )
661
+ if not gcp_resources.instance_type_supports_persistent_disk(
662
+ provisioning_data.instance_type.name
663
+ ):
664
+ raise ComputeError(
665
+ f"Instance type {provisioning_data.instance_type.name} does not support Persistent disk volumes"
666
+ )
667
+
658
668
  zone = get_or_error(volume.provisioning_data).availability_zone
669
+ is_tpu = _is_tpu_provisioning_data(provisioning_data)
659
670
  try:
660
671
  disk = self.disk_client.get(
661
672
  project=self.config.project_id,
@@ -663,18 +674,16 @@ class GCPCompute(
663
674
  disk=volume.volume_id,
664
675
  )
665
676
  disk_url = disk.self_link
677
+ except google.api_core.exceptions.NotFound:
678
+ raise ComputeError("Persistent disk found")
666
679
 
667
- # This method has no information if the instance is a TPU or a VM,
668
- # so we first try to see if there is a TPU with such name
669
- try:
680
+ try:
681
+ if is_tpu:
670
682
  get_node_request = tpu_v2.GetNodeRequest(
671
683
  name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}",
672
684
  )
673
685
  tpu_node = self.tpu_client.get_node(get_node_request)
674
- except google.api_core.exceptions.NotFound:
675
- tpu_node = None
676
686
 
677
- if tpu_node is not None:
678
687
  # Python API to attach a disk to a TPU is not documented,
679
688
  # so we follow the code from the gcloud CLI:
680
689
  # https://github.com/twistedpair/google-cloud-sdk/blob/26ab5a281d56b384cc25750f3279a27afe5b499f/google-cloud-sdk/lib/googlecloudsdk/command_lib/compute/tpus/tpu_vm/util.py#L113
@@ -711,7 +720,6 @@ class GCPCompute(
711
720
  attached_disk.auto_delete = False
712
721
  attached_disk.device_name = f"pd-{volume.volume_id}"
713
722
  device_name = attached_disk.device_name
714
-
715
723
  operation = self.instances_client.attach_disk(
716
724
  project=self.config.project_id,
717
725
  zone=zone,
@@ -720,13 +728,16 @@ class GCPCompute(
720
728
  )
721
729
  gcp_resources.wait_for_extended_operation(operation, "persistent disk attachment")
722
730
  except google.api_core.exceptions.NotFound:
723
- raise ComputeError("Persistent disk or instance not found")
731
+ raise ComputeError("Disk or instance not found")
724
732
  logger.debug(
725
733
  "Attached persistent disk for volume %s to instance %s", volume.volume_id, instance_id
726
734
  )
727
735
  return VolumeAttachmentData(device_name=device_name)
728
736
 
729
- def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
737
+ def detach_volume(
738
+ self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
739
+ ):
740
+ instance_id = provisioning_data.instance_id
730
741
  logger.debug(
731
742
  "Detaching persistent disk for volume %s from instance %s",
732
743
  volume.volume_id,
@@ -734,17 +745,16 @@ class GCPCompute(
734
745
  )
735
746
  zone = get_or_error(volume.provisioning_data).availability_zone
736
747
  attachment_data = get_or_error(volume.get_attachment_data_for_instance(instance_id))
737
- # This method has no information if the instance is a TPU or a VM,
738
- # so we first try to see if there is a TPU with such name
739
- try:
740
- get_node_request = tpu_v2.GetNodeRequest(
741
- name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}",
742
- )
743
- tpu_node = self.tpu_client.get_node(get_node_request)
744
- except google.api_core.exceptions.NotFound:
745
- tpu_node = None
748
+ is_tpu = _is_tpu_provisioning_data(provisioning_data)
749
+ if is_tpu:
750
+ try:
751
+ get_node_request = tpu_v2.GetNodeRequest(
752
+ name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}",
753
+ )
754
+ tpu_node = self.tpu_client.get_node(get_node_request)
755
+ except google.api_core.exceptions.NotFound:
756
+ raise ComputeError("Instance not found")
746
757
 
747
- if tpu_node is not None:
748
758
  source_disk = (
749
759
  f"projects/{self.config.project_id}/zones/{zone}/disks/{volume.volume_id}"
750
760
  )
@@ -815,6 +825,11 @@ def _supported_instances_and_zones(
815
825
  if _is_tpu(offer.instance.name) and not _is_single_host_tpu(offer.instance.name):
816
826
  return False
817
827
  for family in [
828
+ "m4-",
829
+ "c4-",
830
+ "n4-",
831
+ "h3-",
832
+ "n2-",
818
833
  "e2-medium",
819
834
  "e2-standard-",
820
835
  "e2-highmem-",
@@ -1001,3 +1016,11 @@ def _get_tpu_data_disk_for_volume(project_id: str, volume: Volume) -> tpu_v2.Att
1001
1016
  mode=tpu_v2.AttachedDisk.DiskMode.READ_WRITE,
1002
1017
  )
1003
1018
  return attached_disk
1019
+
1020
+
1021
+ def _is_tpu_provisioning_data(provisioning_data: JobProvisioningData) -> bool:
1022
+ is_tpu = False
1023
+ if provisioning_data.backend_data:
1024
+ backend_data_dict = json.loads(provisioning_data.backend_data)
1025
+ is_tpu = backend_data_dict.get("is_tpu", False)
1026
+ return is_tpu
@@ -140,7 +140,10 @@ def create_instance_struct(
140
140
  initialize_params = compute_v1.AttachedDiskInitializeParams()
141
141
  initialize_params.source_image = image_id
142
142
  initialize_params.disk_size_gb = disk_size
143
- initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced"
143
+ if instance_type_supports_persistent_disk(machine_type):
144
+ initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced"
145
+ else:
146
+ initialize_params.disk_type = f"zones/{zone}/diskTypes/hyperdisk-balanced"
144
147
  disk.initialize_params = initialize_params
145
148
  instance.disks = [disk]
146
149
 
@@ -421,7 +424,7 @@ def wait_for_extended_operation(
421
424
 
422
425
  if operation.error_code:
423
426
  # Write only debug logs here.
424
- # The unexpected errors will be propagated and logged appropriatly by the caller.
427
+ # The unexpected errors will be propagated and logged appropriately by the caller.
425
428
  logger.debug(
426
429
  "Error during %s: [Code: %s]: %s",
427
430
  verbose_name,
@@ -462,3 +465,16 @@ def get_placement_policy_resource_name(
462
465
  placement_policy: str,
463
466
  ) -> str:
464
467
  return f"projects/{project_id}/regions/{region}/resourcePolicies/{placement_policy}"
468
+
469
+
470
+ def instance_type_supports_persistent_disk(instance_type_name: str) -> bool:
471
+ return not any(
472
+ instance_type_name.startswith(series)
473
+ for series in [
474
+ "m4-",
475
+ "c4-",
476
+ "n4-",
477
+ "h3-",
478
+ "v6e",
479
+ ]
480
+ )
@@ -110,8 +110,10 @@ class LocalCompute(
110
110
  def delete_volume(self, volume: Volume):
111
111
  pass
112
112
 
113
- def attach_volume(self, volume: Volume, instance_id: str):
113
+ def attach_volume(self, volume: Volume, provisioning_data: JobProvisioningData):
114
114
  pass
115
115
 
116
- def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
116
+ def detach_volume(
117
+ self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
118
+ ):
117
119
  pass