dstack 0.18.40rc1__py3-none-any.whl → 0.18.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -5
  2. dstack/_internal/cli/services/configurators/base.py +4 -2
  3. dstack/_internal/cli/services/configurators/fleet.py +21 -9
  4. dstack/_internal/cli/services/configurators/gateway.py +15 -0
  5. dstack/_internal/cli/services/configurators/run.py +6 -5
  6. dstack/_internal/cli/services/configurators/volume.py +15 -0
  7. dstack/_internal/cli/services/repos.py +3 -3
  8. dstack/_internal/cli/utils/fleet.py +44 -33
  9. dstack/_internal/cli/utils/run.py +27 -7
  10. dstack/_internal/cli/utils/volume.py +30 -9
  11. dstack/_internal/core/backends/aws/compute.py +94 -53
  12. dstack/_internal/core/backends/aws/resources.py +22 -12
  13. dstack/_internal/core/backends/azure/compute.py +2 -0
  14. dstack/_internal/core/backends/base/compute.py +20 -2
  15. dstack/_internal/core/backends/gcp/compute.py +32 -24
  16. dstack/_internal/core/backends/gcp/resources.py +0 -15
  17. dstack/_internal/core/backends/oci/compute.py +10 -5
  18. dstack/_internal/core/backends/oci/resources.py +23 -26
  19. dstack/_internal/core/backends/remote/provisioning.py +65 -27
  20. dstack/_internal/core/backends/runpod/compute.py +1 -0
  21. dstack/_internal/core/models/backends/azure.py +3 -1
  22. dstack/_internal/core/models/configurations.py +24 -1
  23. dstack/_internal/core/models/fleets.py +46 -0
  24. dstack/_internal/core/models/instances.py +5 -1
  25. dstack/_internal/core/models/pools.py +4 -1
  26. dstack/_internal/core/models/profiles.py +10 -4
  27. dstack/_internal/core/models/runs.py +23 -3
  28. dstack/_internal/core/models/volumes.py +26 -0
  29. dstack/_internal/core/services/ssh/attach.py +92 -53
  30. dstack/_internal/core/services/ssh/tunnel.py +58 -31
  31. dstack/_internal/proxy/gateway/routers/registry.py +2 -0
  32. dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
  33. dstack/_internal/proxy/gateway/services/registry.py +4 -0
  34. dstack/_internal/proxy/lib/models.py +3 -0
  35. dstack/_internal/proxy/lib/services/service_connection.py +8 -1
  36. dstack/_internal/server/background/tasks/process_instances.py +73 -35
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -9
  38. dstack/_internal/server/background/tasks/process_running_jobs.py +77 -26
  39. dstack/_internal/server/background/tasks/process_runs.py +2 -12
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +121 -49
  41. dstack/_internal/server/background/tasks/process_terminating_jobs.py +14 -3
  42. dstack/_internal/server/background/tasks/process_volumes.py +11 -1
  43. dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
  44. dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
  45. dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
  46. dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
  47. dstack/_internal/server/migrations/versions/a751ef183f27_move_attachment_data_to_volumes_.py +34 -0
  48. dstack/_internal/server/models.py +27 -23
  49. dstack/_internal/server/routers/runs.py +1 -0
  50. dstack/_internal/server/schemas/runner.py +1 -0
  51. dstack/_internal/server/services/backends/configurators/azure.py +34 -8
  52. dstack/_internal/server/services/config.py +9 -0
  53. dstack/_internal/server/services/fleets.py +32 -3
  54. dstack/_internal/server/services/gateways/client.py +9 -1
  55. dstack/_internal/server/services/jobs/__init__.py +217 -45
  56. dstack/_internal/server/services/jobs/configurators/base.py +47 -2
  57. dstack/_internal/server/services/offers.py +96 -10
  58. dstack/_internal/server/services/pools.py +98 -14
  59. dstack/_internal/server/services/proxy/repo.py +17 -3
  60. dstack/_internal/server/services/runner/client.py +9 -6
  61. dstack/_internal/server/services/runner/ssh.py +33 -5
  62. dstack/_internal/server/services/runs.py +48 -179
  63. dstack/_internal/server/services/services/__init__.py +9 -1
  64. dstack/_internal/server/services/volumes.py +68 -9
  65. dstack/_internal/server/statics/index.html +1 -1
  66. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
  67. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
  68. dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
  69. dstack/_internal/server/testing/common.py +130 -61
  70. dstack/_internal/utils/common.py +22 -8
  71. dstack/_internal/utils/env.py +14 -0
  72. dstack/_internal/utils/ssh.py +1 -1
  73. dstack/api/server/_fleets.py +25 -1
  74. dstack/api/server/_runs.py +23 -2
  75. dstack/api/server/_volumes.py +12 -1
  76. dstack/version.py +1 -1
  77. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/METADATA +1 -1
  78. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/RECORD +104 -93
  79. tests/_internal/cli/services/configurators/test_profile.py +3 -3
  80. tests/_internal/core/services/ssh/test_tunnel.py +56 -4
  81. tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
  82. tests/_internal/server/background/tasks/test_process_instances.py +138 -20
  83. tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
  84. tests/_internal/server/background/tasks/test_process_running_jobs.py +193 -0
  85. tests/_internal/server/background/tasks/test_process_runs.py +27 -3
  86. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +53 -6
  87. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +135 -17
  88. tests/_internal/server/routers/test_fleets.py +15 -2
  89. tests/_internal/server/routers/test_pools.py +6 -0
  90. tests/_internal/server/routers/test_runs.py +27 -0
  91. tests/_internal/server/routers/test_volumes.py +9 -2
  92. tests/_internal/server/services/jobs/__init__.py +0 -0
  93. tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
  94. tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
  95. tests/_internal/server/services/runner/test_client.py +22 -3
  96. tests/_internal/server/services/test_offers.py +167 -0
  97. tests/_internal/server/services/test_pools.py +109 -1
  98. tests/_internal/server/services/test_runs.py +5 -41
  99. tests/_internal/utils/test_common.py +21 -0
  100. tests/_internal/utils/test_env.py +38 -0
  101. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/LICENSE.md +0 -0
  102. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/WHEEL +0 -0
  103. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/entry_points.txt +0 -0
  104. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from datetime import datetime, timezone
