dstack 0.19.13__py3-none-any.whl → 0.19.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (35) hide show
  1. dstack/_internal/cli/commands/attach.py +4 -4
  2. dstack/_internal/cli/services/configurators/run.py +1 -0
  3. dstack/_internal/cli/utils/run.py +16 -4
  4. dstack/_internal/core/compatibility/runs.py +6 -0
  5. dstack/_internal/core/models/projects.py +1 -0
  6. dstack/_internal/core/models/runs.py +23 -1
  7. dstack/_internal/server/app.py +45 -0
  8. dstack/_internal/server/background/tasks/process_running_jobs.py +45 -3
  9. dstack/_internal/server/background/tasks/process_runs.py +149 -79
  10. dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py +42 -0
  11. dstack/_internal/server/migrations/versions/35f732ee4cf5_add_projectmodel_is_public.py +39 -0
  12. dstack/_internal/server/models.py +4 -0
  13. dstack/_internal/server/routers/projects.py +4 -3
  14. dstack/_internal/server/routers/prometheus.py +4 -1
  15. dstack/_internal/server/schemas/projects.py +1 -0
  16. dstack/_internal/server/security/permissions.py +36 -0
  17. dstack/_internal/server/services/jobs/__init__.py +1 -0
  18. dstack/_internal/server/services/jobs/configurators/base.py +11 -7
  19. dstack/_internal/server/services/projects.py +54 -1
  20. dstack/_internal/server/services/runs.py +49 -29
  21. dstack/_internal/server/services/services/__init__.py +19 -0
  22. dstack/_internal/server/services/services/autoscalers.py +37 -26
  23. dstack/_internal/server/statics/index.html +1 -1
  24. dstack/_internal/server/statics/{main-2066f1f22ddb4557bcde.js → main-0ac1e1583684417ae4d1.js} +26 -24
  25. dstack/_internal/server/statics/{main-2066f1f22ddb4557bcde.js.map → main-0ac1e1583684417ae4d1.js.map} +1 -1
  26. dstack/_internal/server/testing/common.py +9 -0
  27. dstack/_internal/settings.py +3 -0
  28. dstack/api/_public/runs.py +14 -5
  29. dstack/api/server/_projects.py +2 -2
  30. dstack/version.py +2 -2
  31. {dstack-0.19.13.dist-info → dstack-0.19.14.dist-info}/METADATA +1 -1
  32. {dstack-0.19.13.dist-info → dstack-0.19.14.dist-info}/RECORD +35 -33
  33. {dstack-0.19.13.dist-info → dstack-0.19.14.dist-info}/WHEEL +0 -0
  34. {dstack-0.19.13.dist-info → dstack-0.19.14.dist-info}/entry_points.txt +0 -0
  35. {dstack-0.19.13.dist-info → dstack-0.19.14.dist-info}/licenses/LICENSE.md +0 -0
@@ -52,9 +52,8 @@ class AttachCommand(APIBaseCommand):
52
52
  )
53
53
  self._parser.add_argument(
54
54
  "--replica",
55
- help="The replica number. Defaults to 0.",
55
+ help="The replica number. Defaults to any running replica.",
56
56
  type=int,
57
- default=0,
58
57
  )
59
58
  self._parser.add_argument(
60
59
  "--job",
@@ -129,14 +128,15 @@ _IGNORED_PORTS = [DSTACK_RUNNER_HTTP_PORT]
129
128
  def _print_attached_message(
130
129
  run: Run,
131
130
  bind_address: Optional[str],
132
- replica_num: int,
131
+ replica_num: Optional[int],
133
132
  job_num: int,
134
133
  ):
135
134
  if bind_address is None:
136
135
  bind_address = "localhost"
137
136
 
138
- output = f"Attached to run [code]{run.name}[/] (replica={replica_num} job={job_num})\n"
139
137
  job = get_or_error(run._find_job(replica_num=replica_num, job_num=job_num))
138
+ replica_num = job.job_spec.replica_num
139
+ output = f"Attached to run [code]{run.name}[/] (replica={replica_num} job={job_num})\n"
140
140
  name = run.name
141
141
  if replica_num != 0 or job_num != 0:
142
142
  name = job.job_spec.job_name
@@ -599,6 +599,7 @@ def _is_ready_to_attach(run: Run) -> bool:
599
599
  ]
