dstack 0.19.12rc1__py3-none-any.whl → 0.19.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (62) hide show
  1. dstack/_internal/cli/commands/attach.py +4 -4
  2. dstack/_internal/cli/services/configurators/run.py +44 -47
  3. dstack/_internal/cli/utils/run.py +31 -31
  4. dstack/_internal/core/backends/aws/compute.py +22 -9
  5. dstack/_internal/core/backends/aws/resources.py +26 -0
  6. dstack/_internal/core/backends/base/offers.py +0 -1
  7. dstack/_internal/core/backends/template/configurator.py.jinja +1 -6
  8. dstack/_internal/core/backends/template/models.py.jinja +4 -0
  9. dstack/_internal/core/compatibility/__init__.py +0 -0
  10. dstack/_internal/core/compatibility/fleets.py +72 -0
  11. dstack/_internal/core/compatibility/gateways.py +34 -0
  12. dstack/_internal/core/compatibility/runs.py +131 -0
  13. dstack/_internal/core/compatibility/volumes.py +32 -0
  14. dstack/_internal/core/models/configurations.py +1 -1
  15. dstack/_internal/core/models/fleets.py +6 -1
  16. dstack/_internal/core/models/instances.py +51 -12
  17. dstack/_internal/core/models/profiles.py +43 -3
  18. dstack/_internal/core/models/projects.py +1 -0
  19. dstack/_internal/core/models/repos/local.py +3 -3
  20. dstack/_internal/core/models/runs.py +139 -43
  21. dstack/_internal/server/app.py +46 -1
  22. dstack/_internal/server/background/tasks/process_running_jobs.py +92 -15
  23. dstack/_internal/server/background/tasks/process_runs.py +163 -80
  24. dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py +42 -0
  25. dstack/_internal/server/migrations/versions/35f732ee4cf5_add_projectmodel_is_public.py +39 -0
  26. dstack/_internal/server/models.py +4 -0
  27. dstack/_internal/server/routers/projects.py +4 -3
  28. dstack/_internal/server/routers/prometheus.py +4 -1
  29. dstack/_internal/server/schemas/projects.py +1 -0
  30. dstack/_internal/server/security/permissions.py +36 -0
  31. dstack/_internal/server/services/jobs/__init__.py +1 -0
  32. dstack/_internal/server/services/jobs/configurators/base.py +11 -7
  33. dstack/_internal/server/services/projects.py +54 -1
  34. dstack/_internal/server/services/runner/client.py +4 -1
  35. dstack/_internal/server/services/runs.py +49 -29
  36. dstack/_internal/server/services/services/__init__.py +19 -0
  37. dstack/_internal/server/services/services/autoscalers.py +37 -26
  38. dstack/_internal/server/services/storage/__init__.py +38 -0
  39. dstack/_internal/server/services/storage/base.py +27 -0
  40. dstack/_internal/server/services/storage/gcs.py +44 -0
  41. dstack/_internal/server/services/{storage.py → storage/s3.py} +4 -27
  42. dstack/_internal/server/settings.py +7 -3
  43. dstack/_internal/server/statics/index.html +1 -1
  44. dstack/_internal/server/statics/{main-5b9786c955b42bf93581.js → main-0ac1e1583684417ae4d1.js} +1695 -62
  45. dstack/_internal/server/statics/{main-5b9786c955b42bf93581.js.map → main-0ac1e1583684417ae4d1.js.map} +1 -1
  46. dstack/_internal/server/statics/{main-8f9c66f404e9c7e7e020.css → main-f39c418b05fe14772dd8.css} +1 -1
  47. dstack/_internal/server/testing/common.py +11 -1
  48. dstack/_internal/settings.py +3 -0
  49. dstack/_internal/utils/common.py +4 -0
  50. dstack/api/_public/runs.py +14 -5
  51. dstack/api/server/_fleets.py +9 -69
  52. dstack/api/server/_gateways.py +3 -14
  53. dstack/api/server/_projects.py +2 -2
  54. dstack/api/server/_runs.py +4 -116
  55. dstack/api/server/_volumes.py +3 -14
  56. dstack/plugins/builtin/rest_plugin/_plugin.py +24 -5
  57. dstack/version.py +2 -2
  58. {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/METADATA +1 -1
  59. {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/RECORD +62 -52
  60. {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/WHEEL +0 -0
  61. {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/entry_points.txt +0 -0
  62. {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/licenses/LICENSE.md +0 -0
@@ -11,6 +11,7 @@ from fastapi import FastAPI, Request, Response, status
11
11
  from fastapi.datastructures import URL
12
12
  from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
13
13
  from fastapi.staticfiles import StaticFiles
14
+ from prometheus_client import Counter, Histogram
14
15
 
15
16
  from dstack._internal.cli.utils.common import console
16
17
  from dstack._internal.core.errors import ForbiddenError, ServerClientError
@@ -63,6 +64,18 @@ from dstack._internal.utils.ssh import check_required_ssh_version
63
64
 
64
65
  logger = get_logger(__name__)
65
66
 
67
+ # Server HTTP metrics
68
+ REQUESTS_TOTAL = Counter(
69
+ "dstack_server_requests_total",
70
+ "Total number of HTTP requests",
71
+ ["method", "endpoint", "http_status", "project_name"],
72
+ )
73
+ REQUEST_DURATION = Histogram(
74
+ "dstack_server_request_duration_seconds",
75
+ "HTTP request duration in seconds",
76
+ ["method", "endpoint", "http_status", "project_name"],
77
+ )
78
+
66
79
 
67
80
  def create_app() -> FastAPI:
68
81
  if settings.SENTRY_DSN is not None:
@@ -128,7 +141,7 @@ async def lifespan(app: FastAPI):
128
141
  yes=UPDATE_DEFAULT_PROJECT,
129
142
  no=DO_NOT_UPDATE_DEFAULT_PROJECT,
130
143
  )
131
- if settings.SERVER_BUCKET is not None:
144
+ if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None:
132
145
  init_default_storage()
133
146
  scheduler = start_background_tasks()
134
147
  dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)"
@@ -216,6 +229,8 @@ def register_routes(app: FastAPI, ui: bool = True):
216
229
  start_time = time.time()
217
230
  response: Response = await call_next(request)
218
231
  process_time = time.time() - start_time
232
+ # log process_time to be used in the log_http_metrics middleware
233
+ request.state.process_time = process_time
219
234
  logger.debug(
220
235
  "Processed request %s %s in %s. Status: %s",
221
236
  request.method,
@@ -225,6 +240,36 @@ def register_routes(app: FastAPI, ui: bool = True):
225
240
  )
226
241
  return response
227
242
 
243
+ # this middleware must be defined after the log_request middleware
244
+ @app.middleware("http")
245
+ async def log_http_metrics(request: Request, call_next):
246
+ def _extract_project_name(request: Request):
247
+ project_name = None
248
+ prefix = "/api/project/"
249
+ if request.url.path.startswith(prefix):
250
+ rest = request.url.path[len(prefix) :]
251
+ project_name = rest.split("/", 1)[0] if rest else None
252
+
253
+ return project_name
254
+
255
+ project_name = _extract_project_name(request)
256
+ response: Response = await call_next(request)
257
+
258
+ REQUEST_DURATION.labels(
259
+ method=request.method,
260
+ endpoint=request.url.path,
261
+ http_status=response.status_code,
262
+ project_name=project_name,
263
+ ).observe(request.state.process_time)
264
+
265
+ REQUESTS_TOTAL.labels(
266
+ method=request.method,
267
+ endpoint=request.url.path,
268
+ http_status=response.status_code,
269
+ project_name=project_name,
270
+ ).inc()
271
+ return response
272
+
228
273
  @app.middleware("http")
229
274
  async def check_client_version(request: Request, call_next):
230
275
  if (
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import re
2
3
  from collections.abc import Iterable
3
4
  from datetime import timedelta, timezone
4
5
  from typing import Dict, List, Optional
@@ -7,6 +8,7 @@ from sqlalchemy import select
7
8
  from sqlalchemy.ext.asyncio import AsyncSession
8
9
  from sqlalchemy.orm import joinedload
9
10
 
11
+ from dstack._internal import settings
10
12
  from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT
11
13
  from dstack._internal.core.errors import GatewayError
12
14
  from dstack._internal.core.models.backends.base import BackendType
@@ -18,6 +20,7 @@ from dstack._internal.core.models.instances import (
18
20
  SSHConnectionParams,
19
21
  )
20
22
  from dstack._internal.core.models.metrics import Metric
23
+ from dstack._internal.core.models.profiles import StartupOrder
21
24
  from dstack._internal.core.models.repos import RemoteRepoCreds
22
25
  from dstack._internal.core.models.runs import (
23
26
  ClusterInfo,
@@ -184,18 +187,10 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
184
187
  if job_provisioning_data.hostname is None:
185
188
  await _wait_for_instance_provisioning_data(job_model=job_model)
186
189
  else:
187
- # Wait until all other jobs in the replica have IPs assigned.
188
- # This is needed to ensure cluster_info has all IPs set.
189
- for other_job in run.jobs:
190
- if (
191
- other_job.job_spec.replica_num == job.job_spec.replica_num
192
- and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
193
- and other_job.job_submissions[-1].job_provisioning_data is not None
194
- and other_job.job_submissions[-1].job_provisioning_data.hostname is None
195
- ):
196
- job_model.last_processed_at = common_utils.get_current_datetime()
197
- await session.commit()
198
- return
190
+ if _should_wait_for_other_nodes(run, job, job_model):
191
+ job_model.last_processed_at = common_utils.get_current_datetime()
192
+ await session.commit()
193
+ return
199
194
 
200
195
  # fails are acceptable until timeout is exceeded
201
196
  if job_provisioning_data.dockerized:
@@ -406,6 +401,48 @@ async def _wait_for_instance_provisioning_data(job_model: JobModel):
406
401
  job_model.job_provisioning_data = job_model.instance.job_provisioning_data
407
402
 
408
403
 
404
+ def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> bool:
405
+ for other_job in run.jobs:
406
+ if (
407
+ other_job.job_spec.replica_num == job.job_spec.replica_num
408
+ and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
409
+ and other_job.job_submissions[-1].job_provisioning_data is not None
410
+ and other_job.job_submissions[-1].job_provisioning_data.hostname is None
411
+ ):
412
+ logger.debug(
413
+ "%s: waiting for other job to have IP assigned",
414
+ fmt(job_model),
415
+ )
416
+ return True
417
+ master_job = find_job(run.jobs, job.job_spec.replica_num, 0)
418
+ if (
419
+ job.job_spec.job_num != 0
420
+ and run.run_spec.merged_profile.startup_order == StartupOrder.MASTER_FIRST
421
+ and master_job.job_submissions[-1].status != JobStatus.RUNNING
422
+ ):
423
+ logger.debug(
424
+ "%s: waiting for master job to become running",
425
+ fmt(job_model),
426
+ )
427
+ return True
428
+ if (
429
+ job.job_spec.job_num == 0
430
+ and run.run_spec.merged_profile.startup_order == StartupOrder.WORKERS_FIRST
431
+ ):
432
+ for other_job in run.jobs:
433
+ if (
434
+ other_job.job_spec.replica_num == job.job_spec.replica_num
435
+ and other_job.job_spec.job_num != job.job_spec.job_num
436
+ and other_job.job_submissions[-1].status != JobStatus.RUNNING
437
+ ):
438
+ logger.debug(
439
+ "%s: waiting for worker job to become running",
440
+ fmt(job_model),
441
+ )
442
+ return True
443
+ return False
444
+
445
+
409
446
  @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
410
447
  def _process_provisioning_with_shim(
411
448
  ports: Dict[int, int],
@@ -482,14 +519,14 @@ def _process_provisioning_with_shim(
482
519
  cpu = None
483
520
  memory = None
484
521
  network_mode = NetworkMode.HOST
485
-
522
+ image_name = _patch_base_image_for_aws_efa(job_spec, job_provisioning_data)
486
523
  if shim_client.is_api_v2_supported():
487
524
  shim_client.submit_task(
488
525
  task_id=job_model.id,
489
526
  name=job_model.job_name,
490
527
  registry_username=registry_username,
491
528
  registry_password=registry_password,
492
- image_name=job_spec.image_name,
529
+ image_name=image_name,
493
530
  container_user=container_user,
494
531
  privileged=job_spec.privileged,
495
532
  gpu=gpu,
@@ -510,7 +547,7 @@ def _process_provisioning_with_shim(
510
547
  submitted = shim_client.submit(
511
548
  username=registry_username,
512
549
  password=registry_password,
513
- image_name=job_spec.image_name,
550
+ image_name=image_name,
514
551
  privileged=job_spec.privileged,
515
552
  container_name=job_model.job_name,
516
553
  container_user=container_user,
@@ -934,3 +971,43 @@ def _get_instance_specific_gpu_devices(
934
971
  GPUDevice(path_on_host="/dev/nvidiactl", path_in_container="/dev/nvidiactl")
935
972
  )
936
973
  return gpu_devices
974
+
975
+
976
+ def _patch_base_image_for_aws_efa(
977
+ job_spec: JobSpec, job_provisioning_data: JobProvisioningData
978
+ ) -> str:
979
+ image_name = job_spec.image_name
980
+
981
+ if job_provisioning_data.backend != BackendType.AWS:
982
+ return image_name
983
+
984
+ instance_type = job_provisioning_data.instance_type.name
985
+ efa_enabled_patterns = [
986
+ # TODO: p6-b200 isn't supported yet in gpuhunt
987
+ r"^p6-b200\.(48xlarge)$",
988
+ r"^p5\.(48xlarge)$",
989
+ r"^p5e\.(48xlarge)$",
990
+ r"^p5en\.(48xlarge)$",
991
+ r"^p4d\.(24xlarge)$",
992
+ r"^p4de\.(24xlarge)$",
993
+ r"^g6\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
994
+ r"^g6e\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
995
+ r"^gr6\.8xlarge$",
996
+ r"^g5\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
997
+ r"^g4dn\.(8xlarge|12xlarge|16xlarge|metal)$",
998
+ r"^p3dn\.(24xlarge)$",
999
+ ]
1000
+
1001
+ is_efa_enabled = any(re.match(pattern, instance_type) for pattern in efa_enabled_patterns)
1002
+ if not is_efa_enabled:
1003
+ return image_name
1004
+
1005
+ if not image_name.startswith(f"{settings.DSTACK_BASE_IMAGE}:"):
1006
+ return image_name
1007
+
1008
+ if image_name.endswith(f"-base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"):
1009
+ return image_name[:-17] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"
1010
+ elif image_name.endswith(f"-devel-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"):
1011
+ return image_name[:-18] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"
1012
+
1013
+ return image_name
@@ -1,18 +1,17 @@
1
1
  import asyncio
2
2
  import datetime
3
- import itertools
4
3
  from typing import List, Optional, Set, Tuple
5
4
 
6
5
  from sqlalchemy import select
7
6
  from sqlalchemy.ext.asyncio import AsyncSession
8
7
  from sqlalchemy.orm import joinedload, selectinload
9
8
 
10
- import dstack._internal.server.services.gateways as gateways
11
9
  import dstack._internal.server.services.services.autoscalers as autoscalers
12
10
  from dstack._internal.core.errors import ServerError
13
- from dstack._internal.core.models.profiles import RetryEvent
11
+ from dstack._internal.core.models.profiles import RetryEvent, StopCriteria
14
12
  from dstack._internal.core.models.runs import (
15
13
  Job,
14
+ JobSpec,
16
15
  JobStatus,
17
16
  JobTerminationReason,
18
17
  Run,
@@ -24,22 +23,23 @@ from dstack._internal.server.db import get_session_ctx
24
23
  from dstack._internal.server.models import JobModel, ProjectModel, RunModel
25
24
  from dstack._internal.server.services.jobs import (
26
25
  find_job,
27
- get_jobs_from_run_spec,
26
+ get_job_specs_from_run_spec,
28
27
  group_jobs_by_replica_latest,
29
28
  )
30
29
  from dstack._internal.server.services.locking import get_locker
31
30
  from dstack._internal.server.services.runs import (
32
- create_job_model_for_new_submission,
33
31
  fmt,
34
32
  process_terminating_run,
35
33
  retry_run_replica_jobs,
36
34
  run_model_to_run,
37
35
  scale_run_replicas,
38
36
  )
37
+ from dstack._internal.server.services.services import update_service_desired_replica_count
39
38
  from dstack._internal.utils import common
40
39
  from dstack._internal.utils.logging import get_logger
41
40
 
42
41
  logger = get_logger(__name__)
42
+ ROLLING_DEPLOYMENT_MAX_SURGE = 1 # at most one extra replica during rolling deployment
43
43
 
44
44
 
45
45
  async def process_runs(batch_size: int = 1):
@@ -133,46 +133,22 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel):
133
133
  logger.debug("%s: pending run is not yet ready for resubmission", fmt(run_model))
134
134
  return
135
135
 
136
- # TODO(egor-s) consolidate with `scale_run_replicas` if possible
137
- replicas = 1
136
+ run_model.desired_replica_count = 1
138
137
  if run.run_spec.configuration.type == "service":
139
- replicas = run.run_spec.configuration.replicas.min or 0 # new default
140
- scaler = autoscalers.get_service_scaler(run.run_spec.configuration)
141
- stats = None
142
- if run_model.gateway_id is not None:
143
- conn = await gateways.get_or_add_gateway_connection(session, run_model.gateway_id)
144
- stats = await conn.get_stats(run_model.project.name, run_model.run_name)
145
- # replicas info doesn't matter for now
146
- replicas = scaler.scale([], stats)
147
- if replicas == 0:
138
+ run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0
139
+ await update_service_desired_replica_count(
140
+ session,
141
+ run_model,
142
+ run.run_spec.configuration,
143
+ # does not matter for pending services, since 0->n scaling should happen without delay
144
+ last_scaled_at=None,
145
+ )
146
+
147
+ if run_model.desired_replica_count == 0:
148
148
  # stay zero scaled
149
149
  return
150
150
 
151
- scheduled_replicas = 0
152
- # Resubmit existing replicas
153
- for replica_num, replica_jobs in itertools.groupby(
154
- run.jobs, key=lambda j: j.job_spec.replica_num
155
- ):
156
- if scheduled_replicas >= replicas:
157
- break
158
- scheduled_replicas += 1
159
- for job in replica_jobs:
160
- new_job_model = create_job_model_for_new_submission(
161
- run_model=run_model,
162
- job=job,
163
- status=JobStatus.SUBMITTED,
164
- )
165
- session.add(new_job_model)
166
- # Create missing replicas
167
- for replica_num in range(scheduled_replicas, replicas):
168
- jobs = await get_jobs_from_run_spec(run.run_spec, replica_num=replica_num)
169
- for job in jobs:
170
- job_model = create_job_model_for_new_submission(
171
- run_model=run_model,
172
- job=job,
173
- status=JobStatus.SUBMITTED,
174
- )
175
- session.add(job_model)
151
+ await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count)
176
152
 
177
153
  run_model.status = RunStatus.SUBMITTED
178
154
  logger.info("%s: run status has changed PENDING -> SUBMITTED", fmt(run_model))
@@ -313,6 +289,10 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
313
289
  termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED
314
290
  else:
315
291
  raise ValueError(f"Unexpected termination reason {run_termination_reasons}")
292
+ elif _should_stop_on_master_done(run):
293
+ new_status = RunStatus.TERMINATING
294
+ # ALL_JOBS_DONE is used for all DONE reasons including master-done
295
+ termination_reason = RunTerminationReason.ALL_JOBS_DONE
316
296
  elif RunStatus.RUNNING in run_statuses:
317
297
  new_status = RunStatus.RUNNING
318
298
  elif RunStatus.PROVISIONING in run_statuses:
@@ -336,27 +316,11 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
336
316
  job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
337
317
 
338
318
  if new_status not in {RunStatus.TERMINATING, RunStatus.PENDING}:
339
- # No need to retry if the run is terminating,
319
+ # No need to retry, scale, or redeploy replicas if the run is terminating,
340
320
  # pending run will retry replicas in `process_pending_run`
341
- for _, replica_jobs in replicas_to_retry:
342
- await retry_run_replica_jobs(
343
- session, run_model, replica_jobs, only_failed=retry_single_job
344
- )
345
-
346
- if run_spec.configuration.type == "service":
347
- scaler = autoscalers.get_service_scaler(run_spec.configuration)
348
- stats = None
349
- if run_model.gateway_id is not None:
350
- conn = await gateways.get_or_add_gateway_connection(session, run_model.gateway_id)
351
- stats = await conn.get_stats(run_model.project.name, run_model.run_name)
352
- # use replicas_info from before retrying
353
- replicas_diff = scaler.scale(replicas_info, stats)
354
- if replicas_diff != 0:
355
- # FIXME: potentially long write transaction
356
- # Why do we flush here?
357
- await session.flush()
358
- await session.refresh(run_model)
359
- await scale_run_replicas(session, run_model, replicas_diff)
321
+ await _handle_run_replicas(
322
+ session, run_model, run_spec, replicas_to_retry, retry_single_job, replicas_info
323
+ )
360
324
 
361
325
  if run_model.status != new_status:
362
326
  logger.info(
@@ -374,6 +338,130 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
374
338
  run_model.resubmission_attempt += 1
375
339
 
376
340
 
341
+ async def _handle_run_replicas(
342
+ session: AsyncSession,
343
+ run_model: RunModel,
344
+ run_spec: RunSpec,
345
+ replicas_to_retry: list[tuple[int, list[JobModel]]],
346
+ retry_single_job: bool,
347
+ replicas_info: list[autoscalers.ReplicaInfo],
348
+ ) -> None:
349
+ """
350
+ Does ONE of:
351
+ - replica retry
352
+ - replica scaling
353
+ - replica rolling deployment
354
+
355
+ Does not do everything at once to avoid conflicts between the stages and long DB transactions.
356
+ """
357
+
358
+ if replicas_to_retry:
359
+ for _, replica_jobs in replicas_to_retry:
360
+ await retry_run_replica_jobs(
361
+ session, run_model, replica_jobs, only_failed=retry_single_job
362
+ )
363
+ return
364
+
365
+ if run_spec.configuration.type == "service":
366
+ await update_service_desired_replica_count(
367
+ session,
368
+ run_model,
369
+ run_spec.configuration,
370
+ # FIXME: should only include scaling events, not retries and deployments
371
+ last_scaled_at=max((r.timestamp for r in replicas_info), default=None),
372
+ )
373
+
374
+ max_replica_count = run_model.desired_replica_count
375
+ if _has_out_of_date_replicas(run_model):
376
+ # allow extra replicas when deployment is in progress
377
+ max_replica_count += ROLLING_DEPLOYMENT_MAX_SURGE
378
+
379
+ active_replica_count = sum(1 for r in replicas_info if r.active)
380
+ if active_replica_count not in range(run_model.desired_replica_count, max_replica_count + 1):
381
+ await scale_run_replicas(
382
+ session,
383
+ run_model,
384
+ replicas_diff=run_model.desired_replica_count - active_replica_count,
385
+ )
386
+ return
387
+
388
+ await _update_jobs_to_new_deployment_in_place(run_model, run_spec)
389
+ if _has_out_of_date_replicas(run_model):
390
+ non_terminated_replica_count = len(
391
+ {j.replica_num for j in run_model.jobs if not j.status.is_finished()}
392
+ )
393
+ # Avoid using too much hardware during a deployment - never have
394
+ # more than max_replica_count non-terminated replicas.
395
+ if non_terminated_replica_count < max_replica_count:
396
+ # Start more up-to-date replicas that will eventually replace out-of-date replicas.
397
+ await scale_run_replicas(
398
+ session,
399
+ run_model,
400
+ replicas_diff=max_replica_count - non_terminated_replica_count,
401
+ )
402
+
403
+ replicas_to_stop_count = 0
404
+ # stop any out-of-date replicas that are not running
405
+ replicas_to_stop_count += len(
406
+ {
407
+ j.replica_num
408
+ for j in run_model.jobs
409
+ if j.status
410
+ not in [JobStatus.RUNNING, JobStatus.TERMINATING] + JobStatus.finished_statuses()
411
+ and j.deployment_num < run_model.deployment_num
412
+ }
413
+ )
414
+ running_replica_count = len(
415
+ {j.replica_num for j in run_model.jobs if j.status == JobStatus.RUNNING}
416
+ )
417
+ if running_replica_count > run_model.desired_replica_count:
418
+ # stop excessive running out-of-date replicas
419
+ replicas_to_stop_count += running_replica_count - run_model.desired_replica_count
420
+ if replicas_to_stop_count:
421
+ await scale_run_replicas(
422
+ session,
423
+ run_model,
424
+ replicas_diff=-replicas_to_stop_count,
425
+ )
426
+
427
+
428
+ async def _update_jobs_to_new_deployment_in_place(run_model: RunModel, run_spec: RunSpec) -> None:
429
+ """
430
+ Bump deployment_num for jobs that do not require redeployment.
431
+ """
432
+
433
+ for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
434
+ if all(j.status.is_finished() for j in job_models):
435
+ continue
436
+ if all(j.deployment_num == run_model.deployment_num for j in job_models):
437
+ continue
438
+ new_job_specs = await get_job_specs_from_run_spec(
439
+ run_spec=run_spec,
440
+ replica_num=replica_num,
441
+ )
442
+ assert len(new_job_specs) == len(job_models), (
443
+ "Changing the number of jobs within a replica is not yet supported"
444
+ )
445
+ can_update_all_jobs = True
446
+ for old_job_model, new_job_spec in zip(job_models, new_job_specs):
447
+ old_job_spec = JobSpec.__response__.parse_raw(old_job_model.job_spec_data)
448
+ if new_job_spec != old_job_spec:
449
+ can_update_all_jobs = False
450
+ break
451
+ if can_update_all_jobs:
452
+ for job_model in job_models:
453
+ job_model.deployment_num = run_model.deployment_num
454
+
455
+
456
+ def _has_out_of_date_replicas(run: RunModel) -> bool:
457
+ for job in run.jobs:
458
+ if job.deployment_num < run.deployment_num and not (
459
+ job.status.is_finished() or job.termination_reason == JobTerminationReason.SCALED_DOWN
460
+ ):
461
+ return True
462
+ return False
463
+
464
+
377
465
  def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datetime.timedelta]:
378
466
  """
379
467
  Checks if the job should be retried.
@@ -389,7 +477,8 @@ def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datet
389
477
  break
390
478
 
391
479
  if (
392
- job_model.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
480
+ job_model.termination_reason is not None
481
+ and job_model.termination_reason.to_retry_event() == RetryEvent.NO_CAPACITY
393
482
  and last_provisioned_submission is None
394
483
  and RetryEvent.NO_CAPACITY in job.job_spec.retry.on_events
395
484
  ):
@@ -399,24 +488,9 @@ def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datet
399
488
  return None
400
489
 
401
490
  if (
402
- last_provisioned_submission.termination_reason
403
- == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
404
- and RetryEvent.INTERRUPTION in job.job_spec.retry.on_events
405
- ):
406
- return common.get_current_datetime() - last_provisioned_submission.last_processed_at
407
-
408
- if (
409
- last_provisioned_submission.termination_reason
410
- in [
411
- JobTerminationReason.CONTAINER_EXITED_WITH_ERROR,
412
- JobTerminationReason.CREATING_CONTAINER_ERROR,
413
- JobTerminationReason.EXECUTOR_ERROR,
414
- JobTerminationReason.GATEWAY_ERROR,
415
- JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED,
416
- JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED,
417
- JobTerminationReason.PORTS_BINDING_FAILED,
418
- ]
419
- and RetryEvent.ERROR in job.job_spec.retry.on_events
491
+ last_provisioned_submission.termination_reason is not None
492
+ and last_provisioned_submission.termination_reason.to_retry_event()
493
+ in job.job_spec.retry.on_events
420
494
  ):
421
495
  return common.get_current_datetime() - last_provisioned_submission.last_processed_at
422
496
 
@@ -434,3 +508,12 @@ def _can_retry_single_job(run_spec: RunSpec) -> bool:
434
508
  # We could make partial retry in some multi-node cases.
435
509
  # E.g. restarting a worker node, independent jobs.
436
510
  return False
511
+
512
+
513
+ def _should_stop_on_master_done(run: Run) -> bool:
514
+ if run.run_spec.merged_profile.stop_criteria != StopCriteria.MASTER_DONE:
515
+ return False
516
+ for job in run.jobs:
517
+ if job.job_spec.job_num == 0 and job.job_submissions[-1].status == JobStatus.DONE:
518
+ return True
519
+ return False
@@ -0,0 +1,42 @@
1
+ """Add rolling deployment fields
2
+
3
+ Revision ID: 35e90e1b0d3e
4
+ Revises: 35f732ee4cf5
5
+ Create Date: 2025-05-29 15:30:27.878569
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ from alembic import op
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = "35e90e1b0d3e"
14
+ down_revision = "35f732ee4cf5"
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
21
+ batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True))
22
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
23
+ batch_op.execute("UPDATE jobs SET deployment_num = 0")
24
+ batch_op.alter_column("deployment_num", nullable=False)
25
+
26
+ with op.batch_alter_table("runs", schema=None) as batch_op:
27
+ batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True))
28
+ batch_op.add_column(sa.Column("desired_replica_count", sa.Integer(), nullable=True))
29
+ with op.batch_alter_table("runs", schema=None) as batch_op:
30
+ batch_op.execute("UPDATE runs SET deployment_num = 0")
31
+ batch_op.execute("UPDATE runs SET desired_replica_count = 1")
32
+ batch_op.alter_column("deployment_num", nullable=False)
33
+ batch_op.alter_column("desired_replica_count", nullable=False)
34
+
35
+
36
+ def downgrade() -> None:
37
+ with op.batch_alter_table("runs", schema=None) as batch_op:
38
+ batch_op.drop_column("deployment_num")
39
+ batch_op.drop_column("desired_replica_count")
40
+
41
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
42
+ batch_op.drop_column("deployment_num")
@@ -0,0 +1,39 @@
1
+ """Add ProjectModel.is_public
2
+
3
+ Revision ID: 35f732ee4cf5
4
+ Revises: bca2fdf130bf
5
+ Create Date: 2025-06-06 13:04:02.912032
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ from alembic import op
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = "35f732ee4cf5"
14
+ down_revision = "bca2fdf130bf"
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ # ### commands auto generated by Alembic - please adjust! ###
21
+ # Add is_public column as nullable first
22
+ with op.batch_alter_table("projects", schema=None) as batch_op:
23
+ batch_op.add_column(sa.Column("is_public", sa.Boolean(), nullable=True))
24
+
25
+ # Set is_public to False for existing projects
26
+ op.execute(sa.sql.text("UPDATE projects SET is_public = FALSE"))
27
+
28
+ # Make is_public non-nullable with default value
29
+ with op.batch_alter_table("projects", schema=None) as batch_op:
30
+ batch_op.alter_column("is_public", nullable=False, server_default=sa.false())
31
+ # ### end Alembic commands ###
32
+
33
+
34
+ def downgrade() -> None:
35
+ # ### commands auto generated by Alembic - please adjust! ###
36
+ # Remove is_public column
37
+ with op.batch_alter_table("projects", schema=None) as batch_op:
38
+ batch_op.drop_column("is_public")
39
+ # ### end Alembic commands ###
@@ -202,6 +202,7 @@ class ProjectModel(BaseModel):
202
202
  name: Mapped[str] = mapped_column(String(50), unique=True)
203
203
  created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime)
204
204
  deleted: Mapped[bool] = mapped_column(Boolean, default=False)
205
+ is_public: Mapped[bool] = mapped_column(Boolean, default=False)
205
206
 
206
207
  owner_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
207
208
  owner: Mapped[UserModel] = relationship(lazy="joined")
@@ -349,6 +350,8 @@ class RunModel(BaseModel):
349
350
  run_spec: Mapped[str] = mapped_column(Text)
350
351
  service_spec: Mapped[Optional[str]] = mapped_column(Text)
351
352
  priority: Mapped[int] = mapped_column(Integer, default=0)
353
+ deployment_num: Mapped[int] = mapped_column(Integer)
354
+ desired_replica_count: Mapped[int] = mapped_column(Integer)
352
355
 
353
356
  jobs: Mapped[List["JobModel"]] = relationship(
354
357
  back_populates="run", lazy="selectin", order_by="[JobModel.replica_num, JobModel.job_num]"
@@ -403,6 +406,7 @@ class JobModel(BaseModel):
403
406
  instance: Mapped[Optional["InstanceModel"]] = relationship(back_populates="jobs")
404
407
  used_instance_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUIDType(binary=False))
405
408
  replica_num: Mapped[int] = mapped_column(Integer)
409
+ deployment_num: Mapped[int] = mapped_column(Integer)
406
410
  job_runtime_data: Mapped[Optional[str]] = mapped_column(Text)
407
411
 
408
412