dstack 0.19.12rc1__py3-none-any.whl → 0.19.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (44) hide show
  1. dstack/_internal/cli/services/configurators/run.py +43 -47
  2. dstack/_internal/cli/utils/run.py +15 -27
  3. dstack/_internal/core/backends/aws/compute.py +22 -9
  4. dstack/_internal/core/backends/aws/resources.py +26 -0
  5. dstack/_internal/core/backends/base/offers.py +0 -1
  6. dstack/_internal/core/backends/template/configurator.py.jinja +1 -6
  7. dstack/_internal/core/backends/template/models.py.jinja +4 -0
  8. dstack/_internal/core/compatibility/__init__.py +0 -0
  9. dstack/_internal/core/compatibility/fleets.py +72 -0
  10. dstack/_internal/core/compatibility/gateways.py +34 -0
  11. dstack/_internal/core/compatibility/runs.py +125 -0
  12. dstack/_internal/core/compatibility/volumes.py +32 -0
  13. dstack/_internal/core/models/configurations.py +1 -1
  14. dstack/_internal/core/models/fleets.py +6 -1
  15. dstack/_internal/core/models/instances.py +51 -12
  16. dstack/_internal/core/models/profiles.py +43 -3
  17. dstack/_internal/core/models/repos/local.py +3 -3
  18. dstack/_internal/core/models/runs.py +118 -44
  19. dstack/_internal/server/app.py +1 -1
  20. dstack/_internal/server/background/tasks/process_running_jobs.py +47 -12
  21. dstack/_internal/server/background/tasks/process_runs.py +14 -1
  22. dstack/_internal/server/services/runner/client.py +4 -1
  23. dstack/_internal/server/services/storage/__init__.py +38 -0
  24. dstack/_internal/server/services/storage/base.py +27 -0
  25. dstack/_internal/server/services/storage/gcs.py +44 -0
  26. dstack/_internal/server/services/{storage.py → storage/s3.py} +4 -27
  27. dstack/_internal/server/settings.py +7 -3
  28. dstack/_internal/server/statics/index.html +1 -1
  29. dstack/_internal/server/statics/{main-5b9786c955b42bf93581.js → main-2066f1f22ddb4557bcde.js} +1677 -46
  30. dstack/_internal/server/statics/{main-5b9786c955b42bf93581.js.map → main-2066f1f22ddb4557bcde.js.map} +1 -1
  31. dstack/_internal/server/statics/{main-8f9c66f404e9c7e7e020.css → main-f39c418b05fe14772dd8.css} +1 -1
  32. dstack/_internal/server/testing/common.py +2 -1
  33. dstack/_internal/utils/common.py +4 -0
  34. dstack/api/server/_fleets.py +9 -69
  35. dstack/api/server/_gateways.py +3 -14
  36. dstack/api/server/_runs.py +4 -116
  37. dstack/api/server/_volumes.py +3 -14
  38. dstack/plugins/builtin/rest_plugin/_plugin.py +24 -5
  39. dstack/version.py +2 -2
  40. {dstack-0.19.12rc1.dist-info → dstack-0.19.13.dist-info}/METADATA +1 -1
  41. {dstack-0.19.12rc1.dist-info → dstack-0.19.13.dist-info}/RECORD +44 -36
  42. {dstack-0.19.12rc1.dist-info → dstack-0.19.13.dist-info}/WHEEL +0 -0
  43. {dstack-0.19.12rc1.dist-info → dstack-0.19.13.dist-info}/entry_points.txt +0 -0
  44. {dstack-0.19.12rc1.dist-info → dstack-0.19.13.dist-info}/licenses/LICENSE.md +0 -0
@@ -440,7 +440,7 @@ class ServiceConfigurationParams(CoreModel):
440
440
  raise ValueError("The minimum number of replicas must be greater than or equal to 0")
441
441
  if v.max < v.min:
442
442
  raise ValueError(
443
- "The maximum number of replicas must be greater than or equal to the minium number of replicas"
443
+ "The maximum number of replicas must be greater than or equal to the minimum number of replicas"
444
444
  )
445
445
  return v
446
446
 
@@ -20,6 +20,7 @@ from dstack._internal.core.models.profiles import (
20
20
  parse_idle_duration,
21
21
  )
22
22
  from dstack._internal.core.models.resources import Range, ResourcesSpec
23
+ from dstack._internal.utils.common import list_enum_values_for_annotation
23
24
  from dstack._internal.utils.json_schema import add_extra_schema_types
24
25
  from dstack._internal.utils.tags import tags_validator
25
26
 
@@ -207,7 +208,11 @@ class InstanceGroupParams(CoreModel):
207
208
  spot_policy: Annotated[
208
209
  Optional[SpotPolicy],
209
210
  Field(
210
- description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`"
211
+ description=(
212
+ "The policy for provisioning spot or on-demand instances:"
213
+ f" {list_enum_values_for_annotation(SpotPolicy)}."
214
+ f" Defaults to `{SpotPolicy.ONDEMAND.value}`"
215
+ )
211
216
  ),
212
217
  ] = None
213
218
  retry: Annotated[
@@ -48,29 +48,68 @@ class Resources(CoreModel):
48
48
  gpus: List[Gpu]
49
49
  spot: bool
50
50
  disk: Disk = Disk(size_mib=102400) # the default value (100GB) for backward compatibility
51
+ # TODO: make description a computed field after migrating to pydanticV2
51
52
  description: str = ""
52
53
  cpu_arch: Optional[gpuhunt.CPUArchitecture] = None
53
54
 
54
- def pretty_format(self, include_spot: bool = False) -> str:
55
+ @root_validator
56
+ def _description(cls, values) -> Dict:
57
+ try:
58
+ description = values["description"]
59
+ if not description:
60
+ cpus = values["cpus"]
61
+ memory_mib = values["memory_mib"]
62
+ gpus = values["gpus"]
63
+ disk_size_mib = values["disk"].size_mib
64
+ spot = values["spot"]
65
+ cpu_arch = values["cpu_arch"]
66
+ values["description"] = Resources._pretty_format(
67
+ cpus, cpu_arch, memory_mib, disk_size_mib, gpus, spot, include_spot=True
68
+ )
69
+ except KeyError:
70
+ return values
71
+ return values
72
+
73
+ @staticmethod
74
+ def _pretty_format(
75
+ cpus: int,
76
+ cpu_arch: Optional[gpuhunt.CPUArchitecture],
77
+ memory_mib: int,
78
+ disk_size_mib: int,
79
+ gpus: List[Gpu],
80
+ spot: bool,
81
+ include_spot: bool = False,
82
+ ) -> str:
55
83
  resources = {}
56
- if self.cpus > 0:
57
- resources["cpus"] = self.cpus
58
- resources["cpu_arch"] = self.cpu_arch
59
- if self.memory_mib > 0:
60
- resources["memory"] = f"{self.memory_mib / 1024:.0f}GB"
61
- if self.disk.size_mib > 0:
62
- resources["disk_size"] = f"{self.disk.size_mib / 1024:.0f}GB"
63
- if self.gpus:
64
- gpu = self.gpus[0]
84
+ if cpus > 0:
85
+ resources["cpus"] = cpus
86
+ resources["cpu_arch"] = cpu_arch
87
+ if memory_mib > 0:
88
+ resources["memory"] = f"{memory_mib / 1024:.0f}GB"
89
+ if disk_size_mib > 0:
90
+ resources["disk_size"] = f"{disk_size_mib / 1024:.0f}GB"
91
+ if gpus:
92
+ gpu = gpus[0]
65
93
  resources["gpu_name"] = gpu.name
66
- resources["gpu_count"] = len(self.gpus)
94
+ resources["gpu_count"] = len(gpus)
67
95
  if gpu.memory_mib > 0:
68
96
  resources["gpu_memory"] = f"{gpu.memory_mib / 1024:.0f}GB"
69
97
  output = pretty_resources(**resources)
70
- if include_spot and self.spot:
98
+ if include_spot and spot:
71
99
  output += " (spot)"
72
100
  return output
73
101
 
102
+ def pretty_format(self, include_spot: bool = False) -> str:
103
+ return Resources._pretty_format(
104
+ self.cpus,
105
+ self.cpu_arch,
106
+ self.memory_mib,
107
+ self.disk.size_mib,
108
+ self.gpus,
109
+ self.spot,
110
+ include_spot,
111
+ )
112
+
74
113
 
75
114
  class InstanceType(CoreModel):
76
115
  name: str
@@ -6,6 +6,7 @@ from typing_extensions import Annotated, Literal
6
6
 
7
7
  from dstack._internal.core.models.backends.base import BackendType
8
8
  from dstack._internal.core.models.common import CoreModel, Duration
9
+ from dstack._internal.utils.common import list_enum_values_for_annotation
9
10
  from dstack._internal.utils.tags import tags_validator
10
11
 
11
12
  DEFAULT_RETRY_DURATION = 3600
@@ -32,6 +33,17 @@ class TerminationPolicy(str, Enum):
32
33
  DESTROY_AFTER_IDLE = "destroy-after-idle"
33
34
 
34
35
 
36
+ class StartupOrder(str, Enum):
37
+ ANY = "any"
38
+ MASTER_FIRST = "master-first"
39
+ WORKERS_FIRST = "workers-first"
40
+
41
+
42
+ class StopCriteria(str, Enum):
43
+ ALL_DONE = "all-done"
44
+ MASTER_DONE = "master-done"
45
+
46
+
35
47
  @overload
36
48
  def parse_duration(v: None) -> None: ...
37
49
 
@@ -102,7 +114,7 @@ class ProfileRetry(CoreModel):
102
114
  Field(
103
115
  description=(
104
116
  "The list of events that should be handled with retry."
105
- " Supported events are `no-capacity`, `interruption`, and `error`."
117
+ f" Supported events are {list_enum_values_for_annotation(RetryEvent)}."
106
118
  " Omit to retry on all events"
107
119
  )
108
120
  ),
@@ -190,7 +202,11 @@ class ProfileParams(CoreModel):
190
202
  spot_policy: Annotated[
191
203
  Optional[SpotPolicy],
192
204
  Field(
193
- description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`. Defaults to `on-demand`"
205
+ description=(
206
+ "The policy for provisioning spot or on-demand instances:"
207
+ f" {list_enum_values_for_annotation(SpotPolicy)}."
208
+ f" Defaults to `{SpotPolicy.ONDEMAND.value}`"
209
+ )
194
210
  ),
195
211
  ] = None
