dstack 0.18.40rc1__py3-none-any.whl → 0.18.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -5
  2. dstack/_internal/cli/services/configurators/base.py +4 -2
  3. dstack/_internal/cli/services/configurators/fleet.py +21 -9
  4. dstack/_internal/cli/services/configurators/gateway.py +15 -0
  5. dstack/_internal/cli/services/configurators/run.py +6 -5
  6. dstack/_internal/cli/services/configurators/volume.py +15 -0
  7. dstack/_internal/cli/services/repos.py +3 -3
  8. dstack/_internal/cli/utils/fleet.py +44 -33
  9. dstack/_internal/cli/utils/run.py +27 -7
  10. dstack/_internal/cli/utils/volume.py +30 -9
  11. dstack/_internal/core/backends/aws/compute.py +94 -53
  12. dstack/_internal/core/backends/aws/resources.py +22 -12
  13. dstack/_internal/core/backends/azure/compute.py +2 -0
  14. dstack/_internal/core/backends/base/compute.py +20 -2
  15. dstack/_internal/core/backends/gcp/compute.py +32 -24
  16. dstack/_internal/core/backends/gcp/resources.py +0 -15
  17. dstack/_internal/core/backends/oci/compute.py +10 -5
  18. dstack/_internal/core/backends/oci/resources.py +23 -26
  19. dstack/_internal/core/backends/remote/provisioning.py +65 -27
  20. dstack/_internal/core/backends/runpod/compute.py +1 -0
  21. dstack/_internal/core/models/backends/azure.py +3 -1
  22. dstack/_internal/core/models/configurations.py +24 -1
  23. dstack/_internal/core/models/fleets.py +46 -0
  24. dstack/_internal/core/models/instances.py +5 -1
  25. dstack/_internal/core/models/pools.py +4 -1
  26. dstack/_internal/core/models/profiles.py +10 -4
  27. dstack/_internal/core/models/runs.py +23 -3
  28. dstack/_internal/core/models/volumes.py +26 -0
  29. dstack/_internal/core/services/ssh/attach.py +92 -53
  30. dstack/_internal/core/services/ssh/tunnel.py +58 -31
  31. dstack/_internal/proxy/gateway/routers/registry.py +2 -0
  32. dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
  33. dstack/_internal/proxy/gateway/services/registry.py +4 -0
  34. dstack/_internal/proxy/lib/models.py +3 -0
  35. dstack/_internal/proxy/lib/services/service_connection.py +8 -1
  36. dstack/_internal/server/background/tasks/process_instances.py +73 -35
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -9
  38. dstack/_internal/server/background/tasks/process_running_jobs.py +77 -26
  39. dstack/_internal/server/background/tasks/process_runs.py +2 -12
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +121 -49
  41. dstack/_internal/server/background/tasks/process_terminating_jobs.py +14 -3
  42. dstack/_internal/server/background/tasks/process_volumes.py +11 -1
  43. dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
  44. dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
  45. dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
  46. dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
  47. dstack/_internal/server/migrations/versions/a751ef183f27_move_attachment_data_to_volumes_.py +34 -0
  48. dstack/_internal/server/models.py +27 -23
  49. dstack/_internal/server/routers/runs.py +1 -0
  50. dstack/_internal/server/schemas/runner.py +1 -0
  51. dstack/_internal/server/services/backends/configurators/azure.py +34 -8
  52. dstack/_internal/server/services/config.py +9 -0
  53. dstack/_internal/server/services/fleets.py +32 -3
  54. dstack/_internal/server/services/gateways/client.py +9 -1
  55. dstack/_internal/server/services/jobs/__init__.py +217 -45
  56. dstack/_internal/server/services/jobs/configurators/base.py +47 -2
  57. dstack/_internal/server/services/offers.py +96 -10
  58. dstack/_internal/server/services/pools.py +98 -14
  59. dstack/_internal/server/services/proxy/repo.py +17 -3
  60. dstack/_internal/server/services/runner/client.py +9 -6
  61. dstack/_internal/server/services/runner/ssh.py +33 -5
  62. dstack/_internal/server/services/runs.py +48 -179
  63. dstack/_internal/server/services/services/__init__.py +9 -1
  64. dstack/_internal/server/services/volumes.py +68 -9
  65. dstack/_internal/server/statics/index.html +1 -1
  66. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
  67. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
  68. dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
  69. dstack/_internal/server/testing/common.py +130 -61
  70. dstack/_internal/utils/common.py +22 -8
  71. dstack/_internal/utils/env.py +14 -0
  72. dstack/_internal/utils/ssh.py +1 -1
  73. dstack/api/server/_fleets.py +25 -1
  74. dstack/api/server/_runs.py +23 -2
  75. dstack/api/server/_volumes.py +12 -1
  76. dstack/version.py +1 -1
  77. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/METADATA +1 -1
  78. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/RECORD +104 -93
  79. tests/_internal/cli/services/configurators/test_profile.py +3 -3
  80. tests/_internal/core/services/ssh/test_tunnel.py +56 -4
  81. tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
  82. tests/_internal/server/background/tasks/test_process_instances.py +138 -20
  83. tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
  84. tests/_internal/server/background/tasks/test_process_running_jobs.py +193 -0
  85. tests/_internal/server/background/tasks/test_process_runs.py +27 -3
  86. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +53 -6
  87. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +135 -17
  88. tests/_internal/server/routers/test_fleets.py +15 -2
  89. tests/_internal/server/routers/test_pools.py +6 -0
  90. tests/_internal/server/routers/test_runs.py +27 -0
  91. tests/_internal/server/routers/test_volumes.py +9 -2
  92. tests/_internal/server/services/jobs/__init__.py +0 -0
  93. tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
  94. tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
  95. tests/_internal/server/services/runner/test_client.py +22 -3
  96. tests/_internal/server/services/test_offers.py +167 -0
  97. tests/_internal/server/services/test_pools.py +109 -1
  98. tests/_internal/server/services/test_runs.py +5 -41
  99. tests/_internal/utils/test_common.py +21 -0
  100. tests/_internal/utils/test_env.py +38 -0
  101. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/LICENSE.md +0 -0
  102. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/WHEEL +0 -0
  103. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/entry_points.txt +0 -0
  104. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import ipaddress