600
600
  or run._run.jobs[0].job_submissions[-1].status
601
601
  in [JobStatus.SUBMITTED, JobStatus.PROVISIONING, JobStatus.PULLING]
602
+ or run._run.is_deployment_in_progress()
602
603
  )
603
604
 
604
605
 
@@ -162,9 +162,16 @@ def get_runs_table(
162
162
 
163
163
  for run in runs:
164
164
  run = run._run # TODO(egor-s): make public attribute
165
+ show_deployment_num = (
166
+ verbose
167
+ and run.run_spec.configuration.type == "service"
168
+ or run.is_deployment_in_progress()
169
+ )
170
+ merge_job_rows = len(run.jobs) == 1 and not show_deployment_num
165
171
 
166
172
  run_row: Dict[Union[str, int], Any] = {
167
- "NAME": run.run_spec.run_name,
173
+ "NAME": run.run_spec.run_name
174
+ + (f" [secondary]deployment={run.deployment_num}[/]" if show_deployment_num else ""),
168
175
  "SUBMITTED": format_date(run.submitted_at),
169
176
  "STATUS": (
170
177
  run.latest_job_submission.status_message
@@ -174,7 +181,7 @@ def get_runs_table(
174
181
  }
175
182
  if run.error:
176
183
  run_row["ERROR"] = run.error
177
- if len(run.jobs) != 1:
184
+ if not merge_job_rows:
178
185
  add_row_from_dict(table, run_row)
179
186
 
180
187
  for job in run.jobs:
@@ -184,7 +191,12 @@ def get_runs_table(
184
191
  inactive_for = format_duration_multiunit(latest_job_submission.inactivity_secs)
185
192
  status += f" (inactive for {inactive_for})"
186
193
  job_row: Dict[Union[str, int], Any] = {
187
- "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}",
194
+ "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}"
195
+ + (
196
+ f" deployment={latest_job_submission.deployment_num}"
197
+ if show_deployment_num
198
+ else ""
199
+ ),
188
200
  "STATUS": latest_job_submission.status_message,
189
201
  "SUBMITTED": format_date(latest_job_submission.submitted_at),
190
202
  "ERROR": latest_job_submission.error,
@@ -208,7 +220,7 @@ def get_runs_table(
208
220
  "PRICE": f"${jpd.price:.4f}".rstrip("0").rstrip("."),
209
221
  }
210
222
  )
211
- if len(run.jobs) == 1:
223
+ if merge_job_rows:
212
224
  # merge rows
213
225
  job_row.update(run_row)
214
226
  add_row_from_dict(table, job_row, style="secondary" if len(run.jobs) != 1 else None)
@@ -19,6 +19,8 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]:
19
19
  if current_resource is not None:
20
20
  current_resource_excludes = {}
21
21
  current_resource_excludes["status_message"] = True
22
+ if current_resource.deployment_num == 0:
23
+ current_resource_excludes["deployment_num"] = True
22
24
  apply_plan_excludes["current_resource"] = current_resource_excludes
23
25
  current_resource_excludes["run_spec"] = get_run_spec_excludes(current_resource.run_spec)
24
26
  job_submissions_excludes = {}
@@ -36,6 +38,8 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]:
36
38
  }
37
39
  if all(js.exit_status is None for js in job_submissions):
38
40
  job_submissions_excludes["exit_status"] = True
41
+ if all(js.deployment_num == 0 for js in job_submissions):
42
+ job_submissions_excludes["deployment_num"] = True
39
43
  latest_job_submission = current_resource.latest_job_submission
40
44
  if latest_job_submission is not None:
41
45
  latest_job_submission_excludes = {}
@@ -50,6 +54,8 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]:
50
54
  }
51
55
  if latest_job_submission.exit_status is None:
52
56
  latest_job_submission_excludes["exit_status"] = True
