dstack 0.19.1__py3-none-any.whl → 0.19.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (68) hide show
  1. dstack/_internal/cli/commands/metrics.py +138 -0
  2. dstack/_internal/cli/commands/stats.py +5 -119
  3. dstack/_internal/cli/main.py +2 -0
  4. dstack/_internal/cli/services/profile.py +9 -0
  5. dstack/_internal/core/backends/aws/configurator.py +1 -0
  6. dstack/_internal/core/backends/base/compute.py +4 -1
  7. dstack/_internal/core/backends/base/models.py +7 -7
  8. dstack/_internal/core/backends/configurators.py +9 -0
  9. dstack/_internal/core/backends/cudo/compute.py +2 -0
  10. dstack/_internal/core/backends/cudo/configurator.py +0 -13
  11. dstack/_internal/core/backends/datacrunch/compute.py +118 -32
  12. dstack/_internal/core/backends/datacrunch/configurator.py +16 -11
  13. dstack/_internal/core/backends/gcp/compute.py +140 -26
  14. dstack/_internal/core/backends/gcp/configurator.py +2 -0
  15. dstack/_internal/core/backends/gcp/features/__init__.py +0 -0
  16. dstack/_internal/core/backends/gcp/features/tcpx.py +34 -0
  17. dstack/_internal/core/backends/gcp/models.py +13 -1
  18. dstack/_internal/core/backends/gcp/resources.py +64 -27
  19. dstack/_internal/core/backends/lambdalabs/compute.py +2 -4
  20. dstack/_internal/core/backends/lambdalabs/configurator.py +0 -21
  21. dstack/_internal/core/backends/models.py +8 -0
  22. dstack/_internal/core/backends/nebius/__init__.py +0 -0
  23. dstack/_internal/core/backends/nebius/backend.py +16 -0
  24. dstack/_internal/core/backends/nebius/compute.py +272 -0
  25. dstack/_internal/core/backends/nebius/configurator.py +74 -0
  26. dstack/_internal/core/backends/nebius/models.py +108 -0
  27. dstack/_internal/core/backends/nebius/resources.py +240 -0
  28. dstack/_internal/core/backends/tensordock/api_client.py +5 -4
  29. dstack/_internal/core/backends/tensordock/compute.py +2 -15
  30. dstack/_internal/core/errors.py +14 -0
  31. dstack/_internal/core/models/backends/base.py +2 -0
  32. dstack/_internal/core/models/profiles.py +3 -0
  33. dstack/_internal/proxy/lib/schemas/model_proxy.py +3 -3
  34. dstack/_internal/server/background/tasks/process_instances.py +12 -7
  35. dstack/_internal/server/background/tasks/process_running_jobs.py +20 -0
  36. dstack/_internal/server/background/tasks/process_submitted_jobs.py +3 -2
  37. dstack/_internal/server/routers/prometheus.py +5 -0
  38. dstack/_internal/server/security/permissions.py +19 -1
  39. dstack/_internal/server/services/instances.py +14 -6
  40. dstack/_internal/server/services/jobs/__init__.py +3 -3
  41. dstack/_internal/server/services/offers.py +4 -2
  42. dstack/_internal/server/services/runs.py +0 -2
  43. dstack/_internal/server/statics/index.html +1 -1
  44. dstack/_internal/server/statics/{main-da9f8c06a69c20dac23e.css → main-8f9c66f404e9c7e7e020.css} +1 -1
  45. dstack/_internal/server/statics/{main-4a0fe83e84574654e397.js → main-e190de603dc1e9f485ec.js} +7306 -149
  46. dstack/_internal/server/statics/{main-4a0fe83e84574654e397.js.map → main-e190de603dc1e9f485ec.js.map} +1 -1
  47. dstack/_internal/utils/common.py +8 -2
  48. dstack/_internal/utils/event_loop.py +30 -0
  49. dstack/_internal/utils/ignore.py +2 -0
  50. dstack/api/server/_fleets.py +3 -5
  51. dstack/api/server/_runs.py +6 -7
  52. dstack/version.py +1 -1
  53. {dstack-0.19.1.dist-info → dstack-0.19.3.dist-info}/METADATA +27 -11
  54. {dstack-0.19.1.dist-info → dstack-0.19.3.dist-info}/RECORD +67 -57
  55. tests/_internal/core/backends/datacrunch/test_configurator.py +6 -2
  56. tests/_internal/server/background/tasks/test_process_instances.py +4 -2
  57. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +29 -0
  58. tests/_internal/server/routers/test_backends.py +116 -0
  59. tests/_internal/server/routers/test_fleets.py +2 -0
  60. tests/_internal/server/routers/test_prometheus.py +21 -0
  61. tests/_internal/server/routers/test_runs.py +4 -0
  62. tests/_internal/utils/test_common.py +16 -1
  63. tests/_internal/utils/test_event_loop.py +18 -0
  64. dstack/_internal/core/backends/datacrunch/api_client.py +0 -77
  65. {dstack-0.19.1.dist-info → dstack-0.19.3.dist-info}/LICENSE.md +0 -0
  66. {dstack-0.19.1.dist-info → dstack-0.19.3.dist-info}/WHEEL +0 -0
  67. {dstack-0.19.1.dist-info → dstack-0.19.3.dist-info}/entry_points.txt +0 -0
  68. {dstack-0.19.1.dist-info → dstack-0.19.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,138 @@
1
+ import argparse
2
+ import time
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+ from rich.live import Live
6
+ from rich.table import Table
7
+
8
+ from dstack._internal.cli.commands import APIBaseCommand
9
+ from dstack._internal.cli.services.completion import RunNameCompleter
10
+ from dstack._internal.cli.utils.common import (
11
+ LIVE_TABLE_PROVISION_INTERVAL_SECS,
12
+ LIVE_TABLE_REFRESH_RATE_PER_SEC,
13
+ add_row_from_dict,
14
+ console,
15
+ )
16
+ from dstack._internal.core.errors import CLIError
17
+ from dstack._internal.core.models.instances import Resources
18
+ from dstack._internal.core.models.metrics import JobMetrics
19
+ from dstack.api._public import Client
20
+ from dstack.api._public.runs import Run
21
+
22
+
23
+ class MetricsCommand(APIBaseCommand):
24
+ NAME = "metrics"
25
+ DESCRIPTION = "Show run metrics"
26
+
27
+ def _register(self):
28
+ super()._register()
29
+ self._parser.add_argument("run_name").completer = RunNameCompleter()
30
+ self._parser.add_argument(
31
+ "-w",
32
+ "--watch",
33
+ help="Watch run metrics in realtime",
34
+ action="store_true",
35
+ )
36
+
37
+ def _command(self, args: argparse.Namespace):
38
+ super()._command(args)
39
+ run = self.api.runs.get(run_name=args.run_name)
40
+ if run is None:
41
+ raise CLIError(f"Run {args.run_name} not found")
42
+ if run.status.is_finished():
43
+ raise CLIError(f"Run {args.run_name} is finished")
44
+ metrics = _get_run_jobs_metrics(api=self.api, run=run)
45
+
46
+ if not args.watch:
47
+ console.print(_get_metrics_table(run, metrics))
48
+ return
49
+
50
+ try:
51
+ with Live(console=console, refresh_per_second=LIVE_TABLE_REFRESH_RATE_PER_SEC) as live:
52
+ while True:
53
+ live.update(_get_metrics_table(run, metrics))
54
+ time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
55
+ run = self.api.runs.get(run_name=args.run_name)
56
+ if run is None:
57
+ raise CLIError(f"Run {args.run_name} not found")
58
+ if run.status.is_finished():
59
+ raise CLIError(f"Run {args.run_name} is finished")
60
+ metrics = _get_run_jobs_metrics(api=self.api, run=run)
61
+ except KeyboardInterrupt:
62
+ pass
63
+
64
+
65
+ def _get_run_jobs_metrics(api: Client, run: Run) -> List[JobMetrics]:
66
+ metrics = []
67
+ for job in run._run.jobs:
68
+ job_metrics = api.client.metrics.get_job_metrics(
69
+ project_name=api.project,
70
+ run_name=run.name,
71
+ replica_num=job.job_spec.replica_num,
72
+ job_num=job.job_spec.job_num,
73
+ )
74
+ metrics.append(job_metrics)
75
+ return metrics
76
+
77
+
78
+ def _get_metrics_table(run: Run, metrics: List[JobMetrics]) -> Table:
79
+ table = Table(box=None)
80
+ table.add_column("NAME", style="bold", no_wrap=True)
81
+ table.add_column("CPU")
82
+ table.add_column("MEMORY")
83
+ table.add_column("GPU")
84
+
85
+ run_row: Dict[Union[str, int], Any] = {"NAME": run.name}
86
+ if len(run._run.jobs) != 1:
87
+ add_row_from_dict(table, run_row)
88
+
89
+ for job, job_metrics in zip(run._run.jobs, metrics):
90
+ jrd = job.job_submissions[-1].job_runtime_data
91
+ jpd = job.job_submissions[-1].job_provisioning_data
92
+ resources: Optional[Resources] = None
93
+ if jrd is not None and jrd.offer is not None:
94
+ resources = jrd.offer.instance.resources
95
+ elif jpd is not None:
96
+ resources = jpd.instance_type.resources
97
+ cpu_usage = _get_metric_value(job_metrics, "cpu_usage_percent")
98
+ if cpu_usage is not None:
99
+ if resources is not None:
100
+ cpu_usage = cpu_usage / resources.cpus
101
+ cpu_usage = f"{cpu_usage:.0f}%"
102
+ memory_usage = _get_metric_value(job_metrics, "memory_working_set_bytes")
103
+ if memory_usage is not None:
104
+ memory_usage = f"{round(memory_usage / 1024 / 1024)}MB"
105
+ if resources is not None:
106
+ memory_usage += f"/{resources.memory_mib}MB"
107
+ gpu_metrics = ""
108
+ gpus_detected_num = _get_metric_value(job_metrics, "gpus_detected_num")
109
+ if gpus_detected_num is not None:
110
+ for i in range(gpus_detected_num):
111
+ gpu_memory_usage = _get_metric_value(job_metrics, f"gpu_memory_usage_bytes_gpu{i}")
112
+ gpu_util_percent = _get_metric_value(job_metrics, f"gpu_util_percent_gpu{i}")
113
+ if gpu_memory_usage is not None:
114
+ if i != 0:
115
+ gpu_metrics += "\n"
116
+ gpu_metrics += f"#{i} {round(gpu_memory_usage / 1024 / 1024)}MB"
117
+ if resources is not None:
118
+ gpu_metrics += f"/{resources.gpus[i].memory_mib}MB"
119
+ gpu_metrics += f" {gpu_util_percent}% Util"
120
+
121
+ job_row: Dict[Union[str, int], Any] = {
122
+ "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}",
123
+ "CPU": cpu_usage or "-",
124
+ "MEMORY": memory_usage or "-",
125
+ "GPU": gpu_metrics or "-",
126
+ }
127
+ if len(run._run.jobs) == 1:
128
+ job_row.update(run_row)
129
+ add_row_from_dict(table, job_row)
130
+
131
+ return table
132
+
133
+
134
+ def _get_metric_value(job_metrics: JobMetrics, name: str) -> Optional[Any]:
135
+ for metric in job_metrics.metrics:
136
+ if metric.name == name:
137
+ return metric.values[-1]
138
+ return None
@@ -1,128 +1,14 @@
1
1
  import argparse
2
- import time
3
- from typing import Any, Dict, List, Optional, Union
4
2
 
5
- from rich.live import Live
6
- from rich.table import Table
3
+ from dstack._internal.cli.commands.metrics import MetricsCommand
4
+ from dstack._internal.utils.logging import get_logger
7
5
 
8
- from dstack._internal.cli.commands import APIBaseCommand
9
- from dstack._internal.cli.services.completion import RunNameCompleter
10
- from dstack._internal.cli.utils.common import (
11
- LIVE_TABLE_PROVISION_INTERVAL_SECS,
12
- LIVE_TABLE_REFRESH_RATE_PER_SEC,
13
- add_row_from_dict,
14
- console,
15
- )
16
- from dstack._internal.core.errors import CLIError
17
- from dstack._internal.core.models.metrics import JobMetrics
18
- from dstack.api._public import Client
19
- from dstack.api._public.runs import Run
6
+ logger = get_logger(__name__)
20
7
 
21
8
 
22
- class StatsCommand(APIBaseCommand):
9
+ class StatsCommand(MetricsCommand):
23
10
  NAME = "stats"
24
- DESCRIPTION = "Show run stats"
25
-
26
- def _register(self):
27
- super()._register()
28
- self._parser.add_argument("run_name").completer = RunNameCompleter()
29
- self._parser.add_argument(
30
- "-w",
31
- "--watch",
32
- help="Watch run stats in realtime",
33
- action="store_true",
34
- )
35
11
 
36
12
  def _command(self, args: argparse.Namespace):
13
+ logger.warning("`dstack stats` is deprecated in favor of `dstack metrics`")
37
14
  super()._command(args)
38
- run = self.api.runs.get(run_name=args.run_name)
39
- if run is None:
40
- raise CLIError(f"Run {args.run_name} not found")
41
- if run.status.is_finished():
42
- raise CLIError(f"Run {args.run_name} is finished")
43
- metrics = _get_run_jobs_metrics(api=self.api, run=run)
44
-
45
- if not args.watch:
46
- console.print(_get_stats_table(run, metrics))
47
- return
48
-
49
- try:
50
- with Live(console=console, refresh_per_second=LIVE_TABLE_REFRESH_RATE_PER_SEC) as live:
51
- while True:
52
- live.update(_get_stats_table(run, metrics))
53
- time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
54
- run = self.api.runs.get(run_name=args.run_name)
55
- if run is None:
56
- raise CLIError(f"Run {args.run_name} not found")
57
- if run.status.is_finished():
58
- raise CLIError(f"Run {args.run_name} is finished")
59
- metrics = _get_run_jobs_metrics(api=self.api, run=run)
60
- except KeyboardInterrupt:
61
- pass
62
-
63
-
64
- def _get_run_jobs_metrics(api: Client, run: Run) -> List[JobMetrics]:
65
- metrics = []
66
- for job in run._run.jobs:
67
- job_metrics = api.client.metrics.get_job_metrics(
68
- project_name=api.project,
69
- run_name=run.name,
70
- replica_num=job.job_spec.replica_num,
71
- job_num=job.job_spec.job_num,
72
- )
73
- metrics.append(job_metrics)
74
- return metrics
75
-
76
-
77
- def _get_stats_table(run: Run, metrics: List[JobMetrics]) -> Table:
78
- table = Table(box=None)
79
- table.add_column("NAME", style="bold", no_wrap=True)
80
- table.add_column("CPU")
81
- table.add_column("MEMORY")
82
- table.add_column("GPU")
83
-
84
- run_row: Dict[Union[str, int], Any] = {"NAME": run.name}
85
- if len(run._run.jobs) != 1:
86
- add_row_from_dict(table, run_row)
87
-
88
- for job, job_metrics in zip(run._run.jobs, metrics):
89
- cpu_usage = _get_metric_value(job_metrics, "cpu_usage_percent")
90
- if cpu_usage is not None:
91
- cpu_usage = f"{cpu_usage}%"
92
- memory_usage = _get_metric_value(job_metrics, "memory_working_set_bytes")
93
- if memory_usage is not None:
94
- memory_usage = f"{round(memory_usage / 1024 / 1024)}MB"
95
- if job.job_submissions[-1].job_provisioning_data is not None:
96
- memory_usage += f"/{job.job_submissions[-1].job_provisioning_data.instance_type.resources.memory_mib}MB"
97
- gpu_stats = ""
98
- gpus_detected_num = _get_metric_value(job_metrics, "gpus_detected_num")
99
- if gpus_detected_num is not None:
100
- for i in range(gpus_detected_num):
101
- gpu_memory_usage = _get_metric_value(job_metrics, f"gpu_memory_usage_bytes_gpu{i}")
102
- gpu_util_percent = _get_metric_value(job_metrics, f"gpu_util_percent_gpu{i}")
103
- if gpu_memory_usage is not None:
104
- if i != 0:
105
- gpu_stats += "\n"
106
- gpu_stats += f"#{i} {round(gpu_memory_usage / 1024 / 1024)}MB"
107
- if job.job_submissions[-1].job_provisioning_data is not None:
108
- gpu_stats += f"/{job.job_submissions[-1].job_provisioning_data.instance_type.resources.gpus[i].memory_mib}MB"
109
- gpu_stats += f" {gpu_util_percent}% Util"
110
-
111
- job_row: Dict[Union[str, int], Any] = {
112
- "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}",
113
- "CPU": cpu_usage or "-",
114
- "MEMORY": memory_usage or "-",
115
- "GPU": gpu_stats or "-",
116
- }
117
- if len(run._run.jobs) == 1:
118
- job_row.update(run_row)
119
- add_row_from_dict(table, job_row)
120
-
121
- return table
122
-
123
-
124
- def _get_metric_value(job_metrics: JobMetrics, name: str) -> Optional[Any]:
125
- for metric in job_metrics.metrics:
126
- if metric.name == name:
127
- return metric.values[-1]
128
- return None
@@ -13,6 +13,7 @@ from dstack._internal.cli.commands.fleet import FleetCommand
13
13
  from dstack._internal.cli.commands.gateway import GatewayCommand
14
14
  from dstack._internal.cli.commands.init import InitCommand
15
15
  from dstack._internal.cli.commands.logs import LogsCommand
16
+ from dstack._internal.cli.commands.metrics import MetricsCommand
16
17
  from dstack._internal.cli.commands.ps import PsCommand
17
18
  from dstack._internal.cli.commands.server import ServerCommand
18
19
  from dstack._internal.cli.commands.stats import StatsCommand
@@ -65,6 +66,7 @@ def main():
65
66
  GatewayCommand.register(subparsers)
66
67
  InitCommand.register(subparsers)
67
68
  LogsCommand.register(subparsers)
69
+ MetricsCommand.register(subparsers)
68
70
  PsCommand.register(subparsers)
69
71
  ServerCommand.register(subparsers)
70
72
  StatsCommand.register(subparsers)
@@ -65,6 +65,13 @@ def register_profile_args(parser: argparse.ArgumentParser):
65
65
  )