196
212
  retry: Annotated[
@@ -225,7 +241,11 @@ class ProfileParams(CoreModel):
225
241
  creation_policy: Annotated[
226
242
  Optional[CreationPolicy],
227
243
  Field(
228
- description="The policy for using instances from fleets. Defaults to `reuse-or-create`"
244
+ description=(
245
+ "The policy for using instances from fleets:"
246
+ f" {list_enum_values_for_annotation(CreationPolicy)}."
247
+ f" Defaults to `{CreationPolicy.REUSE_OR_CREATE.value}`"
248
+ )
229
249
  ),
230
250
  ] = None
231
251
  idle_duration: Annotated[
@@ -241,6 +261,26 @@ class ProfileParams(CoreModel):
241
261
  Optional[UtilizationPolicy],
242
262
  Field(description="Run termination policy based on utilization"),
243
263
  ] = None
264
+ startup_order: Annotated[
265
+ Optional[StartupOrder],
266
+ Field(
267
+ description=(
268
+ f"The order in which master and workers jobs are started:"
269
+ f" {list_enum_values_for_annotation(StartupOrder)}."
270
+ f" Defaults to `{StartupOrder.ANY.value}`"
271
+ )
272
+ ),
273
+ ] = None
274
+ stop_criteria: Annotated[
275
+ Optional[StopCriteria],
276
+ Field(
277
+ description=(
278
+ "The criteria determining when a multi-node run should be considered finished:"
279
+ f" {list_enum_values_for_annotation(StopCriteria)}."
280
+ f" Defaults to `{StopCriteria.ALL_DONE.value}`"
281
+ )
282
+ ),
283
+ ] = None
244
284
  fleets: Annotated[
245
285
  Optional[list[str]], Field(description="The fleets considered for reuse")
246
286
  ] = None
@@ -84,9 +84,9 @@ class LocalRepo(Repo):
84
84
  .add_custom_ignore_filename(".dstackignore")
85
85
  .build()
86
86
  ):
87
- path = entry.path().relative_to(repo_path.absolute())
88
- if path != Path("."):
89
- t.add(path, recursive=False)
87
+ entry_path_within_repo = entry.path().relative_to(repo_path)
88
+ if entry_path_within_repo != Path("."):
89
+ t.add(entry.path(), arcname=entry_path_within_repo, recursive=False)
90
90
  logger.debug("Code file size: %s", sizeof_fmt(fp.tell()))
91
91
  return get_sha256(fp)
92
92
 
@@ -148,9 +148,6 @@ class JobTerminationReason(str, Enum):
148
148
  }
149
149
  return mapping[self]
150
150
 
151
- def pretty_repr(self) -> str:
152
- return " ".join(self.value.split("_")).capitalize()
153
-
154
151
 
155
152
  class Requirements(CoreModel):
156
153
  # TODO: Make requirements' fields required
@@ -289,6 +286,9 @@ class JobSubmission(CoreModel):
289
286
  exit_status: Optional[int]
290
287
  job_provisioning_data: Optional[JobProvisioningData]
291
288
  job_runtime_data: Optional[JobRuntimeData]
289
+ # TODO: make status_message and error a computed field after migrating to pydanticV2
290
+ status_message: Optional[str]
291
+ error: Optional[str] = None
292
292
 
293
293
  @property
294
294
  def age(self) -> timedelta:
@@ -301,6 +301,71 @@ class JobSubmission(CoreModel):
301
301
  end_time = self.finished_at
302
302
  return end_time - self.submitted_at
303
303
 
304
+ @root_validator
305
+ def _status_message(cls, values) -> Dict:
306
+ try:
307
+ status = values["status"]
308
+ termination_reason = values["termination_reason"]
309
+ exit_code = values["exit_status"]
310
+ except KeyError:
311
+ return values
312
+ values["status_message"] = JobSubmission._get_status_message(
313
+ status=status,
314
+ termination_reason=termination_reason,
315
+ exit_status=exit_code,
316
+ )
317
+ return values
318
+
319
+ @staticmethod
320
+ def _get_status_message(
321
+ status: JobStatus,
322
+ termination_reason: Optional[JobTerminationReason],
323
+ exit_status: Optional[int],
324
+ ) -> str:
325
+ if status == JobStatus.DONE:
326
+ return "exited (0)"
327
+ elif status == JobStatus.FAILED:
328
+ if termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR:
329
+ return f"exited ({exit_status})"
330
+ elif termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY:
331
+ return "no offers"
332
+ elif termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY:
333
+ return "interrupted"
334
+ else:
335
+ return "error"
336
+ elif status == JobStatus.TERMINATED:
337
+ if termination_reason == JobTerminationReason.TERMINATED_BY_USER:
338
+ return "stopped"
339
+ elif termination_reason == JobTerminationReason.ABORTED_BY_USER:
340
+ return "aborted"
341
+ return status.value
342
+
343
+ @root_validator
344
+ def _error(cls, values) -> Dict:
345
+ try:
346
+ termination_reason = values["termination_reason"]
347
+ except KeyError:
348
+ return values
349
+ values["error"] = JobSubmission._get_error(termination_reason=termination_reason)
350
+ return values
351
+
352
+ @staticmethod
353
+ def _get_error(termination_reason: Optional[JobTerminationReason]) -> Optional[str]:
354
+ error_mapping = {
355
+ JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable",
356
+ JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded",
357
+ JobTerminationReason.VOLUME_ERROR: "waiting runner limit exceeded",
358
+ JobTerminationReason.GATEWAY_ERROR: "gateway error",
359
+ JobTerminationReason.SCALED_DOWN: "scaled down",
360
+ JobTerminationReason.INACTIVITY_DURATION_EXCEEDED: "inactivity duration exceeded",
361
+ JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY: "utilization policy",
362
+ JobTerminationReason.PORTS_BINDING_FAILED: "ports binding failed",
363
+ JobTerminationReason.CREATING_CONTAINER_ERROR: "runner error",
364
+ JobTerminationReason.EXECUTOR_ERROR: "executor error",
365
+ JobTerminationReason.MAX_DURATION_EXCEEDED: "max duration exceeded",
366
+ }
367
+ return error_mapping.get(termination_reason)
368
+
304
369
 