57
+ if latest_job_submission.deployment_num == 0:
58
+ latest_job_submission_excludes["deployment_num"] = True
53
59
  return {"plan": apply_plan_excludes}
54
60
 
55
61
 
@@ -25,3 +25,4 @@ class Project(CoreModel):
25
25
  created_at: Optional[datetime] = None
26
26
  backends: List[BackendInfo]
27
27
  members: List[Member]
28
+ is_public: bool = False
@@ -148,6 +148,19 @@ class JobTerminationReason(str, Enum):
148
148
  }
149
149
  return mapping[self]
150
150
 
151
+ def to_retry_event(self) -> Optional[RetryEvent]:
152
+ """
153
+ Returns:
154
+ the retry event this termination reason triggers
155
+ or None if this termination reason should not be retried
156
+ """
157
+ mapping = {
158
+ self.FAILED_TO_START_DUE_TO_NO_CAPACITY: RetryEvent.NO_CAPACITY,
159
+ self.INTERRUPTED_BY_NO_CAPACITY: RetryEvent.INTERRUPTION,
160
+ }
161
+ default = RetryEvent.ERROR if self.to_status() == JobStatus.FAILED else None
162
+ return mapping.get(self, default)
163
+
151
164
 
152
165
  class Requirements(CoreModel):
153
166
  # TODO: Make requirements' fields required
@@ -276,6 +289,7 @@ class ClusterInfo(CoreModel):
276
289
  class JobSubmission(CoreModel):
277
290
  id: UUID4
278
291
  submission_num: int
292
+ deployment_num: int = 0 # default for compatibility with pre-0.19.14 servers
279
293
  submitted_at: datetime
280
294
  last_processed_at: datetime
281
295
  finished_at: Optional[datetime]
