dstack 0.19.12rc1__py3-none-any.whl → 0.19.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (62) hide show
  1. dstack/_internal/cli/commands/attach.py +4 -4
  2. dstack/_internal/cli/services/configurators/run.py +44 -47
  3. dstack/_internal/cli/utils/run.py +31 -31
  4. dstack/_internal/core/backends/aws/compute.py +22 -9
  5. dstack/_internal/core/backends/aws/resources.py +26 -0
  6. dstack/_internal/core/backends/base/offers.py +0 -1
  7. dstack/_internal/core/backends/template/configurator.py.jinja +1 -6
  8. dstack/_internal/core/backends/template/models.py.jinja +4 -0
  9. dstack/_internal/core/compatibility/__init__.py +0 -0
  10. dstack/_internal/core/compatibility/fleets.py +72 -0
  11. dstack/_internal/core/compatibility/gateways.py +34 -0
  12. dstack/_internal/core/compatibility/runs.py +131 -0
  13. dstack/_internal/core/compatibility/volumes.py +32 -0
  14. dstack/_internal/core/models/configurations.py +1 -1
  15. dstack/_internal/core/models/fleets.py +6 -1
  16. dstack/_internal/core/models/instances.py +51 -12
  17. dstack/_internal/core/models/profiles.py +43 -3
  18. dstack/_internal/core/models/projects.py +1 -0
  19. dstack/_internal/core/models/repos/local.py +3 -3
  20. dstack/_internal/core/models/runs.py +139 -43
  21. dstack/_internal/server/app.py +46 -1
  22. dstack/_internal/server/background/tasks/process_running_jobs.py +92 -15
  23. dstack/_internal/server/background/tasks/process_runs.py +163 -80
  24. dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py +42 -0
  25. dstack/_internal/server/migrations/versions/35f732ee4cf5_add_projectmodel_is_public.py +39 -0
  26. dstack/_internal/server/models.py +4 -0
  27. dstack/_internal/server/routers/projects.py +4 -3
  28. dstack/_internal/server/routers/prometheus.py +4 -1
  29. dstack/_internal/server/schemas/projects.py +1 -0
  30. dstack/_internal/server/security/permissions.py +36 -0
  31. dstack/_internal/server/services/jobs/__init__.py +1 -0
  32. dstack/_internal/server/services/jobs/configurators/base.py +11 -7
  33. dstack/_internal/server/services/projects.py +54 -1
  34. dstack/_internal/server/services/runner/client.py +4 -1
  35. dstack/_internal/server/services/runs.py +49 -29
  36. dstack/_internal/server/services/services/__init__.py +19 -0
  37. dstack/_internal/server/services/services/autoscalers.py +37 -26
  38. dstack/_internal/server/services/storage/__init__.py +38 -0
  39. dstack/_internal/server/services/storage/base.py +27 -0
  40. dstack/_internal/server/services/storage/gcs.py +44 -0
  41. dstack/_internal/server/services/{storage.py → storage/s3.py} +4 -27
  42. dstack/_internal/server/settings.py +7 -3
  43. dstack/_internal/server/statics/index.html +1 -1
  44. dstack/_internal/server/statics/{main-5b9786c955b42bf93581.js → main-0ac1e1583684417ae4d1.js} +1695 -62
  45. dstack/_internal/server/statics/{main-5b9786c955b42bf93581.js.map → main-0ac1e1583684417ae4d1.js.map} +1 -1
  46. dstack/_internal/server/statics/{main-8f9c66f404e9c7e7e020.css → main-f39c418b05fe14772dd8.css} +1 -1
  47. dstack/_internal/server/testing/common.py +11 -1
  48. dstack/_internal/settings.py +3 -0
  49. dstack/_internal/utils/common.py +4 -0
  50. dstack/api/_public/runs.py +14 -5
  51. dstack/api/server/_fleets.py +9 -69
  52. dstack/api/server/_gateways.py +3 -14
  53. dstack/api/server/_projects.py +2 -2
  54. dstack/api/server/_runs.py +4 -116
  55. dstack/api/server/_volumes.py +3 -14
  56. dstack/plugins/builtin/rest_plugin/_plugin.py +24 -5
  57. dstack/version.py +2 -2
  58. {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/METADATA +1 -1
  59. {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/RECORD +62 -52
  60. {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/WHEEL +0 -0
  61. {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/entry_points.txt +0 -0
  62. {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/licenses/LICENSE.md +0 -0
@@ -14,7 +14,7 @@ from dstack._internal.server.schemas.projects import (
14
14
  from dstack._internal.server.security.permissions import (
15
15
  Authenticated,
16
16
  ProjectManager,
17
- ProjectMember,
17
+ ProjectMemberOrPublicAccess,
18
18
  )
19
19
  from dstack._internal.server.services import projects
20
20
  from dstack._internal.server.utils.routers import get_base_api_additional_responses
@@ -36,7 +36,7 @@ async def list_projects(
36
36
 
37
37
  `members` and `backends` are always empty - call `/api/projects/{project_name}/get` to retrieve them.
38
38
  """
39
- return await projects.list_user_projects(session=session, user=user)
39
+ return await projects.list_user_accessible_projects(session=session, user=user)
40
40
 
41
41
 
42
42
  @router.post("/create")
@@ -49,6 +49,7 @@ async def create_project(
49
49
  session=session,
50
50
  user=user,
51
51
  project_name=body.project_name,
52
+ is_public=body.is_public,
52
53
  )
53
54
 
54
55
 
@@ -68,7 +69,7 @@ async def delete_projects(
68
69
  @router.post("/{project_name}/get")
69
70
  async def get_project(
70
71
  session: AsyncSession = Depends(get_session),
71
- user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
72
+ user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()),
72
73
  ) -> Project:
73
74
  _, project = user_project
74
75
  return projects.project_model_to_project(project)
@@ -3,6 +3,7 @@ from typing import Annotated
3
3
 
4
4
  from fastapi import APIRouter, Depends
5
5
  from fastapi.responses import PlainTextResponse
6
+ from prometheus_client import generate_latest
6
7
  from sqlalchemy.ext.asyncio import AsyncSession
7
8
 
8
9
  from dstack._internal.server import settings
@@ -26,4 +27,6 @@ async def get_prometheus_metrics(
26
27
  ) -> str:
27
28
  if not settings.ENABLE_PROMETHEUS_METRICS:
28
29
  raise error_not_found()
29
- return await prometheus.get_metrics(session=session)
30
+ custom_metrics = await prometheus.get_metrics(session=session)
31
+ prometheus_metrics = generate_latest()
32
+ return custom_metrics + prometheus_metrics.decode()
@@ -8,6 +8,7 @@ from dstack._internal.core.models.users import ProjectRole
8
8
 
9
9
  class CreateProjectRequest(CoreModel):
10
10
  project_name: str
11
+ is_public: bool = False
11
12
 
12
13
 
13
14
  class DeleteProjectsRequest(CoreModel):
@@ -99,6 +99,42 @@ class ProjectMember:
99
99
  return await get_project_member(session, project_name, token.credentials)
100
100
 
101
101
 
102
+ class ProjectMemberOrPublicAccess:
103
+ """
104
+ Allows access to project for:
105
+ - Global admins
106
+ - Project members
107
+ - Any authenticated user if the project is public
108
+ """
109
+
110
+ async def __call__(
111
+ self,
112
+ *,
113
+ session: AsyncSession = Depends(get_session),
114
+ project_name: str,
115
+ token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
116
+ ) -> Tuple[UserModel, ProjectModel]:
117
+ user = await log_in_with_token(session=session, token=token.credentials)
118
+ if user is None:
119
+ raise error_invalid_token()
120
+
121
+ project = await get_project_model_by_name(session=session, project_name=project_name)
122
+ if project is None:
123
+ raise error_not_found()
124
+
125
+ if user.global_role == GlobalRole.ADMIN:
126
+ return user, project
127
+
128
+ project_role = get_user_project_role(user=user, project=project)
129
+ if project_role is not None:
130
+ return user, project
131
+
132
+ if project.is_public:
133
+ return user, project
134
+
135
+ raise error_forbidden()
136
+
137
+
102
138
  class OptionalServiceAccount:
103
139
  def __init__(self, token: Optional[str]) -> None:
104
140
  self._token = token
@@ -128,6 +128,7 @@ def job_model_to_job_submission(job_model: JobModel) -> JobSubmission:
128
128
  return JobSubmission(
129
129
  id=job_model.id,
130
130
  submission_num=job_model.submission_num,
131
+ deployment_num=job_model.deployment_num,
131
132
  submitted_at=job_model.submitted_at.replace(tzinfo=timezone.utc),
132
133
  last_processed_at=last_processed_at,
133
134
  finished_at=finished_at,
@@ -50,11 +50,15 @@ def get_default_python_verison() -> str:
50
50
  )
51
51
 
52
52
 
53
- def get_default_image(python_version: str, nvcc: bool = False) -> str:
54
- suffix = ""
55
- if nvcc:
56
- suffix = "-devel"
57
- return f"{settings.DSTACK_BASE_IMAGE}:py{python_version}-{settings.DSTACK_BASE_IMAGE_VERSION}-cuda-12.1{suffix}"
53
+ def get_default_image(nvcc: bool = False) -> str:
54
+ """
55
+ Note: May be overridden by dstack (e.g., EFA-enabled version for AWS EFA-capable instances).
56
+ See `dstack._internal.server.background.tasks.process_running_jobs._patch_base_image_for_aws_efa` for details.
57
+
58
+ Args:
59
+ nvcc: If True, returns 'devel' variant, otherwise 'base'.
60
+ """
61
+ return f"{settings.DSTACK_BASE_IMAGE}:{settings.DSTACK_BASE_IMAGE_VERSION}-{'devel' if nvcc else 'base'}-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"
58
62
 
59
63
 
60
64
  class JobConfigurator(ABC):
@@ -173,7 +177,7 @@ class JobConfigurator(ABC):
173
177
  ):
174
178
  return []
175
179
  return [
176
- f"uv venv --prompt workflow --seed {DEFAULT_REPO_DIR}/.venv > /dev/null 2>&1",
180
+ f"uv venv --python {self._python()} --prompt workflow --seed {DEFAULT_REPO_DIR}/.venv > /dev/null 2>&1",
177
181
  f"echo 'source {DEFAULT_REPO_DIR}/.venv/bin/activate' >> ~/.bashrc",
178
182
  f"source {DEFAULT_REPO_DIR}/.venv/bin/activate",
179
183
  ]
@@ -199,7 +203,7 @@ class JobConfigurator(ABC):
199
203
  def _image_name(self) -> str:
200
204
  if self.run_spec.configuration.image is not None:
201
205
  return self.run_spec.configuration.image
202
- return get_default_image(self._python(), nvcc=bool(self.run_spec.configuration.nvcc))
206
+ return get_default_image(nvcc=bool(self.run_spec.configuration.nvcc))
203
207
 
204
208
  async def _user(self) -> Optional[UnixUser]:
205
209
  user = self.run_spec.configuration.user
@@ -53,10 +53,37 @@ async def list_user_projects(
53
53
  session: AsyncSession,
54
54
  user: UserModel,
55
55
  ) -> List[Project]:
56
+ """
57
+ Returns projects where the user is a member.
58
+ """
56
59
  if user.global_role == GlobalRole.ADMIN:
57
60
  projects = await list_project_models(session=session)
58
61
  else:
59
62
  projects = await list_user_project_models(session=session, user=user)
63
+
64
+ projects = sorted(projects, key=lambda p: p.created_at)
65
+ return [
66
+ project_model_to_project(p, include_backends=False, include_members=False)
67
+ for p in projects
68
+ ]
69
+
70
+
71
+ async def list_user_accessible_projects(
72
+ session: AsyncSession,
73
+ user: UserModel,
74
+ ) -> List[Project]:
75
+ """
76
+ Returns all projects accessible to the user:
77
+ - For global admins: ALL projects in the system
78
+ - For regular users: Projects where user is a member + public projects where user is NOT a member
79
+ """
80
+ if user.global_role == GlobalRole.ADMIN:
81
+ projects = await list_project_models(session=session)
82
+ else:
83
+ member_projects = await list_user_project_models(session=session, user=user)
84
+ public_projects = await list_public_non_member_project_models(session=session, user=user)
85
+ projects = member_projects + public_projects
86
+
60
87
  projects = sorted(projects, key=lambda p: p.created_at)
61
88
  return [
62
89
  project_model_to_project(p, include_backends=False, include_members=False)
@@ -86,6 +113,7 @@ async def create_project(
86
113
  session: AsyncSession,
87
114
  user: UserModel,
88
115
  project_name: str,
116
+ is_public: bool = False,
89
117
  ) -> Project:
90
118
  user_permissions = users.get_user_permissions(user)
91
119
  if not user_permissions.can_create_projects:
@@ -100,6 +128,7 @@ async def create_project(
100
128
  session=session,
101
129
  owner=user,
102
130
  project_name=project_name,
131
+ is_public=is_public,
103
132
  )
104
133
  await add_project_member(
105
134
  session=session,
@@ -233,6 +262,9 @@ async def list_user_project_models(
233
262
  user: UserModel,
234
263
  include_members: bool = False,
235
264
  ) -> List[ProjectModel]:
265
+ """
266
+ List project models for a user where they are a member.
267
+ """
236
268
  options = []
237
269
  if include_members:
238
270
  options.append(joinedload(ProjectModel.members))
@@ -248,6 +280,25 @@ async def list_user_project_models(
248
280
  return list(res.scalars().unique().all())
249
281
 
250
282
 
283
+ async def list_public_non_member_project_models(
284
+ session: AsyncSession,
285
+ user: UserModel,
286
+ ) -> List[ProjectModel]:
287
+ """
288
+ List public project models where user is NOT a member.
289
+ """
290
+ res = await session.execute(
291
+ select(ProjectModel).where(
292
+ ProjectModel.deleted == False,
293
+ ProjectModel.is_public == True,
294
+ ProjectModel.id.notin_(
295
+ select(MemberModel.project_id).where(MemberModel.user_id == user.id)
296
+ ),
297
+ )
298
+ )
299
+ return list(res.scalars().all())
300
+
301
+
251
302
  async def list_user_owned_project_models(
252
303
  session: AsyncSession, user: UserModel, include_deleted: bool = False
253
304
  ) -> List[ProjectModel]:
@@ -323,7 +374,7 @@ async def get_project_model_by_id_or_error(
323
374
 
324
375
 
325
376
  async def create_project_model(
326
- session: AsyncSession, owner: UserModel, project_name: str
377
+ session: AsyncSession, owner: UserModel, project_name: str, is_public: bool = False
327
378
  ) -> ProjectModel:
328
379
  private_bytes, public_bytes = await run_async(
329
380
  generate_rsa_key_pair_bytes, f"{project_name}@dstack"
@@ -334,6 +385,7 @@ async def create_project_model(
334
385
  name=project_name,
335
386
  ssh_private_key=private_bytes.decode(),
336
387
  ssh_public_key=public_bytes.decode(),
388
+ is_public=is_public,
337
389
  )
338
390
  session.add(project)
339
391
  await session.commit()
@@ -407,6 +459,7 @@ def project_model_to_project(
407
459
  created_at=project_model.created_at.replace(tzinfo=timezone.utc),
408
460
  backends=backends,
409
461
  members=members,
462
+ is_public=project_model.is_public,
410
463
  )
411
464
 
412
465
 
@@ -32,6 +32,7 @@ from dstack._internal.utils.common import get_or_error
32
32
  from dstack._internal.utils.logging import get_logger
33
33
 
34
34
  REQUEST_TIMEOUT = 9
35
+ UPLOAD_CODE_REQUEST_TIMEOUT = 60
35
36
 
36
37
  logger = get_logger(__name__)
37
38
 
@@ -109,7 +110,9 @@ class RunnerClient:
109
110
  resp.raise_for_status()
110
111
 
111
112
  def upload_code(self, file: Union[BinaryIO, bytes]):
112
- resp = requests.post(self._url("/api/upload_code"), data=file, timeout=REQUEST_TIMEOUT)
113
+ resp = requests.post(
114
+ self._url("/api/upload_code"), data=file, timeout=UPLOAD_CODE_REQUEST_TIMEOUT
115
+ )
113
116
  resp.raise_for_status()
114
117
 
115
118
  def run_job(self):
@@ -439,6 +439,7 @@ async def apply_plan(
439
439
  .values(
440
440
  run_spec=run_spec.json(),
441
441
  priority=run_spec.configuration.priority,
442
+ deployment_num=current_resource.deployment_num + 1,
442
443
  )
443
444
  )
444
445
  run = await get_run_by_name(
@@ -501,6 +502,8 @@ async def submit_run(
501
502
  run_spec=run_spec.json(),
502
503
  last_processed_at=submitted_at,
503
504
  priority=run_spec.configuration.priority,
505
+ deployment_num=0,
506
+ desired_replica_count=1, # a relevant value will be set in process_runs.py
504
507
  )
505
508
  session.add(run_model)
506
509
 
@@ -539,6 +542,7 @@ def create_job_model_for_new_submission(
539
542
  job_num=job.job_spec.job_num,
540
543
  job_name=f"{job.job_spec.job_name}",
541
544
  replica_num=job.job_spec.replica_num,
545
+ deployment_num=run_model.deployment_num,
542
546
  submission_num=len(job.job_submissions),
543
547
  submitted_at=now,
544
548
  last_processed_at=now,
@@ -662,13 +666,9 @@ def run_model_to_run(
662
666
  for job_num, job_submissions in itertools.groupby(
663
667
  replica_submissions, key=lambda j: j.job_num
664
668
  ):
665
- job_spec = None
666
669
  submissions = []
670
+ job_model = None
667
671
  for job_model in job_submissions:
668
- if job_spec is None:
669
- job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
670
- if not include_sensitive:
671
- _remove_job_spec_sensitive_info(job_spec)
672
672
  if include_job_submissions:
673
673
  job_submission = job_model_to_job_submission(job_model)
674
674
  if return_in_api:
@@ -680,7 +680,11 @@ def run_model_to_run(
680
680
  if job_submission.job_provisioning_data.ssh_port is None:
681
681
  job_submission.job_provisioning_data.ssh_port = 22
682
682
  submissions.append(job_submission)
683
- if job_spec is not None:
683
+ if job_model is not None:
684
+ # Use the spec from the latest submission. Submissions can have different specs
685
+ job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
686
+ if not include_sensitive:
687
+ _remove_job_spec_sensitive_info(job_spec)
684
688
  jobs.append(Job(job_spec=job_spec, job_submissions=submissions))
685
689
 
686
690
  run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)
@@ -707,6 +711,7 @@ def run_model_to_run(
707
711
  jobs=jobs,
708
712
  latest_job_submission=latest_job_submission,
709
713
  service=service_spec,
714
+ deployment_num=run_model.deployment_num,
710
715
  deleted=run_model.deleted,
711
716
  )
712
717
  run.cost = _get_run_cost(run)
@@ -897,9 +902,24 @@ _UPDATABLE_SPEC_FIELDS = ["repo_code_hash", "configuration"]
897
902
  _CONF_UPDATABLE_FIELDS = ["priority"]
898
903
  _TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS = {
899
904
  "dev-environment": ["inactivity_duration"],
900
- # Most service fields can be updated via replica redeployment.
901
- # TODO: Allow updating other fields when rolling deployment is supported.
902
- "service": ["replicas", "scaling", "strip_prefix"],
905
+ "service": [
906
+ # in-place
907
+ "replicas",
908
+ "scaling",
909
+ # rolling deployment
910
+ "resources",
911
+ "volumes",
912
+ "image",
913
+ "user",
914
+ "privileged",
915
+ "entrypoint",
916
+ "python",
917
+ "nvcc",
918
+ "single_branch",
919
+ "env",
920
+ "shell",
921
+ "commands",
922
+ ],
903
923
  }
904
924
 
905
925
 
@@ -1004,34 +1024,33 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
1004
1024
  abs(replicas_diff),
1005
1025
  )
1006
1026
 
1007
- # lists of (importance, replica_num, jobs)
1027
+ # lists of (importance, is_out_of_date, replica_num, jobs)
1008
1028
  active_replicas = []
1009
1029
  inactive_replicas = []
1010
1030
 
1011
1031
  for replica_num, replica_jobs in group_jobs_by_replica_latest(run_model.jobs):
1012
1032
  statuses = set(job.status for job in replica_jobs)
1033
+ deployment_num = replica_jobs[0].deployment_num # same for all jobs
1034
+ is_out_of_date = deployment_num < run_model.deployment_num
1013
1035
  if {JobStatus.TERMINATING, *JobStatus.finished_statuses()} & statuses:
1014
1036
  # if there are any terminating or finished jobs, the replica is inactive
1015
- inactive_replicas.append((0, replica_num, replica_jobs))
1037
+ inactive_replicas.append((0, is_out_of_date, replica_num, replica_jobs))
1016
1038
  elif JobStatus.SUBMITTED in statuses:
1017
1039
  # if there are any submitted jobs, the replica is active and has the importance of 0
1018
- active_replicas.append((0, replica_num, replica_jobs))
1040
+ active_replicas.append((0, is_out_of_date, replica_num, replica_jobs))
1019
1041
  elif {JobStatus.PROVISIONING, JobStatus.PULLING} & statuses:
1020
1042
  # if there are any provisioning or pulling jobs, the replica is active and has the importance of 1
1021
- active_replicas.append((1, replica_num, replica_jobs))
1043
+ active_replicas.append((1, is_out_of_date, replica_num, replica_jobs))
1022
1044
  else:
1023
1045
  # all jobs are running, the replica is active and has the importance of 2
1024
- active_replicas.append((2, replica_num, replica_jobs))
1046
+ active_replicas.append((2, is_out_of_date, replica_num, replica_jobs))
1025
1047
 
1026
- # sort by importance (desc) and replica_num (asc)
1027
- active_replicas.sort(key=lambda r: (-r[0], r[1]))
1048
+ # sort by is_out_of_date (up-to-date first), importance (desc), and replica_num (asc)
1049
+ active_replicas.sort(key=lambda r: (r[1], -r[0], r[2]))
1028
1050
  run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)
1029
1051
 
1030
1052
  if replicas_diff < 0:
1031
- if len(active_replicas) + replicas_diff < run_spec.configuration.replicas.min:
1032
- raise ServerClientError("Can't scale down below the minimum number of replicas")
1033
-
1034
- for _, _, replica_jobs in reversed(active_replicas[-abs(replicas_diff) :]):
1053
+ for _, _, _, replica_jobs in reversed(active_replicas[-abs(replicas_diff) :]):
1035
1054
  # scale down the less important replicas first
1036
1055
  for job in replica_jobs:
1037
1056
  if job.status.is_finished() or job.status == JobStatus.TERMINATING:
@@ -1040,18 +1059,15 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
1040
1059
  job.termination_reason = JobTerminationReason.SCALED_DOWN
1041
1060
  # background task will process the job later
1042
1061
  else:
1043
- if len(active_replicas) + replicas_diff > run_spec.configuration.replicas.max:
1044
- raise ServerClientError("Can't scale up above the maximum number of replicas")
1045
1062
  scheduled_replicas = 0
1046
1063
 
1047
1064
  # rerun inactive replicas
1048
- for _, _, replica_jobs in inactive_replicas:
1065
+ for _, _, _, replica_jobs in inactive_replicas:
1049
1066
  if scheduled_replicas == replicas_diff:
1050
1067
  break
1051
1068
  await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False)
1052
1069
  scheduled_replicas += 1
1053
1070
 
1054
- # create new replicas
1055
1071
  for replica_num in range(
1056
1072
  len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff
1057
1073
  ):
@@ -1068,7 +1084,14 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
1068
1084
  async def retry_run_replica_jobs(
1069
1085
  session: AsyncSession, run_model: RunModel, latest_jobs: List[JobModel], *, only_failed: bool
1070
1086
  ):
1071
- for job_model in latest_jobs:
1087
+ new_jobs = await get_jobs_from_run_spec(
1088
+ RunSpec.__response__.parse_raw(run_model.run_spec),
1089
+ replica_num=latest_jobs[0].replica_num,
1090
+ )
1091
+ assert len(new_jobs) == len(latest_jobs), (
1092
+ "Changing the number of jobs within a replica is not yet supported"
1093
+ )
1094
+ for job_model, new_job in zip(latest_jobs, new_jobs):
1072
1095
  if not (job_model.status.is_finished() or job_model.status == JobStatus.TERMINATING):
1073
1096
  if only_failed:
1074
1097
  # No need to resubmit, skip
@@ -1079,10 +1102,7 @@ async def retry_run_replica_jobs(
1079
1102
 
1080
1103
  new_job_model = create_job_model_for_new_submission(
1081
1104
  run_model=run_model,
1082
- job=Job(
1083
- job_spec=JobSpec.__response__.parse_raw(job_model.job_spec_data),
1084
- job_submissions=[],
1085
- ),
1105
+ job=new_job,
1086
1106
  status=JobStatus.SUBMITTED,
1087
1107
  )
1088
1108
  # dirty hack to avoid passing all job submissions
@@ -30,6 +30,7 @@ from dstack._internal.server.services.gateways import (
30
30
  get_project_gateway_model_by_name,
31
31
  )
32
32
  from dstack._internal.server.services.logging import fmt
33
+ from dstack._internal.server.services.services.autoscalers import get_service_scaler
33
34
  from dstack._internal.server.services.services.options import get_service_options
34
35
  from dstack._internal.utils.logging import get_logger
35
36
 
@@ -258,3 +259,21 @@ def _get_gateway_https(configuration: GatewayConfiguration) -> bool:
258
259
  if configuration.certificate is not None and configuration.certificate.type == "lets-encrypt":
259
260
  return True
260
261
  return False
262
+
263
+
264
+ async def update_service_desired_replica_count(
265
+ session: AsyncSession,
266
+ run_model: RunModel,
267
+ configuration: ServiceConfiguration,
268
+ last_scaled_at: Optional[int],
269
+ ) -> None:
270
+ scaler = get_service_scaler(configuration)
271
+ stats = None
272
+ if run_model.gateway_id is not None:
273
+ conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
274
+ stats = await conn.get_stats(run_model.project.name, run_model.run_name)
275
+ run_model.desired_replica_count = scaler.get_desired_count(
276
+ current_desired_count=run_model.desired_replica_count,
277
+ stats=stats,
278
+ last_scaled_at=last_scaled_at,
279
+ )
@@ -1,7 +1,7 @@
1
1
  import datetime
2
2
  import math
3
3
  from abc import ABC, abstractmethod
4
- from typing import List, Optional
4
+ from typing import Optional
5
5
 
6
6
  from pydantic import BaseModel
7
7
 
@@ -23,14 +23,20 @@ class ReplicaInfo(BaseModel):
23
23
 
24
24
  class BaseServiceScaler(ABC):
25
25
  @abstractmethod
26
- def scale(self, replicas: List[ReplicaInfo], stats: Optional[PerWindowStats]) -> int:
26
+ def get_desired_count(
27
+ self,
28
+ current_desired_count: int,
29
+ stats: Optional[PerWindowStats],
30
+ last_scaled_at: Optional[datetime.datetime],
31
+ ) -> int:
27
32
  """
28
33
  Args:
29
- replicas: list of all replicas
30
34
  stats: service usage stats
35
+ current_desired_count: currently used desired count
36
+ last_scaled_at: last time service was scaled, None if it was never scaled yet
31
37
 
32
38
  Returns:
33
- diff: number of replicas to add or remove
39
+ desired_count: desired count of replicas
34
40
  """
35
41
  pass
36
42
 
@@ -49,12 +55,14 @@ class ManualScaler(BaseServiceScaler):
49
55
  self.min_replicas = min_replicas
50
56
  self.max_replicas = max_replicas
51
57
 
52
- def scale(self, replicas: List[ReplicaInfo], stats: Optional[PerWindowStats]) -> int:
53
- active_replicas = [r for r in replicas if r.active]
54
- target_replicas = len(active_replicas)
55
- # clip the target replicas to the min and max values
56
- target_replicas = min(max(target_replicas, self.min_replicas), self.max_replicas)
57
- return target_replicas - len(active_replicas)
58
+ def get_desired_count(
59
+ self,
60
+ current_desired_count: int,
61
+ stats: Optional[PerWindowStats],
62
+ last_scaled_at: Optional[datetime.datetime],
63
+ ) -> int:
64
+ # clip the desired count to the min and max values
65
+ return min(max(current_desired_count, self.min_replicas), self.max_replicas)
58
66
 
59
67
 
60
68
  class RPSAutoscaler(BaseServiceScaler):
@@ -72,40 +80,43 @@ class RPSAutoscaler(BaseServiceScaler):
72
80
  self.scale_up_delay = scale_up_delay
73
81
  self.scale_down_delay = scale_down_delay
74
82
 
75
- def scale(self, replicas: List[ReplicaInfo], stats: Optional[PerWindowStats]) -> int:
83
+ def get_desired_count(
84
+ self,
85
+ current_desired_count: int,
86
+ stats: Optional[PerWindowStats],
87
+ last_scaled_at: Optional[datetime.datetime],
88
+ ) -> int:
76
89
  if not stats:
77
- return 0
90
+ return current_desired_count
78
91
 
79
92
  now = common_utils.get_current_datetime()
80
- active_replicas = [r for r in replicas if r.active]
81
- last_scaled_at = max((r.timestamp for r in replicas), default=None)
82
93
 
83
94
  # calculate the average RPS over the last minute
84
95
  rps = stats[60].requests / 60
85
- target_replicas = math.ceil(rps / self.target)
86
- # clip the target replicas to the min and max values
87
- target_replicas = min(max(target_replicas, self.min_replicas), self.max_replicas)
96
+ new_desired_count = math.ceil(rps / self.target)
97
+ # clip the desired count to the min and max values
98
+ new_desired_count = min(max(new_desired_count, self.min_replicas), self.max_replicas)
88
99
 
89
- if target_replicas > len(active_replicas):
90
- if len(active_replicas) == 0:
100
+ if new_desired_count > current_desired_count:
101
+ if current_desired_count == 0:
91
102
  # no replicas, scale up immediately
92
- return target_replicas
103
+ return new_desired_count
93
104
  if (
94
105
  last_scaled_at is not None
95
106
  and (now - last_scaled_at).total_seconds() < self.scale_up_delay
96
107
  ):
97
108
  # too early to scale up, wait for the delay
98
- return 0
99
- return target_replicas - len(active_replicas)
100
- elif target_replicas < len(active_replicas):
109
+ return current_desired_count
110
+ return new_desired_count
111
+ elif new_desired_count < current_desired_count:
101
112
  if (
102
113
  last_scaled_at is not None
103
114
  and (now - last_scaled_at).total_seconds() < self.scale_down_delay
104
115
  ):
105
116
  # too early to scale down, wait for the delay
106
- return 0
107
- return target_replicas - len(active_replicas)
108
- return 0
117
+ return current_desired_count
118
+ return new_desired_count
119
+ return new_desired_count
109
120
 
110
121
 
111
122
  def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler:
@@ -0,0 +1,38 @@
1
+ from typing import Optional
2
+
3
+ from dstack._internal.server import settings
4
+ from dstack._internal.server.services.storage.base import BaseStorage
5
+ from dstack._internal.server.services.storage.gcs import GCS_AVAILABLE, GCSStorage
6
+ from dstack._internal.server.services.storage.s3 import BOTO_AVAILABLE, S3Storage
7
+
8
+ _default_storage = None
9
+
10
+
11
+ def init_default_storage():
12
+ global _default_storage
13
+ if settings.SERVER_S3_BUCKET is None and settings.SERVER_GCS_BUCKET is None:
14
+ raise ValueError(
15
+ "Either settings.SERVER_S3_BUCKET or settings.SERVER_GCS_BUCKET must be set"
16
+ )
17
+ if settings.SERVER_S3_BUCKET and settings.SERVER_GCS_BUCKET:
18
+ raise ValueError(
19
+ "Only one of settings.SERVER_S3_BUCKET or settings.SERVER_GCS_BUCKET can be set"
20
+ )
21
+
22
+ if settings.SERVER_S3_BUCKET:
23
+ if not BOTO_AVAILABLE:
24
+ raise ValueError("AWS dependencies are not installed")
25
+ _default_storage = S3Storage(
26
+ bucket=settings.SERVER_S3_BUCKET,
27
+ region=settings.SERVER_S3_BUCKET_REGION,
28
+ )
29
+ elif settings.SERVER_GCS_BUCKET:
30
+ if not GCS_AVAILABLE:
31
+ raise ValueError("GCS dependencies are not installed")
32
+ _default_storage = GCSStorage(
33
+ bucket=settings.SERVER_GCS_BUCKET,
34
+ )
35
+
36
+
37
+ def get_default_storage() -> Optional[BaseStorage]:
38
+ return _default_storage
@@ -0,0 +1,27 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional
3
+
4
+
5
+ class BaseStorage(ABC):
6
+ @abstractmethod
7
+ def upload_code(
8
+ self,
9
+ project_id: str,
10
+ repo_id: str,
11
+ code_hash: str,
12
+ blob: bytes,
13
+ ):
14
+ pass
15
+
16
+ @abstractmethod
17
+ def get_code(
18
+ self,
19
+ project_id: str,
20
+ repo_id: str,
21
+ code_hash: str,
22
+ ) -> Optional[bytes]:
23
+ pass
24
+
25
+ @staticmethod
26
+ def _get_code_key(project_id: str, repo_id: str, code_hash: str) -> str:
27
+ return f"data/projects/{project_id}/codes/{repo_id}/{code_hash}"