dstack 0.19.16__py3-none-any.whl → 0.19.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (80) hide show
  1. dstack/_internal/cli/commands/secrets.py +92 -0
  2. dstack/_internal/cli/main.py +2 -0
  3. dstack/_internal/cli/services/completion.py +5 -0
  4. dstack/_internal/cli/services/configurators/fleet.py +13 -1
  5. dstack/_internal/cli/services/configurators/run.py +59 -17
  6. dstack/_internal/cli/utils/secrets.py +25 -0
  7. dstack/_internal/core/backends/__init__.py +10 -4
  8. dstack/_internal/core/backends/aws/compute.py +237 -18
  9. dstack/_internal/core/backends/base/compute.py +20 -2
  10. dstack/_internal/core/backends/cudo/compute.py +23 -9
  11. dstack/_internal/core/backends/gcp/compute.py +13 -7
  12. dstack/_internal/core/backends/lambdalabs/compute.py +2 -1
  13. dstack/_internal/core/compatibility/fleets.py +12 -11
  14. dstack/_internal/core/compatibility/gateways.py +9 -8
  15. dstack/_internal/core/compatibility/logs.py +4 -3
  16. dstack/_internal/core/compatibility/runs.py +41 -17
  17. dstack/_internal/core/compatibility/volumes.py +9 -8
  18. dstack/_internal/core/errors.py +4 -0
  19. dstack/_internal/core/models/common.py +7 -0
  20. dstack/_internal/core/models/configurations.py +11 -0
  21. dstack/_internal/core/models/files.py +67 -0
  22. dstack/_internal/core/models/runs.py +14 -0
  23. dstack/_internal/core/models/secrets.py +9 -2
  24. dstack/_internal/core/services/diff.py +36 -3
  25. dstack/_internal/server/app.py +22 -0
  26. dstack/_internal/server/background/__init__.py +61 -37
  27. dstack/_internal/server/background/tasks/process_fleets.py +19 -3
  28. dstack/_internal/server/background/tasks/process_gateways.py +1 -1
  29. dstack/_internal/server/background/tasks/process_instances.py +13 -2
  30. dstack/_internal/server/background/tasks/process_placement_groups.py +4 -2
  31. dstack/_internal/server/background/tasks/process_running_jobs.py +123 -15
  32. dstack/_internal/server/background/tasks/process_runs.py +23 -7
  33. dstack/_internal/server/background/tasks/process_submitted_jobs.py +36 -7
  34. dstack/_internal/server/background/tasks/process_terminating_jobs.py +5 -3
  35. dstack/_internal/server/background/tasks/process_volumes.py +2 -2
  36. dstack/_internal/server/migrations/versions/5f1707c525d2_add_filearchivemodel.py +39 -0
  37. dstack/_internal/server/migrations/versions/644b8a114187_add_secretmodel.py +49 -0
  38. dstack/_internal/server/models.py +33 -0
  39. dstack/_internal/server/routers/files.py +67 -0
  40. dstack/_internal/server/routers/secrets.py +57 -15
  41. dstack/_internal/server/schemas/files.py +5 -0
  42. dstack/_internal/server/schemas/runner.py +2 -0
  43. dstack/_internal/server/schemas/secrets.py +7 -11
  44. dstack/_internal/server/services/backends/__init__.py +1 -1
  45. dstack/_internal/server/services/files.py +91 -0
  46. dstack/_internal/server/services/fleets.py +5 -4
  47. dstack/_internal/server/services/gateways/__init__.py +4 -2
  48. dstack/_internal/server/services/jobs/__init__.py +19 -8
  49. dstack/_internal/server/services/jobs/configurators/base.py +25 -3
  50. dstack/_internal/server/services/jobs/configurators/dev.py +3 -3
  51. dstack/_internal/server/services/locking.py +101 -12
  52. dstack/_internal/server/services/proxy/repo.py +3 -0
  53. dstack/_internal/server/services/runner/client.py +8 -0
  54. dstack/_internal/server/services/runs.py +76 -47
  55. dstack/_internal/server/services/secrets.py +204 -0
  56. dstack/_internal/server/services/storage/base.py +21 -0
  57. dstack/_internal/server/services/storage/gcs.py +28 -6
  58. dstack/_internal/server/services/storage/s3.py +27 -9
  59. dstack/_internal/server/services/volumes.py +2 -2
  60. dstack/_internal/server/settings.py +19 -5
  61. dstack/_internal/server/statics/index.html +1 -1
  62. dstack/_internal/server/statics/{main-a4eafa74304e587d037c.js → main-d1ac2e8c38ed5f08a114.js} +86 -34
  63. dstack/_internal/server/statics/{main-a4eafa74304e587d037c.js.map → main-d1ac2e8c38ed5f08a114.js.map} +1 -1
  64. dstack/_internal/server/statics/{main-f53d6d0d42f8d61df1de.css → main-d58fc0460cb0eae7cb5c.css} +1 -1
  65. dstack/_internal/server/statics/static/media/google.b194b06fafd0a52aeb566922160ea514.svg +1 -0
  66. dstack/_internal/server/testing/common.py +50 -8
  67. dstack/_internal/settings.py +4 -0
  68. dstack/_internal/utils/files.py +69 -0
  69. dstack/_internal/utils/nested_list.py +47 -0
  70. dstack/_internal/utils/path.py +12 -4
  71. dstack/api/_public/runs.py +67 -7
  72. dstack/api/server/__init__.py +6 -0
  73. dstack/api/server/_files.py +18 -0
  74. dstack/api/server/_secrets.py +15 -15
  75. dstack/version.py +1 -1
  76. {dstack-0.19.16.dist-info → dstack-0.19.18.dist-info}/METADATA +13 -13
  77. {dstack-0.19.16.dist-info → dstack-0.19.18.dist-info}/RECORD +80 -67
  78. {dstack-0.19.16.dist-info → dstack-0.19.18.dist-info}/WHEEL +0 -0
  79. {dstack-0.19.16.dist-info → dstack-0.19.18.dist-info}/entry_points.txt +0 -0
  80. {dstack-0.19.16.dist-info → dstack-0.19.18.dist-info}/licenses/LICENSE.md +0 -0
@@ -45,6 +45,7 @@ from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT
45
45
  from dstack._internal.core.errors import (
46
46
  BackendError,
47
47
  NotYetTerminated,
48
+ PlacementGroupNotSupportedError,
48
49
  ProvisioningError,
49
50
  )
50
51
  from dstack._internal.core.models.backends.base import BackendType
@@ -73,7 +74,7 @@ from dstack._internal.core.models.runs import (
73
74
  from dstack._internal.core.services.profiles import get_retry
74
75
  from dstack._internal.server import settings as server_settings
75
76
  from dstack._internal.server.background.tasks.common import get_provisioning_timeout
76
- from dstack._internal.server.db import get_session_ctx
77
+ from dstack._internal.server.db import get_db, get_session_ctx
77
78
  from dstack._internal.server.models import (
78
79
  FleetModel,
79
80
  InstanceModel,
@@ -110,6 +111,8 @@ from dstack._internal.utils.ssh import (
110
111
  pkey_from_str,
111
112
  )
112
113
 
114
+ MIN_PROCESSING_INTERVAL = timedelta(seconds=10)
115
+
113
116
  PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60)
114
117
 
115
118
  TERMINATION_DEADLINE_OFFSET = timedelta(minutes=20)
@@ -129,7 +132,7 @@ async def process_instances(batch_size: int = 1):
129
132
 
130
133
 
131
134
  async def _process_next_instance():
132
- lock, lockset = get_locker().get_lockset(InstanceModel.__tablename__)
135
+ lock, lockset = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__)
133
136
  async with get_session_ctx() as session:
134
137
  async with lock:
135
138
  res = await session.execute(
@@ -145,6 +148,8 @@ async def _process_next_instance():
145
148
  ]
146
149
  ),
147
150
  InstanceModel.id.not_in(lockset),
151
+ InstanceModel.last_processed_at
152
+ < get_current_datetime().replace(tzinfo=None) - MIN_PROCESSING_INTERVAL,
148
153
  )
149
154
  .options(lazyload(InstanceModel.jobs))
150
155
  .order_by(InstanceModel.last_processed_at.asc())
@@ -1063,6 +1068,12 @@ async def _create_placement_group(
1063
1068
  placement_group_model_to_placement_group(placement_group_model),
1064
1069
  master_instance_offer,
1065
1070
  )
1071
+ except PlacementGroupNotSupportedError:
1072
+ logger.debug(
1073
+ "Skipping offer %s because placement group not supported",
1074
+ master_instance_offer.instance.name,
1075
+ )
1076
+ return None
1066
1077
  except BackendError as e:
