dstack 0.18.40rc1__py3-none-any.whl → 0.18.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -5
  2. dstack/_internal/cli/services/configurators/base.py +4 -2
  3. dstack/_internal/cli/services/configurators/fleet.py +21 -9
  4. dstack/_internal/cli/services/configurators/gateway.py +15 -0
  5. dstack/_internal/cli/services/configurators/run.py +6 -5
  6. dstack/_internal/cli/services/configurators/volume.py +15 -0
  7. dstack/_internal/cli/services/repos.py +3 -3
  8. dstack/_internal/cli/utils/fleet.py +44 -33
  9. dstack/_internal/cli/utils/run.py +27 -7
  10. dstack/_internal/cli/utils/volume.py +21 -9
  11. dstack/_internal/core/backends/aws/compute.py +92 -52
  12. dstack/_internal/core/backends/aws/resources.py +22 -12
  13. dstack/_internal/core/backends/azure/compute.py +2 -0
  14. dstack/_internal/core/backends/base/compute.py +20 -2
  15. dstack/_internal/core/backends/gcp/compute.py +30 -23
  16. dstack/_internal/core/backends/gcp/resources.py +0 -15
  17. dstack/_internal/core/backends/oci/compute.py +10 -5
  18. dstack/_internal/core/backends/oci/resources.py +23 -26
  19. dstack/_internal/core/backends/remote/provisioning.py +65 -27
  20. dstack/_internal/core/backends/runpod/compute.py +1 -0
  21. dstack/_internal/core/models/backends/azure.py +3 -1
  22. dstack/_internal/core/models/configurations.py +24 -1
  23. dstack/_internal/core/models/fleets.py +46 -0
  24. dstack/_internal/core/models/instances.py +5 -1
  25. dstack/_internal/core/models/pools.py +4 -1
  26. dstack/_internal/core/models/profiles.py +10 -4
  27. dstack/_internal/core/models/runs.py +20 -0
  28. dstack/_internal/core/models/volumes.py +3 -0
  29. dstack/_internal/core/services/ssh/attach.py +92 -53
  30. dstack/_internal/core/services/ssh/tunnel.py +58 -31
  31. dstack/_internal/proxy/gateway/routers/registry.py +2 -0
  32. dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
  33. dstack/_internal/proxy/gateway/services/registry.py +4 -0
  34. dstack/_internal/proxy/lib/models.py +3 -0
  35. dstack/_internal/proxy/lib/services/service_connection.py +8 -1
  36. dstack/_internal/server/background/tasks/process_instances.py +72 -33
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -9
  38. dstack/_internal/server/background/tasks/process_running_jobs.py +73 -26
  39. dstack/_internal/server/background/tasks/process_runs.py +2 -12
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +109 -42
  41. dstack/_internal/server/background/tasks/process_terminating_jobs.py +1 -1
  42. dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
  43. dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
  44. dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
  45. dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
  46. dstack/_internal/server/models.py +10 -4
  47. dstack/_internal/server/routers/runs.py +1 -0
  48. dstack/_internal/server/schemas/runner.py +1 -0
  49. dstack/_internal/server/services/backends/configurators/azure.py +34 -8
  50. dstack/_internal/server/services/config.py +9 -0
  51. dstack/_internal/server/services/fleets.py +27 -2
  52. dstack/_internal/server/services/gateways/client.py +9 -1
  53. dstack/_internal/server/services/jobs/__init__.py +215 -43
  54. dstack/_internal/server/services/jobs/configurators/base.py +47 -2
  55. dstack/_internal/server/services/offers.py +91 -5
  56. dstack/_internal/server/services/pools.py +95 -11
  57. dstack/_internal/server/services/proxy/repo.py +17 -3
  58. dstack/_internal/server/services/runner/client.py +1 -1
  59. dstack/_internal/server/services/runner/ssh.py +33 -5
  60. dstack/_internal/server/services/runs.py +48 -179
  61. dstack/_internal/server/services/services/__init__.py +9 -1
  62. dstack/_internal/server/statics/index.html +1 -1
  63. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
  64. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
  65. dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
  66. dstack/_internal/server/testing/common.py +117 -52
  67. dstack/_internal/utils/common.py +22 -8
  68. dstack/_internal/utils/env.py +14 -0
  69. dstack/_internal/utils/ssh.py +1 -1
  70. dstack/api/server/_fleets.py +25 -1
  71. dstack/api/server/_runs.py +23 -2
  72. dstack/api/server/_volumes.py +12 -1
  73. dstack/version.py +1 -1
  74. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/METADATA +1 -1
  75. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/RECORD +98 -89
  76. tests/_internal/cli/services/configurators/test_profile.py +3 -3
  77. tests/_internal/core/services/ssh/test_tunnel.py +56 -4
  78. tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
  79. tests/_internal/server/background/tasks/test_process_instances.py +138 -20
  80. tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
  81. tests/_internal/server/background/tasks/test_process_running_jobs.py +192 -0
  82. tests/_internal/server/background/tasks/test_process_runs.py +27 -3
  83. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +48 -3
  84. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +126 -13
  85. tests/_internal/server/routers/test_fleets.py +15 -2
  86. tests/_internal/server/routers/test_pools.py +6 -0
  87. tests/_internal/server/routers/test_runs.py +27 -0
  88. tests/_internal/server/services/jobs/__init__.py +0 -0
  89. tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
  90. tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
  91. tests/_internal/server/services/test_pools.py +4 -0
  92. tests/_internal/server/services/test_runs.py +5 -41
  93. tests/_internal/utils/test_common.py +21 -0
  94. tests/_internal/utils/test_env.py +38 -0
  95. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/LICENSE.md +0 -0
  96. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/WHEEL +0 -0
  97. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/entry_points.txt +0 -0
  98. {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from datetime import datetime, timezone
2
2
  from pathlib import Path
3
+ from typing import Optional
3
4
  from unittest.mock import MagicMock, Mock, patch
4
5
 
5
6
  import pytest
@@ -8,11 +9,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
8
9
  from dstack._internal.core.errors import SSHError
9
10
  from dstack._internal.core.models.backends.base import BackendType
10
11
  from dstack._internal.core.models.common import NetworkMode
12
+ from dstack._internal.core.models.configurations import DevEnvironmentConfiguration
11
13
  from dstack._internal.core.models.instances import InstanceStatus
12
14
  from dstack._internal.core.models.runs import (
13
15
  JobRuntimeData,
14
16
  JobStatus,
15
17
  JobTerminationReason,
18
+ RunStatus,
16
19
  )
17
20
  from dstack._internal.core.models.volumes import (
18
21
  InstanceMountPoint,
@@ -100,6 +103,13 @@ class TestProcessRunningJobs:
100
103
  repo=repo,
101
104
  user=user,
102
105
  )
106
+ pool = await create_pool(session=session, project=project)
107
+ instance = await create_instance(
108
+ session=session,
109
+ project=project,
110
+ pool=pool,
111
+ status=InstanceStatus.BUSY,
112
+ )
103
113
  job_provisioning_data = get_job_provisioning_data(dockerized=False)
104
114
  job = await create_job(
105
115
  session=session,
@@ -107,6 +117,8 @@ class TestProcessRunningJobs:
107
117
  status=JobStatus.PROVISIONING,
108
118
  submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc),
109
119
  job_provisioning_data=job_provisioning_data,
120
+ instance=instance,
121
+ instance_assigned=True,
110
122
  )
111
123
  with (
112
124
  patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
@@ -141,12 +153,21 @@ class TestProcessRunningJobs:
141
153
  repo=repo,
142
154
  user=user,
143
155
  )
156
+ pool = await create_pool(session=session, project=project)
157
+ instance = await create_instance(
158
+ session=session,
159
+ project=project,
160
+ pool=pool,
161
+ status=InstanceStatus.BUSY,
162
+ )
144
163
  job_provisioning_data = get_job_provisioning_data(dockerized=False)
145
164
  job = await create_job(
146
165
  session=session,
147
166
  run=run,
148
167
  status=JobStatus.PROVISIONING,
149
168
  job_provisioning_data=job_provisioning_data,
169
+ instance=instance,
170
+ instance_assigned=True,
150
171
  )
151
172
  with (
152
173
  patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
@@ -183,12 +204,21 @@ class TestProcessRunningJobs:
183
204
  repo=repo,
184
205
  user=user,
185
206
  )
207
+ pool = await create_pool(session=session, project=project)
208
+ instance = await create_instance(
209
+ session=session,
210
+ project=project,
211
+ pool=pool,
212
+ status=InstanceStatus.BUSY,
213
+ )
186
214
  job_provisioning_data = get_job_provisioning_data(dockerized=False)
187
215
  job = await create_job(
188
216
  session=session,
189
217
  run=run,
190
218
  status=JobStatus.RUNNING,
191
219
  job_provisioning_data=job_provisioning_data,
220
+ instance=instance,
221
+ instance_assigned=True,
192
222
  )
193
223
  with (
194
224
  patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
@@ -274,6 +304,13 @@ class TestProcessRunningJobs:
274
304
  run_name="test-run",
275
305
  run_spec=run_spec,
276
306
  )
307
+ pool = await create_pool(session=session, project=project)
308
+ instance = await create_instance(
309
+ session=session,
310
+ project=project,
311
+ pool=pool,
312
+ status=InstanceStatus.BUSY,
313
+ )
277
314
  job_provisioning_data = get_job_provisioning_data(dockerized=True)
278
315
 
279
316
  with patch(
@@ -285,6 +322,8 @@ class TestProcessRunningJobs:
285
322
  run=run,
286
323
  status=JobStatus.PROVISIONING,
287
324
  job_provisioning_data=job_provisioning_data,
325
+ instance=instance,
326
+ instance_assigned=True,
288
327
  )
289
328
 
290
329
  await process_running_jobs()
@@ -334,12 +373,21 @@ class TestProcessRunningJobs:
334
373
  repo=repo,
335
374
  user=user,
336
375
  )
376
+ pool = await create_pool(session=session, project=project)
377
+ instance = await create_instance(
378
+ session=session,
379
+ project=project,
380
+ pool=pool,
381
+ status=InstanceStatus.BUSY,
382
+ )
337
383
  job = await create_job(
338
384
  session=session,
339
385
  run=run,
340
386
  status=JobStatus.PULLING,
341
387
  job_provisioning_data=get_job_provisioning_data(dockerized=True),
342
388
  job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None),
389
+ instance=instance,
390
+ instance_assigned=True,
343
391
  )
344
392
  shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING
345
393
  shim_client_mock.get_task.return_value.ports = [
@@ -382,6 +430,13 @@ class TestProcessRunningJobs:
382
430
  repo=repo,
383
431
  user=user,
384
432
  )
433
+ pool = await create_pool(session=session, project=project)
434
+ instance = await create_instance(
435
+ session=session,
436
+ project=project,
437
+ pool=pool,
438
+ status=InstanceStatus.BUSY,
439
+ )
385
440
  job_provisioning_data = get_job_provisioning_data(dockerized=True)
386
441
  job = await create_job(
387
442
  session=session,
@@ -389,6 +444,8 @@ class TestProcessRunningJobs:
389
444
  status=JobStatus.PULLING,
390
445
  job_provisioning_data=job_provisioning_data,
391
446
  job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None),
447
+ instance=instance,
448
+ instance_assigned=True,
392
449
  )
393
450
  shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING
394
451
  shim_client_mock.get_task.return_value.ports = None
@@ -467,12 +524,21 @@ class TestProcessRunningJobs:
467
524
  run_name="test-run",
468
525
  run_spec=run_spec,
469
526
  )
527
+ pool = await create_pool(session=session, project=project)
528
+ instance = await create_instance(
529
+ session=session,
530
+ project=project,
531
+ pool=pool,
532
+ status=InstanceStatus.BUSY,
533
+ )
470
534
  job = await create_job(
471
535
  session=session,
472
536
  run=run,
473
537
  status=JobStatus.PROVISIONING,
474
538
  job_provisioning_data=get_job_provisioning_data(dockerized=True),
475
539
  submitted_at=get_current_datetime(),
540
+ instance=instance,
541
+ instance_assigned=True,
476
542
  )
477
543
  monkeypatch.setattr(
478
544
  "dstack._internal.server.services.runner.ssh.SSHTunnel", Mock(return_value=MagicMock())
@@ -495,3 +561,129 @@ class TestProcessRunningJobs:
495
561
  shim_client_mock.stop.assert_called_once_with(force=True)
496
562
  await session.refresh(job)
497
563
  assert job.status == JobStatus.PROVISIONING
564
+
565
+ @pytest.mark.asyncio
566
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
567
+ @pytest.mark.parametrize(
568
+ (
569
+ "inactivity_duration",
570
+ "no_connections_secs",
571
+ "expected_status",
572
+ "expected_termination_reason",
573
+ "expected_inactivity_secs",
574
+ ),
575
+ [
576
+ pytest.param(
577
+ "1h",
578
+ 60 * 60 - 1,
579
+ JobStatus.RUNNING,
580
+ None,
581
+ 60 * 60 - 1,
582
+ id="duration-not-exceeded",
583
+ ),
584
+ pytest.param(
585
+ "1h",
586
+ 60 * 60,
587
+ JobStatus.TERMINATING,
588
+ JobTerminationReason.TERMINATED_BY_SERVER,
589
+ 60 * 60,
590
+ id="duration-exceeded-exactly",
591
+ ),
592
+ pytest.param(
593
+ "1h",
594
+ 60 * 60 + 1,
595
+ JobStatus.TERMINATING,
596
+ JobTerminationReason.TERMINATED_BY_SERVER,
597
+ 60 * 60 + 1,
598
+ id="duration-exceeded",
599
+ ),
600
+ pytest.param("off", 60 * 60, JobStatus.RUNNING, None, None, id="duration-off"),
601
+ pytest.param(False, 60 * 60, JobStatus.RUNNING, None, None, id="duration-false"),
602
+ pytest.param(None, 60 * 60, JobStatus.RUNNING, None, None, id="duration-none"),
603
+ pytest.param(
604
+ "1h",
605
+ None,
606
+ JobStatus.TERMINATING,
607
+ JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY,
608
+ None,
609
+ id="legacy-runner",
610
+ ),
611
+ pytest.param(
612
+ None,
613
+ None,
614
+ JobStatus.RUNNING,
615
+ None,
616
+ None,
617
+ id="legacy-runner-without-duration",
618
+ ),
619
+ ],
620
+ )
621
+ async def test_inactivity_duration(
622
+ self,
623
+ test_db,
624
+ session: AsyncSession,
625
+ inactivity_duration,
626
+ no_connections_secs: Optional[int],
627
+ expected_status: JobStatus,
628
+ expected_termination_reason: Optional[JobTerminationReason],
629
+ expected_inactivity_secs: Optional[int],
630
+ ) -> None:
631
+ project = await create_project(session=session)
632
+ user = await create_user(session=session)
633
+ repo = await create_repo(
634
+ session=session,
635
+ project_id=project.id,
636
+ )
637
+ run = await create_run(
638
+ session=session,
639
+ project=project,
640
+ repo=repo,
641
+ user=user,
642
+ status=RunStatus.RUNNING,
643
+ run_name="test-run",
644
+ run_spec=get_run_spec(
645
+ run_name="test-run",
646
+ repo_id=repo.name,
647
+ configuration=DevEnvironmentConfiguration(
648
+ name="test-run",
649
+ ide="vscode",
650
+ inactivity_duration=inactivity_duration,
651
+ ),
652
+ ),
653
+ )
654
+ pool = await create_pool(session=session, project=project)
655
+ instance = await create_instance(
656
+ session=session,
657
+ project=project,
658
+ pool=pool,
659
+ status=InstanceStatus.BUSY,
660
+ )
661
+ job = await create_job(
662
+ session=session,
663
+ run=run,
664
+ status=JobStatus.RUNNING,
665
+ job_provisioning_data=get_job_provisioning_data(),
666
+ instance=instance,
667
+ instance_assigned=True,
668
+ )
669
+ with (
670
+ patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
671
+ patch(
672
+ "dstack._internal.server.services.runner.client.RunnerClient"
673
+ ) as RunnerClientMock,
674
+ ):
675
+ runner_client_mock = RunnerClientMock.return_value
676
+ runner_client_mock.pull.return_value = PullResponse(
677
+ job_states=[],
678
+ job_logs=[],
679
+ runner_logs=[],
680
+ last_updated=0,
681
+ no_connections_secs=no_connections_secs,
682
+ )
683
+ await process_running_jobs()
684
+ SSHTunnelMock.assert_called_once()
685
+ runner_client_mock.pull.assert_called_once()
686
+ await session.refresh(job)
687
+ assert job.status == expected_status
688
+ assert job.termination_reason == expected_termination_reason
689
+ assert job.inactivity_secs == expected_inactivity_secs
@@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
8
8
 
9
9
  import dstack._internal.server.background.tasks.process_runs as process_runs
10
10
  from dstack._internal.core.models.configurations import ServiceConfiguration
11
+ from dstack._internal.core.models.instances import InstanceStatus
11
12
  from dstack._internal.core.models.profiles import Profile
12
13
  from dstack._internal.core.models.resources import Range
13
14
  from dstack._internal.core.models.runs import (
@@ -116,11 +117,20 @@ class TestProcessRuns:
116
117
  async def test_terminate_run_jobs(self, test_db, session: AsyncSession):
117
118
  run = await make_run(session, status=RunStatus.TERMINATING)
118
119
  run.termination_reason = RunTerminationReason.JOB_FAILED
120
+ pool = await create_pool(session=session, project=run.project)
121
+ instance = await create_instance(
122
+ session=session,
123
+ project=run.project,
124
+ pool=pool,
125
+ status=InstanceStatus.BUSY,
126
+ )
119
127
  job = await create_job(
120
128
  session=session,
121
129
  run=run,
122
130
  job_provisioning_data=get_job_provisioning_data(),
123
131
  status=JobStatus.RUNNING,
132
+ instance=instance,
133
+ instance_assigned=True,
124
134
  )
125
135
 
126
136
  with patch("dstack._internal.server.services.jobs._stop_runner") as stop_runner:
@@ -288,13 +298,27 @@ class TestProcessRunsReplicas:
288
298
 
289
299
  @pytest.mark.asyncio
290
300
  @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
291
- async def test_some_failed_to_terminating(self, test_db, session: AsyncSession):
301
+ @pytest.mark.parametrize(
302
+ ("job_status", "job_termination_reason"),
303
+ [
304
+ (JobStatus.FAILED, JobTerminationReason.CONTAINER_EXITED_WITH_ERROR),
305
+ (JobStatus.TERMINATING, JobTerminationReason.TERMINATED_BY_SERVER),
306
+ (JobStatus.TERMINATED, JobTerminationReason.TERMINATED_BY_SERVER),
307
+ ],
308
+ )
309
+ async def test_some_failed_to_terminating(
310
+ self,
311
+ test_db,
312
+ session: AsyncSession,
313
+ job_status: JobStatus,
314
+ job_termination_reason: JobTerminationReason,
315
+ ) -> None:
292
316
  run = await make_run(session, status=RunStatus.RUNNING, replicas=2)
293
317
  await create_job(
294
318
  session=session,
295
319
  run=run,
296
- status=JobStatus.FAILED,
297
- termination_reason=JobTerminationReason.CONTAINER_EXITED_WITH_ERROR,
320
+ status=job_status,
321
+ termination_reason=job_termination_reason,
298
322
  replica_num=0,
299
323
  )
300
324
  await create_job(session=session, run=run, status=JobStatus.RUNNING, replica_num=1)
@@ -38,6 +38,7 @@ from dstack._internal.server.testing.common import (
38
38
  create_run,
39
39
  create_user,
40
40
  create_volume,
41
+ get_instance_offer_with_availability,
41
42
  get_run_spec,
42
43
  get_volume_provisioning_data,
43
44
  )
@@ -447,7 +448,7 @@ class TestProcessSubmittedJobs:
447
448
  await process_submitted_jobs()
448
449
  await session.refresh(job)
449
450
  res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
450
- job = res.scalar_one()
451
+ job = res.unique().scalar_one()
451
452
  assert job.status == JobStatus.SUBMITTED
452
453
  assert (
453
454
  job.instance_assigned and job.instance is not None and job.instance.id == instance.id
@@ -517,13 +518,57 @@ class TestProcessSubmittedJobs:
517
518
  joinedload(JobModel.instance).selectinload(InstanceModel.volumes)
518
519
  )
519
520
  )
520
- job = res.scalar_one()
521
+ job = res.unique().scalar_one()
521
522
  assert job.status == JobStatus.PROVISIONING
522
523
  assert (
523
524
  job.instance_assigned and job.instance is not None and job.instance.id == instance.id
524
525
  )
525
526
  assert job.instance.volumes == [volume]
526
527
 
528
+ @pytest.mark.asyncio
529
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
530
+ async def test_assigns_job_to_shared_instance(self, test_db, session: AsyncSession):
531
+ project = await create_project(session)
532
+ user = await create_user(session)
533
+ pool = await create_pool(session=session, project=project)
534
+ repo = await create_repo(
535
+ session=session,
536
+ project_id=project.id,
537
+ )
538
+ offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128)
539
+ instance = await create_instance(
540
+ session=session,
541
+ project=project,
542
+ pool=pool,
543
+ status=InstanceStatus.IDLE,
544
+ offer=offer,
545
+ total_blocks=4,
546
+ busy_blocks=1,
547
+ )
548
+ await session.refresh(pool)
549
+ run = await create_run(
550
+ session=session,
551
+ project=project,
552
+ repo=repo,
553
+ user=user,
554
+ )
555
+ job = await create_job(
556
+ session=session,
557
+ run=run,
558
+ instance_assigned=False,
559
+ )
560
+ await process_submitted_jobs()
561
+ await session.refresh(job)
562
+ await session.refresh(instance)
563
+ res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
564
+ job = res.unique().scalar_one()
565
+ assert job.status == JobStatus.SUBMITTED
566
+ assert (
567
+ job.instance_assigned and job.instance is not None and job.instance.id == instance.id
568
+ )
569
+ assert instance.total_blocks == 4
570
+ assert instance.busy_blocks == 2
571
+
527
572
  @pytest.mark.asyncio
528
573
  @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
529
574
  async def test_creates_new_instance_in_existing_fleet(self, test_db, session: AsyncSession):
@@ -590,7 +635,7 @@ class TestProcessSubmittedJobs:
590
635
 
591
636
  await session.refresh(job)
592
637
  res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
593
- job = res.scalar_one()
638
+ job = res.unique().scalar_one()
594
639
  assert job.status == JobStatus.PROVISIONING
595
640
  assert job.instance is not None
596
641
  assert job.instance.instance_num == 1
@@ -24,17 +24,19 @@ from dstack._internal.server.testing.common import (
24
24
  create_run,
25
25
  create_user,
26
26
  create_volume,
27
+ get_instance_offer_with_availability,
27
28
  get_job_provisioning_data,
29
+ get_job_runtime_data,
30
+ get_volume_configuration,
28
31
  get_volume_provisioning_data,
29
32
  )
30
33
 
31
- pytestmark = pytest.mark.usefixtures("image_config_mock")
32
-
33
34
 
35
+ @pytest.mark.asyncio
36
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
37
+ @pytest.mark.usefixtures("test_db", "image_config_mock")
34
38
  class TestProcessTerminatingJobs:
35
- @pytest.mark.asyncio
36
- @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
37
- async def test_terminates_job(self, test_db, session: AsyncSession):
39
+ async def test_terminates_job(self, session: AsyncSession):
38
40
  project = await create_project(session=session)
39
41
  user = await create_user(session=session)
40
42
  pool = await create_pool(session=session, project=project)
@@ -73,9 +75,7 @@ class TestProcessTerminatingJobs:
73
75
  assert job is not None
74
76
  assert job.status == JobStatus.TERMINATED
75
77
 
76
- @pytest.mark.asyncio
77
- @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
78
- async def test_detaches_job_volumes(self, test_db, session: AsyncSession):
78
+ async def test_detaches_job_volumes(self, session: AsyncSession):
79
79
  project = await create_project(session=session)
80
80
  user = await create_user(session=session)
81
81
  pool = await create_pool(session=session, project=project)
@@ -122,9 +122,7 @@ class TestProcessTerminatingJobs:
122
122
  await session.refresh(job)
123
123
  assert job.status == JobStatus.TERMINATED
124
124
 
125
- @pytest.mark.asyncio
126
- @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
127
- async def test_force_detaches_job_volumes(self, test_db, session: AsyncSession):
125
+ async def test_force_detaches_job_volumes(self, session: AsyncSession):
128
126
  project = await create_project(session=session)
129
127
  user = await create_user(session=session)
130
128
  pool = await create_pool(session=session, project=project)
@@ -172,7 +170,7 @@ class TestProcessTerminatingJobs:
172
170
  backend_mock.compute.return_value.is_volume_detached.assert_called_once()
173
171
  await session.refresh(job)
174
172
  res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
175
- job = res.scalar_one()
173
+ job = res.unique().scalar_one()
176
174
  assert job.status == JobStatus.TERMINATING
177
175
  # The instance should be released even if detach fails
178
176
  # so that stuck volumes don't prevent the instance from terminating.
@@ -214,10 +212,125 @@ class TestProcessTerminatingJobs:
214
212
  await session.refresh(job)
215
213
  await session.refresh(instance)
216
214
  res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
217
- job = res.scalar_one()
215
+ job = res.unique().scalar_one()
218
216
  res = await session.execute(
219
217
  select(InstanceModel).options(joinedload(InstanceModel.volumes))
220
218
  )
221
219
  instance = res.unique().scalar_one()
222
220
  assert job.status == JobStatus.TERMINATED
223
221
  assert len(instance.volumes) == 0
222
+
223
+ async def test_terminates_job_on_shared_instance(self, session: AsyncSession):
224
+ project = await create_project(session)
225
+ user = await create_user(session)
226
+ pool = await create_pool(session=session, project=project)
227
+ repo = await create_repo(
228
+ session=session,
229
+ project_id=project.id,
230
+ )
231
+ instance = await create_instance(
232
+ session=session,
233
+ project=project,
234
+ pool=pool,
235
+ status=InstanceStatus.BUSY,
236
+ total_blocks=4,
237
+ busy_blocks=3,
238
+ )
239
+ await session.refresh(pool)
240
+ run = await create_run(
241
+ session=session,
242
+ project=project,
243
+ repo=repo,
244
+ user=user,
245
+ )
246
+ shared_offer = get_instance_offer_with_availability(blocks=2, total_blocks=4)
247
+ jrd = get_job_runtime_data(offer=shared_offer)
248
+ job = await create_job(
249
+ session=session,
250
+ run=run,
251
+ instance_assigned=True,
252
+ instance=instance,
253
+ job_runtime_data=jrd,
254
+ status=JobStatus.TERMINATING,
255
+ termination_reason=JobTerminationReason.TERMINATED_BY_USER,
256
+ )
257
+
258
+ await process_terminating_jobs()
259
+
260
+ await session.refresh(job)
261
+ await session.refresh(instance)
262
+ res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
263
+ job = res.unique().scalar_one()
264
+ assert job.status == JobStatus.TERMINATED
265
+ assert job.instance_assigned
266
+ assert job.instance is None
267
+ assert instance.total_blocks == 4
268
+ assert instance.busy_blocks == 1
269
+
270
+ async def test_detaches_job_volumes_on_shared_instance(self, session: AsyncSession):
271
+ project = await create_project(session=session)
272
+ user = await create_user(session=session)
273
+ pool = await create_pool(session=session, project=project)
274
+ volume_conf_1 = get_volume_configuration(name="vol-1")
275
+ volume_1 = await create_volume(
276
+ session=session,
277
+ project=project,
278
+ user=user,
279
+ status=VolumeStatus.ACTIVE,
280
+ backend=BackendType.AWS,
281
+ configuration=volume_conf_1,
282
+ volume_provisioning_data=get_volume_provisioning_data(),
283
+ )
284
+ volume_conf_2 = get_volume_configuration(name="vol-2")
285
+ volume_2 = await create_volume(
286
+ session=session,
287
+ project=project,
288
+ user=user,
289
+ status=VolumeStatus.ACTIVE,
290
+ backend=BackendType.AWS,
291
+ configuration=volume_conf_2,
292
+ volume_provisioning_data=get_volume_provisioning_data(),
293
+ )
294
+ instance = await create_instance(
295
+ session=session,
296
+ project=project,
297
+ pool=pool,
298
+ status=InstanceStatus.BUSY,
299
+ volumes=[volume_1, volume_2],
300
+ )
301
+ repo = await create_repo(session=session, project_id=project.id)
302
+ run = await create_run(
303
+ session=session,
304
+ project=project,
305
+ repo=repo,
306
+ user=user,
307
+ )
308
+ job_provisioning_data = get_job_provisioning_data(dockerized=False)
309
+ job = await create_job(
310
+ session=session,
311
+ run=run,
312
+ status=JobStatus.TERMINATING,
313
+ termination_reason=JobTerminationReason.TERMINATED_BY_USER,
314
+ submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc),
315
+ job_provisioning_data=job_provisioning_data,
316
+ job_runtime_data=get_job_runtime_data(volume_names=["vol-1"]),
317
+ instance=instance,
318
+ )
319
+ with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m:
320
+ backend_mock = Mock()
321
+ m.return_value = backend_mock
322
+ backend_mock.compute.return_value.is_volume_detached.return_value = True
323
+
324
+ await process_terminating_jobs()
325
+
326
+ m.assert_awaited_once()
327
+ backend_mock.compute.return_value.detach_volume.assert_called_once()
328
+ backend_mock.compute.return_value.is_volume_detached.assert_called_once()
329
+ await session.refresh(job)
330
+ await session.refresh(instance)
331
+ assert job.status == JobStatus.TERMINATED
332
+ res = await session.execute(
333
+ select(InstanceModel).options(joinedload(InstanceModel.volumes))
334
+ )
335
+ instance = res.unique().scalar_one()
336
+ assert instance.volumes == [volume_2]