2
2
  from pathlib import Path
3
+ from typing import Optional
3
4
  from unittest.mock import MagicMock, Mock, patch
4
5
 
5
6
  import pytest
@@ -8,11 +9,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
8
9
  from dstack._internal.core.errors import SSHError
9
10
  from dstack._internal.core.models.backends.base import BackendType
10
11
  from dstack._internal.core.models.common import NetworkMode
12
+ from dstack._internal.core.models.configurations import DevEnvironmentConfiguration
11
13
  from dstack._internal.core.models.instances import InstanceStatus
12
14
  from dstack._internal.core.models.runs import (
13
15
  JobRuntimeData,
14
16
  JobStatus,
15
17
  JobTerminationReason,
18
+ RunStatus,
16
19
  )
17
20
  from dstack._internal.core.models.volumes import (
18
21
  InstanceMountPoint,
@@ -100,6 +103,13 @@ class TestProcessRunningJobs:
100
103
  repo=repo,
101
104
  user=user,
102
105
  )
106
+ pool = await create_pool(session=session, project=project)
107
+ instance = await create_instance(
108
+ session=session,
109
+ project=project,
110
+ pool=pool,
111
+ status=InstanceStatus.BUSY,
112
+ )
103
113
  job_provisioning_data = get_job_provisioning_data(dockerized=False)
104
114
  job = await create_job(
105
115
  session=session,
@@ -107,6 +117,8 @@ class TestProcessRunningJobs:
107
117
  status=JobStatus.PROVISIONING,
108
118
  submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc),
109
119
  job_provisioning_data=job_provisioning_data,
120
+ instance=instance,
121
+ instance_assigned=True,
110
122
  )
111
123
  with (
112
124
  patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
@@ -141,12 +153,21 @@ class TestProcessRunningJobs:
141
153
  repo=repo,
142
154
  user=user,
143
155
  )
156
+ pool = await create_pool(session=session, project=project)
157
+ instance = await create_instance(
158
+ session=session,
159
+ project=project,
160
+ pool=pool,
161
+ status=InstanceStatus.BUSY,
162
+ )
144
163
  job_provisioning_data = get_job_provisioning_data(dockerized=False)
145
164
  job = await create_job(
146
165
  session=session,
147
166
  run=run,
148
167
  status=JobStatus.PROVISIONING,
149
168
  job_provisioning_data=job_provisioning_data,
169
+ instance=instance,
170
+ instance_assigned=True,
150
171
  )
151
172
  with (
152
173
  patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
@@ -183,12 +204,21 @@ class TestProcessRunningJobs:
183
204
  repo=repo,
184
205
  user=user,
185
206
  )
207
+ pool = await create_pool(session=session, project=project)
208
+ instance = await create_instance(
209
+ session=session,
210
+ project=project,
211
+ pool=pool,
212
+ status=InstanceStatus.BUSY,
213
+ )
186
214
  job_provisioning_data = get_job_provisioning_data(dockerized=False)
187
215
  job = await create_job(
188
216
  session=session,
189
217
  run=run,
190
218
  status=JobStatus.RUNNING,
191
219
  job_provisioning_data=job_provisioning_data,
220
+ instance=instance,
221
+ instance_assigned=True,
192
222
  )
193
223
  with (
194
224
  patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
@@ -274,6 +304,13 @@ class TestProcessRunningJobs:
274
304
  run_name="test-run",
275
305
  run_spec=run_spec,
276
306
  )
307
+ pool = await create_pool(session=session, project=project)
308
+ instance = await create_instance(
309
+ session=session,
310
+ project=project,
311
+ pool=pool,
312
+ status=InstanceStatus.BUSY,
313
+ )
277
314
  job_provisioning_data = get_job_provisioning_data(dockerized=True)
278
315
 
279
316
  with patch(
@@ -285,6 +322,8 @@ class TestProcessRunningJobs:
285
322
  run=run,
286
323
  status=JobStatus.PROVISIONING,
287
324
  job_provisioning_data=job_provisioning_data,
325
+ instance=instance,
326
+ instance_assigned=True,
288
327
  )
289
328
 
290
329
  await process_running_jobs()
@@ -310,6 +349,7 @@ class TestProcessRunningJobs:
310
349
  host_ssh_user="ubuntu",
311
350
  host_ssh_keys=["user_ssh_key"],
312
351
  container_ssh_keys=[project_ssh_pub_key, "user_ssh_key"],
352
+ instance_id=job_provisioning_data.instance_id,
313
353
  )
314
354
  await session.refresh(job)
315
355
  assert job is not None
@@ -334,12 +374,21 @@ class TestProcessRunningJobs:
334
374
  repo=repo,
335
375
  user=user,
336
376
  )
377
+ pool = await create_pool(session=session, project=project)
378
+ instance = await create_instance(
379
+ session=session,
380
+ project=project,
381
+ pool=pool,
382
+ status=InstanceStatus.BUSY,
383
+ )
337
384
  job = await create_job(
338
385
  session=session,
339
386
  run=run,
340
387
  status=JobStatus.PULLING,
341
388
  job_provisioning_data=get_job_provisioning_data(dockerized=True),
342
389
  job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None),
390
+ instance=instance,
391
+ instance_assigned=True,
343
392
  )
344
393
  shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING
345
394
  shim_client_mock.get_task.return_value.ports = [
@@ -382,6 +431,13 @@ class TestProcessRunningJobs:
382
431
  repo=repo,
383
432
  user=user,
384
433
  )
434
+ pool = await create_pool(session=session, project=project)
435
+ instance = await create_instance(
436
+ session=session,
437
+ project=project,
438
+ pool=pool,
439
+ status=InstanceStatus.BUSY,
440
+ )
385
441
  job_provisioning_data = get_job_provisioning_data(dockerized=True)
386
442
  job = await create_job(
387
443
  session=session,
@@ -389,6 +445,8 @@ class TestProcessRunningJobs:
389
445
  status=JobStatus.PULLING,
390
446
  job_provisioning_data=job_provisioning_data,
391
447
  job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None),
448
+ instance=instance,
449
+ instance_assigned=True,
392
450
  )
393
451
  shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING
394
452
  shim_client_mock.get_task.return_value.ports = None