305
370
  class Job(CoreModel):
306
371
  job_spec: JobSpec
@@ -431,6 +496,7 @@ class Run(CoreModel):
431
496
  submitted_at: datetime
432
497
  last_processed_at: datetime
433
498
  status: RunStatus
499
+ status_message: Optional[str] = None
434
500
  termination_reason: Optional[RunTerminationReason]
435
501
  run_spec: RunSpec
436
502
  jobs: List[Job]
@@ -445,15 +511,60 @@ class Run(CoreModel):
445
511
  def _error(cls, values) -> Dict:
446
512
  try:
447
513
  termination_reason = values["termination_reason"]
448
- jobs = values["jobs"]
449
514
  except KeyError:
450
515
  return values
451
- values["error"] = _get_run_error(
452
- run_termination_reason=termination_reason,
453
- run_jobs=jobs,
516
+ values["error"] = Run._get_error(termination_reason=termination_reason)
517
+ return values
518
+
519
+ @staticmethod
520
+ def _get_error(termination_reason: Optional[RunTerminationReason]) -> Optional[str]:
521
+ if termination_reason == RunTerminationReason.RETRY_LIMIT_EXCEEDED:
522
+ return "retry limit exceeded"
523
+ elif termination_reason == RunTerminationReason.SERVER_ERROR:
524
+ return "server error"
525
+ else:
526
+ return None
527
+
528
+ @root_validator
529
+ def _status_message(cls, values) -> Dict:
530
+ try:
531
+ status = values["status"]
532
+ jobs: List[Job] = values["jobs"]
533
+ retry_on_events = (
534
+ jobs[0].job_spec.retry.on_events if jobs and jobs[0].job_spec.retry else []
535
+ )
536
+ termination_reason = Run.get_last_termination_reason(jobs[0]) if jobs else None
537
+ except KeyError:
538
+ return values
539
+ values["status_message"] = Run._get_status_message(
540
+ status=status,
541
+ retry_on_events=retry_on_events,
542
+ termination_reason=termination_reason,
454
543
  )
455
544
  return values
456
545
 
546
+ @staticmethod
547
+ def get_last_termination_reason(job: "Job") -> Optional[JobTerminationReason]:
548
+ for submission in reversed(job.job_submissions):
549
+ if submission.termination_reason is not None:
550
+ return submission.termination_reason
551
+ return None
552
+
553
+ @staticmethod
554
+ def _get_status_message(
555
+ status: RunStatus,
556
+ retry_on_events: List[RetryEvent],
557
+ termination_reason: Optional[JobTerminationReason],
558
+ ) -> str:
559
+ # Currently, `retrying` is shown only for `no-capacity` events
560
+ if (
561
+ status in [RunStatus.SUBMITTED, RunStatus.PENDING]
562
+ and termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
563
+ and RetryEvent.NO_CAPACITY in retry_on_events
564
+ ):
565
+ return "retrying"
566
+ return status.value
567
+
457
568
 
458
569
  class JobPlan(CoreModel):
459
570
  job_spec: JobSpec
@@ -502,40 +613,3 @@ def get_policy_map(spot_policy: Optional[SpotPolicy], default: SpotPolicy) -> Op
502
613
  SpotPolicy.ONDEMAND: False,
503
614
  }
504
615
  return policy_map[spot_policy]
