dstack 0.19.16__py3-none-any.whl → 0.19.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dstack might be problematic. Click here for more details.
- dstack/_internal/cli/commands/secrets.py +92 -0
- dstack/_internal/cli/main.py +2 -0
- dstack/_internal/cli/services/completion.py +5 -0
- dstack/_internal/cli/services/configurators/run.py +59 -17
- dstack/_internal/cli/utils/secrets.py +25 -0
- dstack/_internal/core/backends/__init__.py +10 -4
- dstack/_internal/core/compatibility/runs.py +29 -2
- dstack/_internal/core/models/configurations.py +11 -0
- dstack/_internal/core/models/files.py +67 -0
- dstack/_internal/core/models/runs.py +14 -0
- dstack/_internal/core/models/secrets.py +9 -2
- dstack/_internal/server/app.py +2 -0
- dstack/_internal/server/background/tasks/process_running_jobs.py +109 -12
- dstack/_internal/server/background/tasks/process_runs.py +15 -3
- dstack/_internal/server/migrations/versions/5f1707c525d2_add_filearchivemodel.py +39 -0
- dstack/_internal/server/migrations/versions/644b8a114187_add_secretmodel.py +49 -0
- dstack/_internal/server/models.py +33 -0
- dstack/_internal/server/routers/files.py +67 -0
- dstack/_internal/server/routers/secrets.py +57 -15
- dstack/_internal/server/schemas/files.py +5 -0
- dstack/_internal/server/schemas/runner.py +2 -0
- dstack/_internal/server/schemas/secrets.py +7 -11
- dstack/_internal/server/services/backends/__init__.py +1 -1
- dstack/_internal/server/services/files.py +91 -0
- dstack/_internal/server/services/jobs/__init__.py +19 -8
- dstack/_internal/server/services/jobs/configurators/base.py +20 -2
- dstack/_internal/server/services/jobs/configurators/dev.py +3 -3
- dstack/_internal/server/services/proxy/repo.py +3 -0
- dstack/_internal/server/services/runner/client.py +8 -0
- dstack/_internal/server/services/runs.py +52 -7
- dstack/_internal/server/services/secrets.py +204 -0
- dstack/_internal/server/services/storage/base.py +21 -0
- dstack/_internal/server/services/storage/gcs.py +28 -6
- dstack/_internal/server/services/storage/s3.py +27 -9
- dstack/_internal/server/settings.py +2 -2
- dstack/_internal/server/statics/index.html +1 -1
- dstack/_internal/server/statics/{main-a4eafa74304e587d037c.js → main-d151637af20f70b2e796.js} +56 -8
- dstack/_internal/server/statics/{main-a4eafa74304e587d037c.js.map → main-d151637af20f70b2e796.js.map} +1 -1
- dstack/_internal/server/statics/{main-f53d6d0d42f8d61df1de.css → main-d48635d8fe670d53961c.css} +1 -1
- dstack/_internal/server/statics/static/media/google.b194b06fafd0a52aeb566922160ea514.svg +1 -0
- dstack/_internal/server/testing/common.py +43 -5
- dstack/_internal/settings.py +4 -0
- dstack/_internal/utils/files.py +69 -0
- dstack/_internal/utils/nested_list.py +47 -0
- dstack/_internal/utils/path.py +12 -4
- dstack/api/_public/runs.py +67 -7
- dstack/api/server/__init__.py +6 -0
- dstack/api/server/_files.py +18 -0
- dstack/api/server/_secrets.py +15 -15
- dstack/version.py +1 -1
- {dstack-0.19.16.dist-info → dstack-0.19.17.dist-info}/METADATA +3 -4
- {dstack-0.19.16.dist-info → dstack-0.19.17.dist-info}/RECORD +55 -42
- {dstack-0.19.16.dist-info → dstack-0.19.17.dist-info}/WHEEL +0 -0
- {dstack-0.19.16.dist-info → dstack-0.19.17.dist-info}/entry_points.txt +0 -0
- {dstack-0.19.16.dist-info → dstack-0.19.17.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from fastapi import UploadFile
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
|
|
8
|
+
from dstack._internal.core.errors import ServerClientError
|
|
9
|
+
from dstack._internal.core.models.files import FileArchive
|
|
10
|
+
from dstack._internal.server.models import FileArchiveModel, UserModel
|
|
11
|
+
from dstack._internal.server.services.storage import get_default_storage
|
|
12
|
+
from dstack._internal.utils.common import run_async
|
|
13
|
+
from dstack._internal.utils.logging import get_logger
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
async def get_archive_model(
|
|
19
|
+
session: AsyncSession,
|
|
20
|
+
id: uuid.UUID,
|
|
21
|
+
user: Optional[UserModel] = None,
|
|
22
|
+
) -> Optional[FileArchiveModel]:
|
|
23
|
+
stmt = select(FileArchiveModel).where(FileArchiveModel.id == id)
|
|
24
|
+
if user is not None:
|
|
25
|
+
stmt = stmt.where(FileArchiveModel.user_id == user.id)
|
|
26
|
+
res = await session.execute(stmt)
|
|
27
|
+
return res.scalar()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
async def get_archive_model_by_hash(
|
|
31
|
+
session: AsyncSession,
|
|
32
|
+
user: UserModel,
|
|
33
|
+
hash: str,
|
|
34
|
+
) -> Optional[FileArchiveModel]:
|
|
35
|
+
res = await session.execute(
|
|
36
|
+
select(FileArchiveModel).where(
|
|
37
|
+
FileArchiveModel.user_id == user.id,
|
|
38
|
+
FileArchiveModel.blob_hash == hash,
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
return res.scalar()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
async def get_archive_by_hash(
|
|
45
|
+
session: AsyncSession,
|
|
46
|
+
user: UserModel,
|
|
47
|
+
hash: str,
|
|
48
|
+
) -> Optional[FileArchive]:
|
|
49
|
+
archive_model = await get_archive_model_by_hash(
|
|
50
|
+
session=session,
|
|
51
|
+
user=user,
|
|
52
|
+
hash=hash,
|
|
53
|
+
)
|
|
54
|
+
if archive_model is None:
|
|
55
|
+
return None
|
|
56
|
+
return archive_model_to_archive(archive_model)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
async def upload_archive(
|
|
60
|
+
session: AsyncSession,
|
|
61
|
+
user: UserModel,
|
|
62
|
+
file: UploadFile,
|
|
63
|
+
) -> FileArchive:
|
|
64
|
+
if file.filename is None:
|
|
65
|
+
raise ServerClientError("filename not specified")
|
|
66
|
+
archive_hash = file.filename
|
|
67
|
+
archive_model = await get_archive_model_by_hash(
|
|
68
|
+
session=session,
|
|
69
|
+
user=user,
|
|
70
|
+
hash=archive_hash,
|
|
71
|
+
)
|
|
72
|
+
if archive_model is not None:
|
|
73
|
+
logger.debug("File archive (user_id=%s, hash=%s) already uploaded", user.id, archive_hash)
|
|
74
|
+
return archive_model_to_archive(archive_model)
|
|
75
|
+
blob = await file.read()
|
|
76
|
+
storage = get_default_storage()
|
|
77
|
+
if storage is not None:
|
|
78
|
+
await run_async(storage.upload_archive, str(user.id), archive_hash, blob)
|
|
79
|
+
archive_model = FileArchiveModel(
|
|
80
|
+
user_id=user.id,
|
|
81
|
+
blob_hash=archive_hash,
|
|
82
|
+
blob=blob if storage is None else None,
|
|
83
|
+
)
|
|
84
|
+
session.add(archive_model)
|
|
85
|
+
await session.commit()
|
|
86
|
+
logger.debug("File archive (user_id=%s, hash=%s) has been uploaded", user.id, archive_hash)
|
|
87
|
+
return archive_model_to_archive(archive_model)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def archive_model_to_archive(archive_model: FileArchiveModel) -> FileArchive:
|
|
91
|
+
return FileArchive(id=archive_model.id, hash=archive_model.blob_hash)
|
|
@@ -33,6 +33,7 @@ from dstack._internal.core.models.runs import (
|
|
|
33
33
|
RunSpec,
|
|
34
34
|
)
|
|
35
35
|
from dstack._internal.core.models.volumes import Volume, VolumeMountPoint, VolumeStatus
|
|
36
|
+
from dstack._internal.server import settings
|
|
36
37
|
from dstack._internal.server.models import (
|
|
37
38
|
InstanceModel,
|
|
38
39
|
JobModel,
|
|
@@ -64,15 +65,23 @@ from dstack._internal.utils.logging import get_logger
|
|
|
64
65
|
logger = get_logger(__name__)
|
|
65
66
|
|
|
66
67
|
|
|
67
|
-
async def get_jobs_from_run_spec(
|
|
68
|
+
async def get_jobs_from_run_spec(
|
|
69
|
+
run_spec: RunSpec, secrets: Dict[str, str], replica_num: int
|
|
70
|
+
) -> List[Job]:
|
|
68
71
|
return [
|
|
69
72
|
Job(job_spec=s, job_submissions=[])
|
|
70
|
-
for s in await get_job_specs_from_run_spec(
|
|
73
|
+
for s in await get_job_specs_from_run_spec(
|
|
74
|
+
run_spec=run_spec,
|
|
75
|
+
secrets=secrets,
|
|
76
|
+
replica_num=replica_num,
|
|
77
|
+
)
|
|
71
78
|
]
|
|
72
79
|
|
|
73
80
|
|
|
74
|
-
async def get_job_specs_from_run_spec(
|
|
75
|
-
|
|
81
|
+
async def get_job_specs_from_run_spec(
|
|
82
|
+
run_spec: RunSpec, secrets: Dict[str, str], replica_num: int
|
|
83
|
+
) -> List[JobSpec]:
|
|
84
|
+
job_configurator = _get_job_configurator(run_spec=run_spec, secrets=secrets)
|
|
76
85
|
job_specs = await job_configurator.get_job_specs(replica_num=replica_num)
|
|
77
86
|
return job_specs
|
|
78
87
|
|
|
@@ -158,10 +167,10 @@ def delay_job_instance_termination(job_model: JobModel):
|
|
|
158
167
|
job_model.remove_at = common.get_current_datetime() + timedelta(seconds=15)
|
|
159
168
|
|
|
160
169
|
|
|
161
|
-
def _get_job_configurator(run_spec: RunSpec) -> JobConfigurator:
|
|
170
|
+
def _get_job_configurator(run_spec: RunSpec, secrets: Dict[str, str]) -> JobConfigurator:
|
|
162
171
|
configuration_type = RunConfigurationType(run_spec.configuration.type)
|
|
163
172
|
configurator_class = _configuration_type_to_configurator_class_map[configuration_type]
|
|
164
|
-
return configurator_class(run_spec)
|
|
173
|
+
return configurator_class(run_spec=run_spec, secrets=secrets)
|
|
165
174
|
|
|
166
175
|
|
|
167
176
|
_job_configurator_classes = [
|
|
@@ -380,8 +389,10 @@ def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel):
|
|
|
380
389
|
message=job_model.termination_reason_message,
|
|
381
390
|
timeout=0,
|
|
382
391
|
)
|
|
383
|
-
# maybe somehow postpone removing old tasks to allow inspecting failed jobs
|
|
384
|
-
|
|
392
|
+
# maybe somehow postpone removing old tasks to allow inspecting failed jobs without
|
|
393
|
+
# the following setting?
|
|
394
|
+
if not settings.SERVER_KEEP_SHIM_TASKS:
|
|
395
|
+
shim_client.remove_task(task_id=job_model.id)
|
|
385
396
|
else:
|
|
386
397
|
shim_client.stop(force=True)
|
|
387
398
|
|
|
@@ -68,8 +68,13 @@ class JobConfigurator(ABC):
|
|
|
68
68
|
# JobSSHKey should be shared for all jobs in a replica for inter-node communication.
|
|
69
69
|
_job_ssh_key: Optional[JobSSHKey] = None
|
|
70
70
|
|
|
71
|
-
def __init__(
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
run_spec: RunSpec,
|
|
74
|
+
secrets: Optional[Dict[str, str]] = None,
|
|
75
|
+
):
|
|
72
76
|
self.run_spec = run_spec
|
|
77
|
+
self.secrets = secrets or {}
|
|
73
78
|
|
|
74
79
|
async def get_job_specs(self, replica_num: int) -> List[JobSpec]:
|
|
75
80
|
job_spec = await self._get_job_spec(replica_num=replica_num, job_num=0, jobs_per_replica=1)
|
|
@@ -98,10 +103,20 @@ class JobConfigurator(ABC):
|
|
|
98
103
|
async def _get_image_config(self) -> ImageConfig:
|
|
99
104
|
if self._image_config is not None:
|
|
100
105
|
return self._image_config
|
|
106
|
+
interpolate = VariablesInterpolator({"secrets": self.secrets}).interpolate_or_error
|
|
107
|
+
registry_auth = self.run_spec.configuration.registry_auth
|
|
108
|
+
if registry_auth is not None:
|
|
109
|
+
try:
|
|
110
|
+
registry_auth = RegistryAuth(
|
|
111
|
+
username=interpolate(registry_auth.username),
|
|
112
|
+
password=interpolate(registry_auth.password),
|
|
113
|
+
)
|
|
114
|
+
except InterpolatorError as e:
|
|
115
|
+
raise ServerClientError(e.args[0])
|
|
101
116
|
image_config = await run_async(
|
|
102
117
|
_get_image_config,
|
|
103
118
|
self._image_name(),
|
|
104
|
-
|
|
119
|
+
registry_auth,
|
|
105
120
|
)
|
|
106
121
|
self._image_config = image_config
|
|
107
122
|
return image_config
|
|
@@ -134,6 +149,9 @@ class JobConfigurator(ABC):
|
|
|
134
149
|
working_dir=self._working_dir(),
|
|
135
150
|
volumes=self._volumes(job_num),
|
|
136
151
|
ssh_key=self._ssh_key(jobs_per_replica),
|
|
152
|
+
repo_data=self.run_spec.repo_data,
|
|
153
|
+
repo_code_hash=self.run_spec.repo_code_hash,
|
|
154
|
+
file_archives=self.run_spec.file_archives,
|
|
137
155
|
)
|
|
138
156
|
return job_spec
|
|
139
157
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import List, Optional
|
|
1
|
+
from typing import Dict, List, Optional
|
|
2
2
|
|
|
3
3
|
from dstack._internal.core.errors import ServerClientError
|
|
4
4
|
from dstack._internal.core.models.configurations import PortMapping, RunConfigurationType
|
|
@@ -17,7 +17,7 @@ INSTALL_IPYKERNEL = (
|
|
|
17
17
|
class DevEnvironmentJobConfigurator(JobConfigurator):
|
|
18
18
|
TYPE: RunConfigurationType = RunConfigurationType.DEV_ENVIRONMENT
|
|
19
19
|
|
|
20
|
-
def __init__(self, run_spec: RunSpec):
|
|
20
|
+
def __init__(self, run_spec: RunSpec, secrets: Dict[str, str]):
|
|
21
21
|
if run_spec.configuration.ide == "vscode":
|
|
22
22
|
__class = VSCodeDesktop
|
|
23
23
|
elif run_spec.configuration.ide == "cursor":
|
|
@@ -29,7 +29,7 @@ class DevEnvironmentJobConfigurator(JobConfigurator):
|
|
|
29
29
|
version=run_spec.configuration.version,
|
|
30
30
|
extensions=["ms-python.python", "ms-toolsai.jupyter"],
|
|
31
31
|
)
|
|
32
|
-
super().__init__(run_spec)
|
|
32
|
+
super().__init__(run_spec=run_spec, secrets=secrets)
|
|
33
33
|
|
|
34
34
|
def _shell_commands(self) -> List[str]:
|
|
35
35
|
commands = self.ide.get_install_commands()
|
|
@@ -7,6 +7,7 @@ from sqlalchemy.orm import joinedload
|
|
|
7
7
|
|
|
8
8
|
import dstack._internal.server.services.jobs as jobs_services
|
|
9
9
|
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
|
|
10
|
+
from dstack._internal.core.models.backends.base import BackendType
|
|
10
11
|
from dstack._internal.core.models.configurations import ServiceConfiguration
|
|
11
12
|
from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams
|
|
12
13
|
from dstack._internal.core.models.runs import (
|
|
@@ -86,6 +87,8 @@ class ServerProxyRepo(BaseProxyRepo):
|
|
|
86
87
|
username=jpd.username,
|
|
87
88
|
port=jpd.ssh_port,
|
|
88
89
|
)
|
|
90
|
+
if jpd.backend == BackendType.LOCAL:
|
|
91
|
+
ssh_proxy = None
|
|
89
92
|
ssh_head_proxy: Optional[SSHConnectionParams] = None
|
|
90
93
|
ssh_head_proxy_private_key: Optional[str] = None
|
|
91
94
|
instance = get_or_error(job.instance)
|
|
@@ -109,6 +109,14 @@ class RunnerClient:
|
|
|
109
109
|
)
|
|
110
110
|
resp.raise_for_status()
|
|
111
111
|
|
|
112
|
+
def upload_archive(self, id: uuid.UUID, file: Union[BinaryIO, bytes]):
|
|
113
|
+
resp = requests.post(
|
|
114
|
+
self._url("/api/upload_archive"),
|
|
115
|
+
files={"archive": (str(id), file)},
|
|
116
|
+
timeout=UPLOAD_CODE_REQUEST_TIMEOUT,
|
|
117
|
+
)
|
|
118
|
+
resp.raise_for_status()
|
|
119
|
+
|
|
112
120
|
def upload_code(self, file: Union[BinaryIO, bytes]):
|
|
113
121
|
resp = requests.post(
|
|
114
122
|
self._url("/api/upload_code"), data=file, timeout=UPLOAD_CODE_REQUEST_TIMEOUT
|
|
@@ -82,6 +82,7 @@ from dstack._internal.server.services.offers import get_offers_by_requirements
|
|
|
82
82
|
from dstack._internal.server.services.plugins import apply_plugin_policies
|
|
83
83
|
from dstack._internal.server.services.projects import list_project_models, list_user_project_models
|
|
84
84
|
from dstack._internal.server.services.resources import set_resources_defaults
|
|
85
|
+
from dstack._internal.server.services.secrets import get_project_secrets_mapping
|
|
85
86
|
from dstack._internal.server.services.users import get_user_model_by_name
|
|
86
87
|
from dstack._internal.utils.logging import get_logger
|
|
87
88
|
from dstack._internal.utils.random_names import generate_name
|
|
@@ -311,7 +312,12 @@ async def get_plan(
|
|
|
311
312
|
):
|
|
312
313
|
action = ApplyAction.UPDATE
|
|
313
314
|
|
|
314
|
-
|
|
315
|
+
secrets = await get_project_secrets_mapping(session=session, project=project)
|
|
316
|
+
jobs = await get_jobs_from_run_spec(
|
|
317
|
+
run_spec=effective_run_spec,
|
|
318
|
+
secrets=secrets,
|
|
319
|
+
replica_num=0,
|
|
320
|
+
)
|
|
315
321
|
|
|
316
322
|
volumes = await get_job_configured_volumes(
|
|
317
323
|
session=session,
|
|
@@ -462,6 +468,10 @@ async def submit_run(
|
|
|
462
468
|
project=project,
|
|
463
469
|
run_spec=run_spec,
|
|
464
470
|
)
|
|
471
|
+
secrets = await get_project_secrets_mapping(
|
|
472
|
+
session=session,
|
|
473
|
+
project=project,
|
|
474
|
+
)
|
|
465
475
|
|
|
466
476
|
lock_namespace = f"run_names_{project.name}"
|
|
467
477
|
if get_db().dialect_name == "sqlite":
|
|
@@ -513,7 +523,11 @@ async def submit_run(
|
|
|
513
523
|
await services.register_service(session, run_model, run_spec)
|
|
514
524
|
|
|
515
525
|
for replica_num in range(replicas):
|
|
516
|
-
jobs = await get_jobs_from_run_spec(
|
|
526
|
+
jobs = await get_jobs_from_run_spec(
|
|
527
|
+
run_spec=run_spec,
|
|
528
|
+
secrets=secrets,
|
|
529
|
+
replica_num=replica_num,
|
|
530
|
+
)
|
|
517
531
|
for job in jobs:
|
|
518
532
|
job_model = create_job_model_for_new_submission(
|
|
519
533
|
run_model=run_model,
|
|
@@ -898,7 +912,16 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec):
|
|
|
898
912
|
set_resources_defaults(run_spec.configuration.resources)
|
|
899
913
|
|
|
900
914
|
|
|
901
|
-
_UPDATABLE_SPEC_FIELDS = ["
|
|
915
|
+
_UPDATABLE_SPEC_FIELDS = ["configuration_path", "configuration"]
|
|
916
|
+
_TYPE_SPECIFIC_UPDATABLE_SPEC_FIELDS = {
|
|
917
|
+
"service": [
|
|
918
|
+
# rolling deployment
|
|
919
|
+
"repo_data",
|
|
920
|
+
"repo_code_hash",
|
|
921
|
+
"file_archives",
|
|
922
|
+
"working_dir",
|
|
923
|
+
],
|
|
924
|
+
}
|
|
902
925
|
_CONF_UPDATABLE_FIELDS = ["priority"]
|
|
903
926
|
_TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS = {
|
|
904
927
|
"dev-environment": ["inactivity_duration"],
|
|
@@ -909,10 +932,13 @@ _TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS = {
|
|
|
909
932
|
# rolling deployment
|
|
910
933
|
"resources",
|
|
911
934
|
"volumes",
|
|
935
|
+
"docker",
|
|
936
|
+
"files",
|
|
912
937
|
"image",
|
|
913
938
|
"user",
|
|
914
939
|
"privileged",
|
|
915
940
|
"entrypoint",
|
|
941
|
+
"working_dir",
|
|
916
942
|
"python",
|
|
917
943
|
"nvcc",
|
|
918
944
|
"single_branch",
|
|
@@ -935,11 +961,14 @@ def _can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec) -> bo
|
|
|
935
961
|
def _check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec):
|
|
936
962
|
spec_diff = diff_models(current_run_spec, new_run_spec)
|
|
937
963
|
changed_spec_fields = list(spec_diff.keys())
|
|
964
|
+
updatable_spec_fields = _UPDATABLE_SPEC_FIELDS + _TYPE_SPECIFIC_UPDATABLE_SPEC_FIELDS.get(
|
|
965
|
+
new_run_spec.configuration.type, []
|
|
966
|
+
)
|
|
938
967
|
for key in changed_spec_fields:
|
|
939
|
-
if key not in
|
|
968
|
+
if key not in updatable_spec_fields:
|
|
940
969
|
raise ServerClientError(
|
|
941
970
|
f"Failed to update fields {changed_spec_fields}."
|
|
942
|
-
f" Can only update {
|
|
971
|
+
f" Can only update {updatable_spec_fields}."
|
|
943
972
|
)
|
|
944
973
|
_check_can_update_configuration(current_run_spec.configuration, new_run_spec.configuration)
|
|
945
974
|
|
|
@@ -1068,10 +1097,20 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
|
|
|
1068
1097
|
await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False)
|
|
1069
1098
|
scheduled_replicas += 1
|
|
1070
1099
|
|
|
1100
|
+
secrets = await get_project_secrets_mapping(
|
|
1101
|
+
session=session,
|
|
1102
|
+
project=run_model.project,
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1071
1105
|
for replica_num in range(
|
|
1072
1106
|
len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff
|
|
1073
1107
|
):
|
|
1074
|
-
|
|
1108
|
+
# FIXME: Handle getting image configuration errors or skip it.
|
|
1109
|
+
jobs = await get_jobs_from_run_spec(
|
|
1110
|
+
run_spec=run_spec,
|
|
1111
|
+
secrets=secrets,
|
|
1112
|
+
replica_num=replica_num,
|
|
1113
|
+
)
|
|
1075
1114
|
for job in jobs:
|
|
1076
1115
|
job_model = create_job_model_for_new_submission(
|
|
1077
1116
|
run_model=run_model,
|
|
@@ -1084,8 +1123,14 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
|
|
|
1084
1123
|
async def retry_run_replica_jobs(
|
|
1085
1124
|
session: AsyncSession, run_model: RunModel, latest_jobs: List[JobModel], *, only_failed: bool
|
|
1086
1125
|
):
|
|
1126
|
+
# FIXME: Handle getting image configuration errors or skip it.
|
|
1127
|
+
secrets = await get_project_secrets_mapping(
|
|
1128
|
+
session=session,
|
|
1129
|
+
project=run_model.project,
|
|
1130
|
+
)
|
|
1087
1131
|
new_jobs = await get_jobs_from_run_spec(
|
|
1088
|
-
RunSpec.__response__.parse_raw(run_model.run_spec),
|
|
1132
|
+
run_spec=RunSpec.__response__.parse_raw(run_model.run_spec),
|
|
1133
|
+
secrets=secrets,
|
|
1089
1134
|
replica_num=latest_jobs[0].replica_num,
|
|
1090
1135
|
)
|
|
1091
1136
|
assert len(new_jobs) == len(latest_jobs), (
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
import sqlalchemy.exc
|
|
5
|
+
from sqlalchemy import delete, select, update
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
|
|
8
|
+
from dstack._internal.core.errors import (
|
|
9
|
+
ResourceExistsError,
|
|
10
|
+
ResourceNotExistsError,
|
|
11
|
+
ServerClientError,
|
|
12
|
+
)
|
|
13
|
+
from dstack._internal.core.models.secrets import Secret
|
|
14
|
+
from dstack._internal.server.models import DecryptedString, ProjectModel, SecretModel
|
|
15
|
+
from dstack._internal.utils.logging import get_logger
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
_SECRET_NAME_REGEX = "^[A-Za-z0-9-_]{1,200}$"
|
|
21
|
+
_SECRET_VALUE_MAX_LENGTH = 2000
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def list_secrets(
|
|
25
|
+
session: AsyncSession,
|
|
26
|
+
project: ProjectModel,
|
|
27
|
+
) -> List[Secret]:
|
|
28
|
+
secret_models = await list_project_secret_models(session=session, project=project)
|
|
29
|
+
return [secret_model_to_secret(s, include_value=False) for s in secret_models]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
async def get_project_secrets_mapping(
|
|
33
|
+
session: AsyncSession,
|
|
34
|
+
project: ProjectModel,
|
|
35
|
+
) -> Dict[str, str]:
|
|
36
|
+
secret_models = await list_project_secret_models(session=session, project=project)
|
|
37
|
+
return {s.name: s.value.get_plaintext_or_error() for s in secret_models}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
async def get_secret(
|
|
41
|
+
session: AsyncSession,
|
|
42
|
+
project: ProjectModel,
|
|
43
|
+
name: str,
|
|
44
|
+
) -> Optional[Secret]:
|
|
45
|
+
secret_model = await get_project_secret_model_by_name(
|
|
46
|
+
session=session,
|
|
47
|
+
project=project,
|
|
48
|
+
name=name,
|
|
49
|
+
)
|
|
50
|
+
if secret_model is None:
|
|
51
|
+
return None
|
|
52
|
+
return secret_model_to_secret(secret_model, include_value=True)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
async def create_or_update_secret(
|
|
56
|
+
session: AsyncSession,
|
|
57
|
+
project: ProjectModel,
|
|
58
|
+
name: str,
|
|
59
|
+
value: str,
|
|
60
|
+
) -> Secret:
|
|
61
|
+
_validate_secret(name=name, value=value)
|
|
62
|
+
try:
|
|
63
|
+
secret_model = await create_secret(
|
|
64
|
+
session=session,
|
|
65
|
+
project=project,
|
|
66
|
+
name=name,
|
|
67
|
+
value=value,
|
|
68
|
+
)
|
|
69
|
+
except ResourceExistsError:
|
|
70
|
+
secret_model = await update_secret(
|
|
71
|
+
session=session,
|
|
72
|
+
project=project,
|
|
73
|
+
name=name,
|
|
74
|
+
value=value,
|
|
75
|
+
)
|
|
76
|
+
return secret_model_to_secret(secret_model, include_value=True)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
async def delete_secrets(
|
|
80
|
+
session: AsyncSession,
|
|
81
|
+
project: ProjectModel,
|
|
82
|
+
names: List[str],
|
|
83
|
+
):
|
|
84
|
+
existing_secrets_query = await session.execute(
|
|
85
|
+
select(SecretModel).where(
|
|
86
|
+
SecretModel.project_id == project.id,
|
|
87
|
+
SecretModel.name.in_(names),
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
existing_names = [s.name for s in existing_secrets_query.scalars().all()]
|
|
91
|
+
missing_names = set(names) - set(existing_names)
|
|
92
|
+
if missing_names:
|
|
93
|
+
raise ResourceNotExistsError(f"Secrets not found: {', '.join(missing_names)}")
|
|
94
|
+
|
|
95
|
+
await session.execute(
|
|
96
|
+
delete(SecretModel).where(
|
|
97
|
+
SecretModel.project_id == project.id,
|
|
98
|
+
SecretModel.name.in_(names),
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
await session.commit()
|
|
102
|
+
logger.info("Deleted secrets %s in project %s", names, project.name)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def secret_model_to_secret(secret_model: SecretModel, include_value: bool = False) -> Secret:
|
|
106
|
+
value = None
|
|
107
|
+
if include_value:
|
|
108
|
+
value = secret_model.value.get_plaintext_or_error()
|
|
109
|
+
return Secret(
|
|
110
|
+
id=secret_model.id,
|
|
111
|
+
name=secret_model.name,
|
|
112
|
+
value=value,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
async def list_project_secret_models(
|
|
117
|
+
session: AsyncSession,
|
|
118
|
+
project: ProjectModel,
|
|
119
|
+
) -> List[SecretModel]:
|
|
120
|
+
res = await session.execute(
|
|
121
|
+
select(SecretModel)
|
|
122
|
+
.where(
|
|
123
|
+
SecretModel.project_id == project.id,
|
|
124
|
+
)
|
|
125
|
+
.order_by(SecretModel.created_at.desc())
|
|
126
|
+
)
|
|
127
|
+
secret_models = list(res.scalars().all())
|
|
128
|
+
return secret_models
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
async def get_project_secret_model_by_name(
|
|
132
|
+
session: AsyncSession,
|
|
133
|
+
project: ProjectModel,
|
|
134
|
+
name: str,
|
|
135
|
+
) -> Optional[SecretModel]:
|
|
136
|
+
res = await session.execute(
|
|
137
|
+
select(SecretModel).where(
|
|
138
|
+
SecretModel.project_id == project.id,
|
|
139
|
+
SecretModel.name == name,
|
|
140
|
+
)
|
|
141
|
+
)
|
|
142
|
+
return res.scalar_one_or_none()
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def create_secret(
|
|
146
|
+
session: AsyncSession,
|
|
147
|
+
project: ProjectModel,
|
|
148
|
+
name: str,
|
|
149
|
+
value: str,
|
|
150
|
+
) -> SecretModel:
|
|
151
|
+
secret_model = SecretModel(
|
|
152
|
+
project_id=project.id,
|
|
153
|
+
name=name,
|
|
154
|
+
value=DecryptedString(plaintext=value),
|
|
155
|
+
)
|
|
156
|
+
try:
|
|
157
|
+
async with session.begin_nested():
|
|
158
|
+
session.add(secret_model)
|
|
159
|
+
except sqlalchemy.exc.IntegrityError:
|
|
160
|
+
raise ResourceExistsError()
|
|
161
|
+
await session.commit()
|
|
162
|
+
return secret_model
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
async def update_secret(
|
|
166
|
+
session: AsyncSession,
|
|
167
|
+
project: ProjectModel,
|
|
168
|
+
name: str,
|
|
169
|
+
value: str,
|
|
170
|
+
) -> SecretModel:
|
|
171
|
+
await session.execute(
|
|
172
|
+
update(SecretModel)
|
|
173
|
+
.where(
|
|
174
|
+
SecretModel.project_id == project.id,
|
|
175
|
+
SecretModel.name == name,
|
|
176
|
+
)
|
|
177
|
+
.values(
|
|
178
|
+
value=DecryptedString(plaintext=value),
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
await session.commit()
|
|
182
|
+
secret_model = await get_project_secret_model_by_name(
|
|
183
|
+
session=session,
|
|
184
|
+
project=project,
|
|
185
|
+
name=name,
|
|
186
|
+
)
|
|
187
|
+
if secret_model is None:
|
|
188
|
+
raise ResourceNotExistsError()
|
|
189
|
+
return secret_model
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _validate_secret(name: str, value: str):
|
|
193
|
+
_validate_secret_name(name)
|
|
194
|
+
_validate_secret_value(value)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _validate_secret_name(name: str):
|
|
198
|
+
if re.match(_SECRET_NAME_REGEX, name) is None:
|
|
199
|
+
raise ServerClientError(f"Secret name should match regex '{_SECRET_NAME_REGEX}")
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _validate_secret_value(value: str):
|
|
203
|
+
if len(value) > _SECRET_VALUE_MAX_LENGTH:
|
|
204
|
+
raise ServerClientError(f"Secret value length must not exceed {_SECRET_VALUE_MAX_LENGTH}")
|
|
@@ -22,6 +22,27 @@ class BaseStorage(ABC):
|
|
|
22
22
|
) -> Optional[bytes]:
|
|
23
23
|
pass
|
|
24
24
|
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def upload_archive(
|
|
27
|
+
self,
|
|
28
|
+
user_id: str,
|
|
29
|
+
archive_hash: str,
|
|
30
|
+
blob: bytes,
|
|
31
|
+
):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def get_archive(
|
|
36
|
+
self,
|
|
37
|
+
user_id: str,
|
|
38
|
+
archive_hash: str,
|
|
39
|
+
) -> Optional[bytes]:
|
|
40
|
+
pass
|
|
41
|
+
|
|
25
42
|
@staticmethod
|
|
26
43
|
def _get_code_key(project_id: str, repo_id: str, code_hash: str) -> str:
|
|
27
44
|
return f"data/projects/{project_id}/codes/{repo_id}/{code_hash}"
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def _get_archive_key(user_id: str, archive_hash: str) -> str:
|
|
48
|
+
return f"data/users/{user_id}/file_archives/{archive_hash}"
|
|
@@ -25,9 +25,8 @@ class GCSStorage(BaseStorage):
|
|
|
25
25
|
code_hash: str,
|
|
26
26
|
blob: bytes,
|
|
27
27
|
):
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
blob_obj.upload_from_string(blob)
|
|
28
|
+
key = self._get_code_key(project_id, repo_id, code_hash)
|
|
29
|
+
self._upload(key, blob)
|
|
31
30
|
|
|
32
31
|
def get_code(
|
|
33
32
|
self,
|
|
@@ -35,10 +34,33 @@ class GCSStorage(BaseStorage):
|
|
|
35
34
|
repo_id: str,
|
|
36
35
|
code_hash: str,
|
|
37
36
|
) -> Optional[bytes]:
|
|
37
|
+
key = self._get_code_key(project_id, repo_id, code_hash)
|
|
38
|
+
return self._get(key)
|
|
39
|
+
|
|
40
|
+
def upload_archive(
|
|
41
|
+
self,
|
|
42
|
+
user_id: str,
|
|
43
|
+
archive_hash: str,
|
|
44
|
+
blob: bytes,
|
|
45
|
+
):
|
|
46
|
+
key = self._get_archive_key(user_id, archive_hash)
|
|
47
|
+
self._upload(key, blob)
|
|
48
|
+
|
|
49
|
+
def get_archive(
|
|
50
|
+
self,
|
|
51
|
+
user_id: str,
|
|
52
|
+
archive_hash: str,
|
|
53
|
+
) -> Optional[bytes]:
|
|
54
|
+
key = self._get_archive_key(user_id, archive_hash)
|
|
55
|
+
return self._get(key)
|
|
56
|
+
|
|
57
|
+
def _upload(self, key: str, blob: bytes):
|
|
58
|
+
blob_obj = self._bucket.blob(key)
|
|
59
|
+
blob_obj.upload_from_string(blob)
|
|
60
|
+
|
|
61
|
+
def _get(self, key: str) -> Optional[bytes]:
|
|
38
62
|
try:
|
|
39
|
-
|
|
40
|
-
blob = self._bucket.blob(blob_name)
|
|
63
|
+
blob = self._bucket.blob(key)
|
|
41
64
|
except NotFound:
|
|
42
65
|
return None
|
|
43
|
-
|
|
44
66
|
return blob.download_as_bytes()
|