dstack 0.18.41__py3-none-any.whl → 0.18.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. dstack/_internal/cli/commands/__init__.py +2 -1
  2. dstack/_internal/cli/commands/apply.py +4 -2
  3. dstack/_internal/cli/commands/attach.py +21 -1
  4. dstack/_internal/cli/commands/completion.py +20 -0
  5. dstack/_internal/cli/commands/delete.py +3 -1
  6. dstack/_internal/cli/commands/fleet.py +2 -1
  7. dstack/_internal/cli/commands/gateway.py +7 -2
  8. dstack/_internal/cli/commands/logs.py +3 -2
  9. dstack/_internal/cli/commands/stats.py +2 -1
  10. dstack/_internal/cli/commands/stop.py +2 -1
  11. dstack/_internal/cli/commands/volume.py +2 -1
  12. dstack/_internal/cli/main.py +6 -0
  13. dstack/_internal/cli/services/completion.py +86 -0
  14. dstack/_internal/cli/services/configurators/run.py +10 -17
  15. dstack/_internal/cli/utils/fleet.py +5 -1
  16. dstack/_internal/cli/utils/volume.py +9 -0
  17. dstack/_internal/core/backends/aws/compute.py +24 -11
  18. dstack/_internal/core/backends/aws/resources.py +3 -3
  19. dstack/_internal/core/backends/azure/compute.py +14 -8
  20. dstack/_internal/core/backends/azure/resources.py +2 -0
  21. dstack/_internal/core/backends/base/compute.py +102 -2
  22. dstack/_internal/core/backends/base/offers.py +7 -1
  23. dstack/_internal/core/backends/cudo/compute.py +8 -4
  24. dstack/_internal/core/backends/datacrunch/compute.py +10 -4
  25. dstack/_internal/core/backends/gcp/auth.py +19 -13
  26. dstack/_internal/core/backends/gcp/compute.py +27 -20
  27. dstack/_internal/core/backends/gcp/resources.py +3 -10
  28. dstack/_internal/core/backends/kubernetes/compute.py +4 -3
  29. dstack/_internal/core/backends/lambdalabs/compute.py +9 -3
  30. dstack/_internal/core/backends/nebius/compute.py +2 -2
  31. dstack/_internal/core/backends/oci/compute.py +10 -4
  32. dstack/_internal/core/backends/runpod/compute.py +11 -4
  33. dstack/_internal/core/backends/tensordock/compute.py +14 -3
  34. dstack/_internal/core/backends/vastai/compute.py +12 -2
  35. dstack/_internal/core/backends/vultr/api_client.py +3 -3
  36. dstack/_internal/core/backends/vultr/compute.py +9 -3
  37. dstack/_internal/core/models/backends/aws.py +2 -0
  38. dstack/_internal/core/models/backends/base.py +1 -0
  39. dstack/_internal/core/models/configurations.py +0 -1
  40. dstack/_internal/core/models/runs.py +3 -3
  41. dstack/_internal/core/models/volumes.py +23 -0
  42. dstack/_internal/core/services/__init__.py +5 -1
  43. dstack/_internal/core/services/configs/__init__.py +3 -0
  44. dstack/_internal/server/background/tasks/common.py +22 -0
  45. dstack/_internal/server/background/tasks/process_instances.py +13 -21
  46. dstack/_internal/server/background/tasks/process_running_jobs.py +13 -16
  47. dstack/_internal/server/background/tasks/process_submitted_jobs.py +12 -7
  48. dstack/_internal/server/background/tasks/process_terminating_jobs.py +7 -2
  49. dstack/_internal/server/background/tasks/process_volumes.py +11 -1
  50. dstack/_internal/server/migrations/versions/a751ef183f27_move_attachment_data_to_volumes_.py +34 -0
  51. dstack/_internal/server/models.py +17 -19
  52. dstack/_internal/server/routers/logs.py +3 -0
  53. dstack/_internal/server/services/backends/configurators/aws.py +31 -1
  54. dstack/_internal/server/services/backends/configurators/gcp.py +8 -15
  55. dstack/_internal/server/services/config.py +11 -1
  56. dstack/_internal/server/services/fleets.py +5 -1
  57. dstack/_internal/server/services/jobs/__init__.py +14 -11
  58. dstack/_internal/server/services/jobs/configurators/dev.py +1 -3
  59. dstack/_internal/server/services/jobs/configurators/task.py +1 -3
  60. dstack/_internal/server/services/logs/__init__.py +78 -0
  61. dstack/_internal/server/services/{logs.py → logs/aws.py} +12 -207
  62. dstack/_internal/server/services/logs/base.py +47 -0
  63. dstack/_internal/server/services/logs/filelog.py +110 -0
  64. dstack/_internal/server/services/logs/gcp.py +165 -0
  65. dstack/_internal/server/services/offers.py +7 -7
  66. dstack/_internal/server/services/pools.py +19 -20
  67. dstack/_internal/server/services/proxy/routers/service_proxy.py +14 -7
  68. dstack/_internal/server/services/runner/client.py +8 -5
  69. dstack/_internal/server/services/volumes.py +68 -9
  70. dstack/_internal/server/settings.py +3 -0
  71. dstack/_internal/server/statics/index.html +1 -1
  72. dstack/_internal/server/statics/{main-ad5150a441de98cd8987.css → main-7510e71dfa9749a4e70e.css} +1 -1
  73. dstack/_internal/server/statics/{main-2ac66bfcbd2e39830b88.js → main-fe8fd9db55df8d10e648.js} +66 -66
  74. dstack/_internal/server/statics/{main-2ac66bfcbd2e39830b88.js.map → main-fe8fd9db55df8d10e648.js.map} +1 -1
  75. dstack/_internal/server/testing/common.py +46 -17
  76. dstack/api/_public/runs.py +1 -1
  77. dstack/version.py +2 -2
  78. {dstack-0.18.41.dist-info → dstack-0.18.43.dist-info}/METADATA +4 -3
  79. {dstack-0.18.41.dist-info → dstack-0.18.43.dist-info}/RECORD +97 -86
  80. tests/_internal/core/backends/base/__init__.py +0 -0
  81. tests/_internal/core/backends/base/test_compute.py +56 -0
  82. tests/_internal/server/background/tasks/test_process_running_jobs.py +2 -1
  83. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +5 -3
  84. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +11 -6
  85. tests/_internal/server/conftest.py +4 -5
  86. tests/_internal/server/routers/test_backends.py +1 -0
  87. tests/_internal/server/routers/test_logs.py +1 -1
  88. tests/_internal/server/routers/test_runs.py +2 -2
  89. tests/_internal/server/routers/test_volumes.py +9 -2
  90. tests/_internal/server/services/runner/test_client.py +22 -3
  91. tests/_internal/server/services/test_logs.py +3 -3
  92. tests/_internal/server/services/test_offers.py +167 -0
  93. tests/_internal/server/services/test_pools.py +105 -1
  94. {dstack-0.18.41.dist-info → dstack-0.18.43.dist-info}/LICENSE.md +0 -0
  95. {dstack-0.18.41.dist-info → dstack-0.18.43.dist-info}/WHEEL +0 -0
  96. {dstack-0.18.41.dist-info → dstack-0.18.43.dist-info}/entry_points.txt +0 -0
  97. {dstack-0.18.41.dist-info → dstack-0.18.43.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@ from dstack._internal.core.models.volumes import VolumeStatus
13
13
  from dstack._internal.server.background.tasks.process_terminating_jobs import (
14
14
  process_terminating_jobs,
15
15
  )
16
- from dstack._internal.server.models import InstanceModel, JobModel
16
+ from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel
17
17
  from dstack._internal.server.services.volumes import volume_model_to_volume
18
18
  from dstack._internal.server.testing.common import (
19
19
  create_instance,
@@ -176,7 +176,7 @@ class TestProcessTerminatingJobs:
176
176
  # so that stuck volumes don't prevent the instance from terminating.
177
177
  assert job.instance is None
178
178
  assert job.volumes_detached_at is not None
179
- assert len(instance.volumes) == 1
179
+ assert len(instance.volume_attachments) == 1
180
180
 
181
181
  # Force detach called
182
182
  with (
@@ -214,11 +214,11 @@ class TestProcessTerminatingJobs:
214
214
  res = await session.execute(select(JobModel).options(joinedload(JobModel.instance)))
215
215
  job = res.unique().scalar_one()
216
216
  res = await session.execute(
217
- select(InstanceModel).options(joinedload(InstanceModel.volumes))
217
+ select(InstanceModel).options(joinedload(InstanceModel.volume_attachments))
218
218
  )
219
219
  instance = res.unique().scalar_one()
220
220
  assert job.status == JobStatus.TERMINATED
221
- assert len(instance.volumes) == 0
221
+ assert len(instance.volume_attachments) == 0
222
222
 
223
223
  async def test_terminates_job_on_shared_instance(self, session: AsyncSession):
224
224
  project = await create_project(session)
@@ -330,7 +330,12 @@ class TestProcessTerminatingJobs:
330
330
  await session.refresh(instance)
331
331
  assert job.status == JobStatus.TERMINATED
332
332
  res = await session.execute(
333
- select(InstanceModel).options(joinedload(InstanceModel.volumes))
333
+ select(InstanceModel).options(
334
+ joinedload(InstanceModel.volume_attachments).joinedload(
335
+ VolumeAttachmentModel.volume
336
+ )
337
+ )
334
338
  )
335
339
  instance = res.unique().scalar_one()
336
- assert instance.volumes == [volume_2]
340
+ assert len(instance.volume_attachments) == 1
341
+ assert instance.volume_attachments[0].volume == volume_2
@@ -8,6 +8,7 @@ from dstack._internal.server.main import app
8
8
  from dstack._internal.server.services import encryption as encryption # import for side-effect
9
9
  from dstack._internal.server.services import logs as logs_services
10
10
  from dstack._internal.server.services.docker import ImageConfig, ImageConfigObject
11
+ from dstack._internal.server.services.logs.filelog import FileLogStorage
11
12
  from dstack._internal.server.testing.conf import postgres_container, session, test_db # noqa: F401
12
13
 
13
14
 
@@ -18,13 +19,11 @@ def client(event_loop):
18
19
 
19
20
 
20
21
  @pytest.fixture
21
- def test_log_storage(
22
- tmp_path: Path, monkeypatch: pytest.MonkeyPatch
23
- ) -> logs_services.FileLogStorage:
22
+ def test_log_storage(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> FileLogStorage:
24
23
  root = tmp_path / "test_logs"
25
24
  root.mkdir()
26
- storage = logs_services.FileLogStorage(root)
27
- monkeypatch.setattr(logs_services, "_default_log_storage", storage)
25
+ storage = FileLogStorage(root)
26
+ monkeypatch.setattr(logs_services, "_log_storage", storage)
28
27
  return storage
29
28
 
30
29
 
@@ -1335,6 +1335,7 @@ class TestGetConfigInfo:
1335
1335
  "vpc_ids": None,
1336
1336
  "default_vpcs": None,
1337
1337
  "public_ips": None,
1338
+ "iam_instance_profile": None,
1338
1339
  "tags": None,
1339
1340
  "os_images": None,
1340
1341
  "creds": json.loads(backend.auth.plaintext),
@@ -3,7 +3,7 @@ from httpx import AsyncClient
3
3
  from sqlalchemy.ext.asyncio import AsyncSession
4
4
 
5
5
  from dstack._internal.core.models.users import GlobalRole, ProjectRole
6
- from dstack._internal.server.services.logs import FileLogStorage
6
+ from dstack._internal.server.services.logs.filelog import FileLogStorage
7
7
  from dstack._internal.server.services.projects import add_project_member
8
8
  from dstack._internal.server.testing.common import create_project, create_user, get_auth_headers
9
9
 
@@ -180,7 +180,7 @@ def get_dev_env_run_plan_dict(
180
180
  ],
181
181
  "env": {},
182
182
  "home_dir": "/root",
183
- "image_name": "dstackai/base:py3.13-0.6-cuda-12.1",
183
+ "image_name": "dstackai/base:py3.13-0.7-cuda-12.1",
184
184
  "user": None,
185
185
  "privileged": privileged,
186
186
  "job_name": f"{run_name}-0-0",
@@ -337,7 +337,7 @@ def get_dev_env_run_dict(
337
337
  ],
338
338
  "env": {},
339
339
  "home_dir": "/root",
340
- "image_name": "dstackai/base:py3.13-0.6-cuda-12.1",
340
+ "image_name": "dstackai/base:py3.13-0.7-cuda-12.1",
341
341
  "user": None,
342
342
  "privileged": privileged,
343
343
  "job_name": f"{run_name}-0-0",
@@ -11,7 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
11
11
 
12
12
  from dstack._internal.core.models.backends.base import BackendType
13
13
  from dstack._internal.core.models.users import GlobalRole, ProjectRole
14
- from dstack._internal.server.models import VolumeModel
14
+ from dstack._internal.server.models import VolumeAttachmentModel, VolumeModel
15
15
  from dstack._internal.server.services.projects import add_project_member
16
16
  from dstack._internal.server.testing.common import (
17
17
  create_instance,
@@ -76,6 +76,7 @@ class TestListVolumes:
76
76
  "deleted": False,
77
77
  "volume_id": None,
78
78
  "provisioning_data": None,
79
+ "attachments": [],
79
80
  "attachment_data": None,
80
81
  },
81
82
  {
@@ -91,6 +92,7 @@ class TestListVolumes:
91
92
  "deleted": False,
92
93
  "volume_id": None,
93
94
  "provisioning_data": None,
95
+ "attachments": [],
94
96
  "attachment_data": None,
95
97
  },
96
98
  ]
@@ -117,6 +119,7 @@ class TestListVolumes:
117
119
  "deleted": False,
118
120
  "volume_id": None,
119
121
  "provisioning_data": None,
122
+ "attachments": [],
120
123
  "attachment_data": None,
121
124
  },
122
125
  ]
@@ -170,6 +173,7 @@ class TestListVolumes:
170
173
  "deleted": False,
171
174
  "volume_id": None,
172
175
  "provisioning_data": None,
176
+ "attachments": [],
173
177
  "attachment_data": None,
174
178
  },
175
179
  ]
@@ -217,6 +221,7 @@ class TestListProjectVolumes:
217
221
  "deleted": False,
218
222
  "volume_id": None,
219
223
  "provisioning_data": None,
224
+ "attachments": [],
220
225
  "attachment_data": None,
221
226
  }
222
227
  ]
@@ -264,6 +269,7 @@ class TestGetVolume:
264
269
  "deleted": False,
265
270
  "volume_id": None,
266
271
  "provisioning_data": None,
272
+ "attachments": [],
267
273
  "attachment_data": None,
268
274
  }
269
275
 
@@ -325,6 +331,7 @@ class TestCreateVolume:
325
331
  "deleted": False,
326
332
  "volume_id": None,
327
333
  "provisioning_data": None,
334
+ "attachments": [],
328
335
  "attachment_data": None,
329
336
  }
330
337
  res = await session.execute(select(VolumeModel))
@@ -391,7 +398,7 @@ class TestDeleteVolumes:
391
398
  project=project,
392
399
  pool=pool,
393
400
  )
394
- volume.instances.append(instance)
401
+ volume.attachments.append(VolumeAttachmentModel(instance=instance))
395
402
  await session.commit()
396
403
  response = await client.post(
397
404
  f"/api/project/{project.name}/volumes/delete",
@@ -9,7 +9,13 @@ from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT
9
9
  from dstack._internal.core.models.backends.base import BackendType
10
10
  from dstack._internal.core.models.common import NetworkMode
11
11
  from dstack._internal.core.models.resources import Memory
12
- from dstack._internal.core.models.volumes import InstanceMountPoint, VolumeMountPoint
12
+ from dstack._internal.core.models.volumes import (
13
+ InstanceMountPoint,
14
+ VolumeAttachment,
15
+ VolumeAttachmentData,
16
+ VolumeInstance,
17
+ VolumeMountPoint,
18
+ )
13
19
  from dstack._internal.server.schemas.runner import (
14
20
  HealthcheckResponse,
15
21
  JobResult,
@@ -133,7 +139,12 @@ class TestShimClientV1(BaseShimClientTest):
133
139
  volume_id="vol-id",
134
140
  configuration=get_volume_configuration(backend=BackendType.GCP),
135
141
  external=False,
136
- device_name="/dev/sdv",
142
+ attachments=[
143
+ VolumeAttachment(
144
+ instance=VolumeInstance(name="instance", instance_num=0, instance_id="i-1"),
145
+ attachment_data=VolumeAttachmentData(device_name="/dev/sdv"),
146
+ )
147
+ ],
137
148
  )
138
149
 
139
150
  submitted = client.submit(
@@ -150,6 +161,7 @@ class TestShimClientV1(BaseShimClientTest):
150
161
  mounts=[VolumeMountPoint(name="vol", path="/vol")],
151
162
  volumes=[volume],
152
163
  instance_mounts=[InstanceMountPoint(instance_path="/mnt/nfs/home", path="/home")],
164
+ instance_id="i-1",
153
165
  )
154
166
 
155
167
  assert submitted is True
@@ -198,6 +210,7 @@ class TestShimClientV1(BaseShimClientTest):
198
210
  mounts=[],
199
211
  volumes=[],
200
212
  instance_mounts=[],
213
+ instance_id="",
201
214
  )
202
215
 
203
216
  assert submitted is False
@@ -294,7 +307,12 @@ class TestShimClientV2(BaseShimClientTest):
294
307
  volume_id="vol-id",
295
308
  configuration=get_volume_configuration(backend=BackendType.GCP),
296
309
  external=False,
297
- device_name="/dev/sdv",
310
+ attachments=[
311
+ VolumeAttachment(
312
+ instance=VolumeInstance(name="instance", instance_num=0, instance_id="i-1"),
313
+ attachment_data=VolumeAttachmentData(device_name="/dev/sdv"),
314
+ )
315
+ ],
298
316
  )
299
317
 
300
318
  client.submit_task(
@@ -316,6 +334,7 @@ class TestShimClientV2(BaseShimClientTest):
316
334
  host_ssh_user="dstack",
317
335
  host_ssh_keys=["host_key"],
318
336
  container_ssh_keys=["project_key", "user_key"],
337
+ instance_id="i-1",
319
338
  )
320
339
 
321
340
  assert adapter.call_count == 2
@@ -17,11 +17,11 @@ from dstack._internal.core.models.logs import LogEvent, LogEventSource
17
17
  from dstack._internal.server.models import ProjectModel
18
18
  from dstack._internal.server.schemas.logs import PollLogsRequest
19
19
  from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent
20
- from dstack._internal.server.services.logs import (
20
+ from dstack._internal.server.services.logs.aws import (
21
21
  CloudWatchLogStorage,
22
- FileLogStorage,
23
- LogStorageError,
24
22
  )
23
+ from dstack._internal.server.services.logs.base import LogStorageError
24
+ from dstack._internal.server.services.logs.filelog import FileLogStorage
25
25
  from dstack._internal.server.testing.common import create_project
26
26
 
27
27
 
@@ -0,0 +1,167 @@
1
+ from unittest.mock import Mock, patch
2
+
3
+ import pytest
4
+
5
+ from dstack._internal.core.models.backends.base import BackendType
6
+ from dstack._internal.core.models.profiles import Profile
7
+ from dstack._internal.core.models.resources import ResourcesSpec
8
+ from dstack._internal.core.models.runs import Requirements
9
+ from dstack._internal.server.services.offers import get_offers_by_requirements
10
+ from dstack._internal.server.testing.common import (
11
+ get_instance_offer_with_availability,
12
+ get_volume,
13
+ get_volume_configuration,
14
+ )
15
+
16
+
17
+ class TestGetOffersByRequirements:
18
+ @pytest.mark.asyncio
19
+ async def test_returns_all_offers(self):
20
+ profile = Profile(name="test")
21
+ requirements = Requirements(resources=ResourcesSpec())
22
+ with patch("dstack._internal.server.services.backends.get_project_backends") as m:
23
+ aws_backend_mock = Mock()
24
+ aws_backend_mock.TYPE = BackendType.AWS
25
+ aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS)
26
+ aws_backend_mock.compute.return_value.get_offers_cached.return_value = [aws_offer]
27
+ runpod_backend_mock = Mock()
28
+ runpod_backend_mock.TYPE = BackendType.RUNPOD
29
+ runpod_offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD)
30
+ runpod_backend_mock.compute.return_value.get_offers_cached.return_value = [
31
+ runpod_offer
32
+ ]
33
+ m.return_value = [aws_backend_mock, runpod_backend_mock]
34
+ res = await get_offers_by_requirements(
35
+ project=Mock(),
36
+ profile=profile,
37
+ requirements=requirements,
38
+ )
39
+ m.assert_awaited_once()
40
+ assert res == [(aws_backend_mock, aws_offer), (runpod_backend_mock, runpod_offer)]
41
+
42
+ @pytest.mark.asyncio
43
+ async def test_returns_multinode_offers(self):
44
+ profile = Profile(name="test")
45
+ requirements = Requirements(resources=ResourcesSpec())
46
+ with patch("dstack._internal.server.services.backends.get_project_backends") as m:
47
+ aws_backend_mock = Mock()
48
+ aws_backend_mock.TYPE = BackendType.AWS
49
+ aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS)
50
+ aws_backend_mock.compute.return_value.get_offers_cached.return_value = [aws_offer]
51
+ runpod_backend_mock = Mock()
52
+ runpod_backend_mock.TYPE = BackendType.RUNPOD
53
+ runpod_offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD)
54
+ runpod_backend_mock.compute.return_value.get_offers_cached.return_value = [
55
+ runpod_offer
56
+ ]
57
+ m.return_value = [aws_backend_mock, runpod_backend_mock]
58
+ res = await get_offers_by_requirements(
59
+ project=Mock(),
60
+ profile=profile,
61
+ requirements=requirements,
62
+ multinode=True,
63
+ )
64
+ m.assert_awaited_once()
65
+ assert res == [(aws_backend_mock, aws_offer)]
66
+
67
+ @pytest.mark.asyncio
68
+ async def test_returns_volume_offers(self):
69
+ profile = Profile(name="test")
70
+ requirements = Requirements(resources=ResourcesSpec())
71
+ with patch("dstack._internal.server.services.backends.get_project_backends") as m:
72
+ aws_backend_mock = Mock()
73
+ aws_backend_mock.TYPE = BackendType.AWS
74
+ aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS)
75
+ aws_backend_mock.compute.return_value.get_offers_cached.return_value = [aws_offer]
76
+ runpod_backend_mock = Mock()
77
+ runpod_backend_mock.TYPE = BackendType.RUNPOD
78
+ runpod_offer1 = get_instance_offer_with_availability(
79
+ backend=BackendType.RUNPOD, region="eu"
80
+ )
81
+ runpod_offer2 = get_instance_offer_with_availability(
82
+ backend=BackendType.RUNPOD, region="us"
83
+ )
84
+ runpod_backend_mock.compute.return_value.get_offers_cached.return_value = [
85
+ runpod_offer1,
86
+ runpod_offer2,
87
+ ]
88
+ m.return_value = [aws_backend_mock, runpod_backend_mock]
89
+ res = await get_offers_by_requirements(
90
+ project=Mock(),
91
+ profile=profile,
92
+ requirements=requirements,
93
+ volumes=[
94
+ [
95
+ get_volume(
96
+ configuration=get_volume_configuration(
97
+ backend=BackendType.RUNPOD, region="us"
98
+ )
99
+ )
100
+ ]
101
+ ],
102
+ )
103
+ m.assert_awaited_once()
104
+ assert res == [(runpod_backend_mock, runpod_offer2)]
105
+
106
+ @pytest.mark.asyncio
107
+ async def test_returns_az_offers(self):
108
+ profile = Profile(name="test", availability_zones=["az1", "az3"])
109
+ requirements = Requirements(resources=ResourcesSpec())
110
+ with patch("dstack._internal.server.services.backends.get_project_backends") as m:
111
+ aws_backend_mock = Mock()
112
+ aws_backend_mock.TYPE = BackendType.AWS
113
+ aws_offer1 = get_instance_offer_with_availability(
114
+ backend=BackendType.AWS, availability_zones=["az1"]
115
+ )
116
+ aws_offer2 = get_instance_offer_with_availability(
117
+ backend=BackendType.AWS, availability_zones=["az2"]
118
+ )
119
+ aws_offer3 = get_instance_offer_with_availability(
120
+ backend=BackendType.AWS, availability_zones=["az2", "az3"]
121
+ )
122
+ expected_aws_offer3 = aws_offer3.copy()
123
+ expected_aws_offer3.availability_zones = ["az3"]
124
+ aws_offer4 = get_instance_offer_with_availability(
125
+ backend=BackendType.AWS, availability_zones=None
126
+ )
127
+ aws_backend_mock.compute.return_value.get_offers_cached.return_value = [
128
+ aws_offer1,
129
+ aws_offer2,
130
+ aws_offer3,
131
+ aws_offer4,
132
+ ]
133
+ m.return_value = [aws_backend_mock]
134
+ res = await get_offers_by_requirements(
135
+ project=Mock(),
136
+ profile=profile,
137
+ requirements=requirements,
138
+ )
139
+ m.assert_awaited_once()
140
+ assert res == [(aws_backend_mock, aws_offer1), (aws_backend_mock, expected_aws_offer3)]
141
+
142
+ @pytest.mark.asyncio
143
+ async def test_returns_no_offers_for_multinode_instance_mounts_and_non_multinode_backend(self):
144
+ # Regression test for https://github.com/dstackai/dstack/issues/2211
145
+ profile = Profile(name="test", backends=[BackendType.RUNPOD])
146
+ requirements = Requirements(resources=ResourcesSpec())
147
+ with patch("dstack._internal.server.services.backends.get_project_backends") as m:
148
+ aws_backend_mock = Mock()
149
+ aws_backend_mock.TYPE = BackendType.AWS
150
+ aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS)
151
+ aws_backend_mock.compute.return_value.get_offers_cached.return_value = [aws_offer]
152
+ runpod_backend_mock = Mock()
153
+ runpod_backend_mock.TYPE = BackendType.RUNPOD
154
+ runpod_offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD)
155
+ runpod_backend_mock.compute.return_value.get_offers_cached.return_value = [
156
+ runpod_offer
157
+ ]
158
+ m.return_value = [aws_backend_mock, runpod_backend_mock]
159
+ res = await get_offers_by_requirements(
160
+ project=Mock(),
161
+ profile=profile,
162
+ requirements=requirements,
163
+ multinode=True,
164
+ instance_mounts=True,
165
+ )
166
+ m.assert_awaited_once()
167
+ assert res == []
@@ -7,11 +7,115 @@ import dstack._internal.server.services.pools as services_pools
7
7
  from dstack._internal.core.models.backends.base import BackendType
8
8
  from dstack._internal.core.models.instances import InstanceStatus, InstanceType, Resources
9
9
  from dstack._internal.core.models.pools import Instance
10
+ from dstack._internal.core.models.profiles import Profile
10
11
  from dstack._internal.server.models import InstanceModel
11
- from dstack._internal.server.testing.common import create_project, create_user
12
+ from dstack._internal.server.testing.common import (
13
+ create_instance,
14
+ create_pool,
15
+ create_project,
16
+ create_user,
17
+ get_volume,
18
+ get_volume_configuration,
19
+ )
12
20
  from dstack._internal.utils.common import get_current_datetime
13
21
 
14
22
 
23
+ class TestFilterPoolInstances:
24
+ # TODO: Refactor filter_pool_instances to not depend on InstanceModel and simplify tests
25
+ @pytest.mark.asyncio
26
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
27
+ async def test_returns_all_instances(self, test_db, session: AsyncSession):
28
+ user = await create_user(session=session)
29
+ project = await create_project(session=session, owner=user)
30
+ pool = await create_pool(session=session, project=project)
31
+ aws_instance = await create_instance(
32
+ session=session,
33
+ project=project,
34
+ pool=pool,
35
+ backend=BackendType.AWS,
36
+ )
37
+ runpod_instance = await create_instance(
38
+ session=session,
39
+ project=project,
40
+ pool=pool,
41
+ backend=BackendType.RUNPOD,
42
+ )
43
+ instances = [aws_instance, runpod_instance]
44
+ res = services_pools.filter_pool_instances(
45
+ pool_instances=instances,
46
+ profile=Profile(name="test"),
47
+ )
48
+ assert res == instances
49
+
50
+ @pytest.mark.asyncio
51
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
52
+ async def test_returns_multinode_instances(self, test_db, session: AsyncSession):
53
+ user = await create_user(session=session)
54
+ project = await create_project(session=session, owner=user)
55
+ pool = await create_pool(session=session, project=project)
56
+ aws_instance = await create_instance(
57
+ session=session,
58
+ project=project,
59
+ pool=pool,
60
+ backend=BackendType.AWS,
61
+ )
62
+ runpod_instance = await create_instance(
63
+ session=session,
64
+ project=project,
65
+ pool=pool,
66
+ backend=BackendType.RUNPOD,
67
+ )
68
+ instances = [aws_instance, runpod_instance]
69
+ res = services_pools.filter_pool_instances(
70
+ pool_instances=instances,
71
+ profile=Profile(name="test"),
72
+ multinode=True,
73
+ )
74
+ assert res == [aws_instance]
75
+
76
+ @pytest.mark.asyncio
77
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
78
+ async def test_returns_volume_instances(self, test_db, session: AsyncSession):
79
+ user = await create_user(session=session)
80
+ project = await create_project(session=session, owner=user)
81
+ pool = await create_pool(session=session, project=project)
82
+ aws_instance = await create_instance(
83
+ session=session,
84
+ project=project,
85
+ pool=pool,
86
+ backend=BackendType.AWS,
87
+ )
88
+ runpod_instance1 = await create_instance(
89
+ session=session,
90
+ project=project,
91
+ pool=pool,
92
+ backend=BackendType.RUNPOD,
93
+ region="eu",
94
+ )
95
+ runpod_instance2 = await create_instance(
96
+ session=session,
97
+ project=project,
98
+ pool=pool,
99
+ backend=BackendType.RUNPOD,
100
+ region="us",
101
+ )
102
+ instances = [aws_instance, runpod_instance1, runpod_instance2]
103
+ res = services_pools.filter_pool_instances(
104
+ pool_instances=instances,
105
+ profile=Profile(name="test"),
106
+ volumes=[
107
+ [
108
+ get_volume(
109
+ configuration=get_volume_configuration(
110
+ backend=BackendType.RUNPOD, region="us"
111
+ )
112
+ )
113
+ ]
114
+ ],
115
+ )
116
+ assert res == [runpod_instance2]
117
+
118
+
15
119
  class TestGenerateInstanceName:
16
120
  @pytest.mark.asyncio
17
121
  @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)