2
2
  import uuid
3
3
  from collections.abc import Container, Iterable
4
4
  from datetime import datetime, timezone
5
- from typing import List, Optional
5
+ from typing import List, Literal, Optional, Union
6
6
 
7
7
  import gpuhunt
8
8
  from sqlalchemy import and_, or_, select
@@ -30,6 +30,7 @@ from dstack._internal.core.models.instances import (
30
30
  InstanceType,
31
31
  RemoteConnectionInfo,
32
32
  Resources,
33
+ SSHConnectionParams,
33
34
  SSHKey,
34
35
  )
35
36
  from dstack._internal.core.models.pools import Instance, Pool, PoolInstances
@@ -51,6 +52,7 @@ from dstack._internal.server.models import (
51
52
  UserModel,
52
53
  )
53
54
  from dstack._internal.server.services.locking import get_locker
55
+ from dstack._internal.server.services.offers import generate_shared_offer
54
56
  from dstack._internal.server.services.projects import list_project_models, list_user_project_models
55
57
  from dstack._internal.utils import common as common_utils
56
58
  from dstack._internal.utils import random_names
@@ -206,7 +208,7 @@ async def remove_instance(
206
208
  terminated = False
207
209
  for instance in pool.instances:
208
210
  if instance.name == instance_name:
209
- if force or instance.job_id is None:
211
+ if force or not instance.jobs:
210
212
  instance.status = InstanceStatus.TERMINATING
211
213
  terminated = True
212
214
  await session.commit()
@@ -249,6 +251,8 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance:
249
251
  unreachable=instance_model.unreachable,
250
252
  termination_reason=instance_model.termination_reason,
251
253
  created=instance_model.created_at.replace(tzinfo=timezone.utc),
254
+ total_blocks=instance_model.total_blocks,
255
+ busy_blocks=instance_model.busy_blocks,
252
256
  )
253
257
 
254
258
  offer = get_instance_offer(instance_model)
@@ -261,9 +265,7 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance:
261
265
  if jpd is not None:
262
266
  instance.instance_type = jpd.instance_type
263
267
  instance.hostname = jpd.hostname
264
-
265
- if instance_model.job is not None:
266
- instance.job_name = instance_model.job.job_name
268
+ instance.availability_zone = jpd.availability_zone
267
269
 
268
270
  return instance
269
271
 
@@ -292,6 +294,27 @@ def get_instance_requirements(instance_model: InstanceModel) -> Requirements:
292
294
  return Requirements.__response__.parse_raw(instance_model.requirements)
293
295
 
294
296
 
297
+ def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, Optional[str]]:
298
+ """
299
+ Returns a pair of SSH private keys: host key and optional proxy jump key.
300
+ """
301
+ host_private_key = instance_model.project.ssh_private_key
302
+ if instance_model.remote_connection_info is None:
303
+ # Cloud instance
304
+ return host_private_key, None
305
+ # SSH instance
306
+ rci = RemoteConnectionInfo.__response__.parse_raw(instance_model.remote_connection_info)
307
+ if rci.ssh_proxy is None:
308
+ return host_private_key, None
309
+ if rci.ssh_proxy_keys is None:
310
+ # Inconsistent RemoteConnectionInfo structure - proxy without keys
311
+ raise ValueError("Missing instance SSH proxy private keys")
312
+ proxy_private_keys = [key.private for key in rci.ssh_proxy_keys if key.private is not None]
313
+ if not proxy_private_keys:
314
+ raise ValueError("No instance SSH proxy private key found")
315
+ return host_private_key, proxy_private_keys[0]
316
+
317
+
295
318
  async def generate_instance_name(
296
319
  session: AsyncSession,
297
320
  project: ProjectModel,
@@ -409,20 +432,21 @@ async def add_remote(
409
432
  def filter_pool_instances(
410
433
  pool_instances: List[InstanceModel],
411
434
  profile: Profile,
412
- requirements: Requirements,
413
435
  *,
436
+ requirements: Optional[Requirements] = None,
414
437
  status: Optional[InstanceStatus] = None,
415
438
  fleet_model: Optional[FleetModel] = None,
416
439
  multinode: bool = False,
417
440
  master_job_provisioning_data: Optional[JobProvisioningData] = None,
418
441
  volumes: Optional[List[List[Volume]]] = None,
442
+ shared: bool = False,
419
443
  ) -> List[InstanceModel]:
420
444
  instances: List[InstanceModel] = []
421
445
  candidates: List[InstanceModel] = []
422
446
 
423
447
  backend_types = profile.backends
424
448
  regions = profile.regions
425
- zones = None
449
+ zones = profile.availability_zones
426
450
 
427
451
  if volumes:
428
452
  mount_point_volumes = volumes[0]
@@ -433,23 +457,24 @@ def filter_pool_instances(
433
457
  for v in mount_point_volumes
434
458
  if v.provisioning_data is not None
435
459
  ]
436
- if volume_zones:
460
+ if zones is None:
437
461
  zones = volume_zones
462
+ zones = [z for z in zones if z in volume_zones]
438
463
 
439
464
  if multinode:
440
- if not backend_types:
465
+ if backend_types is None:
441
466
  backend_types = BACKENDS_WITH_MULTINODE_SUPPORT
442
467
  backend_types = [b for b in backend_types if b in BACKENDS_WITH_MULTINODE_SUPPORT]
443
468
 
444
469
  # For multi-node, restrict backend and region.
445
470
  # The default behavior is to provision all nodes in the same backend and region.
446
471
  if master_job_provisioning_data is not None:
447
- if not backend_types:
472
+ if backend_types is None:
448
473
  backend_types = [master_job_provisioning_data.get_base_backend()]
449
474
  backend_types = [
450
475
  b for b in backend_types if b == master_job_provisioning_data.get_base_backend()
451
476
  ]
452
- if not regions:
477
+ if regions is None:
453
478
  regions = [master_job_provisioning_data.region]
454
479
  regions = [r for r in regions if r == master_job_provisioning_data.region]
455
480
 
@@ -480,9 +505,17 @@ def filter_pool_instances(
480
505
  and jpd.availability_zone not in zones
481
506
  ):
482
507
  continue
508
+ if instance.total_blocks is None:
509
+ # Still provisioning, we don't know yet if it shared or not
510
+ continue
511
+ if (instance.total_blocks > 1) != shared:
512
+ continue
483
513
 
484
514
  candidates.append(instance)
485
515
 
516
+ if requirements is None:
517
+ return candidates
518
+
486
519
  query_filter = requirements_to_query_filter(requirements)
487
520
  for instance in candidates:
488
521
  if instance.offer is None:
@@ -494,6 +527,47 @@ def filter_pool_instances(
494
527
  return instances
495
528
 
496
529
 
530
+ def get_shared_pool_instances_with_offers(
531
+ pool_instances: List[InstanceModel],
532
+ profile: Profile,
533
+ requirements: Requirements,
534
+ *,
535
+ idle_only: bool = False,
536
+ fleet_model: Optional[FleetModel] = None,
537
+ volumes: Optional[List[List[Volume]]] = None,
538
+ ) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]:
539
+ instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]] = []
540
+ query_filter = requirements_to_query_filter(requirements)
541
+ filtered_instances = filter_pool_instances(
542
+ pool_instances=pool_instances,
543
+ profile=profile,
544
+ fleet_model=fleet_model,
545
+ multinode=False,
546
+ volumes=volumes,
547
+ shared=True,
548
+ )
549
+ for instance in filtered_instances:
550
+ if idle_only and instance.status not in [InstanceStatus.IDLE, InstanceStatus.BUSY]:
551
+ continue
552
+ offer = get_instance_offer(instance)
553
+ if offer is None:
554
+ continue
555
+ total_blocks = common_utils.get_or_error(instance.total_blocks)
556
+ idle_blocks = total_blocks - instance.busy_blocks
557
+ for blocks in range(1, total_blocks + 1):
558
+ shared_offer = generate_shared_offer(offer, blocks, total_blocks)
559
+ catalog_item = offer_to_catalog_item(shared_offer)
560
+ if gpuhunt.matches(catalog_item, query_filter):
561
+ if blocks <= idle_blocks:
562
+ shared_offer.availability = InstanceAvailability.IDLE
563
+ else:
564
+ shared_offer.availability = InstanceAvailability.BUSY
565
+ if shared_offer.availability == InstanceAvailability.IDLE or not idle_only:
566
+ instances_with_offers.append((instance, shared_offer))
567
+ break
568
+ return instances_with_offers
569
+
570
+
497
571
  async def list_pools_instance_models(
498
572
  session: AsyncSession,
499
573
  projects: List[ProjectModel],
@@ -557,7 +631,7 @@ async def list_pools_instance_models(
557
631
  .limit(limit)
558
632
  .options(joinedload(InstanceModel.pool), joinedload(InstanceModel.fleet))
559
633
  )
560
- instance_models = list(res.scalars().all())
634
+ instance_models = list(res.unique().scalars().all())
561
635
  return instance_models
562
636
 
563
637
 
@@ -618,7 +692,7 @@ async def list_active_remote_instances(
618
692
  res = await session.execute(
619
693
  select(InstanceModel).where(*filters).order_by(InstanceModel.created_at.asc())
620
694
  )
621
- instance_models = list(res.scalars().all())
695
+ instance_models = list(res.unique().scalars().all())
622
696
  return instance_models
623
697
 
624
698
 
@@ -633,6 +707,7 @@ async def create_instance_model(
633
707
  instance_num: int,
634
708
  placement_group_name: Optional[str],
635
709
  reservation: Optional[str],
710
+ blocks: Union[Literal["auto"], int],
636
711
  ) -> InstanceModel:
637
712
  termination_policy, termination_idle_time = get_termination(
638
713
  profile, DEFAULT_POOL_TERMINATION_IDLE_TIME
@@ -646,8 +721,8 @@ async def create_instance_model(
646
721
  project_name=project.name,
647
722
  instance_name=instance_name,
648
723
  user=user.name,
649
- instance_id=str(instance_id),
650
724
  ssh_keys=[project_ssh_key],
725
+ instance_id=str(instance_id),
651
726
  placement_group_name=placement_group_name,
652
727
  reservation=reservation,
653
728
  )
@@ -665,6 +740,8 @@ async def create_instance_model(
665
740
  instance_configuration=instance_config.json(),
666
741
  termination_policy=termination_policy,
667
742
  termination_idle_time=termination_idle_time,
743
+ total_blocks=None if blocks == "auto" else blocks,
744
+ busy_blocks=0,
668
745
  )
669
746
  session.add(instance)
670
747
  return instance
@@ -682,7 +759,10 @@ async def create_ssh_instance_model(
682
759
  port: int,
683
760
  ssh_user: str,
684
761
  ssh_keys: List[SSHKey],
762
+ ssh_proxy: Optional[SSHConnectionParams],
763
+ ssh_proxy_keys: Optional[list[SSHKey]],
685
764
  env: Env,
765
+ blocks: Union[Literal["auto"], int],
686
766
  ) -> InstanceModel:
687
767
  # TODO: doc - will overwrite after remote connected
688
768
  instance_resource = Resources(cpus=2, memory_mib=8, gpus=[], spot=False)
@@ -717,6 +797,8 @@ async def create_ssh_instance_model(
717
797
  port=port,
718
798
  ssh_user=ssh_user,
719
799
  ssh_keys=ssh_keys,
800
+ ssh_proxy=ssh_proxy,
801
+ ssh_proxy_keys=ssh_proxy_keys,
720
802
  env=env,
721
803
  )
722
804
  im = InstanceModel(
@@ -737,5 +819,7 @@ async def create_ssh_instance_model(
737
819
  price=offer.price,
738
820
  termination_policy=TerminationPolicy.DONT_DESTROY,
739
821
  termination_idle_time=0,
822
+ total_blocks=None if blocks == "auto" else blocks,
823
+ busy_blocks=0,
740
824
  )
741
825
  return im
@@ -9,7 +9,7 @@ import dstack._internal.server.services.jobs as jobs_services
9
9
  from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
10
10
  from dstack._internal.core.models.common import is_core_model_instance
11
11
  from dstack._internal.core.models.configurations import ServiceConfiguration
12
- from dstack._internal.core.models.instances import SSHConnectionParams
12
+ from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams
13
13
  from dstack._internal.core.models.runs import (
14
14
  JobProvisioningData,
15
15
  JobStatus,
@@ -30,6 +30,7 @@ from dstack._internal.proxy.lib.models import (
30
30
  from dstack._internal.proxy.lib.repo import BaseProxyRepo
31
31
  from dstack._internal.server.models import JobModel, ProjectModel, RunModel
32
32
  from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE
33
+ from dstack._internal.utils.common import get_or_error
33
34
 
34
35
 
35
36
  class ServerProxyRepo(BaseProxyRepo):
@@ -53,9 +54,12 @@ class ServerProxyRepo(BaseProxyRepo):
53
54
  JobModel.status == JobStatus.RUNNING,
54
55
  JobModel.job_num == 0,
55
56
  )
56
- .options(joinedload(JobModel.run))
57
+ .options(
58
+ joinedload(JobModel.run),
59
+ joinedload(JobModel.instance),
60
+ )
57
61
  )
58
- jobs = res.scalars().all()
62
+ jobs = res.unique().scalars().all()
59
63
  if not len(jobs):
60
64
  return None
61
65
  run = jobs[0].run
@@ -83,12 +87,22 @@ class ServerProxyRepo(BaseProxyRepo):
83
87
  username=jpd.username,
84
88
  port=jpd.ssh_port,
85
89
  )
90
+ ssh_head_proxy: Optional[SSHConnectionParams] = None
91
+ ssh_head_proxy_private_key: Optional[str] = None
92
+ instance = get_or_error(job.instance)
93
+ if instance.remote_connection_info is not None:
94
+ rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
95
+ if rci.ssh_proxy is not None:
96
+ ssh_head_proxy = rci.ssh_proxy
97
+ ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
86
98
  replica = Replica(
87
99
  id=job.id.hex,
88
100
  app_port=run_spec.configuration.port.container_port,
89
101
  ssh_destination=ssh_destination,
90
102
  ssh_port=ssh_port,
91
103
  ssh_proxy=ssh_proxy,
104
+ ssh_head_proxy=ssh_head_proxy,
105
+ ssh_head_proxy_private_key=ssh_head_proxy_private_key,
92
106
  )
93
107
  replicas.append(replica)
94
108
  return Service(
@@ -30,7 +30,7 @@ from dstack._internal.server.schemas.runner import (
30
30
  from dstack._internal.utils.common import get_or_error
31
31
  from dstack._internal.utils.logging import get_logger
32
32
 
33
- REQUEST_TIMEOUT = 15
33
+ REQUEST_TIMEOUT = 9
34
34
 
35
35
  logger = get_logger(__name__)
36
36
 
@@ -239,6 +239,7 @@ class ShimClient:
239
239
  host_ssh_user: str,
240
240
  host_ssh_keys: list[str],
241
241
  container_ssh_keys: list[str],
242
+ instance_id: str,
242
243
  ) -> None:
243
244
  if not self.is_api_v2_supported():
244
245
  raise ShimAPIVersionError()
@@ -255,7 +256,7 @@ class ShimClient:
255
256
  memory=_memory_to_bytes(memory), # None = 0 = "all available"
256
257
  shm_size=_memory_to_bytes(shm_size), # None = 0 = "use default value"
257
258
  network_mode=network_mode,
258
- volumes=[_volume_to_shim_volume_info(v) for v in volumes],
259
+ volumes=[_volume_to_shim_volume_info(v, instance_id) for v in volumes],
259
260
  volume_mounts=volume_mounts,
260
261
  instance_mounts=instance_mounts,
261
262
  host_ssh_user=host_ssh_user,
@@ -303,6 +304,7 @@ class ShimClient:
303
304
  mounts: List[VolumeMountPoint],
304
305
  volumes: List[Volume],
305
306
  instance_mounts: List[InstanceMountPoint],
307
+ instance_id: str,
306
308
  ) -> bool:
307
309
  """
308
310
  Returns `True` if submitted and `False` if the shim already has a job (`409 Conflict`).
@@ -320,7 +322,7 @@ class ShimClient:
320
322
  ssh_user=ssh_user,
321
323
  ssh_key=ssh_key,
322
324
  mounts=mounts,
323
- volumes=[_volume_to_shim_volume_info(v) for v in volumes],
325
+ volumes=[_volume_to_shim_volume_info(v, instance_id) for v in volumes],
324
326
  instance_mounts=instance_mounts,
325
327
  )
326
328
  resp = self._request("POST", "/api/submit", body)
@@ -398,10 +400,11 @@ def health_response_to_health_status(data: HealthcheckResponse) -> HealthStatus:
398
400
  )
399
401
 
400
402
 
401
- def _volume_to_shim_volume_info(volume: Volume) -> ShimVolumeInfo:
403
+ def _volume_to_shim_volume_info(volume: Volume, instance_id: str) -> ShimVolumeInfo:
402
404
  device_name = None
403
- if volume.attachment_data is not None:
404
- device_name = volume.attachment_data.device_name
405
+ attachment_data = volume.get_attachment_data_for_instance(instance_id)
406
+ if attachment_data is not None:
407
+ device_name = attachment_data.device_name
405
408
  return ShimVolumeInfo(
406
409
  backend=volume.configuration.backend.value,
407
410
  name=volume.name,
@@ -17,22 +17,36 @@ from dstack._internal.utils.path import FileContent
17
17
  logger = get_logger(__name__)
18
18
  P = ParamSpec("P")
19
19
  R = TypeVar("R")
20
+ # A host private key or pair of (host private key, optional proxy jump private key)
21
+ PrivateKeyOrPair = Union[str, tuple[str, Optional[str]]]
20
22
 
21
23
 
22
24
  def runner_ssh_tunnel(
23
25
  ports: List[int], retries: int = 3, retry_interval: float = 1
24
26
  ) -> Callable[
25
27
  [Callable[Concatenate[Dict[int, int], P], R]],
26
- Callable[Concatenate[str, JobProvisioningData, Optional[JobRuntimeData], P], Union[bool, R]],
28
+ Callable[
29
+ Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P],
30
+ Union[bool, R],
31
+ ],
27
32
  ]:
33
+ """
34
+ A decorator that opens an SSH tunnel to the runner.
35
+
36
+ NOTE: connections from dstack-server to running jobs are expected to be short.
37
+ The runner uses a heuristic to differentiate dstack-server connections from
38
+ client connections based on their duration. See `ConnectionTracker` for details.
39
+ """
40
+
28
41
  def decorator(
29
42
  func: Callable[Concatenate[Dict[int, int], P], R],
30
43
  ) -> Callable[
31
- Concatenate[str, JobProvisioningData, Optional[JobRuntimeData], P], Union[bool, R]
44
+ Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P],
45
+ Union[bool, R],
32
46
  ]:
33
47
  @functools.wraps(func)
34
48
  def wrapper(
35
- ssh_private_key: str,
49
+ ssh_private_key: PrivateKeyOrPair,
36
50
  job_provisioning_data: JobProvisioningData,
37
51
  job_runtime_data: Optional[JobRuntimeData],
38
52
  *args: P.args,
@@ -51,6 +65,20 @@ def runner_ssh_tunnel(
51
65
  # without SSH
52
66
  return func(container_ports_map, *args, **kwargs)
53
67
 
68
+ if isinstance(ssh_private_key, str):
69
+ ssh_proxy_private_key = None
70
+ else:
71
+ ssh_private_key, ssh_proxy_private_key = ssh_private_key
72
+ identity = FileContent(ssh_private_key)
73
+ if ssh_proxy_private_key is not None:
74
+ proxy_identity = FileContent(ssh_proxy_private_key)
75
+ else:
76
+ proxy_identity = None
77
+
78
+ ssh_proxies = []
79
+ if job_provisioning_data.ssh_proxy is not None:
80
+ ssh_proxies.append((job_provisioning_data.ssh_proxy, proxy_identity))
81
+
54
82
  for attempt in range(retries):
55
83
  last = attempt == retries - 1
56
84
  # remote_host:local mapping
@@ -66,8 +94,8 @@ def runner_ssh_tunnel(
66
94
  ),
67
95
  port=job_provisioning_data.ssh_port,
68
96
  forwarded_sockets=ports_to_forwarded_sockets(tunnel_ports_map),
69
- identity=FileContent(ssh_private_key),
70
- ssh_proxy=job_provisioning_data.ssh_proxy,
97
+ identity=identity,
98
+ ssh_proxies=ssh_proxies,
71
99
  ):
72
100
  return func(runner_ports_map, *args, **kwargs)
73
101
  except SSHError: