dstack 0.19.12rc1__py3-none-any.whl → 0.19.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dstack might be problematic. Click here for more details.
- dstack/_internal/cli/commands/attach.py +4 -4
- dstack/_internal/cli/services/configurators/run.py +44 -47
- dstack/_internal/cli/utils/run.py +31 -31
- dstack/_internal/core/backends/aws/compute.py +22 -9
- dstack/_internal/core/backends/aws/resources.py +26 -0
- dstack/_internal/core/backends/base/offers.py +0 -1
- dstack/_internal/core/backends/template/configurator.py.jinja +1 -6
- dstack/_internal/core/backends/template/models.py.jinja +4 -0
- dstack/_internal/core/compatibility/__init__.py +0 -0
- dstack/_internal/core/compatibility/fleets.py +72 -0
- dstack/_internal/core/compatibility/gateways.py +34 -0
- dstack/_internal/core/compatibility/runs.py +131 -0
- dstack/_internal/core/compatibility/volumes.py +32 -0
- dstack/_internal/core/models/configurations.py +1 -1
- dstack/_internal/core/models/fleets.py +6 -1
- dstack/_internal/core/models/instances.py +51 -12
- dstack/_internal/core/models/profiles.py +43 -3
- dstack/_internal/core/models/projects.py +1 -0
- dstack/_internal/core/models/repos/local.py +3 -3
- dstack/_internal/core/models/runs.py +139 -43
- dstack/_internal/server/app.py +46 -1
- dstack/_internal/server/background/tasks/process_running_jobs.py +92 -15
- dstack/_internal/server/background/tasks/process_runs.py +163 -80
- dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py +42 -0
- dstack/_internal/server/migrations/versions/35f732ee4cf5_add_projectmodel_is_public.py +39 -0
- dstack/_internal/server/models.py +4 -0
- dstack/_internal/server/routers/projects.py +4 -3
- dstack/_internal/server/routers/prometheus.py +4 -1
- dstack/_internal/server/schemas/projects.py +1 -0
- dstack/_internal/server/security/permissions.py +36 -0
- dstack/_internal/server/services/jobs/__init__.py +1 -0
- dstack/_internal/server/services/jobs/configurators/base.py +11 -7
- dstack/_internal/server/services/projects.py +54 -1
- dstack/_internal/server/services/runner/client.py +4 -1
- dstack/_internal/server/services/runs.py +49 -29
- dstack/_internal/server/services/services/__init__.py +19 -0
- dstack/_internal/server/services/services/autoscalers.py +37 -26
- dstack/_internal/server/services/storage/__init__.py +38 -0
- dstack/_internal/server/services/storage/base.py +27 -0
- dstack/_internal/server/services/storage/gcs.py +44 -0
- dstack/_internal/server/services/{storage.py → storage/s3.py} +4 -27
- dstack/_internal/server/settings.py +7 -3
- dstack/_internal/server/statics/index.html +1 -1
- dstack/_internal/server/statics/{main-5b9786c955b42bf93581.js → main-0ac1e1583684417ae4d1.js} +1695 -62
- dstack/_internal/server/statics/{main-5b9786c955b42bf93581.js.map → main-0ac1e1583684417ae4d1.js.map} +1 -1
- dstack/_internal/server/statics/{main-8f9c66f404e9c7e7e020.css → main-f39c418b05fe14772dd8.css} +1 -1
- dstack/_internal/server/testing/common.py +11 -1
- dstack/_internal/settings.py +3 -0
- dstack/_internal/utils/common.py +4 -0
- dstack/api/_public/runs.py +14 -5
- dstack/api/server/_fleets.py +9 -69
- dstack/api/server/_gateways.py +3 -14
- dstack/api/server/_projects.py +2 -2
- dstack/api/server/_runs.py +4 -116
- dstack/api/server/_volumes.py +3 -14
- dstack/plugins/builtin/rest_plugin/_plugin.py +24 -5
- dstack/version.py +2 -2
- {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/METADATA +1 -1
- {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/RECORD +62 -52
- {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/WHEEL +0 -0
- {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/entry_points.txt +0 -0
- {dstack-0.19.12rc1.dist-info → dstack-0.19.14.dist-info}/licenses/LICENSE.md +0 -0
dstack/_internal/server/app.py
CHANGED
|
@@ -11,6 +11,7 @@ from fastapi import FastAPI, Request, Response, status
|
|
|
11
11
|
from fastapi.datastructures import URL
|
|
12
12
|
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
|
13
13
|
from fastapi.staticfiles import StaticFiles
|
|
14
|
+
from prometheus_client import Counter, Histogram
|
|
14
15
|
|
|
15
16
|
from dstack._internal.cli.utils.common import console
|
|
16
17
|
from dstack._internal.core.errors import ForbiddenError, ServerClientError
|
|
@@ -63,6 +64,18 @@ from dstack._internal.utils.ssh import check_required_ssh_version
|
|
|
63
64
|
|
|
64
65
|
logger = get_logger(__name__)
|
|
65
66
|
|
|
67
|
+
# Server HTTP metrics
|
|
68
|
+
REQUESTS_TOTAL = Counter(
|
|
69
|
+
"dstack_server_requests_total",
|
|
70
|
+
"Total number of HTTP requests",
|
|
71
|
+
["method", "endpoint", "http_status", "project_name"],
|
|
72
|
+
)
|
|
73
|
+
REQUEST_DURATION = Histogram(
|
|
74
|
+
"dstack_server_request_duration_seconds",
|
|
75
|
+
"HTTP request duration in seconds",
|
|
76
|
+
["method", "endpoint", "http_status", "project_name"],
|
|
77
|
+
)
|
|
78
|
+
|
|
66
79
|
|
|
67
80
|
def create_app() -> FastAPI:
|
|
68
81
|
if settings.SENTRY_DSN is not None:
|
|
@@ -128,7 +141,7 @@ async def lifespan(app: FastAPI):
|
|
|
128
141
|
yes=UPDATE_DEFAULT_PROJECT,
|
|
129
142
|
no=DO_NOT_UPDATE_DEFAULT_PROJECT,
|
|
130
143
|
)
|
|
131
|
-
if settings.
|
|
144
|
+
if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None:
|
|
132
145
|
init_default_storage()
|
|
133
146
|
scheduler = start_background_tasks()
|
|
134
147
|
dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)"
|
|
@@ -216,6 +229,8 @@ def register_routes(app: FastAPI, ui: bool = True):
|
|
|
216
229
|
start_time = time.time()
|
|
217
230
|
response: Response = await call_next(request)
|
|
218
231
|
process_time = time.time() - start_time
|
|
232
|
+
# log process_time to be used in the log_http_metrics middleware
|
|
233
|
+
request.state.process_time = process_time
|
|
219
234
|
logger.debug(
|
|
220
235
|
"Processed request %s %s in %s. Status: %s",
|
|
221
236
|
request.method,
|
|
@@ -225,6 +240,36 @@ def register_routes(app: FastAPI, ui: bool = True):
|
|
|
225
240
|
)
|
|
226
241
|
return response
|
|
227
242
|
|
|
243
|
+
# this middleware must be defined after the log_request middleware
|
|
244
|
+
@app.middleware("http")
|
|
245
|
+
async def log_http_metrics(request: Request, call_next):
|
|
246
|
+
def _extract_project_name(request: Request):
|
|
247
|
+
project_name = None
|
|
248
|
+
prefix = "/api/project/"
|
|
249
|
+
if request.url.path.startswith(prefix):
|
|
250
|
+
rest = request.url.path[len(prefix) :]
|
|
251
|
+
project_name = rest.split("/", 1)[0] if rest else None
|
|
252
|
+
|
|
253
|
+
return project_name
|
|
254
|
+
|
|
255
|
+
project_name = _extract_project_name(request)
|
|
256
|
+
response: Response = await call_next(request)
|
|
257
|
+
|
|
258
|
+
REQUEST_DURATION.labels(
|
|
259
|
+
method=request.method,
|
|
260
|
+
endpoint=request.url.path,
|
|
261
|
+
http_status=response.status_code,
|
|
262
|
+
project_name=project_name,
|
|
263
|
+
).observe(request.state.process_time)
|
|
264
|
+
|
|
265
|
+
REQUESTS_TOTAL.labels(
|
|
266
|
+
method=request.method,
|
|
267
|
+
endpoint=request.url.path,
|
|
268
|
+
http_status=response.status_code,
|
|
269
|
+
project_name=project_name,
|
|
270
|
+
).inc()
|
|
271
|
+
return response
|
|
272
|
+
|
|
228
273
|
@app.middleware("http")
|
|
229
274
|
async def check_client_version(request: Request, call_next):
|
|
230
275
|
if (
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import re
|
|
2
3
|
from collections.abc import Iterable
|
|
3
4
|
from datetime import timedelta, timezone
|
|
4
5
|
from typing import Dict, List, Optional
|
|
@@ -7,6 +8,7 @@ from sqlalchemy import select
|
|
|
7
8
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
8
9
|
from sqlalchemy.orm import joinedload
|
|
9
10
|
|
|
11
|
+
from dstack._internal import settings
|
|
10
12
|
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT
|
|
11
13
|
from dstack._internal.core.errors import GatewayError
|
|
12
14
|
from dstack._internal.core.models.backends.base import BackendType
|
|
@@ -18,6 +20,7 @@ from dstack._internal.core.models.instances import (
|
|
|
18
20
|
SSHConnectionParams,
|
|
19
21
|
)
|
|
20
22
|
from dstack._internal.core.models.metrics import Metric
|
|
23
|
+
from dstack._internal.core.models.profiles import StartupOrder
|
|
21
24
|
from dstack._internal.core.models.repos import RemoteRepoCreds
|
|
22
25
|
from dstack._internal.core.models.runs import (
|
|
23
26
|
ClusterInfo,
|
|
@@ -184,18 +187,10 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
|
|
|
184
187
|
if job_provisioning_data.hostname is None:
|
|
185
188
|
await _wait_for_instance_provisioning_data(job_model=job_model)
|
|
186
189
|
else:
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
other_job.job_spec.replica_num == job.job_spec.replica_num
|
|
192
|
-
and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
|
|
193
|
-
and other_job.job_submissions[-1].job_provisioning_data is not None
|
|
194
|
-
and other_job.job_submissions[-1].job_provisioning_data.hostname is None
|
|
195
|
-
):
|
|
196
|
-
job_model.last_processed_at = common_utils.get_current_datetime()
|
|
197
|
-
await session.commit()
|
|
198
|
-
return
|
|
190
|
+
if _should_wait_for_other_nodes(run, job, job_model):
|
|
191
|
+
job_model.last_processed_at = common_utils.get_current_datetime()
|
|
192
|
+
await session.commit()
|
|
193
|
+
return
|
|
199
194
|
|
|
200
195
|
# fails are acceptable until timeout is exceeded
|
|
201
196
|
if job_provisioning_data.dockerized:
|
|
@@ -406,6 +401,48 @@ async def _wait_for_instance_provisioning_data(job_model: JobModel):
|
|
|
406
401
|
job_model.job_provisioning_data = job_model.instance.job_provisioning_data
|
|
407
402
|
|
|
408
403
|
|
|
404
|
+
def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> bool:
|
|
405
|
+
for other_job in run.jobs:
|
|
406
|
+
if (
|
|
407
|
+
other_job.job_spec.replica_num == job.job_spec.replica_num
|
|
408
|
+
and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
|
|
409
|
+
and other_job.job_submissions[-1].job_provisioning_data is not None
|
|
410
|
+
and other_job.job_submissions[-1].job_provisioning_data.hostname is None
|
|
411
|
+
):
|
|
412
|
+
logger.debug(
|
|
413
|
+
"%s: waiting for other job to have IP assigned",
|
|
414
|
+
fmt(job_model),
|
|
415
|
+
)
|
|
416
|
+
return True
|
|
417
|
+
master_job = find_job(run.jobs, job.job_spec.replica_num, 0)
|
|
418
|
+
if (
|
|
419
|
+
job.job_spec.job_num != 0
|
|
420
|
+
and run.run_spec.merged_profile.startup_order == StartupOrder.MASTER_FIRST
|
|
421
|
+
and master_job.job_submissions[-1].status != JobStatus.RUNNING
|
|
422
|
+
):
|
|
423
|
+
logger.debug(
|
|
424
|
+
"%s: waiting for master job to become running",
|
|
425
|
+
fmt(job_model),
|
|
426
|
+
)
|
|
427
|
+
return True
|
|
428
|
+
if (
|
|
429
|
+
job.job_spec.job_num == 0
|
|
430
|
+
and run.run_spec.merged_profile.startup_order == StartupOrder.WORKERS_FIRST
|
|
431
|
+
):
|
|
432
|
+
for other_job in run.jobs:
|
|
433
|
+
if (
|
|
434
|
+
other_job.job_spec.replica_num == job.job_spec.replica_num
|
|
435
|
+
and other_job.job_spec.job_num != job.job_spec.job_num
|
|
436
|
+
and other_job.job_submissions[-1].status != JobStatus.RUNNING
|
|
437
|
+
):
|
|
438
|
+
logger.debug(
|
|
439
|
+
"%s: waiting for worker job to become running",
|
|
440
|
+
fmt(job_model),
|
|
441
|
+
)
|
|
442
|
+
return True
|
|
443
|
+
return False
|
|
444
|
+
|
|
445
|
+
|
|
409
446
|
@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
|
|
410
447
|
def _process_provisioning_with_shim(
|
|
411
448
|
ports: Dict[int, int],
|
|
@@ -482,14 +519,14 @@ def _process_provisioning_with_shim(
|
|
|
482
519
|
cpu = None
|
|
483
520
|
memory = None
|
|
484
521
|
network_mode = NetworkMode.HOST
|
|
485
|
-
|
|
522
|
+
image_name = _patch_base_image_for_aws_efa(job_spec, job_provisioning_data)
|
|
486
523
|
if shim_client.is_api_v2_supported():
|
|
487
524
|
shim_client.submit_task(
|
|
488
525
|
task_id=job_model.id,
|
|
489
526
|
name=job_model.job_name,
|
|
490
527
|
registry_username=registry_username,
|
|
491
528
|
registry_password=registry_password,
|
|
492
|
-
image_name=
|
|
529
|
+
image_name=image_name,
|
|
493
530
|
container_user=container_user,
|
|
494
531
|
privileged=job_spec.privileged,
|
|
495
532
|
gpu=gpu,
|
|
@@ -510,7 +547,7 @@ def _process_provisioning_with_shim(
|
|
|
510
547
|
submitted = shim_client.submit(
|
|
511
548
|
username=registry_username,
|
|
512
549
|
password=registry_password,
|
|
513
|
-
image_name=
|
|
550
|
+
image_name=image_name,
|
|
514
551
|
privileged=job_spec.privileged,
|
|
515
552
|
container_name=job_model.job_name,
|
|
516
553
|
container_user=container_user,
|
|
@@ -934,3 +971,43 @@ def _get_instance_specific_gpu_devices(
|
|
|
934
971
|
GPUDevice(path_on_host="/dev/nvidiactl", path_in_container="/dev/nvidiactl")
|
|
935
972
|
)
|
|
936
973
|
return gpu_devices
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
def _patch_base_image_for_aws_efa(
|
|
977
|
+
job_spec: JobSpec, job_provisioning_data: JobProvisioningData
|
|
978
|
+
) -> str:
|
|
979
|
+
image_name = job_spec.image_name
|
|
980
|
+
|
|
981
|
+
if job_provisioning_data.backend != BackendType.AWS:
|
|
982
|
+
return image_name
|
|
983
|
+
|
|
984
|
+
instance_type = job_provisioning_data.instance_type.name
|
|
985
|
+
efa_enabled_patterns = [
|
|
986
|
+
# TODO: p6-b200 isn't supported yet in gpuhunt
|
|
987
|
+
r"^p6-b200\.(48xlarge)$",
|
|
988
|
+
r"^p5\.(48xlarge)$",
|
|
989
|
+
r"^p5e\.(48xlarge)$",
|
|
990
|
+
r"^p5en\.(48xlarge)$",
|
|
991
|
+
r"^p4d\.(24xlarge)$",
|
|
992
|
+
r"^p4de\.(24xlarge)$",
|
|
993
|
+
r"^g6\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
|
|
994
|
+
r"^g6e\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
|
|
995
|
+
r"^gr6\.8xlarge$",
|
|
996
|
+
r"^g5\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
|
|
997
|
+
r"^g4dn\.(8xlarge|12xlarge|16xlarge|metal)$",
|
|
998
|
+
r"^p3dn\.(24xlarge)$",
|
|
999
|
+
]
|
|
1000
|
+
|
|
1001
|
+
is_efa_enabled = any(re.match(pattern, instance_type) for pattern in efa_enabled_patterns)
|
|
1002
|
+
if not is_efa_enabled:
|
|
1003
|
+
return image_name
|
|
1004
|
+
|
|
1005
|
+
if not image_name.startswith(f"{settings.DSTACK_BASE_IMAGE}:"):
|
|
1006
|
+
return image_name
|
|
1007
|
+
|
|
1008
|
+
if image_name.endswith(f"-base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"):
|
|
1009
|
+
return image_name[:-17] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"
|
|
1010
|
+
elif image_name.endswith(f"-devel-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"):
|
|
1011
|
+
return image_name[:-18] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"
|
|
1012
|
+
|
|
1013
|
+
return image_name
|
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import datetime
|
|
3
|
-
import itertools
|
|
4
3
|
from typing import List, Optional, Set, Tuple
|
|
5
4
|
|
|
6
5
|
from sqlalchemy import select
|
|
7
6
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
8
7
|
from sqlalchemy.orm import joinedload, selectinload
|
|
9
8
|
|
|
10
|
-
import dstack._internal.server.services.gateways as gateways
|
|
11
9
|
import dstack._internal.server.services.services.autoscalers as autoscalers
|
|
12
10
|
from dstack._internal.core.errors import ServerError
|
|
13
|
-
from dstack._internal.core.models.profiles import RetryEvent
|
|
11
|
+
from dstack._internal.core.models.profiles import RetryEvent, StopCriteria
|
|
14
12
|
from dstack._internal.core.models.runs import (
|
|
15
13
|
Job,
|
|
14
|
+
JobSpec,
|
|
16
15
|
JobStatus,
|
|
17
16
|
JobTerminationReason,
|
|
18
17
|
Run,
|
|
@@ -24,22 +23,23 @@ from dstack._internal.server.db import get_session_ctx
|
|
|
24
23
|
from dstack._internal.server.models import JobModel, ProjectModel, RunModel
|
|
25
24
|
from dstack._internal.server.services.jobs import (
|
|
26
25
|
find_job,
|
|
27
|
-
|
|
26
|
+
get_job_specs_from_run_spec,
|
|
28
27
|
group_jobs_by_replica_latest,
|
|
29
28
|
)
|
|
30
29
|
from dstack._internal.server.services.locking import get_locker
|
|
31
30
|
from dstack._internal.server.services.runs import (
|
|
32
|
-
create_job_model_for_new_submission,
|
|
33
31
|
fmt,
|
|
34
32
|
process_terminating_run,
|
|
35
33
|
retry_run_replica_jobs,
|
|
36
34
|
run_model_to_run,
|
|
37
35
|
scale_run_replicas,
|
|
38
36
|
)
|
|
37
|
+
from dstack._internal.server.services.services import update_service_desired_replica_count
|
|
39
38
|
from dstack._internal.utils import common
|
|
40
39
|
from dstack._internal.utils.logging import get_logger
|
|
41
40
|
|
|
42
41
|
logger = get_logger(__name__)
|
|
42
|
+
ROLLING_DEPLOYMENT_MAX_SURGE = 1 # at most one extra replica during rolling deployment
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
async def process_runs(batch_size: int = 1):
|
|
@@ -133,46 +133,22 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel):
|
|
|
133
133
|
logger.debug("%s: pending run is not yet ready for resubmission", fmt(run_model))
|
|
134
134
|
return
|
|
135
135
|
|
|
136
|
-
|
|
137
|
-
replicas = 1
|
|
136
|
+
run_model.desired_replica_count = 1
|
|
138
137
|
if run.run_spec.configuration.type == "service":
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
138
|
+
run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0
|
|
139
|
+
await update_service_desired_replica_count(
|
|
140
|
+
session,
|
|
141
|
+
run_model,
|
|
142
|
+
run.run_spec.configuration,
|
|
143
|
+
# does not matter for pending services, since 0->n scaling should happen without delay
|
|
144
|
+
last_scaled_at=None,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if run_model.desired_replica_count == 0:
|
|
148
148
|
# stay zero scaled
|
|
149
149
|
return
|
|
150
150
|
|
|
151
|
-
|
|
152
|
-
# Resubmit existing replicas
|
|
153
|
-
for replica_num, replica_jobs in itertools.groupby(
|
|
154
|
-
run.jobs, key=lambda j: j.job_spec.replica_num
|
|
155
|
-
):
|
|
156
|
-
if scheduled_replicas >= replicas:
|
|
157
|
-
break
|
|
158
|
-
scheduled_replicas += 1
|
|
159
|
-
for job in replica_jobs:
|
|
160
|
-
new_job_model = create_job_model_for_new_submission(
|
|
161
|
-
run_model=run_model,
|
|
162
|
-
job=job,
|
|
163
|
-
status=JobStatus.SUBMITTED,
|
|
164
|
-
)
|
|
165
|
-
session.add(new_job_model)
|
|
166
|
-
# Create missing replicas
|
|
167
|
-
for replica_num in range(scheduled_replicas, replicas):
|
|
168
|
-
jobs = await get_jobs_from_run_spec(run.run_spec, replica_num=replica_num)
|
|
169
|
-
for job in jobs:
|
|
170
|
-
job_model = create_job_model_for_new_submission(
|
|
171
|
-
run_model=run_model,
|
|
172
|
-
job=job,
|
|
173
|
-
status=JobStatus.SUBMITTED,
|
|
174
|
-
)
|
|
175
|
-
session.add(job_model)
|
|
151
|
+
await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count)
|
|
176
152
|
|
|
177
153
|
run_model.status = RunStatus.SUBMITTED
|
|
178
154
|
logger.info("%s: run status has changed PENDING -> SUBMITTED", fmt(run_model))
|
|
@@ -313,6 +289,10 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
|
|
|
313
289
|
termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED
|
|
314
290
|
else:
|
|
315
291
|
raise ValueError(f"Unexpected termination reason {run_termination_reasons}")
|
|
292
|
+
elif _should_stop_on_master_done(run):
|
|
293
|
+
new_status = RunStatus.TERMINATING
|
|
294
|
+
# ALL_JOBS_DONE is used for all DONE reasons including master-done
|
|
295
|
+
termination_reason = RunTerminationReason.ALL_JOBS_DONE
|
|
316
296
|
elif RunStatus.RUNNING in run_statuses:
|
|
317
297
|
new_status = RunStatus.RUNNING
|
|
318
298
|
elif RunStatus.PROVISIONING in run_statuses:
|
|
@@ -336,27 +316,11 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
|
|
|
336
316
|
job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
|
|
337
317
|
|
|
338
318
|
if new_status not in {RunStatus.TERMINATING, RunStatus.PENDING}:
|
|
339
|
-
# No need to retry if the run is terminating,
|
|
319
|
+
# No need to retry, scale, or redeploy replicas if the run is terminating,
|
|
340
320
|
# pending run will retry replicas in `process_pending_run`
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
if run_spec.configuration.type == "service":
|
|
347
|
-
scaler = autoscalers.get_service_scaler(run_spec.configuration)
|
|
348
|
-
stats = None
|
|
349
|
-
if run_model.gateway_id is not None:
|
|
350
|
-
conn = await gateways.get_or_add_gateway_connection(session, run_model.gateway_id)
|
|
351
|
-
stats = await conn.get_stats(run_model.project.name, run_model.run_name)
|
|
352
|
-
# use replicas_info from before retrying
|
|
353
|
-
replicas_diff = scaler.scale(replicas_info, stats)
|
|
354
|
-
if replicas_diff != 0:
|
|
355
|
-
# FIXME: potentially long write transaction
|
|
356
|
-
# Why do we flush here?
|
|
357
|
-
await session.flush()
|
|
358
|
-
await session.refresh(run_model)
|
|
359
|
-
await scale_run_replicas(session, run_model, replicas_diff)
|
|
321
|
+
await _handle_run_replicas(
|
|
322
|
+
session, run_model, run_spec, replicas_to_retry, retry_single_job, replicas_info
|
|
323
|
+
)
|
|
360
324
|
|
|
361
325
|
if run_model.status != new_status:
|
|
362
326
|
logger.info(
|
|
@@ -374,6 +338,130 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
|
|
|
374
338
|
run_model.resubmission_attempt += 1
|
|
375
339
|
|
|
376
340
|
|
|
341
|
+
async def _handle_run_replicas(
|
|
342
|
+
session: AsyncSession,
|
|
343
|
+
run_model: RunModel,
|
|
344
|
+
run_spec: RunSpec,
|
|
345
|
+
replicas_to_retry: list[tuple[int, list[JobModel]]],
|
|
346
|
+
retry_single_job: bool,
|
|
347
|
+
replicas_info: list[autoscalers.ReplicaInfo],
|
|
348
|
+
) -> None:
|
|
349
|
+
"""
|
|
350
|
+
Does ONE of:
|
|
351
|
+
- replica retry
|
|
352
|
+
- replica scaling
|
|
353
|
+
- replica rolling deployment
|
|
354
|
+
|
|
355
|
+
Does not do everything at once to avoid conflicts between the stages and long DB transactions.
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
if replicas_to_retry:
|
|
359
|
+
for _, replica_jobs in replicas_to_retry:
|
|
360
|
+
await retry_run_replica_jobs(
|
|
361
|
+
session, run_model, replica_jobs, only_failed=retry_single_job
|
|
362
|
+
)
|
|
363
|
+
return
|
|
364
|
+
|
|
365
|
+
if run_spec.configuration.type == "service":
|
|
366
|
+
await update_service_desired_replica_count(
|
|
367
|
+
session,
|
|
368
|
+
run_model,
|
|
369
|
+
run_spec.configuration,
|
|
370
|
+
# FIXME: should only include scaling events, not retries and deployments
|
|
371
|
+
last_scaled_at=max((r.timestamp for r in replicas_info), default=None),
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
max_replica_count = run_model.desired_replica_count
|
|
375
|
+
if _has_out_of_date_replicas(run_model):
|
|
376
|
+
# allow extra replicas when deployment is in progress
|
|
377
|
+
max_replica_count += ROLLING_DEPLOYMENT_MAX_SURGE
|
|
378
|
+
|
|
379
|
+
active_replica_count = sum(1 for r in replicas_info if r.active)
|
|
380
|
+
if active_replica_count not in range(run_model.desired_replica_count, max_replica_count + 1):
|
|
381
|
+
await scale_run_replicas(
|
|
382
|
+
session,
|
|
383
|
+
run_model,
|
|
384
|
+
replicas_diff=run_model.desired_replica_count - active_replica_count,
|
|
385
|
+
)
|
|
386
|
+
return
|
|
387
|
+
|
|
388
|
+
await _update_jobs_to_new_deployment_in_place(run_model, run_spec)
|
|
389
|
+
if _has_out_of_date_replicas(run_model):
|
|
390
|
+
non_terminated_replica_count = len(
|
|
391
|
+
{j.replica_num for j in run_model.jobs if not j.status.is_finished()}
|
|
392
|
+
)
|
|
393
|
+
# Avoid using too much hardware during a deployment - never have
|
|
394
|
+
# more than max_replica_count non-terminated replicas.
|
|
395
|
+
if non_terminated_replica_count < max_replica_count:
|
|
396
|
+
# Start more up-to-date replicas that will eventually replace out-of-date replicas.
|
|
397
|
+
await scale_run_replicas(
|
|
398
|
+
session,
|
|
399
|
+
run_model,
|
|
400
|
+
replicas_diff=max_replica_count - non_terminated_replica_count,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
replicas_to_stop_count = 0
|
|
404
|
+
# stop any out-of-date replicas that are not running
|
|
405
|
+
replicas_to_stop_count += len(
|
|
406
|
+
{
|
|
407
|
+
j.replica_num
|
|
408
|
+
for j in run_model.jobs
|
|
409
|
+
if j.status
|
|
410
|
+
not in [JobStatus.RUNNING, JobStatus.TERMINATING] + JobStatus.finished_statuses()
|
|
411
|
+
and j.deployment_num < run_model.deployment_num
|
|
412
|
+
}
|
|
413
|
+
)
|
|
414
|
+
running_replica_count = len(
|
|
415
|
+
{j.replica_num for j in run_model.jobs if j.status == JobStatus.RUNNING}
|
|
416
|
+
)
|
|
417
|
+
if running_replica_count > run_model.desired_replica_count:
|
|
418
|
+
# stop excessive running out-of-date replicas
|
|
419
|
+
replicas_to_stop_count += running_replica_count - run_model.desired_replica_count
|
|
420
|
+
if replicas_to_stop_count:
|
|
421
|
+
await scale_run_replicas(
|
|
422
|
+
session,
|
|
423
|
+
run_model,
|
|
424
|
+
replicas_diff=-replicas_to_stop_count,
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
async def _update_jobs_to_new_deployment_in_place(run_model: RunModel, run_spec: RunSpec) -> None:
|
|
429
|
+
"""
|
|
430
|
+
Bump deployment_num for jobs that do not require redeployment.
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
|
|
434
|
+
if all(j.status.is_finished() for j in job_models):
|
|
435
|
+
continue
|
|
436
|
+
if all(j.deployment_num == run_model.deployment_num for j in job_models):
|
|
437
|
+
continue
|
|
438
|
+
new_job_specs = await get_job_specs_from_run_spec(
|
|
439
|
+
run_spec=run_spec,
|
|
440
|
+
replica_num=replica_num,
|
|
441
|
+
)
|
|
442
|
+
assert len(new_job_specs) == len(job_models), (
|
|
443
|
+
"Changing the number of jobs within a replica is not yet supported"
|
|
444
|
+
)
|
|
445
|
+
can_update_all_jobs = True
|
|
446
|
+
for old_job_model, new_job_spec in zip(job_models, new_job_specs):
|
|
447
|
+
old_job_spec = JobSpec.__response__.parse_raw(old_job_model.job_spec_data)
|
|
448
|
+
if new_job_spec != old_job_spec:
|
|
449
|
+
can_update_all_jobs = False
|
|
450
|
+
break
|
|
451
|
+
if can_update_all_jobs:
|
|
452
|
+
for job_model in job_models:
|
|
453
|
+
job_model.deployment_num = run_model.deployment_num
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def _has_out_of_date_replicas(run: RunModel) -> bool:
|
|
457
|
+
for job in run.jobs:
|
|
458
|
+
if job.deployment_num < run.deployment_num and not (
|
|
459
|
+
job.status.is_finished() or job.termination_reason == JobTerminationReason.SCALED_DOWN
|
|
460
|
+
):
|
|
461
|
+
return True
|
|
462
|
+
return False
|
|
463
|
+
|
|
464
|
+
|
|
377
465
|
def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datetime.timedelta]:
|
|
378
466
|
"""
|
|
379
467
|
Checks if the job should be retried.
|
|
@@ -389,7 +477,8 @@ def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datet
|
|
|
389
477
|
break
|
|
390
478
|
|
|
391
479
|
if (
|
|
392
|
-
job_model.termination_reason
|
|
480
|
+
job_model.termination_reason is not None
|
|
481
|
+
and job_model.termination_reason.to_retry_event() == RetryEvent.NO_CAPACITY
|
|
393
482
|
and last_provisioned_submission is None
|
|
394
483
|
and RetryEvent.NO_CAPACITY in job.job_spec.retry.on_events
|
|
395
484
|
):
|
|
@@ -399,24 +488,9 @@ def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datet
|
|
|
399
488
|
return None
|
|
400
489
|
|
|
401
490
|
if (
|
|
402
|
-
last_provisioned_submission.termination_reason
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
):
|
|
406
|
-
return common.get_current_datetime() - last_provisioned_submission.last_processed_at
|
|
407
|
-
|
|
408
|
-
if (
|
|
409
|
-
last_provisioned_submission.termination_reason
|
|
410
|
-
in [
|
|
411
|
-
JobTerminationReason.CONTAINER_EXITED_WITH_ERROR,
|
|
412
|
-
JobTerminationReason.CREATING_CONTAINER_ERROR,
|
|
413
|
-
JobTerminationReason.EXECUTOR_ERROR,
|
|
414
|
-
JobTerminationReason.GATEWAY_ERROR,
|
|
415
|
-
JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED,
|
|
416
|
-
JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED,
|
|
417
|
-
JobTerminationReason.PORTS_BINDING_FAILED,
|
|
418
|
-
]
|
|
419
|
-
and RetryEvent.ERROR in job.job_spec.retry.on_events
|
|
491
|
+
last_provisioned_submission.termination_reason is not None
|
|
492
|
+
and last_provisioned_submission.termination_reason.to_retry_event()
|
|
493
|
+
in job.job_spec.retry.on_events
|
|
420
494
|
):
|
|
421
495
|
return common.get_current_datetime() - last_provisioned_submission.last_processed_at
|
|
422
496
|
|
|
@@ -434,3 +508,12 @@ def _can_retry_single_job(run_spec: RunSpec) -> bool:
|
|
|
434
508
|
# We could make partial retry in some multi-node cases.
|
|
435
509
|
# E.g. restarting a worker node, independent jobs.
|
|
436
510
|
return False
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def _should_stop_on_master_done(run: Run) -> bool:
|
|
514
|
+
if run.run_spec.merged_profile.stop_criteria != StopCriteria.MASTER_DONE:
|
|
515
|
+
return False
|
|
516
|
+
for job in run.jobs:
|
|
517
|
+
if job.job_spec.job_num == 0 and job.job_submissions[-1].status == JobStatus.DONE:
|
|
518
|
+
return True
|
|
519
|
+
return False
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Add rolling deployment fields
|
|
2
|
+
|
|
3
|
+
Revision ID: 35e90e1b0d3e
|
|
4
|
+
Revises: 35f732ee4cf5
|
|
5
|
+
Create Date: 2025-05-29 15:30:27.878569
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import sqlalchemy as sa
|
|
10
|
+
from alembic import op
|
|
11
|
+
|
|
12
|
+
# revision identifiers, used by Alembic.
|
|
13
|
+
revision = "35e90e1b0d3e"
|
|
14
|
+
down_revision = "35f732ee4cf5"
|
|
15
|
+
branch_labels = None
|
|
16
|
+
depends_on = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def upgrade() -> None:
|
|
20
|
+
with op.batch_alter_table("jobs", schema=None) as batch_op:
|
|
21
|
+
batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True))
|
|
22
|
+
with op.batch_alter_table("jobs", schema=None) as batch_op:
|
|
23
|
+
batch_op.execute("UPDATE jobs SET deployment_num = 0")
|
|
24
|
+
batch_op.alter_column("deployment_num", nullable=False)
|
|
25
|
+
|
|
26
|
+
with op.batch_alter_table("runs", schema=None) as batch_op:
|
|
27
|
+
batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True))
|
|
28
|
+
batch_op.add_column(sa.Column("desired_replica_count", sa.Integer(), nullable=True))
|
|
29
|
+
with op.batch_alter_table("runs", schema=None) as batch_op:
|
|
30
|
+
batch_op.execute("UPDATE runs SET deployment_num = 0")
|
|
31
|
+
batch_op.execute("UPDATE runs SET desired_replica_count = 1")
|
|
32
|
+
batch_op.alter_column("deployment_num", nullable=False)
|
|
33
|
+
batch_op.alter_column("desired_replica_count", nullable=False)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def downgrade() -> None:
|
|
37
|
+
with op.batch_alter_table("runs", schema=None) as batch_op:
|
|
38
|
+
batch_op.drop_column("deployment_num")
|
|
39
|
+
batch_op.drop_column("desired_replica_count")
|
|
40
|
+
|
|
41
|
+
with op.batch_alter_table("jobs", schema=None) as batch_op:
|
|
42
|
+
batch_op.drop_column("deployment_num")
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Add ProjectModel.is_public
|
|
2
|
+
|
|
3
|
+
Revision ID: 35f732ee4cf5
|
|
4
|
+
Revises: bca2fdf130bf
|
|
5
|
+
Create Date: 2025-06-06 13:04:02.912032
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import sqlalchemy as sa
|
|
10
|
+
from alembic import op
|
|
11
|
+
|
|
12
|
+
# revision identifiers, used by Alembic.
|
|
13
|
+
revision = "35f732ee4cf5"
|
|
14
|
+
down_revision = "bca2fdf130bf"
|
|
15
|
+
branch_labels = None
|
|
16
|
+
depends_on = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def upgrade() -> None:
|
|
20
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
21
|
+
# Add is_public column as nullable first
|
|
22
|
+
with op.batch_alter_table("projects", schema=None) as batch_op:
|
|
23
|
+
batch_op.add_column(sa.Column("is_public", sa.Boolean(), nullable=True))
|
|
24
|
+
|
|
25
|
+
# Set is_public to False for existing projects
|
|
26
|
+
op.execute(sa.sql.text("UPDATE projects SET is_public = FALSE"))
|
|
27
|
+
|
|
28
|
+
# Make is_public non-nullable with default value
|
|
29
|
+
with op.batch_alter_table("projects", schema=None) as batch_op:
|
|
30
|
+
batch_op.alter_column("is_public", nullable=False, server_default=sa.false())
|
|
31
|
+
# ### end Alembic commands ###
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def downgrade() -> None:
|
|
35
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
36
|
+
# Remove is_public column
|
|
37
|
+
with op.batch_alter_table("projects", schema=None) as batch_op:
|
|
38
|
+
batch_op.drop_column("is_public")
|
|
39
|
+
# ### end Alembic commands ###
|
|
@@ -202,6 +202,7 @@ class ProjectModel(BaseModel):
|
|
|
202
202
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
|
203
203
|
created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime)
|
|
204
204
|
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
|
205
|
+
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
|
|
205
206
|
|
|
206
207
|
owner_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
|
|
207
208
|
owner: Mapped[UserModel] = relationship(lazy="joined")
|
|
@@ -349,6 +350,8 @@ class RunModel(BaseModel):
|
|
|
349
350
|
run_spec: Mapped[str] = mapped_column(Text)
|
|
350
351
|
service_spec: Mapped[Optional[str]] = mapped_column(Text)
|
|
351
352
|
priority: Mapped[int] = mapped_column(Integer, default=0)
|
|
353
|
+
deployment_num: Mapped[int] = mapped_column(Integer)
|
|
354
|
+
desired_replica_count: Mapped[int] = mapped_column(Integer)
|
|
352
355
|
|
|
353
356
|
jobs: Mapped[List["JobModel"]] = relationship(
|
|
354
357
|
back_populates="run", lazy="selectin", order_by="[JobModel.replica_num, JobModel.job_num]"
|
|
@@ -403,6 +406,7 @@ class JobModel(BaseModel):
|
|
|
403
406
|
instance: Mapped[Optional["InstanceModel"]] = relationship(back_populates="jobs")
|
|
404
407
|
used_instance_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUIDType(binary=False))
|
|
405
408
|
replica_num: Mapped[int] = mapped_column(Integer)
|
|
409
|
+
deployment_num: Mapped[int] = mapped_column(Integer)
|
|
406
410
|
job_runtime_data: Mapped[Optional[str]] = mapped_column(Text)
|
|
407
411
|
|
|
408
412
|
|