505
-
506
-
507
- def _get_run_error(
508
- run_termination_reason: Optional[RunTerminationReason],
509
- run_jobs: List[Job],
510
- ) -> str:
511
- if run_termination_reason is None:
512
- return ""
513
- if len(run_jobs) > 1:
514
- return run_termination_reason.name
515
- run_job_termination_reason, exit_status = _get_run_job_termination_reason_and_exit_status(
516
- run_jobs
517
- )
518
- # For failed runs, also show termination reason to provide more context.
519
- # For other run statuses, the job termination reason will duplicate run status.
520
- if run_job_termination_reason is not None and run_termination_reason in [
521
- RunTerminationReason.JOB_FAILED,
522
- RunTerminationReason.SERVER_ERROR,
523
- RunTerminationReason.RETRY_LIMIT_EXCEEDED,
524
- ]:
525
- if exit_status:
526
- return (
527
- f"{run_termination_reason.name}\n({run_job_termination_reason.name} {exit_status})"
528
- )
529
- return f"{run_termination_reason.name}\n({run_job_termination_reason.name})"
530
- return run_termination_reason.name
531
-
532
-
533
- def _get_run_job_termination_reason_and_exit_status(
534
- run_jobs: List[Job],
535
- ) -> tuple[Optional[JobTerminationReason], Optional[int]]:
536
- for job in run_jobs:
537
- if len(job.job_submissions) > 0:
538
- job_submission = job.job_submissions[-1]
539
- if job_submission.termination_reason is not None:
540
- return job_submission.termination_reason, job_submission.exit_status
541
- return None, None
@@ -128,7 +128,7 @@ async def lifespan(app: FastAPI):
128
128
  yes=UPDATE_DEFAULT_PROJECT,
129
129
  no=DO_NOT_UPDATE_DEFAULT_PROJECT,
130
130
  )
131
- if settings.SERVER_BUCKET is not None:
131
+ if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None:
132
132
  init_default_storage()
133
133
  scheduler = start_background_tasks()
134
134
  dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)"
@@ -18,6 +18,7 @@ from dstack._internal.core.models.instances import (
18
18
  SSHConnectionParams,
19
19
  )
20
20
  from dstack._internal.core.models.metrics import Metric
21
+ from dstack._internal.core.models.profiles import StartupOrder
21
22
  from dstack._internal.core.models.repos import RemoteRepoCreds
