dstack 0.19.21__py3-none-any.whl → 0.19.23rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dstack might be problematic. Click here for more details.
- dstack/_internal/cli/commands/apply.py +8 -3
- dstack/_internal/cli/services/configurators/__init__.py +8 -0
- dstack/_internal/cli/services/configurators/fleet.py +1 -1
- dstack/_internal/cli/services/configurators/gateway.py +1 -1
- dstack/_internal/cli/services/configurators/run.py +11 -1
- dstack/_internal/cli/services/configurators/volume.py +1 -1
- dstack/_internal/cli/utils/common.py +48 -5
- dstack/_internal/cli/utils/fleet.py +5 -5
- dstack/_internal/cli/utils/run.py +32 -0
- dstack/_internal/core/backends/configurators.py +9 -0
- dstack/_internal/core/backends/hotaisle/__init__.py +1 -0
- dstack/_internal/core/backends/hotaisle/api_client.py +109 -0
- dstack/_internal/core/backends/hotaisle/backend.py +16 -0
- dstack/_internal/core/backends/hotaisle/compute.py +225 -0
- dstack/_internal/core/backends/hotaisle/configurator.py +60 -0
- dstack/_internal/core/backends/hotaisle/models.py +45 -0
- dstack/_internal/core/backends/lambdalabs/compute.py +2 -1
- dstack/_internal/core/backends/models.py +8 -0
- dstack/_internal/core/backends/nebius/compute.py +8 -2
- dstack/_internal/core/backends/nebius/fabrics.py +1 -0
- dstack/_internal/core/backends/nebius/resources.py +9 -0
- dstack/_internal/core/compatibility/runs.py +8 -0
- dstack/_internal/core/models/backends/base.py +2 -0
- dstack/_internal/core/models/configurations.py +139 -1
- dstack/_internal/core/models/health.py +28 -0
- dstack/_internal/core/models/instances.py +2 -0
- dstack/_internal/core/models/logs.py +2 -1
- dstack/_internal/core/models/runs.py +21 -1
- dstack/_internal/core/services/ssh/tunnel.py +7 -0
- dstack/_internal/server/app.py +4 -0
- dstack/_internal/server/background/__init__.py +4 -0
- dstack/_internal/server/background/tasks/process_instances.py +107 -56
- dstack/_internal/server/background/tasks/process_probes.py +164 -0
- dstack/_internal/server/background/tasks/process_running_jobs.py +13 -0
- dstack/_internal/server/background/tasks/process_runs.py +21 -14
- dstack/_internal/server/migrations/versions/25479f540245_add_probes.py +43 -0
- dstack/_internal/server/migrations/versions/728b1488b1b4_add_instance_health.py +50 -0
- dstack/_internal/server/models.py +41 -0
- dstack/_internal/server/routers/instances.py +33 -5
- dstack/_internal/server/schemas/health/dcgm.py +56 -0
- dstack/_internal/server/schemas/instances.py +32 -0
- dstack/_internal/server/schemas/runner.py +5 -0
- dstack/_internal/server/services/instances.py +103 -1
- dstack/_internal/server/services/jobs/__init__.py +8 -1
- dstack/_internal/server/services/jobs/configurators/base.py +26 -0
- dstack/_internal/server/services/logging.py +4 -2
- dstack/_internal/server/services/logs/aws.py +13 -1
- dstack/_internal/server/services/logs/gcp.py +16 -1
- dstack/_internal/server/services/probes.py +6 -0
- dstack/_internal/server/services/projects.py +16 -4
- dstack/_internal/server/services/runner/client.py +52 -20
- dstack/_internal/server/services/runner/ssh.py +4 -4
- dstack/_internal/server/services/runs.py +49 -13
- dstack/_internal/server/services/ssh.py +66 -0
- dstack/_internal/server/settings.py +13 -0
- dstack/_internal/server/statics/index.html +1 -1
- dstack/_internal/server/statics/{main-8f9ee218d3eb45989682.css → main-03e818b110e1d5705378.css} +1 -1
- dstack/_internal/server/statics/{main-39a767528976f8078166.js → main-cc067b7fd1a8f33f97da.js} +26 -15
- dstack/_internal/server/statics/{main-39a767528976f8078166.js.map → main-cc067b7fd1a8f33f97da.js.map} +1 -1
- dstack/_internal/server/testing/common.py +44 -0
- dstack/_internal/{core/backends/remote → server/utils}/provisioning.py +22 -17
- dstack/_internal/settings.py +3 -0
- dstack/_internal/utils/common.py +15 -0
- dstack/api/server/__init__.py +1 -1
- dstack/version.py +1 -1
- {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/METADATA +14 -14
- {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/RECORD +71 -58
- /dstack/_internal/{core/backends/remote → server/schemas/health}/__init__.py +0 -0
- {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/WHEEL +0 -0
- {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/entry_points.txt +0 -0
- {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import argparse
|
|
2
|
-
from pathlib import Path
|
|
3
2
|
|
|
4
3
|
from argcomplete import FilesCompleter
|
|
5
4
|
|
|
6
5
|
from dstack._internal.cli.commands import APIBaseCommand
|
|
7
6
|
from dstack._internal.cli.services.configurators import (
|
|
7
|
+
APPLY_STDIN_NAME,
|
|
8
8
|
get_apply_configurator_class,
|
|
9
9
|
load_apply_configuration,
|
|
10
10
|
)
|
|
@@ -40,9 +40,12 @@ class ApplyCommand(APIBaseCommand):
|
|
|
40
40
|
self._parser.add_argument(
|
|
41
41
|
"-f",
|
|
42
42
|
"--file",
|
|
43
|
-
type=Path,
|
|
44
43
|
metavar="FILE",
|
|
45
|
-
help=
|
|
44
|
+
help=(
|
|
45
|
+
"The path to the configuration file."
|
|
46
|
+
" Specify [code]-[/] to read configuration from stdin."
|
|
47
|
+
" Defaults to [code]$PWD/.dstack.yml[/]"
|
|
48
|
+
),
|
|
46
49
|
dest="configuration_file",
|
|
47
50
|
).completer = FilesCompleter(allowednames=["*.yml", "*.yaml"])
|
|
48
51
|
self._parser.add_argument(
|
|
@@ -104,6 +107,8 @@ class ApplyCommand(APIBaseCommand):
|
|
|
104
107
|
return
|
|
105
108
|
|
|
106
109
|
super()._command(args)
|
|
110
|
+
if not args.yes and args.configuration_file == APPLY_STDIN_NAME:
|
|
111
|
+
raise CLIError("Cannot read configuration from stdin if -y/--yes is not specified")
|
|
107
112
|
if args.repo and args.no_repo:
|
|
108
113
|
raise CLIError("Either --repo or --no-repo can be specified")
|
|
109
114
|
repo = None
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import sys
|
|
1
2
|
from pathlib import Path
|
|
2
3
|
from typing import Dict, Optional, Tuple, Type
|
|
3
4
|
|
|
@@ -20,6 +21,9 @@ from dstack._internal.core.models.configurations import (
|
|
|
20
21
|
parse_apply_configuration,
|
|
21
22
|
)
|
|
22
23
|
|
|
24
|
+
APPLY_STDIN_NAME = "-"
|
|
25
|
+
|
|
26
|
+
|
|
23
27
|
apply_configurators_mapping: Dict[ApplyConfigurationType, Type[BaseApplyConfigurator]] = {
|
|
24
28
|
cls.TYPE: cls
|
|
25
29
|
for cls in [
|
|
@@ -62,6 +66,8 @@ def load_apply_configuration(
|
|
|
62
66
|
raise ConfigurationError(
|
|
63
67
|
"No configuration file specified via `-f` and no default .dstack.yml configuration found"
|
|
64
68
|
)
|
|
69
|
+
elif configuration_file == APPLY_STDIN_NAME:
|
|
70
|
+
configuration_path = sys.stdin.fileno()
|
|
65
71
|
else:
|
|
66
72
|
configuration_path = Path(configuration_file)
|
|
67
73
|
if not configuration_path.exists():
|
|
@@ -71,4 +77,6 @@ def load_apply_configuration(
|
|
|
71
77
|
conf = parse_apply_configuration(yaml.safe_load(f))
|
|
72
78
|
except OSError:
|
|
73
79
|
raise ConfigurationError(f"Failed to load configuration from {configuration_path}")
|
|
80
|
+
if isinstance(configuration_path, int):
|
|
81
|
+
return APPLY_STDIN_NAME, conf
|
|
74
82
|
return str(configuration_path.absolute().relative_to(Path.cwd())), conf
|
|
@@ -151,7 +151,7 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
|
|
|
151
151
|
time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
|
|
152
152
|
fleet = self.api.client.fleets.get(self.api.project, fleet.name)
|
|
153
153
|
except KeyboardInterrupt:
|
|
154
|
-
if confirm_ask("Delete the fleet before exiting?"):
|
|
154
|
+
if not command_args.yes and confirm_ask("Delete the fleet before exiting?"):
|
|
155
155
|
with console.status("Deleting fleet..."):
|
|
156
156
|
self.api.client.fleets.delete(
|
|
157
157
|
project_name=self.api.project, names=[fleet.name]
|
|
@@ -121,7 +121,7 @@ class GatewayConfigurator(BaseApplyConfigurator):
|
|
|
121
121
|
time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
|
|
122
122
|
gateway = self.api.client.gateways.get(self.api.project, gateway.name)
|
|
123
123
|
except KeyboardInterrupt:
|
|
124
|
-
if confirm_ask("Delete the gateway before exiting?"):
|
|
124
|
+
if not command_args.yes and confirm_ask("Delete the gateway before exiting?"):
|
|
125
125
|
with console.status("Deleting gateway..."):
|
|
126
126
|
self.api.client.gateways.delete(
|
|
127
127
|
project_name=self.api.project,
|
|
@@ -218,7 +218,9 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
|
|
|
218
218
|
exit(1)
|
|
219
219
|
except KeyboardInterrupt:
|
|
220
220
|
try:
|
|
221
|
-
if not confirm_ask(
|
|
221
|
+
if command_args.yes or not confirm_ask(
|
|
222
|
+
f"\nStop the run [code]{run.name}[/] before detaching?"
|
|
223
|
+
):
|
|
222
224
|
console.print("Detached")
|
|
223
225
|
abort_at_exit = False
|
|
224
226
|
return
|
|
@@ -339,6 +341,14 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
|
|
|
339
341
|
username=interpolator.interpolate_or_error(conf.registry_auth.username),
|
|
340
342
|
password=interpolator.interpolate_or_error(conf.registry_auth.password),
|
|
341
343
|
)
|
|
344
|
+
if isinstance(conf, ServiceConfiguration):
|
|
345
|
+
for probe in conf.probes:
|
|
346
|
+
for header in probe.headers:
|
|
347
|
+
header.value = interpolator.interpolate_or_error(header.value)
|
|
348
|
+
if probe.url:
|
|
349
|
+
probe.url = interpolator.interpolate_or_error(probe.url)
|
|
350
|
+
if probe.body:
|
|
351
|
+
probe.body = interpolator.interpolate_or_error(probe.body)
|
|
342
352
|
except InterpolatorError as e:
|
|
343
353
|
raise ConfigurationError(e.args[0])
|
|
344
354
|
|
|
@@ -110,7 +110,7 @@ class VolumeConfigurator(BaseApplyConfigurator):
|
|
|
110
110
|
time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
|
|
111
111
|
volume = self.api.client.volumes.get(self.api.project, volume.name)
|
|
112
112
|
except KeyboardInterrupt:
|
|
113
|
-
if confirm_ask("Delete the volume before exiting?"):
|
|
113
|
+
if not command_args.yes and confirm_ask("Delete the volume before exiting?"):
|
|
114
114
|
with console.status("Deleting volume..."):
|
|
115
115
|
self.api.client.volumes.delete(
|
|
116
116
|
project_name=self.api.project, names=[volume.name]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
import
|
|
2
|
+
from datetime import datetime, timezone
|
|
3
|
+
from pathlib import Path
|
|
3
4
|
from typing import Any, Dict, Union
|
|
4
5
|
|
|
5
6
|
from rich.console import Console
|
|
@@ -7,8 +8,10 @@ from rich.prompt import Confirm
|
|
|
7
8
|
from rich.table import Table
|
|
8
9
|
from rich.theme import Theme
|
|
9
10
|
|
|
11
|
+
from dstack._internal import settings
|
|
10
12
|
from dstack._internal.cli.utils.rich import DstackRichHandler
|
|
11
13
|
from dstack._internal.core.errors import CLIError, DstackError
|
|
14
|
+
from dstack._internal.utils.common import get_dstack_dir
|
|
12
15
|
|
|
13
16
|
_colors = {
|
|
14
17
|
"secondary": "grey58",
|
|
@@ -35,12 +38,52 @@ def cli_error(e: DstackError) -> CLIError:
|
|
|
35
38
|
return CLIError(*e.args)
|
|
36
39
|
|
|
37
40
|
|
|
41
|
+
def _get_cli_log_file() -> Path:
|
|
42
|
+
"""Get the CLI log file path, rotating the previous log if needed."""
|
|
43
|
+
log_dir = get_dstack_dir() / "logs" / "cli"
|
|
44
|
+
log_file = log_dir / "latest.log"
|
|
45
|
+
|
|
46
|
+
if log_file.exists():
|
|
47
|
+
file_mtime = datetime.fromtimestamp(log_file.stat().st_mtime, tz=timezone.utc)
|
|
48
|
+
current_date = datetime.now(timezone.utc).date()
|
|
49
|
+
|
|
50
|
+
if file_mtime.date() < current_date:
|
|
51
|
+
date_str = file_mtime.strftime("%Y-%m-%d")
|
|
52
|
+
rotated_file = log_dir / f"{date_str}.log"
|
|
53
|
+
|
|
54
|
+
counter = 1
|
|
55
|
+
while rotated_file.exists():
|
|
56
|
+
rotated_file = log_dir / f"{date_str}-{counter}.log"
|
|
57
|
+
counter += 1
|
|
58
|
+
|
|
59
|
+
log_file.rename(rotated_file)
|
|
60
|
+
|
|
61
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
62
|
+
return log_file
|
|
63
|
+
|
|
64
|
+
|
|
38
65
|
def configure_logging():
|
|
39
66
|
dstack_logger = logging.getLogger("dstack")
|
|
40
|
-
dstack_logger.
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
67
|
+
dstack_logger.handlers.clear()
|
|
68
|
+
|
|
69
|
+
log_file = _get_cli_log_file()
|
|
70
|
+
|
|
71
|
+
stdout_handler = DstackRichHandler(console=console)
|
|
72
|
+
stdout_handler.setFormatter(logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
|
|
73
|
+
stdout_handler.setLevel(settings.CLI_LOG_LEVEL)
|
|
74
|
+
dstack_logger.addHandler(stdout_handler)
|
|
75
|
+
|
|
76
|
+
file_handler = logging.FileHandler(log_file)
|
|
77
|
+
file_handler.setFormatter(
|
|
78
|
+
logging.Formatter(
|
|
79
|
+
fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
|
80
|
+
)
|
|
81
|
+
)
|
|
82
|
+
file_handler.setLevel(settings.CLI_FILE_LOG_LEVEL)
|
|
83
|
+
dstack_logger.addHandler(file_handler)
|
|
84
|
+
|
|
85
|
+
# the logger allows all messages, filtering is done by the handlers
|
|
86
|
+
dstack_logger.setLevel(logging.DEBUG)
|
|
44
87
|
|
|
45
88
|
|
|
46
89
|
def confirm_ask(prompt, **kwargs) -> bool:
|
|
@@ -51,11 +51,11 @@ def get_fleets_table(
|
|
|
51
51
|
and total_blocks > 1
|
|
52
52
|
):
|
|
53
53
|
status = f"{busy_blocks}/{total_blocks} {InstanceStatus.BUSY.value}"
|
|
54
|
-
if
|
|
55
|
-
instance.
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
54
|
+
if instance.status in [InstanceStatus.IDLE, InstanceStatus.BUSY]:
|
|
55
|
+
if instance.unreachable:
|
|
56
|
+
status += "\n(unreachable)"
|
|
57
|
+
elif not instance.health_status.is_healthy():
|
|
58
|
+
status += f"\n({instance.health_status.value})"
|
|
59
59
|
|
|
60
60
|
backend = instance.backend or ""
|
|
61
61
|
if backend == "remote":
|
|
@@ -12,11 +12,15 @@ from dstack._internal.core.models.profiles import (
|
|
|
12
12
|
TerminationPolicy,
|
|
13
13
|
)
|
|
14
14
|
from dstack._internal.core.models.runs import (
|
|
15
|
+
JobStatus,
|
|
16
|
+
Probe,
|
|
17
|
+
ProbeSpec,
|
|
15
18
|
RunPlan,
|
|
16
19
|
)
|
|
17
20
|
from dstack._internal.core.services.profiles import get_termination
|
|
18
21
|
from dstack._internal.utils.common import (
|
|
19
22
|
DateFormatter,
|
|
23
|
+
batched,
|
|
20
24
|
format_duration_multiunit,
|
|
21
25
|
format_pretty_duration,
|
|
22
26
|
pretty_date,
|
|
@@ -156,6 +160,12 @@ def get_runs_table(
|
|
|
156
160
|
table.add_column("INSTANCE TYPE", no_wrap=True, ratio=1)
|
|
157
161
|
table.add_column("PRICE", style="grey58", ratio=1)
|
|
158
162
|
table.add_column("STATUS", no_wrap=True, ratio=1)
|
|
163
|
+
if verbose or any(
|
|
164
|
+
run._run.is_deployment_in_progress()
|
|
165
|
+
and any(job.job_submissions[-1].probes for job in run._run.jobs)
|
|
166
|
+
for run in runs
|
|
167
|
+
):
|
|
168
|
+
table.add_column("PROBES", ratio=1)
|
|
159
169
|
table.add_column("SUBMITTED", style="grey58", no_wrap=True, ratio=1)
|
|
160
170
|
if verbose:
|
|
161
171
|
table.add_column("ERROR", no_wrap=True, ratio=2)
|
|
@@ -198,6 +208,9 @@ def get_runs_table(
|
|
|
198
208
|
else ""
|
|
199
209
|
),
|
|
200
210
|
"STATUS": latest_job_submission.status_message,
|
|
211
|
+
"PROBES": _format_job_probes(
|
|
212
|
+
job.job_spec.probes, latest_job_submission.probes, latest_job_submission.status
|
|
213
|
+
),
|
|
201
214
|
"SUBMITTED": format_date(latest_job_submission.submitted_at),
|
|
202
215
|
"ERROR": latest_job_submission.error,
|
|
203
216
|
}
|
|
@@ -226,3 +239,22 @@ def get_runs_table(
|
|
|
226
239
|
add_row_from_dict(table, job_row, style="secondary" if len(run.jobs) != 1 else None)
|
|
227
240
|
|
|
228
241
|
return table
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _format_job_probes(
|
|
245
|
+
probe_specs: list[ProbeSpec], probes: list[Probe], job_status: JobStatus
|
|
246
|
+
) -> str:
|
|
247
|
+
if not probes or job_status != JobStatus.RUNNING:
|
|
248
|
+
return ""
|
|
249
|
+
statuses = []
|
|
250
|
+
for probe_spec, probe in zip(probe_specs, probes):
|
|
251
|
+
# NOTE: the symbols are documented in concepts/services.md, keep in sync.
|
|
252
|
+
if probe.success_streak >= probe_spec.ready_after:
|
|
253
|
+
status = "[code]✓[/]"
|
|
254
|
+
elif probe.success_streak > 0:
|
|
255
|
+
status = "[warning]~[/]"
|
|
256
|
+
else:
|
|
257
|
+
status = "[error]×[/]"
|
|
258
|
+
statuses.append(status)
|
|
259
|
+
# split into whitespace-delimited batches to allow column wrapping
|
|
260
|
+
return " ".join("".join(batch) for batch in batched(statuses, 5))
|
|
@@ -54,6 +54,15 @@ try:
|
|
|
54
54
|
except ImportError:
|
|
55
55
|
pass
|
|
56
56
|
|
|
57
|
+
try:
|
|
58
|
+
from dstack._internal.core.backends.hotaisle.configurator import (
|
|
59
|
+
HotAisleConfigurator,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
_CONFIGURATOR_CLASSES.append(HotAisleConfigurator)
|
|
63
|
+
except ImportError:
|
|
64
|
+
pass
|
|
65
|
+
|
|
57
66
|
try:
|
|
58
67
|
from dstack._internal.core.backends.kubernetes.configurator import (
|
|
59
68
|
KubernetesConfigurator,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Hotaisle backend for dstack
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
import requests
|
|
4
|
+
|
|
5
|
+
from dstack._internal.core.backends.base.configurator import raise_invalid_credentials_error
|
|
6
|
+
from dstack._internal.utils.logging import get_logger
|
|
7
|
+
|
|
8
|
+
API_URL = "https://admin.hotaisle.app/api"
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class HotAisleAPIClient:
|
|
14
|
+
def __init__(self, api_key: str, team_handle: str):
|
|
15
|
+
self.api_key = api_key
|
|
16
|
+
self.team_handle = team_handle
|
|
17
|
+
|
|
18
|
+
def validate_api_key(self) -> bool:
|
|
19
|
+
try:
|
|
20
|
+
self._validate_user_and_team()
|
|
21
|
+
return True
|
|
22
|
+
except requests.HTTPError as e:
|
|
23
|
+
if e.response.status_code == 401:
|
|
24
|
+
raise_invalid_credentials_error(
|
|
25
|
+
fields=[["creds", "api_key"]], details="Invalid API key"
|
|
26
|
+
)
|
|
27
|
+
elif e.response.status_code == 403:
|
|
28
|
+
raise_invalid_credentials_error(
|
|
29
|
+
fields=[["creds", "api_key"]],
|
|
30
|
+
details="Authenticated user does note have required permissions",
|
|
31
|
+
)
|
|
32
|
+
raise e
|
|
33
|
+
except ValueError as e:
|
|
34
|
+
error_message = str(e)
|
|
35
|
+
if "No Hot Aisle teams found" in error_message:
|
|
36
|
+
raise_invalid_credentials_error(
|
|
37
|
+
fields=[["creds", "api_key"]],
|
|
38
|
+
details="Valid API key but no teams found for this user",
|
|
39
|
+
)
|
|
40
|
+
elif "not found" in error_message:
|
|
41
|
+
raise_invalid_credentials_error(
|
|
42
|
+
fields=[["team_handle"]], details=f"Team handle '{self.team_handle}' not found"
|
|
43
|
+
)
|
|
44
|
+
raise e
|
|
45
|
+
|
|
46
|
+
def _validate_user_and_team(self) -> None:
|
|
47
|
+
url = f"{API_URL}/user/"
|
|
48
|
+
response = self._make_request("GET", url)
|
|
49
|
+
response.raise_for_status()
|
|
50
|
+
user_data = response.json()
|
|
51
|
+
|
|
52
|
+
teams = user_data.get("teams", [])
|
|
53
|
+
if not teams:
|
|
54
|
+
raise ValueError("No Hot Aisle teams found for this user")
|
|
55
|
+
|
|
56
|
+
available_teams = [team["handle"] for team in teams]
|
|
57
|
+
if self.team_handle not in available_teams:
|
|
58
|
+
raise ValueError(f"Hot Aisle team '{self.team_handle}' not found.")
|
|
59
|
+
|
|
60
|
+
def upload_ssh_key(self, public_key: str) -> bool:
|
|
61
|
+
url = f"{API_URL}/user/ssh_keys/"
|
|
62
|
+
payload = {"authorized_key": public_key}
|
|
63
|
+
|
|
64
|
+
response = self._make_request("POST", url, json=payload)
|
|
65
|
+
|
|
66
|
+
if response.status_code == 409:
|
|
67
|
+
return True # Key already exists - success
|
|
68
|
+
response.raise_for_status()
|
|
69
|
+
return True
|
|
70
|
+
|
|
71
|
+
def create_virtual_machine(self, vm_payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
72
|
+
url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/"
|
|
73
|
+
response = self._make_request("POST", url, json=vm_payload)
|
|
74
|
+
response.raise_for_status()
|
|
75
|
+
vm_data = response.json()
|
|
76
|
+
return vm_data
|
|
77
|
+
|
|
78
|
+
def get_vm_state(self, vm_name: str) -> str:
|
|
79
|
+
url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/{vm_name}/state/"
|
|
80
|
+
response = self._make_request("GET", url)
|
|
81
|
+
response.raise_for_status()
|
|
82
|
+
state_data = response.json()
|
|
83
|
+
return state_data["state"]
|
|
84
|
+
|
|
85
|
+
def terminate_virtual_machine(self, vm_name: str) -> None:
|
|
86
|
+
url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/{vm_name}/"
|
|
87
|
+
response = self._make_request("DELETE", url)
|
|
88
|
+
if response.status_code == 404:
|
|
89
|
+
logger.debug("Hot Aisle virtual machine %s not found", vm_name)
|
|
90
|
+
return
|
|
91
|
+
response.raise_for_status()
|
|
92
|
+
|
|
93
|
+
def _make_request(
|
|
94
|
+
self, method: str, url: str, json: Optional[Dict[str, Any]] = None, timeout: int = 30
|
|
95
|
+
) -> requests.Response:
|
|
96
|
+
headers = {
|
|
97
|
+
"accept": "application/json",
|
|
98
|
+
"Authorization": f"Token {self.api_key}",
|
|
99
|
+
}
|
|
100
|
+
if json is not None:
|
|
101
|
+
headers["Content-Type"] = "application/json"
|
|
102
|
+
|
|
103
|
+
return requests.request(
|
|
104
|
+
method=method,
|
|
105
|
+
url=url,
|
|
106
|
+
headers=headers,
|
|
107
|
+
json=json,
|
|
108
|
+
timeout=timeout,
|
|
109
|
+
)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from dstack._internal.core.backends.base.backend import Backend
|
|
2
|
+
from dstack._internal.core.backends.hotaisle.compute import HotAisleCompute
|
|
3
|
+
from dstack._internal.core.backends.hotaisle.models import HotAisleConfig
|
|
4
|
+
from dstack._internal.core.models.backends.base import BackendType
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class HotAisleBackend(Backend):
|
|
8
|
+
TYPE = BackendType.HOTAISLE
|
|
9
|
+
COMPUTE_CLASS = HotAisleCompute
|
|
10
|
+
|
|
11
|
+
def __init__(self, config: HotAisleConfig):
|
|
12
|
+
self.config = config
|
|
13
|
+
self._compute = HotAisleCompute(self.config)
|
|
14
|
+
|
|
15
|
+
def compute(self) -> HotAisleCompute:
|
|
16
|
+
return self._compute
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import shlex
|
|
2
|
+
import subprocess
|
|
3
|
+
import tempfile
|
|
4
|
+
from threading import Thread
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
|
|
7
|
+
import gpuhunt
|
|
8
|
+
from gpuhunt.providers.hotaisle import HotAisleProvider
|
|
9
|
+
|
|
10
|
+
from dstack._internal.core.backends.base.compute import (
|
|
11
|
+
Compute,
|
|
12
|
+
ComputeWithCreateInstanceSupport,
|
|
13
|
+
get_shim_commands,
|
|
14
|
+
)
|
|
15
|
+
from dstack._internal.core.backends.base.offers import get_catalog_offers
|
|
16
|
+
from dstack._internal.core.backends.hotaisle.api_client import HotAisleAPIClient
|
|
17
|
+
from dstack._internal.core.backends.hotaisle.models import HotAisleConfig
|
|
18
|
+
from dstack._internal.core.models.backends.base import BackendType
|
|
19
|
+
from dstack._internal.core.models.common import CoreModel
|
|
20
|
+
from dstack._internal.core.models.instances import (
|
|
21
|
+
InstanceAvailability,
|
|
22
|
+
InstanceConfiguration,
|
|
23
|
+
InstanceOfferWithAvailability,
|
|
24
|
+
)
|
|
25
|
+
from dstack._internal.core.models.placement import PlacementGroup
|
|
26
|
+
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
|
|
27
|
+
from dstack._internal.utils.logging import get_logger
|
|
28
|
+
|
|
29
|
+
logger = get_logger(__name__)
|
|
30
|
+
|
|
31
|
+
MAX_INSTANCE_NAME_LEN = 60
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
INSTANCE_TYPE_SPECS = {
|
|
35
|
+
"1x MI300X 8x Xeon Platinum 8462Y+": {
|
|
36
|
+
"cpu_model": "Xeon Platinum 8462Y+",
|
|
37
|
+
"cpu_frequency": 2800000000,
|
|
38
|
+
"cpu_manufacturer": "Intel",
|
|
39
|
+
},
|
|
40
|
+
"1x MI300X 13x Xeon Platinum 8470": {
|
|
41
|
+
"cpu_model": "Xeon Platinum 8470",
|
|
42
|
+
"cpu_frequency": 2000000000,
|
|
43
|
+
"cpu_manufacturer": "Intel",
|
|
44
|
+
},
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class HotAisleCompute(
|
|
49
|
+
ComputeWithCreateInstanceSupport,
|
|
50
|
+
Compute,
|
|
51
|
+
):
|
|
52
|
+
def __init__(self, config: HotAisleConfig):
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.config = config
|
|
55
|
+
self.api_client = HotAisleAPIClient(config.creds.api_key, config.team_handle)
|
|
56
|
+
self.catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False)
|
|
57
|
+
self.catalog.add_provider(
|
|
58
|
+
HotAisleProvider(api_key=config.creds.api_key, team_handle=config.team_handle)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def get_offers(
|
|
62
|
+
self, requirements: Optional[Requirements] = None
|
|
63
|
+
) -> List[InstanceOfferWithAvailability]:
|
|
64
|
+
offers = get_catalog_offers(
|
|
65
|
+
backend=BackendType.HOTAISLE,
|
|
66
|
+
locations=self.config.regions or None,
|
|
67
|
+
requirements=requirements,
|
|
68
|
+
catalog=self.catalog,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
supported_offers = []
|
|
72
|
+
for offer in offers:
|
|
73
|
+
if offer.instance.name in INSTANCE_TYPE_SPECS:
|
|
74
|
+
supported_offers.append(
|
|
75
|
+
InstanceOfferWithAvailability(
|
|
76
|
+
**offer.dict(), availability=InstanceAvailability.AVAILABLE
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
logger.warning(
|
|
81
|
+
f"Skipping unsupported Hot Aisle instance type: {offer.instance.name}"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return supported_offers
|
|
85
|
+
|
|
86
|
+
def get_payload_from_offer(self, instance_type) -> dict:
|
|
87
|
+
instance_type_name = instance_type.name
|
|
88
|
+
cpu_specs = INSTANCE_TYPE_SPECS[instance_type_name]
|
|
89
|
+
cpu_cores = instance_type.resources.cpus
|
|
90
|
+
|
|
91
|
+
return {
|
|
92
|
+
"cpu_cores": cpu_cores,
|
|
93
|
+
"cpus": {
|
|
94
|
+
"count": 1,
|
|
95
|
+
"manufacturer": cpu_specs["cpu_manufacturer"],
|
|
96
|
+
"model": cpu_specs["cpu_model"],
|
|
97
|
+
"cores": cpu_cores,
|
|
98
|
+
"frequency": cpu_specs["cpu_frequency"],
|
|
99
|
+
},
|
|
100
|
+
"disk_capacity": instance_type.resources.disk.size_mib * 1024**2,
|
|
101
|
+
"ram_capacity": instance_type.resources.memory_mib * 1024**2,
|
|
102
|
+
"gpus": [
|
|
103
|
+
{
|
|
104
|
+
"count": len(instance_type.resources.gpus),
|
|
105
|
+
"manufacturer": instance_type.resources.gpus[0].vendor,
|
|
106
|
+
"model": instance_type.resources.gpus[0].name,
|
|
107
|
+
}
|
|
108
|
+
],
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
def create_instance(
|
|
112
|
+
self,
|
|
113
|
+
instance_offer: InstanceOfferWithAvailability,
|
|
114
|
+
instance_config: InstanceConfiguration,
|
|
115
|
+
placement_group: Optional[PlacementGroup],
|
|
116
|
+
) -> JobProvisioningData:
|
|
117
|
+
project_ssh_key = instance_config.ssh_keys[0]
|
|
118
|
+
self.api_client.upload_ssh_key(project_ssh_key.public)
|
|
119
|
+
vm_payload = self.get_payload_from_offer(instance_offer.instance)
|
|
120
|
+
vm_data = self.api_client.create_virtual_machine(vm_payload)
|
|
121
|
+
return JobProvisioningData(
|
|
122
|
+
backend=instance_offer.backend,
|
|
123
|
+
instance_type=instance_offer.instance,
|
|
124
|
+
instance_id=vm_data["name"],
|
|
125
|
+
hostname=None,
|
|
126
|
+
internal_ip=None,
|
|
127
|
+
region=instance_offer.region,
|
|
128
|
+
price=instance_offer.price,
|
|
129
|
+
username="hotaisle",
|
|
130
|
+
ssh_port=22,
|
|
131
|
+
dockerized=True,
|
|
132
|
+
ssh_proxy=None,
|
|
133
|
+
backend_data=HotAisleInstanceBackendData(
|
|
134
|
+
ip_address=vm_data["ip_address"], vm_id=vm_data["name"]
|
|
135
|
+
).json(),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def update_provisioning_data(
|
|
139
|
+
self,
|
|
140
|
+
provisioning_data: JobProvisioningData,
|
|
141
|
+
project_ssh_public_key: str,
|
|
142
|
+
project_ssh_private_key: str,
|
|
143
|
+
):
|
|
144
|
+
vm_state = self.api_client.get_vm_state(provisioning_data.instance_id)
|
|
145
|
+
if vm_state == "running":
|
|
146
|
+
if provisioning_data.hostname is None and provisioning_data.backend_data:
|
|
147
|
+
backend_data = HotAisleInstanceBackendData.load(provisioning_data.backend_data)
|
|
148
|
+
provisioning_data.hostname = backend_data.ip_address
|
|
149
|
+
commands = get_shim_commands(
|
|
150
|
+
authorized_keys=[project_ssh_public_key],
|
|
151
|
+
arch=provisioning_data.instance_type.resources.cpu_arch,
|
|
152
|
+
)
|
|
153
|
+
launch_command = "sudo sh -c " + shlex.quote(" && ".join(commands))
|
|
154
|
+
thread = Thread(
|
|
155
|
+
target=_start_runner,
|
|
156
|
+
kwargs={
|
|
157
|
+
"hostname": provisioning_data.hostname,
|
|
158
|
+
"project_ssh_private_key": project_ssh_private_key,
|
|
159
|
+
"launch_command": launch_command,
|
|
160
|
+
},
|
|
161
|
+
daemon=True,
|
|
162
|
+
)
|
|
163
|
+
thread.start()
|
|
164
|
+
|
|
165
|
+
def terminate_instance(
|
|
166
|
+
self, instance_id: str, region: str, backend_data: Optional[str] = None
|
|
167
|
+
):
|
|
168
|
+
vm_name = instance_id
|
|
169
|
+
self.api_client.terminate_virtual_machine(vm_name)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _start_runner(
|
|
173
|
+
hostname: str,
|
|
174
|
+
project_ssh_private_key: str,
|
|
175
|
+
launch_command: str,
|
|
176
|
+
):
|
|
177
|
+
_launch_runner(
|
|
178
|
+
hostname=hostname,
|
|
179
|
+
ssh_private_key=project_ssh_private_key,
|
|
180
|
+
launch_command=launch_command,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _launch_runner(
|
|
185
|
+
hostname: str,
|
|
186
|
+
ssh_private_key: str,
|
|
187
|
+
launch_command: str,
|
|
188
|
+
):
|
|
189
|
+
daemonized_command = f"{launch_command.rstrip('&')} >/tmp/dstack-shim.log 2>&1 & disown"
|
|
190
|
+
_run_ssh_command(
|
|
191
|
+
hostname=hostname,
|
|
192
|
+
ssh_private_key=ssh_private_key,
|
|
193
|
+
command=daemonized_command,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _run_ssh_command(hostname: str, ssh_private_key: str, command: str):
|
|
198
|
+
with tempfile.NamedTemporaryFile("w+", 0o600) as f:
|
|
199
|
+
f.write(ssh_private_key)
|
|
200
|
+
f.flush()
|
|
201
|
+
subprocess.run(
|
|
202
|
+
[
|
|
203
|
+
"ssh",
|
|
204
|
+
"-F",
|
|
205
|
+
"none",
|
|
206
|
+
"-o",
|
|
207
|
+
"StrictHostKeyChecking=no",
|
|
208
|
+
"-i",
|
|
209
|
+
f.name,
|
|
210
|
+
f"hotaisle@{hostname}",
|
|
211
|
+
command,
|
|
212
|
+
],
|
|
213
|
+
stdout=subprocess.DEVNULL,
|
|
214
|
+
stderr=subprocess.DEVNULL,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class HotAisleInstanceBackendData(CoreModel):
|
|
219
|
+
ip_address: str
|
|
220
|
+
vm_id: Optional[str] = None
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def load(cls, raw: Optional[str]) -> "HotAisleInstanceBackendData":
|
|
224
|
+
assert raw is not None
|
|
225
|
+
return cls.__response__.parse_raw(raw)
|