dstack 0.18.40rc1__py3-none-any.whl → 0.18.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -5
  2. dstack/_internal/cli/services/configurators/base.py +4 -2
  3. dstack/_internal/cli/services/configurators/fleet.py +21 -9
  4. dstack/_internal/cli/services/configurators/gateway.py +15 -0
  5. dstack/_internal/cli/services/configurators/run.py +6 -5
  6. dstack/_internal/cli/services/configurators/volume.py +15 -0
  7. dstack/_internal/cli/services/repos.py +3 -3
  8. dstack/_internal/cli/utils/fleet.py +44 -33
  9. dstack/_internal/cli/utils/run.py +27 -7
  10. dstack/_internal/cli/utils/volume.py +30 -9
  11. dstack/_internal/core/backends/aws/compute.py +94 -53
  12. dstack/_internal/core/backends/aws/resources.py +22 -12
  13. dstack/_internal/core/backends/azure/compute.py +2 -0
  14. dstack/_internal/core/backends/base/compute.py +20 -2
  15. dstack/_internal/core/backends/gcp/compute.py +32 -24
  16. dstack/_internal/core/backends/gcp/resources.py +0 -15
  17. dstack/_internal/core/backends/oci/compute.py +10 -5
  18. dstack/_internal/core/backends/oci/resources.py +23 -26
  19. dstack/_internal/core/backends/remote/provisioning.py +65 -27
  20. dstack/_internal/core/backends/runpod/compute.py +1 -0
  21. dstack/_internal/core/models/backends/azure.py +3 -1
  22. dstack/_internal/core/models/configurations.py +24 -1
  23. dstack/_internal/core/models/fleets.py +46 -0
  24. dstack/_internal/core/models/instances.py +5 -1
  25. dstack/_internal/core/models/pools.py +4 -1
  26. dstack/_internal/core/models/profiles.py +10 -4
  27. dstack/_internal/core/models/runs.py +23 -3
  28. dstack/_internal/core/models/volumes.py +26 -0
  29. dstack/_internal/core/services/ssh/attach.py +92 -53
  30. dstack/_internal/core/services/ssh/tunnel.py +58 -31
  31. dstack/_internal/proxy/gateway/routers/registry.py +2 -0
  32. dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
  33. dstack/_internal/proxy/gateway/services/registry.py +4 -0
  34. dstack/_internal/proxy/lib/models.py +3 -0
  35. dstack/_internal/proxy/lib/services/service_connection.py +8 -1
  36. dstack/_internal/server/background/tasks/process_instances.py +73 -35
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -9
  38. dstack/_internal/server/background/tasks/process_running_jobs.py +77 -26
  39. dstack/_internal/server/background/tasks/process_runs.py +2 -12
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +121 -49
  41. dstack/_internal/server/background/tasks/process_terminating_jobs.py +14 -3
  42. dstack/_internal/server/background/tasks/process_volumes.py +11 -1
  43. dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
  44. dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
  45. dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
  46. dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
  47. dstack/_internal/server/migrations/versions/a751ef183f27_move_attachment_data_to_volumes_.py +34 -0
  48. dstack/_internal/server/models.py +27 -23
  49. dstack/_internal/server/routers/runs.py +1 -0
  50. dstack/_internal/server/schemas/runner.py +1 -0
  51. dstack/_internal/server/services/backends/configurators/azure.py +34 -8
  52. dstack/_internal/server/services/config.py +9 -0
  53. dstack/_internal/server/services/fleets.py +32 -3
  54. dstack/_internal/server/services/gateways/client.py +9 -1
  55. dstack/_internal/server/services/jobs/__init__.py +217 -45
  56. dstack/_internal/server/services/jobs/configurators/base.py +47 -2
  57. dstack/_internal/server/services/offers.py +96 -10
  58. dstack/_internal/server/services/pools.py +98 -14
  59. dstack/_internal/server/services/proxy/repo.py +17 -3
  60. dstack/_internal/server/services/runner/client.py +9 -6
  61. dstack/_internal/server/services/runner/ssh.py +33 -5
  62. dstack/_internal/server/services/runs.py +48 -179
  63. dstack/_internal/server/services/services/__init__.py +9 -1
  64. dstack/_internal/server/services/volumes.py +68 -9
  65. dstack/_internal/server/statics/index.html +1 -1
  66. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
  67. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
  68. dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
  69. dstack/_internal/server/testing/common.py +130 -61
  70. dstack/_internal/utils/common.py +22 -8
  71. dstack/_internal/utils/env.py +14 -0
  72. dstack/_internal/utils/ssh.py +1 -1
  73. dstack/api/server/_fleets.py +25 -1
  74. dstack/api/server/_runs.py +23 -2
  75. dstack/api/server/_volumes.py +12 -1
  76. dstack/version.py +1 -1
  77. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/METADATA +1 -1
  78. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/RECORD +104 -93
  79. tests/_internal/cli/services/configurators/test_profile.py +3 -3
  80. tests/_internal/core/services/ssh/test_tunnel.py +56 -4
  81. tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
  82. tests/_internal/server/background/tasks/test_process_instances.py +138 -20
  83. tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
  84. tests/_internal/server/background/tasks/test_process_running_jobs.py +193 -0
  85. tests/_internal/server/background/tasks/test_process_runs.py +27 -3
  86. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +53 -6
  87. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +135 -17
  88. tests/_internal/server/routers/test_fleets.py +15 -2
  89. tests/_internal/server/routers/test_pools.py +6 -0
  90. tests/_internal/server/routers/test_runs.py +27 -0
  91. tests/_internal/server/routers/test_volumes.py +9 -2
  92. tests/_internal/server/services/jobs/__init__.py +0 -0
  93. tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
  94. tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
  95. tests/_internal/server/services/runner/test_client.py +22 -3
  96. tests/_internal/server/services/test_offers.py +167 -0
  97. tests/_internal/server/services/test_pools.py +109 -1
  98. tests/_internal/server/services/test_runs.py +5 -41
  99. tests/_internal/utils/test_common.py +21 -0
  100. tests/_internal/utils/test_env.py +38 -0
  101. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/LICENSE.md +0 -0
  102. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/WHEEL +0 -0
  103. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/entry_points.txt +0 -0
  104. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/top_level.txt +0 -0
@@ -340,6 +340,7 @@ class TestCreateFleet:
340
340
  },
341
341
  "backends": None,
342
342
  "regions": None,
343
+ "availability_zones": None,
343
344
  "instance_types": None,
344
345
  "spot_policy": None,
345
346
  "retry": None,
@@ -350,10 +351,12 @@ class TestCreateFleet:
350
351
  "type": "fleet",
351
352
  "name": "test-fleet",
352
353
  "reservation": None,
354
+ "blocks": 1,
353
355
  },
354
356
  "profile": {
355
357
  "backends": None,
356
358
  "regions": None,
359
+ "availability_zones": None,
357
360
  "instance_types": None,
358
361
  "spot_policy": None,
359
362
  "retry": None,
@@ -393,15 +396,18 @@ class TestCreateFleet:
393
396
  "pool_name": None,
394
397
  "backend": None,
395
398
  "region": None,
399
+ "availability_zone": None,
396
400
  "instance_type": None,
397
401
  "price": None,
402
+ "total_blocks": 1,
403
+ "busy_blocks": 0,
398
404
  }
399
405
  ],
400
406
  }
401
407
  res = await session.execute(select(FleetModel))
402
408
  assert res.scalar_one()
403
409
  res = await session.execute(select(InstanceModel))
404
- assert res.scalar_one()
410
+ assert res.unique().scalar_one()
405
411
 
406
412
  @pytest.mark.asyncio
407
413
  @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@@ -444,6 +450,7 @@ class TestCreateFleet:
444
450
  "port": None,
445
451
  "identity_file": None,
446
452
  "ssh_key": None, # should not return ssh_key
453
+ "proxy_jump": None,
447
454
  "hosts": ["1.1.1.1"],
448
455
  "network": None,
449
456
  },
@@ -458,6 +465,7 @@ class TestCreateFleet:
458
465
  },
459
466
  "backends": None,
460
467
  "regions": None,
468
+ "availability_zones": None,
461
469
  "instance_types": None,
462
470
  "spot_policy": None,
463
471
  "retry": None,
@@ -468,10 +476,12 @@ class TestCreateFleet:
468
476
  "type": "fleet",
469
477
  "name": spec.configuration.name,
470
478
  "reservation": None,
479
+ "blocks": 1,
471
480
  },
472
481
  "profile": {
473
482
  "backends": None,
474
483
  "regions": None,
484
+ "availability_zones": None,
475
485
  "instance_types": None,
476
486
  "spot_policy": None,
477
487
  "retry": None,
@@ -522,14 +532,17 @@ class TestCreateFleet:
522
532
  "termination_reason": None,
523
533
  "created": "2023-01-02T03:04:00+00:00",
524
534
  "region": "remote",
535
+ "availability_zone": None,
525
536
  "price": 0.0,
537
+ "total_blocks": 1,
538
+ "busy_blocks": 0,
526
539
  }
527
540
  ],
528
541
  }
529
542
  res = await session.execute(select(FleetModel))
530
543
  assert res.scalar_one()
531
544
  res = await session.execute(select(InstanceModel))
532
- instance = res.scalar_one()
545
+ instance = res.unique().scalar_one()
533
546
  assert instance.remote_connection_info is not None
534
547
 
535
548
  @pytest.mark.asyncio
@@ -332,7 +332,10 @@ class TestShowPool:
332
332
  "created": "2023-01-02T03:04:00+00:00",
333
333
  "pool_name": None,
334
334
  "region": "en",
335
+ "availability_zone": None,
335
336
  "price": 1,
337
+ "total_blocks": 1,
338
+ "busy_blocks": 0,
336
339
  }
337
340
  ],
338
341
  }
@@ -503,7 +506,10 @@ class TestRemoveInstance:
503
506
  "created": "2023-01-02T03:04:00+00:00",
504
507
  "pool_name": None,
505
508
  "region": "en",
509
+ "availability_zone": None,
506
510
  "price": 1,
511
+ "total_blocks": 1,
512
+ "busy_blocks": 0,
507
513
  }
508
514
  ],
509
515
  }
@@ -18,6 +18,7 @@ from dstack._internal.core.models.gateways import GatewayStatus
18
18
  from dstack._internal.core.models.instances import (
19
19
  InstanceAvailability,
20
20
  InstanceOfferWithAvailability,
21
+ InstanceStatus,
21
22
  InstanceType,
22
23
  Resources,
23
24
  )
@@ -47,7 +48,9 @@ from dstack._internal.server.testing.common import (
47
48
  create_backend,
48
49
  create_gateway,
49
50
  create_gateway_compute,
51
+ create_instance,
50
52
  create_job,
53
+ create_pool,
51
54
  create_project,
52
55
  create_repo,
53
56
  create_run,
@@ -85,6 +88,7 @@ def get_dev_env_run_plan_dict(
85
88
  "working_dir": None,
86
89
  "home_dir": "/root",
87
90
  "ide": "vscode",
91
+ "inactivity_duration": None,
88
92
  "version": None,
89
93
  "image": None,
90
94
  "user": None,
@@ -107,6 +111,7 @@ def get_dev_env_run_plan_dict(
107
111
  "volumes": [json.loads(v.json()) for v in volumes],
108
112
  "backends": ["local", "aws", "azure", "gcp", "lambda", "runpod"],
109
113
  "regions": ["us"],
114
+ "availability_zones": None,
110
115
  "instance_types": None,
111
116
  "creation_policy": None,
112
117
  "instance_name": None,
@@ -127,6 +132,7 @@ def get_dev_env_run_plan_dict(
127
132
  "profile": {
128
133
  "backends": ["local", "aws", "azure", "gcp", "lambda", "runpod"],
129
134
  "regions": ["us"],
135
+ "availability_zones": None,
130
136
  "instance_types": None,
131
137
  "creation_policy": None,
132
138
  "default": False,
@@ -198,6 +204,7 @@ def get_dev_env_run_plan_dict(
198
204
  "reservation": None,
199
205
  },
200
206
  "retry": None,
207
+ "volumes": volumes,
201
208
  "retry_policy": {"retry": False, "duration": None},
202
209
  "working_dir": ".",
203
210
  },
@@ -238,6 +245,7 @@ def get_dev_env_run_dict(
238
245
  "home_dir": "/root",
239
246
  "working_dir": None,
240
247
  "ide": "vscode",
248
+ "inactivity_duration": None,
241
249
  "version": None,
242
250
  "image": None,
243
251
  "user": None,
@@ -260,6 +268,7 @@ def get_dev_env_run_dict(
260
268
  "volumes": [],
261
269
  "backends": ["local", "aws", "azure", "gcp", "lambda"],
262
270
  "regions": ["us"],
271
+ "availability_zones": None,
263
272
  "instance_types": None,
264
273
  "creation_policy": None,
265
274
  "instance_name": None,
@@ -280,6 +289,7 @@ def get_dev_env_run_dict(
280
289
  "profile": {
281
290
  "backends": ["local", "aws", "azure", "gcp", "lambda"],
282
291
  "regions": ["us"],
292
+ "availability_zones": None,
283
293
  "instance_types": None,
284
294
  "creation_policy": None,
285
295
  "default": False,
@@ -351,6 +361,7 @@ def get_dev_env_run_dict(
351
361
  "reservation": None,
352
362
  },
353
363
  "retry": None,
364
+ "volumes": [],
354
365
  "retry_policy": {"retry": False, "duration": None},
355
366
  "working_dir": ".",
356
367
  },
@@ -361,6 +372,7 @@ def get_dev_env_run_dict(
361
372
  "submitted_at": submitted_at,
362
373
  "last_processed_at": last_processed_at,
363
374
  "finished_at": finished_at,
375
+ "inactivity_secs": None,
364
376
  "status": "submitted",
365
377
  "termination_reason": None,
366
378
  "termination_reason_message": None,
@@ -375,6 +387,7 @@ def get_dev_env_run_dict(
375
387
  "submission_num": 0,
376
388
  "submitted_at": submitted_at,
377
389
  "last_processed_at": last_processed_at,
390
+ "inactivity_secs": None,
378
391
  "finished_at": finished_at,
379
392
  "status": "submitted",
380
393
  "termination_reason": None,
@@ -487,6 +500,7 @@ class TestListRuns:
487
500
  "submitted_at": run1_submitted_at.isoformat(),
488
501
  "last_processed_at": run1_submitted_at.isoformat(),
489
502
  "finished_at": None,
503
+ "inactivity_secs": None,
490
504
  "status": "submitted",
491
505
  "termination_reason": None,
492
506
  "termination_reason_message": None,
@@ -502,6 +516,7 @@ class TestListRuns:
502
516
  "submitted_at": run1_submitted_at.isoformat(),
503
517
  "last_processed_at": run1_submitted_at.isoformat(),
504
518
  "finished_at": None,
519
+ "inactivity_secs": None,
505
520
  "status": "submitted",
506
521
  "termination_reason_message": None,
507
522
  "termination_reason": None,
@@ -1303,11 +1318,20 @@ class TestStopRuns:
1303
1318
  user=user,
1304
1319
  status=RunStatus.RUNNING,
1305
1320
  )
1321
+ pool = await create_pool(session=session, project=project)
1322
+ instance = await create_instance(
1323
+ session=session,
1324
+ project=project,
1325
+ pool=pool,
1326
+ status=InstanceStatus.BUSY,
1327
+ )
1306
1328
  job = await create_job(
1307
1329
  session=session,
1308
1330
  run=run,
1309
1331
  job_provisioning_data=get_job_provisioning_data(),
1310
1332
  status=JobStatus.RUNNING,
1333
+ instance=instance,
1334
+ instance_assigned=True,
1311
1335
  )
1312
1336
  with patch("dstack._internal.server.services.jobs._stop_runner") as stop_runner:
1313
1337
  response = await client.post(
@@ -1533,7 +1557,10 @@ class TestCreateInstance:
1533
1557
  "created": result["created"],
1534
1558
  "pool_name": None,
1535
1559
  "region": None,
1560
+ "availability_zone": None,
1536
1561
  "price": None,
1562
+ "total_blocks": 1,
1563
+ "busy_blocks": 0,
1537
1564
  }
1538
1565
  assert result == expected
1539
1566
 
@@ -11,7 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
11
11
 
12
12
  from dstack._internal.core.models.backends.base import BackendType
13
13
  from dstack._internal.core.models.users import GlobalRole, ProjectRole
14
- from dstack._internal.server.models import VolumeModel
14
+ from dstack._internal.server.models import VolumeAttachmentModel, VolumeModel
15
15
  from dstack._internal.server.services.projects import add_project_member
16
16
  from dstack._internal.server.testing.common import (
17
17
  create_instance,
@@ -76,6 +76,7 @@ class TestListVolumes:
76
76
  "deleted": False,
77
77
  "volume_id": None,
78
78
  "provisioning_data": None,
79
+ "attachments": [],
79
80
  "attachment_data": None,
80
81
  },
81
82
  {
@@ -91,6 +92,7 @@ class TestListVolumes:
91
92
  "deleted": False,
92
93
  "volume_id": None,
93
94
  "provisioning_data": None,
95
+ "attachments": [],
94
96
  "attachment_data": None,
95
97
  },
96
98
  ]
@@ -117,6 +119,7 @@ class TestListVolumes:
117
119
  "deleted": False,
118
120
  "volume_id": None,
119
121
  "provisioning_data": None,
122
+ "attachments": [],
120
123
  "attachment_data": None,
121
124
  },
122
125
  ]
@@ -170,6 +173,7 @@ class TestListVolumes:
170
173
  "deleted": False,
171
174
  "volume_id": None,
172
175
  "provisioning_data": None,
176
+ "attachments": [],
173
177
  "attachment_data": None,
174
178
  },
175
179
  ]
@@ -217,6 +221,7 @@ class TestListProjectVolumes:
217
221
  "deleted": False,
218
222
  "volume_id": None,
219
223
  "provisioning_data": None,
224
+ "attachments": [],
220
225
  "attachment_data": None,
221
226
  }
222
227
  ]
@@ -264,6 +269,7 @@ class TestGetVolume:
264
269
  "deleted": False,
265
270
  "volume_id": None,
266
271
  "provisioning_data": None,
272
+ "attachments": [],
267
273
  "attachment_data": None,
268
274
  }
269
275
 
@@ -325,6 +331,7 @@ class TestCreateVolume:
325
331
  "deleted": False,
326
332
  "volume_id": None,
327
333
  "provisioning_data": None,
334
+ "attachments": [],
328
335
  "attachment_data": None,
329
336
  }
330
337
  res = await session.execute(select(VolumeModel))
@@ -391,7 +398,7 @@ class TestDeleteVolumes:
391
398
  project=project,
392
399
  pool=pool,
393
400
  )
394
- volume.instances.append(instance)
401
+ volume.attachments.append(VolumeAttachmentModel(instance=instance))
395
402
  await session.commit()
396
403
  response = await client.post(
397
404
  f"/api/project/{project.name}/volumes/delete",
File without changes
@@ -0,0 +1,72 @@
1
+ from typing import Union
2
+
3
+ import pytest
4
+
5
+ from dstack._internal.core.errors import ServerClientError
6
+ from dstack._internal.core.models.volumes import InstanceMountPoint, MountPoint, VolumeMountPoint
7
+ from dstack._internal.server.services.jobs.configurators.base import interpolate_job_volumes
8
+
9
+
10
+ class TestInterpolateJobVolumes:
11
+ @pytest.mark.parametrize(
12
+ ["run_volumes", "job_num", "job_volumes"],
13
+ [
14
+ pytest.param(
15
+ [VolumeMountPoint(name="volume", path="/volume")],
16
+ 0,
17
+ [VolumeMountPoint(name=["volume"], path="/volume")],
18
+ id="no_interpolation",
19
+ ),
20
+ pytest.param(
21
+ [InstanceMountPoint(instance_path="/volume", path="/volume")],
22
+ 0,
23
+ [InstanceMountPoint(instance_path="/volume", path="/volume")],
24
+ id="instance_mount",
25
+ ),
26
+ pytest.param(
27
+ [
28
+ VolumeMountPoint(
29
+ name="job${{dstack.job_num}}-rank${{dstack.node_rank}}", path="/volume"
30
+ )
31
+ ],
32
+ 2,
33
+ [VolumeMountPoint(name=["job2-rank2"], path="/volume")],
34
+ id="job_num_and_node_rank",
35
+ ),
36
+ ],
37
+ )
38
+ def test_interpolates_volumes(
39
+ self,
40
+ run_volumes: list[Union[MountPoint, str]],
41
+ job_num: int,
42
+ job_volumes: list[MountPoint],
43
+ ):
44
+ assert interpolate_job_volumes(run_volumes, job_num) == job_volumes
45
+
46
+ @pytest.mark.parametrize(
47
+ ["run_volumes", "job_num"],
48
+ [
49
+ pytest.param(
50
+ [VolumeMountPoint(name="${{}", path="/volume")],
51
+ 0,
52
+ id="invalid_syntax",
53
+ ),
54
+ pytest.param(
55
+ [VolumeMountPoint(name="${{ unknown.namespace }}", path="/volume")],
56
+ 0,
57
+ id="unknown_namespace",
58
+ ),
59
+ pytest.param(
60
+ [VolumeMountPoint(name="${{ dstack.var }}", path="/volume")],
61
+ 0,
62
+ id="unknown_var",
63
+ ),
64
+ ],
65
+ )
66
+ def test_raises_server_client_error(
67
+ self,
68
+ run_volumes: list[Union[MountPoint, str]],
69
+ job_num: int,
70
+ ):
71
+ with pytest.raises(ServerClientError):
72
+ assert interpolate_job_volumes(run_volumes, job_num)
@@ -9,7 +9,13 @@ from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT
9
9
  from dstack._internal.core.models.backends.base import BackendType
10
10
  from dstack._internal.core.models.common import NetworkMode
11
11
  from dstack._internal.core.models.resources import Memory
12
- from dstack._internal.core.models.volumes import InstanceMountPoint, VolumeMountPoint
12
+ from dstack._internal.core.models.volumes import (
13
+ InstanceMountPoint,
14
+ VolumeAttachment,
15
+ VolumeAttachmentData,
16
+ VolumeInstance,
17
+ VolumeMountPoint,
18
+ )
13
19
  from dstack._internal.server.schemas.runner import (
14
20
  HealthcheckResponse,
15
21
  JobResult,
@@ -133,7 +139,12 @@ class TestShimClientV1(BaseShimClientTest):
133
139
  volume_id="vol-id",
134
140
  configuration=get_volume_configuration(backend=BackendType.GCP),
135
141
  external=False,
136
- device_name="/dev/sdv",
142
+ attachments=[
143
+ VolumeAttachment(
144
+ instance=VolumeInstance(name="instance", instance_num=0, instance_id="i-1"),
145
+ attachment_data=VolumeAttachmentData(device_name="/dev/sdv"),
146
+ )
147
+ ],
137
148
  )
138
149
 
139
150
  submitted = client.submit(
@@ -150,6 +161,7 @@ class TestShimClientV1(BaseShimClientTest):
150
161
  mounts=[VolumeMountPoint(name="vol", path="/vol")],
151
162
  volumes=[volume],
152
163
  instance_mounts=[InstanceMountPoint(instance_path="/mnt/nfs/home", path="/home")],
164
+ instance_id="i-1",
153
165
  )
154
166
 
155
167
  assert submitted is True
@@ -198,6 +210,7 @@ class TestShimClientV1(BaseShimClientTest):
198
210
  mounts=[],
199
211
  volumes=[],
200
212
  instance_mounts=[],
213
+ instance_id="",
201
214
  )
202
215
 
203
216
  assert submitted is False
@@ -294,7 +307,12 @@ class TestShimClientV2(BaseShimClientTest):
294
307
  volume_id="vol-id",
295
308
  configuration=get_volume_configuration(backend=BackendType.GCP),
296
309
  external=False,
297
- device_name="/dev/sdv",
310
+ attachments=[
311
+ VolumeAttachment(
312
+ instance=VolumeInstance(name="instance", instance_num=0, instance_id="i-1"),
313
+ attachment_data=VolumeAttachmentData(device_name="/dev/sdv"),
314
+ )
315
+ ],
298
316
  )
299
317
 
300
318
  client.submit_task(
@@ -316,6 +334,7 @@ class TestShimClientV2(BaseShimClientTest):
316
334
  host_ssh_user="dstack",
317
335
  host_ssh_keys=["host_key"],
318
336
  container_ssh_keys=["project_key", "user_key"],
337
+ instance_id="i-1",
319
338
  )
320
339
 
321
340
  assert adapter.call_count == 2
@@ -0,0 +1,167 @@
1
+ from unittest.mock import Mock, patch
2
+
3
+ import pytest
4
+
5
+ from dstack._internal.core.models.backends.base import BackendType
6
+ from dstack._internal.core.models.profiles import Profile
7
+ from dstack._internal.core.models.resources import ResourcesSpec
8
+ from dstack._internal.core.models.runs import Requirements
9
+ from dstack._internal.server.services.offers import get_offers_by_requirements
10
+ from dstack._internal.server.testing.common import (
11
+ get_instance_offer_with_availability,
12
+ get_volume,
13
+ get_volume_configuration,
14
+ )
15
+
16
+
17
+ class TestGetOffersByRequirements:
18
+ @pytest.mark.asyncio
19
+ async def test_returns_all_offers(self):
20
+ profile = Profile(name="test")
21
+ requirements = Requirements(resources=ResourcesSpec())
22
+ with patch("dstack._internal.server.services.backends.get_project_backends") as m:
23
+ aws_backend_mock = Mock()
24
+ aws_backend_mock.TYPE = BackendType.AWS
25
+ aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS)
26
+ aws_backend_mock.compute.return_value.get_offers_cached.return_value = [aws_offer]
27
+ runpod_backend_mock = Mock()
28
+ runpod_backend_mock.TYPE = BackendType.RUNPOD
29
+ runpod_offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD)
30
+ runpod_backend_mock.compute.return_value.get_offers_cached.return_value = [
31
+ runpod_offer
32
+ ]
33
+ m.return_value = [aws_backend_mock, runpod_backend_mock]
34
+ res = await get_offers_by_requirements(
35
+ project=Mock(),
36
+ profile=profile,
37
+ requirements=requirements,
38
+ )
39
+ m.assert_awaited_once()
40
+ assert res == [(aws_backend_mock, aws_offer), (runpod_backend_mock, runpod_offer)]
41
+
42
+ @pytest.mark.asyncio
43
+ async def test_returns_multinode_offers(self):
44
+ profile = Profile(name="test")
45
+ requirements = Requirements(resources=ResourcesSpec())
46
+ with patch("dstack._internal.server.services.backends.get_project_backends") as m:
47
+ aws_backend_mock = Mock()
48
+ aws_backend_mock.TYPE = BackendType.AWS
49
+ aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS)
50
+ aws_backend_mock.compute.return_value.get_offers_cached.return_value = [aws_offer]
51
+ runpod_backend_mock = Mock()
52
+ runpod_backend_mock.TYPE = BackendType.RUNPOD
53
+ runpod_offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD)
54
+ runpod_backend_mock.compute.return_value.get_offers_cached.return_value = [
55
+ runpod_offer
56
+ ]
57
+ m.return_value = [aws_backend_mock, runpod_backend_mock]
58
+ res = await get_offers_by_requirements(
59
+ project=Mock(),
60
+ profile=profile,
61
+ requirements=requirements,
62
+ multinode=True,
63
+ )
64
+ m.assert_awaited_once()
65
+ assert res == [(aws_backend_mock, aws_offer)]
66
+
67
+ @pytest.mark.asyncio
68
+ async def test_returns_volume_offers(self):
69
+ profile = Profile(name="test")
70
+ requirements = Requirements(resources=ResourcesSpec())
71
+ with patch("dstack._internal.server.services.backends.get_project_backends") as m:
72
+ aws_backend_mock = Mock()
73
+ aws_backend_mock.TYPE = BackendType.AWS
74
+ aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS)
75
+ aws_backend_mock.compute.return_value.get_offers_cached.return_value = [aws_offer]
76
+ runpod_backend_mock = Mock()
77
+ runpod_backend_mock.TYPE = BackendType.RUNPOD
78
+ runpod_offer1 = get_instance_offer_with_availability(
79
+ backend=BackendType.RUNPOD, region="eu"
80
+ )
81
+ runpod_offer2 = get_instance_offer_with_availability(
82
+ backend=BackendType.RUNPOD, region="us"
83
+ )
84
+ runpod_backend_mock.compute.return_value.get_offers_cached.return_value = [
85
+ runpod_offer1,
86
+ runpod_offer2,
87
+ ]
88
+ m.return_value = [aws_backend_mock, runpod_backend_mock]
89
+ res = await get_offers_by_requirements(
90
+ project=Mock(),
91
+ profile=profile,
92
+ requirements=requirements,
93
+ volumes=[
94
+ [
95
+ get_volume(
96
+ configuration=get_volume_configuration(
97
+ backend=BackendType.RUNPOD, region="us"
98
+ )
99
+ )
100
+ ]
101
+ ],
102
+ )
103
+ m.assert_awaited_once()
104
+ assert res == [(runpod_backend_mock, runpod_offer2)]
105
+
106
+ @pytest.mark.asyncio
107
+ async def test_returns_az_offers(self):
108
+ profile = Profile(name="test", availability_zones=["az1", "az3"])
109
+ requirements = Requirements(resources=ResourcesSpec())
110
+ with patch("dstack._internal.server.services.backends.get_project_backends") as m:
111
+ aws_backend_mock = Mock()
112
+ aws_backend_mock.TYPE = BackendType.AWS
113
+ aws_offer1 = get_instance_offer_with_availability(
114
+ backend=BackendType.AWS, availability_zones=["az1"]
115
+ )
116
+ aws_offer2 = get_instance_offer_with_availability(
117
+ backend=BackendType.AWS, availability_zones=["az2"]
118
+ )
119
+ aws_offer3 = get_instance_offer_with_availability(
120
+ backend=BackendType.AWS, availability_zones=["az2", "az3"]
121
+ )
122
+ expected_aws_offer3 = aws_offer3.copy()
123
+ expected_aws_offer3.availability_zones = ["az3"]
124
+ aws_offer4 = get_instance_offer_with_availability(
125
+ backend=BackendType.AWS, availability_zones=None
126
+ )
127
+ aws_backend_mock.compute.return_value.get_offers_cached.return_value = [
128
+ aws_offer1,
129
+ aws_offer2,
130
+ aws_offer3,
131
+ aws_offer4,
132
+ ]
133
+ m.return_value = [aws_backend_mock]
134
+ res = await get_offers_by_requirements(
135
+ project=Mock(),
136
+ profile=profile,
137
+ requirements=requirements,
138
+ )
139
+ m.assert_awaited_once()
140
+ assert res == [(aws_backend_mock, aws_offer1), (aws_backend_mock, expected_aws_offer3)]
141
+
142
+ @pytest.mark.asyncio
143
+ async def test_returns_no_offers_for_multinode_instance_mounts_and_non_multinode_backend(self):
144
+ # Regression test for https://github.com/dstackai/dstack/issues/2211
145
+ profile = Profile(name="test", backends=[BackendType.RUNPOD])
146
+ requirements = Requirements(resources=ResourcesSpec())
147
+ with patch("dstack._internal.server.services.backends.get_project_backends") as m:
148
+ aws_backend_mock = Mock()
149
+ aws_backend_mock.TYPE = BackendType.AWS
150
+ aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS)
151
+ aws_backend_mock.compute.return_value.get_offers_cached.return_value = [aws_offer]
152
+ runpod_backend_mock = Mock()
153
+ runpod_backend_mock.TYPE = BackendType.RUNPOD
154
+ runpod_offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD)
155
+ runpod_backend_mock.compute.return_value.get_offers_cached.return_value = [
156
+ runpod_offer
157
+ ]
158
+ m.return_value = [aws_backend_mock, runpod_backend_mock]
159
+ res = await get_offers_by_requirements(
160
+ project=Mock(),
161
+ profile=profile,
162
+ requirements=requirements,
163
+ multinode=True,
164
+ instance_mounts=True,
165
+ )
166
+ m.assert_awaited_once()
167
+ assert res == []