66
66
 
67
67
  fleets_group = parser.add_argument_group("Fleets")
68
+ fleets_group.add_argument(
69
+ "--fleet",
70
+ action="append",
71
+ metavar="NAME",
72
+ dest="fleets",
73
+ help="Consider only instances from the specified fleet(s) for reuse",
74
+ )
68
75
  fleets_group_exc = fleets_group.add_mutually_exclusive_group()
69
76
  fleets_group_exc.add_argument(
70
77
  "-R",
@@ -147,6 +154,8 @@ def apply_profile_args(
147
154
  if args.max_duration is not None:
148
155
  profile_settings.max_duration = args.max_duration
149
156
 
157
+ if args.fleets:
158
+ profile_settings.fleets = args.fleets
150
159
  if args.idle_duration is not None:
151
160
  profile_settings.idle_duration = args.idle_duration
152
161
  elif args.dont_destroy:
@@ -34,6 +34,7 @@ from dstack._internal.utils.logging import get_logger
34
34
 
35
35
  logger = get_logger(__name__)
36
36
 
37
+ # where dstack OS images are published
37
38
  REGIONS = [
38
39
  ("US East, N. Virginia", "us-east-1"),
39
40
  ("US East, Ohio", "us-east-2"),
@@ -94,6 +94,9 @@ class Compute(ABC):
94
94
  """
95
95
  Terminates an instance by `instance_id`.
96
96
  If the instance does not exist, it should not raise errors but return silently.
97
+
98
+ Should return ASAP. If required to wait for some operation, raise `NotYetTerminated`.
99
+ In this case, the method will be called again after a few seconds.
97
100
  """
98
101
  pass
99
102
 
@@ -525,7 +528,7 @@ def get_run_shim_script(is_privileged: bool, pjrt_device: Optional[str]) -> List
525
528
  pjrt_device_env = f"--pjrt-device={pjrt_device}" if pjrt_device else ""
526
529
 
527
530
  return [
528
- f"nohup dstack-shim {privileged_flag} {pjrt_device_env} &",
531
+ f"nohup {DSTACK_SHIM_BINARY_PATH} {privileged_flag} {pjrt_device_env} &",
529
532
  ]
530
533
 
531
534
 
@@ -1,14 +1,14 @@
1
1
  from pathlib import Path
2
2
 
3
3
 
4
- def fill_data(values: dict):
5
- if values.get("data") is not None:
4
+ def fill_data(values: dict, filename_field: str = "filename", data_field: str = "data") -> dict:
5
+ if values.get(data_field) is not None:
6
6
  return values
7
- if "filename" not in values:
8
- raise ValueError()
7
+ if (filename := values.get(filename_field)) is None:
8
+ raise ValueError(f"Either `{filename_field}` or `{data_field}` must be specified")
9
9
  try:
10
- with open(Path(values["filename"]).expanduser()) as f:
11
- values["data"] = f.read()
10
+ with open(Path(filename).expanduser()) as f:
11
+ values[data_field] = f.read()
12
12
  except OSError:
13
- raise ValueError(f"No such file {values['filename']}")
13
+ raise ValueError(f"No such file {filename}")
14
14
  return values
@@ -63,6 +63,15 @@ try:
63
63
  except ImportError:
64
64
  pass
65
65
 
66
+ try:
67
+ from dstack._internal.core.backends.nebius.configurator import (
68
+ NebiusConfigurator,
69
+ )
70
+
71
+ _CONFIGURATOR_CLASSES.append(NebiusConfigurator)
72
+ except ImportError:
73
+ pass
74
+
66
75
  try:
67
76
  from dstack._internal.core.backends.oci.configurator import OCIConfigurator
68
77
 
@@ -41,6 +41,7 @@ class CudoCompute(
41
41
  ) -> List[InstanceOfferWithAvailability]:
42
42
  offers = get_catalog_offers(
43
43
  backend=BackendType.CUDO,
44
+ locations=self.config.regions,
44
45
  requirements=requirements,
45
46
  )
46
47
  offers = [
@@ -48,6 +49,7 @@ class CudoCompute(
48
49
  **offer.dict(), availability=InstanceAvailability.AVAILABLE
49
50
  )
50
51
  for offer in offers
52
+ # in-hyderabad-1 is known to have provisioning issues
51
53
  if offer.region not in ["in-hyderabad-1"]
52
54
  ]
53
55
  return offers
@@ -17,17 +17,6 @@ from dstack._internal.core.backends.cudo.models import (
17
17
  )
18
18
  from dstack._internal.core.models.backends.base import BackendType
19
19
 
20
- REGIONS = [
21
- "no-luster-1",
22
- "se-smedjebacken-1",
23
- "gb-london-1",
24
- "se-stockholm-1",
25
- "us-newyork-1",
26
- "us-santaclara-1",
27
- ]
28
-
29
- DEFAULT_REGION = "no-luster-1"
30
-
31
20
 
32
21
  class CudoConfigurator(Configurator):
33
22
  TYPE = BackendType.CUDO
@@ -39,8 +28,6 @@ class CudoConfigurator(Configurator):
39
28
  def create_backend(
40
29
  self, project_name: str, config: CudoBackendConfigWithCreds
41
30
  ) -> BackendRecord:
42
- if config.regions is None:
43
- config.regions = REGIONS
44
31
  return BackendRecord(
45
32
  config=CudoStoredConfig(
46
33
  **CudoBackendConfig.__response__.parse_obj(config).dict()
@@ -1,5 +1,9 @@
1
1
  from typing import Dict, List, Optional
2
2
 
3
+ from datacrunch import DataCrunchClient
4
+ from datacrunch.exceptions import APIException
5
+ from datacrunch.instances.instances import Instance
6
+
3
7
  from dstack._internal.core.backends.base.backend import Compute
4
8
  from dstack._internal.core.backends.base.compute import (
5
9
  ComputeWithCreateInstanceSupport,
@@ -7,8 +11,8 @@ from dstack._internal.core.backends.base.compute import (
7
11
  get_shim_commands,
8
12
  )
9
13
  from dstack._internal.core.backends.base.offers import get_catalog_offers
10
- from dstack._internal.core.backends.datacrunch.api_client import DataCrunchAPIClient
11
14
  from dstack._internal.core.backends.datacrunch.models import DataCrunchConfig
15
+ from dstack._internal.core.errors import NoCapacityError
12
16
  from dstack._internal.core.models.backends.base import BackendType
13
17
  from dstack._internal.core.models.instances import (
14
18
  InstanceAvailability,
@@ -19,14 +23,12 @@ from dstack._internal.core.models.instances import (
19
23
  from dstack._internal.core.models.resources import Memory, Range
20
24
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
21
25
  from dstack._internal.utils.logging import get_logger
26
+ from dstack._internal.utils.ssh import get_public_key_fingerprint
22
27
 
23
28
  logger = get_logger("datacrunch.compute")
24
29
 
25
30
  MAX_INSTANCE_NAME_LEN = 60
26
31
 
27
- # Ubuntu 22.04 + CUDA 12.0 + Docker
28
- # from API https://datacrunch.stoplight.io/docs/datacrunch-public/c46ab45dbc508-get-all-image-types
29
- IMAGE_ID = "2088da25-bb0d-41cc-a191-dccae45d96fd"
30
32
  IMAGE_SIZE = Memory.parse("50GB")
31
33
 
32
34
  CONFIGURABLE_DISK_SIZE = Range[Memory](min=IMAGE_SIZE, max=None)
@@ -39,7 +41,10 @@ class DataCrunchCompute(
39
41
  def __init__(self, config: DataCrunchConfig):
40
42
  super().__init__()
41
43
  self.config = config
42
- self.api_client = DataCrunchAPIClient(config.creds.client_id, config.creds.client_secret)
44
+ self.client = DataCrunchClient(
45
+ client_id=self.config.creds.client_id,
46
+ client_secret=self.config.creds.client_secret,
47
+ )
43
48
 
44
49
  def get_offers(
45
50
  self, requirements: Optional[Requirements] = None
@@ -56,14 +61,12 @@ class DataCrunchCompute(
56
61
  def _get_offers_with_availability(
57
62
  self, offers: List[InstanceOffer]
58
63
  ) -> List[InstanceOfferWithAvailability]:
59
- raw_availabilities: List[Dict] = self.api_client.client.instances.get_availabilities()
64
+ raw_availabilities: List[Dict] = self.client.instances.get_availabilities()
60
65
 
61
66
  region_availabilities = {}
62
67
  for location in raw_availabilities:
63
68
  location_code = location["location_code"]
64
69
  availabilities = location["availabilities"]
65
- if location_code not in self.config.regions:
66
- continue
67
70
  for name in availabilities:
68
71
  key = (name, location_code)
69
72
  region_availabilities[key] = InstanceAvailability.AVAILABLE
@@ -91,50 +94,50 @@ class DataCrunchCompute(
91
94
  for ssh_public_key in public_keys:
92
95
  ssh_ids.append(
93
96
  # datacrunch allows you to use the same name
94
- self.api_client.get_or_create_ssh_key(
97
+ _get_or_create_ssh_key(
98
+ client=self.client,
95
99
  name=f"dstack-{instance_config.instance_name}.key",
96
100
  public_key=ssh_public_key,
97
101
  )
98
102
  )
99
103
 
100
104
  commands = get_shim_commands(authorized_keys=public_keys)
101
-
102
105
  startup_script = " ".join([" && ".join(commands)])
103
106
  script_name = f"dstack-{instance_config.instance_name}.sh"
104
-
105
- logger.debug("startup script:", startup_script)
106
-
107
- startup_script_ids = self.api_client.get_or_create_startup_scrpit(
108
- name=script_name, script=startup_script
107
+ startup_script_ids = _get_or_create_startup_scrpit(
108
+ client=self.client,
109
+ name=script_name,
110
+ script=startup_script,
109
111
  )
110
112
 
111
113
  disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
112
-
113
- instance = self.api_client.deploy_instance(
114
- instance_type=instance_offer.instance.name,
115
- ssh_key_ids=ssh_ids,
116
- startup_script_id=startup_script_ids,
117
- hostname=instance_name,
118
- description=instance_name,
119
- image=IMAGE_ID,
120
- disk_size=disk_size,
121
- location=instance_offer.region,
122
- )
114
+ image_id = _get_vm_image_id(instance_offer)
123
115
 
124
116
  logger.debug(
125
- "deploy_instance",
117
+ "Deploying datacrunch instance",
126
118
  {
127
119
  "instance_type": instance_offer.instance.name,
128
120
  "ssh_key_ids": ssh_ids,
129
121
  "startup_script_id": startup_script_ids,
130
122
  "hostname": instance_name,
131
123
  "description": instance_name,
132
- "image": IMAGE_ID,
124
+ "image": image_id,
133
125
  "disk_size": disk_size,
134
126
  "location": instance_offer.region,
135
127
  },
136
128
  )
137
-
129
+ instance = _deploy_instance(
130
+ client=self.client,
131
+ instance_type=instance_offer.instance.name,
132
+ ssh_key_ids=ssh_ids,
133
+ startup_script_id=startup_script_ids,
134
+ hostname=instance_name,
135
+ description=instance_name,
136
+ image=image_id,
137
+ disk_size=disk_size,
138
+ is_spot=instance_offer.instance.resources.spot,
139
+ location=instance_offer.region,
140
+ )
138
141
  return JobProvisioningData(
139
142
  backend=instance_offer.backend,
140
143
  instance_type=instance_offer.instance,
@@ -152,8 +155,14 @@ class DataCrunchCompute(
152
155
 
153
156
  def terminate_instance(
154
157
  self, instance_id: str, region: str, backend_data: Optional[str] = None
155
- ) -> None:
156
- self.api_client.delete_instance(instance_id)
158
+ ):
159
+ try:
160
+ self.client.instances.action(id_list=[instance_id], action="delete")
161
+ except APIException as e:
162
+ if e.message == "Invalid instance id":
163
+ logger.debug("Skipping instance %s termination. Instance not found.", instance_id)
164
+ return
165
+ raise
157
166
 
158
167
  def update_provisioning_data(
159
168
  self,
@@ -161,6 +170,83 @@ class DataCrunchCompute(
161
170
  project_ssh_public_key: str,
162
171
  project_ssh_private_key: str,
163
172
  ):
164
- instance = self.api_client.get_instance_by_id(provisioning_data.instance_id)
173
+ instance = _get_instance_by_id(self.client, provisioning_data.instance_id)
165
174
  if instance is not None and instance.status == "running":
166
175
  provisioning_data.hostname = instance.ip
176
+
177
+
178
+ def _get_vm_image_id(instance_offer: InstanceOfferWithAvailability) -> str:
179
+ # https://api.datacrunch.io/v1/images
180
+ if (
181
+ len(instance_offer.instance.resources.gpus) > 0
182
+ and instance_offer.instance.resources.gpus[0].name == "V100"
183
+ ):
184
+ # Ubuntu 22.04 + CUDA 12.0 + Docker
185
+ return "2088da25-bb0d-41cc-a191-dccae45d96fd"
186
+ # Ubuntu 24.04 + CUDA 12.8 Open + Docker
187
+ return "77777777-4f48-4249-82b3-f199fb9b701b"
188
+
189
+
190
+ def _get_or_create_ssh_key(client: DataCrunchClient, name: str, public_key: str) -> str:
191
+ fingerprint = get_public_key_fingerprint(public_key)
192
+ keys = client.ssh_keys.get()
193
+ found_keys = [key for key in keys if fingerprint == get_public_key_fingerprint(key.public_key)]
194
+ if found_keys:
195
+ key = found_keys[0]
196
+ return key.id
197
+ key = client.ssh_keys.create(name, public_key)
198
+ return key.id
199
+
200
+
201
+ def _get_or_create_startup_scrpit(client: DataCrunchClient, name: str, script: str) -> str:
202
+ scripts = client.startup_scripts.get()
203
+ found_scripts = [startup_script for startup_script in scripts if script == startup_script]
204
+ if found_scripts:
205
+ startup_script = found_scripts[0]
206
+ return startup_script.id
207
+
208
+ startup_script = client.startup_scripts.create(name, script)
209
+ return startup_script.id
210
+
211
+
212
+ def _get_instance_by_id(
213
+ client: DataCrunchClient,
214
+ instance_id: str,
215
+ ) -> Optional[Instance]:
216
+ try:
217
+ return client.instances.get_by_id(instance_id)
218
+ except APIException as e:
219
+ if e.message == "Invalid instance id":
220
+ return None
221
+ raise
222
+
223
+
224
+ def _deploy_instance(
225
+ client: DataCrunchClient,
226
+ instance_type: str,
227
+ image: str,
228
+ ssh_key_ids: List[str],
229
+ hostname: str,
230
+ description: str,
231
+ startup_script_id: str,
232
+ disk_size: int,
233
+ is_spot: bool,
234
+ location: str,
235
+ ) -> Instance:
236
+ try:
237
+ instance = client.instances.create(
238
+ instance_type=instance_type,
239
+ image=image,
240
+ ssh_key_ids=ssh_key_ids,
241
+ hostname=hostname,
242
+ description=description,
243
+ startup_script_id=startup_script_id,
244
+ is_spot=is_spot,
245
+ location=location,
246
+ os_volume={"name": "OS volume", "size": disk_size},
247
+ )
248
+ except APIException as e:
249
+ # FIXME: Catch only no capacity errors
250
+ raise NoCapacityError(f"DataCrunch API error: {e.message}")
251
+
252
+ return instance