dstack 0.19.26__py3-none-any.whl → 0.19.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dstack might be problematic. Click here for more details.
- dstack/_internal/cli/commands/init.py +2 -2
- dstack/_internal/cli/services/configurators/run.py +114 -16
- dstack/_internal/cli/services/repos.py +1 -18
- dstack/_internal/core/backends/amddevcloud/__init__.py +1 -0
- dstack/_internal/core/backends/amddevcloud/backend.py +16 -0
- dstack/_internal/core/backends/amddevcloud/compute.py +5 -0
- dstack/_internal/core/backends/amddevcloud/configurator.py +29 -0
- dstack/_internal/core/backends/aws/compute.py +6 -1
- dstack/_internal/core/backends/base/compute.py +33 -5
- dstack/_internal/core/backends/base/offers.py +2 -0
- dstack/_internal/core/backends/configurators.py +15 -0
- dstack/_internal/core/backends/digitalocean/__init__.py +1 -0
- dstack/_internal/core/backends/digitalocean/backend.py +16 -0
- dstack/_internal/core/backends/digitalocean/compute.py +5 -0
- dstack/_internal/core/backends/digitalocean/configurator.py +31 -0
- dstack/_internal/core/backends/digitalocean_base/__init__.py +1 -0
- dstack/_internal/core/backends/digitalocean_base/api_client.py +104 -0
- dstack/_internal/core/backends/digitalocean_base/backend.py +5 -0
- dstack/_internal/core/backends/digitalocean_base/compute.py +173 -0
- dstack/_internal/core/backends/digitalocean_base/configurator.py +57 -0
- dstack/_internal/core/backends/digitalocean_base/models.py +43 -0
- dstack/_internal/core/backends/gcp/compute.py +32 -8
- dstack/_internal/core/backends/hotaisle/api_client.py +25 -33
- dstack/_internal/core/backends/hotaisle/compute.py +1 -6
- dstack/_internal/core/backends/models.py +7 -0
- dstack/_internal/core/backends/nebius/compute.py +0 -7
- dstack/_internal/core/backends/oci/compute.py +4 -5
- dstack/_internal/core/backends/vultr/compute.py +1 -5
- dstack/_internal/core/compatibility/fleets.py +5 -0
- dstack/_internal/core/compatibility/runs.py +8 -1
- dstack/_internal/core/models/backends/base.py +5 -1
- dstack/_internal/core/models/configurations.py +21 -7
- dstack/_internal/core/models/files.py +1 -1
- dstack/_internal/core/models/fleets.py +75 -2
- dstack/_internal/core/models/runs.py +24 -5
- dstack/_internal/core/services/repos.py +85 -80
- dstack/_internal/server/background/tasks/process_fleets.py +109 -13
- dstack/_internal/server/background/tasks/process_instances.py +12 -71
- dstack/_internal/server/background/tasks/process_running_jobs.py +2 -0
- dstack/_internal/server/background/tasks/process_runs.py +2 -0
- dstack/_internal/server/background/tasks/process_submitted_jobs.py +18 -6
- dstack/_internal/server/migrations/versions/2498ab323443_add_fleetmodel_consolidation_attempt_.py +44 -0
- dstack/_internal/server/models.py +5 -2
- dstack/_internal/server/schemas/runner.py +1 -0
- dstack/_internal/server/services/fleets.py +23 -25
- dstack/_internal/server/services/instances.py +3 -3
- dstack/_internal/server/services/jobs/configurators/base.py +46 -6
- dstack/_internal/server/services/jobs/configurators/dev.py +4 -4
- dstack/_internal/server/services/jobs/configurators/extensions/cursor.py +3 -5
- dstack/_internal/server/services/jobs/configurators/extensions/vscode.py +4 -6
- dstack/_internal/server/services/jobs/configurators/service.py +0 -3
- dstack/_internal/server/services/jobs/configurators/task.py +0 -3
- dstack/_internal/server/services/runs.py +16 -0
- dstack/_internal/server/statics/index.html +1 -1
- dstack/_internal/server/statics/{main-d151b300fcac3933213d.js → main-4eecc75fbe64067eb1bc.js} +1146 -899
- dstack/_internal/server/statics/{main-d151b300fcac3933213d.js.map → main-4eecc75fbe64067eb1bc.js.map} +1 -1
- dstack/_internal/server/statics/{main-aec4762350e34d6fbff9.css → main-56191c63d516fd0041c4.css} +1 -1
- dstack/_internal/server/testing/common.py +6 -3
- dstack/_internal/utils/path.py +8 -1
- dstack/_internal/utils/ssh.py +7 -0
- dstack/api/_public/repos.py +41 -6
- dstack/api/_public/runs.py +14 -1
- dstack/version.py +1 -1
- {dstack-0.19.26.dist-info → dstack-0.19.27.dist-info}/METADATA +2 -2
- {dstack-0.19.26.dist-info → dstack-0.19.27.dist-info}/RECORD +68 -53
- {dstack-0.19.26.dist-info → dstack-0.19.27.dist-info}/WHEEL +0 -0
- {dstack-0.19.26.dist-info → dstack-0.19.27.dist-info}/entry_points.txt +0 -0
- {dstack-0.19.26.dist-info → dstack-0.19.27.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -6,12 +6,12 @@ from typing import Optional
|
|
|
6
6
|
from dstack._internal.cli.commands import BaseCommand
|
|
7
7
|
from dstack._internal.cli.services.repos import (
|
|
8
8
|
get_repo_from_dir,
|
|
9
|
-
get_repo_from_url,
|
|
10
9
|
is_git_repo_url,
|
|
11
10
|
register_init_repo_args,
|
|
12
11
|
)
|
|
13
12
|
from dstack._internal.cli.utils.common import configure_logging, confirm_ask, console, warn
|
|
14
13
|
from dstack._internal.core.errors import ConfigurationError
|
|
14
|
+
from dstack._internal.core.models.repos.remote import RemoteRepo
|
|
15
15
|
from dstack._internal.core.services.configs import ConfigManager
|
|
16
16
|
from dstack.api import Client
|
|
17
17
|
|
|
@@ -101,7 +101,7 @@ class InitCommand(BaseCommand):
|
|
|
101
101
|
if repo_url is not None:
|
|
102
102
|
# Dummy repo branch to avoid autodetection that fails on private repos.
|
|
103
103
|
# We don't need branch/hash for repo_id anyway.
|
|
104
|
-
repo =
|
|
104
|
+
repo = RemoteRepo.from_url(repo_url, repo_branch="master")
|
|
105
105
|
elif repo_path is not None:
|
|
106
106
|
repo = get_repo_from_dir(repo_path, local=local)
|
|
107
107
|
else:
|
|
@@ -2,7 +2,7 @@ import argparse
|
|
|
2
2
|
import subprocess
|
|
3
3
|
import sys
|
|
4
4
|
import time
|
|
5
|
-
from pathlib import Path
|
|
5
|
+
from pathlib import Path, PurePosixPath
|
|
6
6
|
from typing import Dict, List, Optional, Set, TypeVar
|
|
7
7
|
|
|
8
8
|
import gpuhunt
|
|
@@ -17,7 +17,6 @@ from dstack._internal.cli.services.configurators.base import (
|
|
|
17
17
|
from dstack._internal.cli.services.profile import apply_profile_args, register_profile_args
|
|
18
18
|
from dstack._internal.cli.services.repos import (
|
|
19
19
|
get_repo_from_dir,
|
|
20
|
-
get_repo_from_url,
|
|
21
20
|
init_default_virtual_repo,
|
|
22
21
|
is_git_repo_url,
|
|
23
22
|
register_init_repo_args,
|
|
@@ -33,6 +32,7 @@ from dstack._internal.core.errors import (
|
|
|
33
32
|
)
|
|
34
33
|
from dstack._internal.core.models.common import ApplyAction, RegistryAuth
|
|
35
34
|
from dstack._internal.core.models.configurations import (
|
|
35
|
+
LEGACY_REPO_DIR,
|
|
36
36
|
AnyRunConfiguration,
|
|
37
37
|
ApplyConfigurationType,
|
|
38
38
|
ConfigurationWithPortsParams,
|
|
@@ -42,19 +42,27 @@ from dstack._internal.core.models.configurations import (
|
|
|
42
42
|
ServiceConfiguration,
|
|
43
43
|
TaskConfiguration,
|
|
44
44
|
)
|
|
45
|
+
from dstack._internal.core.models.repos import RepoHeadWithCreds
|
|
45
46
|
from dstack._internal.core.models.repos.base import Repo
|
|
46
47
|
from dstack._internal.core.models.repos.local import LocalRepo
|
|
48
|
+
from dstack._internal.core.models.repos.remote import RemoteRepo, RemoteRepoCreds
|
|
47
49
|
from dstack._internal.core.models.resources import CPUSpec
|
|
48
50
|
from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus
|
|
49
51
|
from dstack._internal.core.services.configs import ConfigManager
|
|
50
52
|
from dstack._internal.core.services.diff import diff_models
|
|
51
|
-
from dstack._internal.core.services.repos import
|
|
53
|
+
from dstack._internal.core.services.repos import (
|
|
54
|
+
InvalidRepoCredentialsError,
|
|
55
|
+
get_repo_creds_and_default_branch,
|
|
56
|
+
load_repo,
|
|
57
|
+
)
|
|
52
58
|
from dstack._internal.utils.common import local_time
|
|
53
59
|
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
|
|
54
60
|
from dstack._internal.utils.logging import get_logger
|
|
55
61
|
from dstack._internal.utils.nested_list import NestedList, NestedListItem
|
|
62
|
+
from dstack._internal.utils.path import is_absolute_posix_path
|
|
56
63
|
from dstack.api._public.repos import get_ssh_keypair
|
|
57
64
|
from dstack.api._public.runs import Run
|
|
65
|
+
from dstack.api.server import APIClient
|
|
58
66
|
from dstack.api.utils import load_profile
|
|
59
67
|
|
|
60
68
|
_KNOWN_AMD_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_AMD_GPUS}
|
|
@@ -89,6 +97,43 @@ class BaseRunConfigurator(
|
|
|
89
97
|
self.validate_gpu_vendor_and_image(conf)
|
|
90
98
|
self.validate_cpu_arch_and_image(conf)
|
|
91
99
|
|
|
100
|
+
working_dir = conf.working_dir
|
|
101
|
+
if working_dir is None:
|
|
102
|
+
# Use the default working dir for the image for tasks and services if `commands`
|
|
103
|
+
# is not set (emulate pre-0.19.27 JobConfigutor logic), otherwise fall back to
|
|
104
|
+
# `/workflow`.
|
|
105
|
+
if isinstance(conf, DevEnvironmentConfiguration) or conf.commands:
|
|
106
|
+
# relative path for compatibility with pre-0.19.27 servers
|
|
107
|
+
conf.working_dir = "."
|
|
108
|
+
warn(
|
|
109
|
+
f'The [code]working_dir[/code] is not set — using legacy default [code]"{LEGACY_REPO_DIR}"[/code].'
|
|
110
|
+
" Future versions will default to the [code]image[/code]'s working directory."
|
|
111
|
+
)
|
|
112
|
+
elif not is_absolute_posix_path(working_dir):
|
|
113
|
+
legacy_working_dir = PurePosixPath(LEGACY_REPO_DIR) / working_dir
|
|
114
|
+
warn(
|
|
115
|
+
"[code]working_dir[/code] is relative."
|
|
116
|
+
f" Using legacy working directory [code]{legacy_working_dir}[/code]\n\n"
|
|
117
|
+
"Future versions will require absolute path\n"
|
|
118
|
+
f"To keep using legacy working directory, set"
|
|
119
|
+
f" [code]working_dir[/code] to [code]{legacy_working_dir}[/code]\n"
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
# relative path for compatibility with pre-0.19.27 servers
|
|
123
|
+
try:
|
|
124
|
+
conf.working_dir = str(PurePosixPath(working_dir).relative_to(LEGACY_REPO_DIR))
|
|
125
|
+
except ValueError:
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
if conf.repos and conf.repos[0].path is None:
|
|
129
|
+
warn(
|
|
130
|
+
"[code]repos[0].path[/code] is not set,"
|
|
131
|
+
f" using legacy repo path [code]{LEGACY_REPO_DIR}[/code]\n\n"
|
|
132
|
+
"In a future version the default value will be changed."
|
|
133
|
+
f" To keep using [code]{LEGACY_REPO_DIR}[/code], explicitly set"
|
|
134
|
+
f" [code]repos[0].path[/code] to [code]{LEGACY_REPO_DIR}[/code]\n"
|
|
135
|
+
)
|
|
136
|
+
|
|
92
137
|
config_manager = ConfigManager()
|
|
93
138
|
repo = self.get_repo(conf, configuration_path, configurator_args, config_manager)
|
|
94
139
|
self.api.ssh_identity_file = get_ssh_keypair(
|
|
@@ -184,6 +229,9 @@ class BaseRunConfigurator(
|
|
|
184
229
|
format_date=local_time,
|
|
185
230
|
)
|
|
186
231
|
)
|
|
232
|
+
|
|
233
|
+
_warn_fleet_autocreated(self.api.client, run)
|
|
234
|
+
|
|
187
235
|
console.print(
|
|
188
236
|
f"\n[code]{run.name}[/] provisioning completed [secondary]({run.status.value})[/]"
|
|
189
237
|
)
|
|
@@ -486,15 +534,17 @@ class BaseRunConfigurator(
|
|
|
486
534
|
return init_default_virtual_repo(api=self.api)
|
|
487
535
|
|
|
488
536
|
repo: Optional[Repo] = None
|
|
537
|
+
repo_head: Optional[RepoHeadWithCreds] = None
|
|
489
538
|
repo_branch: Optional[str] = configurator_args.repo_branch
|
|
490
539
|
repo_hash: Optional[str] = configurator_args.repo_hash
|
|
540
|
+
repo_creds: Optional[RemoteRepoCreds] = None
|
|
541
|
+
git_identity_file: Optional[str] = configurator_args.git_identity_file
|
|
542
|
+
git_private_key: Optional[str] = None
|
|
543
|
+
oauth_token: Optional[str] = configurator_args.gh_token
|
|
491
544
|
# Should we (re)initialize the repo?
|
|
492
545
|
# If any Git credentials provided, we reinitialize the repo, as the user may have provided
|
|
493
546
|
# updated credentials.
|
|
494
|
-
init =
|
|
495
|
-
configurator_args.git_identity_file is not None
|
|
496
|
-
or configurator_args.gh_token is not None
|
|
497
|
-
)
|
|
547
|
+
init = git_identity_file is not None or oauth_token is not None
|
|
498
548
|
|
|
499
549
|
url: Optional[str] = None
|
|
500
550
|
local_path: Optional[Path] = None
|
|
@@ -527,15 +577,15 @@ class BaseRunConfigurator(
|
|
|
527
577
|
local_path = Path.cwd()
|
|
528
578
|
legacy_local_path = True
|
|
529
579
|
if url:
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
580
|
+
# "master" is a dummy value, we'll fetch the actual default branch later
|
|
581
|
+
repo = RemoteRepo.from_url(repo_url=url, repo_branch="master")
|
|
582
|
+
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
|
|
533
583
|
elif local_path:
|
|
534
584
|
if legacy_local_path:
|
|
535
585
|
if repo_config := config_manager.get_repo_config(local_path):
|
|
536
586
|
repo = load_repo(repo_config)
|
|
537
|
-
|
|
538
|
-
if
|
|
587
|
+
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
|
|
588
|
+
if repo_head is not None:
|
|
539
589
|
warn(
|
|
540
590
|
"The repo is not specified but found and will be used in the run\n"
|
|
541
591
|
"Future versions will not load repos automatically\n"
|
|
@@ -562,20 +612,55 @@ class BaseRunConfigurator(
|
|
|
562
612
|
)
|
|
563
613
|
local: bool = configurator_args.local
|
|
564
614
|
repo = get_repo_from_dir(local_path, local=local)
|
|
565
|
-
|
|
566
|
-
|
|
615
|
+
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
|
|
616
|
+
if isinstance(repo, RemoteRepo):
|
|
617
|
+
repo_branch = repo.run_repo_data.repo_branch
|
|
618
|
+
repo_hash = repo.run_repo_data.repo_hash
|
|
567
619
|
else:
|
|
568
620
|
assert False, "should not reach here"
|
|
569
621
|
|
|
570
622
|
if repo is None:
|
|
571
623
|
return init_default_virtual_repo(api=self.api)
|
|
572
624
|
|
|
625
|
+
if isinstance(repo, RemoteRepo):
|
|
626
|
+
assert repo.repo_url is not None
|
|
627
|
+
|
|
628
|
+
if repo_head is not None and repo_head.repo_creds is not None:
|
|
629
|
+
if git_identity_file is None and oauth_token is None:
|
|
630
|
+
git_private_key = repo_head.repo_creds.private_key
|
|
631
|
+
oauth_token = repo_head.repo_creds.oauth_token
|
|
632
|
+
else:
|
|
633
|
+
init = True
|
|
634
|
+
|
|
635
|
+
try:
|
|
636
|
+
repo_creds, default_repo_branch = get_repo_creds_and_default_branch(
|
|
637
|
+
repo_url=repo.repo_url,
|
|
638
|
+
identity_file=git_identity_file,
|
|
639
|
+
private_key=git_private_key,
|
|
640
|
+
oauth_token=oauth_token,
|
|
641
|
+
)
|
|
642
|
+
except InvalidRepoCredentialsError as e:
|
|
643
|
+
raise CLIError(*e.args) from e
|
|
644
|
+
|
|
645
|
+
if repo_branch is None and repo_hash is None:
|
|
646
|
+
repo_branch = default_repo_branch
|
|
647
|
+
if repo_branch is None:
|
|
648
|
+
raise CLIError(
|
|
649
|
+
"Failed to automatically detect remote repo branch."
|
|
650
|
+
" Specify branch or hash."
|
|
651
|
+
)
|
|
652
|
+
repo = RemoteRepo.from_url(
|
|
653
|
+
repo_url=repo.repo_url, repo_branch=repo_branch, repo_hash=repo_hash
|
|
654
|
+
)
|
|
655
|
+
|
|
573
656
|
if init:
|
|
574
657
|
self.api.repos.init(
|
|
575
658
|
repo=repo,
|
|
576
|
-
git_identity_file=
|
|
577
|
-
oauth_token=
|
|
659
|
+
git_identity_file=git_identity_file,
|
|
660
|
+
oauth_token=oauth_token,
|
|
661
|
+
creds=repo_creds,
|
|
578
662
|
)
|
|
663
|
+
|
|
579
664
|
if isinstance(repo, LocalRepo):
|
|
580
665
|
warn(
|
|
581
666
|
f"{repo.repo_dir} is a local repo\n"
|
|
@@ -827,3 +912,16 @@ def render_run_spec_diff(old_spec: RunSpec, new_spec: RunSpec) -> Optional[str]:
|
|
|
827
912
|
item = NestedListItem(spec_field.replace("_", " ").capitalize())
|
|
828
913
|
nested_list.children.append(item)
|
|
829
914
|
return nested_list.render()
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
def _warn_fleet_autocreated(api: APIClient, run: Run):
|
|
918
|
+
if run._run.fleet is None:
|
|
919
|
+
return
|
|
920
|
+
fleet = api.fleets.get(project_name=run._project, name=run._run.fleet.name)
|
|
921
|
+
if not fleet.spec.autocreated:
|
|
922
|
+
return
|
|
923
|
+
warn(
|
|
924
|
+
f"\nNo existing fleet matched, so the run created a new fleet [code]{fleet.name}[/code].\n"
|
|
925
|
+
"Future dstack versions won't create fleets automatically.\n"
|
|
926
|
+
"Create a fleet explicitly: https://dstack.ai/docs/concepts/fleets/"
|
|
927
|
+
)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import argparse
|
|
2
|
-
from typing import Literal,
|
|
2
|
+
from typing import Literal, Union, overload
|
|
3
3
|
|
|
4
4
|
import git
|
|
5
5
|
|
|
@@ -8,7 +8,6 @@ from dstack._internal.core.errors import CLIError
|
|
|
8
8
|
from dstack._internal.core.models.repos.local import LocalRepo
|
|
9
9
|
from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError
|
|
10
10
|
from dstack._internal.core.models.repos.virtual import VirtualRepo
|
|
11
|
-
from dstack._internal.core.services.repos import get_default_branch
|
|
12
11
|
from dstack._internal.utils.path import PathLike
|
|
13
12
|
from dstack.api._public import Client
|
|
14
13
|
|
|
@@ -43,22 +42,6 @@ def init_default_virtual_repo(api: Client) -> VirtualRepo:
|
|
|
43
42
|
return repo
|
|
44
43
|
|
|
45
44
|
|
|
46
|
-
def get_repo_from_url(
|
|
47
|
-
repo_url: str, repo_branch: Optional[str] = None, repo_hash: Optional[str] = None
|
|
48
|
-
) -> RemoteRepo:
|
|
49
|
-
if repo_branch is None and repo_hash is None:
|
|
50
|
-
repo_branch = get_default_branch(repo_url)
|
|
51
|
-
if repo_branch is None:
|
|
52
|
-
raise CLIError(
|
|
53
|
-
"Failed to automatically detect remote repo branch. Specify branch or hash."
|
|
54
|
-
)
|
|
55
|
-
return RemoteRepo.from_url(
|
|
56
|
-
repo_url=repo_url,
|
|
57
|
-
repo_branch=repo_branch,
|
|
58
|
-
repo_hash=repo_hash,
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
|
|
62
45
|
@overload
|
|
63
46
|
def get_repo_from_dir(repo_dir: PathLike, local: Literal[False] = False) -> RemoteRepo: ...
|
|
64
47
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# This package contains the implementation for the AMDDevCloud backend.
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from dstack._internal.core.backends.amddevcloud.compute import AMDDevCloudCompute
|
|
2
|
+
from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
|
|
3
|
+
from dstack._internal.core.backends.digitalocean_base.models import BaseDigitalOceanConfig
|
|
4
|
+
from dstack._internal.core.models.backends.base import BackendType
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AMDDevCloudBackend(BaseDigitalOceanBackend):
|
|
8
|
+
TYPE = BackendType.AMDDEVCLOUD
|
|
9
|
+
COMPUTE_CLASS = AMDDevCloudCompute
|
|
10
|
+
|
|
11
|
+
def __init__(self, config: BaseDigitalOceanConfig, api_url: str):
|
|
12
|
+
self.config = config
|
|
13
|
+
self._compute = AMDDevCloudCompute(self.config, api_url=api_url, type=self.TYPE)
|
|
14
|
+
|
|
15
|
+
def compute(self) -> AMDDevCloudCompute:
|
|
16
|
+
return self._compute
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from dstack._internal.core.backends.amddevcloud.backend import AMDDevCloudBackend
|
|
4
|
+
from dstack._internal.core.backends.base.configurator import BackendRecord
|
|
5
|
+
from dstack._internal.core.backends.digitalocean_base.api_client import DigitalOceanAPIClient
|
|
6
|
+
from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
|
|
7
|
+
from dstack._internal.core.backends.digitalocean_base.configurator import (
|
|
8
|
+
BaseDigitalOceanConfigurator,
|
|
9
|
+
)
|
|
10
|
+
from dstack._internal.core.backends.digitalocean_base.models import AnyBaseDigitalOceanCreds
|
|
11
|
+
from dstack._internal.core.models.backends.base import (
|
|
12
|
+
BackendType,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AMDDevCloudConfigurator(BaseDigitalOceanConfigurator):
|
|
17
|
+
TYPE = BackendType.AMDDEVCLOUD
|
|
18
|
+
BACKEND_CLASS = AMDDevCloudBackend
|
|
19
|
+
API_URL = "https://api-amd.digitalocean.com"
|
|
20
|
+
|
|
21
|
+
def get_backend(self, record: BackendRecord) -> BaseDigitalOceanBackend:
|
|
22
|
+
config = self._get_config(record)
|
|
23
|
+
return AMDDevCloudBackend(config=config, api_url=self.API_URL)
|
|
24
|
+
|
|
25
|
+
def _validate_creds(self, creds: AnyBaseDigitalOceanCreds, project_name: Optional[str] = None):
|
|
26
|
+
api_client = DigitalOceanAPIClient(creds.api_key, self.API_URL)
|
|
27
|
+
api_client.validate_api_key()
|
|
28
|
+
if project_name:
|
|
29
|
+
api_client.validate_project_name(project_name)
|
|
@@ -292,7 +292,12 @@ class AWSCompute(
|
|
|
292
292
|
image_id=image_id,
|
|
293
293
|
instance_type=instance_offer.instance.name,
|
|
294
294
|
iam_instance_profile=self.config.iam_instance_profile,
|
|
295
|
-
user_data=get_user_data(
|
|
295
|
+
user_data=get_user_data(
|
|
296
|
+
authorized_keys=instance_config.get_public_keys(),
|
|
297
|
+
# Custom OS images may lack ufw, so don't attempt to set up the firewall.
|
|
298
|
+
# Rely on security groups and the image's built-in firewall rules instead.
|
|
299
|
+
skip_firewall_setup=self.config.os_images is not None,
|
|
300
|
+
),
|
|
296
301
|
tags=aws_resources.make_tags(tags),
|
|
297
302
|
security_group_id=security_group_id,
|
|
298
303
|
spot=instance_offer.instance.resources.spot,
|
|
@@ -4,6 +4,7 @@ import re
|
|
|
4
4
|
import string
|
|
5
5
|
import threading
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
|
+
from collections.abc import Iterable
|
|
7
8
|
from functools import lru_cache
|
|
8
9
|
from pathlib import Path
|
|
9
10
|
from typing import Dict, List, Literal, Optional
|
|
@@ -19,7 +20,7 @@ from dstack._internal.core.consts import (
|
|
|
19
20
|
DSTACK_RUNNER_SSH_PORT,
|
|
20
21
|
DSTACK_SHIM_HTTP_PORT,
|
|
21
22
|
)
|
|
22
|
-
from dstack._internal.core.models.configurations import
|
|
23
|
+
from dstack._internal.core.models.configurations import LEGACY_REPO_DIR
|
|
23
24
|
from dstack._internal.core.models.gateways import (
|
|
24
25
|
GatewayComputeConfiguration,
|
|
25
26
|
GatewayProvisioningData,
|
|
@@ -45,6 +46,7 @@ logger = get_logger(__name__)
|
|
|
45
46
|
|
|
46
47
|
DSTACK_SHIM_BINARY_NAME = "dstack-shim"
|
|
47
48
|
DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
|
|
49
|
+
DEFAULT_PRIVATE_SUBNETS = ("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16")
|
|
48
50
|
|
|
49
51
|
GoArchType = Literal["amd64", "arm64"]
|
|
50
52
|
|
|
@@ -507,12 +509,16 @@ def get_user_data(
|
|
|
507
509
|
base_path: Optional[PathLike] = None,
|
|
508
510
|
bin_path: Optional[PathLike] = None,
|
|
509
511
|
backend_shim_env: Optional[Dict[str, str]] = None,
|
|
512
|
+
skip_firewall_setup: bool = False,
|
|
513
|
+
firewall_allow_from_subnets: Iterable[str] = DEFAULT_PRIVATE_SUBNETS,
|
|
510
514
|
) -> str:
|
|
511
515
|
shim_commands = get_shim_commands(
|
|
512
516
|
authorized_keys=authorized_keys,
|
|
513
517
|
base_path=base_path,
|
|
514
518
|
bin_path=bin_path,
|
|
515
519
|
backend_shim_env=backend_shim_env,
|
|
520
|
+
skip_firewall_setup=skip_firewall_setup,
|
|
521
|
+
firewall_allow_from_subnets=firewall_allow_from_subnets,
|
|
516
522
|
)
|
|
517
523
|
commands = (backend_specific_commands or []) + shim_commands
|
|
518
524
|
return get_cloud_config(
|
|
@@ -554,8 +560,13 @@ def get_shim_commands(
|
|
|
554
560
|
bin_path: Optional[PathLike] = None,
|
|
555
561
|
backend_shim_env: Optional[Dict[str, str]] = None,
|
|
556
562
|
arch: Optional[str] = None,
|
|
563
|
+
skip_firewall_setup: bool = False,
|
|
564
|
+
firewall_allow_from_subnets: Iterable[str] = DEFAULT_PRIVATE_SUBNETS,
|
|
557
565
|
) -> List[str]:
|
|
558
|
-
commands = get_setup_cloud_instance_commands(
|
|
566
|
+
commands = get_setup_cloud_instance_commands(
|
|
567
|
+
skip_firewall_setup=skip_firewall_setup,
|
|
568
|
+
firewall_allow_from_subnets=firewall_allow_from_subnets,
|
|
569
|
+
)
|
|
559
570
|
commands += get_shim_pre_start_commands(
|
|
560
571
|
base_path=base_path,
|
|
561
572
|
bin_path=bin_path,
|
|
@@ -638,8 +649,11 @@ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
|
|
|
638
649
|
return url_template.format(version=version, arch=arch)
|
|
639
650
|
|
|
640
651
|
|
|
641
|
-
def get_setup_cloud_instance_commands(
|
|
642
|
-
|
|
652
|
+
def get_setup_cloud_instance_commands(
|
|
653
|
+
skip_firewall_setup: bool,
|
|
654
|
+
firewall_allow_from_subnets: Iterable[str],
|
|
655
|
+
) -> list[str]:
|
|
656
|
+
commands = [
|
|
643
657
|
# Workaround for https://github.com/NVIDIA/nvidia-container-toolkit/issues/48
|
|
644
658
|
# Attempts to patch /etc/docker/daemon.json while keeping any custom settings it may have.
|
|
645
659
|
(
|
|
@@ -653,6 +667,19 @@ def get_setup_cloud_instance_commands() -> list[str]:
|
|
|
653
667
|
"'"
|
|
654
668
|
),
|
|
655
669
|
]
|
|
670
|
+
if not skip_firewall_setup:
|
|
671
|
+
commands += [
|
|
672
|
+
"ufw --force reset", # Some OS images have default rules like `allow 80`. Delete them
|
|
673
|
+
"ufw default deny incoming",
|
|
674
|
+
"ufw default allow outgoing",
|
|
675
|
+
"ufw allow ssh",
|
|
676
|
+
]
|
|
677
|
+
for subnet in firewall_allow_from_subnets:
|
|
678
|
+
commands.append(f"ufw allow from {subnet}")
|
|
679
|
+
commands += [
|
|
680
|
+
"ufw --force enable",
|
|
681
|
+
]
|
|
682
|
+
return commands
|
|
656
683
|
|
|
657
684
|
|
|
658
685
|
def get_shim_pre_start_commands(
|
|
@@ -773,7 +800,8 @@ def get_docker_commands(
|
|
|
773
800
|
f" --ssh-port {DSTACK_RUNNER_SSH_PORT}"
|
|
774
801
|
" --temp-dir /tmp/runner"
|
|
775
802
|
" --home-dir /root"
|
|
776
|
-
|
|
803
|
+
# TODO: Not used, left for compatibility with old runners. Remove eventually.
|
|
804
|
+
f" --working-dir {LEGACY_REPO_DIR}"
|
|
777
805
|
),
|
|
778
806
|
]
|
|
779
807
|
|
|
@@ -34,6 +34,8 @@ def get_catalog_offers(
|
|
|
34
34
|
provider = backend.value
|
|
35
35
|
if backend == BackendType.LAMBDA:
|
|
36
36
|
provider = "lambdalabs"
|
|
37
|
+
if backend == BackendType.AMDDEVCLOUD:
|
|
38
|
+
provider = "digitalocean"
|
|
37
39
|
q = requirements_to_query_filter(requirements)
|
|
38
40
|
q.provider = [provider]
|
|
39
41
|
offers = []
|
|
@@ -5,6 +5,12 @@ from dstack._internal.core.models.backends.base import BackendType
|
|
|
5
5
|
|
|
6
6
|
_CONFIGURATOR_CLASSES: List[Type[Configurator]] = []
|
|
7
7
|
|
|
8
|
+
try:
|
|
9
|
+
from dstack._internal.core.backends.amddevcloud.configurator import AMDDevCloudConfigurator
|
|
10
|
+
|
|
11
|
+
_CONFIGURATOR_CLASSES.append(AMDDevCloudConfigurator)
|
|
12
|
+
except ImportError:
|
|
13
|
+
pass
|
|
8
14
|
|
|
9
15
|
try:
|
|
10
16
|
from dstack._internal.core.backends.aws.configurator import AWSConfigurator
|
|
@@ -47,6 +53,15 @@ try:
|
|
|
47
53
|
except ImportError:
|
|
48
54
|
pass
|
|
49
55
|
|
|
56
|
+
try:
|
|
57
|
+
from dstack._internal.core.backends.digitalocean.configurator import (
|
|
58
|
+
DigitalOceanConfigurator,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
_CONFIGURATOR_CLASSES.append(DigitalOceanConfigurator)
|
|
62
|
+
except ImportError:
|
|
63
|
+
pass
|
|
64
|
+
|
|
50
65
|
try:
|
|
51
66
|
from dstack._internal.core.backends.gcp.configurator import GCPConfigurator
|
|
52
67
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DigitalOcean backend for dstack
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from dstack._internal.core.backends.digitalocean.compute import DigitalOceanCompute
|
|
2
|
+
from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
|
|
3
|
+
from dstack._internal.core.backends.digitalocean_base.models import BaseDigitalOceanConfig
|
|
4
|
+
from dstack._internal.core.models.backends.base import BackendType
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DigitalOceanBackend(BaseDigitalOceanBackend):
|
|
8
|
+
TYPE = BackendType.DIGITALOCEAN
|
|
9
|
+
COMPUTE_CLASS = DigitalOceanCompute
|
|
10
|
+
|
|
11
|
+
def __init__(self, config: BaseDigitalOceanConfig, api_url: str):
|
|
12
|
+
self.config = config
|
|
13
|
+
self._compute = DigitalOceanCompute(self.config, api_url=api_url, type=self.TYPE)
|
|
14
|
+
|
|
15
|
+
def compute(self) -> DigitalOceanCompute:
|
|
16
|
+
return self._compute
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from dstack._internal.core.backends.base.configurator import BackendRecord
|
|
4
|
+
from dstack._internal.core.backends.digitalocean.backend import DigitalOceanBackend
|
|
5
|
+
from dstack._internal.core.backends.digitalocean_base.api_client import DigitalOceanAPIClient
|
|
6
|
+
from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
|
|
7
|
+
from dstack._internal.core.backends.digitalocean_base.configurator import (
|
|
8
|
+
BaseDigitalOceanConfigurator,
|
|
9
|
+
)
|
|
10
|
+
from dstack._internal.core.backends.digitalocean_base.models import (
|
|
11
|
+
AnyBaseDigitalOceanCreds,
|
|
12
|
+
)
|
|
13
|
+
from dstack._internal.core.models.backends.base import (
|
|
14
|
+
BackendType,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DigitalOceanConfigurator(BaseDigitalOceanConfigurator):
|
|
19
|
+
TYPE = BackendType.DIGITALOCEAN
|
|
20
|
+
BACKEND_CLASS = DigitalOceanBackend
|
|
21
|
+
API_URL = "https://api.digitalocean.com"
|
|
22
|
+
|
|
23
|
+
def get_backend(self, record: BackendRecord) -> BaseDigitalOceanBackend:
|
|
24
|
+
config = self._get_config(record)
|
|
25
|
+
return DigitalOceanBackend(config=config, api_url=self.API_URL)
|
|
26
|
+
|
|
27
|
+
def _validate_creds(self, creds: AnyBaseDigitalOceanCreds, project_name: Optional[str] = None):
|
|
28
|
+
api_client = DigitalOceanAPIClient(creds.api_key, self.API_URL)
|
|
29
|
+
api_client.validate_api_key()
|
|
30
|
+
if project_name:
|
|
31
|
+
api_client.validate_project_name(project_name)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# This package contains the base classes for DigitalOcean and AMDDevCloud backends.
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
|
2
|
+
|
|
3
|
+
import requests
|
|
4
|
+
|
|
5
|
+
from dstack._internal.core.backends.base.configurator import raise_invalid_credentials_error
|
|
6
|
+
from dstack._internal.core.errors import NoCapacityError
|
|
7
|
+
from dstack._internal.utils.logging import get_logger
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DigitalOceanAPIClient:
|
|
13
|
+
def __init__(self, api_key: str, api_url: str):
|
|
14
|
+
self.api_key = api_key
|
|
15
|
+
self.base_url = api_url
|
|
16
|
+
|
|
17
|
+
def validate_api_key(self) -> bool:
|
|
18
|
+
try:
|
|
19
|
+
response = self._make_request("GET", "/v2/account")
|
|
20
|
+
response.raise_for_status()
|
|
21
|
+
return True
|
|
22
|
+
except requests.HTTPError as e:
|
|
23
|
+
status = e.response.status_code
|
|
24
|
+
if status == 401:
|
|
25
|
+
raise_invalid_credentials_error(
|
|
26
|
+
fields=[["creds", "api_key"]], details="Invaild API key"
|
|
27
|
+
)
|
|
28
|
+
raise e
|
|
29
|
+
|
|
30
|
+
def validate_project_name(self, project_name: str) -> bool:
|
|
31
|
+
if self.get_project_id(project_name) is None:
|
|
32
|
+
raise_invalid_credentials_error(
|
|
33
|
+
fields=[["project_name"]],
|
|
34
|
+
details=f"Project with name '{project_name}' does not exist",
|
|
35
|
+
)
|
|
36
|
+
return True
|
|
37
|
+
|
|
38
|
+
def list_ssh_keys(self) -> List[Dict[str, Any]]:
|
|
39
|
+
response = self._make_request("GET", "/v2/account/keys")
|
|
40
|
+
response.raise_for_status()
|
|
41
|
+
return response.json()["ssh_keys"]
|
|
42
|
+
|
|
43
|
+
def list_projects(self) -> List[Dict[str, Any]]:
|
|
44
|
+
response = self._make_request("GET", "/v2/projects")
|
|
45
|
+
response.raise_for_status()
|
|
46
|
+
return response.json()["projects"]
|
|
47
|
+
|
|
48
|
+
def get_project_id(self, project_name: str) -> Optional[str]:
|
|
49
|
+
projects = self.list_projects()
|
|
50
|
+
for project in projects:
|
|
51
|
+
if project["name"] == project_name:
|
|
52
|
+
return project["id"]
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
def create_ssh_key(self, name: str, public_key: str) -> Dict[str, Any]:
|
|
56
|
+
payload = {"name": name, "public_key": public_key}
|
|
57
|
+
response = self._make_request("POST", "/v2/account/keys", json=payload)
|
|
58
|
+
response.raise_for_status()
|
|
59
|
+
return response.json()["ssh_key"]
|
|
60
|
+
|
|
61
|
+
def get_or_create_ssh_key(self, name: str, public_key: str) -> int:
|
|
62
|
+
ssh_keys = self.list_ssh_keys()
|
|
63
|
+
for ssh_key in ssh_keys:
|
|
64
|
+
if ssh_key["public_key"].strip() == public_key.strip():
|
|
65
|
+
return ssh_key["id"]
|
|
66
|
+
|
|
67
|
+
ssh_key = self.create_ssh_key(name, public_key)
|
|
68
|
+
return ssh_key["id"]
|
|
69
|
+
|
|
70
|
+
def create_droplet(self, droplet_config: Dict[str, Any]) -> Dict[str, Any]:
|
|
71
|
+
response = self._make_request("POST", "/v2/droplets", json=droplet_config)
|
|
72
|
+
if response.status_code == 422:
|
|
73
|
+
raise NoCapacityError(response.json()["message"])
|
|
74
|
+
response.raise_for_status()
|
|
75
|
+
return response.json()["droplet"]
|
|
76
|
+
|
|
77
|
+
def get_droplet(self, droplet_id: str) -> Dict[str, Any]:
|
|
78
|
+
response = self._make_request("GET", f"/v2/droplets/{droplet_id}")
|
|
79
|
+
response.raise_for_status()
|
|
80
|
+
return response.json()["droplet"]
|
|
81
|
+
|
|
82
|
+
def delete_droplet(self, droplet_id: str) -> None:
|
|
83
|
+
response = self._make_request("DELETE", f"/v2/droplets/{droplet_id}")
|
|
84
|
+
if response.status_code == 404:
|
|
85
|
+
logger.debug("DigitalOcean droplet %s not found", droplet_id)
|
|
86
|
+
return
|
|
87
|
+
response.raise_for_status()
|
|
88
|
+
|
|
89
|
+
def _make_request(
|
|
90
|
+
self, method: str, endpoint: str, json: Optional[Dict[str, Any]] = None, timeout: int = 30
|
|
91
|
+
) -> requests.Response:
|
|
92
|
+
url = f"{self.base_url}{endpoint}"
|
|
93
|
+
headers = {
|
|
94
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
response = requests.request(
|
|
98
|
+
method=method,
|
|
99
|
+
url=url,
|
|
100
|
+
headers=headers,
|
|
101
|
+
json=json,
|
|
102
|
+
timeout=timeout,
|
|
103
|
+
)
|
|
104
|
+
return response
|