dstack 0.19.16__py3-none-any.whl → 0.19.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (80) hide show
  1. dstack/_internal/cli/commands/secrets.py +92 -0
  2. dstack/_internal/cli/main.py +2 -0
  3. dstack/_internal/cli/services/completion.py +5 -0
  4. dstack/_internal/cli/services/configurators/fleet.py +13 -1
  5. dstack/_internal/cli/services/configurators/run.py +59 -17
  6. dstack/_internal/cli/utils/secrets.py +25 -0
  7. dstack/_internal/core/backends/__init__.py +10 -4
  8. dstack/_internal/core/backends/aws/compute.py +237 -18
  9. dstack/_internal/core/backends/base/compute.py +20 -2
  10. dstack/_internal/core/backends/cudo/compute.py +23 -9
  11. dstack/_internal/core/backends/gcp/compute.py +13 -7
  12. dstack/_internal/core/backends/lambdalabs/compute.py +2 -1
  13. dstack/_internal/core/compatibility/fleets.py +12 -11
  14. dstack/_internal/core/compatibility/gateways.py +9 -8
  15. dstack/_internal/core/compatibility/logs.py +4 -3
  16. dstack/_internal/core/compatibility/runs.py +41 -17
  17. dstack/_internal/core/compatibility/volumes.py +9 -8
  18. dstack/_internal/core/errors.py +4 -0
  19. dstack/_internal/core/models/common.py +7 -0
  20. dstack/_internal/core/models/configurations.py +11 -0
  21. dstack/_internal/core/models/files.py +67 -0
  22. dstack/_internal/core/models/runs.py +14 -0
  23. dstack/_internal/core/models/secrets.py +9 -2
  24. dstack/_internal/core/services/diff.py +36 -3
  25. dstack/_internal/server/app.py +22 -0
  26. dstack/_internal/server/background/__init__.py +61 -37
  27. dstack/_internal/server/background/tasks/process_fleets.py +19 -3
  28. dstack/_internal/server/background/tasks/process_gateways.py +1 -1
  29. dstack/_internal/server/background/tasks/process_instances.py +13 -2
  30. dstack/_internal/server/background/tasks/process_placement_groups.py +4 -2
  31. dstack/_internal/server/background/tasks/process_running_jobs.py +123 -15
  32. dstack/_internal/server/background/tasks/process_runs.py +23 -7
  33. dstack/_internal/server/background/tasks/process_submitted_jobs.py +36 -7
  34. dstack/_internal/server/background/tasks/process_terminating_jobs.py +5 -3
  35. dstack/_internal/server/background/tasks/process_volumes.py +2 -2
  36. dstack/_internal/server/migrations/versions/5f1707c525d2_add_filearchivemodel.py +39 -0
  37. dstack/_internal/server/migrations/versions/644b8a114187_add_secretmodel.py +49 -0
  38. dstack/_internal/server/models.py +33 -0
  39. dstack/_internal/server/routers/files.py +67 -0
  40. dstack/_internal/server/routers/secrets.py +57 -15
  41. dstack/_internal/server/schemas/files.py +5 -0
  42. dstack/_internal/server/schemas/runner.py +2 -0
  43. dstack/_internal/server/schemas/secrets.py +7 -11
  44. dstack/_internal/server/services/backends/__init__.py +1 -1
  45. dstack/_internal/server/services/files.py +91 -0
  46. dstack/_internal/server/services/fleets.py +5 -4
  47. dstack/_internal/server/services/gateways/__init__.py +4 -2
  48. dstack/_internal/server/services/jobs/__init__.py +19 -8
  49. dstack/_internal/server/services/jobs/configurators/base.py +25 -3
  50. dstack/_internal/server/services/jobs/configurators/dev.py +3 -3
  51. dstack/_internal/server/services/locking.py +101 -12
  52. dstack/_internal/server/services/proxy/repo.py +3 -0
  53. dstack/_internal/server/services/runner/client.py +8 -0
  54. dstack/_internal/server/services/runs.py +76 -47
  55. dstack/_internal/server/services/secrets.py +204 -0
  56. dstack/_internal/server/services/storage/base.py +21 -0
  57. dstack/_internal/server/services/storage/gcs.py +28 -6
  58. dstack/_internal/server/services/storage/s3.py +27 -9
  59. dstack/_internal/server/services/volumes.py +2 -2
  60. dstack/_internal/server/settings.py +19 -5
  61. dstack/_internal/server/statics/index.html +1 -1
  62. dstack/_internal/server/statics/{main-a4eafa74304e587d037c.js → main-d1ac2e8c38ed5f08a114.js} +86 -34
  63. dstack/_internal/server/statics/{main-a4eafa74304e587d037c.js.map → main-d1ac2e8c38ed5f08a114.js.map} +1 -1
  64. dstack/_internal/server/statics/{main-f53d6d0d42f8d61df1de.css → main-d58fc0460cb0eae7cb5c.css} +1 -1
  65. dstack/_internal/server/statics/static/media/google.b194b06fafd0a52aeb566922160ea514.svg +1 -0
  66. dstack/_internal/server/testing/common.py +50 -8
  67. dstack/_internal/settings.py +4 -0
  68. dstack/_internal/utils/files.py +69 -0
  69. dstack/_internal/utils/nested_list.py +47 -0
  70. dstack/_internal/utils/path.py +12 -4
  71. dstack/api/_public/runs.py +67 -7
  72. dstack/api/server/__init__.py +6 -0
  73. dstack/api/server/_files.py +18 -0
  74. dstack/api/server/_secrets.py +15 -15
  75. dstack/version.py +1 -1
  76. {dstack-0.19.16.dist-info → dstack-0.19.18.dist-info}/METADATA +13 -13
  77. {dstack-0.19.16.dist-info → dstack-0.19.18.dist-info}/RECORD +80 -67
  78. {dstack-0.19.16.dist-info → dstack-0.19.18.dist-info}/WHEEL +0 -0
  79. {dstack-0.19.16.dist-info → dstack-0.19.18.dist-info}/entry_points.txt +0 -0
  80. {dstack-0.19.16.dist-info → dstack-0.19.18.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,8 +1,10 @@
1
1
  import asyncio
2
+ import collections.abc
2
3
  import hashlib
4
+ from abc import abstractmethod
3
5
  from asyncio import Lock
4
6
  from contextlib import asynccontextmanager
5
- from typing import AsyncGenerator, Dict, List, Set, Tuple, TypeVar, Union
7
+ from typing import AsyncGenerator, Iterable, Iterator, Protocol, TypeVar, Union
6
8
 
7
9
  from sqlalchemy import func, select
8
10
  from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
@@ -10,23 +12,54 @@ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
10
12
  KeyT = TypeVar("KeyT")
11
13
 
12
14
 
13
- class ResourceLocker:
14
- def __init__(self):
15
- self.namespace_to_locks_map: Dict[str, Tuple[Lock, set]] = {}
15
+ class LocksetLock(Protocol):
16
+ async def acquire(self) -> bool: ...
17
+ def release(self) -> None: ...
18
+ async def __aenter__(self): ...
19
+ async def __aexit__(self, exc_type, exc, tb): ...
20
+
21
+
22
+ T = TypeVar("T")
23
+
16
24
 
17
- def get_lockset(self, namespace: str) -> Tuple[Lock, set]:
25
+ class Lockset(Protocol[T]):
26
+ def __contains__(self, item: T) -> bool: ...
27
+ def __iter__(self) -> Iterator[T]: ...
28
+ def __len__(self) -> int: ...
29
+ def add(self, item: T) -> None: ...
30
+ def discard(self, item: T) -> None: ...
31
+ def update(self, other: Iterable[T]) -> None: ...
32
+ def difference_update(self, other: Iterable[T]) -> None: ...
33
+
34
+
35
+ class ResourceLocker:
36
+ @abstractmethod
37
+ def get_lockset(self, namespace: str) -> tuple[LocksetLock, Lockset]:
18
38
  """
19
39
  Returns a lockset containing locked resources for in-memory locking.
20
40
  Also returns a lock that guards the lockset.
21
41
  """
22
- return self.namespace_to_locks_map.setdefault(namespace, (Lock(), set()))
42
+ pass
23
43
 
44
+ @abstractmethod
24
45
  @asynccontextmanager
25
- async def lock_ctx(self, namespace: str, keys: List[KeyT]):
46
+ async def lock_ctx(self, namespace: str, keys: list[KeyT]):
26
47
  """
27
48
  Acquires locks for all keys in namespace.
28
49
  The keys must be sorted to prevent deadlock.
29
50
  """
51
+ yield
52
+
53
+
54
+ class InMemoryResourceLocker(ResourceLocker):
55
+ def __init__(self):
56
+ self.namespace_to_locks_map: dict[str, tuple[Lock, set]] = {}
57
+
58
+ def get_lockset(self, namespace: str) -> tuple[Lock, set]:
59
+ return self.namespace_to_locks_map.setdefault(namespace, (Lock(), set()))
60
+
61
+ @asynccontextmanager
62
+ async def lock_ctx(self, namespace: str, keys: list[KeyT]):
30
63
  lock, lockset = self.get_lockset(namespace)
31
64
  try:
32
65
  await _wait_to_lock_many(lock, lockset, keys)
@@ -35,6 +68,56 @@ class ResourceLocker:
35
68
  lockset.difference_update(keys)
36
69
 
37
70
 
71
+ class DummyAsyncLock:
72
+ async def __aenter__(self):
73
+ pass
74
+
75
+ async def __aexit__(self, exc_type, exc, tb):
76
+ pass
77
+
78
+ async def acquire(self):
79
+ return True
80
+
81
+ def release(self):
82
+ pass
83
+
84
+
85
+ class DummySet(collections.abc.MutableSet):
86
+ def __contains__(self, item):
87
+ return False
88
+
89
+ def __iter__(self):
90
+ return iter(())
91
+
92
+ def __len__(self):
93
+ return 0
94
+
95
+ def add(self, value):
96
+ pass
97
+
98
+ def discard(self, value):
99
+ pass
100
+
101
+ def update(self, other):
102
+ pass
103
+
104
+ def difference_update(self, other):
105
+ pass
106
+
107
+
108
+ class DummyResourceLocker(ResourceLocker):
109
+ def __init__(self):
110
+ self.lock = DummyAsyncLock()
111
+ self.lockset = DummySet()
112
+
113
+ def get_lockset(self, namespace: str) -> tuple[DummyAsyncLock, DummySet]:
114
+ return self.lock, self.lockset
115
+
116
+ @asynccontextmanager
117
+ async def lock_ctx(self, namespace: str, keys: list[KeyT]):
118
+ yield
119
+
120
+
38
121
  def string_to_lock_id(s: str) -> int:
39
122
  return int(hashlib.sha256(s.encode()).hexdigest(), 16) % (2**63)
40
123
 
@@ -67,15 +150,21 @@ async def try_advisory_lock_ctx(
67
150
  await bind.execute(select(func.pg_advisory_unlock(string_to_lock_id(resource))))
68
151
 
69
152
 
70
- _locker = ResourceLocker()
153
+ _in_memory_locker = InMemoryResourceLocker()
154
+ _dummy_locker = DummyResourceLocker()
71
155
 
72
156
 
73
- def get_locker() -> ResourceLocker:
74
- return _locker
157
+ def get_locker(dialect_name: str) -> ResourceLocker:
158
+ if dialect_name == "sqlite":
159
+ return _in_memory_locker
160
+ # We could use an in-memory locker on Postgres
161
+ # but it can lead to unnecessary lock contention,
162
+ # so we use a dummy locker that does not take any locks.
163
+ return _dummy_locker
75
164
 
76
165
 
77
166
  async def _wait_to_lock_many(
78
- lock: asyncio.Lock, locked: Set[KeyT], keys: List[KeyT], *, delay: float = 0.1
167
+ lock: asyncio.Lock, locked: set[KeyT], keys: list[KeyT], *, delay: float = 0.1
79
168
  ):
80
169
  """
81
170
  Retry locking until all the keys are locked.
@@ -88,7 +177,7 @@ async def _wait_to_lock_many(
88
177
  locked_now_num = 0
89
178
  for key in left_to_lock:
90
179
  if key in locked:
91
- # Someone already aquired the lock, wait
180
+ # Someone already acquired the lock, wait
92
181
  break
93
182
  locked.add(key)
94
183
  locked_now_num += 1
@@ -7,6 +7,7 @@ from sqlalchemy.orm import joinedload
7
7
 
8
8
  import dstack._internal.server.services.jobs as jobs_services
9
9
  from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
10
+ from dstack._internal.core.models.backends.base import BackendType
10
11
  from dstack._internal.core.models.configurations import ServiceConfiguration
11
12
  from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams
12
13
  from dstack._internal.core.models.runs import (
@@ -86,6 +87,8 @@ class ServerProxyRepo(BaseProxyRepo):
86
87
  username=jpd.username,
87
88
  port=jpd.ssh_port,
88
89
  )
90
+ if jpd.backend == BackendType.LOCAL:
91
+ ssh_proxy = None
89
92
  ssh_head_proxy: Optional[SSHConnectionParams] = None
90
93
  ssh_head_proxy_private_key: Optional[str] = None
91
94
  instance = get_or_error(job.instance)
@@ -109,6 +109,14 @@ class RunnerClient:
109
109
  )
110
110
  resp.raise_for_status()
111
111
 
112
+ def upload_archive(self, id: uuid.UUID, file: Union[BinaryIO, bytes]):
113
+ resp = requests.post(
114
+ self._url("/api/upload_archive"),
115
+ files={"archive": (str(id), file)},
116
+ timeout=UPLOAD_CODE_REQUEST_TIMEOUT,
117
+ )
118
+ resp.raise_for_status()
119
+
112
120
  def upload_code(self, file: Union[BinaryIO, bytes]):
113
121
  resp = requests.post(
114
122
  self._url("/api/upload_code"), data=file, timeout=UPLOAD_CODE_REQUEST_TIMEOUT
@@ -82,6 +82,7 @@ from dstack._internal.server.services.offers import get_offers_by_requirements
82
82
  from dstack._internal.server.services.plugins import apply_plugin_policies
83
83
  from dstack._internal.server.services.projects import list_project_models, list_user_project_models
84
84
  from dstack._internal.server.services.resources import set_resources_defaults
85
+ from dstack._internal.server.services.secrets import get_project_secrets_mapping
85
86
  from dstack._internal.server.services.users import get_user_model_by_name
86
87
  from dstack._internal.utils.logging import get_logger
87
88
  from dstack._internal.utils.random_names import generate_name
@@ -311,7 +312,12 @@ async def get_plan(
311
312
  ):
312
313
  action = ApplyAction.UPDATE
313
314
 
314
- jobs = await get_jobs_from_run_spec(effective_run_spec, replica_num=0)
315
+ secrets = await get_project_secrets_mapping(session=session, project=project)
316
+ jobs = await get_jobs_from_run_spec(
317
+ run_spec=effective_run_spec,
318
+ secrets=secrets,
319
+ replica_num=0,
320
+ )
315
321
 
316
322
  volumes = await get_job_configured_volumes(
317
323
  session=session,
@@ -462,6 +468,10 @@ async def submit_run(
462
468
  project=project,
463
469
  run_spec=run_spec,
464
470
  )
471
+ secrets = await get_project_secrets_mapping(
472
+ session=session,
473
+ project=project,
474
+ )
465
475
 
466
476
  lock_namespace = f"run_names_{project.name}"
467
477
  if get_db().dialect_name == "sqlite":
@@ -472,8 +482,9 @@ async def submit_run(
472
482
  select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
473
483
  )
474
484
 
475
- lock, _ = get_locker().get_lockset(lock_namespace)
485
+ lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
476
486
  async with lock:
487
+ # FIXME: delete_runs commits, so Postgres lock is released too early.
477
488
  if run_spec.run_name is None:
478
489
  run_spec.run_name = await _generate_run_name(
479
490
  session=session,
@@ -513,7 +524,11 @@ async def submit_run(
513
524
  await services.register_service(session, run_model, run_spec)
514
525
 
515
526
  for replica_num in range(replicas):
516
- jobs = await get_jobs_from_run_spec(run_spec, replica_num=replica_num)
527
+ jobs = await get_jobs_from_run_spec(
528
+ run_spec=run_spec,
529
+ secrets=secrets,
530
+ replica_num=replica_num,
531
+ )
517
532
  for job in jobs:
518
533
  job_model = create_job_model_for_new_submission(
519
534
  run_model=run_model,
@@ -572,46 +587,29 @@ async def stop_runs(
572
587
  )
573
588
  run_models = res.scalars().all()
574
589
  run_ids = sorted([r.id for r in run_models])
575
- res = await session.execute(select(JobModel).where(JobModel.run_id.in_(run_ids)))
576
- job_models = res.scalars().all()
577
- job_ids = sorted([j.id for j in job_models])
578
590
  await session.commit()
579
- async with (
580
- get_locker().lock_ctx(RunModel.__tablename__, run_ids),
581
- get_locker().lock_ctx(JobModel.__tablename__, job_ids),
582
- ):
591
+ async with get_locker(get_db().dialect_name).lock_ctx(RunModel.__tablename__, run_ids):
592
+ res = await session.execute(
593
+ select(RunModel)
594
+ .where(RunModel.id.in_(run_ids))
595
+ .order_by(RunModel.id) # take locks in order
596
+ .with_for_update(key_share=True)
597
+ .execution_options(populate_existing=True)
598
+ )
599
+ run_models = res.scalars().all()
600
+ now = common_utils.get_current_datetime()
583
601
  for run_model in run_models:
584
- await stop_run(session=session, run_model=run_model, abort=abort)
585
-
586
-
587
- async def stop_run(session: AsyncSession, run_model: RunModel, abort: bool):
588
- res = await session.execute(
589
- select(RunModel)
590
- .where(RunModel.id == run_model.id)
591
- .order_by(RunModel.id) # take locks in order
592
- .with_for_update(key_share=True)
593
- .execution_options(populate_existing=True)
594
- )
595
- run_model = res.scalar_one()
596
- await session.execute(
597
- select(JobModel)
598
- .where(JobModel.run_id == run_model.id)
599
- .order_by(JobModel.id) # take locks in order
600
- .with_for_update(key_share=True)
601
- .execution_options(populate_existing=True)
602
- )
603
- if run_model.status.is_finished():
604
- return
605
- run_model.status = RunStatus.TERMINATING
606
- if abort:
607
- run_model.termination_reason = RunTerminationReason.ABORTED_BY_USER
608
- else:
609
- run_model.termination_reason = RunTerminationReason.STOPPED_BY_USER
610
- # process the run out of turn
611
- logger.debug("%s: terminating because %s", fmt(run_model), run_model.termination_reason.name)
612
- await process_terminating_run(session, run_model)
613
- run_model.last_processed_at = common_utils.get_current_datetime()
614
- await session.commit()
602
+ if run_model.status.is_finished():
603
+ continue
604
+ run_model.status = RunStatus.TERMINATING
605
+ if abort:
606
+ run_model.termination_reason = RunTerminationReason.ABORTED_BY_USER
607
+ else:
608
+ run_model.termination_reason = RunTerminationReason.STOPPED_BY_USER
609
+ run_model.last_processed_at = now
610
+ # The run will be terminated by process_runs.
611
+ # Terminating synchronously is problematic since it may take a long time.
612
+ await session.commit()
615
613
 
616
614
 
617
615
  async def delete_runs(
@@ -628,7 +626,7 @@ async def delete_runs(
628
626
  run_models = res.scalars().all()
629
627
  run_ids = sorted([r.id for r in run_models])
630
628
  await session.commit()
631
- async with get_locker().lock_ctx(RunModel.__tablename__, run_ids):
629
+ async with get_locker(get_db().dialect_name).lock_ctx(RunModel.__tablename__, run_ids):
632
630
  res = await session.execute(
633
631
  select(RunModel)
634
632
  .where(RunModel.id.in_(run_ids))
@@ -898,7 +896,16 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec):
898
896
  set_resources_defaults(run_spec.configuration.resources)
899
897
 
900
898
 
901
- _UPDATABLE_SPEC_FIELDS = ["repo_code_hash", "configuration"]
899
+ _UPDATABLE_SPEC_FIELDS = ["configuration_path", "configuration"]
900
+ _TYPE_SPECIFIC_UPDATABLE_SPEC_FIELDS = {
901
+ "service": [
902
+ # rolling deployment
903
+ "repo_data",
904
+ "repo_code_hash",
905
+ "file_archives",
906
+ "working_dir",
907
+ ],
908
+ }
902
909
  _CONF_UPDATABLE_FIELDS = ["priority"]
903
910
  _TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS = {
904
911
  "dev-environment": ["inactivity_duration"],
@@ -909,10 +916,13 @@ _TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS = {
909
916
  # rolling deployment
910
917
  "resources",
911
918
  "volumes",
919
+ "docker",
920
+ "files",
912
921
  "image",
913
922
  "user",
914
923
  "privileged",
915
924
  "entrypoint",
925
+ "working_dir",
916
926
  "python",
917
927
  "nvcc",
918
928
  "single_branch",
@@ -935,11 +945,14 @@ def _can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec) -> bo
935
945
  def _check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec):
936
946
  spec_diff = diff_models(current_run_spec, new_run_spec)
937
947
  changed_spec_fields = list(spec_diff.keys())
948
+ updatable_spec_fields = _UPDATABLE_SPEC_FIELDS + _TYPE_SPECIFIC_UPDATABLE_SPEC_FIELDS.get(
949
+ new_run_spec.configuration.type, []
950
+ )
938
951
  for key in changed_spec_fields:
939
- if key not in _UPDATABLE_SPEC_FIELDS:
952
+ if key not in updatable_spec_fields:
940
953
  raise ServerClientError(
941
954
  f"Failed to update fields {changed_spec_fields}."
942
- f" Can only update {_UPDATABLE_SPEC_FIELDS}."
955
+ f" Can only update {updatable_spec_fields}."
943
956
  )
944
957
  _check_can_update_configuration(current_run_spec.configuration, new_run_spec.configuration)
945
958
 
@@ -1068,10 +1081,20 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
1068
1081
  await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False)
1069
1082
  scheduled_replicas += 1
1070
1083
 
1084
+ secrets = await get_project_secrets_mapping(
1085
+ session=session,
1086
+ project=run_model.project,
1087
+ )
1088
+
1071
1089
  for replica_num in range(
1072
1090
  len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff
1073
1091
  ):
1074
- jobs = await get_jobs_from_run_spec(run_spec, replica_num=replica_num)
1092
+ # FIXME: Handle getting image configuration errors or skip it.
1093
+ jobs = await get_jobs_from_run_spec(
1094
+ run_spec=run_spec,
1095
+ secrets=secrets,
1096
+ replica_num=replica_num,
1097
+ )
1075
1098
  for job in jobs:
1076
1099
  job_model = create_job_model_for_new_submission(
1077
1100
  run_model=run_model,
@@ -1084,8 +1107,14 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
1084
1107
  async def retry_run_replica_jobs(
1085
1108
  session: AsyncSession, run_model: RunModel, latest_jobs: List[JobModel], *, only_failed: bool
1086
1109
  ):
1110
+ # FIXME: Handle getting image configuration errors or skip it.
1111
+ secrets = await get_project_secrets_mapping(
1112
+ session=session,
1113
+ project=run_model.project,
1114
+ )
1087
1115
  new_jobs = await get_jobs_from_run_spec(
1088
- RunSpec.__response__.parse_raw(run_model.run_spec),
1116
+ run_spec=RunSpec.__response__.parse_raw(run_model.run_spec),
1117
+ secrets=secrets,
1089
1118
  replica_num=latest_jobs[0].replica_num,
1090
1119
  )
1091
1120
  assert len(new_jobs) == len(latest_jobs), (
@@ -0,0 +1,204 @@
1
+ import re
2
+ from typing import Dict, List, Optional
3
+
4
+ import sqlalchemy.exc
5
+ from sqlalchemy import delete, select, update
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+
8
+ from dstack._internal.core.errors import (
9
+ ResourceExistsError,
10
+ ResourceNotExistsError,
11
+ ServerClientError,
12
+ )
13
+ from dstack._internal.core.models.secrets import Secret
14
+ from dstack._internal.server.models import DecryptedString, ProjectModel, SecretModel
15
+ from dstack._internal.utils.logging import get_logger
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ _SECRET_NAME_REGEX = "^[A-Za-z0-9-_]{1,200}$"
21
+ _SECRET_VALUE_MAX_LENGTH = 2000
22
+
23
+
24
+ async def list_secrets(
25
+ session: AsyncSession,
26
+ project: ProjectModel,
27
+ ) -> List[Secret]:
28
+ secret_models = await list_project_secret_models(session=session, project=project)
29
+ return [secret_model_to_secret(s, include_value=False) for s in secret_models]
30
+
31
+
32
+ async def get_project_secrets_mapping(
33
+ session: AsyncSession,
34
+ project: ProjectModel,
35
+ ) -> Dict[str, str]:
36
+ secret_models = await list_project_secret_models(session=session, project=project)
37
+ return {s.name: s.value.get_plaintext_or_error() for s in secret_models}
38
+
39
+
40
+ async def get_secret(
41
+ session: AsyncSession,
42
+ project: ProjectModel,
43
+ name: str,
44
+ ) -> Optional[Secret]:
45
+ secret_model = await get_project_secret_model_by_name(
46
+ session=session,
47
+ project=project,
48
+ name=name,
49
+ )
50
+ if secret_model is None:
51
+ return None
52
+ return secret_model_to_secret(secret_model, include_value=True)
53
+
54
+
55
+ async def create_or_update_secret(
56
+ session: AsyncSession,
57
+ project: ProjectModel,
58
+ name: str,
59
+ value: str,
60
+ ) -> Secret:
61
+ _validate_secret(name=name, value=value)
62
+ try:
63
+ secret_model = await create_secret(
64
+ session=session,
65
+ project=project,
66
+ name=name,
67
+ value=value,
68
+ )
69
+ except ResourceExistsError:
70
+ secret_model = await update_secret(
71
+ session=session,
72
+ project=project,
73
+ name=name,
74
+ value=value,
75
+ )
76
+ return secret_model_to_secret(secret_model, include_value=True)
77
+
78
+
79
+ async def delete_secrets(
80
+ session: AsyncSession,
81
+ project: ProjectModel,
82
+ names: List[str],
83
+ ):
84
+ existing_secrets_query = await session.execute(
85
+ select(SecretModel).where(
86
+ SecretModel.project_id == project.id,
87
+ SecretModel.name.in_(names),
88
+ )
89
+ )
90
+ existing_names = [s.name for s in existing_secrets_query.scalars().all()]
91
+ missing_names = set(names) - set(existing_names)
92
+ if missing_names:
93
+ raise ResourceNotExistsError(f"Secrets not found: {', '.join(missing_names)}")
94
+
95
+ await session.execute(
96
+ delete(SecretModel).where(
97
+ SecretModel.project_id == project.id,
98
+ SecretModel.name.in_(names),
99
+ )
100
+ )
101
+ await session.commit()
102
+ logger.info("Deleted secrets %s in project %s", names, project.name)
103
+
104
+
105
+ def secret_model_to_secret(secret_model: SecretModel, include_value: bool = False) -> Secret:
106
+ value = None
107
+ if include_value:
108
+ value = secret_model.value.get_plaintext_or_error()
109
+ return Secret(
110
+ id=secret_model.id,
111
+ name=secret_model.name,
112
+ value=value,
113
+ )
114
+
115
+
116
+ async def list_project_secret_models(
117
+ session: AsyncSession,
118
+ project: ProjectModel,
119
+ ) -> List[SecretModel]:
120
+ res = await session.execute(
121
+ select(SecretModel)
122
+ .where(
123
+ SecretModel.project_id == project.id,
124
+ )
125
+ .order_by(SecretModel.created_at.desc())
126
+ )
127
+ secret_models = list(res.scalars().all())
128
+ return secret_models
129
+
130
+
131
+ async def get_project_secret_model_by_name(
132
+ session: AsyncSession,
133
+ project: ProjectModel,
134
+ name: str,
135
+ ) -> Optional[SecretModel]:
136
+ res = await session.execute(
137
+ select(SecretModel).where(
138
+ SecretModel.project_id == project.id,
139
+ SecretModel.name == name,
140
+ )
141
+ )
142
+ return res.scalar_one_or_none()
143
+
144
+
145
+ async def create_secret(
146
+ session: AsyncSession,
147
+ project: ProjectModel,
148
+ name: str,
149
+ value: str,
150
+ ) -> SecretModel:
151
+ secret_model = SecretModel(
152
+ project_id=project.id,
153
+ name=name,
154
+ value=DecryptedString(plaintext=value),
155
+ )
156
+ try:
157
+ async with session.begin_nested():
158
+ session.add(secret_model)
159
+ except sqlalchemy.exc.IntegrityError:
160
+ raise ResourceExistsError()
161
+ await session.commit()
162
+ return secret_model
163
+
164
+
165
+ async def update_secret(
166
+ session: AsyncSession,
167
+ project: ProjectModel,
168
+ name: str,
169
+ value: str,
170
+ ) -> SecretModel:
171
+ await session.execute(
172
+ update(SecretModel)
173
+ .where(
174
+ SecretModel.project_id == project.id,
175
+ SecretModel.name == name,
176
+ )
177
+ .values(
178
+ value=DecryptedString(plaintext=value),
179
+ )
180
+ )
181
+ await session.commit()
182
+ secret_model = await get_project_secret_model_by_name(
183
+ session=session,
184
+ project=project,
185
+ name=name,
186
+ )
187
+ if secret_model is None:
188
+ raise ResourceNotExistsError()
189
+ return secret_model
190
+
191
+
192
+ def _validate_secret(name: str, value: str):
193
+ _validate_secret_name(name)
194
+ _validate_secret_value(value)
195
+
196
+
197
+ def _validate_secret_name(name: str):
198
+ if re.match(_SECRET_NAME_REGEX, name) is None:
199
+ raise ServerClientError(f"Secret name should match regex '{_SECRET_NAME_REGEX}")
200
+
201
+
202
+ def _validate_secret_value(value: str):
203
+ if len(value) > _SECRET_VALUE_MAX_LENGTH:
204
+ raise ServerClientError(f"Secret value length must not exceed {_SECRET_VALUE_MAX_LENGTH}")
@@ -22,6 +22,27 @@ class BaseStorage(ABC):
22
22
  ) -> Optional[bytes]:
23
23
  pass
24
24
 
25
+ @abstractmethod
26
+ def upload_archive(
27
+ self,
28
+ user_id: str,
29
+ archive_hash: str,
30
+ blob: bytes,
31
+ ):
32
+ pass
33
+
34
+ @abstractmethod
35
+ def get_archive(
36
+ self,
37
+ user_id: str,
38
+ archive_hash: str,
39
+ ) -> Optional[bytes]:
40
+ pass
41
+
25
42
  @staticmethod
26
43
  def _get_code_key(project_id: str, repo_id: str, code_hash: str) -> str:
27
44
  return f"data/projects/{project_id}/codes/{repo_id}/{code_hash}"
45
+
46
+ @staticmethod
47
+ def _get_archive_key(user_id: str, archive_hash: str) -> str:
48
+ return f"data/users/{user_id}/file_archives/{archive_hash}"
@@ -25,9 +25,8 @@ class GCSStorage(BaseStorage):
25
25
  code_hash: str,
26
26
  blob: bytes,
27
27
  ):
28
- blob_name = self._get_code_key(project_id, repo_id, code_hash)
29
- blob_obj = self._bucket.blob(blob_name)
30
- blob_obj.upload_from_string(blob)
28
+ key = self._get_code_key(project_id, repo_id, code_hash)
29
+ self._upload(key, blob)
31
30
 
32
31
  def get_code(
33
32
  self,
@@ -35,10 +34,33 @@ class GCSStorage(BaseStorage):
35
34
  repo_id: str,
36
35
  code_hash: str,
37
36
  ) -> Optional[bytes]:
37
+ key = self._get_code_key(project_id, repo_id, code_hash)
38
+ return self._get(key)
39
+
40
+ def upload_archive(
41
+ self,
42
+ user_id: str,
43
+ archive_hash: str,
44
+ blob: bytes,
45
+ ):
46
+ key = self._get_archive_key(user_id, archive_hash)
47
+ self._upload(key, blob)
48
+
49
+ def get_archive(
50
+ self,
51
+ user_id: str,
52
+ archive_hash: str,
53
+ ) -> Optional[bytes]:
54
+ key = self._get_archive_key(user_id, archive_hash)
55
+ return self._get(key)
56
+
57
+ def _upload(self, key: str, blob: bytes):
58
+ blob_obj = self._bucket.blob(key)
59
+ blob_obj.upload_from_string(blob)
60
+
61
+ def _get(self, key: str) -> Optional[bytes]:
38
62
  try:
39
- blob_name = self._get_code_key(project_id, repo_id, code_hash)
40
- blob = self._bucket.blob(blob_name)
63
+ blob = self._bucket.blob(key)
41
64
  except NotFound:
42
65
  return None
43
-
44
66
  return blob.download_as_bytes()