1067
1078
  logger.warning(
1068
1079
  "Failed to create placement group %s in %s/%s: %r",
@@ -7,7 +7,7 @@ from sqlalchemy.orm import joinedload
7
7
 
8
8
  from dstack._internal.core.backends.base.compute import ComputeWithPlacementGroupSupport
9
9
  from dstack._internal.core.errors import PlacementGroupInUseError
10
- from dstack._internal.server.db import get_session_ctx
10
+ from dstack._internal.server.db import get_db, get_session_ctx
11
11
  from dstack._internal.server.models import PlacementGroupModel, ProjectModel
12
12
  from dstack._internal.server.services import backends as backends_services
13
13
  from dstack._internal.server.services.locking import get_locker
@@ -19,7 +19,9 @@ logger = get_logger(__name__)
19
19
 
20
20
 
21
21
  async def process_placement_groups():
22
- lock, lockset = get_locker().get_lockset(PlacementGroupModel.__tablename__)
22
+ lock, lockset = get_locker(get_db().dialect_name).get_lockset(
23
+ PlacementGroupModel.__tablename__
24
+ )
23
25
  async with get_session_ctx() as session:
24
26
  async with lock:
25
27
  res = await session.execute(
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import re
3
+ import uuid
3
4
  from collections.abc import Iterable
4
5
  from datetime import timedelta, timezone
5
6
  from typing import Dict, List, Optional
@@ -14,6 +15,7 @@ from dstack._internal.core.errors import GatewayError
14
15
  from dstack._internal.core.models.backends.base import BackendType
15
16
  from dstack._internal.core.models.common import NetworkMode, RegistryAuth
16
17
  from dstack._internal.core.models.configurations import DevEnvironmentConfiguration
18
+ from dstack._internal.core.models.files import FileArchiveMapping
17
19
  from dstack._internal.core.models.instances import (
18
20
  InstanceStatus,
19
21
  RemoteConnectionInfo,
@@ -32,18 +34,21 @@ from dstack._internal.core.models.runs import (
32
34
  JobTerminationReason,
33
35
  Run,
34
36
  RunSpec,
37
+ RunStatus,
35
38
  )
36
39
  from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint
37
40
  from dstack._internal.server.background.tasks.common import get_provisioning_timeout
38
- from dstack._internal.server.db import get_session_ctx
41
+ from dstack._internal.server.db import get_db, get_session_ctx
39
42
  from dstack._internal.server.models import (
40
43
  InstanceModel,
41
44
  JobModel,
42
45
  ProjectModel,
43
46
  RepoModel,
44
47
  RunModel,
48
+ UserModel,
45
49
  )
46
50
  from dstack._internal.server.schemas.runner import GPUDevice, TaskStatus
51
+ from dstack._internal.server.services import files as files_services
47
52
  from dstack._internal.server.services import logs as logs_services
48
53
  from dstack._internal.server.services import services
49
54
  from dstack._internal.server.services.instances import get_instance_ssh_private_keys
@@ -66,14 +71,16 @@ from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
66
71
  from dstack._internal.server.services.runs import (
67
72
  run_model_to_run,
68
73
  )
74
+ from dstack._internal.server.services.secrets import get_project_secrets_mapping
69
75
  from dstack._internal.server.services.storage import get_default_storage
70
76
  from dstack._internal.utils import common as common_utils
71
- from dstack._internal.utils.interpolator import VariablesInterpolator
77
+ from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
72
78
  from dstack._internal.utils.logging import get_logger
73
79
 
74
80
  logger = get_logger(__name__)
75
81
 
76
82
 
83
+ MIN_PROCESSING_INTERVAL = timedelta(seconds=10)
77
84
  # Minimum time before terminating active job in case of connectivity issues.
78
85
  # Should be sufficient to survive most problems caused by
79
86
  # the server network flickering and providers' glitches.
@@ -88,20 +95,29 @@ async def process_running_jobs(batch_size: int = 1):
88
95
 
89
96
 
90
97
  async def _process_next_running_job():
91
- lock, lockset = get_locker().get_lockset(JobModel.__tablename__)
98
+ lock, lockset = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__)
92
99
  async with get_session_ctx() as session:
93
100
  async with lock:
94
101
  res = await session.execute(
95
102
  select(JobModel)
103
+ .join(JobModel.run)
96
104
  .where(
97
105
  JobModel.status.in_(
98
106
  [JobStatus.PROVISIONING, JobStatus.PULLING, JobStatus.RUNNING]
99
107
  ),
108
+ RunModel.status.not_in([RunStatus.TERMINATING]),
100
109
  JobModel.id.not_in(lockset),
110
+ JobModel.last_processed_at
111
+ < common_utils.get_current_datetime().replace(tzinfo=None)
112
+ - MIN_PROCESSING_INTERVAL,
101
113
  )
102
114
  .order_by(JobModel.last_processed_at.asc())
103
115
  .limit(1)
104
- .with_for_update(skip_locked=True, key_share=True)
116
+ .with_for_update(
117
+ skip_locked=True,
118
+ key_share=True,
119
+ of=JobModel,
120
+ )
105
121
  )
106
122
  job_model = res.unique().scalar()
107
123
  if job_model is None:
@@ -177,7 +193,17 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
177
193
  common_utils.get_or_error(job_model.instance)
178
194
  )
179
195
 
180
- secrets = {} # TODO secrets
196
+ secrets = await get_project_secrets_mapping(session=session, project=project)
197
+
198
+ try:
199
+ _interpolate_secrets(secrets, job.job_spec)
200
+ except InterpolatorError as e:
201
+ logger.info("%s: terminating due to secrets interpolation error", fmt(job_model))
202
+ job_model.status = JobStatus.TERMINATING
203
+ job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
204
+ job_model.termination_reason_message = e.args[0]
205
+ job_model.last_processed_at = common_utils.get_current_datetime()
206
+ return
181
207
 
182
208
  repo_creds_model = await get_repo_creds(session=session, repo=repo_model, user=run_model.user)
183
209
  repo_creds = repo_model_to_repo_head_with_creds(repo_model, repo_creds_model).repo_creds
@@ -214,7 +240,6 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
214
240
  job_model,
215
241
  job_provisioning_data,
216
242
  volumes,
217
- secrets,
218
243
  job.job_spec.registry_auth,
219
244
  public_keys,
220
245
  ssh_user,
@@ -226,12 +251,20 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
226
251
  fmt(job_model),
227
252
  job_submission.age,
228
253
  )
