dstack 0.18.40__py3-none-any.whl → 0.18.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -5
  2. dstack/_internal/cli/services/configurators/base.py +4 -2
  3. dstack/_internal/cli/services/configurators/fleet.py +21 -9
  4. dstack/_internal/cli/services/configurators/gateway.py +15 -0
  5. dstack/_internal/cli/services/configurators/run.py +6 -5
  6. dstack/_internal/cli/services/configurators/volume.py +15 -0
  7. dstack/_internal/cli/services/repos.py +3 -3
  8. dstack/_internal/cli/utils/fleet.py +44 -33
  9. dstack/_internal/cli/utils/run.py +27 -7
  10. dstack/_internal/cli/utils/volume.py +21 -9
  11. dstack/_internal/core/backends/aws/compute.py +92 -52
  12. dstack/_internal/core/backends/aws/resources.py +22 -12
  13. dstack/_internal/core/backends/azure/compute.py +2 -0
  14. dstack/_internal/core/backends/base/compute.py +20 -2
  15. dstack/_internal/core/backends/gcp/compute.py +30 -23
  16. dstack/_internal/core/backends/gcp/resources.py +0 -15
  17. dstack/_internal/core/backends/oci/compute.py +10 -5
  18. dstack/_internal/core/backends/oci/resources.py +23 -26
  19. dstack/_internal/core/backends/remote/provisioning.py +65 -27
  20. dstack/_internal/core/backends/runpod/compute.py +1 -0
  21. dstack/_internal/core/models/backends/azure.py +3 -1
  22. dstack/_internal/core/models/configurations.py +24 -1
  23. dstack/_internal/core/models/fleets.py +46 -0
  24. dstack/_internal/core/models/instances.py +5 -1
  25. dstack/_internal/core/models/pools.py +4 -1
  26. dstack/_internal/core/models/profiles.py +10 -4
  27. dstack/_internal/core/models/runs.py +20 -0
  28. dstack/_internal/core/models/volumes.py +3 -0
  29. dstack/_internal/core/services/ssh/attach.py +92 -53
  30. dstack/_internal/core/services/ssh/tunnel.py +58 -31
  31. dstack/_internal/proxy/gateway/routers/registry.py +2 -0
  32. dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
  33. dstack/_internal/proxy/gateway/services/registry.py +4 -0
  34. dstack/_internal/proxy/lib/models.py +3 -0
  35. dstack/_internal/proxy/lib/services/service_connection.py +8 -1
  36. dstack/_internal/server/background/tasks/process_instances.py +72 -33
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -9
  38. dstack/_internal/server/background/tasks/process_running_jobs.py +73 -26
  39. dstack/_internal/server/background/tasks/process_runs.py +2 -12
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +109 -42
  41. dstack/_internal/server/background/tasks/process_terminating_jobs.py +1 -1
  42. dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
  43. dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
  44. dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
  45. dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
  46. dstack/_internal/server/models.py +10 -4
  47. dstack/_internal/server/routers/runs.py +1 -0
  48. dstack/_internal/server/schemas/runner.py +1 -0
  49. dstack/_internal/server/services/backends/configurators/azure.py +34 -8
  50. dstack/_internal/server/services/config.py +9 -0
  51. dstack/_internal/server/services/fleets.py +27 -2
  52. dstack/_internal/server/services/gateways/client.py +9 -1
  53. dstack/_internal/server/services/jobs/__init__.py +215 -43
  54. dstack/_internal/server/services/jobs/configurators/base.py +47 -2
  55. dstack/_internal/server/services/offers.py +91 -5
  56. dstack/_internal/server/services/pools.py +95 -11
  57. dstack/_internal/server/services/proxy/repo.py +17 -3
  58. dstack/_internal/server/services/runner/client.py +1 -1
  59. dstack/_internal/server/services/runner/ssh.py +33 -5
  60. dstack/_internal/server/services/runs.py +48 -179
  61. dstack/_internal/server/services/services/__init__.py +9 -1
  62. dstack/_internal/server/statics/index.html +1 -1
  63. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
  64. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
  65. dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
  66. dstack/_internal/server/testing/common.py +117 -52
  67. dstack/_internal/utils/common.py +22 -8
  68. dstack/_internal/utils/env.py +14 -0
  69. dstack/_internal/utils/ssh.py +1 -1
  70. dstack/api/server/_fleets.py +25 -1
  71. dstack/api/server/_runs.py +23 -2
  72. dstack/api/server/_volumes.py +12 -1
  73. dstack/version.py +1 -1
  74. {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/METADATA +1 -1
  75. {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/RECORD +98 -89
  76. tests/_internal/cli/services/configurators/test_profile.py +3 -3
  77. tests/_internal/core/services/ssh/test_tunnel.py +56 -4
  78. tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
  79. tests/_internal/server/background/tasks/test_process_instances.py +138 -20
  80. tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
  81. tests/_internal/server/background/tasks/test_process_running_jobs.py +192 -0
  82. tests/_internal/server/background/tasks/test_process_runs.py +27 -3
  83. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +48 -3
  84. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +126 -13
  85. tests/_internal/server/routers/test_fleets.py +15 -2
  86. tests/_internal/server/routers/test_pools.py +6 -0
  87. tests/_internal/server/routers/test_runs.py +27 -0
  88. tests/_internal/server/services/jobs/__init__.py +0 -0
  89. tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
  90. tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
  91. tests/_internal/server/services/test_pools.py +4 -0
  92. tests/_internal/server/services/test_runs.py +5 -41
  93. tests/_internal/utils/test_common.py +21 -0
  94. tests/_internal/utils/test_env.py +38 -0
  95. {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/LICENSE.md +0 -0
  96. {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/WHEEL +0 -0
  97. {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/entry_points.txt +0 -0
  98. {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,12 @@ from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HT
10
10
  from dstack._internal.core.errors import GatewayError
11
11
  from dstack._internal.core.models.backends.base import BackendType
12
12
  from dstack._internal.core.models.common import NetworkMode, RegistryAuth, is_core_model_instance
13
- from dstack._internal.core.models.instances import InstanceStatus, RemoteConnectionInfo
13
+ from dstack._internal.core.models.configurations import DevEnvironmentConfiguration
14
+ from dstack._internal.core.models.instances import (
15
+ InstanceStatus,
16
+ RemoteConnectionInfo,
17
+ SSHConnectionParams,
18
+ )
14
19
  from dstack._internal.core.models.repos import RemoteRepoCreds
15
20
  from dstack._internal.core.models.runs import (
16
21
  ClusterInfo,
@@ -20,10 +25,12 @@ from dstack._internal.core.models.runs import (
20
25
  JobStatus,
21
26
  JobTerminationReason,
22
27
  Run,
28
+ RunSpec,
23
29
  )
24
30
  from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint
25
31
  from dstack._internal.server.db import get_session_ctx
26
32
  from dstack._internal.server.models import (
33
+ InstanceModel,
27
34
  JobModel,
28
35
  ProjectModel,
29
36
  RepoModel,
@@ -34,11 +41,13 @@ from dstack._internal.server.services import logs as logs_services
34
41
  from dstack._internal.server.services import services
35
42
  from dstack._internal.server.services.jobs import (
36
43
  find_job,
44
+ get_job_attached_volumes,
37
45
  get_job_runtime_data,
38
46
  job_model_to_job_submission,
39
47
  )
40
48
  from dstack._internal.server.services.locking import get_locker
41
49
  from dstack._internal.server.services.logging import fmt
50
+ from dstack._internal.server.services.pools import get_instance_ssh_private_keys
42
51
  from dstack._internal.server.services.repos import (
43
52
  get_code_model,
44
53
  get_repo_creds,
@@ -47,7 +56,6 @@ from dstack._internal.server.services.repos import (
47
56
  from dstack._internal.server.services.runner import client
48
57
  from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
49
58
  from dstack._internal.server.services.runs import (
50
- get_job_volumes,
51
59
  run_model_to_run,
52
60
  )
53
61
  from dstack._internal.server.services.storage import get_default_storage
@@ -81,7 +89,7 @@ async def _process_next_running_job():
81
89
  .limit(1)
82
90
  .with_for_update(skip_locked=True)
83
91
  )
84
- job_model = res.scalar()
92
+ job_model = res.unique().scalar()
85
93
  if job_model is None:
86
94
  return
87
95
  lockset.add(job_model.id)
@@ -99,10 +107,10 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
99
107
  res = await session.execute(
100
108
  select(JobModel)
101
109
  .where(JobModel.id == job_model.id)
102
- .options(joinedload(JobModel.instance))
110
+ .options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
103
111
  .execution_options(populate_existing=True)
104
112
  )
105
- job_model = res.scalar_one()
113
+ job_model = res.unique().scalar_one()
106
114
  res = await session.execute(
107
115
  select(RunModel)
108
116
  .where(RunModel.id == job_model.run_id)
@@ -142,25 +150,17 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
142
150
  job_provisioning_data=job_provisioning_data,
143
151
  )
144
152
 
145
- volumes = await get_job_volumes(
153
+ volumes = await get_job_attached_volumes(
146
154
  session=session,
147
155
  project=project,
148
156
  run_spec=run.run_spec,
157
+ job_num=job.job_spec.job_num,
149
158
  job_provisioning_data=job_provisioning_data,
150
159
  )
151
160
 
152
- server_ssh_private_key = project.ssh_private_key
153
- # TODO: Drop this logic and always use project key once it's safe to assume that most on-prem
154
- # fleets are (re)created after this change: https://github.com/dstackai/dstack/pull/1716
155
- if (
156
- job_model.instance is not None
157
- and job_model.instance.remote_connection_info is not None
158
- and job_provisioning_data.dockerized
159
- ):
160
- remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
161
- job_model.instance.remote_connection_info
162
- )
163
- server_ssh_private_key = remote_conn_info.ssh_keys[0].private
161
+ server_ssh_private_keys = get_instance_ssh_private_keys(
162
+ common_utils.get_or_error(job_model.instance)
163
+ )
164
164
 
165
165
  secrets = {} # TODO secrets
166
166
 
@@ -200,7 +200,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
200
200
  user_ssh_key = ""
201
201
  success = await common_utils.run_async(
202
202
  _process_provisioning_with_shim,
203
- server_ssh_private_key,
203
+ server_ssh_private_keys,
204
204
  job_provisioning_data,
205
205
  None,
206
206
  run,
@@ -226,7 +226,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
226
226
  )
227
227
  success = await common_utils.run_async(
228
228
  _submit_job_to_runner,
229
- server_ssh_private_key,
229
+ server_ssh_private_keys,
230
230
  job_provisioning_data,
231
231
  None,
232
232
  run,
@@ -269,7 +269,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
269
269
  )
270
270
  success = await common_utils.run_async(
271
271
  _process_pulling_with_shim,
272
- server_ssh_private_key,
272
+ server_ssh_private_keys,
273
273
  job_provisioning_data,
274
274
  None,
275
275
  run,
@@ -279,14 +279,14 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
279
279
  code,
280
280
  secrets,
281
281
  repo_creds,
282
- server_ssh_private_key,
282
+ server_ssh_private_keys,
283
283
  job_provisioning_data,
284
284
  )
285
285
  elif initial_status == JobStatus.RUNNING:
286
286
  logger.debug("%s: process running job, age=%s", fmt(job_model), job_submission.age)
287
287
  success = await common_utils.run_async(
288
288
  _process_running,
289
- server_ssh_private_key,
289
+ server_ssh_private_keys,
290
290
  job_provisioning_data,
291
291
  job_submission.job_runtime_data,
292
292
  run_model,
@@ -312,8 +312,24 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
312
312
  and job_model.job_num == 0 # gateway connects only to the first node
313
313
  and run.run_spec.configuration.type == "service"
314
314
  ):
315
+ ssh_head_proxy: Optional[SSHConnectionParams] = None
316
+ ssh_head_proxy_private_key: Optional[str] = None
317
+ instance = common_utils.get_or_error(job_model.instance)
318
+ if instance.remote_connection_info is not None:
319
+ rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
320
+ if rci.ssh_proxy is not None:
321
+ ssh_head_proxy = rci.ssh_proxy
322
+ ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys)
323
+ ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private
315
324
  try:
316
- await services.register_replica(session, run_model.gateway_id, run, job_model)
325
+ await services.register_replica(
326
+ session,
327
+ run_model.gateway_id,
328
+ run,
329
+ job_model,
330
+ ssh_head_proxy,
331
+ ssh_head_proxy_private_key,
332
+ )
317
333
  except GatewayError as e:
318
334
  logger.warning(
319
335
  "%s: failed to register service replica: %s, age=%s",
@@ -490,7 +506,7 @@ def _process_pulling_with_shim(
490
506
  code: bytes,
491
507
  secrets: Dict[str, str],
492
508
  repo_credentials: Optional[RemoteRepoCreds],
493
- server_ssh_private_key: str,
509
+ server_ssh_private_keys: tuple[str, Optional[str]],
494
510
  job_provisioning_data: JobProvisioningData,
495
511
  ) -> bool:
496
512
  """
@@ -555,7 +571,7 @@ def _process_pulling_with_shim(
555
571
  return True
556
572
 
557
573
  return _submit_job_to_runner(
558
- server_ssh_private_key,
574
+ server_ssh_private_keys,
559
575
  job_provisioning_data,
560
576
  job_runtime_data,
561
577
  run=run,
@@ -597,6 +613,7 @@ def _process_running(
597
613
  runner_logs=resp.runner_logs,
598
614
  job_logs=resp.job_logs,
599
615
  )
616
+ previous_status = job_model.status
600
617
  if len(resp.job_states) > 0:
601
618
  latest_state_event = resp.job_states[-1]
602
619
  latest_status = latest_state_event.state
@@ -612,10 +629,40 @@ def _process_running(
612
629
  )
613
630
  if latest_state_event.termination_message:
614
631
  job_model.termination_reason_message = latest_state_event.termination_message
632
+ else:
633
+ _terminate_if_inactivity_duration_exceeded(run_model, job_model, resp.no_connections_secs)
634
+ if job_model.status != previous_status:
615
635
  logger.info("%s: now is %s", fmt(job_model), job_model.status.name)
616
636
  return True
617
637
 
618
638
 
639
+ def _terminate_if_inactivity_duration_exceeded(
640
+ run_model: RunModel, job_model: JobModel, no_connections_secs: Optional[int]
641
+ ) -> None:
642
+ conf = RunSpec.__response__.parse_raw(run_model.run_spec).configuration
643
+ if is_core_model_instance(conf, DevEnvironmentConfiguration) and isinstance(
644
+ conf.inactivity_duration, int
645
+ ):
646
+ logger.debug("%s: no SSH connections for %s seconds", fmt(job_model), no_connections_secs)
647
+ job_model.inactivity_secs = no_connections_secs
648
+ if no_connections_secs is None:
649
+ # TODO(0.19 or earlier): make no_connections_secs required
650
+ job_model.status = JobStatus.TERMINATING
651
+ job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
652
+ job_model.termination_reason_message = (
653
+ "The selected instance was created before dstack 0.18.41"
654
+ " and does not support inactivity_duration"
655
+ )
656
+ elif no_connections_secs >= conf.inactivity_duration:
657
+ job_model.status = JobStatus.TERMINATING
658
+ # TODO(0.19 or earlier): set JobTerminationReason.INACTIVITY_DURATION_EXCEEDED
659
+ job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
660
+ job_model.termination_reason_message = (
661
+ f"The job was inactive for {no_connections_secs} seconds,"
662
+ f" exceeding the inactivity_duration of {conf.inactivity_duration} seconds"
663
+ )
664
+
665
+
619
666
  def _get_cluster_info(
620
667
  jobs: List[Job],
621
668
  replica_num: int,
@@ -230,7 +230,8 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
230
230
  # the job is submitted
231
231
  replica_statuses.add(RunStatus.SUBMITTED)
232
232
  elif job_model.status == JobStatus.FAILED or (
233
- job_model.status == JobStatus.TERMINATING
233
+ job_model.status
234
+ in [JobStatus.TERMINATING, JobStatus.TERMINATED, JobStatus.ABORTED]
234
235
  and job_model.termination_reason
235
236
  not in {JobTerminationReason.DONE_BY_RUNNER, JobTerminationReason.SCALED_DOWN}
236
237
  ):
@@ -244,17 +245,6 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
244
245
  run_termination_reasons.add(RunTerminationReason.RETRY_LIMIT_EXCEEDED)
245
246
  else:
246
247
  replica_needs_retry = True
247
- elif job_model.status in {
248
- JobStatus.TERMINATING,
249
- JobStatus.TERMINATED,
250
- JobStatus.ABORTED,
251
- }:
252
- # FIXME: This code does not expect JobStatus.TERMINATED status,
253
- # so if a job transitions from RUNNING to TERMINATED,
254
- # the run will transition to PENDING instead of TERMINATING.
255
- # This may not be observed because process_runs is invoked more frequently
256
- # than process_terminating_jobs and because most jobs usually transition to FAILED.
257
- pass # unexpected, but let's ignore it
258
248
  else:
259
249
  raise ValueError(f"Unexpected job status {job_model.status}")
260
250
 
@@ -15,10 +15,7 @@ from dstack._internal.core.models.fleets import (
15
15
  FleetStatus,
16
16
  InstanceGroupPlacement,
17
17
  )
18
- from dstack._internal.core.models.instances import (
19
- InstanceOfferWithAvailability,
20
- InstanceStatus,
21
- )
18
+ from dstack._internal.core.models.instances import InstanceOfferWithAvailability, InstanceStatus
22
19
  from dstack._internal.core.models.profiles import (
23
20
  DEFAULT_POOL_NAME,
24
21
  DEFAULT_RUN_TERMINATION_IDLE_TIME,
@@ -26,6 +23,7 @@ from dstack._internal.core.models.profiles import (
26
23
  Profile,
27
24
  TerminationPolicy,
28
25
  )
26
+ from dstack._internal.core.models.resources import Memory
29
27
  from dstack._internal.core.models.runs import (
30
28
  Job,
31
29
  JobProvisioningData,
@@ -52,28 +50,31 @@ from dstack._internal.server.services.fleets import (
52
50
  fleet_model_to_fleet,
53
51
  )
54
52
  from dstack._internal.server.services.jobs import (
53
+ check_can_attach_job_volumes,
55
54
  find_job,
56
55
  get_instances_ids_with_detaching_volumes,
56
+ get_job_configured_volume_models,
57
+ get_job_configured_volumes,
58
+ get_job_runtime_data,
57
59
  )
58
60
  from dstack._internal.server.services.locking import get_locker
59
61
  from dstack._internal.server.services.logging import fmt
60
62
  from dstack._internal.server.services.offers import get_offers_by_requirements
61
63
  from dstack._internal.server.services.pools import (
62
64
  filter_pool_instances,
65
+ get_instance_offer,
63
66
  get_instance_provisioning_data,
67
+ get_shared_pool_instances_with_offers,
64
68
  )
65
69
  from dstack._internal.server.services.runs import (
66
- check_can_attach_run_volumes,
67
70
  check_run_spec_requires_instance_mounts,
68
- get_offer_volumes,
69
- get_run_volume_models,
70
- get_run_volumes,
71
71
  run_model_to_run,
72
72
  )
73
73
  from dstack._internal.server.services.volumes import (
74
74
  volume_model_to_volume,
75
75
  )
76
76
  from dstack._internal.utils import common as common_utils
77
+ from dstack._internal.utils import env as env_utils
77
78
  from dstack._internal.utils.logging import get_logger
78
79
 
79
80
  logger = get_logger(__name__)
@@ -152,17 +153,21 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
152
153
  await session.commit()
153
154
  return
154
155
  try:
155
- volume_models = await get_run_volume_models(
156
+ volume_models = await get_job_configured_volume_models(
156
157
  session=session,
157
158
  project=project,
158
159
  run_spec=run_spec,
160
+ job_num=job.job_spec.job_num,
161
+ job_spec=job.job_spec,
159
162
  )
160
- volumes = await get_run_volumes(
163
+ volumes = await get_job_configured_volumes(
161
164
  session=session,
162
165
  project=project,
163
166
  run_spec=run_spec,
167
+ job_num=job.job_spec.job_num,
168
+ job_spec=job.job_spec,
164
169
  )
165
- check_can_attach_run_volumes(run_spec=run_spec, volumes=volumes)
170
+ check_can_attach_job_volumes(volumes)
166
171
  except ServerClientError as e:
167
172
  logger.warning("%s: failed to prepare run volumes: %s", fmt(job_model), repr(e))
168
173
  job_model.status = JobStatus.TERMINATING
@@ -186,12 +191,12 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
186
191
  .where(
187
192
  InstanceModel.pool_id == pool.id,
188
193
  InstanceModel.deleted == False,
189
- InstanceModel.job_id.is_(None),
194
+ InstanceModel.total_blocks > InstanceModel.busy_blocks,
190
195
  )
191
- .options(lazyload(InstanceModel.job))
196
+ .options(lazyload(InstanceModel.jobs))
192
197
  .with_for_update()
193
198
  )
194
- pool_instances = list(res.scalars().all())
199
+ pool_instances = list(res.unique().scalars().all())
195
200
  instances_ids = sorted([i.id for i in pool_instances])
196
201
  if get_db().dialect_name == "sqlite":
197
202
  # Start new transaction to see commited changes after lock
@@ -202,14 +207,16 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
202
207
  detaching_instances_ids = await get_instances_ids_with_detaching_volumes(session)
203
208
  # Refetch after lock
204
209
  res = await session.execute(
205
- select(InstanceModel).where(
210
+ select(InstanceModel)
211
+ .where(
206
212
  InstanceModel.id.not_in(detaching_instances_ids),
207
213
  InstanceModel.id.in_(instances_ids),
208
214
  InstanceModel.deleted == False,
209
- InstanceModel.job_id.is_(None),
215
+ InstanceModel.total_blocks > InstanceModel.busy_blocks,
210
216
  )
217
+ .execution_options(populate_existing=True)
211
218
  )
212
- pool_instances = list(res.scalars().all())
219
+ pool_instances = list(res.unique().scalars().all())
213
220
  instance = await _assign_job_to_pool_instance(
214
221
  session=session,
215
222
  pool_instances=pool_instances,
@@ -221,8 +228,6 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
221
228
  volumes=volumes,
222
229
  )
223
230
  job_model.instance_assigned = True
224
- if instance is not None:
225
- job_model.job_runtime_data = _prepare_job_runtime_data(job, instance).json()
226
231
  job_model.last_processed_at = common_utils.get_current_datetime()
227
232
  await session.commit()
228
233
  return
@@ -234,7 +239,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
234
239
  .options(selectinload(InstanceModel.volumes))
235
240
  .execution_options(populate_existing=True)
236
241
  )
237
- instance = res.scalar_one()
242
+ instance = res.unique().scalar_one()
238
243
  job_model.status = JobStatus.PROVISIONING
239
244
  else:
240
245
  # Assigned no instance, create a new one
@@ -290,7 +295,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
290
295
  offer=offer,
291
296
  instance_num=instance_num,
292
297
  )
293
- job_model.job_runtime_data = _prepare_job_runtime_data(job, instance).json()
298
+ job_model.job_runtime_data = _prepare_job_runtime_data(offer).json()
294
299
  instance.fleet_id = fleet_model.id
295
300
  logger.info(
296
301
  "The job %s created the new instance %s",
@@ -351,21 +356,40 @@ async def _assign_job_to_pool_instance(
351
356
  master_job_provisioning_data: Optional[JobProvisioningData] = None,
352
357
  volumes: Optional[List[List[Volume]]] = None,
353
358
  ) -> Optional[InstanceModel]:
359
+ instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]]
354
360
  profile = run_spec.merged_profile
355
- relevant_instances = filter_pool_instances(
361
+ multinode = job.job_spec.jobs_per_replica > 1
362
+ nonshared_instances = filter_pool_instances(
356
363
  pool_instances=pool_instances,
357
364
  profile=profile,
358
365
  requirements=job.job_spec.requirements,
359
366
  status=InstanceStatus.IDLE,
360
367
  fleet_model=fleet_model,
361
- multinode=job.job_spec.jobs_per_replica > 1,
368
+ multinode=multinode,
362
369
  master_job_provisioning_data=master_job_provisioning_data,
363
370
  volumes=volumes,
371
+ shared=False,
364
372
  )
365
- if len(relevant_instances) == 0:
373
+ instances_with_offers = [
374
+ (instance, common_utils.get_or_error(get_instance_offer(instance)))
375
+ for instance in nonshared_instances
376
+ ]
377
+ if not multinode:
378
+ shared_instances_with_offers = get_shared_pool_instances_with_offers(
379
+ pool_instances=pool_instances,
380
+ profile=profile,
381
+ requirements=job.job_spec.requirements,
382
+ idle_only=True,
383
+ fleet_model=fleet_model,
384
+ volumes=volumes,
385
+ )
386
+ instances_with_offers.extend(shared_instances_with_offers)
387
+
388
+ if len(instances_with_offers) == 0:
366
389
  return None
367
- sorted_instances = sorted(relevant_instances, key=lambda instance: instance.price)
368
- instance = sorted_instances[0]
390
+
391
+ instances_with_offers.sort(key=lambda instance_with_offer: instance_with_offer[0].price or 0)
392
+ instance, offer = instances_with_offers[0]
369
393
  # Reload InstanceModel with volumes
370
394
  res = await session.execute(
371
395
  select(InstanceModel)
@@ -374,7 +398,8 @@ async def _assign_job_to_pool_instance(
374
398
  )
375
399
  instance = res.unique().scalar_one()
376
400
  instance.status = InstanceStatus.BUSY
377
- instance.job = job_model
401
+ instance.busy_blocks += offer.blocks
402
+
378
403
  logger.info(
379
404
  "The job %s switched instance %s status to BUSY",
380
405
  job_model.job_name,
@@ -385,8 +410,10 @@ async def _assign_job_to_pool_instance(
385
410
  },
386
411
  )
387
412
  logger.info("%s: now is provisioning on '%s'", fmt(job_model), instance.name)
388
- job_model.job_provisioning_data = instance.job_provisioning_data
413
+ job_model.instance = instance
389
414
  job_model.used_instance_id = instance.id
415
+ job_model.job_provisioning_data = instance.job_provisioning_data
416
+ job_model.job_runtime_data = _prepare_job_runtime_data(offer).json()
390
417
  return instance
391
418
 
392
419
 
@@ -431,7 +458,7 @@ async def _run_job_on_new_instance(
431
458
  offer.region,
432
459
  offer.price,
433
460
  )
434
- offer_volumes = get_offer_volumes(volumes, offer)
461
+ offer_volumes = _get_offer_volumes(volumes, offer)
435
462
  try:
436
463
  job_provisioning_data = await common_utils.run_async(
437
464
  backend.compute().run_job,
@@ -549,29 +576,64 @@ def _create_instance_model_for_job(
549
576
  offer=offer.json(),
550
577
  termination_policy=termination_policy,
551
578
  termination_idle_time=termination_idle_time,
552
- job=job_model,
579
+ jobs=[job_model],
553
580
  backend=offer.backend,
554
581
  price=offer.price,
555
582
  region=offer.region,
556
583
  volumes=[],
584
+ total_blocks=1,
585
+ busy_blocks=1,
557
586
  )
558
587
  return instance
559
588
 
560
589
 
561
- def _prepare_job_runtime_data(job: Job, instance: InstanceModel) -> JobRuntimeData:
562
- if job.job_spec.jobs_per_replica > 1:
563
- # multi-node runs require host network mode for inter-node communication and occupy
564
- # the entire instance
565
- return JobRuntimeData(network_mode=NetworkMode.HOST)
566
-
567
- # TODO: replace with a real computed value depending on the instance
568
- is_shared_instance = True
590
+ def _prepare_job_runtime_data(offer: InstanceOfferWithAvailability) -> JobRuntimeData:
591
+ if offer.total_blocks == 1:
592
+ if env_utils.get_bool("DSTACK_FORCE_BRIDGE_NETWORK"):
593
+ network_mode = NetworkMode.BRIDGE
594
+ else:
595
+ network_mode = NetworkMode.HOST
596
+ return JobRuntimeData(
597
+ network_mode=network_mode,
598
+ offer=offer,
599
+ )
600
+ return JobRuntimeData(
601
+ network_mode=NetworkMode.BRIDGE,
602
+ offer=offer,
603
+ cpu=offer.instance.resources.cpus,
604
+ gpu=len(offer.instance.resources.gpus),
605
+ memory=Memory(offer.instance.resources.memory_mib / 1024),
606
+ )
569
607
 
570
- if not is_shared_instance:
571
- return JobRuntimeData(network_mode=NetworkMode.HOST)
572
608
 
573
- # TODO: slice CPU/GPU/Memory resources depending on the instance
574
- return JobRuntimeData(network_mode=NetworkMode.BRIDGE)
609
+ def _get_offer_volumes(
610
+ volumes: List[List[Volume]],
611
+ offer: InstanceOfferWithAvailability,
612
+ ) -> List[Volume]:
613
+ """
614
+ Returns volumes suitable for the offer for each mount point.
615
+ """
616
+ offer_volumes = []
617
+ for mount_point_volumes in volumes:
618
+ offer_volumes.append(_get_offer_mount_point_volume(mount_point_volumes, offer))
619
+ return offer_volumes
620
+
621
+
622
+ def _get_offer_mount_point_volume(
623
+ volumes: List[Volume],
624
+ offer: InstanceOfferWithAvailability,
625
+ ) -> Volume:
626
+ """
627
+ Returns the first suitable volume for the offer among possible mount point volumes.
628
+ """
629
+ for volume in volumes:
630
+ if (
631
+ volume.configuration.backend != offer.backend
632
+ or volume.configuration.region != offer.region
633
+ ):
634
+ continue
635
+ return volume
636
+ raise ServerClientError("Failed to find an eligible volume for the mount point")
575
637
 
576
638
 
577
639
  async def _attach_volumes(
@@ -586,6 +648,8 @@ async def _attach_volumes(
586
648
  project=project,
587
649
  backend_type=job_provisioning_data.backend,
588
650
  )
651
+ job_runtime_data = common_utils.get_or_error(get_job_runtime_data(job_model))
652
+ job_runtime_data.volume_names = []
589
653
  logger.info("Attaching volumes: %s", [[v.name for v in vs] for vs in volume_models])
590
654
  for mount_point_volume_models in volume_models:
591
655
  for volume_model in mount_point_volume_models:
@@ -604,6 +668,7 @@ async def _attach_volumes(
604
668
  instance=instance,
605
669
  instance_id=job_provisioning_data.instance_id,
606
670
  )
671
+ job_runtime_data.volume_names.append(volume.name)
607
672
  break # attach next mount point
608
673
  except (ServerClientError, BackendError) as e:
609
674
  logger.warning("%s: failed to attached volume: %s", fmt(job_model), repr(e))
@@ -620,6 +685,8 @@ async def _attach_volumes(
620
685
  # TODO: Replace with JobTerminationReason.VOLUME_ERROR in 0.19
621
686
  job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER
622
687
  job_model.termination_reason_message = "Failed to attach volume"
688
+ finally:
689
+ job_model.job_runtime_data = job_runtime_data.json()
623
690
 
624
691
 
625
692
  async def _attach_volume(
@@ -52,7 +52,7 @@ async def _process_next_terminating_job():
52
52
  InstanceModel.id == job_model.used_instance_id,
53
53
  InstanceModel.id.not_in(instance_lockset),
54
54
  )
55
- .options(lazyload(InstanceModel.job))
55
+ .options(lazyload(InstanceModel.jobs))
56
56
  .with_for_update(skip_locked=True)
57
57
  )
58
58
  instance_model = res.scalar()
@@ -0,0 +1,71 @@
1
+ """Reverse Job-Instance relationship
2
+
3
+ Revision ID: 1338b788b612
4
+ Revises: 51d45659d574
5
+ Create Date: 2025-01-16 14:59:19.113534
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ import sqlalchemy_utils
11
+ from alembic import op
12
+
13
+ # revision identifiers, used by Alembic.
14
+ revision = "1338b788b612"
15
+ down_revision = "51d45659d574"
16
+ branch_labels = None
17
+ depends_on = None
18
+
19
+
20
+ def upgrade() -> None:
21
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
22
+ batch_op.add_column(
23
+ sa.Column(
24
+ "instance_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True
25
+ )
26
+ )
27
+ batch_op.create_foreign_key(
28
+ batch_op.f("fk_jobs_instance_id_instances"),
29
+ "instances",
30
+ ["instance_id"],
31
+ ["id"],
32
+ ondelete="CASCADE",
33
+ )
34
+
35
+ op.execute("""
36
+ UPDATE jobs AS j
37
+ SET instance_id = (
38
+ SELECT i.id
39
+ FROM instances AS i
40
+ WHERE i.job_id = j.id
41
+ )
42
+ """)
43
+
44
+ with op.batch_alter_table("instances", schema=None) as batch_op:
45
+ batch_op.drop_constraint("fk_instances_job_id_jobs", type_="foreignkey")
46
+ batch_op.drop_column("job_id")
47
+
48
+
49
+ def downgrade() -> None:
50
+ with op.batch_alter_table("instances", schema=None) as batch_op:
51
+ batch_op.add_column(
52
+ sa.Column("job_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True)
53
+ )
54
+ batch_op.create_foreign_key("fk_instances_job_id_jobs", "jobs", ["job_id"], ["id"])
55
+
56
+ # This migration is not fully reversible - we cannot assign multiple jobs to a single instance,
57
+ # thus LIMIT 1
58
+ op.execute("""
59
+ UPDATE instances AS i
60
+ SET job_id = (
61
+ SELECT j.id
62
+ FROM jobs j
63
+ WHERE j.instance_id = i.id
64
+ ORDER by j.submitted_at DESC
65
+ LIMIT 1
66
+ )
67
+ """)
68
+
69
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
70
+ batch_op.drop_constraint(batch_op.f("fk_jobs_instance_id_instances"), type_="foreignkey")
71
+ batch_op.drop_column("instance_id")