dstack 0.18.40rc1__py3-none-any.whl → 0.18.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -5
  2. dstack/_internal/cli/services/configurators/base.py +4 -2
  3. dstack/_internal/cli/services/configurators/fleet.py +21 -9
  4. dstack/_internal/cli/services/configurators/gateway.py +15 -0
  5. dstack/_internal/cli/services/configurators/run.py +6 -5
  6. dstack/_internal/cli/services/configurators/volume.py +15 -0
  7. dstack/_internal/cli/services/repos.py +3 -3
  8. dstack/_internal/cli/utils/fleet.py +44 -33
  9. dstack/_internal/cli/utils/run.py +27 -7
  10. dstack/_internal/cli/utils/volume.py +30 -9
  11. dstack/_internal/core/backends/aws/compute.py +94 -53
  12. dstack/_internal/core/backends/aws/resources.py +22 -12
  13. dstack/_internal/core/backends/azure/compute.py +2 -0
  14. dstack/_internal/core/backends/base/compute.py +20 -2
  15. dstack/_internal/core/backends/gcp/compute.py +32 -24
  16. dstack/_internal/core/backends/gcp/resources.py +0 -15
  17. dstack/_internal/core/backends/oci/compute.py +10 -5
  18. dstack/_internal/core/backends/oci/resources.py +23 -26
  19. dstack/_internal/core/backends/remote/provisioning.py +65 -27
  20. dstack/_internal/core/backends/runpod/compute.py +1 -0
  21. dstack/_internal/core/models/backends/azure.py +3 -1
  22. dstack/_internal/core/models/configurations.py +24 -1
  23. dstack/_internal/core/models/fleets.py +46 -0
  24. dstack/_internal/core/models/instances.py +5 -1
  25. dstack/_internal/core/models/pools.py +4 -1
  26. dstack/_internal/core/models/profiles.py +10 -4
  27. dstack/_internal/core/models/runs.py +23 -3
  28. dstack/_internal/core/models/volumes.py +26 -0
  29. dstack/_internal/core/services/ssh/attach.py +92 -53
  30. dstack/_internal/core/services/ssh/tunnel.py +58 -31
  31. dstack/_internal/proxy/gateway/routers/registry.py +2 -0
  32. dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
  33. dstack/_internal/proxy/gateway/services/registry.py +4 -0
  34. dstack/_internal/proxy/lib/models.py +3 -0
  35. dstack/_internal/proxy/lib/services/service_connection.py +8 -1
  36. dstack/_internal/server/background/tasks/process_instances.py +73 -35
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -9
  38. dstack/_internal/server/background/tasks/process_running_jobs.py +77 -26
  39. dstack/_internal/server/background/tasks/process_runs.py +2 -12
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +121 -49
  41. dstack/_internal/server/background/tasks/process_terminating_jobs.py +14 -3
  42. dstack/_internal/server/background/tasks/process_volumes.py +11 -1
  43. dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
  44. dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
  45. dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
  46. dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
  47. dstack/_internal/server/migrations/versions/a751ef183f27_move_attachment_data_to_volumes_.py +34 -0
  48. dstack/_internal/server/models.py +27 -23
  49. dstack/_internal/server/routers/runs.py +1 -0
  50. dstack/_internal/server/schemas/runner.py +1 -0
  51. dstack/_internal/server/services/backends/configurators/azure.py +34 -8
  52. dstack/_internal/server/services/config.py +9 -0
  53. dstack/_internal/server/services/fleets.py +32 -3
  54. dstack/_internal/server/services/gateways/client.py +9 -1
  55. dstack/_internal/server/services/jobs/__init__.py +217 -45
  56. dstack/_internal/server/services/jobs/configurators/base.py +47 -2
  57. dstack/_internal/server/services/offers.py +96 -10
  58. dstack/_internal/server/services/pools.py +98 -14
  59. dstack/_internal/server/services/proxy/repo.py +17 -3
  60. dstack/_internal/server/services/runner/client.py +9 -6
  61. dstack/_internal/server/services/runner/ssh.py +33 -5
  62. dstack/_internal/server/services/runs.py +48 -179
  63. dstack/_internal/server/services/services/__init__.py +9 -1
  64. dstack/_internal/server/services/volumes.py +68 -9
  65. dstack/_internal/server/statics/index.html +1 -1
  66. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
  67. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
  68. dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
  69. dstack/_internal/server/testing/common.py +130 -61
  70. dstack/_internal/utils/common.py +22 -8
  71. dstack/_internal/utils/env.py +14 -0
  72. dstack/_internal/utils/ssh.py +1 -1
  73. dstack/api/server/_fleets.py +25 -1
  74. dstack/api/server/_runs.py +23 -2
  75. dstack/api/server/_volumes.py +12 -1
  76. dstack/version.py +1 -1
  77. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/METADATA +1 -1
  78. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/RECORD +104 -93
  79. tests/_internal/cli/services/configurators/test_profile.py +3 -3
  80. tests/_internal/core/services/ssh/test_tunnel.py +56 -4
  81. tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
  82. tests/_internal/server/background/tasks/test_process_instances.py +138 -20
  83. tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
  84. tests/_internal/server/background/tasks/test_process_running_jobs.py +193 -0
  85. tests/_internal/server/background/tasks/test_process_runs.py +27 -3
  86. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +53 -6
  87. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +135 -17
  88. tests/_internal/server/routers/test_fleets.py +15 -2
  89. tests/_internal/server/routers/test_pools.py +6 -0
  90. tests/_internal/server/routers/test_runs.py +27 -0
  91. tests/_internal/server/routers/test_volumes.py +9 -2
  92. tests/_internal/server/services/jobs/__init__.py +0 -0
  93. tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
  94. tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
  95. tests/_internal/server/services/runner/test_client.py +22 -3
  96. tests/_internal/server/services/test_offers.py +167 -0
  97. tests/_internal/server/services/test_pools.py +109 -1
  98. tests/_internal/server/services/test_runs.py +5 -41
  99. tests/_internal/utils/test_common.py +21 -0
  100. tests/_internal/utils/test_env.py +38 -0
  101. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/LICENSE.md +0 -0
  102. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/WHEEL +0 -0
  103. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/entry_points.txt +0 -0
  104. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/top_level.txt +0 -0