254
+ # FIXME: downloading file archives and code here is a waste of time if
255
+ # the runner is not ready yet
256
+ file_archives = await _get_job_file_archives(
257
+ session=session,
258
+ archive_mappings=job.job_spec.file_archives,
259
+ user=run_model.user,
260
+ )
229
261
  code = await _get_job_code(
230
262
  session=session,
231
263
  project=project,
232
264
  repo=repo_model,
233
- code_hash=run.run_spec.repo_code_hash,
265
+ code_hash=_get_repo_code_hash(run, job),
234
266
  )
267
+
235
268
  success = await common_utils.run_async(
236
269
  _submit_job_to_runner,
237
270
  server_ssh_private_keys,
@@ -242,6 +275,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
242
275
  job,
243
276
  cluster_info,
244
277
  code,
278
+ file_archives,
245
279
  secrets,
246
280
  repo_creds,
247
281
  success_if_not_available=False,
@@ -269,11 +303,18 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
269
303
  logger.debug(
270
304
  "%s: process pulling job with shim, age=%s", fmt(job_model), job_submission.age
271
305
  )
306
+ # FIXME: downloading file archives and code here is a waste of time if
307
+ # the runner is not ready yet
308
+ file_archives = await _get_job_file_archives(
309
+ session=session,
310
+ archive_mappings=job.job_spec.file_archives,
311
+ user=run_model.user,
312
+ )
272
313
  code = await _get_job_code(
273
314
  session=session,
274
315
  project=project,
275
316
  repo=repo_model,
276
- code_hash=run.run_spec.repo_code_hash,
317
+ code_hash=_get_repo_code_hash(run, job),
277
318
  )
278
319
  success = await common_utils.run_async(
279
320
  _process_pulling_with_shim,
@@ -285,6 +326,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
285
326
  job,
286
327
  cluster_info,
287
328
  code,
329
+ file_archives,
288
330
  secrets,
289
331
  repo_creds,
290
332
  server_ssh_private_keys,
@@ -306,8 +348,9 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
306
348
  else:
307
349
  if job_model.termination_reason:
308
350
  logger.warning(
309
- "%s: failed because shim/runner returned an error, age=%s",
351
+ "%s: failed due to %s, age=%s",
310
352
  fmt(job_model),
353
+ job_model.termination_reason.value,
311
354
  job_submission.age,
312
355
  )
313
356
  job_model.status = JobStatus.TERMINATING
@@ -450,7 +493,6 @@ def _process_provisioning_with_shim(
450
493
  job_model: JobModel,
451
494
  job_provisioning_data: JobProvisioningData,
452
495
  volumes: List[Volume],
453
- secrets: Dict[str, str],
454
496
  registry_auth: Optional[RegistryAuth],
455
497
  public_keys: List[str],
456
498
  ssh_user: str,
@@ -476,10 +518,8 @@ def _process_provisioning_with_shim(
476
518
  registry_username = ""
477
519
  registry_password = ""
478
520
  if registry_auth is not None:
479
- logger.debug("%s: authenticating to the registry...", fmt(job_model))
480
- interpolate = VariablesInterpolator({"secrets": secrets}).interpolate
481
- registry_username = interpolate(registry_auth.username)
482
- registry_password = interpolate(registry_auth.password)
521
+ registry_username = registry_auth.username
522
+ registry_password = registry_auth.password
483
523
 
484
524
  volume_mounts: List[VolumeMountPoint] = []
485
525
  instance_mounts: List[InstanceMountPoint] = []
@@ -588,6 +628,7 @@ def _process_pulling_with_shim(
588
628
  job: Job,
589
629
  cluster_info: ClusterInfo,
590
630
  code: bytes,
631
+ file_archives: Iterable[tuple[uuid.UUID, bytes]],
591
632
  secrets: Dict[str, str],
592
633
  repo_credentials: Optional[RemoteRepoCreds],
593
634
  server_ssh_private_keys: tuple[str, Optional[str]],
@@ -663,6 +704,7 @@ def _process_pulling_with_shim(
663
704
  job=job,
664
705
  cluster_info=cluster_info,
665
706
  code=code,
707
+ file_archives=file_archives,
666
708
  secrets=secrets,
667
709
  repo_credentials=repo_credentials,
668
710
  success_if_not_available=True,
@@ -826,6 +868,19 @@ def _get_cluster_info(
826
868
  return cluster_info
827
869
 
828
870
 
871
+ def _get_repo_code_hash(run: Run, job: Job) -> Optional[str]:
872
+ # TODO: drop this function when supporting jobs submitted before 0.19.17 is no longer relevant.
873
+ if (
874
+ job.job_spec.repo_code_hash is None
875
+ and run.run_spec.repo_code_hash is not None
876
+ and job.job_submissions[-1].deployment_num == run.deployment_num
877
+ ):
878
+ # The job spec does not have `repo_code_hash`, because it was submitted before 0.19.17.
879
+ # Use `repo_code_hash` from the run.
880
+ return run.run_spec.repo_code_hash
881
+ return job.job_spec.repo_code_hash
882
+
883
+
829
884
  async def _get_job_code(
830
885
  session: AsyncSession, project: ProjectModel, repo: RepoModel, code_hash: Optional[str]
831
886
  ) -> bytes:
@@ -853,6 +908,43 @@ async def _get_job_code(
853
908
  return blob
854
909
 
855
910
 
911
+ async def _get_job_file_archives(
912
+ session: AsyncSession,
913
+ archive_mappings: Iterable[FileArchiveMapping],
914
+ user: UserModel,
915
+ ) -> list[tuple[uuid.UUID, bytes]]:
916
+ archives: list[tuple[uuid.UUID, bytes]] = []
917
+ for archive_mapping in archive_mappings:
918
+ archive_id = archive_mapping.id
919
+ archive_blob = await _get_job_file_archive(
920
+ session=session, archive_id=archive_id, user=user
921
+ )
922
+ archives.append((archive_id, archive_blob))
923
+ return archives
924
+
925
+
926
+ async def _get_job_file_archive(
927
+ session: AsyncSession, archive_id: uuid.UUID, user: UserModel
928
+ ) -> bytes:
929
+ archive_model = await files_services.get_archive_model(session, id=archive_id, user=user)
930
+ if archive_model is None:
931
+ return b""
932
+ if archive_model.blob is not None:
933
+ return archive_model.blob
934
+ storage = get_default_storage()
935
+ if storage is None:
936
+ return b""
937
+ blob = await common_utils.run_async(
938
+ storage.get_archive,
939
+ str(archive_model.user_id),
940
+ archive_model.blob_hash,
941
+ )
942
+ if blob is None:
943
+ logger.error("Failed to get file archive %s from storage", archive_id)
944
+ return b""
945
+ return blob
946
+
947
+
856
948
  @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1)
857
949
  def _submit_job_to_runner(
858
950
  ports: Dict[int, int],
@@ -861,6 +953,7 @@ def _submit_job_to_runner(
861
953
  job: Job,
862
954
  cluster_info: ClusterInfo,
863
955
  code: bytes,
956
+ file_archives: Iterable[tuple[uuid.UUID, bytes]],
864
957
  secrets: Dict[str, str],
865
958
  repo_credentials: Optional[RemoteRepoCreds],
866
959
  success_if_not_available: bool,
@@ -896,10 +989,15 @@ def _submit_job_to_runner(
896
989
  run=run,
897
990
  job=job,
898
991
  cluster_info=cluster_info,
899
- secrets=secrets,
992
+ # Do not send all the secrets since interpolation is already done by the server.
993
+ # TODO: Passing secrets may be necessary for filtering out secret values from logs.
994
+ secrets={},
900
995
  repo_credentials=repo_credentials,
901
996
  instance_env=instance_env,
902
997
  )
998
+ logger.debug("%s: uploading file archive(s)", fmt(job_model))
999
+ for archive_id, archive in file_archives:
1000
+ runner_client.upload_archive(archive_id, archive)
903
1001
  logger.debug("%s: uploading code", fmt(job_model))
904
1002
  runner_client.upload_code(code)
905
1003
  logger.debug("%s: starting job", fmt(job_model))
@@ -911,6 +1009,16 @@ def _submit_job_to_runner(
911
1009
  return True
912
1010
 
913
1011
 
1012
+ def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec):
1013
+ interpolate = VariablesInterpolator({"secrets": secrets}).interpolate_or_error
1014
+ job_spec.env = {k: interpolate(v) for k, v in job_spec.env.items()}
1015
+ if job_spec.registry_auth is not None:
1016
+ job_spec.registry_auth = RegistryAuth(
1017
+ username=interpolate(job_spec.registry_auth.username),
1018
+ password=interpolate(job_spec.registry_auth.password),
1019
+ )
1020
+
1021
+
914
1022
  def _get_instance_specific_mounts(
915
1023
  backend_type: BackendType, instance_type_name: str
916
1024
  ) -> List[InstanceMountPoint]:
@@ -19,7 +19,7 @@ from dstack._internal.core.models.runs import (
19
19
  RunStatus,
20
20
  RunTerminationReason,
21
21
  )
22
- from dstack._internal.server.db import get_session_ctx
22
+ from dstack._internal.server.db import get_db, get_session_ctx
23
23
  from dstack._internal.server.models import JobModel, ProjectModel, RunModel
24
24
  from dstack._internal.server.services.jobs import (
25
25
  find_job,
@@ -35,11 +35,14 @@ from dstack._internal.server.services.runs import (
35
35
  run_model_to_run,
36
36
  scale_run_replicas,
37
37
  )
38
+ from dstack._internal.server.services.secrets import get_project_secrets_mapping
38
39
  from dstack._internal.server.services.services import update_service_desired_replica_count
39
40
  from dstack._internal.utils import common
40
41
  from dstack._internal.utils.logging import get_logger
41
42
 
42
43
  logger = get_logger(__name__)
44
+
45
+ MIN_PROCESSING_INTERVAL = datetime.timedelta(seconds=5)
43
46
  ROLLING_DEPLOYMENT_MAX_SURGE = 1 # at most one extra replica during rolling deployment
44
47
 
45
48
 
@@ -51,8 +54,8 @@ async def process_runs(batch_size: int = 1):
51
54
 
52
55
 
53
56
  async def _process_next_run():
54
- run_lock, run_lockset = get_locker().get_lockset(RunModel.__tablename__)
55
- job_lock, job_lockset = get_locker().get_lockset(JobModel.__tablename__)
57
+ run_lock, run_lockset = get_locker(get_db().dialect_name).get_lockset(RunModel.__tablename__)
58
+ job_lock, job_lockset = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__)
56
59
  async with get_session_ctx() as session:
57
60
  async with run_lock, job_lock:
58
61
  res = await session.execute(
@@ -60,6 +63,8 @@ async def _process_next_run():
60
63
  .where(
61
64
  RunModel.status.not_in(RunStatus.finished_statuses()),
62
65
  RunModel.id.not_in(run_lockset),
66
+ RunModel.last_processed_at
67
+ < common.get_current_datetime().replace(tzinfo=None) - MIN_PROCESSING_INTERVAL,
63
68
  )
64
69
  .order_by(RunModel.last_processed_at.asc())
65
70
  .limit(1)
@@ -336,7 +341,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
336
341
  current_time - run_model.submitted_at.replace(tzinfo=datetime.timezone.utc)
337
342
  ).total_seconds()
338
343
  logger.info(
339
- "%s: run took %.2f seconds from submision to provisioning.",
344
+ "%s: run took %.2f seconds from submission to provisioning.",
340
345
  fmt(run_model),
341
346
  submit_to_provision_duration,
342
347
  )
@@ -404,7 +409,11 @@ async def _handle_run_replicas(
404
409
  )
405
410
  return
406
411
 
407
- await _update_jobs_to_new_deployment_in_place(run_model, run_spec)
412
+ await _update_jobs_to_new_deployment_in_place(
413
+ session=session,
414
+ run_model=run_model,
415
+ run_spec=run_spec,
416
+ )
408
417
  if _has_out_of_date_replicas(run_model):
409
418
  non_terminated_replica_count = len(
410
419
  {j.replica_num for j in run_model.jobs if not j.status.is_finished()}
@@ -444,18 +453,25 @@ async def _handle_run_replicas(
444
453
  )
445
454
 
446
455
 
447
- async def _update_jobs_to_new_deployment_in_place(run_model: RunModel, run_spec: RunSpec) -> None:
456
+ async def _update_jobs_to_new_deployment_in_place(
457
+ session: AsyncSession, run_model: RunModel, run_spec: RunSpec
458
+ ) -> None:
448
459
  """
449
460
  Bump deployment_num for jobs that do not require redeployment.
450
461
  """
451
-
462
+ secrets = await get_project_secrets_mapping(
463
+ session=session,
464
+ project=run_model.project,
465
+ )
452
466
  for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
453
467
  if all(j.status.is_finished() for j in job_models):
454
468
  continue
455
469
  if all(j.deployment_num == run_model.deployment_num for j in job_models):
456
470
  continue
471
+ # FIXME: Handle getting image configuration errors or skip it.
457
472
  new_job_specs = await get_job_specs_from_run_spec(
458
473
  run_spec=run_spec,
474
+ secrets=secrets,
459
475
  replica_num=replica_num,
460
476
  )
461
477
  assert len(new_job_specs) == len(job_models), (
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import uuid
3
+ from datetime import datetime, timedelta
3
4
  from typing import List, Optional, Tuple
4
5
 
5
6
  from sqlalchemy import select
@@ -80,15 +81,35 @@ from dstack._internal.utils.logging import get_logger
80
81
  logger = get_logger(__name__)
81
82
 
82
83
 
84
+ # Track when we last processed a job.
85
+ # This is needed for a trick:
86
+ # If no tasks were processed recently, we force batch_size 1.
87
+ # If there are lots of runs/jobs with same offers submitted,
88
+ # we warm up the cache instead of requesting the offers concurrently.
89
+ # Mostly useful when runs are submitted via API without getting run plan first.
90
+ BATCH_SIZE_RESET_TIMEOUT = timedelta(minutes=2)
91
+ last_processed_at: Optional[datetime] = None
92
+
93
+
83
94
  async def process_submitted_jobs(batch_size: int = 1):
84
95
  tasks = []
85
- for _ in range(batch_size):
96
+ effective_batch_size = _get_effective_batch_size(batch_size)
97
+ for _ in range(effective_batch_size):
86
98
  tasks.append(_process_next_submitted_job())
87
99
  await asyncio.gather(*tasks)
88
100
 
89
101
 
102
+ def _get_effective_batch_size(batch_size: int) -> int:
103
+ if (
104
+ last_processed_at is None
105
+ or last_processed_at < common_utils.get_current_datetime() - BATCH_SIZE_RESET_TIMEOUT
106
+ ):
107
+ return 1
108
+ return batch_size
109
+
110
+
90
111
  async def _process_next_submitted_job():
91
- lock, lockset = get_locker().get_lockset(JobModel.__tablename__)
112
+ lock, lockset = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__)
92
113
  async with get_session_ctx() as session:
93
114
  async with lock:
94
115
  res = await session.execute(
@@ -125,6 +146,8 @@ async def _process_next_submitted_job():
125
146
  await _process_submitted_job(session=session, job_model=job_model)
126
147
  finally:
127
148
  lockset.difference_update([job_model_id])
149
+ global last_processed_at
150
+ last_processed_at = common_utils.get_current_datetime()
128
151
 
129
152
 
130
153
  async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
@@ -214,7 +237,9 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
214
237
  if get_db().dialect_name == "sqlite":
215
238
  # Start new transaction to see committed changes after lock
216
239
  await session.commit()
217
- async with get_locker().lock_ctx(InstanceModel.__tablename__, instances_ids):
240
+ async with get_locker(get_db().dialect_name).lock_ctx(
241
+ InstanceModel.__tablename__, instances_ids
242
+ ):
218
243
  # If another job freed the instance but is still trying to detach volumes,
219
244
  # do not provision on it to prevent attaching volumes that are currently detaching.
220
245
  detaching_instances_ids = await get_instances_ids_with_detaching_volumes(session)
@@ -243,8 +268,10 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
243
268
  )
244
269
  job_model.instance_assigned = True
245
270
  job_model.last_processed_at = common_utils.get_current_datetime()
246
- await session.commit()
247
- return
271
+ if len(pool_instances) > 0:
272
+ await session.commit()
273
+ return
274
+ # If no instances were locked, we can proceed in the same transaction.
248
275
 
249
276
  if job_model.instance is not None:
250
277
  res = await session.execute(
@@ -334,7 +361,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
334
361
  .order_by(VolumeModel.id) # take locks in order
335
362
  .with_for_update(key_share=True)
336
363
  )
337
- async with get_locker().lock_ctx(VolumeModel.__tablename__, volumes_ids):
364
+ async with get_locker(get_db().dialect_name).lock_ctx(VolumeModel.__tablename__, volumes_ids):
338
365
  if len(volume_models) > 0:
339
366
  await _attach_volumes(
340
367
  session=session,
@@ -527,7 +554,9 @@ async def _get_next_instance_num(session: AsyncSession, fleet_model: FleetModel)
527
554
  if len(fleet_model.instances) == 0:
528
555
  # No instances means the fleet is not in the db yet, so don't lock.
529
556
  return 0
530
- async with get_locker().lock_ctx(FleetModel.__tablename__, [fleet_model.id]):
557
+ async with get_locker(get_db().dialect_name).lock_ctx(
558
+ FleetModel.__tablename__, [fleet_model.id]
559
+ ):
531
560
  fleet_model = (
532
561
  (
533
562
  await session.execute(
@@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
5
5
  from sqlalchemy.orm import joinedload, lazyload
6
6
 
7
7
  from dstack._internal.core.models.runs import JobStatus
8
- from dstack._internal.server.db import get_session_ctx
8
+ from dstack._internal.server.db import get_db, get_session_ctx
9
9
  from dstack._internal.server.models import (
10
10
  InstanceModel,
11
11
  JobModel,
@@ -32,8 +32,10 @@ async def process_terminating_jobs(batch_size: int = 1):
32
32
 
33
33
 
34
34
  async def _process_next_terminating_job():
35
- job_lock, job_lockset = get_locker().get_lockset(JobModel.__tablename__)
36
- instance_lock, instance_lockset = get_locker().get_lockset(InstanceModel.__tablename__)
35
+ job_lock, job_lockset = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__)
36
+ instance_lock, instance_lockset = get_locker(get_db().dialect_name).get_lockset(
37
+ InstanceModel.__tablename__
38
+ )
37
39
  async with get_session_ctx() as session:
38
40
  async with job_lock, instance_lock:
39
41
  res = await session.execute(
@@ -5,7 +5,7 @@ from sqlalchemy.orm import joinedload
5
5
  from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport
6
6
  from dstack._internal.core.errors import BackendError, BackendNotAvailable
7
7
  from dstack._internal.core.models.volumes import VolumeStatus
8
- from dstack._internal.server.db import get_session_ctx
8
+ from dstack._internal.server.db import get_db, get_session_ctx
9
9
  from dstack._internal.server.models import (
10
10
  InstanceModel,
11
11
  ProjectModel,
@@ -22,7 +22,7 @@ logger = get_logger(__name__)
22
22
 
23
23
 
24
24
  async def process_submitted_volumes():
25
- lock, lockset = get_locker().get_lockset(VolumeModel.__tablename__)
25
+ lock, lockset = get_locker(get_db().dialect_name).get_lockset(VolumeModel.__tablename__)
26
26
  async with get_session_ctx() as session:
27
27
  async with lock:
28
28
  res = await session.execute(
@@ -0,0 +1,39 @@
1
+ """Add FileArchiveModel
2
+
3
+ Revision ID: 5f1707c525d2
4
+ Revises: 35e90e1b0d3e
5
+ Create Date: 2025-06-12 12:28:26.678380
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ import sqlalchemy_utils
11
+ from alembic import op
12
+
13
+ # revision identifiers, used by Alembic.
14
+ revision = "5f1707c525d2"
15
+ down_revision = "35e90e1b0d3e"
16
+ branch_labels = None
17
+ depends_on = None
18
+
19
+
20
+ def upgrade() -> None:
21
+ op.create_table(
22
+ "file_archives",
23
+ sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False),
24
+ sa.Column("user_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False),
25
+ sa.Column("blob_hash", sa.Text(), nullable=False),
26
+ sa.Column("blob", sa.LargeBinary(), nullable=True),
27
+ sa.ForeignKeyConstraint(
28
+ ["user_id"],
29
+ ["users.id"],
30
+ name=op.f("fk_file_archives_user_id_users"),
31
+ ondelete="CASCADE",
32
+ ),
33
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_file_archives")),
34
+ sa.UniqueConstraint("user_id", "blob_hash", name="uq_file_archives_user_id_blob_hash"),
35
+ )
36
+
37
+
38
+ def downgrade() -> None:
39
+ op.drop_table("file_archives")