dstack 0.19.11rc1__py3-none-any.whl → 0.19.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (45) hide show
  1. dstack/_internal/cli/commands/offer.py +2 -0
  2. dstack/_internal/cli/services/configurators/run.py +43 -42
  3. dstack/_internal/cli/utils/run.py +10 -26
  4. dstack/_internal/cli/utils/updates.py +13 -1
  5. dstack/_internal/core/backends/aws/compute.py +21 -9
  6. dstack/_internal/core/backends/base/compute.py +7 -3
  7. dstack/_internal/core/backends/gcp/compute.py +43 -20
  8. dstack/_internal/core/backends/gcp/resources.py +18 -2
  9. dstack/_internal/core/backends/local/compute.py +4 -2
  10. dstack/_internal/core/backends/template/configurator.py.jinja +1 -6
  11. dstack/_internal/core/backends/template/models.py.jinja +4 -0
  12. dstack/_internal/core/models/configurations.py +1 -1
  13. dstack/_internal/core/models/fleets.py +6 -1
  14. dstack/_internal/core/models/profiles.py +43 -3
  15. dstack/_internal/core/models/repos/local.py +19 -13
  16. dstack/_internal/core/models/runs.py +78 -45
  17. dstack/_internal/server/background/tasks/process_running_jobs.py +47 -12
  18. dstack/_internal/server/background/tasks/process_runs.py +14 -1
  19. dstack/_internal/server/background/tasks/process_submitted_jobs.py +3 -3
  20. dstack/_internal/server/routers/repos.py +9 -4
  21. dstack/_internal/server/services/fleets.py +2 -2
  22. dstack/_internal/server/services/gateways/__init__.py +1 -1
  23. dstack/_internal/server/services/jobs/__init__.py +4 -4
  24. dstack/_internal/server/services/plugins.py +64 -32
  25. dstack/_internal/server/services/runner/client.py +4 -1
  26. dstack/_internal/server/services/runs.py +2 -2
  27. dstack/_internal/server/services/volumes.py +1 -1
  28. dstack/_internal/server/statics/index.html +1 -1
  29. dstack/_internal/server/statics/{main-b4803049eac16aea9a49.js → main-b0e80f8e26a168c129e9.js} +72 -25
  30. dstack/_internal/server/statics/{main-b4803049eac16aea9a49.js.map → main-b0e80f8e26a168c129e9.js.map} +1 -1
  31. dstack/_internal/server/testing/common.py +2 -1
  32. dstack/_internal/utils/common.py +4 -0
  33. dstack/api/server/_fleets.py +5 -1
  34. dstack/api/server/_runs.py +8 -0
  35. dstack/plugins/builtin/__init__.py +0 -0
  36. dstack/plugins/builtin/rest_plugin/__init__.py +18 -0
  37. dstack/plugins/builtin/rest_plugin/_models.py +48 -0
  38. dstack/plugins/builtin/rest_plugin/_plugin.py +127 -0
  39. dstack/version.py +1 -1
  40. {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/METADATA +2 -2
  41. {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/RECORD +44 -41
  42. dstack/_internal/utils/ignore.py +0 -92
  43. {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/WHEEL +0 -0
  44. {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/entry_points.txt +0 -0
  45. {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/licenses/LICENSE.md +0 -0
@@ -84,6 +84,8 @@ class OfferCommand(APIBaseCommand):
84
84
  job_plan = run_plan.job_plans[0]
85
85
 
86
86
  if args.format == "json":
87
+ # FIXME: Should use effective_run_spec from run_plan,
88
+ # since the spec can be changed by the server and plugins
87
89
  output = {
88
90
  "project": run_plan.project_name,
89
91
  "user": run_plan.user,
@@ -3,7 +3,7 @@ import subprocess
3
3
  import sys
4
4
  import time
5
5
  from pathlib import Path
6
- from typing import Dict, List, Optional, Set, Tuple
6
+ from typing import Dict, List, Optional, Set
7
7
 
8
8
  import gpuhunt
9
9
  from pydantic import parse_obj_as
@@ -41,7 +41,7 @@ from dstack._internal.core.models.configurations import (
41
41
  )
42
42
  from dstack._internal.core.models.repos.base import Repo
43
43
  from dstack._internal.core.models.resources import CPUSpec
44
- from dstack._internal.core.models.runs import JobSubmission, JobTerminationReason, RunStatus
44
+ from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunStatus
45
45
  from dstack._internal.core.services.configs import ConfigManager
46
46
  from dstack._internal.core.services.diff import diff_models
47
47
  from dstack._internal.utils.common import local_time
@@ -105,7 +105,7 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
105
105
  changed_fields = []
106
106
  if run_plan.action == ApplyAction.UPDATE:
107
107
  diff = diff_models(
108
- run_plan.run_spec.configuration,
108
+ run_plan.get_effective_run_spec().configuration,
109
109
  run_plan.current_resource.run_spec.configuration,
110
110
  )
111
111
  changed_fields = list(diff.keys())
@@ -553,35 +553,38 @@ def _print_service_urls(run: Run) -> None:
553
553
 
554
554
 
555
555
  def print_finished_message(run: Run):
556
+ status_message = (
557
+ run._run.latest_job_submission.status_message
558
+ if run._run.latest_job_submission
559
+ else run._run.status_message
560
+ )
561
+ error = (
562
+ run._run.latest_job_submission.error if run._run.latest_job_submission else run._run.error
563
+ )
564
+ termination_reason = (
565
+ run._run.latest_job_submission.termination_reason
566
+ if run._run.latest_job_submission
567
+ else None
568
+ )
569
+ termination_reason_message = (
570
+ run._run.latest_job_submission.termination_reason_message
571
+ if run._run.latest_job_submission
572
+ else None
573
+ )
556
574
  if run.status == RunStatus.DONE:
557
- console.print("[code]Done[/]")
575
+ console.print(f"[code]{status_message.capitalize()}[/code]")
558
576
  return
577
+ else:
578
+ str = f"[error]{status_message.capitalize()}[/error]"
579
+ if error:
580
+ str += f" ([error]{error.capitalize()}[/error])"
581
+ console.print(str)
559
582
 
560
- termination_reason, termination_reason_message, exit_status = (
561
- _get_run_termination_reason_and_exit_status(run)
562
- )
563
- message = "Run failed due to unknown reason. Check CLI, server, and run logs."
564
- if run.status == RunStatus.TERMINATED:
565
- message = "Run terminated due to unknown reason. Check CLI, server, and run logs."
566
-
567
- if termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY:
568
- message = (
569
- "All provisioning attempts failed. "
570
- "This is likely due to cloud providers not having enough capacity. "
571
- "Check CLI and server logs for more details."
572
- )
573
- elif termination_reason is not None:
574
- exit_status_details = f"Exit status: {exit_status}.\n" if exit_status else ""
575
- error_details = (
576
- f"Error: {termination_reason_message}\n" if termination_reason_message else ""
577
- )
578
- message = (
579
- f"Run failed with error code {termination_reason.name}.\n"
580
- f"{exit_status_details}"
581
- f"{error_details}"
582
- f"Check [bold]dstack logs -d {run.name}[/bold] for more details."
583
- )
584
- console.print(f"[error]{message}[/]")
583
+ if termination_reason_message:
584
+ console.print(f"[error]{termination_reason_message}[/error]")
585
+
586
+ if termination_reason:
587
+ console.print(f"Check [code]dstack logs -d {run.name}[/code] for more details.")
585
588
 
586
589
 
587
590
  def get_run_exit_code(run: Run) -> int:
@@ -590,19 +593,17 @@ def get_run_exit_code(run: Run) -> int:
590
593
  return 1
591
594
 
592
595
 
593
- def _get_run_termination_reason_and_exit_status(
594
- run: Run,
595
- ) -> Tuple[Optional[JobTerminationReason], Optional[str], Optional[int]]:
596
- if len(run._run.jobs) == 0:
597
- return None, None, None
598
- job = run._run.jobs[0]
599
- if len(job.job_submissions) == 0:
600
- return None, None, None
601
- job_submission = job.job_submissions[0]
602
- return (
603
- job_submission.termination_reason,
604
- job_submission.termination_reason_message,
605
- job_submission.exit_status,
596
+ def _is_ready_to_attach(run: Run) -> bool:
597
+ return not (
598
+ run.status
599
+ in [
600
+ RunStatus.SUBMITTED,
601
+ RunStatus.PENDING,
602
+ RunStatus.PROVISIONING,
603
+ RunStatus.TERMINATING,
604
+ ]
605
+ or run._run.jobs[0].job_submissions[-1].status
606
+ in [JobStatus.SUBMITTED, JobStatus.PROVISIONING, JobStatus.PULLING]
606
607
  )
607
608
 
608
609
 
@@ -12,7 +12,6 @@ from dstack._internal.core.models.profiles import (
12
12
  TerminationPolicy,
13
13
  )
14
14
  from dstack._internal.core.models.runs import (
15
- Job,
16
15
  RunPlan,
17
16
  )
18
17
  from dstack._internal.core.services.profiles import get_termination
@@ -154,8 +153,7 @@ def get_runs_table(
154
153
  table.add_column("BACKEND", style="grey58", ratio=2)
155
154
  table.add_column("RESOURCES", ratio=3 if not verbose else 2)
156
155
  if verbose:
157
- table.add_column("INSTANCE", no_wrap=True, ratio=1)
158
- table.add_column("RESERVATION", no_wrap=True, ratio=1)
156
+ table.add_column("INSTANCE TYPE", no_wrap=True, ratio=1)
159
157
  table.add_column("PRICE", style="grey58", ratio=1)
160
158
  table.add_column("STATUS", no_wrap=True, ratio=1)
161
159
  table.add_column("SUBMITTED", style="grey58", no_wrap=True, ratio=1)
@@ -163,14 +161,14 @@ def get_runs_table(
163
161
  table.add_column("ERROR", no_wrap=True, ratio=2)
164
162
 
165
163
  for run in runs:
166
- run_error = _get_run_error(run)
167
164
  run = run._run # TODO(egor-s): make public attribute
168
165
 
169
166
  run_row: Dict[Union[str, int], Any] = {
170
167
  "NAME": run.run_spec.run_name,
171
168
  "SUBMITTED": format_date(run.submitted_at),
172
- "ERROR": run_error,
173
169
  }
170
+ if run.error:
171
+ run_row["ERROR"] = run.error
174
172
  if len(run.jobs) != 1:
175
173
  run_row["STATUS"] = run.status
176
174
  add_row_from_dict(table, run_row)
@@ -183,25 +181,26 @@ def get_runs_table(
183
181
  status += f" (inactive for {inactive_for})"
184
182
  job_row: Dict[Union[str, int], Any] = {
185
183
  "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}",
186
- "STATUS": status,
184
+ "STATUS": latest_job_submission.status_message,
187
185
  "SUBMITTED": format_date(latest_job_submission.submitted_at),
188
- "ERROR": _get_job_error(job),
186
+ "ERROR": latest_job_submission.error,
189
187
  }
190
188
  jpd = latest_job_submission.job_provisioning_data
191
189
  if jpd is not None:
192
190
  resources = jpd.instance_type.resources
193
- instance = jpd.instance_type.name
191
+ instance_type = jpd.instance_type.name
194
192
  jrd = latest_job_submission.job_runtime_data
195
193
  if jrd is not None and jrd.offer is not None:
196
194
  resources = jrd.offer.instance.resources
197
195
  if jrd.offer.total_blocks > 1:
198
- instance += f" ({jrd.offer.blocks}/{jrd.offer.total_blocks})"
196
+ instance_type += f" ({jrd.offer.blocks}/{jrd.offer.total_blocks})"
197
+ if jpd.reservation:
198
+ instance_type += f" ({jpd.reservation})"
199
199
  job_row.update(
200
200
  {
201
201
  "BACKEND": f"{jpd.backend.value.replace('remote', 'ssh')} ({jpd.region})",
202
202
  "RESOURCES": resources.pretty_format(include_spot=True),
203
- "INSTANCE": instance,
204
- "RESERVATION": jpd.reservation,
203
+ "INSTANCE TYPE": instance_type,
205
204
  "PRICE": f"${jpd.price:.4f}".rstrip("0").rstrip("."),
206
205
  }
207
206
  )
@@ -211,18 +210,3 @@ def get_runs_table(
211
210
  add_row_from_dict(table, job_row, style="secondary" if len(run.jobs) != 1 else None)
212
211
 
213
212
  return table
214
-
215
-
216
- def _get_run_error(run: Run) -> str:
217
- return run._run.error or ""
218
-
219
-
220
- def _get_job_error(job: Job) -> str:
221
- job_submission = job.job_submissions[-1]
222
- termination_reason = job_submission.termination_reason
223
- exit_status = job_submission.exit_status
224
- if termination_reason is None:
225
- return ""
226
- if exit_status:
227
- return f"{termination_reason.name} {exit_status}"
228
- return termination_reason.name
@@ -57,10 +57,22 @@ def _is_last_check_time_outdated() -> bool:
57
57
  )
58
58
 
59
59
 
60
+ def is_update_available(current_version: str, latest_version: str) -> bool:
61
+ """
62
+ Return True if latest_version is newer than current_version.
63
+ Pre-releases are only considered if the current version is also a pre-release.
64
+ """
65
+ _current_version = pkg_version.parse(str(current_version))
66
+ _latest_version = pkg_version.parse(str(latest_version))
67
+ return _current_version < _latest_version and (
68
+ not _latest_version.is_prerelease or _current_version.is_prerelease
69
+ )
70
+
71
+
60
72
  def _check_version():
61
73
  latest_version = get_latest_version()
62
74
  if latest_version is not None:
63
- if pkg_version.parse(str(version.__version__)) < pkg_version.parse(latest_version):
75
+ if is_update_available(version.__version__, latest_version):
64
76
  console.print(f"A new version of dstack is available: [code]{latest_version}[/]\n")
65
77
 
66
78
 
@@ -611,9 +611,12 @@ class AWSCompute(
611
611
  raise e
612
612
  logger.debug("Deleted EBS volume %s", volume.configuration.name)
613
613
 
614
- def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData:
614
+ def attach_volume(
615
+ self, volume: Volume, provisioning_data: JobProvisioningData
616
+ ) -> VolumeAttachmentData:
615
617
  ec2_client = self.session.client("ec2", region_name=volume.configuration.region)
616
618
 
619
+ instance_id = provisioning_data.instance_id
617
620
  device_names = aws_resources.list_available_device_names(
618
621
  ec2_client=ec2_client, instance_id=instance_id
619
622
  )
@@ -646,9 +649,12 @@ class AWSCompute(
646
649
  logger.debug("Attached EBS volume %s to instance %s", volume.volume_id, instance_id)
647
650
  return VolumeAttachmentData(device_name=device_name)
648
651
 
649
- def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
652
+ def detach_volume(
653
+ self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
654
+ ):
650
655
  ec2_client = self.session.client("ec2", region_name=volume.configuration.region)
651
656
 
657
+ instance_id = provisioning_data.instance_id
652
658
  logger.debug("Detaching EBS volume %s from instance %s", volume.volume_id, instance_id)
653
659
  attachment_data = get_or_error(volume.get_attachment_data_for_instance(instance_id))
654
660
  try:
@@ -667,9 +673,10 @@ class AWSCompute(
667
673
  raise e
668
674
  logger.debug("Detached EBS volume %s from instance %s", volume.volume_id, instance_id)
669
675
 
670
- def is_volume_detached(self, volume: Volume, instance_id: str) -> bool:
676
+ def is_volume_detached(self, volume: Volume, provisioning_data: JobProvisioningData) -> bool:
671
677
  ec2_client = self.session.client("ec2", region_name=volume.configuration.region)
672
678
 
679
+ instance_id = provisioning_data.instance_id
673
680
  logger.debug("Getting EBS volume %s status", volume.volume_id)
674
681
  response = ec2_client.describe_volumes(VolumeIds=[volume.volume_id])
675
682
  volumes_infos = response.get("Volumes")
@@ -819,18 +826,23 @@ def _get_regions_to_zones(session: boto3.Session, regions: List[str]) -> Dict[st
819
826
 
820
827
  def _supported_instances(offer: InstanceOffer) -> bool:
821
828
  for family in [
829
+ "m7i.",
830
+ "c7i.",
831
+ "r7i.",
832
+ "t3.",
822
833
  "t2.small",
823
834
  "c5.",
824
835
  "m5.",
825
- "g4dn.",
826
- "g5.",
836
+ "p5.",
837
+ "p5e.",
838
+ "p4d.",
839
+ "p4de.",
840
+ "p3.",
827
841
  "g6.",
828
842
  "g6e.",
829
843
  "gr6.",
830
- "p3.",
831
- "p4d.",
832
- "p4de.",
833
- "p5.",
844
+ "g5.",
845
+ "g4dn.",
834
846
  ]:
835
847
  if offer.instance.name.startswith(family):
836
848
  return True
@@ -336,7 +336,9 @@ class ComputeWithVolumeSupport(ABC):
336
336
  """
337
337
  raise NotImplementedError()
338
338
 
339
- def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData:
339
+ def attach_volume(
340
+ self, volume: Volume, provisioning_data: JobProvisioningData
341
+ ) -> VolumeAttachmentData:
340
342
  """
341
343
  Attaches a volume to the instance.
342
344
  If the volume is not found, it should raise `ComputeError()`.
@@ -345,7 +347,9 @@ class ComputeWithVolumeSupport(ABC):
345
347
  """
346
348
  raise NotImplementedError()
347
349
 
348
- def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
350
+ def detach_volume(
351
+ self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
352
+ ):
349
353
  """
350
354
  Detaches a volume from the instance.
351
355
  Implement only if compute may return `VolumeProvisioningData.detachable`.
@@ -353,7 +357,7 @@ class ComputeWithVolumeSupport(ABC):
353
357
  """
354
358
  raise NotImplementedError()
355
359
 
356
- def is_volume_detached(self, volume: Volume, instance_id: str) -> bool:
360
+ def is_volume_detached(self, volume: Volume, provisioning_data: JobProvisioningData) -> bool:
357
361
  """
358
362
  Checks if a volume was detached from the instance.
359
363
  If `detach_volume()` may fail to detach volume,
@@ -649,13 +649,24 @@ class GCPCompute(
649
649
  pass
650
650
  logger.debug("Deleted persistent disk for volume %s", volume.name)
651
651
 
652
- def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData:
652
+ def attach_volume(
653
+ self, volume: Volume, provisioning_data: JobProvisioningData
654
+ ) -> VolumeAttachmentData:
655
+ instance_id = provisioning_data.instance_id
653
656
  logger.debug(
654
657
  "Attaching persistent disk for volume %s to instance %s",
655
658
  volume.volume_id,
656
659
  instance_id,
657
660
  )
661
+ if not gcp_resources.instance_type_supports_persistent_disk(
662
+ provisioning_data.instance_type.name
663
+ ):
664
+ raise ComputeError(
665
+ f"Instance type {provisioning_data.instance_type.name} does not support Persistent disk volumes"
666
+ )
667
+
658
668
  zone = get_or_error(volume.provisioning_data).availability_zone
669
+ is_tpu = _is_tpu_provisioning_data(provisioning_data)
659
670
  try:
660
671
  disk = self.disk_client.get(
661
672
  project=self.config.project_id,
@@ -663,18 +674,16 @@ class GCPCompute(
663
674
  disk=volume.volume_id,
664
675
  )
665
676
  disk_url = disk.self_link
677
+ except google.api_core.exceptions.NotFound:
678
+ raise ComputeError("Persistent disk found")
666
679
 
667
- # This method has no information if the instance is a TPU or a VM,
668
- # so we first try to see if there is a TPU with such name
669
- try:
680
+ try:
681
+ if is_tpu:
670
682
  get_node_request = tpu_v2.GetNodeRequest(
671
683
  name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}",
672
684
  )
673
685
  tpu_node = self.tpu_client.get_node(get_node_request)
674
- except google.api_core.exceptions.NotFound:
675
- tpu_node = None
676
686
 
677
- if tpu_node is not None:
678
687
  # Python API to attach a disk to a TPU is not documented,
679
688
  # so we follow the code from the gcloud CLI:
680
689
  # https://github.com/twistedpair/google-cloud-sdk/blob/26ab5a281d56b384cc25750f3279a27afe5b499f/google-cloud-sdk/lib/googlecloudsdk/command_lib/compute/tpus/tpu_vm/util.py#L113
@@ -711,7 +720,6 @@ class GCPCompute(
711
720
  attached_disk.auto_delete = False
712
721
  attached_disk.device_name = f"pd-{volume.volume_id}"
713
722
  device_name = attached_disk.device_name
714
-
715
723
  operation = self.instances_client.attach_disk(
716
724
  project=self.config.project_id,
717
725
  zone=zone,
@@ -720,13 +728,16 @@ class GCPCompute(
720
728
  )
721
729
  gcp_resources.wait_for_extended_operation(operation, "persistent disk attachment")
722
730
  except google.api_core.exceptions.NotFound:
723
- raise ComputeError("Persistent disk or instance not found")
731
+ raise ComputeError("Disk or instance not found")
724
732
  logger.debug(
725
733
  "Attached persistent disk for volume %s to instance %s", volume.volume_id, instance_id
726
734
  )
727
735
  return VolumeAttachmentData(device_name=device_name)
728
736
 
729
- def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
737
+ def detach_volume(
738
+ self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
739
+ ):
740
+ instance_id = provisioning_data.instance_id
730
741
  logger.debug(
731
742
  "Detaching persistent disk for volume %s from instance %s",
732
743
  volume.volume_id,
@@ -734,17 +745,16 @@ class GCPCompute(
734
745
  )
735
746
  zone = get_or_error(volume.provisioning_data).availability_zone
736
747
  attachment_data = get_or_error(volume.get_attachment_data_for_instance(instance_id))
737
- # This method has no information if the instance is a TPU or a VM,
738
- # so we first try to see if there is a TPU with such name
739
- try:
740
- get_node_request = tpu_v2.GetNodeRequest(
741
- name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}",
742
- )
743
- tpu_node = self.tpu_client.get_node(get_node_request)
744
- except google.api_core.exceptions.NotFound:
745
- tpu_node = None
748
+ is_tpu = _is_tpu_provisioning_data(provisioning_data)
749
+ if is_tpu:
750
+ try:
751
+ get_node_request = tpu_v2.GetNodeRequest(
752
+ name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}",
753
+ )
754
+ tpu_node = self.tpu_client.get_node(get_node_request)
755
+ except google.api_core.exceptions.NotFound:
756
+ raise ComputeError("Instance not found")
746
757
 
747
- if tpu_node is not None:
748
758
  source_disk = (
749
759
  f"projects/{self.config.project_id}/zones/{zone}/disks/{volume.volume_id}"
750
760
  )
@@ -815,6 +825,11 @@ def _supported_instances_and_zones(
815
825
  if _is_tpu(offer.instance.name) and not _is_single_host_tpu(offer.instance.name):
816
826
  return False
817
827
  for family in [
828
+ "m4-",
829
+ "c4-",
830
+ "n4-",
831
+ "h3-",
832
+ "n2-",
818
833
  "e2-medium",
819
834
  "e2-standard-",
820
835
  "e2-highmem-",
@@ -1001,3 +1016,11 @@ def _get_tpu_data_disk_for_volume(project_id: str, volume: Volume) -> tpu_v2.Att
1001
1016
  mode=tpu_v2.AttachedDisk.DiskMode.READ_WRITE,
1002
1017
  )
1003
1018
  return attached_disk
1019
+
1020
+
1021
+ def _is_tpu_provisioning_data(provisioning_data: JobProvisioningData) -> bool:
1022
+ is_tpu = False
1023
+ if provisioning_data.backend_data:
1024
+ backend_data_dict = json.loads(provisioning_data.backend_data)
1025
+ is_tpu = backend_data_dict.get("is_tpu", False)
1026
+ return is_tpu
@@ -140,7 +140,10 @@ def create_instance_struct(
140
140
  initialize_params = compute_v1.AttachedDiskInitializeParams()
141
141
  initialize_params.source_image = image_id
142
142
  initialize_params.disk_size_gb = disk_size
143
- initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced"
143
+ if instance_type_supports_persistent_disk(machine_type):
144
+ initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced"
145
+ else:
146
+ initialize_params.disk_type = f"zones/{zone}/diskTypes/hyperdisk-balanced"
144
147
  disk.initialize_params = initialize_params
145
148
  instance.disks = [disk]
146
149
 
@@ -421,7 +424,7 @@ def wait_for_extended_operation(
421
424
 
422
425
  if operation.error_code:
423
426
  # Write only debug logs here.
424
- # The unexpected errors will be propagated and logged appropriatly by the caller.
427
+ # The unexpected errors will be propagated and logged appropriately by the caller.
425
428
  logger.debug(
426
429
  "Error during %s: [Code: %s]: %s",
427
430
  verbose_name,
@@ -462,3 +465,16 @@ def get_placement_policy_resource_name(
462
465
  placement_policy: str,
463
466
  ) -> str:
464
467
  return f"projects/{project_id}/regions/{region}/resourcePolicies/{placement_policy}"
468
+
469
+
470
+ def instance_type_supports_persistent_disk(instance_type_name: str) -> bool:
471
+ return not any(
472
+ instance_type_name.startswith(series)
473
+ for series in [
474
+ "m4-",
475
+ "c4-",
476
+ "n4-",
477
+ "h3-",
478
+ "v6e",
479
+ ]
480
+ )
@@ -110,8 +110,10 @@ class LocalCompute(
110
110
  def delete_volume(self, volume: Volume):
111
111
  pass
112
112
 
113
- def attach_volume(self, volume: Volume, instance_id: str):
113
+ def attach_volume(self, volume: Volume, provisioning_data: JobProvisioningData):
114
114
  pass
115
115
 
116
- def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
116
+ def detach_volume(
117
+ self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
118
+ ):
117
119
  pass
@@ -19,9 +19,6 @@ from dstack._internal.core.models.backends.base import (
19
19
  BackendType,
20
20
  )
21
21
 
22
- # TODO: Add all supported regions and default regions
23
- REGIONS = []
24
-
25
22
 
26
23
  class {{ backend_name }}Configurator(Configurator):
27
24
  TYPE = BackendType.{{ backend_name|upper }}
@@ -31,13 +28,11 @@ class {{ backend_name }}Configurator(Configurator):
31
28
  self, config: {{ backend_name }}BackendConfigWithCreds, default_creds_enabled: bool
32
29
  ):
33
30
  self._validate_creds(config.creds)
34
- # TODO: Validate additional config parameters if any
31
+ # TODO: If possible, validate config.regions and any other config parameters
35
32
 
36
33
  def create_backend(
37
34
  self, project_name: str, config: {{ backend_name }}BackendConfigWithCreds
38
35
  ) -> BackendRecord:
39
- if config.regions is None:
40
- config.regions = REGIONS
41
36
  return BackendRecord(
42
37
  config={{ backend_name }}StoredConfig(
43
38
  **{{ backend_name }}BackendConfig.__response__.parse_obj(config).dict()
@@ -22,6 +22,7 @@ class {{ backend_name }}BackendConfig(CoreModel):
22
22
  It also serves as a base class for other backend config models.
23
23
  Should not include creds.
24
24
  """
25
+
25
26
  type: Annotated[
26
27
  Literal["{{ backend_name|lower }}"],
27
28
  Field(description="The type of backend"),
@@ -37,6 +38,7 @@ class {{ backend_name }}BackendConfigWithCreds({{ backend_name }}BackendConfig):
37
38
  """
38
39
  Same as `{{ backend_name }}BackendConfig` but also includes creds.
39
40
  """
41
+
40
42
  creds: Annotated[Any{{ backend_name }}Creds, Field(description="The credentials")]
41
43
 
42
44
 
@@ -48,6 +50,7 @@ class {{ backend_name }}StoredConfig({{ backend_name }}BackendConfig):
48
50
  The backend config used for config parameters in the DB.
49
51
  Can extend `{{ backend_name }}BackendConfig` with additional parameters.
50
52
  """
53
+
51
54
  pass
52
55
 
53
56
 
@@ -55,4 +58,5 @@ class {{ backend_name }}Config({{ backend_name }}StoredConfig):
55
58
  """
56
59
  The backend config used by `{{ backend_name }}Backend` and `{{ backend_name }}Compute`.
57
60
  """
61
+
58
62
  creds: Any{{ backend_name }}Creds
@@ -440,7 +440,7 @@ class ServiceConfigurationParams(CoreModel):
440
440
  raise ValueError("The minimum number of replicas must be greater than or equal to 0")
441
441
  if v.max < v.min:
442
442
  raise ValueError(
443
- "The maximum number of replicas must be greater than or equal to the minium number of replicas"
443
+ "The maximum number of replicas must be greater than or equal to the minimum number of replicas"
444
444
  )
445
445
  return v
446
446
 
@@ -20,6 +20,7 @@ from dstack._internal.core.models.profiles import (
20
20
  parse_idle_duration,
21
21
  )
22
22
  from dstack._internal.core.models.resources import Range, ResourcesSpec
23
+ from dstack._internal.utils.common import list_enum_values_for_annotation
23
24
  from dstack._internal.utils.json_schema import add_extra_schema_types
24
25
  from dstack._internal.utils.tags import tags_validator
25
26
 
@@ -207,7 +208,11 @@ class InstanceGroupParams(CoreModel):
207
208
  spot_policy: Annotated[
208
209
  Optional[SpotPolicy],
209
210
  Field(
210
- description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`"
211
+ description=(
212
+ "The policy for provisioning spot or on-demand instances:"
213
+ f" {list_enum_values_for_annotation(SpotPolicy)}."
214
+ f" Defaults to `{SpotPolicy.ONDEMAND.value}`"
215
+ )
211
216
  ),
212
217
  ] = None
213
218
  retry: Annotated[