@@ -354,7 +368,7 @@ class JobSubmission(CoreModel):
354
368
  error_mapping = {
355
369
  JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable",
356
370
  JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded",
357
- JobTerminationReason.VOLUME_ERROR: "waiting runner limit exceeded",
371
+ JobTerminationReason.VOLUME_ERROR: "volume error",
358
372
  JobTerminationReason.GATEWAY_ERROR: "gateway error",
359
373
  JobTerminationReason.SCALED_DOWN: "scaled down",
360
374
  JobTerminationReason.INACTIVITY_DURATION_EXCEEDED: "inactivity duration exceeded",
@@ -503,6 +517,7 @@ class Run(CoreModel):
503
517
  latest_job_submission: Optional[JobSubmission]
504
518
  cost: float = 0
505
519
  service: Optional[ServiceSpec] = None
520
+ deployment_num: int = 0 # default for compatibility with pre-0.19.14 servers
506
521
  # TODO: make error a computed field after migrating to pydanticV2
507
522
  error: Optional[str] = None
508
523
  deleted: Optional[bool] = None
@@ -565,6 +580,13 @@ class Run(CoreModel):
565
580
  return "retrying"
566
581
  return status.value
567
582
 
583
+ def is_deployment_in_progress(self) -> bool:
584
+ return any(
585
+ not j.job_submissions[-1].status.is_finished()
586
+ and j.job_submissions[-1].deployment_num != self.deployment_num
587
+ for j in self.jobs
588
+ )
589
+
568
590
 
569
591
  class JobPlan(CoreModel):
570
592
  job_spec: JobSpec
@@ -11,6 +11,7 @@ from fastapi import FastAPI, Request, Response, status
11
11
  from fastapi.datastructures import URL
12
12
  from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
13
13
  from fastapi.staticfiles import StaticFiles
14
+ from prometheus_client import Counter, Histogram
14
15
 
15
16
  from dstack._internal.cli.utils.common import console
16
17
  from dstack._internal.core.errors import ForbiddenError, ServerClientError
@@ -63,6 +64,18 @@ from dstack._internal.utils.ssh import check_required_ssh_version
63
64
 
64
65
  logger = get_logger(__name__)
65
66
 
67
+ # Server HTTP metrics
68
+ REQUESTS_TOTAL = Counter(
69
+ "dstack_server_requests_total",
70
+ "Total number of HTTP requests",
71
+ ["method", "endpoint", "http_status", "project_name"],
72
+ )
73
+ REQUEST_DURATION = Histogram(
74
+ "dstack_server_request_duration_seconds",
75
+ "HTTP request duration in seconds",
76
+ ["method", "endpoint", "http_status", "project_name"],
77
+ )
78
+
66
79
 
67
80
  def create_app() -> FastAPI:
68
81
  if settings.SENTRY_DSN is not None:
@@ -216,6 +229,8 @@ def register_routes(app: FastAPI, ui: bool = True):
216
229
  start_time = time.time()
217
230
  response: Response = await call_next(request)
218
231
  process_time = time.time() - start_time
232
+ # log process_time to be used in the log_http_metrics middleware
233
+ request.state.process_time = process_time
219
234
  logger.debug(
220
235
  "Processed request %s %s in %s. Status: %s",
221
236
  request.method,
@@ -225,6 +240,36 @@ def register_routes(app: FastAPI, ui: bool = True):
225
240
  )
226
241
  return response
227
242
 
243
+ # this middleware must be defined after the log_request middleware
244
+ @app.middleware("http")
245
+ async def log_http_metrics(request: Request, call_next):
246
+ def _extract_project_name(request: Request):
247
+ project_name = None
248
+ prefix = "/api/project/"
249
+ if request.url.path.startswith(prefix):
250
+ rest = request.url.path[len(prefix) :]
251
+ project_name = rest.split("/", 1)[0] if rest else None
252
+
253
+ return project_name
254
+
255
+ project_name = _extract_project_name(request)
256
+ response: Response = await call_next(request)
257
+
258
+ REQUEST_DURATION.labels(
259
+ method=request.method,
260
+ endpoint=request.url.path,
261
+ http_status=response.status_code,
262
+ project_name=project_name,
263
+ ).observe(request.state.process_time)
264
+
265
+ REQUESTS_TOTAL.labels(
266
+ method=request.method,
267
+ endpoint=request.url.path,
268
+ http_status=response.status_code,
269
+ project_name=project_name,
270
+ ).inc()
271
+ return response
272
+
228
273
  @app.middleware("http")
229
274
  async def check_client_version(request: Request, call_next):
230
275
  if (
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import re
2
3
  from collections.abc import Iterable
3
4
  from datetime import timedelta, timezone
4
5
  from typing import Dict, List, Optional
@@ -7,6 +8,7 @@ from sqlalchemy import select
7
8
  from sqlalchemy.ext.asyncio import AsyncSession
8
9
  from sqlalchemy.orm import joinedload
9
10
 
11
+ from dstack._internal import settings
10
12
  from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT
11
13
  from dstack._internal.core.errors import GatewayError
12
14
  from dstack._internal.core.models.backends.base import BackendType
@@ -517,14 +519,14 @@ def _process_provisioning_with_shim(
517
519
  cpu = None
518
520
  memory = None
519
521
  network_mode = NetworkMode.HOST
520
-
522
+ image_name = _patch_base_image_for_aws_efa(job_spec, job_provisioning_data)
521
523
  if shim_client.is_api_v2_supported():
522
524
  shim_client.submit_task(
523
525
  task_id=job_model.id,
524
526
  name=job_model.job_name,
525
527
  registry_username=registry_username,
526
528
  registry_password=registry_password,
527
- image_name=job_spec.image_name,
529
+ image_name=image_name,
528
530
  container_user=container_user,
529
531
  privileged=job_spec.privileged,
530
532
  gpu=gpu,
@@ -545,7 +547,7 @@ def _process_provisioning_with_shim(
545
547
  submitted = shim_client.submit(
546
548
  username=registry_username,
547
549
  password=registry_password,
548
- image_name=job_spec.image_name,
550
+ image_name=image_name,
549
551
  privileged=job_spec.privileged,
550
552
  container_name=job_model.job_name,
551
553
  container_user=container_user,
@@ -969,3 +971,43 @@ def _get_instance_specific_gpu_devices(
969
971
  GPUDevice(path_on_host="/dev/nvidiactl", path_in_container="/dev/nvidiactl")
970
972
  )
971
973
  return gpu_devices
974
+
975
+
976
+ def _patch_base_image_for_aws_efa(
977
+ job_spec: JobSpec, job_provisioning_data: JobProvisioningData
978
+ ) -> str:
979
+ image_name = job_spec.image_name
980
+
981
+ if job_provisioning_data.backend != BackendType.AWS:
982
+ return image_name
983
+
984
+ instance_type = job_provisioning_data.instance_type.name
985
+ efa_enabled_patterns = [
986
+ # TODO: p6-b200 isn't supported yet in gpuhunt
987
+ r"^p6-b200\.(48xlarge)$",
988
+ r"^p5\.(48xlarge)$",
989
+ r"^p5e\.(48xlarge)$",
990
+ r"^p5en\.(48xlarge)$",
991
+ r"^p4d\.(24xlarge)$",
992
+ r"^p4de\.(24xlarge)$",
993
+ r"^g6\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
994
+ r"^g6e\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
995
+ r"^gr6\.8xlarge$",
996
+ r"^g5\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
997
+ r"^g4dn\.(8xlarge|12xlarge|16xlarge|metal)$",
998
+ r"^p3dn\.(24xlarge)$",
999
+ ]
1000
+
1001
+ is_efa_enabled = any(re.match(pattern, instance_type) for pattern in efa_enabled_patterns)
1002
+ if not is_efa_enabled:
1003
+ return image_name
1004
+
1005
+ if not image_name.startswith(f"{settings.DSTACK_BASE_IMAGE}:"):
1006
+ return image_name
1007
+
1008
+ if image_name.endswith(f"-base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"):
1009
+ return image_name[:-17] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"
1010
+ elif image_name.endswith(f"-devel-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"):
1011
+ return image_name[:-18] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"
1012
+
1013
+ return image_name
@@ -1,18 +1,17 @@
1
1
  import asyncio
2
2
  import datetime
3
- import itertools
4
3
  from typing import List, Optional, Set, Tuple
5
4
 
6
5
  from sqlalchemy import select
7
6
  from sqlalchemy.ext.asyncio import AsyncSession
8
7
  from sqlalchemy.orm import joinedload, selectinload
9
8
 
10
- import dstack._internal.server.services.gateways as gateways
11
9
  import dstack._internal.server.services.services.autoscalers as autoscalers
12
10
  from dstack._internal.core.errors import ServerError
13
11
  from dstack._internal.core.models.profiles import RetryEvent, StopCriteria
14
12
  from dstack._internal.core.models.runs import (
15
13
  Job,
14
+ JobSpec,
16
15
  JobStatus,
17
16
  JobTerminationReason,
18
17
  Run,
@@ -24,22 +23,23 @@ from dstack._internal.server.db import get_session_ctx
24
23
  from dstack._internal.server.models import JobModel, ProjectModel, RunModel
25
24
  from dstack._internal.server.services.jobs import (
26
25
  find_job,
27
- get_jobs_from_run_spec,
26
+ get_job_specs_from_run_spec,
28
27
  group_jobs_by_replica_latest,
29
28
  )
30
29
  from dstack._internal.server.services.locking import get_locker
31
30
  from dstack._internal.server.services.runs import (
32
- create_job_model_for_new_submission,
33
31
  fmt,
34
32
  process_terminating_run,
35
33
  retry_run_replica_jobs,
36
34
  run_model_to_run,
37
35
  scale_run_replicas,
38
36
  )
37
+ from dstack._internal.server.services.services import update_service_desired_replica_count
39
38
  from dstack._internal.utils import common
40
39
  from dstack._internal.utils.logging import get_logger
41
40
 
42
41
  logger = get_logger(__name__)
42
+ ROLLING_DEPLOYMENT_MAX_SURGE = 1 # at most one extra replica during rolling deployment
43
43
 
44
44
 
45
45
  async def process_runs(batch_size: int = 1):
@@ -133,46 +133,22 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel):
133
133
  logger.debug("%s: pending run is not yet ready for resubmission", fmt(run_model))
134
134
  return
135
135
 
136
- # TODO(egor-s) consolidate with `scale_run_replicas` if possible
137
- replicas = 1
136
+ run_model.desired_replica_count = 1
138
137
  if run.run_spec.configuration.type == "service":
139
- replicas = run.run_spec.configuration.replicas.min or 0 # new default
140
- scaler = autoscalers.get_service_scaler(run.run_spec.configuration)
141
- stats = None
142
- if run_model.gateway_id is not None:
143
- conn = await gateways.get_or_add_gateway_connection(session, run_model.gateway_id)
144
- stats = await conn.get_stats(run_model.project.name, run_model.run_name)
145
- # replicas info doesn't matter for now
146
- replicas = scaler.scale([], stats)
147
- if replicas == 0:
138
+ run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0
139
+ await update_service_desired_replica_count(
140
+ session,
141
+ run_model,
142
+ run.run_spec.configuration,
143
+ # does not matter for pending services, since 0->n scaling should happen without delay
144
+ last_scaled_at=None,
145
+ )
146
+
147
+ if run_model.desired_replica_count == 0:
148
148
  # stay zero scaled
149
149
  return
150
150
 
151
- scheduled_replicas = 0
152
- # Resubmit existing replicas
153
- for replica_num, replica_jobs in itertools.groupby(
154
- run.jobs, key=lambda j: j.job_spec.replica_num
155
- ):
156
- if scheduled_replicas >= replicas:
157
- break
158
- scheduled_replicas += 1
159
- for job in replica_jobs:
160
- new_job_model = create_job_model_for_new_submission(
161
- run_model=run_model,
162
- job=job,
163
- status=JobStatus.SUBMITTED,
164
- )
165
- session.add(new_job_model)
166
- # Create missing replicas
167
- for replica_num in range(scheduled_replicas, replicas):
168
- jobs = await get_jobs_from_run_spec(run.run_spec, replica_num=replica_num)
169
- for job in jobs:
170
- job_model = create_job_model_for_new_submission(
171
- run_model=run_model,
172
- job=job,
173
- status=JobStatus.SUBMITTED,
174
- )
175
- session.add(job_model)
151
+ await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count)
176
152
 
177
153
  run_model.status = RunStatus.SUBMITTED
178
154
  logger.info("%s: run status has changed PENDING -> SUBMITTED", fmt(run_model))
@@ -340,27 +316,11 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
340
316
  job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
341
317
 
342
318
  if new_status not in {RunStatus.TERMINATING, RunStatus.PENDING}:
343
- # No need to retry if the run is terminating,
319
+ # No need to retry, scale, or redeploy replicas if the run is terminating,
344
320
  # pending run will retry replicas in `process_pending_run`
345
- for _, replica_jobs in replicas_to_retry:
346
- await retry_run_replica_jobs(
347
- session, run_model, replica_jobs, only_failed=retry_single_job
348
- )
349
-
350
- if run_spec.configuration.type == "service":
351
- scaler = autoscalers.get_service_scaler(run_spec.configuration)
352
- stats = None
353
- if run_model.gateway_id is not None:
354
- conn = await gateways.get_or_add_gateway_connection(session, run_model.gateway_id)
355
- stats = await conn.get_stats(run_model.project.name, run_model.run_name)
356
- # use replicas_info from before retrying
357
- replicas_diff = scaler.scale(replicas_info, stats)
358
- if replicas_diff != 0:
359
- # FIXME: potentially long write transaction
360
- # Why do we flush here?
361
- await session.flush()
362
- await session.refresh(run_model)
363
- await scale_run_replicas(session, run_model, replicas_diff)
321
+ await _handle_run_replicas(
322
+ session, run_model, run_spec, replicas_to_retry, retry_single_job, replicas_info
323
+ )
364
324
 
365
325
  if run_model.status != new_status:
366
326
  logger.info(
@@ -378,6 +338,130 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
378
338
  run_model.resubmission_attempt += 1
379
339
 
380
340
 
341
+ async def _handle_run_replicas(
342
+ session: AsyncSession,
343
+ run_model: RunModel,
344
+ run_spec: RunSpec,
345
+ replicas_to_retry: list[tuple[int, list[JobModel]]],
346
+ retry_single_job: bool,
347
+ replicas_info: list[autoscalers.ReplicaInfo],
348
+ ) -> None:
349
+ """
350
+ Does ONE of:
351
+ - replica retry
352
+ - replica scaling
353
+ - replica rolling deployment
354
+
355
+ Does not do everything at once to avoid conflicts between the stages and long DB transactions.
356
+ """
357
+
358
+ if replicas_to_retry:
359
+ for _, replica_jobs in replicas_to_retry:
360
+ await retry_run_replica_jobs(
361
+ session, run_model, replica_jobs, only_failed=retry_single_job
362
+ )
363
+ return
364
+
365
+ if run_spec.configuration.type == "service":
366
+ await update_service_desired_replica_count(
367
+ session,
368
+ run_model,
369
+ run_spec.configuration,
370
+ # FIXME: should only include scaling events, not retries and deployments
371
+ last_scaled_at=max((r.timestamp for r in replicas_info), default=None),
372
+ )
373
+
374
+ max_replica_count = run_model.desired_replica_count
375
+ if _has_out_of_date_replicas(run_model):
376
+ # allow extra replicas when deployment is in progress
377
+ max_replica_count += ROLLING_DEPLOYMENT_MAX_SURGE
378
+
379
+ active_replica_count = sum(1 for r in replicas_info if r.active)
380
+ if active_replica_count not in range(run_model.desired_replica_count, max_replica_count + 1):
381
+ await scale_run_replicas(
382
+ session,
383
+ run_model,
384
+ replicas_diff=run_model.desired_replica_count - active_replica_count,
385
+ )
386
+ return
387
+
388
+ await _update_jobs_to_new_deployment_in_place(run_model, run_spec)
389
+ if _has_out_of_date_replicas(run_model):
390
+ non_terminated_replica_count = len(
391
+ {j.replica_num for j in run_model.jobs if not j.status.is_finished()}
392
+ )
393
+ # Avoid using too much hardware during a deployment - never have
394
+ # more than max_replica_count non-terminated replicas.
395
+ if non_terminated_replica_count < max_replica_count:
396
+ # Start more up-to-date replicas that will eventually replace out-of-date replicas.
397
+ await scale_run_replicas(
398
+ session,
399
+ run_model,
400
+ replicas_diff=max_replica_count - non_terminated_replica_count,
401
+ )
402
+
403
+ replicas_to_stop_count = 0
404
+ # stop any out-of-date replicas that are not running
405
+ replicas_to_stop_count += len(
406
+ {
407
+ j.replica_num
408
+ for j in run_model.jobs
409
+ if j.status
410
+ not in [JobStatus.RUNNING, JobStatus.TERMINATING] + JobStatus.finished_statuses()
411
+ and j.deployment_num < run_model.deployment_num
412
+ }
413
+ )
414
+ running_replica_count = len(
415
+ {j.replica_num for j in run_model.jobs if j.status == JobStatus.RUNNING}
416
+ )
417
+ if running_replica_count > run_model.desired_replica_count:
418
+ # stop excessive running out-of-date replicas
419
+ replicas_to_stop_count += running_replica_count - run_model.desired_replica_count
420
+ if replicas_to_stop_count:
421
+ await scale_run_replicas(
422
+ session,
423
+ run_model,
424
+ replicas_diff=-replicas_to_stop_count,
425
+ )
426
+
427
+
428
+ async def _update_jobs_to_new_deployment_in_place(run_model: RunModel, run_spec: RunSpec) -> None:
429
+ """
430
+ Bump deployment_num for jobs that do not require redeployment.
431
+ """
432
+
433
+ for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
434
+ if all(j.status.is_finished() for j in job_models):
435
+ continue
436
+ if all(j.deployment_num == run_model.deployment_num for j in job_models):
437
+ continue
438
+ new_job_specs = await get_job_specs_from_run_spec(
439
+ run_spec=run_spec,
440
+ replica_num=replica_num,
441
+ )
442
+ assert len(new_job_specs) == len(job_models), (
443
+ "Changing the number of jobs within a replica is not yet supported"
444
+ )
445
+ can_update_all_jobs = True
446
+ for old_job_model, new_job_spec in zip(job_models, new_job_specs):
447
+ old_job_spec = JobSpec.__response__.parse_raw(old_job_model.job_spec_data)
448
+ if new_job_spec != old_job_spec:
449
+ can_update_all_jobs = False
450
+ break
451
+ if can_update_all_jobs:
452
+ for job_model in job_models:
453
+ job_model.deployment_num = run_model.deployment_num
454
+
455
+
456
+ def _has_out_of_date_replicas(run: RunModel) -> bool:
457
+ for job in run.jobs:
458
+ if job.deployment_num < run.deployment_num and not (
459
+ job.status.is_finished() or job.termination_reason == JobTerminationReason.SCALED_DOWN
460
+ ):
461
+ return True
462
+ return False
463
+
464
+
381
465
  def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datetime.timedelta]:
382
466
  """
383
467
  Checks if the job should be retried.
@@ -393,7 +477,8 @@ def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datet
393
477
  break
394
478
 
395
479
  if (
396
- job_model.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
480
+ job_model.termination_reason is not None
481
+ and job_model.termination_reason.to_retry_event() == RetryEvent.NO_CAPACITY
397
482
  and last_provisioned_submission is None
398
483
  and RetryEvent.NO_CAPACITY in job.job_spec.retry.on_events
399
484
  ):
@@ -403,24 +488,9 @@ def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datet
403
488
  return None
404
489
 
405
490
  if (
406
- last_provisioned_submission.termination_reason
407
- == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
408
- and RetryEvent.INTERRUPTION in job.job_spec.retry.on_events
409
- ):
410
- return common.get_current_datetime() - last_provisioned_submission.last_processed_at
411
-
412
- if (
413
- last_provisioned_submission.termination_reason
414
- in [
415
- JobTerminationReason.CONTAINER_EXITED_WITH_ERROR,
416
- JobTerminationReason.CREATING_CONTAINER_ERROR,
417
- JobTerminationReason.EXECUTOR_ERROR,
418
- JobTerminationReason.GATEWAY_ERROR,
419
- JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED,
420
- JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED,
421
- JobTerminationReason.PORTS_BINDING_FAILED,
422
- ]
423
- and RetryEvent.ERROR in job.job_spec.retry.on_events
491
+ last_provisioned_submission.termination_reason is not None
492
+ and last_provisioned_submission.termination_reason.to_retry_event()
493
+ in job.job_spec.retry.on_events
424
494
  ):
425
495
  return common.get_current_datetime() - last_provisioned_submission.last_processed_at
426
496
 
@@ -0,0 +1,42 @@
1
+ """Add rolling deployment fields
2
+
3
+ Revision ID: 35e90e1b0d3e
4
+ Revises: 35f732ee4cf5
5
+ Create Date: 2025-05-29 15:30:27.878569
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ from alembic import op
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = "35e90e1b0d3e"
14
+ down_revision = "35f732ee4cf5"
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
21
+ batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True))
22
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
23
+ batch_op.execute("UPDATE jobs SET deployment_num = 0")
24
+ batch_op.alter_column("deployment_num", nullable=False)
25
+
26
+ with op.batch_alter_table("runs", schema=None) as batch_op:
27
+ batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True))
28
+ batch_op.add_column(sa.Column("desired_replica_count", sa.Integer(), nullable=True))
29
+ with op.batch_alter_table("runs", schema=None) as batch_op:
30
+ batch_op.execute("UPDATE runs SET deployment_num = 0")
31
+ batch_op.execute("UPDATE runs SET desired_replica_count = 1")
32
+ batch_op.alter_column("deployment_num", nullable=False)
33
+ batch_op.alter_column("desired_replica_count", nullable=False)
34
+
35
+
36
+ def downgrade() -> None:
37
+ with op.batch_alter_table("runs", schema=None) as batch_op:
38
+ batch_op.drop_column("deployment_num")
39
+ batch_op.drop_column("desired_replica_count")
40
+
41
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
42
+ batch_op.drop_column("deployment_num")