dstack 0.19.21__py3-none-any.whl → 0.19.23rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (71) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -3
  2. dstack/_internal/cli/services/configurators/__init__.py +8 -0
  3. dstack/_internal/cli/services/configurators/fleet.py +1 -1
  4. dstack/_internal/cli/services/configurators/gateway.py +1 -1
  5. dstack/_internal/cli/services/configurators/run.py +11 -1
  6. dstack/_internal/cli/services/configurators/volume.py +1 -1
  7. dstack/_internal/cli/utils/common.py +48 -5
  8. dstack/_internal/cli/utils/fleet.py +5 -5
  9. dstack/_internal/cli/utils/run.py +32 -0
  10. dstack/_internal/core/backends/configurators.py +9 -0
  11. dstack/_internal/core/backends/hotaisle/__init__.py +1 -0
  12. dstack/_internal/core/backends/hotaisle/api_client.py +109 -0
  13. dstack/_internal/core/backends/hotaisle/backend.py +16 -0
  14. dstack/_internal/core/backends/hotaisle/compute.py +225 -0
  15. dstack/_internal/core/backends/hotaisle/configurator.py +60 -0
  16. dstack/_internal/core/backends/hotaisle/models.py +45 -0
  17. dstack/_internal/core/backends/lambdalabs/compute.py +2 -1
  18. dstack/_internal/core/backends/models.py +8 -0
  19. dstack/_internal/core/backends/nebius/compute.py +8 -2
  20. dstack/_internal/core/backends/nebius/fabrics.py +1 -0
  21. dstack/_internal/core/backends/nebius/resources.py +9 -0
  22. dstack/_internal/core/compatibility/runs.py +8 -0
  23. dstack/_internal/core/models/backends/base.py +2 -0
  24. dstack/_internal/core/models/configurations.py +139 -1
  25. dstack/_internal/core/models/health.py +28 -0
  26. dstack/_internal/core/models/instances.py +2 -0
  27. dstack/_internal/core/models/logs.py +2 -1
  28. dstack/_internal/core/models/runs.py +21 -1
  29. dstack/_internal/core/services/ssh/tunnel.py +7 -0
  30. dstack/_internal/server/app.py +4 -0
  31. dstack/_internal/server/background/__init__.py +4 -0
  32. dstack/_internal/server/background/tasks/process_instances.py +107 -56
  33. dstack/_internal/server/background/tasks/process_probes.py +164 -0
  34. dstack/_internal/server/background/tasks/process_running_jobs.py +13 -0
  35. dstack/_internal/server/background/tasks/process_runs.py +21 -14
  36. dstack/_internal/server/migrations/versions/25479f540245_add_probes.py +43 -0
  37. dstack/_internal/server/migrations/versions/728b1488b1b4_add_instance_health.py +50 -0
  38. dstack/_internal/server/models.py +41 -0
  39. dstack/_internal/server/routers/instances.py +33 -5
  40. dstack/_internal/server/schemas/health/dcgm.py +56 -0
  41. dstack/_internal/server/schemas/instances.py +32 -0
  42. dstack/_internal/server/schemas/runner.py +5 -0
  43. dstack/_internal/server/services/instances.py +103 -1
  44. dstack/_internal/server/services/jobs/__init__.py +8 -1
  45. dstack/_internal/server/services/jobs/configurators/base.py +26 -0
  46. dstack/_internal/server/services/logging.py +4 -2
  47. dstack/_internal/server/services/logs/aws.py +13 -1
  48. dstack/_internal/server/services/logs/gcp.py +16 -1
  49. dstack/_internal/server/services/probes.py +6 -0
  50. dstack/_internal/server/services/projects.py +16 -4
  51. dstack/_internal/server/services/runner/client.py +52 -20
  52. dstack/_internal/server/services/runner/ssh.py +4 -4
  53. dstack/_internal/server/services/runs.py +49 -13
  54. dstack/_internal/server/services/ssh.py +66 -0
  55. dstack/_internal/server/settings.py +13 -0
  56. dstack/_internal/server/statics/index.html +1 -1
  57. dstack/_internal/server/statics/{main-8f9ee218d3eb45989682.css → main-03e818b110e1d5705378.css} +1 -1
  58. dstack/_internal/server/statics/{main-39a767528976f8078166.js → main-cc067b7fd1a8f33f97da.js} +26 -15
  59. dstack/_internal/server/statics/{main-39a767528976f8078166.js.map → main-cc067b7fd1a8f33f97da.js.map} +1 -1
  60. dstack/_internal/server/testing/common.py +44 -0
  61. dstack/_internal/{core/backends/remote → server/utils}/provisioning.py +22 -17
  62. dstack/_internal/settings.py +3 -0
  63. dstack/_internal/utils/common.py +15 -0
  64. dstack/api/server/__init__.py +1 -1
  65. dstack/version.py +1 -1
  66. {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/METADATA +14 -14
  67. {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/RECORD +71 -58
  68. /dstack/_internal/{core/backends/remote → server/schemas/health}/__init__.py +0 -0
  69. {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/WHEEL +0 -0
  70. {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/entry_points.txt +0 -0
  71. {dstack-0.19.21.dist-info → dstack-0.19.23rc1.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,10 +1,10 @@
1
1
  import argparse
2
- from pathlib import Path
3
2
 
4
3
  from argcomplete import FilesCompleter
5
4
 
6
5
  from dstack._internal.cli.commands import APIBaseCommand
7
6
  from dstack._internal.cli.services.configurators import (
7
+ APPLY_STDIN_NAME,
8
8
  get_apply_configurator_class,
9
9
  load_apply_configuration,
10
10
  )
@@ -40,9 +40,12 @@ class ApplyCommand(APIBaseCommand):
40
40
  self._parser.add_argument(
41
41
  "-f",
42
42
  "--file",
43
- type=Path,
44
43
  metavar="FILE",
45
- help="The path to the configuration file. Defaults to [code]$PWD/.dstack.yml[/]",
44
+ help=(
45
+ "The path to the configuration file."
46
+ " Specify [code]-[/] to read configuration from stdin."
47
+ " Defaults to [code]$PWD/.dstack.yml[/]"
48
+ ),
46
49
  dest="configuration_file",
47
50
  ).completer = FilesCompleter(allowednames=["*.yml", "*.yaml"])
48
51
  self._parser.add_argument(
@@ -104,6 +107,8 @@ class ApplyCommand(APIBaseCommand):
104
107
  return
105
108
 
106
109
  super()._command(args)
110
+ if not args.yes and args.configuration_file == APPLY_STDIN_NAME:
111
+ raise CLIError("Cannot read configuration from stdin if -y/--yes is not specified")
107
112
  if args.repo and args.no_repo:
108
113
  raise CLIError("Either --repo or --no-repo can be specified")
109
114
  repo = None
@@ -1,3 +1,4 @@
1
+ import sys
1
2
  from pathlib import Path
2
3
  from typing import Dict, Optional, Tuple, Type
3
4
 
@@ -20,6 +21,9 @@ from dstack._internal.core.models.configurations import (
20
21
  parse_apply_configuration,
21
22
  )
22
23
 
24
+ APPLY_STDIN_NAME = "-"
25
+
26
+
23
27
  apply_configurators_mapping: Dict[ApplyConfigurationType, Type[BaseApplyConfigurator]] = {
24
28
  cls.TYPE: cls
25
29
  for cls in [
@@ -62,6 +66,8 @@ def load_apply_configuration(
62
66
  raise ConfigurationError(
63
67
  "No configuration file specified via `-f` and no default .dstack.yml configuration found"
64
68
  )
69
+ elif configuration_file == APPLY_STDIN_NAME:
70
+ configuration_path = sys.stdin.fileno()
65
71
  else:
66
72
  configuration_path = Path(configuration_file)
67
73
  if not configuration_path.exists():
@@ -71,4 +77,6 @@ def load_apply_configuration(
71
77
  conf = parse_apply_configuration(yaml.safe_load(f))
72
78
  except OSError:
73
79
  raise ConfigurationError(f"Failed to load configuration from {configuration_path}")
80
+ if isinstance(configuration_path, int):
81
+ return APPLY_STDIN_NAME, conf
74
82
  return str(configuration_path.absolute().relative_to(Path.cwd())), conf
@@ -151,7 +151,7 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
151
151
  time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
152
152
  fleet = self.api.client.fleets.get(self.api.project, fleet.name)
153
153
  except KeyboardInterrupt:
154
- if confirm_ask("Delete the fleet before exiting?"):
154
+ if not command_args.yes and confirm_ask("Delete the fleet before exiting?"):
155
155
  with console.status("Deleting fleet..."):
156
156
  self.api.client.fleets.delete(
157
157
  project_name=self.api.project, names=[fleet.name]
@@ -121,7 +121,7 @@ class GatewayConfigurator(BaseApplyConfigurator):
121
121
  time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
122
122
  gateway = self.api.client.gateways.get(self.api.project, gateway.name)
123
123
  except KeyboardInterrupt:
124
- if confirm_ask("Delete the gateway before exiting?"):
124
+ if not command_args.yes and confirm_ask("Delete the gateway before exiting?"):
125
125
  with console.status("Deleting gateway..."):
126
126
  self.api.client.gateways.delete(
127
127
  project_name=self.api.project,
@@ -218,7 +218,9 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
218
218
  exit(1)
219
219
  except KeyboardInterrupt:
220
220
  try:
221
- if not confirm_ask(f"\nStop the run [code]{run.name}[/] before detaching?"):
221
+ if command_args.yes or not confirm_ask(
222
+ f"\nStop the run [code]{run.name}[/] before detaching?"
223
+ ):
222
224
  console.print("Detached")
223
225
  abort_at_exit = False
224
226
  return
@@ -339,6 +341,14 @@ class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
339
341
  username=interpolator.interpolate_or_error(conf.registry_auth.username),
340
342
  password=interpolator.interpolate_or_error(conf.registry_auth.password),
341
343
  )
344
+ if isinstance(conf, ServiceConfiguration):
345
+ for probe in conf.probes:
346
+ for header in probe.headers:
347
+ header.value = interpolator.interpolate_or_error(header.value)
348
+ if probe.url:
349
+ probe.url = interpolator.interpolate_or_error(probe.url)
350
+ if probe.body:
351
+ probe.body = interpolator.interpolate_or_error(probe.body)
342
352
  except InterpolatorError as e:
343
353
  raise ConfigurationError(e.args[0])
344
354
 
@@ -110,7 +110,7 @@ class VolumeConfigurator(BaseApplyConfigurator):
110
110
  time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
111
111
  volume = self.api.client.volumes.get(self.api.project, volume.name)
112
112
  except KeyboardInterrupt:
113
- if confirm_ask("Delete the volume before exiting?"):
113
+ if not command_args.yes and confirm_ask("Delete the volume before exiting?"):
114
114
  with console.status("Deleting volume..."):
115
115
  self.api.client.volumes.delete(
116
116
  project_name=self.api.project, names=[volume.name]
@@ -1,5 +1,6 @@
1
1
  import logging
2
- import os
2
+ from datetime import datetime, timezone
3
+ from pathlib import Path
3
4
  from typing import Any, Dict, Union
4
5
 
5
6
  from rich.console import Console
@@ -7,8 +8,10 @@ from rich.prompt import Confirm
7
8
  from rich.table import Table
8
9
  from rich.theme import Theme
9
10
 
11
+ from dstack._internal import settings
10
12
  from dstack._internal.cli.utils.rich import DstackRichHandler
11
13
  from dstack._internal.core.errors import CLIError, DstackError
14
+ from dstack._internal.utils.common import get_dstack_dir
12
15
 
13
16
  _colors = {
14
17
  "secondary": "grey58",
@@ -35,12 +38,52 @@ def cli_error(e: DstackError) -> CLIError:
35
38
  return CLIError(*e.args)
36
39
 
37
40
 
41
+ def _get_cli_log_file() -> Path:
42
+ """Get the CLI log file path, rotating the previous log if needed."""
43
+ log_dir = get_dstack_dir() / "logs" / "cli"
44
+ log_file = log_dir / "latest.log"
45
+
46
+ if log_file.exists():
47
+ file_mtime = datetime.fromtimestamp(log_file.stat().st_mtime, tz=timezone.utc)
48
+ current_date = datetime.now(timezone.utc).date()
49
+
50
+ if file_mtime.date() < current_date:
51
+ date_str = file_mtime.strftime("%Y-%m-%d")
52
+ rotated_file = log_dir / f"{date_str}.log"
53
+
54
+ counter = 1
55
+ while rotated_file.exists():
56
+ rotated_file = log_dir / f"{date_str}-{counter}.log"
57
+ counter += 1
58
+
59
+ log_file.rename(rotated_file)
60
+
61
+ log_dir.mkdir(parents=True, exist_ok=True)
62
+ return log_file
63
+
64
+
38
65
  def configure_logging():
39
66
  dstack_logger = logging.getLogger("dstack")
40
- dstack_logger.setLevel(os.getenv("DSTACK_CLI_LOG_LEVEL", "INFO").upper())
41
- handler = DstackRichHandler(console=console)
42
- handler.setFormatter(logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
43
- dstack_logger.addHandler(handler)
67
+ dstack_logger.handlers.clear()
68
+
69
+ log_file = _get_cli_log_file()
70
+
71
+ stdout_handler = DstackRichHandler(console=console)
72
+ stdout_handler.setFormatter(logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
73
+ stdout_handler.setLevel(settings.CLI_LOG_LEVEL)
74
+ dstack_logger.addHandler(stdout_handler)
75
+
76
+ file_handler = logging.FileHandler(log_file)
77
+ file_handler.setFormatter(
78
+ logging.Formatter(
79
+ fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
80
+ )
81
+ )
82
+ file_handler.setLevel(settings.CLI_FILE_LOG_LEVEL)
83
+ dstack_logger.addHandler(file_handler)
84
+
85
+ # the logger allows all messages, filtering is done by the handlers
86
+ dstack_logger.setLevel(logging.DEBUG)
44
87
 
45
88
 
46
89
  def confirm_ask(prompt, **kwargs) -> bool:
@@ -51,11 +51,11 @@ def get_fleets_table(
51
51
  and total_blocks > 1
52
52
  ):
53
53
  status = f"{busy_blocks}/{total_blocks} {InstanceStatus.BUSY.value}"
54
- if (
55
- instance.status in [InstanceStatus.IDLE, InstanceStatus.BUSY]
56
- and instance.unreachable
57
- ):
58
- status += "\n(unreachable)"
54
+ if instance.status in [InstanceStatus.IDLE, InstanceStatus.BUSY]:
55
+ if instance.unreachable:
56
+ status += "\n(unreachable)"
57
+ elif not instance.health_status.is_healthy():
58
+ status += f"\n({instance.health_status.value})"
59
59
 
60
60
  backend = instance.backend or ""
61
61
  if backend == "remote":
@@ -12,11 +12,15 @@ from dstack._internal.core.models.profiles import (
12
12
  TerminationPolicy,
13
13
  )
14
14
  from dstack._internal.core.models.runs import (
15
+ JobStatus,
16
+ Probe,
17
+ ProbeSpec,
15
18
  RunPlan,
16
19
  )
17
20
  from dstack._internal.core.services.profiles import get_termination
18
21
  from dstack._internal.utils.common import (
19
22
  DateFormatter,
23
+ batched,
20
24
  format_duration_multiunit,
21
25
  format_pretty_duration,
22
26
  pretty_date,
@@ -156,6 +160,12 @@ def get_runs_table(
156
160
  table.add_column("INSTANCE TYPE", no_wrap=True, ratio=1)
157
161
  table.add_column("PRICE", style="grey58", ratio=1)
158
162
  table.add_column("STATUS", no_wrap=True, ratio=1)
163
+ if verbose or any(
164
+ run._run.is_deployment_in_progress()
165
+ and any(job.job_submissions[-1].probes for job in run._run.jobs)
166
+ for run in runs
167
+ ):
168
+ table.add_column("PROBES", ratio=1)
159
169
  table.add_column("SUBMITTED", style="grey58", no_wrap=True, ratio=1)
160
170
  if verbose:
161
171
  table.add_column("ERROR", no_wrap=True, ratio=2)
@@ -198,6 +208,9 @@ def get_runs_table(
198
208
  else ""
199
209
  ),
200
210
  "STATUS": latest_job_submission.status_message,
211
+ "PROBES": _format_job_probes(
212
+ job.job_spec.probes, latest_job_submission.probes, latest_job_submission.status
213
+ ),
201
214
  "SUBMITTED": format_date(latest_job_submission.submitted_at),
202
215
  "ERROR": latest_job_submission.error,
203
216
  }
@@ -226,3 +239,22 @@ def get_runs_table(
226
239
  add_row_from_dict(table, job_row, style="secondary" if len(run.jobs) != 1 else None)
227
240
 
228
241
  return table
242
+
243
+
244
+ def _format_job_probes(
245
+ probe_specs: list[ProbeSpec], probes: list[Probe], job_status: JobStatus
246
+ ) -> str:
247
+ if not probes or job_status != JobStatus.RUNNING:
248
+ return ""
249
+ statuses = []
250
+ for probe_spec, probe in zip(probe_specs, probes):
251
+ # NOTE: the symbols are documented in concepts/services.md, keep in sync.
252
+ if probe.success_streak >= probe_spec.ready_after:
253
+ status = "[code]✓[/]"
254
+ elif probe.success_streak > 0:
255
+ status = "[warning]~[/]"
256
+ else:
257
+ status = "[error]×[/]"
258
+ statuses.append(status)
259
+ # split into whitespace-delimited batches to allow column wrapping
260
+ return " ".join("".join(batch) for batch in batched(statuses, 5))
@@ -54,6 +54,15 @@ try:
54
54
  except ImportError:
55
55
  pass
56
56
 
57
+ try:
58
+ from dstack._internal.core.backends.hotaisle.configurator import (
59
+ HotAisleConfigurator,
60
+ )
61
+
62
+ _CONFIGURATOR_CLASSES.append(HotAisleConfigurator)
63
+ except ImportError:
64
+ pass
65
+
57
66
  try:
58
67
  from dstack._internal.core.backends.kubernetes.configurator import (
59
68
  KubernetesConfigurator,
@@ -0,0 +1 @@
1
+ # Hotaisle backend for dstack
@@ -0,0 +1,109 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ import requests
4
+
5
+ from dstack._internal.core.backends.base.configurator import raise_invalid_credentials_error
6
+ from dstack._internal.utils.logging import get_logger
7
+
8
+ API_URL = "https://admin.hotaisle.app/api"
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ class HotAisleAPIClient:
14
+ def __init__(self, api_key: str, team_handle: str):
15
+ self.api_key = api_key
16
+ self.team_handle = team_handle
17
+
18
+ def validate_api_key(self) -> bool:
19
+ try:
20
+ self._validate_user_and_team()
21
+ return True
22
+ except requests.HTTPError as e:
23
+ if e.response.status_code == 401:
24
+ raise_invalid_credentials_error(
25
+ fields=[["creds", "api_key"]], details="Invalid API key"
26
+ )
27
+ elif e.response.status_code == 403:
28
+ raise_invalid_credentials_error(
29
+ fields=[["creds", "api_key"]],
30
+ details="Authenticated user does note have required permissions",
31
+ )
32
+ raise e
33
+ except ValueError as e:
34
+ error_message = str(e)
35
+ if "No Hot Aisle teams found" in error_message:
36
+ raise_invalid_credentials_error(
37
+ fields=[["creds", "api_key"]],
38
+ details="Valid API key but no teams found for this user",
39
+ )
40
+ elif "not found" in error_message:
41
+ raise_invalid_credentials_error(
42
+ fields=[["team_handle"]], details=f"Team handle '{self.team_handle}' not found"
43
+ )
44
+ raise e
45
+
46
+ def _validate_user_and_team(self) -> None:
47
+ url = f"{API_URL}/user/"
48
+ response = self._make_request("GET", url)
49
+ response.raise_for_status()
50
+ user_data = response.json()
51
+
52
+ teams = user_data.get("teams", [])
53
+ if not teams:
54
+ raise ValueError("No Hot Aisle teams found for this user")
55
+
56
+ available_teams = [team["handle"] for team in teams]
57
+ if self.team_handle not in available_teams:
58
+ raise ValueError(f"Hot Aisle team '{self.team_handle}' not found.")
59
+
60
+ def upload_ssh_key(self, public_key: str) -> bool:
61
+ url = f"{API_URL}/user/ssh_keys/"
62
+ payload = {"authorized_key": public_key}
63
+
64
+ response = self._make_request("POST", url, json=payload)
65
+
66
+ if response.status_code == 409:
67
+ return True # Key already exists - success
68
+ response.raise_for_status()
69
+ return True
70
+
71
+ def create_virtual_machine(self, vm_payload: Dict[str, Any]) -> Dict[str, Any]:
72
+ url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/"
73
+ response = self._make_request("POST", url, json=vm_payload)
74
+ response.raise_for_status()
75
+ vm_data = response.json()
76
+ return vm_data
77
+
78
+ def get_vm_state(self, vm_name: str) -> str:
79
+ url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/{vm_name}/state/"
80
+ response = self._make_request("GET", url)
81
+ response.raise_for_status()
82
+ state_data = response.json()
83
+ return state_data["state"]
84
+
85
+ def terminate_virtual_machine(self, vm_name: str) -> None:
86
+ url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/{vm_name}/"
87
+ response = self._make_request("DELETE", url)
88
+ if response.status_code == 404:
89
+ logger.debug("Hot Aisle virtual machine %s not found", vm_name)
90
+ return
91
+ response.raise_for_status()
92
+
93
+ def _make_request(
94
+ self, method: str, url: str, json: Optional[Dict[str, Any]] = None, timeout: int = 30
95
+ ) -> requests.Response:
96
+ headers = {
97
+ "accept": "application/json",
98
+ "Authorization": f"Token {self.api_key}",
99
+ }
100
+ if json is not None:
101
+ headers["Content-Type"] = "application/json"
102
+
103
+ return requests.request(
104
+ method=method,
105
+ url=url,
106
+ headers=headers,
107
+ json=json,
108
+ timeout=timeout,
109
+ )
@@ -0,0 +1,16 @@
1
+ from dstack._internal.core.backends.base.backend import Backend
2
+ from dstack._internal.core.backends.hotaisle.compute import HotAisleCompute
3
+ from dstack._internal.core.backends.hotaisle.models import HotAisleConfig
4
+ from dstack._internal.core.models.backends.base import BackendType
5
+
6
+
7
+ class HotAisleBackend(Backend):
8
+ TYPE = BackendType.HOTAISLE
9
+ COMPUTE_CLASS = HotAisleCompute
10
+
11
+ def __init__(self, config: HotAisleConfig):
12
+ self.config = config
13
+ self._compute = HotAisleCompute(self.config)
14
+
15
+ def compute(self) -> HotAisleCompute:
16
+ return self._compute
@@ -0,0 +1,225 @@
1
+ import shlex
2
+ import subprocess
3
+ import tempfile
4
+ from threading import Thread
5
+ from typing import List, Optional
6
+
7
+ import gpuhunt
8
+ from gpuhunt.providers.hotaisle import HotAisleProvider
9
+
10
+ from dstack._internal.core.backends.base.compute import (
11
+ Compute,
12
+ ComputeWithCreateInstanceSupport,
13
+ get_shim_commands,
14
+ )
15
+ from dstack._internal.core.backends.base.offers import get_catalog_offers
16
+ from dstack._internal.core.backends.hotaisle.api_client import HotAisleAPIClient
17
+ from dstack._internal.core.backends.hotaisle.models import HotAisleConfig
18
+ from dstack._internal.core.models.backends.base import BackendType
19
+ from dstack._internal.core.models.common import CoreModel
20
+ from dstack._internal.core.models.instances import (
21
+ InstanceAvailability,
22
+ InstanceConfiguration,
23
+ InstanceOfferWithAvailability,
24
+ )
25
+ from dstack._internal.core.models.placement import PlacementGroup
26
+ from dstack._internal.core.models.runs import JobProvisioningData, Requirements
27
+ from dstack._internal.utils.logging import get_logger
28
+
29
+ logger = get_logger(__name__)
30
+
31
+ MAX_INSTANCE_NAME_LEN = 60
32
+
33
+
34
+ INSTANCE_TYPE_SPECS = {
35
+ "1x MI300X 8x Xeon Platinum 8462Y+": {
36
+ "cpu_model": "Xeon Platinum 8462Y+",
37
+ "cpu_frequency": 2800000000,
38
+ "cpu_manufacturer": "Intel",
39
+ },
40
+ "1x MI300X 13x Xeon Platinum 8470": {
41
+ "cpu_model": "Xeon Platinum 8470",
42
+ "cpu_frequency": 2000000000,
43
+ "cpu_manufacturer": "Intel",
44
+ },
45
+ }
46
+
47
+
48
+ class HotAisleCompute(
49
+ ComputeWithCreateInstanceSupport,
50
+ Compute,
51
+ ):
52
+ def __init__(self, config: HotAisleConfig):
53
+ super().__init__()
54
+ self.config = config
55
+ self.api_client = HotAisleAPIClient(config.creds.api_key, config.team_handle)
56
+ self.catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False)
57
+ self.catalog.add_provider(
58
+ HotAisleProvider(api_key=config.creds.api_key, team_handle=config.team_handle)
59
+ )
60
+
61
+ def get_offers(
62
+ self, requirements: Optional[Requirements] = None
63
+ ) -> List[InstanceOfferWithAvailability]:
64
+ offers = get_catalog_offers(
65
+ backend=BackendType.HOTAISLE,
66
+ locations=self.config.regions or None,
67
+ requirements=requirements,
68
+ catalog=self.catalog,
69
+ )
70
+
71
+ supported_offers = []
72
+ for offer in offers:
73
+ if offer.instance.name in INSTANCE_TYPE_SPECS:
74
+ supported_offers.append(
75
+ InstanceOfferWithAvailability(
76
+ **offer.dict(), availability=InstanceAvailability.AVAILABLE
77
+ )
78
+ )
79
+ else:
80
+ logger.warning(
81
+ f"Skipping unsupported Hot Aisle instance type: {offer.instance.name}"
82
+ )
83
+
84
+ return supported_offers
85
+
86
+ def get_payload_from_offer(self, instance_type) -> dict:
87
+ instance_type_name = instance_type.name
88
+ cpu_specs = INSTANCE_TYPE_SPECS[instance_type_name]
89
+ cpu_cores = instance_type.resources.cpus
90
+
91
+ return {
92
+ "cpu_cores": cpu_cores,
93
+ "cpus": {
94
+ "count": 1,
95
+ "manufacturer": cpu_specs["cpu_manufacturer"],
96
+ "model": cpu_specs["cpu_model"],
97
+ "cores": cpu_cores,
98
+ "frequency": cpu_specs["cpu_frequency"],
99
+ },
100
+ "disk_capacity": instance_type.resources.disk.size_mib * 1024**2,
101
+ "ram_capacity": instance_type.resources.memory_mib * 1024**2,
102
+ "gpus": [
103
+ {
104
+ "count": len(instance_type.resources.gpus),
105
+ "manufacturer": instance_type.resources.gpus[0].vendor,
106
+ "model": instance_type.resources.gpus[0].name,
107
+ }
108
+ ],
109
+ }
110
+
111
+ def create_instance(
112
+ self,
113
+ instance_offer: InstanceOfferWithAvailability,
114
+ instance_config: InstanceConfiguration,
115
+ placement_group: Optional[PlacementGroup],
116
+ ) -> JobProvisioningData:
117
+ project_ssh_key = instance_config.ssh_keys[0]
118
+ self.api_client.upload_ssh_key(project_ssh_key.public)
119
+ vm_payload = self.get_payload_from_offer(instance_offer.instance)
120
+ vm_data = self.api_client.create_virtual_machine(vm_payload)
121
+ return JobProvisioningData(
122
+ backend=instance_offer.backend,
123
+ instance_type=instance_offer.instance,
124
+ instance_id=vm_data["name"],
125
+ hostname=None,
126
+ internal_ip=None,
127
+ region=instance_offer.region,
128
+ price=instance_offer.price,
129
+ username="hotaisle",
130
+ ssh_port=22,
131
+ dockerized=True,
132
+ ssh_proxy=None,
133
+ backend_data=HotAisleInstanceBackendData(
134
+ ip_address=vm_data["ip_address"], vm_id=vm_data["name"]
135
+ ).json(),
136
+ )
137
+
138
+ def update_provisioning_data(
139
+ self,
140
+ provisioning_data: JobProvisioningData,
141
+ project_ssh_public_key: str,
142
+ project_ssh_private_key: str,
143
+ ):
144
+ vm_state = self.api_client.get_vm_state(provisioning_data.instance_id)
145
+ if vm_state == "running":
146
+ if provisioning_data.hostname is None and provisioning_data.backend_data:
147
+ backend_data = HotAisleInstanceBackendData.load(provisioning_data.backend_data)
148
+ provisioning_data.hostname = backend_data.ip_address
149
+ commands = get_shim_commands(
150
+ authorized_keys=[project_ssh_public_key],
151
+ arch=provisioning_data.instance_type.resources.cpu_arch,
152
+ )
153
+ launch_command = "sudo sh -c " + shlex.quote(" && ".join(commands))
154
+ thread = Thread(
155
+ target=_start_runner,
156
+ kwargs={
157
+ "hostname": provisioning_data.hostname,
158
+ "project_ssh_private_key": project_ssh_private_key,
159
+ "launch_command": launch_command,
160
+ },
161
+ daemon=True,
162
+ )
163
+ thread.start()
164
+
165
+ def terminate_instance(
166
+ self, instance_id: str, region: str, backend_data: Optional[str] = None
167
+ ):
168
+ vm_name = instance_id
169
+ self.api_client.terminate_virtual_machine(vm_name)
170
+
171
+
172
+ def _start_runner(
173
+ hostname: str,
174
+ project_ssh_private_key: str,
175
+ launch_command: str,
176
+ ):
177
+ _launch_runner(
178
+ hostname=hostname,
179
+ ssh_private_key=project_ssh_private_key,
180
+ launch_command=launch_command,
181
+ )
182
+
183
+
184
+ def _launch_runner(
185
+ hostname: str,
186
+ ssh_private_key: str,
187
+ launch_command: str,
188
+ ):
189
+ daemonized_command = f"{launch_command.rstrip('&')} >/tmp/dstack-shim.log 2>&1 & disown"
190
+ _run_ssh_command(
191
+ hostname=hostname,
192
+ ssh_private_key=ssh_private_key,
193
+ command=daemonized_command,
194
+ )
195
+
196
+
197
+ def _run_ssh_command(hostname: str, ssh_private_key: str, command: str):
198
+ with tempfile.NamedTemporaryFile("w+", 0o600) as f:
199
+ f.write(ssh_private_key)
200
+ f.flush()
201
+ subprocess.run(
202
+ [
203
+ "ssh",
204
+ "-F",
205
+ "none",
206
+ "-o",
207
+ "StrictHostKeyChecking=no",
208
+ "-i",
209
+ f.name,
210
+ f"hotaisle@{hostname}",
211
+ command,
212
+ ],
213
+ stdout=subprocess.DEVNULL,
214
+ stderr=subprocess.DEVNULL,
215
+ )
216
+
217
+
218
+ class HotAisleInstanceBackendData(CoreModel):
219
+ ip_address: str
220
+ vm_id: Optional[str] = None
221
+
222
+ @classmethod
223
+ def load(cls, raw: Optional[str]) -> "HotAisleInstanceBackendData":
224
+ assert raw is not None
225
+ return cls.__response__.parse_raw(raw)