@@ -467,12 +525,21 @@ class TestProcessRunningJobs:
467
525
  run_name="test-run",
468
526
  run_spec=run_spec,
469
527
  )
528
+ pool = await create_pool(session=session, project=project)
529
+ instance = await create_instance(
530
+ session=session,
531
+ project=project,
532
+ pool=pool,
533
+ status=InstanceStatus.BUSY,
534
+ )
470
535
  job = await create_job(
471
536
  session=session,
472
537
  run=run,
473
538
  status=JobStatus.PROVISIONING,
474
539
  job_provisioning_data=get_job_provisioning_data(dockerized=True),
475
540
  submitted_at=get_current_datetime(),
541
+ instance=instance,
542
+ instance_assigned=True,
476
543
  )
477
544
  monkeypatch.setattr(
478
545
  "dstack._internal.server.services.runner.ssh.SSHTunnel", Mock(return_value=MagicMock())
@@ -495,3 +562,129 @@ class TestProcessRunningJobs:
495
562
  shim_client_mock.stop.assert_called_once_with(force=True)
496
563
  await session.refresh(job)
497
564
  assert job.status == JobStatus.PROVISIONING
565
+
566
+ @pytest.mark.asyncio
567
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
568
+ @pytest.mark.parametrize(
569
+ (
570
+ "inactivity_duration",
571
+ "no_connections_secs",
572
+ "expected_status",
573
+ "expected_termination_reason",
574
+ "expected_inactivity_secs",
575
+ ),
576
+ [
577
+ pytest.param(
578
+ "1h",
579
+ 60 * 60 - 1,
580
+ JobStatus.RUNNING,
581
+ None,
582
+ 60 * 60 - 1,
583
+ id="duration-not-exceeded",
584
+ ),
585
+ pytest.param(
586
+ "1h",
587
+ 60 * 60,
588
+ JobStatus.TERMINATING,
589
+ JobTerminationReason.TERMINATED_BY_SERVER,
590
+ 60 * 60,
591
+ id="duration-exceeded-exactly",
592
+ ),
593
+ pytest.param(
594
+ "1h",
595
+ 60 * 60 + 1,
596
+ JobStatus.TERMINATING,
597
+ JobTerminationReason.TERMINATED_BY_SERVER,
598
+ 60 * 60 + 1,
599
+ id="duration-exceeded",
600
+ ),
601
+ pytest.param("off", 60 * 60, JobStatus.RUNNING, None, None, id="duration-off"),
602
+ pytest.param(False, 60 * 60, JobStatus.RUNNING, None, None, id="duration-false"),
603
+ pytest.param(None, 60 * 60, JobStatus.RUNNING, None, None, id="duration-none"),
604
+ pytest.param(
605
+ "1h",
606
+ None,
607
+ JobStatus.TERMINATING,
608
+ JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY,
609
+ None,
610
+ id="legacy-runner",
611
+ ),
612
+ pytest.param(
613
+ None,
614
+ None,
615
+ JobStatus.RUNNING,
616
+ None,
617
+ None,
618
+ id="legacy-runner-without-duration",
619
+ ),
620
+ ],
621
+ )
622
+ async def test_inactivity_duration(
623
+ self,
624
+ test_db,
625
+ session: AsyncSession,
626
+ inactivity_duration,
627
+ no_connections_secs: Optional[int],
628
+ expected_status: JobStatus,
629
+ expected_termination_reason: Optional[JobTerminationReason],
630
+ expected_inactivity_secs: Optional[int],
631
+ ) -> None:
632
+ project = await create_project(session=session)
633
+ user = await create_user(session=session)
634
+ repo = await create_repo(
635
+ session=session,
636
+ project_id=project.id,
637
+ )
638
+ run = await create_run(
639
+ session=session,
640
+ project=project,
641
+ repo=repo,
642
+ user=user,
643
+ status=RunStatus.RUNNING,
644
+ run_name="test-run",
645
+ run_spec=get_run_spec(
646
+ run_name="test-run",
647
+ repo_id=repo.name,
648
+ configuration=DevEnvironmentConfiguration(
649
+ name="test-run",
650
+ ide="vscode",
651
+ inactivity_duration=inactivity_duration,
652
+ ),
653
+ ),
654
+ )
655
+ pool = await create_pool(session=session, project=project)
656
+ instance = await create_instance(
657
+ session=session,
658
+ project=project,
659
+ pool=pool,
660
+ status=InstanceStatus.BUSY,
661
+ )
662
+ job = await create_job(
663
+ session=session,
664
+ run=run,
665
+ status=JobStatus.RUNNING,
666
+ job_provisioning_data=get_job_provisioning_data(),
667
+ instance=instance,
668
+ instance_assigned=True,
669
+ )
670
+ with (
671
+ patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
672
+ patch(
673
+ "dstack._internal.server.services.runner.client.RunnerClient"
674
+ ) as RunnerClientMock,
675
+ ):
676
+ runner_client_mock = RunnerClientMock.return_value
677
+ runner_client_mock.pull.return_value = PullResponse(
678
+ job_states=[],
679
+ job_logs=[],
680
+ runner_logs=[],
681
+ last_updated=0,
682
+ no_connections_secs=no_connections_secs,
683
+ )
684
+ await process_running_jobs()
685
+ SSHTunnelMock.assert_called_once()
686
+ runner_client_mock.pull.assert_called_once()
687
+ await session.refresh(job)
688
+ assert job.status == expected_status
689
+ assert job.termination_reason == expected_termination_reason
690
+ assert job.inactivity_secs == expected_inactivity_secs
@@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
8
8
 
9
9
  import dstack._internal.server.background.tasks.process_runs as process_runs
10
10
  from dstack._internal.core.models.configurations import ServiceConfiguration
11
+ from dstack._internal.core.models.instances import InstanceStatus
11
12
  from dstack._internal.core.models.profiles import Profile
12
13
  from dstack._internal.core.models.resources import Range
13
14
  from dstack._internal.core.models.runs import (
@@ -116,11 +117,20 @@ class TestProcessRuns:
116
117
  async def test_terminate_run_jobs(self, test_db, session: AsyncSession):
117
118
  run = await make_run(session, status=RunStatus.TERMINATING)
118
119
  run.termination_reason = RunTerminationReason.JOB_FAILED
120
+ pool = await create_pool(session=session, project=run.project)
121
+ instance = await create_instance(
122
+ session=session,
123
+ project=run.project,
124
+ pool=pool,
125
+ status=InstanceStatus.BUSY,
126
+ )
119
127
  job = await create_job(
120
128
  session=session,
121
129
  run=run,
122
130
  job_provisioning_data=get_job_provisioning_data(),
123
131
  status=JobStatus.RUNNING,
132
+ instance=instance,
133
+ instance_assigned=True,
124
134
  )
125
135
 
126
136
  with patch("dstack._internal.server.services.jobs._stop_runner") as stop_runner:
@@ -288,13 +298,27 @@ class TestProcessRunsReplicas:
288
298
 
289
299
  @pytest.mark.asyncio
290
300
  @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
291
- async def test_some_failed_to_terminating(self, test_db, session: AsyncSession):
301
+ @pytest.mark.parametrize(
302
+ ("job_status", "job_termination_reason"),
303
+ [
304
+ (JobStatus.FAILED, JobTerminationReason.CONTAINER_EXITED_WITH_ERROR),
305
+ (JobStatus.TERMINATING, JobTerminationReason.TERMINATED_BY_SERVER),
306
+ (JobStatus.TERMINATED, JobTerminationReason.TERMINATED_BY_SERVER),
307
+ ],
308
+ )
309
+ async def test_some_failed_to_terminating(
310
+ self,
311
+ test_db,
312
+ session: AsyncSession,
313
+ job_status: JobStatus,
314
+ job_termination_reason: JobTerminationReason,
315
+ ) -> None:
292
316
  run = await make_run(session, status=RunStatus.RUNNING, replicas=2)
293
317
  await create_job(
294
318
  session=session,
295
319
  run=run,
296
- status=JobStatus.FAILED,
297
- termination_reason=JobTerminationReason.CONTAINER_EXITED_WITH_ERROR,
320
+ status=job_status,
321
+ termination_reason=job_termination_reason,
298
322
  replica_num=0,
299
323
  )
300
324
  await create_job(session=session, run=run, status=JobStatus.RUNNING, replica_num=1)
@@ -27,7 +27,7 @@ from dstack._internal.core.models.volumes import (
27
27
  VolumeStatus,
28
28
  )
29
29
  from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs
30
- from dstack._internal.server.models import InstanceModel, JobModel
30
+ from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel
31
31
  from dstack._internal.server.testing.common import (
32
32
  create_fleet,
33
33
  create_instance,
@@ -38,6 +38,7 @@ from dstack._internal.server.testing.common import (
38
38
  create_run,
39
39
  create_user,
40
40
  create_volume,
41
+ get_instance_offer_with_availability,
41
42
  get_run_spec,
42
43
  get_volume_provisioning_data,
43
44
  )
@@ -447,7 +448,7 @@ class TestProcessSubmittedJobs:
447
448
  await process_submitted_jobs()
448
449
  await session.refresh(job)
449
450
  res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
450
- job = res.scalar_one()
451
+ job = res.unique().scalar_one()
451
452
  assert job.status == JobStatus.SUBMITTED
452
453
  assert (
453
454
  job.instance_assigned and job.instance is not None and job.instance.id == instance.id
@@ -514,15 +515,61 @@ class TestProcessSubmittedJobs:
514
515
  await session.refresh(instance)
515
516
  res = await session.execute(
516
517
  select(JobModel).options(
517
- joinedload(JobModel.instance).selectinload(InstanceModel.volumes)
518
+ joinedload(JobModel.instance)
519
+ .joinedload(InstanceModel.volume_attachments)
520
+ .joinedload(VolumeAttachmentModel.volume)
518
521
  )
519
522
  )
520
- job = res.scalar_one()
523
+ job = res.unique().scalar_one()
521
524
  assert job.status == JobStatus.PROVISIONING
522
525
  assert (
523
526
  job.instance_assigned and job.instance is not None and job.instance.id == instance.id
524
527
  )
525
- assert job.instance.volumes == [volume]
528
+ assert job.instance.volume_attachments[0].volume == volume
529
+
530
+ @pytest.mark.asyncio
531
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
532
+ async def test_assigns_job_to_shared_instance(self, test_db, session: AsyncSession):
533
+ project = await create_project(session)
534
+ user = await create_user(session)
535
+ pool = await create_pool(session=session, project=project)
536
+ repo = await create_repo(
537
+ session=session,
538
+ project_id=project.id,
539
+ )
540
+ offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128)
541
+ instance = await create_instance(
542
+ session=session,
543
+ project=project,
544
+ pool=pool,
545
+ status=InstanceStatus.IDLE,
546
+ offer=offer,
547
+ total_blocks=4,
548
+ busy_blocks=1,
549
+ )
550
+ await session.refresh(pool)
551
+ run = await create_run(
552
+ session=session,
553
+ project=project,
554
+ repo=repo,
555
+ user=user,
556
+ )
557
+ job = await create_job(
558
+ session=session,
559
+ run=run,
560
+ instance_assigned=False,
561
+ )
562
+ await process_submitted_jobs()
563
+ await session.refresh(job)
564
+ await session.refresh(instance)
565
+ res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
566
+ job = res.unique().scalar_one()
567
+ assert job.status == JobStatus.SUBMITTED
568
+ assert (
569
+ job.instance_assigned and job.instance is not None and job.instance.id == instance.id
570
+ )
571
+ assert instance.total_blocks == 4
572
+ assert instance.busy_blocks == 2
526
573
 
527
574
  @pytest.mark.asyncio
528
575
  @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@@ -590,7 +637,7 @@ class TestProcessSubmittedJobs:
590
637
 
591
638
  await session.refresh(job)
592
639
  res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
593
- job = res.scalar_one()
640
+ job = res.unique().scalar_one()
594
641
  assert job.status == JobStatus.PROVISIONING
595
642
  assert job.instance is not None
596
643
  assert job.instance.instance_num == 1
@@ -13,7 +13,7 @@ from dstack._internal.core.models.volumes import VolumeStatus
13
13
  from dstack._internal.server.background.tasks.process_terminating_jobs import (
14
14
  process_terminating_jobs,
15
15
  )
16
- from dstack._internal.server.models import InstanceModel, JobModel
16
+ from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel
17
17
  from dstack._internal.server.services.volumes import volume_model_to_volume
18
18
  from dstack._internal.server.testing.common import (
19
19
  create_instance,
@@ -24,17 +24,19 @@ from dstack._internal.server.testing.common import (
24
24
  create_run,
25
25
  create_user,
26
26
  create_volume,
27
+ get_instance_offer_with_availability,
27
28
  get_job_provisioning_data,
29
+ get_job_runtime_data,
30
+ get_volume_configuration,
28
31
  get_volume_provisioning_data,
29
32
  )
30
33
 
31
- pytestmark = pytest.mark.usefixtures("image_config_mock")
32
-
33
34
 
35
+ @pytest.mark.asyncio
36
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
37
+ @pytest.mark.usefixtures("test_db", "image_config_mock")
34
38
  class TestProcessTerminatingJobs:
35
- @pytest.mark.asyncio
36
- @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
37
- async def test_terminates_job(self, test_db, session: AsyncSession):
39
+ async def test_terminates_job(self, session: AsyncSession):
38
40
  project = await create_project(session=session)
39
41
  user = await create_user(session=session)
40
42
  pool = await create_pool(session=session, project=project)
@@ -73,9 +75,7 @@ class TestProcessTerminatingJobs:
73
75
  assert job is not None
74
76
  assert job.status == JobStatus.TERMINATED
75
77
 
76
- @pytest.mark.asyncio
77
- @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
78
- async def test_detaches_job_volumes(self, test_db, session: AsyncSession):
78
+ async def test_detaches_job_volumes(self, session: AsyncSession):
79
79
  project = await create_project(session=session)
80
80
  user = await create_user(session=session)
81
81
  pool = await create_pool(session=session, project=project)
@@ -122,9 +122,7 @@ class TestProcessTerminatingJobs:
122
122
  await session.refresh(job)
123
123
  assert job.status == JobStatus.TERMINATED
124
124
 
125
- @pytest.mark.asyncio
126
- @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
127
- async def test_force_detaches_job_volumes(self, test_db, session: AsyncSession):
125
+ async def test_force_detaches_job_volumes(self, session: AsyncSession):
128
126
  project = await create_project(session=session)
129
127
  user = await create_user(session=session)
130
128
  pool = await create_pool(session=session, project=project)
@@ -172,13 +170,13 @@ class TestProcessTerminatingJobs:
172
170
  backend_mock.compute.return_value.is_volume_detached.assert_called_once()
173
171
  await session.refresh(job)
174
172
  res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
175
- job = res.scalar_one()
173
+ job = res.unique().scalar_one()
176
174
  assert job.status == JobStatus.TERMINATING
177
175
  # The instance should be released even if detach fails
178
176
  # so that stuck volumes don't prevent the instance from terminating.
179
177
  assert job.instance is None
180
178
  assert job.volumes_detached_at is not None
181
- assert len(instance.volumes) == 1
179
+ assert len(instance.volume_attachments) == 1
182
180
 
183
181
  # Force detach called
184
182
  with (
@@ -214,10 +212,130 @@ class TestProcessTerminatingJobs:
214
212
  await session.refresh(job)
215
213
  await session.refresh(instance)
216
214
  res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
217
- job = res.scalar_one()
215
+ job = res.unique().scalar_one()
218
216
  res = await session.execute(
219
- select(InstanceModel).options(joinedload(InstanceModel.volumes))
217
+ select(InstanceModel).options(joinedload(InstanceModel.volume_attachments))
220
218
  )
221
219
  instance = res.unique().scalar_one()
222
220
  assert job.status == JobStatus.TERMINATED
223
- assert len(instance.volumes) == 0
221
+ assert len(instance.volume_attachments) == 0
222
+
223
+ async def test_terminates_job_on_shared_instance(self, session: AsyncSession):
224
+ project = await create_project(session)
225
+ user = await create_user(session)
226
+ pool = await create_pool(session=session, project=project)
227
+ repo = await create_repo(
228
+ session=session,
229
+ project_id=project.id,
230
+ )
231
+ instance = await create_instance(
232
+ session=session,
233
+ project=project,
234
+ pool=pool,
235
+ status=InstanceStatus.BUSY,
236
+ total_blocks=4,
237
+ busy_blocks=3,
238
+ )
239
+ await session.refresh(pool)
240
+ run = await create_run(
241
+ session=session,
242
+ project=project,
243
+ repo=repo,
244
+ user=user,
245
+ )
246
+ shared_offer = get_instance_offer_with_availability(blocks=2, total_blocks=4)
247
+ jrd = get_job_runtime_data(offer=shared_offer)
248
+ job = await create_job(
249
+ session=session,
250
+ run=run,
251
+ instance_assigned=True,
252
+ instance=instance,
253
+ job_runtime_data=jrd,
254
+ status=JobStatus.TERMINATING,
255
+ termination_reason=JobTerminationReason.TERMINATED_BY_USER,
256
+ )
257
+
258
+ await process_terminating_jobs()
259
+
260
+ await session.refresh(job)
261
+ await session.refresh(instance)
262
+ res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
263
+ job = res.unique().scalar_one()
264
+ assert job.status == JobStatus.TERMINATED
265
+ assert job.instance_assigned
266
+ assert job.instance is None
267
+ assert instance.total_blocks == 4
268
+ assert instance.busy_blocks == 1
269
+
270
+ async def test_detaches_job_volumes_on_shared_instance(self, session: AsyncSession):
271
+ project = await create_project(session=session)
272
+ user = await create_user(session=session)
273
+ pool = await create_pool(session=session, project=project)
274
+ volume_conf_1 = get_volume_configuration(name="vol-1")
275
+ volume_1 = await create_volume(
276
+ session=session,
277
+ project=project,
278
+ user=user,
279
+ status=VolumeStatus.ACTIVE,
280
+ backend=BackendType.AWS,
281
+ configuration=volume_conf_1,
282
+ volume_provisioning_data=get_volume_provisioning_data(),
283
+ )
284
+ volume_conf_2 = get_volume_configuration(name="vol-2")
285
+ volume_2 = await create_volume(
286
+ session=session,
287
+ project=project,
288
+ user=user,
289
+ status=VolumeStatus.ACTIVE,
290
+ backend=BackendType.AWS,
291
+ configuration=volume_conf_2,
292
+ volume_provisioning_data=get_volume_provisioning_data(),
293
+ )
294
+ instance = await create_instance(
295
+ session=session,
296
+ project=project,
297
+ pool=pool,
298
+ status=InstanceStatus.BUSY,
299
+ volumes=[volume_1, volume_2],
300
+ )
301
+ repo = await create_repo(session=session, project_id=project.id)
302
+ run = await create_run(
303
+ session=session,
304
+ project=project,
305
+ repo=repo,
306
+ user=user,
307
+ )
308
+ job_provisioning_data = get_job_provisioning_data(dockerized=False)
309
+ job = await create_job(
310
+ session=session,
311
+ run=run,
312
+ status=JobStatus.TERMINATING,
313
+ termination_reason=JobTerminationReason.TERMINATED_BY_USER,
314
+ submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc),
315
+ job_provisioning_data=job_provisioning_data,
316
+ job_runtime_data=get_job_runtime_data(volume_names=["vol-1"]),
317
+ instance=instance,
318
+ )
319
+ with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m:
320
+ backend_mock = Mock()
321
+ m.return_value = backend_mock
322
+ backend_mock.compute.return_value.is_volume_detached.return_value = True
323
+
324
+ await process_terminating_jobs()
325
+
326
+ m.assert_awaited_once()
327
+ backend_mock.compute.return_value.detach_volume.assert_called_once()
328
+ backend_mock.compute.return_value.is_volume_detached.assert_called_once()
329
+ await session.refresh(job)
330
+ await session.refresh(instance)
331
+ assert job.status == JobStatus.TERMINATED
332
+ res = await session.execute(
333
+ select(InstanceModel).options(
334
+ joinedload(InstanceModel.volume_attachments).joinedload(
335
+ VolumeAttachmentModel.volume
336
+ )
337
+ )
338
+ )
339
+ instance = res.unique().scalar_one()
340
+ assert len(instance.volume_attachments) == 1
341
+ assert instance.volume_attachments[0].volume == volume_2