dstack 0.19.1__py3-none-any.whl → 0.19.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. dstack/_internal/cli/commands/metrics.py +138 -0
  2. dstack/_internal/cli/commands/stats.py +5 -119
  3. dstack/_internal/cli/main.py +2 -0
  4. dstack/_internal/core/backends/base/compute.py +3 -0
  5. dstack/_internal/core/backends/base/models.py +7 -7
  6. dstack/_internal/core/backends/configurators.py +9 -0
  7. dstack/_internal/core/backends/models.py +8 -0
  8. dstack/_internal/core/backends/nebius/__init__.py +0 -0
  9. dstack/_internal/core/backends/nebius/backend.py +16 -0
  10. dstack/_internal/core/backends/nebius/compute.py +270 -0
  11. dstack/_internal/core/backends/nebius/configurator.py +74 -0
  12. dstack/_internal/core/backends/nebius/models.py +108 -0
  13. dstack/_internal/core/backends/nebius/resources.py +222 -0
  14. dstack/_internal/core/errors.py +14 -0
  15. dstack/_internal/core/models/backends/base.py +2 -0
  16. dstack/_internal/proxy/lib/schemas/model_proxy.py +3 -3
  17. dstack/_internal/server/background/tasks/process_instances.py +12 -7
  18. dstack/_internal/server/routers/prometheus.py +5 -0
  19. dstack/_internal/server/security/permissions.py +19 -1
  20. dstack/_internal/server/statics/index.html +1 -1
  21. dstack/_internal/server/statics/{main-4a0fe83e84574654e397.js → main-bcb3228138bc8483cc0b.js} +7268 -125
  22. dstack/_internal/server/statics/{main-4a0fe83e84574654e397.js.map → main-bcb3228138bc8483cc0b.js.map} +1 -1
  23. dstack/_internal/server/statics/{main-da9f8c06a69c20dac23e.css → main-c0bdaac8f1ea67d499eb.css} +1 -1
  24. dstack/_internal/utils/event_loop.py +30 -0
  25. dstack/version.py +1 -1
  26. {dstack-0.19.1.dist-info → dstack-0.19.2.dist-info}/METADATA +27 -11
  27. {dstack-0.19.1.dist-info → dstack-0.19.2.dist-info}/RECORD +35 -26
  28. tests/_internal/server/background/tasks/test_process_instances.py +4 -2
  29. tests/_internal/server/routers/test_backends.py +116 -0
  30. tests/_internal/server/routers/test_prometheus.py +21 -0
  31. tests/_internal/utils/test_event_loop.py +18 -0
  32. {dstack-0.19.1.dist-info → dstack-0.19.2.dist-info}/LICENSE.md +0 -0
  33. {dstack-0.19.1.dist-info → dstack-0.19.2.dist-info}/WHEEL +0 -0
  34. {dstack-0.19.1.dist-info → dstack-0.19.2.dist-info}/entry_points.txt +0 -0
  35. {dstack-0.19.1.dist-info → dstack-0.19.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,138 @@
1
+ import argparse
2
+ import time
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+ from rich.live import Live
6
+ from rich.table import Table
7
+
8
+ from dstack._internal.cli.commands import APIBaseCommand
9
+ from dstack._internal.cli.services.completion import RunNameCompleter
10
+ from dstack._internal.cli.utils.common import (
11
+ LIVE_TABLE_PROVISION_INTERVAL_SECS,
12
+ LIVE_TABLE_REFRESH_RATE_PER_SEC,
13
+ add_row_from_dict,
14
+ console,
15
+ )
16
+ from dstack._internal.core.errors import CLIError
17
+ from dstack._internal.core.models.instances import Resources
18
+ from dstack._internal.core.models.metrics import JobMetrics
19
+ from dstack.api._public import Client
20
+ from dstack.api._public.runs import Run
21
+
22
+
23
+ class MetricsCommand(APIBaseCommand):
24
+ NAME = "metrics"
25
+ DESCRIPTION = "Show run metrics"
26
+
27
+ def _register(self):
28
+ super()._register()
29
+ self._parser.add_argument("run_name").completer = RunNameCompleter()
30
+ self._parser.add_argument(
31
+ "-w",
32
+ "--watch",
33
+ help="Watch run metrics in realtime",
34
+ action="store_true",
35
+ )
36
+
37
+ def _command(self, args: argparse.Namespace):
38
+ super()._command(args)
39
+ run = self.api.runs.get(run_name=args.run_name)
40
+ if run is None:
41
+ raise CLIError(f"Run {args.run_name} not found")
42
+ if run.status.is_finished():
43
+ raise CLIError(f"Run {args.run_name} is finished")
44
+ metrics = _get_run_jobs_metrics(api=self.api, run=run)
45
+
46
+ if not args.watch:
47
+ console.print(_get_metrics_table(run, metrics))
48
+ return
49
+
50
+ try:
51
+ with Live(console=console, refresh_per_second=LIVE_TABLE_REFRESH_RATE_PER_SEC) as live:
52
+ while True:
53
+ live.update(_get_metrics_table(run, metrics))
54
+ time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
55
+ run = self.api.runs.get(run_name=args.run_name)
56
+ if run is None:
57
+ raise CLIError(f"Run {args.run_name} not found")
58
+ if run.status.is_finished():
59
+ raise CLIError(f"Run {args.run_name} is finished")
60
+ metrics = _get_run_jobs_metrics(api=self.api, run=run)
61
+ except KeyboardInterrupt:
62
+ pass
63
+
64
+
65
+ def _get_run_jobs_metrics(api: Client, run: Run) -> List[JobMetrics]:
66
+ metrics = []
67
+ for job in run._run.jobs:
68
+ job_metrics = api.client.metrics.get_job_metrics(
69
+ project_name=api.project,
70
+ run_name=run.name,
71
+ replica_num=job.job_spec.replica_num,
72
+ job_num=job.job_spec.job_num,
73
+ )
74
+ metrics.append(job_metrics)
75
+ return metrics
76
+
77
+
78
+ def _get_metrics_table(run: Run, metrics: List[JobMetrics]) -> Table:
79
+ table = Table(box=None)
80
+ table.add_column("NAME", style="bold", no_wrap=True)
81
+ table.add_column("CPU")
82
+ table.add_column("MEMORY")
83
+ table.add_column("GPU")
84
+
85
+ run_row: Dict[Union[str, int], Any] = {"NAME": run.name}
86
+ if len(run._run.jobs) != 1:
87
+ add_row_from_dict(table, run_row)
88
+
89
+ for job, job_metrics in zip(run._run.jobs, metrics):
90
+ jrd = job.job_submissions[-1].job_runtime_data
91
+ jpd = job.job_submissions[-1].job_provisioning_data
92
+ resources: Optional[Resources] = None
93
+ if jrd is not None and jrd.offer is not None:
94
+ resources = jrd.offer.instance.resources
95
+ elif jpd is not None:
96
+ resources = jpd.instance_type.resources
97
+ cpu_usage = _get_metric_value(job_metrics, "cpu_usage_percent")
98
+ if cpu_usage is not None:
99
+ if resources is not None:
100
+ cpu_usage = cpu_usage / resources.cpus
101
+ cpu_usage = f"{cpu_usage:.0f}%"
102
+ memory_usage = _get_metric_value(job_metrics, "memory_working_set_bytes")
103
+ if memory_usage is not None:
104
+ memory_usage = f"{round(memory_usage / 1024 / 1024)}MB"
105
+ if resources is not None:
106
+ memory_usage += f"/{resources.memory_mib}MB"
107
+ gpu_metrics = ""
108
+ gpus_detected_num = _get_metric_value(job_metrics, "gpus_detected_num")
109
+ if gpus_detected_num is not None:
110
+ for i in range(gpus_detected_num):
111
+ gpu_memory_usage = _get_metric_value(job_metrics, f"gpu_memory_usage_bytes_gpu{i}")
112
+ gpu_util_percent = _get_metric_value(job_metrics, f"gpu_util_percent_gpu{i}")
113
+ if gpu_memory_usage is not None:
114
+ if i != 0:
115
+ gpu_metrics += "\n"
116
+ gpu_metrics += f"#{i} {round(gpu_memory_usage / 1024 / 1024)}MB"
117
+ if resources is not None:
118
+ gpu_metrics += f"/{resources.gpus[i].memory_mib}MB"
119
+ gpu_metrics += f" {gpu_util_percent}% Util"
120
+
121
+ job_row: Dict[Union[str, int], Any] = {
122
+ "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}",
123
+ "CPU": cpu_usage or "-",
124
+ "MEMORY": memory_usage or "-",
125
+ "GPU": gpu_metrics or "-",
126
+ }
127
+ if len(run._run.jobs) == 1:
128
+ job_row.update(run_row)
129
+ add_row_from_dict(table, job_row)
130
+
131
+ return table
132
+
133
+
134
+ def _get_metric_value(job_metrics: JobMetrics, name: str) -> Optional[Any]:
135
+ for metric in job_metrics.metrics:
136
+ if metric.name == name:
137
+ return metric.values[-1]
138
+ return None
@@ -1,128 +1,14 @@
1
1
  import argparse
2
- import time
3
- from typing import Any, Dict, List, Optional, Union
4
2
 
5
- from rich.live import Live
6
- from rich.table import Table
3
+ from dstack._internal.cli.commands.metrics import MetricsCommand
4
+ from dstack._internal.utils.logging import get_logger
7
5
 
8
- from dstack._internal.cli.commands import APIBaseCommand
9
- from dstack._internal.cli.services.completion import RunNameCompleter
10
- from dstack._internal.cli.utils.common import (
11
- LIVE_TABLE_PROVISION_INTERVAL_SECS,
12
- LIVE_TABLE_REFRESH_RATE_PER_SEC,
13
- add_row_from_dict,
14
- console,
15
- )
16
- from dstack._internal.core.errors import CLIError
17
- from dstack._internal.core.models.metrics import JobMetrics
18
- from dstack.api._public import Client
19
- from dstack.api._public.runs import Run
6
+ logger = get_logger(__name__)
20
7
 
21
8
 
22
- class StatsCommand(APIBaseCommand):
9
+ class StatsCommand(MetricsCommand):
23
10
  NAME = "stats"
24
- DESCRIPTION = "Show run stats"
25
-
26
- def _register(self):
27
- super()._register()
28
- self._parser.add_argument("run_name").completer = RunNameCompleter()
29
- self._parser.add_argument(
30
- "-w",
31
- "--watch",
32
- help="Watch run stats in realtime",
33
- action="store_true",
34
- )
35
11
 
36
12
  def _command(self, args: argparse.Namespace):
13
+ logger.warning("`dstack stats` is deprecated in favor of `dstack metrics`")
37
14
  super()._command(args)
38
- run = self.api.runs.get(run_name=args.run_name)
39
- if run is None:
40
- raise CLIError(f"Run {args.run_name} not found")
41
- if run.status.is_finished():
42
- raise CLIError(f"Run {args.run_name} is finished")
43
- metrics = _get_run_jobs_metrics(api=self.api, run=run)
44
-
45
- if not args.watch:
46
- console.print(_get_stats_table(run, metrics))
47
- return
48
-
49
- try:
50
- with Live(console=console, refresh_per_second=LIVE_TABLE_REFRESH_RATE_PER_SEC) as live:
51
- while True:
52
- live.update(_get_stats_table(run, metrics))
53
- time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
54
- run = self.api.runs.get(run_name=args.run_name)
55
- if run is None:
56
- raise CLIError(f"Run {args.run_name} not found")
57
- if run.status.is_finished():
58
- raise CLIError(f"Run {args.run_name} is finished")
59
- metrics = _get_run_jobs_metrics(api=self.api, run=run)
60
- except KeyboardInterrupt:
61
- pass
62
-
63
-
64
- def _get_run_jobs_metrics(api: Client, run: Run) -> List[JobMetrics]:
65
- metrics = []
66
- for job in run._run.jobs:
67
- job_metrics = api.client.metrics.get_job_metrics(
68
- project_name=api.project,
69
- run_name=run.name,
70
- replica_num=job.job_spec.replica_num,
71
- job_num=job.job_spec.job_num,
72
- )
73
- metrics.append(job_metrics)
74
- return metrics
75
-
76
-
77
- def _get_stats_table(run: Run, metrics: List[JobMetrics]) -> Table:
78
- table = Table(box=None)
79
- table.add_column("NAME", style="bold", no_wrap=True)
80
- table.add_column("CPU")
81
- table.add_column("MEMORY")
82
- table.add_column("GPU")
83
-
84
- run_row: Dict[Union[str, int], Any] = {"NAME": run.name}
85
- if len(run._run.jobs) != 1:
86
- add_row_from_dict(table, run_row)
87
-
88
- for job, job_metrics in zip(run._run.jobs, metrics):
89
- cpu_usage = _get_metric_value(job_metrics, "cpu_usage_percent")
90
- if cpu_usage is not None:
91
- cpu_usage = f"{cpu_usage}%"
92
- memory_usage = _get_metric_value(job_metrics, "memory_working_set_bytes")
93
- if memory_usage is not None:
94
- memory_usage = f"{round(memory_usage / 1024 / 1024)}MB"
95
- if job.job_submissions[-1].job_provisioning_data is not None:
96
- memory_usage += f"/{job.job_submissions[-1].job_provisioning_data.instance_type.resources.memory_mib}MB"
97
- gpu_stats = ""
98
- gpus_detected_num = _get_metric_value(job_metrics, "gpus_detected_num")
99
- if gpus_detected_num is not None:
100
- for i in range(gpus_detected_num):
101
- gpu_memory_usage = _get_metric_value(job_metrics, f"gpu_memory_usage_bytes_gpu{i}")
102
- gpu_util_percent = _get_metric_value(job_metrics, f"gpu_util_percent_gpu{i}")
103
- if gpu_memory_usage is not None:
104
- if i != 0:
105
- gpu_stats += "\n"
106
- gpu_stats += f"#{i} {round(gpu_memory_usage / 1024 / 1024)}MB"
107
- if job.job_submissions[-1].job_provisioning_data is not None:
108
- gpu_stats += f"/{job.job_submissions[-1].job_provisioning_data.instance_type.resources.gpus[i].memory_mib}MB"
109
- gpu_stats += f" {gpu_util_percent}% Util"
110
-
111
- job_row: Dict[Union[str, int], Any] = {
112
- "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}",
113
- "CPU": cpu_usage or "-",
114
- "MEMORY": memory_usage or "-",
115
- "GPU": gpu_stats or "-",
116
- }
117
- if len(run._run.jobs) == 1:
118
- job_row.update(run_row)
119
- add_row_from_dict(table, job_row)
120
-
121
- return table
122
-
123
-
124
- def _get_metric_value(job_metrics: JobMetrics, name: str) -> Optional[Any]:
125
- for metric in job_metrics.metrics:
126
- if metric.name == name:
127
- return metric.values[-1]
128
- return None
@@ -13,6 +13,7 @@ from dstack._internal.cli.commands.fleet import FleetCommand
13
13
  from dstack._internal.cli.commands.gateway import GatewayCommand
14
14
  from dstack._internal.cli.commands.init import InitCommand
15
15
  from dstack._internal.cli.commands.logs import LogsCommand
16
+ from dstack._internal.cli.commands.metrics import MetricsCommand
16
17
  from dstack._internal.cli.commands.ps import PsCommand
17
18
  from dstack._internal.cli.commands.server import ServerCommand
18
19
  from dstack._internal.cli.commands.stats import StatsCommand
@@ -65,6 +66,7 @@ def main():
65
66
  GatewayCommand.register(subparsers)
66
67
  InitCommand.register(subparsers)
67
68
  LogsCommand.register(subparsers)
69
+ MetricsCommand.register(subparsers)
68
70
  PsCommand.register(subparsers)
69
71
  ServerCommand.register(subparsers)
70
72
  StatsCommand.register(subparsers)
@@ -94,6 +94,9 @@ class Compute(ABC):
94
94
  """
95
95
  Terminates an instance by `instance_id`.
96
96
  If the instance does not exist, it should not raise errors but return silently.
97
+
98
+ Should return ASAP. If required to wait for some operation, raise `NotYetTerminated`.
99
+ In this case, the method will be called again after a few seconds.
97
100
  """
98
101
  pass
99
102
 
@@ -1,14 +1,14 @@
1
1
  from pathlib import Path
2
2
 
3
3
 
4
- def fill_data(values: dict):
5
- if values.get("data") is not None:
4
+ def fill_data(values: dict, filename_field: str = "filename", data_field: str = "data") -> dict:
5
+ if values.get(data_field) is not None:
6
6
  return values
7
- if "filename" not in values:
8
- raise ValueError()
7
+ if (filename := values.get(filename_field)) is None:
8
+ raise ValueError(f"Either `{filename_field}` or `{data_field}` must be specified")
9
9
  try:
10
- with open(Path(values["filename"]).expanduser()) as f:
11
- values["data"] = f.read()
10
+ with open(Path(filename).expanduser()) as f:
11
+ values[data_field] = f.read()
12
12
  except OSError:
13
- raise ValueError(f"No such file {values['filename']}")
13
+ raise ValueError(f"No such file {filename}")
14
14
  return values
@@ -63,6 +63,15 @@ try:
63
63
  except ImportError:
64
64
  pass
65
65
 
66
+ try:
67
+ from dstack._internal.core.backends.nebius.configurator import (
68
+ NebiusConfigurator,
69
+ )
70
+
71
+ _CONFIGURATOR_CLASSES.append(NebiusConfigurator)
72
+ except ImportError:
73
+ pass
74
+
66
75
  try:
67
76
  from dstack._internal.core.backends.oci.configurator import OCIConfigurator
68
77
 
@@ -34,6 +34,11 @@ from dstack._internal.core.backends.lambdalabs.models import (
34
34
  LambdaBackendConfig,
35
35
  LambdaBackendConfigWithCreds,
36
36
  )
37
+ from dstack._internal.core.backends.nebius.models import (
38
+ NebiusBackendConfig,
39
+ NebiusBackendConfigWithCreds,
40
+ NebiusBackendFileConfigWithCreds,
41
+ )
37
42
  from dstack._internal.core.backends.oci.models import (
38
43
  OCIBackendConfig,
39
44
  OCIBackendConfigWithCreds,
@@ -65,6 +70,7 @@ AnyBackendConfigWithoutCreds = Union[
65
70
  GCPBackendConfig,
66
71
  KubernetesBackendConfig,
67
72
  LambdaBackendConfig,
73
+ NebiusBackendConfig,
68
74
  OCIBackendConfig,
69
75
  RunpodBackendConfig,
70
76
  TensorDockBackendConfig,
@@ -86,6 +92,7 @@ AnyBackendConfigWithCreds = Union[
86
92
  KubernetesBackendConfigWithCreds,
87
93
  LambdaBackendConfigWithCreds,
88
94
  OCIBackendConfigWithCreds,
95
+ NebiusBackendConfigWithCreds,
89
96
  RunpodBackendConfigWithCreds,
90
97
  TensorDockBackendConfigWithCreds,
91
98
  VastAIBackendConfigWithCreds,
@@ -105,6 +112,7 @@ AnyBackendFileConfigWithCreds = Union[
105
112
  KubernetesBackendFileConfigWithCreds,
106
113
  LambdaBackendConfigWithCreds,
107
114
  OCIBackendConfigWithCreds,
115
+ NebiusBackendFileConfigWithCreds,
108
116
  RunpodBackendConfigWithCreds,
109
117
  TensorDockBackendConfigWithCreds,
110
118
  VastAIBackendConfigWithCreds,
File without changes
@@ -0,0 +1,16 @@
1
+ from dstack._internal.core.backends.base.backend import Backend
2
+ from dstack._internal.core.backends.nebius.compute import NebiusCompute
3
+ from dstack._internal.core.backends.nebius.models import NebiusConfig
4
+ from dstack._internal.core.models.backends.base import BackendType
5
+
6
+
7
+ class NebiusBackend(Backend):
8
+ TYPE = BackendType.NEBIUS
9
+ COMPUTE_CLASS = NebiusCompute
10
+
11
+ def __init__(self, config: NebiusConfig):
12
+ self.config = config
13
+ self._compute = NebiusCompute(self.config)
14
+
15
+ def compute(self) -> NebiusCompute:
16
+ return self._compute
@@ -0,0 +1,270 @@
1
+ import json
2
+ import shlex
3
+ import time
4
+ from functools import cached_property
5
+ from typing import List, Optional
6
+
7
+ from nebius.aio.operation import Operation as SDKOperation
8
+ from nebius.aio.service_error import RequestError, StatusCode
9
+ from nebius.api.nebius.common.v1 import Operation
10
+ from nebius.sdk import SDK
11
+
12
+ from dstack._internal.core.backends.base.backend import Compute
13
+ from dstack._internal.core.backends.base.compute import (
14
+ ComputeWithCreateInstanceSupport,
15
+ ComputeWithMultinodeSupport,
16
+ generate_unique_instance_name,
17
+ get_user_data,
18
+ )
19
+ from dstack._internal.core.backends.base.offers import get_catalog_offers
20
+ from dstack._internal.core.backends.nebius import resources
21
+ from dstack._internal.core.backends.nebius.models import NebiusConfig, NebiusServiceAccountCreds
22
+ from dstack._internal.core.errors import BackendError, NotYetTerminated, ProvisioningError
23
+ from dstack._internal.core.models.backends.base import BackendType
24
+ from dstack._internal.core.models.common import CoreModel
25
+ from dstack._internal.core.models.instances import (
26
+ InstanceAvailability,
27
+ InstanceConfiguration,
28
+ InstanceOffer,
29
+ InstanceOfferWithAvailability,
30
+ )
31
+ from dstack._internal.core.models.resources import Memory, Range
32
+ from dstack._internal.core.models.runs import JobProvisioningData, Requirements
33
+ from dstack._internal.utils.logging import get_logger
34
+
35
+ logger = get_logger(__name__)
36
+ CONFIGURABLE_DISK_SIZE = Range[Memory](
37
+ min=Memory.parse("40GB"), # min for the ubuntu22.04-cuda12 image
38
+ max=Memory.parse("8192GB"), # max for the NETWORK_SSD disk type
39
+ )
40
+ WAIT_FOR_DISK_TIMEOUT = 20
41
+ WAIT_FOR_INSTANCE_TIMEOUT = 30
42
+ WAIT_FOR_INSTANCE_UPDATE_INTERVAL = 2.5
43
+ DELETE_INSTANCE_TIMEOUT = 25
44
+ DOCKER_DAEMON_CONFIG = {
45
+ "runtimes": {"nvidia": {"args": [], "path": "nvidia-container-runtime"}},
46
+ # Workaround for https://github.com/NVIDIA/nvidia-container-toolkit/issues/48
47
+ "exec-opts": ["native.cgroupdriver=cgroupfs"],
48
+ }
49
+ SETUP_COMMANDS = [
50
+ "ufw allow ssh",
51
+ "ufw allow from 192.168.0.0/16",
52
+ "ufw default deny incoming",
53
+ "ufw default allow outgoing",
54
+ "ufw enable",
55
+ 'sed -i "s/.*AllowTcpForwarding.*/AllowTcpForwarding yes/g" /etc/ssh/sshd_config',
56
+ "service ssh restart",
57
+ f"echo {shlex.quote(json.dumps(DOCKER_DAEMON_CONFIG))} > /etc/docker/daemon.json",
58
+ "service docker restart",
59
+ ]
60
+ SUPPORTED_PLATFORMS = [
61
+ "gpu-h100-sxm",
62
+ "gpu-h200-sxm",
63
+ "gpu-l40s-a",
64
+ "gpu-l40s-d",
65
+ "cpu-d3",
66
+ "cpu-e2",
67
+ ]
68
+
69
+
70
+ class NebiusCompute(
71
+ ComputeWithCreateInstanceSupport,
72
+ ComputeWithMultinodeSupport,
73
+ Compute,
74
+ ):
75
+ def __init__(self, config: NebiusConfig):
76
+ super().__init__()
77
+ self.config = config
78
+ self._subnet_id_cache: dict[str, str] = {}
79
+
80
+ @cached_property
81
+ def _sdk(self) -> SDK:
82
+ assert isinstance(self.config.creds, NebiusServiceAccountCreds)
83
+ return resources.make_sdk(self.config.creds)
84
+
85
+ @cached_property
86
+ def _region_to_project_id(self) -> dict[str, str]:
87
+ return resources.get_region_to_project_id_map(self._sdk)
88
+
89
+ def _get_subnet_id(self, region: str) -> str:
90
+ if region not in self._subnet_id_cache:
91
+ self._subnet_id_cache[region] = resources.get_default_subnet(
92
+ self._sdk, self._region_to_project_id[region]
93
+ ).metadata.id
94
+ return self._subnet_id_cache[region]
95
+
96
+ def get_offers(
97
+ self, requirements: Optional[Requirements] = None
98
+ ) -> List[InstanceOfferWithAvailability]:
99
+ offers = get_catalog_offers(
100
+ backend=BackendType.NEBIUS,
101
+ locations=self.config.regions or list(self._region_to_project_id),
102
+ requirements=requirements,
103
+ extra_filter=_supported_instances,
104
+ configurable_disk_size=CONFIGURABLE_DISK_SIZE,
105
+ )
106
+ return [
107
+ InstanceOfferWithAvailability(
108
+ **offer.dict(),
109
+ availability=InstanceAvailability.UNKNOWN,
110
+ )
111
+ for offer in offers
112
+ ]
113
+
114
+ def create_instance(
115
+ self,
116
+ instance_offer: InstanceOfferWithAvailability,
117
+ instance_config: InstanceConfiguration,
118
+ ) -> JobProvisioningData:
119
+ # NOTE: This method can block for a long time as it waits for the boot disk to be created
120
+ # and the instance to enter the STARTING state. This has to be done in create_instance so
121
+ # that we can handle quota and availability errors that may occur even after creating an
122
+ # instance.
123
+ instance_name = generate_unique_instance_name(instance_config)
124
+ platform, preset = instance_offer.instance.name.split()
125
+ create_disk_op = resources.create_disk(
126
+ sdk=self._sdk,
127
+ name=instance_name,
128
+ project_id=self._region_to_project_id[instance_offer.region],
129
+ size_mib=instance_offer.instance.resources.disk.size_mib,
130
+ image_family="ubuntu22.04-cuda12",
131
+ )
132
+ create_instance_op = None
133
+ try:
134
+ logger.debug("Blocking until disk %s is created", create_disk_op.resource_id)
135
+ resources.wait_for_operation(create_disk_op, timeout=WAIT_FOR_DISK_TIMEOUT)
136
+ if not create_disk_op.successful():
137
+ raw_op = create_disk_op.raw()
138
+ raise ProvisioningError(
139
+ f"Create disk operation failed. Message: {raw_op.status.message}."
140
+ f" Details: {raw_op.status.details}"
141
+ )
142
+ create_instance_op = resources.create_instance(
143
+ sdk=self._sdk,
144
+ name=instance_name,
145
+ project_id=self._region_to_project_id[instance_offer.region],
146
+ user_data=get_user_data(
147
+ instance_config.get_public_keys(),
148
+ backend_specific_commands=SETUP_COMMANDS,
149
+ ),
150
+ platform=platform,
151
+ preset=preset,
152
+ disk_id=create_disk_op.resource_id,
153
+ subnet_id=self._get_subnet_id(instance_offer.region),
154
+ )
155
+ _wait_for_instance(self._sdk, create_instance_op)
156
+ except BaseException:
157
+ if create_instance_op is not None:
158
+ try:
159
+ with resources.ignore_errors([StatusCode.NOT_FOUND]):
160
+ delete_instance_op = resources.delete_instance(
161
+ self._sdk, create_instance_op.resource_id
162
+ )
163
+ resources.wait_for_operation(
164
+ delete_instance_op, timeout=DELETE_INSTANCE_TIMEOUT
165
+ )
166
+ except Exception as e:
167
+ logger.exception(
168
+ "Could not delete instance %s: %s", create_instance_op.resource_id, e
169
+ )
170
+ try:
171
+ with resources.ignore_errors([StatusCode.NOT_FOUND]):
172
+ resources.delete_disk(self._sdk, create_disk_op.resource_id)
173
+ except Exception as e:
174
+ logger.exception(
175
+ "Could not delete boot disk %s: %s", create_disk_op.resource_id, e
176
+ )
177
+ raise
178
+ return JobProvisioningData(
179
+ backend=instance_offer.backend,
180
+ instance_type=instance_offer.instance,
181
+ instance_id=create_instance_op.resource_id,
182
+ hostname=None,
183
+ region=instance_offer.region,
184
+ price=instance_offer.price,
185
+ ssh_port=22,
186
+ username="ubuntu",
187
+ dockerized=True,
188
+ backend_data=NebiusInstanceBackendData(boot_disk_id=create_disk_op.resource_id).json(),
189
+ )
190
+
191
+ def update_provisioning_data(
192
+ self, provisioning_data, project_ssh_public_key, project_ssh_private_key
193
+ ):
194
+ instance = resources.get_instance(self._sdk, provisioning_data.instance_id)
195
+ if not instance.status.network_interfaces:
196
+ return
197
+ interface = instance.status.network_interfaces[0]
198
+ provisioning_data.hostname, _ = interface.public_ip_address.address.split("/")
199
+ provisioning_data.internal_ip, _ = interface.ip_address.address.split("/")
200
+
201
+ def terminate_instance(
202
+ self, instance_id: str, region: str, backend_data: Optional[str] = None
203
+ ):
204
+ backend_data_parsed = NebiusInstanceBackendData.load(backend_data)
205
+ try:
206
+ instance = resources.get_instance(self._sdk, instance_id)
207
+ except RequestError as e:
208
+ if e.status.code != StatusCode.NOT_FOUND:
209
+ raise
210
+ instance = None
211
+ if instance is not None:
212
+ if instance.status.state != instance.status.InstanceState.DELETING:
213
+ resources.delete_instance(self._sdk, instance_id)
214
+ raise NotYetTerminated(
215
+ "Requested instance deletion."
216
+ " Will wait for deletion before deleting the boot disk."
217
+ f" Instance state was: {instance.status.state.name}"
218
+ )
219
+ else:
220
+ raise NotYetTerminated(
221
+ "Waiting for instance deletion before deleting the boot disk."
222
+ f" Instance state: {instance.status.state.name}"
223
+ )
224
+ with resources.ignore_errors([StatusCode.NOT_FOUND]):
225
+ resources.delete_disk(self._sdk, backend_data_parsed.boot_disk_id)
226
+
227
+
228
+ class NebiusInstanceBackendData(CoreModel):
229
+ boot_disk_id: str
230
+
231
+ @classmethod
232
+ def load(cls, raw: Optional[str]) -> "NebiusInstanceBackendData":
233
+ assert raw is not None
234
+ return cls.__response__.parse_raw(raw)
235
+
236
+
237
+ def _wait_for_instance(sdk: SDK, op: SDKOperation[Operation]) -> None:
238
+ start = time.monotonic()
239
+ while True:
240
+ if op.done() and not op.successful():
241
+ raise ProvisioningError(
242
+ f"Create instance operation failed. Message: {op.raw().status.message}."
243
+ f" Details: {op.raw().status.details}"
244
+ )
245
+ instance = resources.get_instance(sdk, op.resource_id)
246
+ if instance.status.state in [
247
+ instance.status.InstanceState.STARTING,
248
+ instance.status.InstanceState.RUNNING,
249
+ ]:
250
+ break
251
+ if time.monotonic() - start > WAIT_FOR_INSTANCE_TIMEOUT:
252
+ raise BackendError(
253
+ f"Instance {instance.metadata.id} did not start booting in time."
254
+ f" Status: {instance.status.state.name}"
255
+ )
256
+ logger.debug(
257
+ "Waiting for instance %s. Status: %s. Operation status: %s",
258
+ instance.metadata.name,
259
+ instance.status.state.name,
260
+ op.status(),
261
+ )
262
+ time.sleep(WAIT_FOR_INSTANCE_UPDATE_INTERVAL)
263
+ resources.LOOP.await_(
264
+ op.update(timeout=resources.REQUEST_TIMEOUT, metadata=resources.REQUEST_MD)
265
+ )
266
+
267
+
268
+ def _supported_instances(offer: InstanceOffer) -> bool:
269
+ platform, _ = offer.instance.name.split()
270
+ return platform in SUPPORTED_PLATFORMS and not offer.instance.resources.spot