22
23
  from dstack._internal.core.models.runs import (
23
24
  ClusterInfo,
@@ -184,18 +185,10 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
184
185
  if job_provisioning_data.hostname is None:
185
186
  await _wait_for_instance_provisioning_data(job_model=job_model)
186
187
  else:
187
- # Wait until all other jobs in the replica have IPs assigned.
188
- # This is needed to ensure cluster_info has all IPs set.
189
- for other_job in run.jobs:
190
- if (
191
- other_job.job_spec.replica_num == job.job_spec.replica_num
192
- and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
193
- and other_job.job_submissions[-1].job_provisioning_data is not None
194
- and other_job.job_submissions[-1].job_provisioning_data.hostname is None
195
- ):
196
- job_model.last_processed_at = common_utils.get_current_datetime()
197
- await session.commit()
198
- return
188
+ if _should_wait_for_other_nodes(run, job, job_model):
189
+ job_model.last_processed_at = common_utils.get_current_datetime()
190
+ await session.commit()
191
+ return
199
192
 
200
193
  # fails are acceptable until timeout is exceeded
201
194
  if job_provisioning_data.dockerized:
@@ -406,6 +399,48 @@ async def _wait_for_instance_provisioning_data(job_model: JobModel):
406
399
  job_model.job_provisioning_data = job_model.instance.job_provisioning_data
407
400
 
408
401
 
402
+ def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> bool:
403
+ for other_job in run.jobs:
404
+ if (
405
+ other_job.job_spec.replica_num == job.job_spec.replica_num
406
+ and other_job.job_submissions[-1].status == JobStatus.PROVISIONING
407
+ and other_job.job_submissions[-1].job_provisioning_data is not None
408
+ and other_job.job_submissions[-1].job_provisioning_data.hostname is None
409
+ ):
410
+ logger.debug(
411
+ "%s: waiting for other job to have IP assigned",
412
+ fmt(job_model),
413
+ )
414
+ return True
415
+ master_job = find_job(run.jobs, job.job_spec.replica_num, 0)
416
+ if (
417
+ job.job_spec.job_num != 0
418
+ and run.run_spec.merged_profile.startup_order == StartupOrder.MASTER_FIRST
419
+ and master_job.job_submissions[-1].status != JobStatus.RUNNING
420
+ ):
421
+ logger.debug(
422
+ "%s: waiting for master job to become running",
423
+ fmt(job_model),
424
+ )
425
+ return True
426
+ if (
427
+ job.job_spec.job_num == 0
428
+ and run.run_spec.merged_profile.startup_order == StartupOrder.WORKERS_FIRST
429
+ ):
430
+ for other_job in run.jobs:
431
+ if (
432
+ other_job.job_spec.replica_num == job.job_spec.replica_num
433
+ and other_job.job_spec.job_num != job.job_spec.job_num
434
+ and other_job.job_submissions[-1].status != JobStatus.RUNNING
435
+ ):
436
+ logger.debug(
437
+ "%s: waiting for worker job to become running",
438
+ fmt(job_model),
439
+ )
440
+ return True
441
+ return False
442
+
443
+
409
444
  @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
410
445
  def _process_provisioning_with_shim(
411
446
  ports: Dict[int, int],
@@ -10,7 +10,7 @@ from sqlalchemy.orm import joinedload, selectinload
10
10
  import dstack._internal.server.services.gateways as gateways
11
11
  import dstack._internal.server.services.services.autoscalers as autoscalers
12
12
  from dstack._internal.core.errors import ServerError
13
- from dstack._internal.core.models.profiles import RetryEvent
13
+ from dstack._internal.core.models.profiles import RetryEvent, StopCriteria
14
14
  from dstack._internal.core.models.runs import (
15
15
  Job,
16
16
  JobStatus,
@@ -313,6 +313,10 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
313
313
  termination_reason = RunTerminationReason.RETRY_LIMIT_EXCEEDED
314
314
  else:
315
315
  raise ValueError(f"Unexpected termination reason {run_termination_reasons}")
316
+ elif _should_stop_on_master_done(run):
317
+ new_status = RunStatus.TERMINATING
318
+ # ALL_JOBS_DONE is used for all DONE reasons including master-done
319
+ termination_reason = RunTerminationReason.ALL_JOBS_DONE
316
320
  elif RunStatus.RUNNING in run_statuses:
317
321
  new_status = RunStatus.RUNNING
318
322
  elif RunStatus.PROVISIONING in run_statuses:
@@ -434,3 +438,12 @@ def _can_retry_single_job(run_spec: RunSpec) -> bool:
434
438
  # We could make partial retry in some multi-node cases.
435
439
  # E.g. restarting a worker node, independent jobs.
436
440
  return False
441
+
442
+
443
+ def _should_stop_on_master_done(run: Run) -> bool:
444
+ if run.run_spec.merged_profile.stop_criteria != StopCriteria.MASTER_DONE:
445
+ return False
446
+ for job in run.jobs:
447
+ if job.job_spec.job_num == 0 and job.job_submissions[-1].status == JobStatus.DONE:
448
+ return True
449
+ return False
@@ -32,6 +32,7 @@ from dstack._internal.utils.common import get_or_error
32
32
  from dstack._internal.utils.logging import get_logger
33
33
 
34
34
  REQUEST_TIMEOUT = 9
35
+ UPLOAD_CODE_REQUEST_TIMEOUT = 60
35
36
 
36
37
  logger = get_logger(__name__)
37
38
 
@@ -109,7 +110,9 @@ class RunnerClient:
109
110
  resp.raise_for_status()
110
111
 
111
112
  def upload_code(self, file: Union[BinaryIO, bytes]):
112
- resp = requests.post(self._url("/api/upload_code"), data=file, timeout=REQUEST_TIMEOUT)
113
+ resp = requests.post(
114
+ self._url("/api/upload_code"), data=file, timeout=UPLOAD_CODE_REQUEST_TIMEOUT
115
+ )
113
116
  resp.raise_for_status()
114
117
 
115
118
  def run_job(self):
@@ -0,0 +1,38 @@
1
+ from typing import Optional
2
+
3
+ from dstack._internal.server import settings
4
+ from dstack._internal.server.services.storage.base import BaseStorage
5
+ from dstack._internal.server.services.storage.gcs import GCS_AVAILABLE, GCSStorage
6
+ from dstack._internal.server.services.storage.s3 import BOTO_AVAILABLE, S3Storage
7
+
8
+ _default_storage = None
9
+
10
+
11
+ def init_default_storage():
12
+ global _default_storage
13
+ if settings.SERVER_S3_BUCKET is None and settings.SERVER_GCS_BUCKET is None:
14
+ raise ValueError(
15
+ "Either settings.SERVER_S3_BUCKET or settings.SERVER_GCS_BUCKET must be set"
16
+ )
17
+ if settings.SERVER_S3_BUCKET and settings.SERVER_GCS_BUCKET:
18
+ raise ValueError(
19
+ "Only one of settings.SERVER_S3_BUCKET or settings.SERVER_GCS_BUCKET can be set"
20
+ )
21
+
22
+ if settings.SERVER_S3_BUCKET:
23
+ if not BOTO_AVAILABLE:
24
+ raise ValueError("AWS dependencies are not installed")
25
+ _default_storage = S3Storage(
26
+ bucket=settings.SERVER_S3_BUCKET,
27
+ region=settings.SERVER_S3_BUCKET_REGION,
28
+ )
29
+ elif settings.SERVER_GCS_BUCKET:
30
+ if not GCS_AVAILABLE:
31
+ raise ValueError("GCS dependencies are not installed")
32
+ _default_storage = GCSStorage(
33
+ bucket=settings.SERVER_GCS_BUCKET,
34
+ )
35
+
36
+
37
+ def get_default_storage() -> Optional[BaseStorage]:
38
+ return _default_storage
@@ -0,0 +1,27 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional
3
+
4
+
5
+ class BaseStorage(ABC):
6
+ @abstractmethod
7
+ def upload_code(
8
+ self,
9
+ project_id: str,
10
+ repo_id: str,
11
+ code_hash: str,
12
+ blob: bytes,
13
+ ):
14
+ pass
15
+
16
+ @abstractmethod
17
+ def get_code(
18
+ self,
19
+ project_id: str,
20
+ repo_id: str,
21
+ code_hash: str,
22
+ ) -> Optional[bytes]:
23
+ pass
24
+
25
+ @staticmethod
26
+ def _get_code_key(project_id: str, repo_id: str, code_hash: str) -> str:
27
+ return f"data/projects/{project_id}/codes/{repo_id}/{code_hash}"
@@ -0,0 +1,44 @@
1
+ from typing import Optional
2
+
3
+ from dstack._internal.server.services.storage.base import BaseStorage
4
+
5
+ GCS_AVAILABLE = True
6
+ try:
7
+ from google.cloud import storage
8
+ from google.cloud.exceptions import NotFound
9
+ except ImportError:
10
+ GCS_AVAILABLE = False
11
+
12
+
13
+ class GCSStorage(BaseStorage):
14
+ def __init__(
15
+ self,
16
+ bucket: str,
17
+ ):
18
+ self._client = storage.Client()
19
+ self._bucket = self._client.bucket(bucket)
20
+
21
+ def upload_code(
22
+ self,
23
+ project_id: str,
24
+ repo_id: str,
25
+ code_hash: str,
26
+ blob: bytes,
27
+ ):
28
+ blob_name = self._get_code_key(project_id, repo_id, code_hash)
29
+ blob_obj = self._bucket.blob(blob_name)
30
+ blob_obj.upload_from_string(blob)
31
+
32
+ def get_code(
33
+ self,
34
+ project_id: str,
35
+ repo_id: str,
36
+ code_hash: str,
37
+ ) -> Optional[bytes]:
38
+ try:
39
+ blob_name = self._get_code_key(project_id, repo_id, code_hash)
40
+ blob = self._bucket.blob(blob_name)
41
+ except NotFound:
42
+ return None
43
+
44
+ return blob.download_as_bytes()