dstack 0.19.11rc1__py3-none-any.whl → 0.19.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (45) hide show
  1. dstack/_internal/cli/commands/offer.py +2 -0
  2. dstack/_internal/cli/services/configurators/run.py +43 -42
  3. dstack/_internal/cli/utils/run.py +10 -26
  4. dstack/_internal/cli/utils/updates.py +13 -1
  5. dstack/_internal/core/backends/aws/compute.py +21 -9
  6. dstack/_internal/core/backends/base/compute.py +7 -3
  7. dstack/_internal/core/backends/gcp/compute.py +43 -20
  8. dstack/_internal/core/backends/gcp/resources.py +18 -2
  9. dstack/_internal/core/backends/local/compute.py +4 -2
  10. dstack/_internal/core/backends/template/configurator.py.jinja +1 -6
  11. dstack/_internal/core/backends/template/models.py.jinja +4 -0
  12. dstack/_internal/core/models/configurations.py +1 -1
  13. dstack/_internal/core/models/fleets.py +6 -1
  14. dstack/_internal/core/models/profiles.py +43 -3
  15. dstack/_internal/core/models/repos/local.py +19 -13
  16. dstack/_internal/core/models/runs.py +78 -45
  17. dstack/_internal/server/background/tasks/process_running_jobs.py +47 -12
  18. dstack/_internal/server/background/tasks/process_runs.py +14 -1
  19. dstack/_internal/server/background/tasks/process_submitted_jobs.py +3 -3
  20. dstack/_internal/server/routers/repos.py +9 -4
  21. dstack/_internal/server/services/fleets.py +2 -2
  22. dstack/_internal/server/services/gateways/__init__.py +1 -1
  23. dstack/_internal/server/services/jobs/__init__.py +4 -4
  24. dstack/_internal/server/services/plugins.py +64 -32
  25. dstack/_internal/server/services/runner/client.py +4 -1
  26. dstack/_internal/server/services/runs.py +2 -2
  27. dstack/_internal/server/services/volumes.py +1 -1
  28. dstack/_internal/server/statics/index.html +1 -1
  29. dstack/_internal/server/statics/{main-b4803049eac16aea9a49.js → main-b0e80f8e26a168c129e9.js} +72 -25
  30. dstack/_internal/server/statics/{main-b4803049eac16aea9a49.js.map → main-b0e80f8e26a168c129e9.js.map} +1 -1
  31. dstack/_internal/server/testing/common.py +2 -1
  32. dstack/_internal/utils/common.py +4 -0
  33. dstack/api/server/_fleets.py +5 -1
  34. dstack/api/server/_runs.py +8 -0
  35. dstack/plugins/builtin/__init__.py +0 -0
  36. dstack/plugins/builtin/rest_plugin/__init__.py +18 -0
  37. dstack/plugins/builtin/rest_plugin/_models.py +48 -0
  38. dstack/plugins/builtin/rest_plugin/_plugin.py +127 -0
  39. dstack/version.py +1 -1
  40. {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/METADATA +2 -2
  41. {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/RECORD +44 -41
  42. dstack/_internal/utils/ignore.py +0 -92
  43. {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/WHEEL +0 -0
  44. {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/entry_points.txt +0 -0
  45. {dstack-0.19.11rc1.dist-info → dstack-0.19.12.dist-info}/licenses/LICENSE.md +0 -0
@@ -6,6 +6,7 @@ from typing_extensions import Annotated, Literal
6
6
 
7
7
  from dstack._internal.core.models.backends.base import BackendType
8
8
  from dstack._internal.core.models.common import CoreModel, Duration
9
+ from dstack._internal.utils.common import list_enum_values_for_annotation
9
10
  from dstack._internal.utils.tags import tags_validator
10
11
 
11
12
  DEFAULT_RETRY_DURATION = 3600
@@ -32,6 +33,17 @@ class TerminationPolicy(str, Enum):
32
33
  DESTROY_AFTER_IDLE = "destroy-after-idle"
33
34
 
34
35
 
36
+ class StartupOrder(str, Enum):
37
+ ANY = "any"
38
+ MASTER_FIRST = "master-first"
39
+ WORKERS_FIRST = "workers-first"
40
+
41
+
42
+ class StopCriteria(str, Enum):
43
+ ALL_DONE = "all-done"
44
+ MASTER_DONE = "master-done"
45
+
46
+
35
47
  @overload
36
48
  def parse_duration(v: None) -> None: ...
37
49
 
@@ -102,7 +114,7 @@ class ProfileRetry(CoreModel):
102
114
  Field(
103
115
  description=(
104
116
  "The list of events that should be handled with retry."
105
- " Supported events are `no-capacity`, `interruption`, and `error`."
117
+ f" Supported events are {list_enum_values_for_annotation(RetryEvent)}."
106
118
  " Omit to retry on all events"
107
119
  )
108
120
  ),
@@ -190,7 +202,11 @@ class ProfileParams(CoreModel):
190
202
  spot_policy: Annotated[
191
203
  Optional[SpotPolicy],
192
204
  Field(
193
- description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`. Defaults to `on-demand`"
205
+ description=(
206
+ "The policy for provisioning spot or on-demand instances:"
207
+ f" {list_enum_values_for_annotation(SpotPolicy)}."
208
+ f" Defaults to `{SpotPolicy.ONDEMAND.value}`"
209
+ )
194
210
  ),
195
211
  ] = None
196
212
  retry: Annotated[
@@ -225,7 +241,11 @@ class ProfileParams(CoreModel):
225
241
  creation_policy: Annotated[
226
242
  Optional[CreationPolicy],
227
243
  Field(
228
- description="The policy for using instances from fleets. Defaults to `reuse-or-create`"
244
+ description=(
245
+ "The policy for using instances from fleets:"
246
+ f" {list_enum_values_for_annotation(CreationPolicy)}."
247
+ f" Defaults to `{CreationPolicy.REUSE_OR_CREATE.value}`"
248
+ )
229
249
  ),
230
250
  ] = None
231
251
  idle_duration: Annotated[
@@ -241,6 +261,26 @@ class ProfileParams(CoreModel):
241
261
  Optional[UtilizationPolicy],
242
262
  Field(description="Run termination policy based on utilization"),
243
263
  ] = None
264
+ startup_order: Annotated[
265
+ Optional[StartupOrder],
266
+ Field(
267
+ description=(
268
+ f"The order in which master and workers jobs are started:"
269
+ f" {list_enum_values_for_annotation(StartupOrder)}."
270
+ f" Defaults to `{StartupOrder.ANY.value}`"
271
+ )
272
+ ),
273
+ ] = None
274
+ stop_criteria: Annotated[
275
+ Optional[StopCriteria],
276
+ Field(
277
+ description=(
278
+ "The criteria determining when a multi-node run should be considered finished:"
279
+ f" {list_enum_values_for_annotation(StopCriteria)}."
280
+ f" Defaults to `{StopCriteria.ALL_DONE.value}`"
281
+ )
282
+ ),
283
+ ] = None
244
284
  fleets: Annotated[
245
285
  Optional[list[str]], Field(description="The fleets considered for reuse")
246
286
  ] = None
@@ -2,13 +2,18 @@ import tarfile
2
2
  from pathlib import Path
3
3
  from typing import BinaryIO, Optional
4
4
 
5
+ import ignore
6
+ import ignore.overrides
5
7
  from typing_extensions import Literal
6
8
 
7
9
  from dstack._internal.core.models.repos.base import BaseRepoInfo, Repo
10
+ from dstack._internal.utils.common import sizeof_fmt
8
11
  from dstack._internal.utils.hash import get_sha256, slugify
9
- from dstack._internal.utils.ignore import GitIgnore
12
+ from dstack._internal.utils.logging import get_logger
10
13
  from dstack._internal.utils.path import PathLike
11
14
 
15
+ logger = get_logger(__name__)
16
+
12
17
 
13
18
  class LocalRepoInfo(BaseRepoInfo):
14
19
  repo_type: Literal["local"] = "local"
@@ -69,22 +74,23 @@ class LocalRepo(Repo):
69
74
  self.run_repo_data = repo_data
70
75
 
71
76
  def write_code_file(self, fp: BinaryIO) -> str:
77
+ repo_path = Path(self.run_repo_data.repo_dir)
72
78
  with tarfile.TarFile(mode="w", fileobj=fp) as t:
73
- t.add(
74
- self.run_repo_data.repo_dir,
75
- arcname="",
76
- filter=TarIgnore(self.run_repo_data.repo_dir, globs=[".git"]),
77
- )
79
+ for entry in (
80
+ ignore.WalkBuilder(repo_path)
81
+ .overrides(ignore.overrides.OverrideBuilder(repo_path).add("!/.git/").build())
82
+ .hidden(False) # do not ignore files that start with a dot
83
+ .require_git(False) # respect git ignore rules even if not a git repo
84
+ .add_custom_ignore_filename(".dstackignore")
85
+ .build()
86
+ ):
87
+ entry_path_within_repo = entry.path().relative_to(repo_path)
88
+ if entry_path_within_repo != Path("."):
89
+ t.add(entry.path(), arcname=entry_path_within_repo, recursive=False)
90
+ logger.debug("Code file size: %s", sizeof_fmt(fp.tell()))
78
91
  return get_sha256(fp)
79
92
 
80
93
  def get_repo_info(self) -> LocalRepoInfo:
81
94
  return LocalRepoInfo(
82
95
  repo_dir=self.run_repo_data.repo_dir,
83
96
  )
84
-
85
-
86
- class TarIgnore(GitIgnore):
87
- def __call__(self, tarinfo: tarfile.TarInfo) -> Optional[tarfile.TarInfo]:
88
- if self.ignore(tarinfo.path):
89
- return None
90
- return tarinfo
@@ -148,9 +148,6 @@ class JobTerminationReason(str, Enum):
148
148
  }
149
149
  return mapping[self]
150
150
 
151
- def pretty_repr(self) -> str:
152
- return " ".join(self.value.split("_")).capitalize()
153
-
154
151
 
155
152
  class Requirements(CoreModel):
156
153
  # TODO: Make requirements' fields required
@@ -289,6 +286,9 @@ class JobSubmission(CoreModel):
289
286
  exit_status: Optional[int]
290
287
  job_provisioning_data: Optional[JobProvisioningData]
291
288
  job_runtime_data: Optional[JobRuntimeData]
289
+ # TODO: make status_message and error a computed field after migrating to pydanticV2
290
+ status_message: Optional[str]
291
+ error: Optional[str] = None
292
292
 
293
293
  @property
294
294
  def age(self) -> timedelta:
@@ -301,6 +301,71 @@ class JobSubmission(CoreModel):
301
301
  end_time = self.finished_at
302
302
  return end_time - self.submitted_at
303
303
 
304
+ @root_validator
305
+ def _status_message(cls, values) -> Dict:
306
+ try:
307
+ status = values["status"]
308
+ termination_reason = values["termination_reason"]
309
+ exit_code = values["exit_status"]
310
+ except KeyError:
311
+ return values
312
+ values["status_message"] = JobSubmission._get_status_message(
313
+ status=status,
314
+ termination_reason=termination_reason,
315
+ exit_status=exit_code,
316
+ )
317
+ return values
318
+
319
+ @staticmethod
320
+ def _get_status_message(
321
+ status: JobStatus,
322
+ termination_reason: Optional[JobTerminationReason],
323
+ exit_status: Optional[int],
324
+ ) -> str:
325
+ if status == JobStatus.DONE:
326
+ return "exited (0)"
327
+ elif status == JobStatus.FAILED:
328
+ if termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR:
329
+ return f"exited ({exit_status})"
330
+ elif termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY:
331
+ return "no offers"
332
+ elif termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY:
333
+ return "interrupted"
334
+ else:
335
+ return "error"
336
+ elif status == JobStatus.TERMINATED:
337
+ if termination_reason == JobTerminationReason.TERMINATED_BY_USER:
338
+ return "stopped"
339
+ elif termination_reason == JobTerminationReason.ABORTED_BY_USER:
340
+ return "aborted"
341
+ return status.value
342
+
343
+ @root_validator
344
+ def _error(cls, values) -> Dict:
345
+ try:
346
+ termination_reason = values["termination_reason"]
347
+ except KeyError:
348
+ return values
349
+ values["error"] = JobSubmission._get_error(termination_reason=termination_reason)
350
+ return values
351
+
352
+ @staticmethod
353
+ def _get_error(termination_reason: Optional[JobTerminationReason]) -> Optional[str]:
354
+ error_mapping = {
355
+ JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable",
356
+ JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded",
357
+ JobTerminationReason.VOLUME_ERROR: "waiting runner limit exceeded",
358
+ JobTerminationReason.GATEWAY_ERROR: "gateway error",
359
+ JobTerminationReason.SCALED_DOWN: "scaled down",
360
+ JobTerminationReason.INACTIVITY_DURATION_EXCEEDED: "inactivity duration exceeded",
361
+ JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY: "utilization policy",
362
+ JobTerminationReason.PORTS_BINDING_FAILED: "ports binding failed",
363
+ JobTerminationReason.CREATING_CONTAINER_ERROR: "runner error",
364
+ JobTerminationReason.EXECUTOR_ERROR: "executor error",
365
+ JobTerminationReason.MAX_DURATION_EXCEEDED: "max duration exceeded",
366
+ }
367
+ return error_mapping.get(termination_reason)
368
+
304
369
 
305
370
  class Job(CoreModel):
306
371
  job_spec: JobSpec
@@ -445,15 +510,20 @@ class Run(CoreModel):
445
510
  def _error(cls, values) -> Dict:
446
511
  try:
447
512
  termination_reason = values["termination_reason"]
448
- jobs = values["jobs"]
449
513
  except KeyError:
450
514
  return values
451
- values["error"] = _get_run_error(
452
- run_termination_reason=termination_reason,
453
- run_jobs=jobs,
454
- )
515
+ values["error"] = Run._get_error(termination_reason=termination_reason)
455
516
  return values
456
517
 
518
+ @staticmethod
519
+ def _get_error(termination_reason: Optional[RunTerminationReason]) -> Optional[str]:
520
+ if termination_reason == RunTerminationReason.RETRY_LIMIT_EXCEEDED:
521
+ return "retry limit exceeded"
522
+ elif termination_reason == RunTerminationReason.SERVER_ERROR:
523
+ return "server error"
524
+ else:
525
+ return None
526
+
457
527
 
458
528
  class JobPlan(CoreModel):
459
529
  job_spec: JobSpec
@@ -502,40 +572,3 @@ def get_policy_map(spot_policy: Optional[SpotPolicy], default: SpotPolicy) -> Op
502
572
  SpotPolicy.ONDEMAND: False,
503
573
  }
504
574
  return policy_map[spot_policy]
505
-
506
-
507
- def _get_run_error(
508
- run_termination_reason: Optional[RunTerminationReason],
509
- run_jobs: List[Job],
510
- ) -> str:
511
- if run_termination_reason is None:
512
- return ""
513
- if len(run_jobs) > 1:
514
- return run_termination_reason.name
515
- run_job_termination_reason, exit_status = _get_run_job_termination_reason_and_exit_status(
516
- run_jobs
517
- )
518
- # For failed runs, also show termination reason to provide more context.
519
- # For other run statuses, the job termination reason will duplicate run status.
520
- if run_job_termination_reason is not None and run_termination_reason in [
521
- RunTerminationReason.JOB_FAILED,
522
- RunTerminationReason.SERVER_ERROR,
523
- RunTerminationReason.RETRY_LIMIT_EXCEEDED,
524
- ]:
525
- if exit_status:
526
- return (
527
- f"{run_termination_reason.name}\n({run_job_termination_reason.name} {exit_status})"
528
- )
529
- return f"{run_termination_reason.name}\n({run_job_termination_reason.name})"
530
- return run_termination_reason.name
531
-
532
-
533
- def _get_run_job_termination_reason_and_exit_status(
534
- run_jobs: List[Job],
535
- ) -> tuple[Optional[JobTerminationReason], Optional[int]]:
536
- for job in run_jobs:
537
- if len(job.job_submissions) > 0:
538
- job_submission = job.job_submissions[-1]
539
- if job_submission.termination_reason is not None:
540
- return job_submission.termination_reason, job_submission.exit_status
541
- return None, None
@@ -18,6 +18,7 @@ from dstack._internal.core.models.instances import (
18
18
  SSHConnectionParams,
19
19
  )
20
20
  from dstack._internal.core.models.metrics import Metric
21
+ from dstack._internal.core.models.profiles import StartupOrder
21
22
  from dstack._internal.core.models.repos import RemoteRepoCreds
22
23
  from dstack._internal.core.models.runs import (
23
24
  ClusterInfo,
@@ -184,18 +185,10 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
184
185
  if job_provisioning_data.hostname is None:
185
186
  await _wait_for_instance_provisioning_data(job_model=job_model)
186
187
  else:
187
- # Wait until all other jobs in the replica have IPs assigned.
188
- # This is needed to ensure cluster_info has all IPs set.
189
- for other_job in run.jobs:
190
- if (
191
- other_job.job_spec.replica_num == job.job_spec.replica_num
192
- and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
193
- and other_job.job_submissions[-1].job_provisioning_data is not None
194
- and other_job.job_submissions[-1].job_provisioning_data.hostname is None
195
- ):
196
- job_model.last_processed_at = common_utils.get_current_datetime()
197
- await session.commit()
198
- return
188
+ if _should_wait_for_other_nodes(run, job, job_model):
189
+ job_model.last_processed_at = common_utils.get_current_datetime()
190
+ await session.commit()
191
+ return
199
192
 
200
193
  # fails are acceptable until timeout is exceeded
201
194
  if job_provisioning_data.dockerized:
@@ -406,6 +399,48 @@ async def _wait_for_instance_provisioning_data(job_model: JobModel):
406
399
  job_model.job_provisioning_data = job_model.instance.job_provisioning_data
407
400
 
408
401
 
402
+ def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> bool:
403
+ for other_job in run.jobs:
404
+ if (
405
+ other_job.job_spec.replica_num == job.job_spec.replica_num
406
+ and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
407
+ and other_job.job_submissions[-1].job_provisioning_data is not None
408
+ and other_job.job_submissions[-1].job_provisioning_data.hostname is None
409
+ ):
410
+ logger.debug(
411
+ "%s: waiting for other job to have IP assigned",
412
+ fmt(job_model),
413
+ )
414
+ return True
415
+ master_job = find_job(run.jobs, job.job_spec.replica_num, 0)
416
+ if (
417
+ job.job_spec.job_num != 0
418
+ and run.run_spec.merged_profile.startup_order == StartupOrder.MASTER_FIRST
419
+ and master_job.job_submissions[-1].status != JobStatus.RUNNING
420
+ ):
421
+ logger.debug(
422
+ "%s: waiting for master job to become running",
423
+ fmt(job_model),
424
+ )
425
+ return True
426
+ if (
427
+ job.job_spec.job_num == 0
428
+ and run.run_spec.merged_profile.startup_order == StartupOrder.WORKERS_FIRST
429
+ ):
430
+ for other_job in run.jobs:
431
+ if (
432
+ other_job.job_spec.replica_num == job.job_spec.replica_num
433
+ and other_job.job_spec.job_num != job.job_spec.job_num
434
+ and other_job.job_submissions[-1].status != JobStatus.RUNNING
435
+ ):
436
+ logger.debug(
437
+ "%s: waiting for worker job to become running",
438
+ fmt(job_model),
439
+ )
440
+ return True
441
+ return False
442
+
443
+
409
444
  @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
410
445
  def _process_provisioning_with_shim(
411
446
  ports: Dict[int, int],
@@ -10,7 +10,7 @@ from sqlalchemy.orm import joinedload, selectinload
10
10
  import dstack._internal.server.services.gateways as gateways
11
11
  import dstack._internal.server.services.services.autoscalers as autoscalers
12
12
  from dstack._internal.core.errors import ServerError
13
- from dstack._internal.core.models.profiles import RetryEvent
13
+ from dstack._internal.core.models.profiles import RetryEvent, StopCriteria
14
14
  from dstack._internal.core.models.runs import (
15
15
  Job,
16
16
  JobStatus,
@@ -313,6 +313,10 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
313
313
  termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED
314
314
  else:
315
315
  raise ValueError(f"Unexpected termination reason {run_termination_reasons}")
316
+ elif _should_stop_on_master_done(run):
317
+ new_status = RunStatus.TERMINATING
318
+ # ALL_JOBS_DONE is used for all DONE reasons including master-done
319
+ termination_reason = RunTerminationReason.ALL_JOBS_DONE
316
320
  elif RunStatus.RUNNING in run_statuses:
317
321
  new_status = RunStatus.RUNNING
318
322
  elif RunStatus.PROVISIONING in run_statuses:
@@ -434,3 +438,12 @@ def _can_retry_single_job(run_spec: RunSpec) -> bool:
434
438
  # We could make partial retry in some multi-node cases.
435
439
  # E.g. restarting a worker node, independent jobs.
436
440
  return False
441
+
442
+
443
+ def _should_stop_on_master_done(run: Run) -> bool:
444
+ if run.run_spec.merged_profile.stop_criteria != StopCriteria.MASTER_DONE:
445
+ return False
446
+ for job in run.jobs:
447
+ if job.job_spec.job_num == 0 and job.job_submissions[-1].status == JobStatus.DONE:
448
+ return True
449
+ return False
@@ -659,7 +659,7 @@ async def _attach_volumes(
659
659
  backend=backend,
660
660
  volume_model=volume_model,
661
661
  instance=instance,
662
- instance_id=job_provisioning_data.instance_id,
662
+ jpd=job_provisioning_data,
663
663
  )
664
664
  job_runtime_data.volume_names.append(volume.name)
665
665
  break # attach next mount point
@@ -685,7 +685,7 @@ async def _attach_volume(
685
685
  backend: Backend,
686
686
  volume_model: VolumeModel,
687
687
  instance: InstanceModel,
688
- instance_id: str,
688
+ jpd: JobProvisioningData,
689
689
  ):
690
690
  compute = backend.compute()
691
691
  assert isinstance(compute, ComputeWithVolumeSupport)
@@ -697,7 +697,7 @@ async def _attach_volume(
697
697
  attachment_data = await common_utils.run_async(
698
698
  compute.attach_volume,
699
699
  volume=volume,
700
- instance_id=instance_id,
700
+ provisioning_data=jpd,
701
701
  )
702
702
  volume_attachment_model = VolumeAttachmentModel(
703
703
  volume=volume_model,
@@ -1,7 +1,6 @@
1
1
  from typing import List, Tuple
2
2
 
3
3
  from fastapi import APIRouter, Depends, Request, UploadFile
4
- from humanize import naturalsize
5
4
  from sqlalchemy.ext.asyncio import AsyncSession
6
5
 
7
6
  from dstack._internal.core.errors import ResourceNotExistsError, ServerClientError
@@ -20,6 +19,7 @@ from dstack._internal.server.utils.routers import (
20
19
  get_base_api_additional_responses,
21
20
  get_request_size,
22
21
  )
22
+ from dstack._internal.utils.common import sizeof_fmt
23
23
 
24
24
  router = APIRouter(
25
25
  prefix="/api/project/{project_name}/repos",
@@ -98,10 +98,15 @@ async def upload_code(
98
98
  ):
99
99
  request_size = get_request_size(request)
100
100
  if SERVER_CODE_UPLOAD_LIMIT > 0 and request_size > SERVER_CODE_UPLOAD_LIMIT:
101
+ diff_size_fmt = sizeof_fmt(request_size)
102
+ limit_fmt = sizeof_fmt(SERVER_CODE_UPLOAD_LIMIT)
103
+ if diff_size_fmt == limit_fmt:
104
+ diff_size_fmt = f"{request_size}B"
105
+ limit_fmt = f"{SERVER_CODE_UPLOAD_LIMIT}B"
101
106
  raise ServerClientError(
102
- f"Repo diff size is {naturalsize(request_size)}, which exceeds the limit of "
103
- f"{naturalsize(SERVER_CODE_UPLOAD_LIMIT)}. Use .gitignore to exclude large files from the repo. This "
104
- f"limit can be modified by setting the DSTACK_SERVER_CODE_UPLOAD_LIMIT_BYTES environment variable"
107
+ f"Repo diff size is {diff_size_fmt}, which exceeds the limit of {limit_fmt}."
108
+ " Use .gitignore to exclude large files from the repo."
109
+ " This limit can be modified by setting the DSTACK_SERVER_CODE_UPLOAD_LIMIT environment variable."
105
110
  )
106
111
  _, project = user_project
107
112
  await repos.upload_code(
@@ -237,7 +237,7 @@ async def get_plan(
237
237
  ) -> FleetPlan:
238
238
  # Spec must be copied by parsing to calculate merged_profile
239
239
  effective_spec = FleetSpec.parse_obj(spec.dict())
240
- effective_spec = apply_plugin_policies(
240
+ effective_spec = await apply_plugin_policies(
241
241
  user=user.name,
242
242
  project=project.name,
243
243
  spec=effective_spec,
@@ -342,7 +342,7 @@ async def create_fleet(
342
342
  spec: FleetSpec,
343
343
  ) -> Fleet:
344
344
  # Spec must be copied by parsing to calculate merged_profile
345
- spec = apply_plugin_policies(
345
+ spec = await apply_plugin_policies(
346
346
  user=user.name,
347
347
  project=project.name,
348
348
  spec=spec,
@@ -140,7 +140,7 @@ async def create_gateway(
140
140
  project: ProjectModel,
141
141
  configuration: GatewayConfiguration,
142
142
  ) -> Gateway:
143
- spec = apply_plugin_policies(
143
+ spec = await apply_plugin_policies(
144
144
  user=user.name,
145
145
  project=project.name,
146
146
  # Create pseudo spec until the gateway API is updated to accept spec
@@ -470,20 +470,20 @@ async def _detach_volume_from_job_instance(
470
470
  await run_async(
471
471
  compute.detach_volume,
472
472
  volume=volume,
473
- instance_id=jpd.instance_id,
473
+ provisioning_data=jpd,
474
474
  force=False,
475
475
  )
476
476
  # For some backends, the volume may be detached immediately
477
477
  detached = await run_async(
478
478
  compute.is_volume_detached,
479
479
  volume=volume,
480
- instance_id=jpd.instance_id,
480
+ provisioning_data=jpd,
481
481
  )
482
482
  else:
483
483
  detached = await run_async(
484
484
  compute.is_volume_detached,
485
485
  volume=volume,
486
- instance_id=jpd.instance_id,
486
+ provisioning_data=jpd,
487
487
  )
488
488
  if not detached and _should_force_detach_volume(job_model, job_spec.stop_duration):
489
489
  logger.info(
@@ -494,7 +494,7 @@ async def _detach_volume_from_job_instance(
494
494
  await run_async(
495
495
  compute.detach_volume,
496
496
  volume=volume,
497
- instance_id=jpd.instance_id,
497
+ provisioning_data=jpd,
498
498
  force=True,
499
499
  )
500
500
  # Let the next iteration check if force detach worked
@@ -1,9 +1,11 @@
1
1
  import itertools
2
2
  from importlib import import_module
3
+ from typing import Dict
3
4
 
4
5
  from backports.entry_points_selectable import entry_points # backport for Python 3.9
5
6
 
6
7
  from dstack._internal.core.errors import ServerClientError
8
+ from dstack._internal.utils.common import run_async
7
9
  from dstack._internal.utils.logging import get_logger
8
10
  from dstack.plugins import ApplyPolicy, ApplySpec, Plugin
9
11
 
@@ -12,59 +14,89 @@ logger = get_logger(__name__)
12
14
 
13
15
  _PLUGINS: list[Plugin] = []
14
16
 
17
+ _BUILTIN_PLUGINS: Dict[str, str] = {"rest_plugin": "dstack.plugins.builtin.rest_plugin:RESTPlugin"}
15
18
 
16
- def load_plugins(enabled_plugins: list[str]):
17
- _PLUGINS.clear()
18
- plugins_entrypoints = entry_points(group="dstack.plugins")
19
- plugins_to_load = enabled_plugins.copy()
20
- for entrypoint in plugins_entrypoints:
21
- if entrypoint.name not in enabled_plugins:
22
- logger.info(
23
- ("Found not enabled plugin %s. Plugin will not be loaded."),
24
- entrypoint.name,
25
- )
26
- continue
19
+
20
+ class PluginEntrypoint:
21
+ def __init__(self, name: str, import_path: str, is_builtin: bool = False):
22
+ self.name = name
23
+ self.import_path = import_path
24
+ self.is_builtin = is_builtin
25
+
26
+ def load(self):
27
+ module_path, _, class_name = self.import_path.partition(":")
27
28
  try:
28
- module_path, _, class_name = entrypoint.value.partition(":")
29
29
  module = import_module(module_path)
30
+ plugin_class = getattr(module, class_name, None)
31
+ if plugin_class is None:
32
+ logger.warning(
33
+ ("Failed to load plugin %s: plugin class %s not found in module %s."),
34
+ self.name,
35
+ class_name,
36
+ module_path,
37
+ )
38
+ return None
39
+ if not issubclass(plugin_class, Plugin):
40
+ logger.warning(
41
+ ("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
42
+ self.name,
43
+ class_name,
44
+ )
45
+ return None
46
+ return plugin_class()
30
47
  except ImportError:
31
48
  logger.warning(
32
49
  (
33
50
  "Failed to load plugin %s when importing %s."
34
51
  " Ensure the module is on the import path."
35
52
  ),
36
- entrypoint.name,
37
- entrypoint.value,
53
+ self.name,
54
+ self.import_path,
38
55
  )
39
- continue
40
- plugin_class = getattr(module, class_name, None)
41
- if plugin_class is None:
42
- logger.warning(
43
- ("Failed to load plugin %s: plugin class %s not found in module %s."),
56
+ return None
57
+
58
+
59
+ def load_plugins(enabled_plugins: list[str]):
60
+ _PLUGINS.clear()
61
+ entrypoints: dict[str, PluginEntrypoint] = {}
62
+ plugins_to_load = enabled_plugins.copy()
63
+ for entrypoint in entry_points(group="dstack.plugins"):
64
+ if entrypoint.name not in enabled_plugins:
65
+ logger.info(
66
+ ("Found not enabled plugin %s. Plugin will not be loaded."),
44
67
  entrypoint.name,
45
- class_name,
46
- module_path,
47
68
  )
48
69
  continue
49
- if not issubclass(plugin_class, Plugin):
50
- logger.warning(
51
- ("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
52
- entrypoint.name,
53
- class_name,
70
+ else:
71
+ entrypoints[entrypoint.name] = PluginEntrypoint(
72
+ entrypoint.name, entrypoint.value, is_builtin=False
54
73
  )
55
- continue
56
- plugins_to_load.remove(entrypoint.name)
57
- _PLUGINS.append(plugin_class())
58
- logger.info("Loaded plugin %s", entrypoint.name)
74
+
75
+ for name, import_path in _BUILTIN_PLUGINS.items():
76
+ if name not in enabled_plugins:
77
+ logger.info(
78
+ ("Found not enabled builtin plugin %s. Plugin will not be loaded."),
79
+ name,
80
+ )
81
+ else:
82
+ entrypoints[name] = PluginEntrypoint(name, import_path, is_builtin=True)
83
+
84
+ for plugin_name, plugin_entrypoint in entrypoints.items():
85
+ plugin_instance = plugin_entrypoint.load()
86
+ if plugin_instance is not None:
87
+ _PLUGINS.append(plugin_instance)
88
+ plugins_to_load.remove(plugin_name)
89
+ logger.info("Loaded plugin %s", plugin_name)
90
+
59
91
  if plugins_to_load:
60
92
  logger.warning("Enabled plugins not found: %s", plugins_to_load)
61
93
 
62
94
 
63
- def apply_plugin_policies(user: str, project: str, spec: ApplySpec) -> ApplySpec:
95
+ async def apply_plugin_policies(user: str, project: str, spec: ApplySpec) -> ApplySpec:
64
96
  policies = _get_apply_policies()
65
97
  for policy in policies:
66
98
  try:
67
- spec = policy.on_apply(user=user, project=project, spec=spec)
99
+ spec = await run_async(policy.on_apply, user=user, project=project, spec=spec)
68
100
  except ValueError as e:
69
101
  msg = None
70
102
  if len(e.args) > 0: