dstack 0.18.40__py3-none-any.whl → 0.18.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dstack/_internal/cli/commands/apply.py +8 -5
- dstack/_internal/cli/services/configurators/base.py +4 -2
- dstack/_internal/cli/services/configurators/fleet.py +21 -9
- dstack/_internal/cli/services/configurators/gateway.py +15 -0
- dstack/_internal/cli/services/configurators/run.py +6 -5
- dstack/_internal/cli/services/configurators/volume.py +15 -0
- dstack/_internal/cli/services/repos.py +3 -3
- dstack/_internal/cli/utils/fleet.py +44 -33
- dstack/_internal/cli/utils/run.py +27 -7
- dstack/_internal/cli/utils/volume.py +21 -9
- dstack/_internal/core/backends/aws/compute.py +92 -52
- dstack/_internal/core/backends/aws/resources.py +22 -12
- dstack/_internal/core/backends/azure/compute.py +2 -0
- dstack/_internal/core/backends/base/compute.py +20 -2
- dstack/_internal/core/backends/gcp/compute.py +30 -23
- dstack/_internal/core/backends/gcp/resources.py +0 -15
- dstack/_internal/core/backends/oci/compute.py +10 -5
- dstack/_internal/core/backends/oci/resources.py +23 -26
- dstack/_internal/core/backends/remote/provisioning.py +65 -27
- dstack/_internal/core/backends/runpod/compute.py +1 -0
- dstack/_internal/core/models/backends/azure.py +3 -1
- dstack/_internal/core/models/configurations.py +24 -1
- dstack/_internal/core/models/fleets.py +46 -0
- dstack/_internal/core/models/instances.py +5 -1
- dstack/_internal/core/models/pools.py +4 -1
- dstack/_internal/core/models/profiles.py +10 -4
- dstack/_internal/core/models/runs.py +20 -0
- dstack/_internal/core/models/volumes.py +3 -0
- dstack/_internal/core/services/ssh/attach.py +92 -53
- dstack/_internal/core/services/ssh/tunnel.py +58 -31
- dstack/_internal/proxy/gateway/routers/registry.py +2 -0
- dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
- dstack/_internal/proxy/gateway/services/registry.py +4 -0
- dstack/_internal/proxy/lib/models.py +3 -0
- dstack/_internal/proxy/lib/services/service_connection.py +8 -1
- dstack/_internal/server/background/tasks/process_instances.py +72 -33
- dstack/_internal/server/background/tasks/process_metrics.py +9 -9
- dstack/_internal/server/background/tasks/process_running_jobs.py +73 -26
- dstack/_internal/server/background/tasks/process_runs.py +2 -12
- dstack/_internal/server/background/tasks/process_submitted_jobs.py +109 -42
- dstack/_internal/server/background/tasks/process_terminating_jobs.py +1 -1
- dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
- dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
- dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
- dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
- dstack/_internal/server/models.py +10 -4
- dstack/_internal/server/routers/runs.py +1 -0
- dstack/_internal/server/schemas/runner.py +1 -0
- dstack/_internal/server/services/backends/configurators/azure.py +34 -8
- dstack/_internal/server/services/config.py +9 -0
- dstack/_internal/server/services/fleets.py +27 -2
- dstack/_internal/server/services/gateways/client.py +9 -1
- dstack/_internal/server/services/jobs/__init__.py +215 -43
- dstack/_internal/server/services/jobs/configurators/base.py +47 -2
- dstack/_internal/server/services/offers.py +91 -5
- dstack/_internal/server/services/pools.py +95 -11
- dstack/_internal/server/services/proxy/repo.py +17 -3
- dstack/_internal/server/services/runner/client.py +1 -1
- dstack/_internal/server/services/runner/ssh.py +33 -5
- dstack/_internal/server/services/runs.py +48 -179
- dstack/_internal/server/services/services/__init__.py +9 -1
- dstack/_internal/server/statics/index.html +1 -1
- dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
- dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
- dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
- dstack/_internal/server/testing/common.py +117 -52
- dstack/_internal/utils/common.py +22 -8
- dstack/_internal/utils/env.py +14 -0
- dstack/_internal/utils/ssh.py +1 -1
- dstack/api/server/_fleets.py +25 -1
- dstack/api/server/_runs.py +23 -2
- dstack/api/server/_volumes.py +12 -1
- dstack/version.py +1 -1
- {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/METADATA +1 -1
- {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/RECORD +98 -89
- tests/_internal/cli/services/configurators/test_profile.py +3 -3
- tests/_internal/core/services/ssh/test_tunnel.py +56 -4
- tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
- tests/_internal/server/background/tasks/test_process_instances.py +138 -20
- tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
- tests/_internal/server/background/tasks/test_process_running_jobs.py +192 -0
- tests/_internal/server/background/tasks/test_process_runs.py +27 -3
- tests/_internal/server/background/tasks/test_process_submitted_jobs.py +48 -3
- tests/_internal/server/background/tasks/test_process_terminating_jobs.py +126 -13
- tests/_internal/server/routers/test_fleets.py +15 -2
- tests/_internal/server/routers/test_pools.py +6 -0
- tests/_internal/server/routers/test_runs.py +27 -0
- tests/_internal/server/services/jobs/__init__.py +0 -0
- tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
- tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
- tests/_internal/server/services/test_pools.py +4 -0
- tests/_internal/server/services/test_runs.py +5 -41
- tests/_internal/utils/test_common.py +21 -0
- tests/_internal/utils/test_env.py +38 -0
- {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/LICENSE.md +0 -0
- {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/WHEEL +0 -0
- {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/entry_points.txt +0 -0
- {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/top_level.txt +0 -0
|
@@ -69,13 +69,16 @@ class SSHTunnel:
|
|
|
69
69
|
options: Dict[str, str] = SSH_DEFAULT_OPTIONS,
|
|
70
70
|
ssh_config_path: Union[PathLike, Literal["none"]] = "none",
|
|
71
71
|
port: Optional[int] = None,
|
|
72
|
-
|
|
72
|
+
ssh_proxies: Iterable[tuple[SSHConnectionParams, Optional[FilePathOrContent]]] = (),
|
|
73
73
|
):
|
|
74
74
|
"""
|
|
75
75
|
:param forwarded_sockets: Connections to the specified local sockets will be
|
|
76
76
|
forwarded to their corresponding remote sockets
|
|
77
77
|
:param reverse_forwarded_sockets: Connections to the specified remote sockets
|
|
78
78
|
will be forwarded to their corresponding local sockets
|
|
79
|
+
:param ssh_proxies: pairs of SSH connections params and optional identities,
|
|
80
|
+
in order from outer to inner. If an identity is `None`, the `identity` param
|
|
81
|
+
is used instead.
|
|
79
82
|
"""
|
|
80
83
|
self.destination = destination
|
|
81
84
|
self.forwarded_sockets = list(forwarded_sockets)
|
|
@@ -83,21 +86,21 @@ class SSHTunnel:
|
|
|
83
86
|
self.options = options
|
|
84
87
|
self.port = port
|
|
85
88
|
self.ssh_config_path = normalize_path(ssh_config_path)
|
|
86
|
-
self.ssh_proxy = ssh_proxy
|
|
87
89
|
temp_dir = tempfile.TemporaryDirectory()
|
|
88
90
|
self.temp_dir = temp_dir
|
|
89
91
|
if control_sock_path is None:
|
|
90
92
|
control_sock_path = os.path.join(temp_dir.name, "control.sock")
|
|
91
93
|
self.control_sock_path = normalize_path(control_sock_path)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
94
|
+
self.identity_path = normalize_path(self._get_identity_path(identity, "identity"))
|
|
95
|
+
self.ssh_proxies: list[tuple[SSHConnectionParams, PathLike]] = []
|
|
96
|
+
for proxy_index, (proxy_params, proxy_identity) in enumerate(ssh_proxies):
|
|
97
|
+
if proxy_identity is None:
|
|
98
|
+
proxy_identity_path = self.identity_path
|
|
99
|
+
else:
|
|
100
|
+
proxy_identity_path = self._get_identity_path(
|
|
101
|
+
proxy_identity, f"proxy_identity_{proxy_index}"
|
|
102
|
+
)
|
|
103
|
+
self.ssh_proxies.append((proxy_params, proxy_identity_path))
|
|
101
104
|
self.log_path = normalize_path(os.path.join(temp_dir.name, "tunnel.log"))
|
|
102
105
|
self.ssh_client_info = get_ssh_client_info()
|
|
103
106
|
self.ssh_exec_path = str(self.ssh_client_info.path)
|
|
@@ -142,8 +145,8 @@ class SSHTunnel:
|
|
|
142
145
|
command += ["-p", str(self.port)]
|
|
143
146
|
for k, v in self.options.items():
|
|
144
147
|
command += ["-o", f"{k}={v}"]
|
|
145
|
-
if proxy_command := self.
|
|
146
|
-
command += ["-o",
|
|
148
|
+
if proxy_command := self._get_proxy_command():
|
|
149
|
+
command += ["-o", proxy_command]
|
|
147
150
|
for socket_pair in self.forwarded_sockets:
|
|
148
151
|
command += ["-L", f"{socket_pair.local.render()}:{socket_pair.remote.render()}"]
|
|
149
152
|
for socket_pair in self.reverse_forwarded_sockets:
|
|
@@ -160,24 +163,6 @@ class SSHTunnel:
|
|
|
160
163
|
def exec_command(self) -> List[str]:
|
|
161
164
|
return [self.ssh_exec_path, "-S", self.control_sock_path, self.destination]
|
|
162
165
|
|
|
163
|
-
def proxy_command(self) -> Optional[List[str]]:
|
|
164
|
-
if self.ssh_proxy is None:
|
|
165
|
-
return None
|
|
166
|
-
return [
|
|
167
|
-
self.ssh_exec_path,
|
|
168
|
-
"-i",
|
|
169
|
-
self.identity_path,
|
|
170
|
-
"-W",
|
|
171
|
-
"%h:%p",
|
|
172
|
-
"-o",
|
|
173
|
-
"StrictHostKeyChecking=no",
|
|
174
|
-
"-o",
|
|
175
|
-
"UserKnownHostsFile=/dev/null",
|
|
176
|
-
"-p",
|
|
177
|
-
str(self.ssh_proxy.port),
|
|
178
|
-
f"{self.ssh_proxy.username}@{self.ssh_proxy.hostname}",
|
|
179
|
-
]
|
|
180
|
-
|
|
181
166
|
def open(self) -> None:
|
|
182
167
|
# We cannot use `stderr=subprocess.PIPE` here since the forked process (daemon) does not
|
|
183
168
|
# close standard streams if ProxyJump is used, therefore we will wait EOF from the pipe
|
|
@@ -251,6 +236,38 @@ class SSHTunnel:
|
|
|
251
236
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
252
237
|
self.close()
|
|
253
238
|
|
|
239
|
+
def _get_proxy_command(self) -> Optional[str]:
|
|
240
|
+
proxy_command: Optional[str] = None
|
|
241
|
+
for params, identity_path in self.ssh_proxies:
|
|
242
|
+
proxy_command = self._build_proxy_command(params, identity_path, proxy_command)
|
|
243
|
+
return proxy_command
|
|
244
|
+
|
|
245
|
+
def _build_proxy_command(
|
|
246
|
+
self,
|
|
247
|
+
params: SSHConnectionParams,
|
|
248
|
+
identity_path: PathLike,
|
|
249
|
+
prev_proxy_command: Optional[str],
|
|
250
|
+
) -> Optional[str]:
|
|
251
|
+
command = [
|
|
252
|
+
self.ssh_exec_path,
|
|
253
|
+
"-i",
|
|
254
|
+
identity_path,
|
|
255
|
+
"-W",
|
|
256
|
+
"%h:%p",
|
|
257
|
+
"-o",
|
|
258
|
+
"StrictHostKeyChecking=no",
|
|
259
|
+
"-o",
|
|
260
|
+
"UserKnownHostsFile=/dev/null",
|
|
261
|
+
]
|
|
262
|
+
if prev_proxy_command is not None:
|
|
263
|
+
command += ["-o", prev_proxy_command.replace("%", "%%")]
|
|
264
|
+
command += [
|
|
265
|
+
"-p",
|
|
266
|
+
str(params.port),
|
|
267
|
+
f"{params.username}@{params.hostname}",
|
|
268
|
+
]
|
|
269
|
+
return "ProxyCommand=" + shlex.join(command)
|
|
270
|
+
|
|
254
271
|
def _read_log_file(self) -> bytes:
|
|
255
272
|
with open(self.log_path, "rb") as f:
|
|
256
273
|
return f.read()
|
|
@@ -263,6 +280,16 @@ class SSHTunnel:
|
|
|
263
280
|
except OSError as e:
|
|
264
281
|
logger.debug("Failed to remove SSH tunnel log file %s: %s", self.log_path, e)
|
|
265
282
|
|
|
283
|
+
def _get_identity_path(self, identity: FilePathOrContent, tmp_filename: str) -> PathLike:
|
|
284
|
+
if isinstance(identity, FilePath):
|
|
285
|
+
return identity.path
|
|
286
|
+
identity_path = os.path.join(self.temp_dir.name, tmp_filename)
|
|
287
|
+
with open(
|
|
288
|
+
identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w"
|
|
289
|
+
) as f:
|
|
290
|
+
f.write(identity.content)
|
|
291
|
+
return identity_path
|
|
292
|
+
|
|
266
293
|
|
|
267
294
|
def ports_to_forwarded_sockets(
|
|
268
295
|
ports: Dict[int, int], bind_local: str = "localhost"
|
|
@@ -76,6 +76,8 @@ async def register_replica(
|
|
|
76
76
|
ssh_destination=body.ssh_host,
|
|
77
77
|
ssh_port=body.ssh_port,
|
|
78
78
|
ssh_proxy=body.ssh_proxy,
|
|
79
|
+
ssh_head_proxy=body.ssh_head_proxy,
|
|
80
|
+
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
|
|
79
81
|
repo=repo,
|
|
80
82
|
nginx=nginx,
|
|
81
83
|
service_conn_pool=service_conn_pool,
|
|
@@ -50,6 +50,8 @@ class RegisterReplicaRequest(BaseModel):
|
|
|
50
50
|
ssh_host: str
|
|
51
51
|
ssh_port: int
|
|
52
52
|
ssh_proxy: Optional[SSHConnectionParams]
|
|
53
|
+
ssh_head_proxy: Optional[SSHConnectionParams]
|
|
54
|
+
ssh_head_proxy_private_key: Optional[str]
|
|
53
55
|
|
|
54
56
|
|
|
55
57
|
class RegisterEntrypointRequest(BaseModel):
|
|
@@ -123,6 +123,8 @@ async def register_replica(
|
|
|
123
123
|
ssh_destination: str,
|
|
124
124
|
ssh_port: int,
|
|
125
125
|
ssh_proxy: Optional[SSHConnectionParams],
|
|
126
|
+
ssh_head_proxy: Optional[SSHConnectionParams],
|
|
127
|
+
ssh_head_proxy_private_key: Optional[str],
|
|
126
128
|
repo: GatewayProxyRepo,
|
|
127
129
|
nginx: Nginx,
|
|
128
130
|
service_conn_pool: ServiceConnectionPool,
|
|
@@ -133,6 +135,8 @@ async def register_replica(
|
|
|
133
135
|
ssh_destination=ssh_destination,
|
|
134
136
|
ssh_port=ssh_port,
|
|
135
137
|
ssh_proxy=ssh_proxy,
|
|
138
|
+
ssh_head_proxy=ssh_head_proxy,
|
|
139
|
+
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
|
|
136
140
|
)
|
|
137
141
|
|
|
138
142
|
async with lock:
|
|
@@ -23,6 +23,9 @@ class Replica(ImmutableModel):
|
|
|
23
23
|
ssh_destination: str
|
|
24
24
|
ssh_port: int
|
|
25
25
|
ssh_proxy: Optional[SSHConnectionParams]
|
|
26
|
+
# Optional outer proxy, a head node/bastion
|
|
27
|
+
ssh_head_proxy: Optional[SSHConnectionParams] = None
|
|
28
|
+
ssh_head_proxy_private_key: Optional[str] = None
|
|
26
29
|
|
|
27
30
|
|
|
28
31
|
class Service(ImmutableModel):
|
|
@@ -18,6 +18,7 @@ from dstack._internal.core.services.ssh.tunnel import (
|
|
|
18
18
|
from dstack._internal.proxy.lib.errors import UnexpectedProxyError
|
|
19
19
|
from dstack._internal.proxy.lib.models import Project, Replica, Service
|
|
20
20
|
from dstack._internal.proxy.lib.repo import BaseProxyRepo
|
|
21
|
+
from dstack._internal.utils.common import get_or_error
|
|
21
22
|
from dstack._internal.utils.logging import get_logger
|
|
22
23
|
from dstack._internal.utils.path import FileContent
|
|
23
24
|
|
|
@@ -45,10 +46,16 @@ class ServiceConnection:
|
|
|
45
46
|
os.chmod(self._temp_dir.name, 0o755)
|
|
46
47
|
options["StreamLocalBindMask"] = "0111"
|
|
47
48
|
self._app_socket_path = (Path(self._temp_dir.name) / "replica.sock").absolute()
|
|
49
|
+
ssh_proxies = []
|
|
50
|
+
if replica.ssh_head_proxy is not None:
|
|
51
|
+
ssh_head_proxy_private_key = get_or_error(replica.ssh_head_proxy_private_key)
|
|
52
|
+
ssh_proxies.append((replica.ssh_head_proxy, FileContent(ssh_head_proxy_private_key)))
|
|
53
|
+
if replica.ssh_proxy is not None:
|
|
54
|
+
ssh_proxies.append((replica.ssh_proxy, None))
|
|
48
55
|
self._tunnel = SSHTunnel(
|
|
49
56
|
destination=replica.ssh_destination,
|
|
50
57
|
port=replica.ssh_port,
|
|
51
|
-
|
|
58
|
+
ssh_proxies=ssh_proxies,
|
|
52
59
|
identity=FileContent(project.ssh_private_key),
|
|
53
60
|
forwarded_sockets=[
|
|
54
61
|
SocketPair(
|
|
@@ -42,12 +42,12 @@ from dstack._internal.core.models.backends.base import BackendType
|
|
|
42
42
|
from dstack._internal.core.models.fleets import InstanceGroupPlacement
|
|
43
43
|
from dstack._internal.core.models.instances import (
|
|
44
44
|
InstanceAvailability,
|
|
45
|
-
InstanceConfiguration,
|
|
46
45
|
InstanceOfferWithAvailability,
|
|
47
46
|
InstanceRuntime,
|
|
48
47
|
InstanceStatus,
|
|
49
48
|
InstanceType,
|
|
50
49
|
RemoteConnectionInfo,
|
|
50
|
+
SSHKey,
|
|
51
51
|
)
|
|
52
52
|
from dstack._internal.core.models.placement import (
|
|
53
53
|
PlacementGroup,
|
|
@@ -77,6 +77,7 @@ from dstack._internal.server.services.fleets import (
|
|
|
77
77
|
get_create_instance_offers,
|
|
78
78
|
)
|
|
79
79
|
from dstack._internal.server.services.locking import get_locker
|
|
80
|
+
from dstack._internal.server.services.offers import is_divisible_into_blocks
|
|
80
81
|
from dstack._internal.server.services.placement import (
|
|
81
82
|
get_fleet_placement_groups,
|
|
82
83
|
placement_group_model_to_placement_group,
|
|
@@ -86,6 +87,7 @@ from dstack._internal.server.services.pools import (
|
|
|
86
87
|
get_instance_profile,
|
|
87
88
|
get_instance_provisioning_data,
|
|
88
89
|
get_instance_requirements,
|
|
90
|
+
get_instance_ssh_private_keys,
|
|
89
91
|
)
|
|
90
92
|
from dstack._internal.server.services.runner import client as runner_client
|
|
91
93
|
from dstack._internal.server.services.runner.client import HealthStatus
|
|
@@ -133,7 +135,7 @@ async def _process_next_instance():
|
|
|
133
135
|
),
|
|
134
136
|
InstanceModel.id.not_in(lockset),
|
|
135
137
|
)
|
|
136
|
-
.options(lazyload(InstanceModel.
|
|
138
|
+
.options(lazyload(InstanceModel.jobs))
|
|
137
139
|
.order_by(InstanceModel.last_processed_at.asc())
|
|
138
140
|
.limit(1)
|
|
139
141
|
.with_for_update(skip_locked=True)
|
|
@@ -156,7 +158,7 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
|
|
|
156
158
|
select(InstanceModel)
|
|
157
159
|
.where(InstanceModel.id == instance.id)
|
|
158
160
|
.options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
|
|
159
|
-
.options(joinedload(InstanceModel.
|
|
161
|
+
.options(joinedload(InstanceModel.jobs))
|
|
160
162
|
.options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
|
|
161
163
|
.execution_options(populate_existing=True)
|
|
162
164
|
)
|
|
@@ -164,7 +166,7 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
|
|
|
164
166
|
if (
|
|
165
167
|
instance.status == InstanceStatus.IDLE
|
|
166
168
|
and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
|
|
167
|
-
and instance.
|
|
169
|
+
and not instance.jobs
|
|
168
170
|
):
|
|
169
171
|
await _mark_terminating_if_idle_duration_expired(instance)
|
|
170
172
|
if instance.status == InstanceStatus.PENDING:
|
|
@@ -232,11 +234,11 @@ async def _add_remote(instance: InstanceModel) -> None:
|
|
|
232
234
|
remote_details = RemoteConnectionInfo.parse_raw(cast(str, instance.remote_connection_info))
|
|
233
235
|
# Prepare connection key
|
|
234
236
|
try:
|
|
235
|
-
pkeys =
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
237
|
+
pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys)
|
|
238
|
+
if remote_details.ssh_proxy_keys is not None:
|
|
239
|
+
ssh_proxy_pkeys = _ssh_keys_to_pkeys(remote_details.ssh_proxy_keys)
|
|
240
|
+
else:
|
|
241
|
+
ssh_proxy_pkeys = None
|
|
240
242
|
except (ValueError, PasswordRequiredException):
|
|
241
243
|
instance.status = InstanceStatus.TERMINATED
|
|
242
244
|
instance.termination_reason = "Unsupported private SSH key type"
|
|
@@ -254,7 +256,9 @@ async def _add_remote(instance: InstanceModel) -> None:
|
|
|
254
256
|
authorized_keys.append(instance.project.ssh_public_key.strip())
|
|
255
257
|
|
|
256
258
|
try:
|
|
257
|
-
future = run_async(
|
|
259
|
+
future = run_async(
|
|
260
|
+
_deploy_instance, remote_details, pkeys, ssh_proxy_pkeys, authorized_keys
|
|
261
|
+
)
|
|
258
262
|
deploy_timeout = 20 * 60 # 20 minutes
|
|
259
263
|
result = await asyncio.wait_for(future, timeout=deploy_timeout)
|
|
260
264
|
health, host_info = result
|
|
@@ -322,6 +326,26 @@ async def _add_remote(instance: InstanceModel) -> None:
|
|
|
322
326
|
)
|
|
323
327
|
return
|
|
324
328
|
|
|
329
|
+
divisible, blocks = is_divisible_into_blocks(
|
|
330
|
+
cpu_count=instance_type.resources.cpus,
|
|
331
|
+
gpu_count=len(instance_type.resources.gpus),
|
|
332
|
+
blocks="auto" if instance.total_blocks is None else instance.total_blocks,
|
|
333
|
+
)
|
|
334
|
+
if divisible:
|
|
335
|
+
instance.total_blocks = blocks
|
|
336
|
+
else:
|
|
337
|
+
instance.status = InstanceStatus.TERMINATED
|
|
338
|
+
instance.termination_reason = "Cannot split into blocks"
|
|
339
|
+
logger.warning(
|
|
340
|
+
"Failed to add instance %s: cannot split into blocks",
|
|
341
|
+
instance.name,
|
|
342
|
+
extra={
|
|
343
|
+
"instance_name": instance.name,
|
|
344
|
+
"instance_status": InstanceStatus.TERMINATED.value,
|
|
345
|
+
},
|
|
346
|
+
)
|
|
347
|
+
return
|
|
348
|
+
|
|
325
349
|
region = instance.region
|
|
326
350
|
jpd = JobProvisioningData(
|
|
327
351
|
backend=BackendType.REMOTE,
|
|
@@ -336,7 +360,7 @@ async def _add_remote(instance: InstanceModel) -> None:
|
|
|
336
360
|
ssh_port=remote_details.port,
|
|
337
361
|
dockerized=True,
|
|
338
362
|
backend_data=None,
|
|
339
|
-
ssh_proxy=
|
|
363
|
+
ssh_proxy=remote_details.ssh_proxy,
|
|
340
364
|
)
|
|
341
365
|
|
|
342
366
|
instance.status = InstanceStatus.IDLE if health else InstanceStatus.PROVISIONING
|
|
@@ -359,10 +383,16 @@ async def _add_remote(instance: InstanceModel) -> None:
|
|
|
359
383
|
def _deploy_instance(
|
|
360
384
|
remote_details: RemoteConnectionInfo,
|
|
361
385
|
pkeys: List[PKey],
|
|
386
|
+
ssh_proxy_pkeys: Optional[list[PKey]],
|
|
362
387
|
authorized_keys: List[str],
|
|
363
388
|
) -> Tuple[HealthStatus, Dict[str, Any]]:
|
|
364
389
|
with get_paramiko_connection(
|
|
365
|
-
remote_details.ssh_user,
|
|
390
|
+
remote_details.ssh_user,
|
|
391
|
+
remote_details.host,
|
|
392
|
+
remote_details.port,
|
|
393
|
+
pkeys,
|
|
394
|
+
remote_details.ssh_proxy,
|
|
395
|
+
ssh_proxy_pkeys,
|
|
366
396
|
) as client:
|
|
367
397
|
logger.info(f"Connected to {remote_details.ssh_user} {remote_details.host}")
|
|
368
398
|
|
|
@@ -479,6 +509,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
|
|
|
479
509
|
requirements=requirements,
|
|
480
510
|
exclude_not_available=True,
|
|
481
511
|
fleet_model=instance.fleet,
|
|
512
|
+
blocks="auto" if instance.total_blocks is None else instance.total_blocks,
|
|
482
513
|
)
|
|
483
514
|
|
|
484
515
|
if not offers and should_retry:
|
|
@@ -496,11 +527,10 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
|
|
|
496
527
|
session=session, fleet_id=instance.fleet_id
|
|
497
528
|
)
|
|
498
529
|
|
|
499
|
-
instance_configuration = _patch_instance_configuration(instance)
|
|
500
|
-
|
|
501
530
|
for backend, instance_offer in offers:
|
|
502
531
|
if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT:
|
|
503
532
|
continue
|
|
533
|
+
instance_offer = _get_instance_offer_for_instance(instance_offer, instance)
|
|
504
534
|
if (
|
|
505
535
|
instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT
|
|
506
536
|
and instance.fleet
|
|
@@ -554,6 +584,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
|
|
|
554
584
|
instance.instance_configuration = instance_configuration.json()
|
|
555
585
|
instance.job_provisioning_data = job_provisioning_data.json()
|
|
556
586
|
instance.offer = instance_offer.json()
|
|
587
|
+
instance.total_blocks = instance_offer.total_blocks
|
|
557
588
|
instance.started_at = get_current_datetime()
|
|
558
589
|
instance.last_retry_at = get_current_datetime()
|
|
559
590
|
|
|
@@ -585,8 +616,8 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
|
|
|
585
616
|
async def _check_instance(instance: InstanceModel) -> None:
|
|
586
617
|
if (
|
|
587
618
|
instance.status == InstanceStatus.BUSY
|
|
588
|
-
and instance.
|
|
589
|
-
and
|
|
619
|
+
and instance.jobs
|
|
620
|
+
and all(job.status.is_finished() for job in instance.jobs)
|
|
590
621
|
):
|
|
591
622
|
# A busy instance could have no active jobs due to this bug: https://github.com/dstackai/dstack/issues/2068
|
|
592
623
|
instance.status = InstanceStatus.TERMINATING
|
|
@@ -617,18 +648,14 @@ async def _check_instance(instance: InstanceModel) -> None:
|
|
|
617
648
|
instance.status = InstanceStatus.BUSY
|
|
618
649
|
return
|
|
619
650
|
|
|
620
|
-
|
|
621
|
-
# TODO: Drop this logic and always use project key once it's safe to assume that most on-prem
|
|
622
|
-
# fleets are (re)created after this change: https://github.com/dstackai/dstack/pull/1716
|
|
623
|
-
if instance.remote_connection_info is not None:
|
|
624
|
-
remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
|
|
625
|
-
instance.remote_connection_info
|
|
626
|
-
)
|
|
627
|
-
ssh_private_key = remote_conn_info.ssh_keys[0].private
|
|
651
|
+
ssh_private_keys = get_instance_ssh_private_keys(instance)
|
|
628
652
|
|
|
629
653
|
# May return False if fails to establish ssh connection
|
|
630
654
|
health_status_response = await run_async(
|
|
631
|
-
_instance_healthcheck,
|
|
655
|
+
_instance_healthcheck,
|
|
656
|
+
ssh_private_keys,
|
|
657
|
+
job_provisioning_data,
|
|
658
|
+
None,
|
|
632
659
|
)
|
|
633
660
|
if isinstance(health_status_response, bool) or health_status_response is None:
|
|
634
661
|
health_status = HealthStatus(healthy=False, reason="SSH or tunnel error")
|
|
@@ -648,9 +675,7 @@ async def _check_instance(instance: InstanceModel) -> None:
|
|
|
648
675
|
instance.unreachable = False
|
|
649
676
|
|
|
650
677
|
if instance.status == InstanceStatus.PROVISIONING:
|
|
651
|
-
instance.status =
|
|
652
|
-
InstanceStatus.IDLE if instance.job_id is None else InstanceStatus.BUSY
|
|
653
|
-
)
|
|
678
|
+
instance.status = InstanceStatus.IDLE if not instance.jobs else InstanceStatus.BUSY
|
|
654
679
|
logger.info(
|
|
655
680
|
"Instance %s has switched to %s status",
|
|
656
681
|
instance.name,
|
|
@@ -869,21 +894,31 @@ def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool:
|
|
|
869
894
|
)
|
|
870
895
|
|
|
871
896
|
|
|
872
|
-
def
|
|
873
|
-
|
|
897
|
+
def _get_instance_offer_for_instance(
|
|
898
|
+
instance_offer: InstanceOfferWithAvailability,
|
|
899
|
+
instance: InstanceModel,
|
|
900
|
+
) -> InstanceOfferWithAvailability:
|
|
874
901
|
if instance.fleet is None:
|
|
875
|
-
return
|
|
902
|
+
return instance_offer
|
|
876
903
|
|
|
877
904
|
fleet = fleet_model_to_fleet(instance.fleet)
|
|
878
905
|
master_instance = instance.fleet.instances[0]
|
|
879
906
|
master_job_provisioning_data = get_instance_provisioning_data(master_instance)
|
|
907
|
+
instance_offer = instance_offer.copy()
|
|
880
908
|
if (
|
|
881
909
|
fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER
|
|
882
910
|
and master_job_provisioning_data is not None
|
|
911
|
+
and master_job_provisioning_data.availability_zone is not None
|
|
883
912
|
):
|
|
884
|
-
|
|
913
|
+
if instance_offer.availability_zones is None:
|
|
914
|
+
instance_offer.availability_zones = [master_job_provisioning_data.availability_zone]
|
|
915
|
+
instance_offer.availability_zones = [
|
|
916
|
+
z
|
|
917
|
+
for z in instance_offer.availability_zones
|
|
918
|
+
if instance_offer.availability_zones == master_job_provisioning_data.availability_zone
|
|
919
|
+
]
|
|
885
920
|
|
|
886
|
-
return
|
|
921
|
+
return instance_offer
|
|
887
922
|
|
|
888
923
|
|
|
889
924
|
def _create_placement_group_if_does_not_exist(
|
|
@@ -942,3 +977,7 @@ def _get_instance_timeout_interval(
|
|
|
942
977
|
if backend_type == BackendType.VULTR and instance_type_name.startswith("vbm"):
|
|
943
978
|
return timedelta(seconds=3300)
|
|
944
979
|
return timedelta(seconds=600)
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
def _ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]:
|
|
983
|
+
return [pkey_from_str(sk.private) for sk in ssh_keys if sk.private is not None]
|
|
@@ -3,18 +3,19 @@ import json
|
|
|
3
3
|
from typing import Dict, List, Optional
|
|
4
4
|
|
|
5
5
|
from sqlalchemy import delete, select
|
|
6
|
-
from sqlalchemy.orm import
|
|
6
|
+
from sqlalchemy.orm import joinedload
|
|
7
7
|
|
|
8
8
|
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT
|
|
9
9
|
from dstack._internal.core.models.runs import JobStatus
|
|
10
10
|
from dstack._internal.server import settings
|
|
11
11
|
from dstack._internal.server.db import get_session_ctx
|
|
12
|
-
from dstack._internal.server.models import JobMetricsPoint, JobModel
|
|
12
|
+
from dstack._internal.server.models import InstanceModel, JobMetricsPoint, JobModel
|
|
13
13
|
from dstack._internal.server.schemas.runner import MetricsResponse
|
|
14
14
|
from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data
|
|
15
|
+
from dstack._internal.server.services.pools import get_instance_ssh_private_keys
|
|
15
16
|
from dstack._internal.server.services.runner import client
|
|
16
17
|
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
|
|
17
|
-
from dstack._internal.utils.common import batched, get_current_datetime, run_async
|
|
18
|
+
from dstack._internal.utils.common import batched, get_current_datetime, get_or_error, run_async
|
|
18
19
|
from dstack._internal.utils.logging import get_logger
|
|
19
20
|
|
|
20
21
|
logger = get_logger(__name__)
|
|
@@ -29,14 +30,12 @@ async def collect_metrics():
|
|
|
29
30
|
async with get_session_ctx() as session:
|
|
30
31
|
res = await session.execute(
|
|
31
32
|
select(JobModel)
|
|
32
|
-
.where(
|
|
33
|
-
|
|
34
|
-
)
|
|
35
|
-
.options(selectinload(JobModel.project))
|
|
33
|
+
.where(JobModel.status.in_([JobStatus.RUNNING]))
|
|
34
|
+
.options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
|
|
36
35
|
.order_by(JobModel.last_processed_at.asc())
|
|
37
36
|
.limit(MAX_JOBS_FETCHED)
|
|
38
37
|
)
|
|
39
|
-
job_models = res.scalars().all()
|
|
38
|
+
job_models = res.unique().scalars().all()
|
|
40
39
|
|
|
41
40
|
for batch in batched(job_models, BATCH_SIZE):
|
|
42
41
|
await _collect_jobs_metrics(batch)
|
|
@@ -87,6 +86,7 @@ def _get_recently_collected_metric_cutoff() -> int:
|
|
|
87
86
|
|
|
88
87
|
|
|
89
88
|
async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint]:
|
|
89
|
+
ssh_private_keys = get_instance_ssh_private_keys(get_or_error(job_model.instance))
|
|
90
90
|
jpd = get_job_provisioning_data(job_model)
|
|
91
91
|
jrd = get_job_runtime_data(job_model)
|
|
92
92
|
if jpd is None:
|
|
@@ -94,7 +94,7 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint]
|
|
|
94
94
|
try:
|
|
95
95
|
res = await run_async(
|
|
96
96
|
_pull_runner_metrics,
|
|
97
|
-
|
|
97
|
+
ssh_private_keys,
|
|
98
98
|
jpd,
|
|
99
99
|
jrd,
|
|
100
100
|
)
|