dstack 0.19.11rc1__py3-none-any.whl → 0.19.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dstack might be problematic. Click here for more details.
- dstack/_internal/cli/commands/offer.py +2 -0
- dstack/_internal/cli/services/configurators/run.py +43 -42
- dstack/_internal/cli/utils/run.py +10 -26
- dstack/_internal/cli/utils/updates.py +13 -1
- dstack/_internal/core/backends/aws/compute.py +21 -9
- dstack/_internal/core/backends/base/compute.py +7 -3
- dstack/_internal/core/backends/gcp/compute.py +43 -20
- dstack/_internal/core/backends/gcp/resources.py +18 -2
- dstack/_internal/core/backends/local/compute.py +4 -2
- dstack/_internal/core/backends/template/configurator.py.jinja +1 -6
- dstack/_internal/core/backends/template/models.py.jinja +4 -0
- dstack/_internal/core/models/configurations.py +1 -1
- dstack/_internal/core/models/fleets.py +6 -1
- dstack/_internal/core/models/profiles.py +43 -3
- dstack/_internal/core/models/repos/local.py +19 -13
- dstack/_internal/core/models/runs.py +78 -45
- dstack/_internal/server/background/tasks/process_running_jobs.py +47 -12
- dstack/_internal/server/background/tasks/process_runs.py +14 -1
- dstack/_internal/server/background/tasks/process_submitted_jobs.py +3 -3
- dstack/_internal/server/routers/repos.py +9 -4
- dstack/_internal/server/services/fleets.py +2 -2
- dstack/_internal/server/services/gateways/__init__.py +1 -1
- dstack/_internal/server/services/jobs/__init__.py +4 -4
- dstack/_internal/server/services/plugins.py +64 -32
- dstack/_internal/server/services/runner/client.py +4 -1
- dstack/_internal/server/services/runs.py +2 -2
- dstack/_internal/server/services/volumes.py +1 -1
- dstack/_internal/server/statics/index.html +1 -1
- dstack/_internal/server/statics/{main-b4803049eac16aea9a49.js → main-b0e80f8e26a168c129e9.js} +72 -25
- dstack/_internal/server/statics/{main-b4803049eac16aea9a49.js.map → main-b0e80f8e26a168c129e9.js.map} +1 -1
- dstack/_internal/server/testing/common.py +2 -1
- dstack/_internal/utils/common.py +4 -0
- dstack/api/server/_fleets.py +5 -1
- dstack/api/server/_runs.py +8 -0
- dstack/plugins/builtin/__init__.py +0 -0
- dstack/plugins/builtin/rest_plugin/__init__.py +18 -0
- dstack/plugins/builtin/rest_plugin/_models.py +48 -0
- dstack/plugins/builtin/rest_plugin/_plugin.py +127 -0
- dstack/version.py +1 -1
- {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/METADATA +2 -2
- {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/RECORD +44 -41
- dstack/_internal/utils/ignore.py +0 -92
- {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/WHEEL +0 -0
- {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/entry_points.txt +0 -0
- {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -6,6 +6,7 @@ from typing_extensions import Annotated, Literal
|
|
|
6
6
|
|
|
7
7
|
from dstack._internal.core.models.backends.base import BackendType
|
|
8
8
|
from dstack._internal.core.models.common import CoreModel, Duration
|
|
9
|
+
from dstack._internal.utils.common import list_enum_values_for_annotation
|
|
9
10
|
from dstack._internal.utils.tags import tags_validator
|
|
10
11
|
|
|
11
12
|
DEFAULT_RETRY_DURATION = 3600
|
|
@@ -32,6 +33,17 @@ class TerminationPolicy(str, Enum):
|
|
|
32
33
|
DESTROY_AFTER_IDLE = "destroy-after-idle"
|
|
33
34
|
|
|
34
35
|
|
|
36
|
+
class StartupOrder(str, Enum):
|
|
37
|
+
ANY = "any"
|
|
38
|
+
MASTER_FIRST = "master-first"
|
|
39
|
+
WORKERS_FIRST = "workers-first"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class StopCriteria(str, Enum):
|
|
43
|
+
ALL_DONE = "all-done"
|
|
44
|
+
MASTER_DONE = "master-done"
|
|
45
|
+
|
|
46
|
+
|
|
35
47
|
@overload
|
|
36
48
|
def parse_duration(v: None) -> None: ...
|
|
37
49
|
|
|
@@ -102,7 +114,7 @@ class ProfileRetry(CoreModel):
|
|
|
102
114
|
Field(
|
|
103
115
|
description=(
|
|
104
116
|
"The list of events that should be handled with retry."
|
|
105
|
-
" Supported events are
|
|
117
|
+
f" Supported events are {list_enum_values_for_annotation(RetryEvent)}."
|
|
106
118
|
" Omit to retry on all events"
|
|
107
119
|
)
|
|
108
120
|
),
|
|
@@ -190,7 +202,11 @@ class ProfileParams(CoreModel):
|
|
|
190
202
|
spot_policy: Annotated[
|
|
191
203
|
Optional[SpotPolicy],
|
|
192
204
|
Field(
|
|
193
|
-
description=
|
|
205
|
+
description=(
|
|
206
|
+
"The policy for provisioning spot or on-demand instances:"
|
|
207
|
+
f" {list_enum_values_for_annotation(SpotPolicy)}."
|
|
208
|
+
f" Defaults to `{SpotPolicy.ONDEMAND.value}`"
|
|
209
|
+
)
|
|
194
210
|
),
|
|
195
211
|
] = None
|
|
196
212
|
retry: Annotated[
|
|
@@ -225,7 +241,11 @@ class ProfileParams(CoreModel):
|
|
|
225
241
|
creation_policy: Annotated[
|
|
226
242
|
Optional[CreationPolicy],
|
|
227
243
|
Field(
|
|
228
|
-
description=
|
|
244
|
+
description=(
|
|
245
|
+
"The policy for using instances from fleets:"
|
|
246
|
+
f" {list_enum_values_for_annotation(CreationPolicy)}."
|
|
247
|
+
f" Defaults to `{CreationPolicy.REUSE_OR_CREATE.value}`"
|
|
248
|
+
)
|
|
229
249
|
),
|
|
230
250
|
] = None
|
|
231
251
|
idle_duration: Annotated[
|
|
@@ -241,6 +261,26 @@ class ProfileParams(CoreModel):
|
|
|
241
261
|
Optional[UtilizationPolicy],
|
|
242
262
|
Field(description="Run termination policy based on utilization"),
|
|
243
263
|
] = None
|
|
264
|
+
startup_order: Annotated[
|
|
265
|
+
Optional[StartupOrder],
|
|
266
|
+
Field(
|
|
267
|
+
description=(
|
|
268
|
+
f"The order in which master and workers jobs are started:"
|
|
269
|
+
f" {list_enum_values_for_annotation(StartupOrder)}."
|
|
270
|
+
f" Defaults to `{StartupOrder.ANY.value}`"
|
|
271
|
+
)
|
|
272
|
+
),
|
|
273
|
+
] = None
|
|
274
|
+
stop_criteria: Annotated[
|
|
275
|
+
Optional[StopCriteria],
|
|
276
|
+
Field(
|
|
277
|
+
description=(
|
|
278
|
+
"The criteria determining when a multi-node run should be considered finished:"
|
|
279
|
+
f" {list_enum_values_for_annotation(StopCriteria)}."
|
|
280
|
+
f" Defaults to `{StopCriteria.ALL_DONE.value}`"
|
|
281
|
+
)
|
|
282
|
+
),
|
|
283
|
+
] = None
|
|
244
284
|
fleets: Annotated[
|
|
245
285
|
Optional[list[str]], Field(description="The fleets considered for reuse")
|
|
246
286
|
] = None
|
|
@@ -2,13 +2,18 @@ import tarfile
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from typing import BinaryIO, Optional
|
|
4
4
|
|
|
5
|
+
import ignore
|
|
6
|
+
import ignore.overrides
|
|
5
7
|
from typing_extensions import Literal
|
|
6
8
|
|
|
7
9
|
from dstack._internal.core.models.repos.base import BaseRepoInfo, Repo
|
|
10
|
+
from dstack._internal.utils.common import sizeof_fmt
|
|
8
11
|
from dstack._internal.utils.hash import get_sha256, slugify
|
|
9
|
-
from dstack._internal.utils.
|
|
12
|
+
from dstack._internal.utils.logging import get_logger
|
|
10
13
|
from dstack._internal.utils.path import PathLike
|
|
11
14
|
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
12
17
|
|
|
13
18
|
class LocalRepoInfo(BaseRepoInfo):
|
|
14
19
|
repo_type: Literal["local"] = "local"
|
|
@@ -69,22 +74,23 @@ class LocalRepo(Repo):
|
|
|
69
74
|
self.run_repo_data = repo_data
|
|
70
75
|
|
|
71
76
|
def write_code_file(self, fp: BinaryIO) -> str:
|
|
77
|
+
repo_path = Path(self.run_repo_data.repo_dir)
|
|
72
78
|
with tarfile.TarFile(mode="w", fileobj=fp) as t:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
79
|
+
for entry in (
|
|
80
|
+
ignore.WalkBuilder(repo_path)
|
|
81
|
+
.overrides(ignore.overrides.OverrideBuilder(repo_path).add("!/.git/").build())
|
|
82
|
+
.hidden(False) # do not ignore files that start with a dot
|
|
83
|
+
.require_git(False) # respect git ignore rules even if not a git repo
|
|
84
|
+
.add_custom_ignore_filename(".dstackignore")
|
|
85
|
+
.build()
|
|
86
|
+
):
|
|
87
|
+
entry_path_within_repo = entry.path().relative_to(repo_path)
|
|
88
|
+
if entry_path_within_repo != Path("."):
|
|
89
|
+
t.add(entry.path(), arcname=entry_path_within_repo, recursive=False)
|
|
90
|
+
logger.debug("Code file size: %s", sizeof_fmt(fp.tell()))
|
|
78
91
|
return get_sha256(fp)
|
|
79
92
|
|
|
80
93
|
def get_repo_info(self) -> LocalRepoInfo:
|
|
81
94
|
return LocalRepoInfo(
|
|
82
95
|
repo_dir=self.run_repo_data.repo_dir,
|
|
83
96
|
)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
class TarIgnore(GitIgnore):
|
|
87
|
-
def __call__(self, tarinfo: tarfile.TarInfo) -> Optional[tarfile.TarInfo]:
|
|
88
|
-
if self.ignore(tarinfo.path):
|
|
89
|
-
return None
|
|
90
|
-
return tarinfo
|
|
@@ -148,9 +148,6 @@ class JobTerminationReason(str, Enum):
|
|
|
148
148
|
}
|
|
149
149
|
return mapping[self]
|
|
150
150
|
|
|
151
|
-
def pretty_repr(self) -> str:
|
|
152
|
-
return " ".join(self.value.split("_")).capitalize()
|
|
153
|
-
|
|
154
151
|
|
|
155
152
|
class Requirements(CoreModel):
|
|
156
153
|
# TODO: Make requirements' fields required
|
|
@@ -289,6 +286,9 @@ class JobSubmission(CoreModel):
|
|
|
289
286
|
exit_status: Optional[int]
|
|
290
287
|
job_provisioning_data: Optional[JobProvisioningData]
|
|
291
288
|
job_runtime_data: Optional[JobRuntimeData]
|
|
289
|
+
# TODO: make status_message and error a computed field after migrating to pydanticV2
|
|
290
|
+
status_message: Optional[str]
|
|
291
|
+
error: Optional[str] = None
|
|
292
292
|
|
|
293
293
|
@property
|
|
294
294
|
def age(self) -> timedelta:
|
|
@@ -301,6 +301,71 @@ class JobSubmission(CoreModel):
|
|
|
301
301
|
end_time = self.finished_at
|
|
302
302
|
return end_time - self.submitted_at
|
|
303
303
|
|
|
304
|
+
@root_validator
|
|
305
|
+
def _status_message(cls, values) -> Dict:
|
|
306
|
+
try:
|
|
307
|
+
status = values["status"]
|
|
308
|
+
termination_reason = values["termination_reason"]
|
|
309
|
+
exit_code = values["exit_status"]
|
|
310
|
+
except KeyError:
|
|
311
|
+
return values
|
|
312
|
+
values["status_message"] = JobSubmission._get_status_message(
|
|
313
|
+
status=status,
|
|
314
|
+
termination_reason=termination_reason,
|
|
315
|
+
exit_status=exit_code,
|
|
316
|
+
)
|
|
317
|
+
return values
|
|
318
|
+
|
|
319
|
+
@staticmethod
|
|
320
|
+
def _get_status_message(
|
|
321
|
+
status: JobStatus,
|
|
322
|
+
termination_reason: Optional[JobTerminationReason],
|
|
323
|
+
exit_status: Optional[int],
|
|
324
|
+
) -> str:
|
|
325
|
+
if status == JobStatus.DONE:
|
|
326
|
+
return "exited (0)"
|
|
327
|
+
elif status == JobStatus.FAILED:
|
|
328
|
+
if termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR:
|
|
329
|
+
return f"exited ({exit_status})"
|
|
330
|
+
elif termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY:
|
|
331
|
+
return "no offers"
|
|
332
|
+
elif termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY:
|
|
333
|
+
return "interrupted"
|
|
334
|
+
else:
|
|
335
|
+
return "error"
|
|
336
|
+
elif status == JobStatus.TERMINATED:
|
|
337
|
+
if termination_reason == JobTerminationReason.TERMINATED_BY_USER:
|
|
338
|
+
return "stopped"
|
|
339
|
+
elif termination_reason == JobTerminationReason.ABORTED_BY_USER:
|
|
340
|
+
return "aborted"
|
|
341
|
+
return status.value
|
|
342
|
+
|
|
343
|
+
@root_validator
|
|
344
|
+
def _error(cls, values) -> Dict:
|
|
345
|
+
try:
|
|
346
|
+
termination_reason = values["termination_reason"]
|
|
347
|
+
except KeyError:
|
|
348
|
+
return values
|
|
349
|
+
values["error"] = JobSubmission._get_error(termination_reason=termination_reason)
|
|
350
|
+
return values
|
|
351
|
+
|
|
352
|
+
@staticmethod
|
|
353
|
+
def _get_error(termination_reason: Optional[JobTerminationReason]) -> Optional[str]:
|
|
354
|
+
error_mapping = {
|
|
355
|
+
JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable",
|
|
356
|
+
JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded",
|
|
357
|
+
JobTerminationReason.VOLUME_ERROR: "waiting runner limit exceeded",
|
|
358
|
+
JobTerminationReason.GATEWAY_ERROR: "gateway error",
|
|
359
|
+
JobTerminationReason.SCALED_DOWN: "scaled down",
|
|
360
|
+
JobTerminationReason.INACTIVITY_DURATION_EXCEEDED: "inactivity duration exceeded",
|
|
361
|
+
JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY: "utilization policy",
|
|
362
|
+
JobTerminationReason.PORTS_BINDING_FAILED: "ports binding failed",
|
|
363
|
+
JobTerminationReason.CREATING_CONTAINER_ERROR: "runner error",
|
|
364
|
+
JobTerminationReason.EXECUTOR_ERROR: "executor error",
|
|
365
|
+
JobTerminationReason.MAX_DURATION_EXCEEDED: "max duration exceeded",
|
|
366
|
+
}
|
|
367
|
+
return error_mapping.get(termination_reason)
|
|
368
|
+
|
|
304
369
|
|
|
305
370
|
class Job(CoreModel):
|
|
306
371
|
job_spec: JobSpec
|
|
@@ -445,15 +510,20 @@ class Run(CoreModel):
|
|
|
445
510
|
def _error(cls, values) -> Dict:
|
|
446
511
|
try:
|
|
447
512
|
termination_reason = values["termination_reason"]
|
|
448
|
-
jobs = values["jobs"]
|
|
449
513
|
except KeyError:
|
|
450
514
|
return values
|
|
451
|
-
values["error"] =
|
|
452
|
-
run_termination_reason=termination_reason,
|
|
453
|
-
run_jobs=jobs,
|
|
454
|
-
)
|
|
515
|
+
values["error"] = Run._get_error(termination_reason=termination_reason)
|
|
455
516
|
return values
|
|
456
517
|
|
|
518
|
+
@staticmethod
|
|
519
|
+
def _get_error(termination_reason: Optional[RunTerminationReason]) -> Optional[str]:
|
|
520
|
+
if termination_reason == RunTerminationReason.RETRY_LIMIT_EXCEEDED:
|
|
521
|
+
return "retry limit exceeded"
|
|
522
|
+
elif termination_reason == RunTerminationReason.SERVER_ERROR:
|
|
523
|
+
return "server error"
|
|
524
|
+
else:
|
|
525
|
+
return None
|
|
526
|
+
|
|
457
527
|
|
|
458
528
|
class JobPlan(CoreModel):
|
|
459
529
|
job_spec: JobSpec
|
|
@@ -502,40 +572,3 @@ def get_policy_map(spot_policy: Optional[SpotPolicy], default: SpotPolicy) -> Op
|
|
|
502
572
|
SpotPolicy.ONDEMAND: False,
|
|
503
573
|
}
|
|
504
574
|
return policy_map[spot_policy]
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
def _get_run_error(
|
|
508
|
-
run_termination_reason: Optional[RunTerminationReason],
|
|
509
|
-
run_jobs: List[Job],
|
|
510
|
-
) -> str:
|
|
511
|
-
if run_termination_reason is None:
|
|
512
|
-
return ""
|
|
513
|
-
if len(run_jobs) > 1:
|
|
514
|
-
return run_termination_reason.name
|
|
515
|
-
run_job_termination_reason, exit_status = _get_run_job_termination_reason_and_exit_status(
|
|
516
|
-
run_jobs
|
|
517
|
-
)
|
|
518
|
-
# For failed runs, also show termination reason to provide more context.
|
|
519
|
-
# For other run statuses, the job termination reason will duplicate run status.
|
|
520
|
-
if run_job_termination_reason is not None and run_termination_reason in [
|
|
521
|
-
RunTerminationReason.JOB_FAILED,
|
|
522
|
-
RunTerminationReason.SERVER_ERROR,
|
|
523
|
-
RunTerminationReason.RETRY_LIMIT_EXCEEDED,
|
|
524
|
-
]:
|
|
525
|
-
if exit_status:
|
|
526
|
-
return (
|
|
527
|
-
f"{run_termination_reason.name}\n({run_job_termination_reason.name} {exit_status})"
|
|
528
|
-
)
|
|
529
|
-
return f"{run_termination_reason.name}\n({run_job_termination_reason.name})"
|
|
530
|
-
return run_termination_reason.name
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
def _get_run_job_termination_reason_and_exit_status(
|
|
534
|
-
run_jobs: List[Job],
|
|
535
|
-
) -> tuple[Optional[JobTerminationReason], Optional[int]]:
|
|
536
|
-
for job in run_jobs:
|
|
537
|
-
if len(job.job_submissions) > 0:
|
|
538
|
-
job_submission = job.job_submissions[-1]
|
|
539
|
-
if job_submission.termination_reason is not None:
|
|
540
|
-
return job_submission.termination_reason, job_submission.exit_status
|
|
541
|
-
return None, None
|
|
@@ -18,6 +18,7 @@ from dstack._internal.core.models.instances import (
|
|
|
18
18
|
SSHConnectionParams,
|
|
19
19
|
)
|
|
20
20
|
from dstack._internal.core.models.metrics import Metric
|
|
21
|
+
from dstack._internal.core.models.profiles import StartupOrder
|
|
21
22
|
from dstack._internal.core.models.repos import RemoteRepoCreds
|
|
22
23
|
from dstack._internal.core.models.runs import (
|
|
23
24
|
ClusterInfo,
|
|
@@ -184,18 +185,10 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
|
|
|
184
185
|
if job_provisioning_data.hostname is None:
|
|
185
186
|
await _wait_for_instance_provisioning_data(job_model=job_model)
|
|
186
187
|
else:
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
other_job.job_spec.replica_num == job.job_spec.replica_num
|
|
192
|
-
and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
|
|
193
|
-
and other_job.job_submissions[-1].job_provisioning_data is not None
|
|
194
|
-
and other_job.job_submissions[-1].job_provisioning_data.hostname is None
|
|
195
|
-
):
|
|
196
|
-
job_model.last_processed_at = common_utils.get_current_datetime()
|
|
197
|
-
await session.commit()
|
|
198
|
-
return
|
|
188
|
+
if _should_wait_for_other_nodes(run, job, job_model):
|
|
189
|
+
job_model.last_processed_at = common_utils.get_current_datetime()
|
|
190
|
+
await session.commit()
|
|
191
|
+
return
|
|
199
192
|
|
|
200
193
|
# fails are acceptable until timeout is exceeded
|
|
201
194
|
if job_provisioning_data.dockerized:
|
|
@@ -406,6 +399,48 @@ async def _wait_for_instance_provisioning_data(job_model: JobModel):
|
|
|
406
399
|
job_model.job_provisioning_data = job_model.instance.job_provisioning_data
|
|
407
400
|
|
|
408
401
|
|
|
402
|
+
def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> bool:
|
|
403
|
+
for other_job in run.jobs:
|
|
404
|
+
if (
|
|
405
|
+
other_job.job_spec.replica_num == job.job_spec.replica_num
|
|
406
|
+
and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
|
|
407
|
+
and other_job.job_submissions[-1].job_provisioning_data is not None
|
|
408
|
+
and other_job.job_submissions[-1].job_provisioning_data.hostname is None
|
|
409
|
+
):
|
|
410
|
+
logger.debug(
|
|
411
|
+
"%s: waiting for other job to have IP assigned",
|
|
412
|
+
fmt(job_model),
|
|
413
|
+
)
|
|
414
|
+
return True
|
|
415
|
+
master_job = find_job(run.jobs, job.job_spec.replica_num, 0)
|
|
416
|
+
if (
|
|
417
|
+
job.job_spec.job_num != 0
|
|
418
|
+
and run.run_spec.merged_profile.startup_order == StartupOrder.MASTER_FIRST
|
|
419
|
+
and master_job.job_submissions[-1].status != JobStatus.RUNNING
|
|
420
|
+
):
|
|
421
|
+
logger.debug(
|
|
422
|
+
"%s: waiting for master job to become running",
|
|
423
|
+
fmt(job_model),
|
|
424
|
+
)
|
|
425
|
+
return True
|
|
426
|
+
if (
|
|
427
|
+
job.job_spec.job_num == 0
|
|
428
|
+
and run.run_spec.merged_profile.startup_order == StartupOrder.WORKERS_FIRST
|
|
429
|
+
):
|
|
430
|
+
for other_job in run.jobs:
|
|
431
|
+
if (
|
|
432
|
+
other_job.job_spec.replica_num == job.job_spec.replica_num
|
|
433
|
+
and other_job.job_spec.job_num != job.job_spec.job_num
|
|
434
|
+
and other_job.job_submissions[-1].status != JobStatus.RUNNING
|
|
435
|
+
):
|
|
436
|
+
logger.debug(
|
|
437
|
+
"%s: waiting for worker job to become running",
|
|
438
|
+
fmt(job_model),
|
|
439
|
+
)
|
|
440
|
+
return True
|
|
441
|
+
return False
|
|
442
|
+
|
|
443
|
+
|
|
409
444
|
@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
|
|
410
445
|
def _process_provisioning_with_shim(
|
|
411
446
|
ports: Dict[int, int],
|
|
@@ -10,7 +10,7 @@ from sqlalchemy.orm import joinedload, selectinload
|
|
|
10
10
|
import dstack._internal.server.services.gateways as gateways
|
|
11
11
|
import dstack._internal.server.services.services.autoscalers as autoscalers
|
|
12
12
|
from dstack._internal.core.errors import ServerError
|
|
13
|
-
from dstack._internal.core.models.profiles import RetryEvent
|
|
13
|
+
from dstack._internal.core.models.profiles import RetryEvent, StopCriteria
|
|
14
14
|
from dstack._internal.core.models.runs import (
|
|
15
15
|
Job,
|
|
16
16
|
JobStatus,
|
|
@@ -313,6 +313,10 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
|
|
|
313
313
|
termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED
|
|
314
314
|
else:
|
|
315
315
|
raise ValueError(f"Unexpected termination reason {run_termination_reasons}")
|
|
316
|
+
elif _should_stop_on_master_done(run):
|
|
317
|
+
new_status = RunStatus.TERMINATING
|
|
318
|
+
# ALL_JOBS_DONE is used for all DONE reasons including master-done
|
|
319
|
+
termination_reason = RunTerminationReason.ALL_JOBS_DONE
|
|
316
320
|
elif RunStatus.RUNNING in run_statuses:
|
|
317
321
|
new_status = RunStatus.RUNNING
|
|
318
322
|
elif RunStatus.PROVISIONING in run_statuses:
|
|
@@ -434,3 +438,12 @@ def _can_retry_single_job(run_spec: RunSpec) -> bool:
|
|
|
434
438
|
# We could make partial retry in some multi-node cases.
|
|
435
439
|
# E.g. restarting a worker node, independent jobs.
|
|
436
440
|
return False
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def _should_stop_on_master_done(run: Run) -> bool:
|
|
444
|
+
if run.run_spec.merged_profile.stop_criteria != StopCriteria.MASTER_DONE:
|
|
445
|
+
return False
|
|
446
|
+
for job in run.jobs:
|
|
447
|
+
if job.job_spec.job_num == 0 and job.job_submissions[-1].status == JobStatus.DONE:
|
|
448
|
+
return True
|
|
449
|
+
return False
|
|
@@ -659,7 +659,7 @@ async def _attach_volumes(
|
|
|
659
659
|
backend=backend,
|
|
660
660
|
volume_model=volume_model,
|
|
661
661
|
instance=instance,
|
|
662
|
-
|
|
662
|
+
jpd=job_provisioning_data,
|
|
663
663
|
)
|
|
664
664
|
job_runtime_data.volume_names.append(volume.name)
|
|
665
665
|
break # attach next mount point
|
|
@@ -685,7 +685,7 @@ async def _attach_volume(
|
|
|
685
685
|
backend: Backend,
|
|
686
686
|
volume_model: VolumeModel,
|
|
687
687
|
instance: InstanceModel,
|
|
688
|
-
|
|
688
|
+
jpd: JobProvisioningData,
|
|
689
689
|
):
|
|
690
690
|
compute = backend.compute()
|
|
691
691
|
assert isinstance(compute, ComputeWithVolumeSupport)
|
|
@@ -697,7 +697,7 @@ async def _attach_volume(
|
|
|
697
697
|
attachment_data = await common_utils.run_async(
|
|
698
698
|
compute.attach_volume,
|
|
699
699
|
volume=volume,
|
|
700
|
-
|
|
700
|
+
provisioning_data=jpd,
|
|
701
701
|
)
|
|
702
702
|
volume_attachment_model = VolumeAttachmentModel(
|
|
703
703
|
volume=volume_model,
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import List, Tuple
|
|
2
2
|
|
|
3
3
|
from fastapi import APIRouter, Depends, Request, UploadFile
|
|
4
|
-
from humanize import naturalsize
|
|
5
4
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
6
5
|
|
|
7
6
|
from dstack._internal.core.errors import ResourceNotExistsError, ServerClientError
|
|
@@ -20,6 +19,7 @@ from dstack._internal.server.utils.routers import (
|
|
|
20
19
|
get_base_api_additional_responses,
|
|
21
20
|
get_request_size,
|
|
22
21
|
)
|
|
22
|
+
from dstack._internal.utils.common import sizeof_fmt
|
|
23
23
|
|
|
24
24
|
router = APIRouter(
|
|
25
25
|
prefix="/api/project/{project_name}/repos",
|
|
@@ -98,10 +98,15 @@ async def upload_code(
|
|
|
98
98
|
):
|
|
99
99
|
request_size = get_request_size(request)
|
|
100
100
|
if SERVER_CODE_UPLOAD_LIMIT > 0 and request_size > SERVER_CODE_UPLOAD_LIMIT:
|
|
101
|
+
diff_size_fmt = sizeof_fmt(request_size)
|
|
102
|
+
limit_fmt = sizeof_fmt(SERVER_CODE_UPLOAD_LIMIT)
|
|
103
|
+
if diff_size_fmt == limit_fmt:
|
|
104
|
+
diff_size_fmt = f"{request_size}B"
|
|
105
|
+
limit_fmt = f"{SERVER_CODE_UPLOAD_LIMIT}B"
|
|
101
106
|
raise ServerClientError(
|
|
102
|
-
f"Repo diff size is {
|
|
103
|
-
|
|
104
|
-
|
|
107
|
+
f"Repo diff size is {diff_size_fmt}, which exceeds the limit of {limit_fmt}."
|
|
108
|
+
" Use .gitignore to exclude large files from the repo."
|
|
109
|
+
" This limit can be modified by setting the DSTACK_SERVER_CODE_UPLOAD_LIMIT environment variable."
|
|
105
110
|
)
|
|
106
111
|
_, project = user_project
|
|
107
112
|
await repos.upload_code(
|
|
@@ -237,7 +237,7 @@ async def get_plan(
|
|
|
237
237
|
) -> FleetPlan:
|
|
238
238
|
# Spec must be copied by parsing to calculate merged_profile
|
|
239
239
|
effective_spec = FleetSpec.parse_obj(spec.dict())
|
|
240
|
-
effective_spec = apply_plugin_policies(
|
|
240
|
+
effective_spec = await apply_plugin_policies(
|
|
241
241
|
user=user.name,
|
|
242
242
|
project=project.name,
|
|
243
243
|
spec=effective_spec,
|
|
@@ -342,7 +342,7 @@ async def create_fleet(
|
|
|
342
342
|
spec: FleetSpec,
|
|
343
343
|
) -> Fleet:
|
|
344
344
|
# Spec must be copied by parsing to calculate merged_profile
|
|
345
|
-
spec = apply_plugin_policies(
|
|
345
|
+
spec = await apply_plugin_policies(
|
|
346
346
|
user=user.name,
|
|
347
347
|
project=project.name,
|
|
348
348
|
spec=spec,
|
|
@@ -140,7 +140,7 @@ async def create_gateway(
|
|
|
140
140
|
project: ProjectModel,
|
|
141
141
|
configuration: GatewayConfiguration,
|
|
142
142
|
) -> Gateway:
|
|
143
|
-
spec = apply_plugin_policies(
|
|
143
|
+
spec = await apply_plugin_policies(
|
|
144
144
|
user=user.name,
|
|
145
145
|
project=project.name,
|
|
146
146
|
# Create pseudo spec until the gateway API is updated to accept spec
|
|
@@ -470,20 +470,20 @@ async def _detach_volume_from_job_instance(
|
|
|
470
470
|
await run_async(
|
|
471
471
|
compute.detach_volume,
|
|
472
472
|
volume=volume,
|
|
473
|
-
|
|
473
|
+
provisioning_data=jpd,
|
|
474
474
|
force=False,
|
|
475
475
|
)
|
|
476
476
|
# For some backends, the volume may be detached immediately
|
|
477
477
|
detached = await run_async(
|
|
478
478
|
compute.is_volume_detached,
|
|
479
479
|
volume=volume,
|
|
480
|
-
|
|
480
|
+
provisioning_data=jpd,
|
|
481
481
|
)
|
|
482
482
|
else:
|
|
483
483
|
detached = await run_async(
|
|
484
484
|
compute.is_volume_detached,
|
|
485
485
|
volume=volume,
|
|
486
|
-
|
|
486
|
+
provisioning_data=jpd,
|
|
487
487
|
)
|
|
488
488
|
if not detached and _should_force_detach_volume(job_model, job_spec.stop_duration):
|
|
489
489
|
logger.info(
|
|
@@ -494,7 +494,7 @@ async def _detach_volume_from_job_instance(
|
|
|
494
494
|
await run_async(
|
|
495
495
|
compute.detach_volume,
|
|
496
496
|
volume=volume,
|
|
497
|
-
|
|
497
|
+
provisioning_data=jpd,
|
|
498
498
|
force=True,
|
|
499
499
|
)
|
|
500
500
|
# Let the next iteration check if force detach worked
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
import itertools
|
|
2
2
|
from importlib import import_module
|
|
3
|
+
from typing import Dict
|
|
3
4
|
|
|
4
5
|
from backports.entry_points_selectable import entry_points # backport for Python 3.9
|
|
5
6
|
|
|
6
7
|
from dstack._internal.core.errors import ServerClientError
|
|
8
|
+
from dstack._internal.utils.common import run_async
|
|
7
9
|
from dstack._internal.utils.logging import get_logger
|
|
8
10
|
from dstack.plugins import ApplyPolicy, ApplySpec, Plugin
|
|
9
11
|
|
|
@@ -12,59 +14,89 @@ logger = get_logger(__name__)
|
|
|
12
14
|
|
|
13
15
|
_PLUGINS: list[Plugin] = []
|
|
14
16
|
|
|
17
|
+
_BUILTIN_PLUGINS: Dict[str, str] = {"rest_plugin": "dstack.plugins.builtin.rest_plugin:RESTPlugin"}
|
|
15
18
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
)
|
|
26
|
-
continue
|
|
19
|
+
|
|
20
|
+
class PluginEntrypoint:
|
|
21
|
+
def __init__(self, name: str, import_path: str, is_builtin: bool = False):
|
|
22
|
+
self.name = name
|
|
23
|
+
self.import_path = import_path
|
|
24
|
+
self.is_builtin = is_builtin
|
|
25
|
+
|
|
26
|
+
def load(self):
|
|
27
|
+
module_path, _, class_name = self.import_path.partition(":")
|
|
27
28
|
try:
|
|
28
|
-
module_path, _, class_name = entrypoint.value.partition(":")
|
|
29
29
|
module = import_module(module_path)
|
|
30
|
+
plugin_class = getattr(module, class_name, None)
|
|
31
|
+
if plugin_class is None:
|
|
32
|
+
logger.warning(
|
|
33
|
+
("Failed to load plugin %s: plugin class %s not found in module %s."),
|
|
34
|
+
self.name,
|
|
35
|
+
class_name,
|
|
36
|
+
module_path,
|
|
37
|
+
)
|
|
38
|
+
return None
|
|
39
|
+
if not issubclass(plugin_class, Plugin):
|
|
40
|
+
logger.warning(
|
|
41
|
+
("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
|
|
42
|
+
self.name,
|
|
43
|
+
class_name,
|
|
44
|
+
)
|
|
45
|
+
return None
|
|
46
|
+
return plugin_class()
|
|
30
47
|
except ImportError:
|
|
31
48
|
logger.warning(
|
|
32
49
|
(
|
|
33
50
|
"Failed to load plugin %s when importing %s."
|
|
34
51
|
" Ensure the module is on the import path."
|
|
35
52
|
),
|
|
36
|
-
|
|
37
|
-
|
|
53
|
+
self.name,
|
|
54
|
+
self.import_path,
|
|
38
55
|
)
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def load_plugins(enabled_plugins: list[str]):
|
|
60
|
+
_PLUGINS.clear()
|
|
61
|
+
entrypoints: dict[str, PluginEntrypoint] = {}
|
|
62
|
+
plugins_to_load = enabled_plugins.copy()
|
|
63
|
+
for entrypoint in entry_points(group="dstack.plugins"):
|
|
64
|
+
if entrypoint.name not in enabled_plugins:
|
|
65
|
+
logger.info(
|
|
66
|
+
("Found not enabled plugin %s. Plugin will not be loaded."),
|
|
44
67
|
entrypoint.name,
|
|
45
|
-
class_name,
|
|
46
|
-
module_path,
|
|
47
68
|
)
|
|
48
69
|
continue
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
entrypoint.name,
|
|
53
|
-
class_name,
|
|
70
|
+
else:
|
|
71
|
+
entrypoints[entrypoint.name] = PluginEntrypoint(
|
|
72
|
+
entrypoint.name, entrypoint.value, is_builtin=False
|
|
54
73
|
)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
74
|
+
|
|
75
|
+
for name, import_path in _BUILTIN_PLUGINS.items():
|
|
76
|
+
if name not in enabled_plugins:
|
|
77
|
+
logger.info(
|
|
78
|
+
("Found not enabled builtin plugin %s. Plugin will not be loaded."),
|
|
79
|
+
name,
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
entrypoints[name] = PluginEntrypoint(name, import_path, is_builtin=True)
|
|
83
|
+
|
|
84
|
+
for plugin_name, plugin_entrypoint in entrypoints.items():
|
|
85
|
+
plugin_instance = plugin_entrypoint.load()
|
|
86
|
+
if plugin_instance is not None:
|
|
87
|
+
_PLUGINS.append(plugin_instance)
|
|
88
|
+
plugins_to_load.remove(plugin_name)
|
|
89
|
+
logger.info("Loaded plugin %s", plugin_name)
|
|
90
|
+
|
|
59
91
|
if plugins_to_load:
|
|
60
92
|
logger.warning("Enabled plugins not found: %s", plugins_to_load)
|
|
61
93
|
|
|
62
94
|
|
|
63
|
-
def apply_plugin_policies(user: str, project: str, spec: ApplySpec) -> ApplySpec:
|
|
95
|
+
async def apply_plugin_policies(user: str, project: str, spec: ApplySpec) -> ApplySpec:
|
|
64
96
|
policies = _get_apply_policies()
|
|
65
97
|
for policy in policies:
|
|
66
98
|
try:
|
|
67
|
-
spec = policy.on_apply
|
|
99
|
+
spec = await run_async(policy.on_apply, user=user, project=project, spec=spec)
|
|
68
100
|
except ValueError as e:
|
|
69
101
|
msg = None
|
|
70
102
|
if len(e.args) > 0:
|