@@ -7,14 +7,21 @@ from uuid import UUID
7
7
  import requests
8
8
  from sqlalchemy import select
9
9
  from sqlalchemy.ext.asyncio import AsyncSession
10
+ from sqlalchemy.orm import joinedload
10
11
 
11
12
  import dstack._internal.server.services.backends as backends_services
12
13
  from dstack._internal.core.backends.base import Backend
13
14
  from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT
14
- from dstack._internal.core.errors import BackendError, ResourceNotExistsError, SSHError
15
+ from dstack._internal.core.errors import (
16
+ BackendError,
17
+ ResourceNotExistsError,
18
+ ServerClientError,
19
+ SSHError,
20
+ )
15
21
  from dstack._internal.core.models.backends.base import BackendType
22
+ from dstack._internal.core.models.common import is_core_model_instance
16
23
  from dstack._internal.core.models.configurations import RunConfigurationType
17
- from dstack._internal.core.models.instances import InstanceStatus, RemoteConnectionInfo
24
+ from dstack._internal.core.models.instances import InstanceStatus
18
25
  from dstack._internal.core.models.runs import (
19
26
  Job,
20
27
  JobProvisioningData,
@@ -25,6 +32,7 @@ from dstack._internal.core.models.runs import (
25
32
  JobTerminationReason,
26
33
  RunSpec,
27
34
  )
35
+ from dstack._internal.core.models.volumes import Volume, VolumeMountPoint, VolumeStatus
28
36
  from dstack._internal.server.models import (
29
37
  InstanceModel,
30
38
  JobModel,
@@ -33,14 +41,22 @@ from dstack._internal.server.models import (
33
41
  VolumeModel,
34
42
  )
35
43
  from dstack._internal.server.services import services
36
- from dstack._internal.server.services.jobs.configurators.base import JobConfigurator
44
+ from dstack._internal.server.services import volumes as volumes_services
45
+ from dstack._internal.server.services.jobs.configurators.base import (
46
+ JobConfigurator,
47
+ interpolate_job_volumes,
48
+ )
37
49
  from dstack._internal.server.services.jobs.configurators.dev import DevEnvironmentJobConfigurator
38
50
  from dstack._internal.server.services.jobs.configurators.service import ServiceJobConfigurator
39
51
  from dstack._internal.server.services.jobs.configurators.task import TaskJobConfigurator
40
52
  from dstack._internal.server.services.logging import fmt
53
+ from dstack._internal.server.services.pools import get_instance_ssh_private_keys
41
54
  from dstack._internal.server.services.runner import client
42
55
  from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
43
- from dstack._internal.server.services.volumes import volume_model_to_volume
56
+ from dstack._internal.server.services.volumes import (
57
+ list_project_volume_models,
58
+ volume_model_to_volume,
59
+ )
44
60
  from dstack._internal.utils import common
45
61
  from dstack._internal.utils.common import get_or_error, run_async
46
62
  from dstack._internal.utils.logging import get_logger
@@ -115,6 +131,7 @@ def job_model_to_job_submission(job_model: JobModel) -> JobSubmission:
115
131
  submitted_at=job_model.submitted_at.replace(tzinfo=timezone.utc),
116
132
  last_processed_at=last_processed_at,
117
133
  finished_at=finished_at,
134
+ inactivity_secs=job_model.inactivity_secs,
118
135
  status=job_model.status,
119
136
  termination_reason=job_model.termination_reason,
120
137
  termination_reason_message=job_model.termination_reason_message,
@@ -155,29 +172,22 @@ _configuration_type_to_configurator_class_map = {c.TYPE: c for c in _job_configu
155
172
 
156
173
 
157
174
  async def stop_runner(session: AsyncSession, job_model: JobModel):
158
- project = await session.get(ProjectModel, job_model.project_id)
159
- ssh_private_key = project.ssh_private_key
160
-
161
175
  res = await session.execute(
162
- select(InstanceModel).where(
176
+ select(InstanceModel)
177
+ .where(
163
178
  InstanceModel.project_id == job_model.project_id,
164
- InstanceModel.job_id == job_model.id,
179
+ InstanceModel.id == job_model.instance_id,
165
180
  )
181
+ .options(joinedload(InstanceModel.project))
166
182
  )
167
183
  instance: Optional[InstanceModel] = res.scalar()
168
184
 
169
- # TODO: Drop this logic and always use project key once it's safe to assume that most on-prem
170
- # fleets are (re)created after this change: https://github.com/dstackai/dstack/pull/1716
171
- if instance and instance.remote_connection_info is not None:
172
- remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
173
- instance.remote_connection_info
174
- )
175
- ssh_private_key = remote_conn_info.ssh_keys[0].private
185
+ ssh_private_keys = get_instance_ssh_private_keys(common.get_or_error(instance))
176
186
  try:
177
187
  jpd = get_job_provisioning_data(job_model)
178
188
  if jpd is not None:
179
189
  jrd = get_job_runtime_data(job_model)
180
- await run_async(_stop_runner, ssh_private_key, jpd, jrd, job_model)
190
+ await run_async(_stop_runner, ssh_private_keys, jpd, jrd, job_model)
181
191
  except SSHError:
182
192
  logger.debug("%s: failed to stop runner", fmt(job_model))
183
193
 
@@ -219,30 +229,41 @@ async def process_terminating_job(
219
229
  _set_job_termination_status(job_model)
220
230
  return
221
231
 
232
+ all_volumes_detached: bool = True
233
+ jrd = get_job_runtime_data(job_model)
222
234
  jpd = get_job_provisioning_data(job_model)
223
235
  if jpd is not None:
224
236
  logger.debug("%s: stopping container", fmt(job_model))
225
- ssh_private_key = instance_model.project.ssh_private_key
226
- # TODO: Drop this logic and always use project key once it's safe to assume that
227
- # most on-prem fleets are (re)created after this change:
228
- # https://github.com/dstackai/dstack/pull/1716
229
- if instance_model and instance_model.remote_connection_info is not None:
230
- remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
231
- instance_model.remote_connection_info
237
+ ssh_private_keys = get_instance_ssh_private_keys(instance_model)
238
+ await stop_container(job_model, jpd, ssh_private_keys)
239
+ volume_models: list[VolumeModel]
240
+ if jrd is not None and jrd.volume_names is not None:
241
+ volume_models = await list_project_volume_models(
242
+ session=session, project=instance_model.project, names=jrd.volume_names
232
243
  )
233
- ssh_private_key = remote_conn_info.ssh_keys[0].private
234
- await stop_container(job_model, jpd, ssh_private_key)
235
- if len(instance_model.volumes) > 0:
236
- logger.info("Detaching volumes: %s", [v.name for v in instance_model.volumes])
237
- await _detach_volumes_from_job_instance(
244
+ else:
245
+ volume_models = [va.volume for va in instance_model.volume_attachments]
246
+ if len(volume_models) > 0:
247
+ logger.info("Detaching volumes: %s", [v.name for v in volume_models])
248
+ all_volumes_detached = await _detach_volumes_from_job_instance(
238
249
  project=instance_model.project,
239
250
  job_model=job_model,
240
251
  jpd=jpd,
241
252
  instance_model=instance_model,
253
+ volume_models=volume_models,
242
254
  )
243
255
 
256
+ if jrd is not None and jrd.offer is not None:
257
+ blocks = jrd.offer.blocks
258
+ else:
259
+ # Old job submitted before jrd or blocks were introduced
260
+ blocks = 1
261
+ instance_model.busy_blocks -= blocks
262
+
244
263
  if instance_model.status == InstanceStatus.BUSY:
245
- instance_model.status = InstanceStatus.IDLE
264
+ # no other jobs besides this one
265
+ if not [j for j in instance_model.jobs if j.id != job_model.id]:
266
+ instance_model.status = InstanceStatus.IDLE
246
267
  elif instance_model.status != InstanceStatus.TERMINATED:
247
268
  # instance was PROVISIONING (specially for the job)
248
269
  # schedule for termination
@@ -254,7 +275,7 @@ async def process_terminating_job(
254
275
 
255
276
  # The instance should be released even if detach fails
256
277
  # so that stuck volumes don't prevent the instance from terminating.
257
- instance_model.job_id = None
278
+ job_model.instance_id = None
258
279
  instance_model.last_job_processed_at = common.get_current_datetime()
259
280
  logger.info(
260
281
  "%s: instance '%s' has been released, new status is %s",
@@ -263,9 +284,8 @@ async def process_terminating_job(
263
284
  instance_model.status.name,
264
285
  )
265
286
  await services.unregister_replica(session, job_model)
266
- if len(instance_model.volumes) == 0:
287
+ if all_volumes_detached:
267
288
  # Do not terminate while some volumes are not detached.
268
- # TODO: In case of multiple jobs per instance, don't consider volumes from other jobs.
269
289
  _set_job_termination_status(job_model)
270
290
 
271
291
 
@@ -280,18 +300,25 @@ async def process_volumes_detaching(
280
300
  If the volumes fail to detach, force detaches them.
281
301
  """
282
302
  jpd = get_or_error(get_job_provisioning_data(job_model))
283
- logger.info("Detaching volumes: %s", [v.name for v in instance_model.volumes])
284
- await _detach_volumes_from_job_instance(
303
+ jrd = get_job_runtime_data(job_model)
304
+ if jrd is not None and jrd.volume_names is not None:
305
+ volume_models = await list_project_volume_models(
306
+ session=session, project=instance_model.project, names=jrd.volume_names
307
+ )
308
+ else:
309
+ volume_models = [va.volume for va in instance_model.volume_attachments]
310
+ logger.info("Detaching volumes: %s", [v.name for v in volume_models])
311
+ all_volumes_detached = await _detach_volumes_from_job_instance(
285
312
  project=instance_model.project,
286
313
  job_model=job_model,
287
314
  jpd=jpd,
288
315
  instance_model=instance_model,
316
+ volume_models=volume_models,
289
317
  )
290
- if len(instance_model.volumes) == 0:
318
+ if all_volumes_detached:
291
319
  # Do not terminate the job while some volumes are not detached.
292
320
  # If force detach never succeeds, the job will be stuck terminating.
293
321
  # The job releases the instance when soft detaching, so the instance won't be stuck.
294
- # TODO: In case of multiple jobs per instance, don't consider volumes from other jobs.
295
322
  _set_job_termination_status(job_model)
296
323
 
297
324
 
@@ -311,14 +338,16 @@ def _set_job_termination_status(job_model: JobModel):
311
338
 
312
339
 
313
340
  async def stop_container(
314
- job_model: JobModel, job_provisioning_data: JobProvisioningData, ssh_private_key: str
341
+ job_model: JobModel,
342
+ job_provisioning_data: JobProvisioningData,
343
+ ssh_private_keys: tuple[str, Optional[str]],
315
344
  ):
316
345
  if job_provisioning_data.dockerized:
317
346
  # send a request to the shim to terminate the docker container
318
347
  # SSHError and RequestException are caught in the `runner_ssh_tunner` decorator
319
348
  await run_async(
320
349
  _shim_submit_stop,
321
- ssh_private_key,
350
+ ssh_private_keys,
322
351
  job_provisioning_data,
323
352
  None,
324
353
  job_model,
@@ -378,7 +407,8 @@ async def _detach_volumes_from_job_instance(
378
407
  job_model: JobModel,
379
408
  jpd: JobProvisioningData,
380
409
  instance_model: InstanceModel,
381
- ):
410
+ volume_models: list[VolumeModel],
411
+ ) -> bool:
382
412
  job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
383
413
  backend = await backends_services.get_project_backend_by_type(
384
414
  project=project,
@@ -388,11 +418,11 @@ async def _detach_volumes_from_job_instance(
388
418
  logger.error(
389
419
  "Failed to detach volumes from %s. Backend not available.", instance_model.name
390
420
  )
391
- return
421
+ return False
392
422
 
393
- # TODO: In case of multiple jobs per instance, detach only volumes used by this job
423
+ all_detached = True
394
424
  detached_volumes = []
395
- for volume_model in instance_model.volumes:
425
+ for volume_model in volume_models:
396
426
  detached = await _detach_volume_from_job_instance(
397
427
  backend=backend,
398
428
  job_model=job_model,
@@ -403,13 +433,16 @@ async def _detach_volumes_from_job_instance(
403
433
  )
404
434
  if detached:
405
435
  detached_volumes.append(volume_model)
436
+ else:
437
+ all_detached = False
406
438
 
407
439
  if job_model.volumes_detached_at is None:
408
440
  job_model.volumes_detached_at = common.get_current_datetime()
409
441
  detached_volumes_ids = {v.id for v in detached_volumes}
410
- instance_model.volumes = [
411
- v for v in instance_model.volumes if v.id not in detached_volumes_ids
442
+ instance_model.volume_attachments = [
443
+ va for va in instance_model.volume_attachments if va.volume_id not in detached_volumes_ids
412
444
  ]
445
+ return all_detached
413
446
 
414
447
 
415
448
  async def _detach_volume_from_job_instance(
@@ -503,3 +536,142 @@ async def get_instances_ids_with_detaching_volumes(session: AsyncSession) -> Lis
503
536
  )
504
537
  job_models = res.scalars().all()
505
538
  return [jm.used_instance_id for jm in job_models if jm.used_instance_id]
539
+
540
+
541
+ async def get_job_configured_volumes(
542
+ session: AsyncSession,
543
+ project: ProjectModel,
544
+ run_spec: RunSpec,
545
+ job_num: int,
546
+ job_spec: Optional[JobSpec] = None,
547
+ ) -> List[List[Volume]]:
548
+ """
549
+ Returns a list of job volumes grouped by mount points.
550
+ """
551
+ volume_models = await get_job_configured_volume_models(
552
+ session=session,
553
+ project=project,
554
+ run_spec=run_spec,
555
+ job_num=job_num,
556
+ job_spec=job_spec,
557
+ )
558
+ return [
559
+ [volumes_services.volume_model_to_volume(v) for v in mount_point_volume_models]
560
+ for mount_point_volume_models in volume_models
561
+ ]
562
+
563
+
564
+ async def get_job_configured_volume_models(
565
+ session: AsyncSession,
566
+ project: ProjectModel,
567
+ run_spec: RunSpec,
568
+ job_num: int,
569
+ job_spec: Optional[JobSpec] = None,
570
+ ) -> List[List[VolumeModel]]:
571
+ """
572
+ Returns a list of job volume models grouped by mount points.
573
+ """
574
+ job_volumes = None
575
+ if job_spec is not None:
576
+ job_volumes = job_spec.volumes
577
+ if job_volumes is None:
578
+ # job_spec not provided or a legacy job_spec without volumes
579
+ job_volumes = interpolate_job_volumes(run_spec.configuration.volumes, job_num)
580
+ volume_models = []
581
+ for mount_point in job_volumes:
582
+ if not is_core_model_instance(mount_point, VolumeMountPoint):
583
+ continue
584
+ if isinstance(mount_point.name, str):
585
+ names = [mount_point.name]
586
+ else:
587
+ names = mount_point.name
588
+ mount_point_volume_models = []
589
+ for name in names:
590
+ volume_model = await volumes_services.get_project_volume_model_by_name(
591
+ session=session,
592
+ project=project,
593
+ name=name,
594
+ )
595
+ if volume_model is None:
596
+ raise ResourceNotExistsError(f"Volume {mount_point.name} not found")
597
+ mount_point_volume_models.append(volume_model)
598
+ volume_models.append(mount_point_volume_models)
599
+ return volume_models
600
+
601
+
602
+ def check_can_attach_job_volumes(volumes: List[List[Volume]]):
603
+ """
604
+ Performs basic checks if volumes can be attached.
605
+ This is useful to show error ASAP (when user submits the run).
606
+ If the attachment is to fail anyway, the error will be handled when proccessing submitted jobs.
607
+ """
608
+ if len(volumes) == 0:
609
+ return
610
+ expected_backends = {v.configuration.backend for v in volumes[0]}
611
+ expected_regions = {v.configuration.region for v in volumes[0]}
612
+ for mount_point_volumes in volumes:
613
+ backends = {v.configuration.backend for v in mount_point_volumes}
614
+ regions = {v.configuration.region for v in mount_point_volumes}
615
+ if backends != expected_backends:
616
+ raise ServerClientError(
617
+ "Volumes from different backends specified for different mount points"
618
+ )
619
+ if regions != expected_regions:
620
+ raise ServerClientError(
621
+ "Volumes from different regions specified for different mount points"
622
+ )
623
+ for volume in mount_point_volumes:
624
+ if volume.status != VolumeStatus.ACTIVE:
625
+ raise ServerClientError(f"Cannot mount volumes that are not active: {volume.name}")
626
+ volumes_names = [v.name for vs in volumes for v in vs]
627
+ if len(volumes_names) != len(set(volumes_names)):
628
+ raise ServerClientError("Cannot attach the same volume at different mount points")
629
+
630
+
631
+ async def get_job_attached_volumes(
632
+ session: AsyncSession,
633
+ project: ProjectModel,
634
+ run_spec: RunSpec,
635
+ job_num: int,
636
+ job_provisioning_data: JobProvisioningData,
637
+ ) -> List[Volume]:
638
+ """
639
+ Returns volumes attached to the job.
640
+ """
641
+ job_configured_volumes = await get_job_configured_volumes(
642
+ session=session,
643
+ project=project,
644
+ run_spec=run_spec,
645
+ job_num=job_num,
646
+ )
647
+ job_volumes = []
648
+ for mount_point_volumes in job_configured_volumes:
649
+ job_volumes.append(
650
+ _get_job_mount_point_attached_volume(mount_point_volumes, job_provisioning_data)
651
+ )
652
+ return job_volumes
653
+
654
+
655
+ def _get_job_mount_point_attached_volume(
656
+ volumes: List[Volume],
657
+ job_provisioning_data: JobProvisioningData,
658
+ ) -> Volume:
659
+ """
660
+ Returns the volume attached to the job among the list of possible mount point volumes.
661
+ """
662
+ for volume in volumes:
663
+ if (
664
+ volume.configuration.backend != job_provisioning_data.get_base_backend()
665
+ or volume.configuration.region != job_provisioning_data.region
666
+ ):
667
+ continue
668
+ if (
669
+ volume.provisioning_data is not None
670
+ and volume.provisioning_data.availability_zone is not None
671
+ and job_provisioning_data.availability_zone is not None
672
+ and volume.provisioning_data.availability_zone
673
+ != job_provisioning_data.availability_zone
674
+ ):
675
+ continue
676
+ return volume
677
+ raise ServerClientError("Failed to find an eligible volume for the mount point")
@@ -1,13 +1,13 @@
1
1
  import shlex
2
2
  import sys
3
3
  from abc import ABC, abstractmethod
4
- from typing import Dict, List, Optional
4
+ from typing import Dict, List, Optional, Union
5
5
 
6
6
  from cachetools import TTLCache, cached
7
7
 
8
8
  import dstack.version as version
9
9
  from dstack._internal.core.errors import DockerRegistryError, ServerClientError
10
- from dstack._internal.core.models.common import RegistryAuth
10
+ from dstack._internal.core.models.common import RegistryAuth, is_core_model_instance
11
11
  from dstack._internal.core.models.configurations import (
12
12
  PortMapping,
13
13
  PythonVersion,
@@ -22,10 +22,12 @@ from dstack._internal.core.models.runs import (
22
22
  RunSpec,
23
23
  )
24
24
  from dstack._internal.core.models.unix import UnixUser
25
+ from dstack._internal.core.models.volumes import MountPoint, VolumeMountPoint
25
26
  from dstack._internal.core.services.profiles import get_retry
26
27
  from dstack._internal.core.services.ssh.ports import filter_reserved_ports
27
28
  from dstack._internal.server.services.docker import ImageConfig, get_image_config
28
29
  from dstack._internal.utils.common import run_async
30
+ from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
29
31
 
30
32
 
31
33
  def get_default_python_verison() -> str:
@@ -115,6 +117,7 @@ class JobConfigurator(ABC):
115
117
  requirements=self._requirements(),
116
118
  retry=self._retry(),
117
119
  working_dir=self._working_dir(),
120
+ volumes=self._volumes(job_num),
118
121
  )
119
122
  return job_spec
120
123
 
@@ -224,6 +227,48 @@ class JobConfigurator(ABC):
224
227
  return self.run_spec.configuration.python.value
225
228
  return get_default_python_verison()
226
229
 
230
+ def _volumes(self, job_num: int) -> List[MountPoint]:
231
+ return interpolate_job_volumes(self.run_spec.configuration.volumes, job_num)
232
+
233
+
234
+ def interpolate_job_volumes(
235
+ run_volumes: List[Union[MountPoint, str]],
236
+ job_num: int,
237
+ ) -> List[MountPoint]:
238
+ if len(run_volumes) == 0:
239
+ return []
240
+ interpolator = VariablesInterpolator(
241
+ namespaces={
242
+ "dstack": {
243
+ "job_num": str(job_num),
244
+ "node_rank": str(job_num), # an alias for job_num
245
+ }
246
+ }
247
+ )
248
+ job_volumes = []
249
+ for mount_point in run_volumes:
250
+ if isinstance(mount_point, str):
251
+ # pydantic validator ensures strings are converted to MountPoint
252
+ continue
253
+ if not is_core_model_instance(mount_point, VolumeMountPoint):
254
+ job_volumes.append(mount_point.copy())
255
+ continue
256
+ if isinstance(mount_point.name, str):
257
+ names = [mount_point.name]
258
+ else:
259
+ names = mount_point.name
260
+ try:
261
+ interpolated_names = [interpolator.interpolate_or_error(n) for n in names]
262
+ except InterpolatorError as e:
263
+ raise ServerClientError(e.args[0])
264
+ job_volumes.append(
265
+ VolumeMountPoint(
266
+ name=interpolated_names,
267
+ path=mount_point.path,
268
+ )
269
+ )
270
+ return job_volumes
271
+
227
272
 
228
273
  def _join_shell_commands(commands: List[str]) -> str:
229
274
  for i, cmd in enumerate(commands):
@@ -1,4 +1,6 @@
1
- from typing import List, Optional, Tuple
1
+ from typing import List, Literal, Optional, Tuple, Union
2
+
3
+ import gpuhunt
2
4
 
3
5
  from dstack._internal.core.backends import (
4
6
  BACKENDS_WITH_CREATE_INSTANCE_SUPPORT,
@@ -7,7 +9,11 @@ from dstack._internal.core.backends import (
7
9
  )
8
10
  from dstack._internal.core.backends.base import Backend
9
11
  from dstack._internal.core.models.backends.base import BackendType
10
- from dstack._internal.core.models.instances import InstanceOfferWithAvailability
12
+ from dstack._internal.core.models.instances import (
13
+ InstanceOfferWithAvailability,
14
+ InstanceType,
15
+ Resources,
16
+ )
11
17
  from dstack._internal.core.models.profiles import Profile
12
18
  from dstack._internal.core.models.runs import JobProvisioningData, Requirements
13
19
  from dstack._internal.core.models.volumes import Volume
@@ -25,6 +31,7 @@ async def get_offers_by_requirements(
25
31
  volumes: Optional[List[List[Volume]]] = None,
26
32
  privileged: bool = False,
27
33
  instance_mounts: bool = False,
34
+ blocks: Union[int, Literal["auto"]] = 1,
28
35
  ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]:
29
36
  backends: List[Backend] = await backends_services.get_project_backends(project=project)
30
37
 
@@ -38,33 +45,40 @@ async def get_offers_by_requirements(
38
45
 
39
46
  backend_types = profile.backends
40
47
  regions = profile.regions
48
+ availability_zones = profile.availability_zones
41
49
 
42
50
  if volumes:
43
51
  mount_point_volumes = volumes[0]
44
- backend_types = [v.configuration.backend for v in mount_point_volumes]
45
- regions = [v.configuration.region for v in mount_point_volumes]
52
+ volumes_backend_types = [v.configuration.backend for v in mount_point_volumes]
53
+ if backend_types is None:
54
+ backend_types = volumes_backend_types
55
+ backend_types = [b for b in backend_types if b in volumes_backend_types]
56
+ volumes_regions = [v.configuration.region for v in mount_point_volumes]
57
+ if regions is None:
58
+ regions = volumes_regions
59
+ regions = [r for r in regions if r in volumes_regions]
46
60
 
47
61
  if multinode:
48
- if not backend_types:
62
+ if backend_types is None:
49
63
  backend_types = BACKENDS_WITH_MULTINODE_SUPPORT
50
64
  backend_types = [b for b in backend_types if b in BACKENDS_WITH_MULTINODE_SUPPORT]
51
65
 
52
66
  if privileged or instance_mounts:
53
- if not backend_types:
67
+ if backend_types is None:
54
68
  backend_types = BACKENDS_WITH_CREATE_INSTANCE_SUPPORT
55
69
  backend_types = [b for b in backend_types if b in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT]
56
70
 
57
71
  if profile.reservation is not None:
58
- if not backend_types:
72
+ if backend_types is None:
59
73
  backend_types = BACKENDS_WITH_RESERVATION_SUPPORT
60
74
  backend_types = [b for b in backend_types if b in BACKENDS_WITH_RESERVATION_SUPPORT]
61
75
 
62
76
  # For multi-node, restrict backend and region.
63
77
  # The default behavior is to provision all nodes in the same backend and region.
64
78
  if master_job_provisioning_data is not None:
65
- if not backend_types:
79
+ if backend_types is None:
66
80
  backend_types = [master_job_provisioning_data.get_base_backend()]
67
- if not regions:
81
+ if regions is None:
68
82
  regions = [master_job_provisioning_data.region]
69
83
  backend_types = [
70
84
  b for b in backend_types if b == master_job_provisioning_data.get_base_backend()
@@ -89,7 +103,79 @@ async def get_offers_by_requirements(
89
103
  if regions is not None:
90
104
  offers = [(b, o) for b, o in offers if o.region in regions]
91
105
 
106
+ if availability_zones is not None:
107
+ new_offers = []
108
+ for b, o in offers:
109
+ if o.availability_zones is not None:
110
+ new_offer = o.copy()
111
+ new_offer.availability_zones = [
112
+ z for z in o.availability_zones if z in availability_zones
113
+ ]
114
+ if new_offer.availability_zones:
115
+ new_offers.append((b, new_offer))
116
+ offers = new_offers
117
+
92
118
  if profile.instance_types is not None:
93
119
  offers = [(b, o) for b, o in offers if o.instance.name in profile.instance_types]
94
120
 
95
- return offers
121
+ if blocks == 1:
122
+ return offers
123
+
124
+ shareable_offers = []
125
+ for backend, offer in offers:
126
+ resources = offer.instance.resources
127
+ cpu_count = resources.cpus
128
+ gpu_count = len(resources.gpus)
129
+ if gpu_count > 0 and resources.gpus[0].vendor == gpuhunt.AcceleratorVendor.GOOGLE:
130
+ # TPUs cannot be shared
131
+ gpu_count = 1
132
+ divisible, _blocks = is_divisible_into_blocks(cpu_count, gpu_count, blocks)
133
+ if not divisible:
134
+ continue
135
+ offer.total_blocks = _blocks
136
+ shareable_offers.append((backend, offer))
137
+ return shareable_offers
138
+
139
+
140
+ def is_divisible_into_blocks(
141
+ cpu_count: int, gpu_count: int, blocks: Union[int, Literal["auto"]]
142
+ ) -> tuple[bool, int]:
143
+ """
144
+ Returns `True` and number of blocks the instance can be split into or `False` and `0` if
145
+ is not divisible.
146
+ Requested number of blocks can be `auto`, which means as many as possible.
147
+ """
148
+ if blocks == "auto":
149
+ if gpu_count == 0:
150
+ blocks = cpu_count
151
+ else:
152
+ blocks = min(cpu_count, gpu_count)
153
+ if blocks < 1 or cpu_count % blocks or gpu_count % blocks:
154
+ return False, 0
155
+ return True, blocks
156
+
157
+
158
+ def generate_shared_offer(
159
+ offer: InstanceOfferWithAvailability, blocks: int, total_blocks: int
160
+ ) -> InstanceOfferWithAvailability:
161
+ full_resources = offer.instance.resources
162
+ resources = Resources(
163
+ cpus=full_resources.cpus // total_blocks * blocks,
164
+ memory_mib=full_resources.memory_mib // total_blocks * blocks,
165
+ gpus=full_resources.gpus[: len(full_resources.gpus) // total_blocks * blocks],
166
+ spot=full_resources.spot,
167
+ disk=full_resources.disk,
168
+ description=full_resources.description,
169
+ )
170
+ return InstanceOfferWithAvailability(
171
+ backend=offer.backend,
172
+ instance=InstanceType(
173
+ name=offer.instance.name,
174
+ resources=resources,
175
+ ),
176
+ region=offer.region,
177
+ price=offer.price,
178
+ availability=offer.availability,
179
+ blocks=blocks,
180
+ total_blocks=total_blocks,
181
+ )