dstack 0.18.40rc1__py3-none-any.whl → 0.18.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dstack/_internal/cli/commands/apply.py +8 -5
- dstack/_internal/cli/services/configurators/base.py +4 -2
- dstack/_internal/cli/services/configurators/fleet.py +21 -9
- dstack/_internal/cli/services/configurators/gateway.py +15 -0
- dstack/_internal/cli/services/configurators/run.py +6 -5
- dstack/_internal/cli/services/configurators/volume.py +15 -0
- dstack/_internal/cli/services/repos.py +3 -3
- dstack/_internal/cli/utils/fleet.py +44 -33
- dstack/_internal/cli/utils/run.py +27 -7
- dstack/_internal/cli/utils/volume.py +21 -9
- dstack/_internal/core/backends/aws/compute.py +92 -52
- dstack/_internal/core/backends/aws/resources.py +22 -12
- dstack/_internal/core/backends/azure/compute.py +2 -0
- dstack/_internal/core/backends/base/compute.py +20 -2
- dstack/_internal/core/backends/gcp/compute.py +30 -23
- dstack/_internal/core/backends/gcp/resources.py +0 -15
- dstack/_internal/core/backends/oci/compute.py +10 -5
- dstack/_internal/core/backends/oci/resources.py +23 -26
- dstack/_internal/core/backends/remote/provisioning.py +65 -27
- dstack/_internal/core/backends/runpod/compute.py +1 -0
- dstack/_internal/core/models/backends/azure.py +3 -1
- dstack/_internal/core/models/configurations.py +24 -1
- dstack/_internal/core/models/fleets.py +46 -0
- dstack/_internal/core/models/instances.py +5 -1
- dstack/_internal/core/models/pools.py +4 -1
- dstack/_internal/core/models/profiles.py +10 -4
- dstack/_internal/core/models/runs.py +20 -0
- dstack/_internal/core/models/volumes.py +3 -0
- dstack/_internal/core/services/ssh/attach.py +92 -53
- dstack/_internal/core/services/ssh/tunnel.py +58 -31
- dstack/_internal/proxy/gateway/routers/registry.py +2 -0
- dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
- dstack/_internal/proxy/gateway/services/registry.py +4 -0
- dstack/_internal/proxy/lib/models.py +3 -0
- dstack/_internal/proxy/lib/services/service_connection.py +8 -1
- dstack/_internal/server/background/tasks/process_instances.py +72 -33
- dstack/_internal/server/background/tasks/process_metrics.py +9 -9
- dstack/_internal/server/background/tasks/process_running_jobs.py +73 -26
- dstack/_internal/server/background/tasks/process_runs.py +2 -12
- dstack/_internal/server/background/tasks/process_submitted_jobs.py +109 -42
- dstack/_internal/server/background/tasks/process_terminating_jobs.py +1 -1
- dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
- dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
- dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
- dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
- dstack/_internal/server/models.py +10 -4
- dstack/_internal/server/routers/runs.py +1 -0
- dstack/_internal/server/schemas/runner.py +1 -0
- dstack/_internal/server/services/backends/configurators/azure.py +34 -8
- dstack/_internal/server/services/config.py +9 -0
- dstack/_internal/server/services/fleets.py +27 -2
- dstack/_internal/server/services/gateways/client.py +9 -1
- dstack/_internal/server/services/jobs/__init__.py +215 -43
- dstack/_internal/server/services/jobs/configurators/base.py +47 -2
- dstack/_internal/server/services/offers.py +91 -5
- dstack/_internal/server/services/pools.py +95 -11
- dstack/_internal/server/services/proxy/repo.py +17 -3
- dstack/_internal/server/services/runner/client.py +1 -1
- dstack/_internal/server/services/runner/ssh.py +33 -5
- dstack/_internal/server/services/runs.py +48 -179
- dstack/_internal/server/services/services/__init__.py +9 -1
- dstack/_internal/server/statics/index.html +1 -1
- dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
- dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
- dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
- dstack/_internal/server/testing/common.py +117 -52
- dstack/_internal/utils/common.py +22 -8
- dstack/_internal/utils/env.py +14 -0
- dstack/_internal/utils/ssh.py +1 -1
- dstack/api/server/_fleets.py +25 -1
- dstack/api/server/_runs.py +23 -2
- dstack/api/server/_volumes.py +12 -1
- dstack/version.py +1 -1
- {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/METADATA +1 -1
- {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/RECORD +98 -89
- tests/_internal/cli/services/configurators/test_profile.py +3 -3
- tests/_internal/core/services/ssh/test_tunnel.py +56 -4
- tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
- tests/_internal/server/background/tasks/test_process_instances.py +138 -20
- tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
- tests/_internal/server/background/tasks/test_process_running_jobs.py +192 -0
- tests/_internal/server/background/tasks/test_process_runs.py +27 -3
- tests/_internal/server/background/tasks/test_process_submitted_jobs.py +48 -3
- tests/_internal/server/background/tasks/test_process_terminating_jobs.py +126 -13
- tests/_internal/server/routers/test_fleets.py +15 -2
- tests/_internal/server/routers/test_pools.py +6 -0
- tests/_internal/server/routers/test_runs.py +27 -0
- tests/_internal/server/services/jobs/__init__.py +0 -0
- tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
- tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
- tests/_internal/server/services/test_pools.py +4 -0
- tests/_internal/server/services/test_runs.py +5 -41
- tests/_internal/utils/test_common.py +21 -0
- tests/_internal/utils/test_env.py +38 -0
- {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/LICENSE.md +0 -0
- {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/WHEEL +0 -0
- {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/entry_points.txt +0 -0
- {dstack-0.18.40rc1.dist-info → dstack-0.18.41.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@ import ipaddress
|
|
|
2
2
|
import uuid
|
|
3
3
|
from collections.abc import Container, Iterable
|
|
4
4
|
from datetime import datetime, timezone
|
|
5
|
-
from typing import List, Optional
|
|
5
|
+
from typing import List, Literal, Optional, Union
|
|
6
6
|
|
|
7
7
|
import gpuhunt
|
|
8
8
|
from sqlalchemy import and_, or_, select
|
|
@@ -30,6 +30,7 @@ from dstack._internal.core.models.instances import (
|
|
|
30
30
|
InstanceType,
|
|
31
31
|
RemoteConnectionInfo,
|
|
32
32
|
Resources,
|
|
33
|
+
SSHConnectionParams,
|
|
33
34
|
SSHKey,
|
|
34
35
|
)
|
|
35
36
|
from dstack._internal.core.models.pools import Instance, Pool, PoolInstances
|
|
@@ -51,6 +52,7 @@ from dstack._internal.server.models import (
|
|
|
51
52
|
UserModel,
|
|
52
53
|
)
|
|
53
54
|
from dstack._internal.server.services.locking import get_locker
|
|
55
|
+
from dstack._internal.server.services.offers import generate_shared_offer
|
|
54
56
|
from dstack._internal.server.services.projects import list_project_models, list_user_project_models
|
|
55
57
|
from dstack._internal.utils import common as common_utils
|
|
56
58
|
from dstack._internal.utils import random_names
|
|
@@ -206,7 +208,7 @@ async def remove_instance(
|
|
|
206
208
|
terminated = False
|
|
207
209
|
for instance in pool.instances:
|
|
208
210
|
if instance.name == instance_name:
|
|
209
|
-
if force or instance.
|
|
211
|
+
if force or not instance.jobs:
|
|
210
212
|
instance.status = InstanceStatus.TERMINATING
|
|
211
213
|
terminated = True
|
|
212
214
|
await session.commit()
|
|
@@ -249,6 +251,8 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance:
|
|
|
249
251
|
unreachable=instance_model.unreachable,
|
|
250
252
|
termination_reason=instance_model.termination_reason,
|
|
251
253
|
created=instance_model.created_at.replace(tzinfo=timezone.utc),
|
|
254
|
+
total_blocks=instance_model.total_blocks,
|
|
255
|
+
busy_blocks=instance_model.busy_blocks,
|
|
252
256
|
)
|
|
253
257
|
|
|
254
258
|
offer = get_instance_offer(instance_model)
|
|
@@ -261,9 +265,7 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance:
|
|
|
261
265
|
if jpd is not None:
|
|
262
266
|
instance.instance_type = jpd.instance_type
|
|
263
267
|
instance.hostname = jpd.hostname
|
|
264
|
-
|
|
265
|
-
if instance_model.job is not None:
|
|
266
|
-
instance.job_name = instance_model.job.job_name
|
|
268
|
+
instance.availability_zone = jpd.availability_zone
|
|
267
269
|
|
|
268
270
|
return instance
|
|
269
271
|
|
|
@@ -292,6 +294,27 @@ def get_instance_requirements(instance_model: InstanceModel) -> Requirements:
|
|
|
292
294
|
return Requirements.__response__.parse_raw(instance_model.requirements)
|
|
293
295
|
|
|
294
296
|
|
|
297
|
+
def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, Optional[str]]:
|
|
298
|
+
"""
|
|
299
|
+
Returns a pair of SSH private keys: host key and optional proxy jump key.
|
|
300
|
+
"""
|
|
301
|
+
host_private_key = instance_model.project.ssh_private_key
|
|
302
|
+
if instance_model.remote_connection_info is None:
|
|
303
|
+
# Cloud instance
|
|
304
|
+
return host_private_key, None
|
|
305
|
+
# SSH instance
|
|
306
|
+
rci = RemoteConnectionInfo.__response__.parse_raw(instance_model.remote_connection_info)
|
|
307
|
+
if rci.ssh_proxy is None:
|
|
308
|
+
return host_private_key, None
|
|
309
|
+
if rci.ssh_proxy_keys is None:
|
|
310
|
+
# Inconsistent RemoteConnectionInfo structure - proxy without keys
|
|
311
|
+
raise ValueError("Missing instance SSH proxy private keys")
|
|
312
|
+
proxy_private_keys = [key.private for key in rci.ssh_proxy_keys if key.private is not None]
|
|
313
|
+
if not proxy_private_keys:
|
|
314
|
+
raise ValueError("No instance SSH proxy private key found")
|
|
315
|
+
return host_private_key, proxy_private_keys[0]
|
|
316
|
+
|
|
317
|
+
|
|
295
318
|
async def generate_instance_name(
|
|
296
319
|
session: AsyncSession,
|
|
297
320
|
project: ProjectModel,
|
|
@@ -409,20 +432,21 @@ async def add_remote(
|
|
|
409
432
|
def filter_pool_instances(
|
|
410
433
|
pool_instances: List[InstanceModel],
|
|
411
434
|
profile: Profile,
|
|
412
|
-
requirements: Requirements,
|
|
413
435
|
*,
|
|
436
|
+
requirements: Optional[Requirements] = None,
|
|
414
437
|
status: Optional[InstanceStatus] = None,
|
|
415
438
|
fleet_model: Optional[FleetModel] = None,
|
|
416
439
|
multinode: bool = False,
|
|
417
440
|
master_job_provisioning_data: Optional[JobProvisioningData] = None,
|
|
418
441
|
volumes: Optional[List[List[Volume]]] = None,
|
|
442
|
+
shared: bool = False,
|
|
419
443
|
) -> List[InstanceModel]:
|
|
420
444
|
instances: List[InstanceModel] = []
|
|
421
445
|
candidates: List[InstanceModel] = []
|
|
422
446
|
|
|
423
447
|
backend_types = profile.backends
|
|
424
448
|
regions = profile.regions
|
|
425
|
-
zones =
|
|
449
|
+
zones = profile.availability_zones
|
|
426
450
|
|
|
427
451
|
if volumes:
|
|
428
452
|
mount_point_volumes = volumes[0]
|
|
@@ -433,8 +457,9 @@ def filter_pool_instances(
|
|
|
433
457
|
for v in mount_point_volumes
|
|
434
458
|
if v.provisioning_data is not None
|
|
435
459
|
]
|
|
436
|
-
if
|
|
460
|
+
if zones is None:
|
|
437
461
|
zones = volume_zones
|
|
462
|
+
zones = [z for z in zones if z in volume_zones]
|
|
438
463
|
|
|
439
464
|
if multinode:
|
|
440
465
|
if not backend_types:
|
|
@@ -480,9 +505,17 @@ def filter_pool_instances(
|
|
|
480
505
|
and jpd.availability_zone not in zones
|
|
481
506
|
):
|
|
482
507
|
continue
|
|
508
|
+
if instance.total_blocks is None:
|
|
509
|
+
# Still provisioning, we don't know yet if it shared or not
|
|
510
|
+
continue
|
|
511
|
+
if (instance.total_blocks > 1) != shared:
|
|
512
|
+
continue
|
|
483
513
|
|
|
484
514
|
candidates.append(instance)
|
|
485
515
|
|
|
516
|
+
if requirements is None:
|
|
517
|
+
return candidates
|
|
518
|
+
|
|
486
519
|
query_filter = requirements_to_query_filter(requirements)
|
|
487
520
|
for instance in candidates:
|
|
488
521
|
if instance.offer is None:
|
|
@@ -494,6 +527,47 @@ def filter_pool_instances(
|
|
|
494
527
|
return instances
|
|
495
528
|
|
|
496
529
|
|
|
530
|
+
def get_shared_pool_instances_with_offers(
|
|
531
|
+
pool_instances: List[InstanceModel],
|
|
532
|
+
profile: Profile,
|
|
533
|
+
requirements: Requirements,
|
|
534
|
+
*,
|
|
535
|
+
idle_only: bool = False,
|
|
536
|
+
fleet_model: Optional[FleetModel] = None,
|
|
537
|
+
volumes: Optional[List[List[Volume]]] = None,
|
|
538
|
+
) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]:
|
|
539
|
+
instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]] = []
|
|
540
|
+
query_filter = requirements_to_query_filter(requirements)
|
|
541
|
+
filtered_instances = filter_pool_instances(
|
|
542
|
+
pool_instances=pool_instances,
|
|
543
|
+
profile=profile,
|
|
544
|
+
fleet_model=fleet_model,
|
|
545
|
+
multinode=False,
|
|
546
|
+
volumes=volumes,
|
|
547
|
+
shared=True,
|
|
548
|
+
)
|
|
549
|
+
for instance in filtered_instances:
|
|
550
|
+
if idle_only and instance.status not in [InstanceStatus.IDLE, InstanceStatus.BUSY]:
|
|
551
|
+
continue
|
|
552
|
+
offer = get_instance_offer(instance)
|
|
553
|
+
if offer is None:
|
|
554
|
+
continue
|
|
555
|
+
total_blocks = common_utils.get_or_error(instance.total_blocks)
|
|
556
|
+
idle_blocks = total_blocks - instance.busy_blocks
|
|
557
|
+
for blocks in range(1, total_blocks + 1):
|
|
558
|
+
shared_offer = generate_shared_offer(offer, blocks, total_blocks)
|
|
559
|
+
catalog_item = offer_to_catalog_item(shared_offer)
|
|
560
|
+
if gpuhunt.matches(catalog_item, query_filter):
|
|
561
|
+
if blocks <= idle_blocks:
|
|
562
|
+
shared_offer.availability = InstanceAvailability.IDLE
|
|
563
|
+
else:
|
|
564
|
+
shared_offer.availability = InstanceAvailability.BUSY
|
|
565
|
+
if shared_offer.availability == InstanceAvailability.IDLE or not idle_only:
|
|
566
|
+
instances_with_offers.append((instance, shared_offer))
|
|
567
|
+
break
|
|
568
|
+
return instances_with_offers
|
|
569
|
+
|
|
570
|
+
|
|
497
571
|
async def list_pools_instance_models(
|
|
498
572
|
session: AsyncSession,
|
|
499
573
|
projects: List[ProjectModel],
|
|
@@ -557,7 +631,7 @@ async def list_pools_instance_models(
|
|
|
557
631
|
.limit(limit)
|
|
558
632
|
.options(joinedload(InstanceModel.pool), joinedload(InstanceModel.fleet))
|
|
559
633
|
)
|
|
560
|
-
instance_models = list(res.scalars().all())
|
|
634
|
+
instance_models = list(res.unique().scalars().all())
|
|
561
635
|
return instance_models
|
|
562
636
|
|
|
563
637
|
|
|
@@ -618,7 +692,7 @@ async def list_active_remote_instances(
|
|
|
618
692
|
res = await session.execute(
|
|
619
693
|
select(InstanceModel).where(*filters).order_by(InstanceModel.created_at.asc())
|
|
620
694
|
)
|
|
621
|
-
instance_models = list(res.scalars().all())
|
|
695
|
+
instance_models = list(res.unique().scalars().all())
|
|
622
696
|
return instance_models
|
|
623
697
|
|
|
624
698
|
|
|
@@ -633,6 +707,7 @@ async def create_instance_model(
|
|
|
633
707
|
instance_num: int,
|
|
634
708
|
placement_group_name: Optional[str],
|
|
635
709
|
reservation: Optional[str],
|
|
710
|
+
blocks: Union[Literal["auto"], int],
|
|
636
711
|
) -> InstanceModel:
|
|
637
712
|
termination_policy, termination_idle_time = get_termination(
|
|
638
713
|
profile, DEFAULT_POOL_TERMINATION_IDLE_TIME
|
|
@@ -646,8 +721,8 @@ async def create_instance_model(
|
|
|
646
721
|
project_name=project.name,
|
|
647
722
|
instance_name=instance_name,
|
|
648
723
|
user=user.name,
|
|
649
|
-
instance_id=str(instance_id),
|
|
650
724
|
ssh_keys=[project_ssh_key],
|
|
725
|
+
instance_id=str(instance_id),
|
|
651
726
|
placement_group_name=placement_group_name,
|
|
652
727
|
reservation=reservation,
|
|
653
728
|
)
|
|
@@ -665,6 +740,8 @@ async def create_instance_model(
|
|
|
665
740
|
instance_configuration=instance_config.json(),
|
|
666
741
|
termination_policy=termination_policy,
|
|
667
742
|
termination_idle_time=termination_idle_time,
|
|
743
|
+
total_blocks=None if blocks == "auto" else blocks,
|
|
744
|
+
busy_blocks=0,
|
|
668
745
|
)
|
|
669
746
|
session.add(instance)
|
|
670
747
|
return instance
|
|
@@ -682,7 +759,10 @@ async def create_ssh_instance_model(
|
|
|
682
759
|
port: int,
|
|
683
760
|
ssh_user: str,
|
|
684
761
|
ssh_keys: List[SSHKey],
|
|
762
|
+
ssh_proxy: Optional[SSHConnectionParams],
|
|
763
|
+
ssh_proxy_keys: Optional[list[SSHKey]],
|
|
685
764
|
env: Env,
|
|
765
|
+
blocks: Union[Literal["auto"], int],
|
|
686
766
|
) -> InstanceModel:
|
|
687
767
|
# TODO: doc - will overwrite after remote connected
|
|
688
768
|
instance_resource = Resources(cpus=2, memory_mib=8, gpus=[], spot=False)
|
|
@@ -717,6 +797,8 @@ async def create_ssh_instance_model(
|
|
|
717
797
|
port=port,
|
|
718
798
|
ssh_user=ssh_user,
|
|
719
799
|
ssh_keys=ssh_keys,
|
|
800
|
+
ssh_proxy=ssh_proxy,
|
|
801
|
+
ssh_proxy_keys=ssh_proxy_keys,
|
|
720
802
|
env=env,
|
|
721
803
|
)
|
|
722
804
|
im = InstanceModel(
|
|
@@ -737,5 +819,7 @@ async def create_ssh_instance_model(
|
|
|
737
819
|
price=offer.price,
|
|
738
820
|
termination_policy=TerminationPolicy.DONT_DESTROY,
|
|
739
821
|
termination_idle_time=0,
|
|
822
|
+
total_blocks=None if blocks == "auto" else blocks,
|
|
823
|
+
busy_blocks=0,
|
|
740
824
|
)
|
|
741
825
|
return im
|
|
@@ -9,7 +9,7 @@ import dstack._internal.server.services.jobs as jobs_services
|
|
|
9
9
|
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
|
|
10
10
|
from dstack._internal.core.models.common import is_core_model_instance
|
|
11
11
|
from dstack._internal.core.models.configurations import ServiceConfiguration
|
|
12
|
-
from dstack._internal.core.models.instances import SSHConnectionParams
|
|
12
|
+
from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams
|
|
13
13
|
from dstack._internal.core.models.runs import (
|
|
14
14
|
JobProvisioningData,
|
|
15
15
|
JobStatus,
|
|
@@ -30,6 +30,7 @@ from dstack._internal.proxy.lib.models import (
|
|
|
30
30
|
from dstack._internal.proxy.lib.repo import BaseProxyRepo
|
|
31
31
|
from dstack._internal.server.models import JobModel, ProjectModel, RunModel
|
|
32
32
|
from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE
|
|
33
|
+
from dstack._internal.utils.common import get_or_error
|
|
33
34
|
|
|
34
35
|
|
|
35
36
|
class ServerProxyRepo(BaseProxyRepo):
|
|
@@ -53,9 +54,12 @@ class ServerProxyRepo(BaseProxyRepo):
|
|
|
53
54
|
JobModel.status == JobStatus.RUNNING,
|
|
54
55
|
JobModel.job_num == 0,
|
|
55
56
|
)
|
|
56
|
-
.options(
|
|
57
|
+
.options(
|
|
58
|
+
joinedload(JobModel.run),
|
|
59
|
+
joinedload(JobModel.instance),
|
|
60
|
+
)
|
|
57
61
|
)
|
|
58
|
-
jobs = res.scalars().all()
|
|
62
|
+
jobs = res.unique().scalars().all()
|
|
59
63
|
if not len(jobs):
|
|
60
64
|
return None
|
|
61
65
|
run = jobs[0].run
|
|
@@ -83,12 +87,22 @@ class ServerProxyRepo(BaseProxyRepo):
|
|
|
83
87
|
username=jpd.username,
|
|
84
88
|
port=jpd.ssh_port,
|
|
85
89
|
)
|
|
90
|
+
ssh_head_proxy: Optional[SSHConnectionParams] = None
|
|
91
|
+
ssh_head_proxy_private_key: Optional[str] = None
|
|
92
|
+
instance = get_or_error(job.instance)
|
|
93
|
+
if instance.remote_connection_info is not None:
|
|
94
|
+
rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
|
|
95
|
+
if rci.ssh_proxy is not None:
|
|
96
|
+
ssh_head_proxy = rci.ssh_proxy
|
|
97
|
+
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
|
|
86
98
|
replica = Replica(
|
|
87
99
|
id=job.id.hex,
|
|
88
100
|
app_port=run_spec.configuration.port.container_port,
|
|
89
101
|
ssh_destination=ssh_destination,
|
|
90
102
|
ssh_port=ssh_port,
|
|
91
103
|
ssh_proxy=ssh_proxy,
|
|
104
|
+
ssh_head_proxy=ssh_head_proxy,
|
|
105
|
+
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
|
|
92
106
|
)
|
|
93
107
|
replicas.append(replica)
|
|
94
108
|
return Service(
|
|
@@ -17,22 +17,36 @@ from dstack._internal.utils.path import FileContent
|
|
|
17
17
|
logger = get_logger(__name__)
|
|
18
18
|
P = ParamSpec("P")
|
|
19
19
|
R = TypeVar("R")
|
|
20
|
+
# A host private key or pair of (host private key, optional proxy jump private key)
|
|
21
|
+
PrivateKeyOrPair = Union[str, tuple[str, Optional[str]]]
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
def runner_ssh_tunnel(
|
|
23
25
|
ports: List[int], retries: int = 3, retry_interval: float = 1
|
|
24
26
|
) -> Callable[
|
|
25
27
|
[Callable[Concatenate[Dict[int, int], P], R]],
|
|
26
|
-
Callable[
|
|
28
|
+
Callable[
|
|
29
|
+
Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P],
|
|
30
|
+
Union[bool, R],
|
|
31
|
+
],
|
|
27
32
|
]:
|
|
33
|
+
"""
|
|
34
|
+
A decorator that opens an SSH tunnel to the runner.
|
|
35
|
+
|
|
36
|
+
NOTE: connections from dstack-server to running jobs are expected to be short.
|
|
37
|
+
The runner uses a heuristic to differentiate dstack-server connections from
|
|
38
|
+
client connections based on their duration. See `ConnectionTracker` for details.
|
|
39
|
+
"""
|
|
40
|
+
|
|
28
41
|
def decorator(
|
|
29
42
|
func: Callable[Concatenate[Dict[int, int], P], R],
|
|
30
43
|
) -> Callable[
|
|
31
|
-
Concatenate[
|
|
44
|
+
Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P],
|
|
45
|
+
Union[bool, R],
|
|
32
46
|
]:
|
|
33
47
|
@functools.wraps(func)
|
|
34
48
|
def wrapper(
|
|
35
|
-
ssh_private_key:
|
|
49
|
+
ssh_private_key: PrivateKeyOrPair,
|
|
36
50
|
job_provisioning_data: JobProvisioningData,
|
|
37
51
|
job_runtime_data: Optional[JobRuntimeData],
|
|
38
52
|
*args: P.args,
|
|
@@ -51,6 +65,20 @@ def runner_ssh_tunnel(
|
|
|
51
65
|
# without SSH
|
|
52
66
|
return func(container_ports_map, *args, **kwargs)
|
|
53
67
|
|
|
68
|
+
if isinstance(ssh_private_key, str):
|
|
69
|
+
ssh_proxy_private_key = None
|
|
70
|
+
else:
|
|
71
|
+
ssh_private_key, ssh_proxy_private_key = ssh_private_key
|
|
72
|
+
identity = FileContent(ssh_private_key)
|
|
73
|
+
if ssh_proxy_private_key is not None:
|
|
74
|
+
proxy_identity = FileContent(ssh_proxy_private_key)
|
|
75
|
+
else:
|
|
76
|
+
proxy_identity = None
|
|
77
|
+
|
|
78
|
+
ssh_proxies = []
|
|
79
|
+
if job_provisioning_data.ssh_proxy is not None:
|
|
80
|
+
ssh_proxies.append((job_provisioning_data.ssh_proxy, proxy_identity))
|
|
81
|
+
|
|
54
82
|
for attempt in range(retries):
|
|
55
83
|
last = attempt == retries - 1
|
|
56
84
|
# remote_host:local mapping
|
|
@@ -66,8 +94,8 @@ def runner_ssh_tunnel(
|
|
|
66
94
|
),
|
|
67
95
|
port=job_provisioning_data.ssh_port,
|
|
68
96
|
forwarded_sockets=ports_to_forwarded_sockets(tunnel_ports_map),
|
|
69
|
-
identity=
|
|
70
|
-
|
|
97
|
+
identity=identity,
|
|
98
|
+
ssh_proxies=ssh_proxies,
|
|
71
99
|
):
|
|
72
100
|
return func(runner_ports_map, *args, **kwargs)
|
|
73
101
|
except SSHError:
|