dstack 0.19.26__py3-none-any.whl → 0.19.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (68) hide show
  1. dstack/_internal/cli/commands/init.py +2 -2
  2. dstack/_internal/cli/services/configurators/run.py +114 -16
  3. dstack/_internal/cli/services/repos.py +1 -18
  4. dstack/_internal/core/backends/amddevcloud/__init__.py +1 -0
  5. dstack/_internal/core/backends/amddevcloud/backend.py +16 -0
  6. dstack/_internal/core/backends/amddevcloud/compute.py +5 -0
  7. dstack/_internal/core/backends/amddevcloud/configurator.py +29 -0
  8. dstack/_internal/core/backends/aws/compute.py +6 -1
  9. dstack/_internal/core/backends/base/compute.py +33 -5
  10. dstack/_internal/core/backends/base/offers.py +2 -0
  11. dstack/_internal/core/backends/configurators.py +15 -0
  12. dstack/_internal/core/backends/digitalocean/__init__.py +1 -0
  13. dstack/_internal/core/backends/digitalocean/backend.py +16 -0
  14. dstack/_internal/core/backends/digitalocean/compute.py +5 -0
  15. dstack/_internal/core/backends/digitalocean/configurator.py +31 -0
  16. dstack/_internal/core/backends/digitalocean_base/__init__.py +1 -0
  17. dstack/_internal/core/backends/digitalocean_base/api_client.py +104 -0
  18. dstack/_internal/core/backends/digitalocean_base/backend.py +5 -0
  19. dstack/_internal/core/backends/digitalocean_base/compute.py +173 -0
  20. dstack/_internal/core/backends/digitalocean_base/configurator.py +57 -0
  21. dstack/_internal/core/backends/digitalocean_base/models.py +43 -0
  22. dstack/_internal/core/backends/gcp/compute.py +32 -8
  23. dstack/_internal/core/backends/hotaisle/api_client.py +25 -33
  24. dstack/_internal/core/backends/hotaisle/compute.py +1 -6
  25. dstack/_internal/core/backends/models.py +7 -0
  26. dstack/_internal/core/backends/nebius/compute.py +0 -7
  27. dstack/_internal/core/backends/oci/compute.py +4 -5
  28. dstack/_internal/core/backends/vultr/compute.py +1 -5
  29. dstack/_internal/core/compatibility/fleets.py +5 -0
  30. dstack/_internal/core/compatibility/runs.py +8 -1
  31. dstack/_internal/core/models/backends/base.py +5 -1
  32. dstack/_internal/core/models/configurations.py +21 -7
  33. dstack/_internal/core/models/files.py +1 -1
  34. dstack/_internal/core/models/fleets.py +75 -2
  35. dstack/_internal/core/models/runs.py +24 -5
  36. dstack/_internal/core/services/repos.py +85 -80
  37. dstack/_internal/server/background/tasks/process_fleets.py +109 -13
  38. dstack/_internal/server/background/tasks/process_instances.py +12 -71
  39. dstack/_internal/server/background/tasks/process_running_jobs.py +2 -0
  40. dstack/_internal/server/background/tasks/process_runs.py +2 -0
  41. dstack/_internal/server/background/tasks/process_submitted_jobs.py +18 -6
  42. dstack/_internal/server/migrations/versions/2498ab323443_add_fleetmodel_consolidation_attempt_.py +44 -0
  43. dstack/_internal/server/models.py +5 -2
  44. dstack/_internal/server/schemas/runner.py +1 -0
  45. dstack/_internal/server/services/fleets.py +23 -25
  46. dstack/_internal/server/services/instances.py +3 -3
  47. dstack/_internal/server/services/jobs/configurators/base.py +46 -6
  48. dstack/_internal/server/services/jobs/configurators/dev.py +4 -4
  49. dstack/_internal/server/services/jobs/configurators/extensions/cursor.py +3 -5
  50. dstack/_internal/server/services/jobs/configurators/extensions/vscode.py +4 -6
  51. dstack/_internal/server/services/jobs/configurators/service.py +0 -3
  52. dstack/_internal/server/services/jobs/configurators/task.py +0 -3
  53. dstack/_internal/server/services/runs.py +16 -0
  54. dstack/_internal/server/statics/index.html +1 -1
  55. dstack/_internal/server/statics/{main-d151b300fcac3933213d.js → main-4eecc75fbe64067eb1bc.js} +1146 -899
  56. dstack/_internal/server/statics/{main-d151b300fcac3933213d.js.map → main-4eecc75fbe64067eb1bc.js.map} +1 -1
  57. dstack/_internal/server/statics/{main-aec4762350e34d6fbff9.css → main-56191c63d516fd0041c4.css} +1 -1
  58. dstack/_internal/server/testing/common.py +6 -3
  59. dstack/_internal/utils/path.py +8 -1
  60. dstack/_internal/utils/ssh.py +7 -0
  61. dstack/api/_public/repos.py +41 -6
  62. dstack/api/_public/runs.py +14 -1
  63. dstack/version.py +1 -1
  64. {dstack-0.19.26.dist-info → dstack-0.19.27.dist-info}/METADATA +2 -2
  65. {dstack-0.19.26.dist-info → dstack-0.19.27.dist-info}/RECORD +68 -53
  66. {dstack-0.19.26.dist-info → dstack-0.19.27.dist-info}/WHEEL +0 -0
  67. {dstack-0.19.26.dist-info → dstack-0.19.27.dist-info}/entry_points.txt +0 -0
  68. {dstack-0.19.26.dist-info → dstack-0.19.27.dist-info}/licenses/LICENSE.md +0 -0
@@ -6,12 +6,12 @@ from typing import Optional
6
6
  from dstack._internal.cli.commands import BaseCommand
7
7
  from dstack._internal.cli.services.repos import (
8
8
  get_repo_from_dir,
9
- get_repo_from_url,
10
9
  is_git_repo_url,
11
10
  register_init_repo_args,
12
11
  )
13
12
  from dstack._internal.cli.utils.common import configure_logging, confirm_ask, console, warn
14
13
  from dstack._internal.core.errors import ConfigurationError
14
+ from dstack._internal.core.models.repos.remote import RemoteRepo
15
15
  from dstack._internal.core.services.configs import ConfigManager
16
16
  from dstack.api import Client
17
17
 
@@ -101,7 +101,7 @@ class InitCommand(BaseCommand):
101
101
  if repo_url is not None:
102
102
  # Dummy repo branch to avoid autodetection that fails on private repos.
103
103
  # We don't need branch/hash for repo_id anyway.
104
- repo = get_repo_from_url(repo_url, repo_branch="master")
104
+ repo = RemoteRepo.from_url(repo_url, repo_branch="master")
105
105
  elif repo_path is not None:
106
106
  repo = get_repo_from_dir(repo_path, local=local)
107
107
  else:
@@ -2,7 +2,7 @@ import argparse
2
2
  import subprocess
3
3
  import sys
4
4
  import time
5
- from pathlib import Path
5
+ from pathlib import Path, PurePosixPath
6
6
  from typing import Dict, List, Optional, Set, TypeVar
7
7
 
8
8
  import gpuhunt
@@ -17,7 +17,6 @@ from dstack._internal.cli.services.configurators.base import (
17
17
  from dstack._internal.cli.services.profile import apply_profile_args, register_profile_args
18
18
  from dstack._internal.cli.services.repos import (
19
19
  get_repo_from_dir,
20
- get_repo_from_url,
21
20
  init_default_virtual_repo,
22
21
  is_git_repo_url,
23
22
  register_init_repo_args,
@@ -33,6 +32,7 @@ from dstack._internal.core.errors import (
33
32
  )
34
33
  from dstack._internal.core.models.common import ApplyAction, RegistryAuth
35
34
  from dstack._internal.core.models.configurations import (
35
+ LEGACY_REPO_DIR,
36
36
  AnyRunConfiguration,
37
37
  ApplyConfigurationType,
38
38
  ConfigurationWithPortsParams,
@@ -42,19 +42,27 @@ from dstack._internal.core.models.configurations import (
42
42
  ServiceConfiguration,
43
43
  TaskConfiguration,
44
44
  )
45
+ from dstack._internal.core.models.repos import RepoHeadWithCreds
45
46
  from dstack._internal.core.models.repos.base import Repo
46
47
  from dstack._internal.core.models.repos.local import LocalRepo
48
+ from dstack._internal.core.models.repos.remote import RemoteRepo, RemoteRepoCreds
47
49
  from dstack._internal.core.models.resources import CPUSpec
48
50
  from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus
49
51
  from dstack._internal.core.services.configs import ConfigManager
50
52
  from dstack._internal.core.services.diff import diff_models
51
- from dstack._internal.core.services.repos import load_repo
53
+ from dstack._internal.core.services.repos import (
54
+ InvalidRepoCredentialsError,
55
+ get_repo_creds_and_default_branch,
56
+ load_repo,
57
+ )
52
58
  from dstack._internal.utils.common import local_time
53
59
  from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
54
60
  from dstack._internal.utils.logging import get_logger
55
61
  from dstack._internal.utils.nested_list import NestedList, NestedListItem
62
+ from dstack._internal.utils.path import is_absolute_posix_path
56
63
  from dstack.api._public.repos import get_ssh_keypair
57
64
  from dstack.api._public.runs import Run
65
+ from dstack.api.server import APIClient
58
66
  from dstack.api.utils import load_profile
59
67
 
60
68
  _KNOWN_AMD_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_AMD_GPUS}
@@ -89,6 +97,43 @@ class BaseRunConfigurator(
89
97
  self.validate_gpu_vendor_and_image(conf)
90
98
  self.validate_cpu_arch_and_image(conf)
91
99
 
100
+ working_dir = conf.working_dir
101
+ if working_dir is None:
102
+ # Use the default working dir for the image for tasks and services if `commands`
103
+ # is not set (emulate pre-0.19.27 JobConfigutor logic), otherwise fall back to
104
+ # `/workflow`.
105
+ if isinstance(conf, DevEnvironmentConfiguration) or conf.commands:
106
+ # relative path for compatibility with pre-0.19.27 servers
107
+ conf.working_dir = "."
108
+ warn(
109
+ f'The [code]working_dir[/code] is not set — using legacy default [code]"{LEGACY_REPO_DIR}"[/code].'
110
+ " Future versions will default to the [code]image[/code]'s working directory."
111
+ )
112
+ elif not is_absolute_posix_path(working_dir):
113
+ legacy_working_dir = PurePosixPath(LEGACY_REPO_DIR) / working_dir
114
+ warn(
115
+ "[code]working_dir[/code] is relative."
116
+ f" Using legacy working directory [code]{legacy_working_dir}[/code]\n\n"
117
+ "Future versions will require absolute path\n"
118
+ f"To keep using legacy working directory, set"
119
+ f" [code]working_dir[/code] to [code]{legacy_working_dir}[/code]\n"
120
+ )
121
+ else:
122
+ # relative path for compatibility with pre-0.19.27 servers
123
+ try:
124
+ conf.working_dir = str(PurePosixPath(working_dir).relative_to(LEGACY_REPO_DIR))
125
+ except ValueError:
126
+ pass
127
+
128
+ if conf.repos and conf.repos[0].path is None:
129
+ warn(
130
+ "[code]repos[0].path[/code] is not set,"
131
+ f" using legacy repo path [code]{LEGACY_REPO_DIR}[/code]\n\n"
132
+ "In a future version the default value will be changed."
133
+ f" To keep using [code]{LEGACY_REPO_DIR}[/code], explicitly set"
134
+ f" [code]repos[0].path[/code] to [code]{LEGACY_REPO_DIR}[/code]\n"
135
+ )
136
+
92
137
  config_manager = ConfigManager()
93
138
  repo = self.get_repo(conf, configuration_path, configurator_args, config_manager)
94
139
  self.api.ssh_identity_file = get_ssh_keypair(
@@ -184,6 +229,9 @@ class BaseRunConfigurator(
184
229
  format_date=local_time,
185
230
  )
186
231
  )
232
+
233
+ _warn_fleet_autocreated(self.api.client, run)
234
+
187
235
  console.print(
188
236
  f"\n[code]{run.name}[/] provisioning completed [secondary]({run.status.value})[/]"
189
237
  )
@@ -486,15 +534,17 @@ class BaseRunConfigurator(
486
534
  return init_default_virtual_repo(api=self.api)
487
535
 
488
536
  repo: Optional[Repo] = None
537
+ repo_head: Optional[RepoHeadWithCreds] = None
489
538
  repo_branch: Optional[str] = configurator_args.repo_branch
490
539
  repo_hash: Optional[str] = configurator_args.repo_hash
540
+ repo_creds: Optional[RemoteRepoCreds] = None
541
+ git_identity_file: Optional[str] = configurator_args.git_identity_file
542
+ git_private_key: Optional[str] = None
543
+ oauth_token: Optional[str] = configurator_args.gh_token
491
544
  # Should we (re)initialize the repo?
492
545
  # If any Git credentials provided, we reinitialize the repo, as the user may have provided
493
546
  # updated credentials.
494
- init = (
495
- configurator_args.git_identity_file is not None
496
- or configurator_args.gh_token is not None
497
- )
547
+ init = git_identity_file is not None or oauth_token is not None
498
548
 
499
549
  url: Optional[str] = None
500
550
  local_path: Optional[Path] = None
@@ -527,15 +577,15 @@ class BaseRunConfigurator(
527
577
  local_path = Path.cwd()
528
578
  legacy_local_path = True
529
579
  if url:
530
- repo = get_repo_from_url(repo_url=url, repo_branch=repo_branch, repo_hash=repo_hash)
531
- if not self.api.repos.is_initialized(repo, by_user=True):
532
- init = True
580
+ # "master" is a dummy value, we'll fetch the actual default branch later
581
+ repo = RemoteRepo.from_url(repo_url=url, repo_branch="master")
582
+ repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
533
583
  elif local_path:
534
584
  if legacy_local_path:
535
585
  if repo_config := config_manager.get_repo_config(local_path):
536
586
  repo = load_repo(repo_config)
537
- # allow users with legacy configurations use shared repo creds
538
- if self.api.repos.is_initialized(repo, by_user=False):
587
+ repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
588
+ if repo_head is not None:
539
589
  warn(
540
590
  "The repo is not specified but found and will be used in the run\n"
541
591
  "Future versions will not load repos automatically\n"
@@ -562,20 +612,55 @@ class BaseRunConfigurator(
562
612
  )
563
613
  local: bool = configurator_args.local
564
614
  repo = get_repo_from_dir(local_path, local=local)
565
- if not self.api.repos.is_initialized(repo, by_user=True):
566
- init = True
615
+ repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
616
+ if isinstance(repo, RemoteRepo):
617
+ repo_branch = repo.run_repo_data.repo_branch
618
+ repo_hash = repo.run_repo_data.repo_hash
567
619
  else:
568
620
  assert False, "should not reach here"
569
621
 
570
622
  if repo is None:
571
623
  return init_default_virtual_repo(api=self.api)
572
624
 
625
+ if isinstance(repo, RemoteRepo):
626
+ assert repo.repo_url is not None
627
+
628
+ if repo_head is not None and repo_head.repo_creds is not None:
629
+ if git_identity_file is None and oauth_token is None:
630
+ git_private_key = repo_head.repo_creds.private_key
631
+ oauth_token = repo_head.repo_creds.oauth_token
632
+ else:
633
+ init = True
634
+
635
+ try:
636
+ repo_creds, default_repo_branch = get_repo_creds_and_default_branch(
637
+ repo_url=repo.repo_url,
638
+ identity_file=git_identity_file,
639
+ private_key=git_private_key,
640
+ oauth_token=oauth_token,
641
+ )
642
+ except InvalidRepoCredentialsError as e:
643
+ raise CLIError(*e.args) from e
644
+
645
+ if repo_branch is None and repo_hash is None:
646
+ repo_branch = default_repo_branch
647
+ if repo_branch is None:
648
+ raise CLIError(
649
+ "Failed to automatically detect remote repo branch."
650
+ " Specify branch or hash."
651
+ )
652
+ repo = RemoteRepo.from_url(
653
+ repo_url=repo.repo_url, repo_branch=repo_branch, repo_hash=repo_hash
654
+ )
655
+
573
656
  if init:
574
657
  self.api.repos.init(
575
658
  repo=repo,
576
- git_identity_file=configurator_args.git_identity_file,
577
- oauth_token=configurator_args.gh_token,
659
+ git_identity_file=git_identity_file,
660
+ oauth_token=oauth_token,
661
+ creds=repo_creds,
578
662
  )
663
+
579
664
  if isinstance(repo, LocalRepo):
580
665
  warn(
581
666
  f"{repo.repo_dir} is a local repo\n"
@@ -827,3 +912,16 @@ def render_run_spec_diff(old_spec: RunSpec, new_spec: RunSpec) -> Optional[str]:
827
912
  item = NestedListItem(spec_field.replace("_", " ").capitalize())
828
913
  nested_list.children.append(item)
829
914
  return nested_list.render()
915
+
916
+
917
+ def _warn_fleet_autocreated(api: APIClient, run: Run):
918
+ if run._run.fleet is None:
919
+ return
920
+ fleet = api.fleets.get(project_name=run._project, name=run._run.fleet.name)
921
+ if not fleet.spec.autocreated:
922
+ return
923
+ warn(
924
+ f"\nNo existing fleet matched, so the run created a new fleet [code]{fleet.name}[/code].\n"
925
+ "Future dstack versions won't create fleets automatically.\n"
926
+ "Create a fleet explicitly: https://dstack.ai/docs/concepts/fleets/"
927
+ )
@@ -1,5 +1,5 @@
1
1
  import argparse
2
- from typing import Literal, Optional, Union, overload
2
+ from typing import Literal, Union, overload
3
3
 
4
4
  import git
5
5
 
@@ -8,7 +8,6 @@ from dstack._internal.core.errors import CLIError
8
8
  from dstack._internal.core.models.repos.local import LocalRepo
9
9
  from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError
10
10
  from dstack._internal.core.models.repos.virtual import VirtualRepo
11
- from dstack._internal.core.services.repos import get_default_branch
12
11
  from dstack._internal.utils.path import PathLike
13
12
  from dstack.api._public import Client
14
13
 
@@ -43,22 +42,6 @@ def init_default_virtual_repo(api: Client) -> VirtualRepo:
43
42
  return repo
44
43
 
45
44
 
46
- def get_repo_from_url(
47
- repo_url: str, repo_branch: Optional[str] = None, repo_hash: Optional[str] = None
48
- ) -> RemoteRepo:
49
- if repo_branch is None and repo_hash is None:
50
- repo_branch = get_default_branch(repo_url)
51
- if repo_branch is None:
52
- raise CLIError(
53
- "Failed to automatically detect remote repo branch. Specify branch or hash."
54
- )
55
- return RemoteRepo.from_url(
56
- repo_url=repo_url,
57
- repo_branch=repo_branch,
58
- repo_hash=repo_hash,
59
- )
60
-
61
-
62
45
  @overload
63
46
  def get_repo_from_dir(repo_dir: PathLike, local: Literal[False] = False) -> RemoteRepo: ...
64
47
 
@@ -0,0 +1 @@
1
+ # This package contains the implementation for the AMDDevCloud backend.
@@ -0,0 +1,16 @@
1
+ from dstack._internal.core.backends.amddevcloud.compute import AMDDevCloudCompute
2
+ from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
3
+ from dstack._internal.core.backends.digitalocean_base.models import BaseDigitalOceanConfig
4
+ from dstack._internal.core.models.backends.base import BackendType
5
+
6
+
7
+ class AMDDevCloudBackend(BaseDigitalOceanBackend):
8
+ TYPE = BackendType.AMDDEVCLOUD
9
+ COMPUTE_CLASS = AMDDevCloudCompute
10
+
11
+ def __init__(self, config: BaseDigitalOceanConfig, api_url: str):
12
+ self.config = config
13
+ self._compute = AMDDevCloudCompute(self.config, api_url=api_url, type=self.TYPE)
14
+
15
+ def compute(self) -> AMDDevCloudCompute:
16
+ return self._compute
@@ -0,0 +1,5 @@
1
+ from dstack._internal.core.backends.digitalocean_base.compute import BaseDigitalOceanCompute
2
+
3
+
4
+ class AMDDevCloudCompute(BaseDigitalOceanCompute):
5
+ pass
@@ -0,0 +1,29 @@
1
+ from typing import Optional
2
+
3
+ from dstack._internal.core.backends.amddevcloud.backend import AMDDevCloudBackend
4
+ from dstack._internal.core.backends.base.configurator import BackendRecord
5
+ from dstack._internal.core.backends.digitalocean_base.api_client import DigitalOceanAPIClient
6
+ from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
7
+ from dstack._internal.core.backends.digitalocean_base.configurator import (
8
+ BaseDigitalOceanConfigurator,
9
+ )
10
+ from dstack._internal.core.backends.digitalocean_base.models import AnyBaseDigitalOceanCreds
11
+ from dstack._internal.core.models.backends.base import (
12
+ BackendType,
13
+ )
14
+
15
+
16
+ class AMDDevCloudConfigurator(BaseDigitalOceanConfigurator):
17
+ TYPE = BackendType.AMDDEVCLOUD
18
+ BACKEND_CLASS = AMDDevCloudBackend
19
+ API_URL = "https://api-amd.digitalocean.com"
20
+
21
+ def get_backend(self, record: BackendRecord) -> BaseDigitalOceanBackend:
22
+ config = self._get_config(record)
23
+ return AMDDevCloudBackend(config=config, api_url=self.API_URL)
24
+
25
+ def _validate_creds(self, creds: AnyBaseDigitalOceanCreds, project_name: Optional[str] = None):
26
+ api_client = DigitalOceanAPIClient(creds.api_key, self.API_URL)
27
+ api_client.validate_api_key()
28
+ if project_name:
29
+ api_client.validate_project_name(project_name)
@@ -292,7 +292,12 @@ class AWSCompute(
292
292
  image_id=image_id,
293
293
  instance_type=instance_offer.instance.name,
294
294
  iam_instance_profile=self.config.iam_instance_profile,
295
- user_data=get_user_data(authorized_keys=instance_config.get_public_keys()),
295
+ user_data=get_user_data(
296
+ authorized_keys=instance_config.get_public_keys(),
297
+ # Custom OS images may lack ufw, so don't attempt to set up the firewall.
298
+ # Rely on security groups and the image's built-in firewall rules instead.
299
+ skip_firewall_setup=self.config.os_images is not None,
300
+ ),
296
301
  tags=aws_resources.make_tags(tags),
297
302
  security_group_id=security_group_id,
298
303
  spot=instance_offer.instance.resources.spot,
@@ -4,6 +4,7 @@ import re
4
4
  import string
5
5
  import threading
6
6
  from abc import ABC, abstractmethod
7
+ from collections.abc import Iterable
7
8
  from functools import lru_cache
8
9
  from pathlib import Path
9
10
  from typing import Dict, List, Literal, Optional
@@ -19,7 +20,7 @@ from dstack._internal.core.consts import (
19
20
  DSTACK_RUNNER_SSH_PORT,
20
21
  DSTACK_SHIM_HTTP_PORT,
21
22
  )
22
- from dstack._internal.core.models.configurations import DEFAULT_REPO_DIR
23
+ from dstack._internal.core.models.configurations import LEGACY_REPO_DIR
23
24
  from dstack._internal.core.models.gateways import (
24
25
  GatewayComputeConfiguration,
25
26
  GatewayProvisioningData,
@@ -45,6 +46,7 @@ logger = get_logger(__name__)
45
46
 
46
47
  DSTACK_SHIM_BINARY_NAME = "dstack-shim"
47
48
  DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
49
+ DEFAULT_PRIVATE_SUBNETS = ("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16")
48
50
 
49
51
  GoArchType = Literal["amd64", "arm64"]
50
52
 
@@ -507,12 +509,16 @@ def get_user_data(
507
509
  base_path: Optional[PathLike] = None,
508
510
  bin_path: Optional[PathLike] = None,
509
511
  backend_shim_env: Optional[Dict[str, str]] = None,
512
+ skip_firewall_setup: bool = False,
513
+ firewall_allow_from_subnets: Iterable[str] = DEFAULT_PRIVATE_SUBNETS,
510
514
  ) -> str:
511
515
  shim_commands = get_shim_commands(
512
516
  authorized_keys=authorized_keys,
513
517
  base_path=base_path,
514
518
  bin_path=bin_path,
515
519
  backend_shim_env=backend_shim_env,
520
+ skip_firewall_setup=skip_firewall_setup,
521
+ firewall_allow_from_subnets=firewall_allow_from_subnets,
516
522
  )
517
523
  commands = (backend_specific_commands or []) + shim_commands
518
524
  return get_cloud_config(
@@ -554,8 +560,13 @@ def get_shim_commands(
554
560
  bin_path: Optional[PathLike] = None,
555
561
  backend_shim_env: Optional[Dict[str, str]] = None,
556
562
  arch: Optional[str] = None,
563
+ skip_firewall_setup: bool = False,
564
+ firewall_allow_from_subnets: Iterable[str] = DEFAULT_PRIVATE_SUBNETS,
557
565
  ) -> List[str]:
558
- commands = get_setup_cloud_instance_commands()
566
+ commands = get_setup_cloud_instance_commands(
567
+ skip_firewall_setup=skip_firewall_setup,
568
+ firewall_allow_from_subnets=firewall_allow_from_subnets,
569
+ )
559
570
  commands += get_shim_pre_start_commands(
560
571
  base_path=base_path,
561
572
  bin_path=bin_path,
@@ -638,8 +649,11 @@ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str:
638
649
  return url_template.format(version=version, arch=arch)
639
650
 
640
651
 
641
- def get_setup_cloud_instance_commands() -> list[str]:
642
- return [
652
+ def get_setup_cloud_instance_commands(
653
+ skip_firewall_setup: bool,
654
+ firewall_allow_from_subnets: Iterable[str],
655
+ ) -> list[str]:
656
+ commands = [
643
657
  # Workaround for https://github.com/NVIDIA/nvidia-container-toolkit/issues/48
644
658
  # Attempts to patch /etc/docker/daemon.json while keeping any custom settings it may have.
645
659
  (
@@ -653,6 +667,19 @@ def get_setup_cloud_instance_commands() -> list[str]:
653
667
  "'"
654
668
  ),
655
669
  ]
670
+ if not skip_firewall_setup:
671
+ commands += [
672
+ "ufw --force reset", # Some OS images have default rules like `allow 80`. Delete them
673
+ "ufw default deny incoming",
674
+ "ufw default allow outgoing",
675
+ "ufw allow ssh",
676
+ ]
677
+ for subnet in firewall_allow_from_subnets:
678
+ commands.append(f"ufw allow from {subnet}")
679
+ commands += [
680
+ "ufw --force enable",
681
+ ]
682
+ return commands
656
683
 
657
684
 
658
685
  def get_shim_pre_start_commands(
@@ -773,7 +800,8 @@ def get_docker_commands(
773
800
  f" --ssh-port {DSTACK_RUNNER_SSH_PORT}"
774
801
  " --temp-dir /tmp/runner"
775
802
  " --home-dir /root"
776
- f" --working-dir {DEFAULT_REPO_DIR}"
803
+ # TODO: Not used, left for compatibility with old runners. Remove eventually.
804
+ f" --working-dir {LEGACY_REPO_DIR}"
777
805
  ),
778
806
  ]
779
807
 
@@ -34,6 +34,8 @@ def get_catalog_offers(
34
34
  provider = backend.value
35
35
  if backend == BackendType.LAMBDA:
36
36
  provider = "lambdalabs"
37
+ if backend == BackendType.AMDDEVCLOUD:
38
+ provider = "digitalocean"
37
39
  q = requirements_to_query_filter(requirements)
38
40
  q.provider = [provider]
39
41
  offers = []
@@ -5,6 +5,12 @@ from dstack._internal.core.models.backends.base import BackendType
5
5
 
6
6
  _CONFIGURATOR_CLASSES: List[Type[Configurator]] = []
7
7
 
8
+ try:
9
+ from dstack._internal.core.backends.amddevcloud.configurator import AMDDevCloudConfigurator
10
+
11
+ _CONFIGURATOR_CLASSES.append(AMDDevCloudConfigurator)
12
+ except ImportError:
13
+ pass
8
14
 
9
15
  try:
10
16
  from dstack._internal.core.backends.aws.configurator import AWSConfigurator
@@ -47,6 +53,15 @@ try:
47
53
  except ImportError:
48
54
  pass
49
55
 
56
+ try:
57
+ from dstack._internal.core.backends.digitalocean.configurator import (
58
+ DigitalOceanConfigurator,
59
+ )
60
+
61
+ _CONFIGURATOR_CLASSES.append(DigitalOceanConfigurator)
62
+ except ImportError:
63
+ pass
64
+
50
65
  try:
51
66
  from dstack._internal.core.backends.gcp.configurator import GCPConfigurator
52
67
 
@@ -0,0 +1 @@
1
+ # DigitalOcean backend for dstack
@@ -0,0 +1,16 @@
1
+ from dstack._internal.core.backends.digitalocean.compute import DigitalOceanCompute
2
+ from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
3
+ from dstack._internal.core.backends.digitalocean_base.models import BaseDigitalOceanConfig
4
+ from dstack._internal.core.models.backends.base import BackendType
5
+
6
+
7
+ class DigitalOceanBackend(BaseDigitalOceanBackend):
8
+ TYPE = BackendType.DIGITALOCEAN
9
+ COMPUTE_CLASS = DigitalOceanCompute
10
+
11
+ def __init__(self, config: BaseDigitalOceanConfig, api_url: str):
12
+ self.config = config
13
+ self._compute = DigitalOceanCompute(self.config, api_url=api_url, type=self.TYPE)
14
+
15
+ def compute(self) -> DigitalOceanCompute:
16
+ return self._compute
@@ -0,0 +1,5 @@
1
+ from ..digitalocean_base.compute import BaseDigitalOceanCompute
2
+
3
+
4
+ class DigitalOceanCompute(BaseDigitalOceanCompute):
5
+ pass
@@ -0,0 +1,31 @@
1
+ from typing import Optional
2
+
3
+ from dstack._internal.core.backends.base.configurator import BackendRecord
4
+ from dstack._internal.core.backends.digitalocean.backend import DigitalOceanBackend
5
+ from dstack._internal.core.backends.digitalocean_base.api_client import DigitalOceanAPIClient
6
+ from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
7
+ from dstack._internal.core.backends.digitalocean_base.configurator import (
8
+ BaseDigitalOceanConfigurator,
9
+ )
10
+ from dstack._internal.core.backends.digitalocean_base.models import (
11
+ AnyBaseDigitalOceanCreds,
12
+ )
13
+ from dstack._internal.core.models.backends.base import (
14
+ BackendType,
15
+ )
16
+
17
+
18
+ class DigitalOceanConfigurator(BaseDigitalOceanConfigurator):
19
+ TYPE = BackendType.DIGITALOCEAN
20
+ BACKEND_CLASS = DigitalOceanBackend
21
+ API_URL = "https://api.digitalocean.com"
22
+
23
+ def get_backend(self, record: BackendRecord) -> BaseDigitalOceanBackend:
24
+ config = self._get_config(record)
25
+ return DigitalOceanBackend(config=config, api_url=self.API_URL)
26
+
27
+ def _validate_creds(self, creds: AnyBaseDigitalOceanCreds, project_name: Optional[str] = None):
28
+ api_client = DigitalOceanAPIClient(creds.api_key, self.API_URL)
29
+ api_client.validate_api_key()
30
+ if project_name:
31
+ api_client.validate_project_name(project_name)
@@ -0,0 +1 @@
1
+ # This package contains the base classes for DigitalOcean and AMDDevCloud backends.
@@ -0,0 +1,104 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import requests
4
+
5
+ from dstack._internal.core.backends.base.configurator import raise_invalid_credentials_error
6
+ from dstack._internal.core.errors import NoCapacityError
7
+ from dstack._internal.utils.logging import get_logger
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ class DigitalOceanAPIClient:
13
+ def __init__(self, api_key: str, api_url: str):
14
+ self.api_key = api_key
15
+ self.base_url = api_url
16
+
17
+ def validate_api_key(self) -> bool:
18
+ try:
19
+ response = self._make_request("GET", "/v2/account")
20
+ response.raise_for_status()
21
+ return True
22
+ except requests.HTTPError as e:
23
+ status = e.response.status_code
24
+ if status == 401:
25
+ raise_invalid_credentials_error(
26
+ fields=[["creds", "api_key"]], details="Invaild API key"
27
+ )
28
+ raise e
29
+
30
+ def validate_project_name(self, project_name: str) -> bool:
31
+ if self.get_project_id(project_name) is None:
32
+ raise_invalid_credentials_error(
33
+ fields=[["project_name"]],
34
+ details=f"Project with name '{project_name}' does not exist",
35
+ )
36
+ return True
37
+
38
+ def list_ssh_keys(self) -> List[Dict[str, Any]]:
39
+ response = self._make_request("GET", "/v2/account/keys")
40
+ response.raise_for_status()
41
+ return response.json()["ssh_keys"]
42
+
43
+ def list_projects(self) -> List[Dict[str, Any]]:
44
+ response = self._make_request("GET", "/v2/projects")
45
+ response.raise_for_status()
46
+ return response.json()["projects"]
47
+
48
+ def get_project_id(self, project_name: str) -> Optional[str]:
49
+ projects = self.list_projects()
50
+ for project in projects:
51
+ if project["name"] == project_name:
52
+ return project["id"]
53
+ return None
54
+
55
+ def create_ssh_key(self, name: str, public_key: str) -> Dict[str, Any]:
56
+ payload = {"name": name, "public_key": public_key}
57
+ response = self._make_request("POST", "/v2/account/keys", json=payload)
58
+ response.raise_for_status()
59
+ return response.json()["ssh_key"]
60
+
61
+ def get_or_create_ssh_key(self, name: str, public_key: str) -> int:
62
+ ssh_keys = self.list_ssh_keys()
63
+ for ssh_key in ssh_keys:
64
+ if ssh_key["public_key"].strip() == public_key.strip():
65
+ return ssh_key["id"]
66
+
67
+ ssh_key = self.create_ssh_key(name, public_key)
68
+ return ssh_key["id"]
69
+
70
+ def create_droplet(self, droplet_config: Dict[str, Any]) -> Dict[str, Any]:
71
+ response = self._make_request("POST", "/v2/droplets", json=droplet_config)
72
+ if response.status_code == 422:
73
+ raise NoCapacityError(response.json()["message"])
74
+ response.raise_for_status()
75
+ return response.json()["droplet"]
76
+
77
+ def get_droplet(self, droplet_id: str) -> Dict[str, Any]:
78
+ response = self._make_request("GET", f"/v2/droplets/{droplet_id}")
79
+ response.raise_for_status()
80
+ return response.json()["droplet"]
81
+
82
+ def delete_droplet(self, droplet_id: str) -> None:
83
+ response = self._make_request("DELETE", f"/v2/droplets/{droplet_id}")
84
+ if response.status_code == 404:
85
+ logger.debug("DigitalOcean droplet %s not found", droplet_id)
86
+ return
87
+ response.raise_for_status()
88
+
89
+ def _make_request(
90
+ self, method: str, endpoint: str, json: Optional[Dict[str, Any]] = None, timeout: int = 30
91
+ ) -> requests.Response:
92
+ url = f"{self.base_url}{endpoint}"
93
+ headers = {
94
+ "Authorization": f"Bearer {self.api_key}",
95
+ }
96
+
97
+ response = requests.request(
98
+ method=method,
99
+ url=url,
100
+ headers=headers,
101
+ json=json,
102
+ timeout=timeout,
103
+ )
104
+ return response
@@ -0,0 +1,5 @@
1
+ from dstack._internal.core.backends.base.backend import Backend
2
+
3
+
4
+ class BaseDigitalOceanBackend(Backend):
5
+ pass