dstack 0.18.40rc1__py3-none-any.whl → 0.18.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -5
  2. dstack/_internal/cli/services/configurators/base.py +4 -2
  3. dstack/_internal/cli/services/configurators/fleet.py +21 -9
  4. dstack/_internal/cli/services/configurators/gateway.py +15 -0
  5. dstack/_internal/cli/services/configurators/run.py +6 -5
  6. dstack/_internal/cli/services/configurators/volume.py +15 -0
  7. dstack/_internal/cli/services/repos.py +3 -3
  8. dstack/_internal/cli/utils/fleet.py +44 -33
  9. dstack/_internal/cli/utils/run.py +27 -7
  10. dstack/_internal/cli/utils/volume.py +30 -9
  11. dstack/_internal/core/backends/aws/compute.py +94 -53
  12. dstack/_internal/core/backends/aws/resources.py +22 -12
  13. dstack/_internal/core/backends/azure/compute.py +2 -0
  14. dstack/_internal/core/backends/base/compute.py +20 -2
  15. dstack/_internal/core/backends/gcp/compute.py +32 -24
  16. dstack/_internal/core/backends/gcp/resources.py +0 -15
  17. dstack/_internal/core/backends/oci/compute.py +10 -5
  18. dstack/_internal/core/backends/oci/resources.py +23 -26
  19. dstack/_internal/core/backends/remote/provisioning.py +65 -27
  20. dstack/_internal/core/backends/runpod/compute.py +1 -0
  21. dstack/_internal/core/models/backends/azure.py +3 -1
  22. dstack/_internal/core/models/configurations.py +24 -1
  23. dstack/_internal/core/models/fleets.py +46 -0
  24. dstack/_internal/core/models/instances.py +5 -1
  25. dstack/_internal/core/models/pools.py +4 -1
  26. dstack/_internal/core/models/profiles.py +10 -4
  27. dstack/_internal/core/models/runs.py +23 -3
  28. dstack/_internal/core/models/volumes.py +26 -0
  29. dstack/_internal/core/services/ssh/attach.py +92 -53
  30. dstack/_internal/core/services/ssh/tunnel.py +58 -31
  31. dstack/_internal/proxy/gateway/routers/registry.py +2 -0
  32. dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
  33. dstack/_internal/proxy/gateway/services/registry.py +4 -0
  34. dstack/_internal/proxy/lib/models.py +3 -0
  35. dstack/_internal/proxy/lib/services/service_connection.py +8 -1
  36. dstack/_internal/server/background/tasks/process_instances.py +73 -35
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -9
  38. dstack/_internal/server/background/tasks/process_running_jobs.py +77 -26
  39. dstack/_internal/server/background/tasks/process_runs.py +2 -12
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +121 -49
  41. dstack/_internal/server/background/tasks/process_terminating_jobs.py +14 -3
  42. dstack/_internal/server/background/tasks/process_volumes.py +11 -1
  43. dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
  44. dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
  45. dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
  46. dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
  47. dstack/_internal/server/migrations/versions/a751ef183f27_move_attachment_data_to_volumes_.py +34 -0
  48. dstack/_internal/server/models.py +27 -23
  49. dstack/_internal/server/routers/runs.py +1 -0
  50. dstack/_internal/server/schemas/runner.py +1 -0
  51. dstack/_internal/server/services/backends/configurators/azure.py +34 -8
  52. dstack/_internal/server/services/config.py +9 -0
  53. dstack/_internal/server/services/fleets.py +32 -3
  54. dstack/_internal/server/services/gateways/client.py +9 -1
  55. dstack/_internal/server/services/jobs/__init__.py +217 -45
  56. dstack/_internal/server/services/jobs/configurators/base.py +47 -2
  57. dstack/_internal/server/services/offers.py +96 -10
  58. dstack/_internal/server/services/pools.py +98 -14
  59. dstack/_internal/server/services/proxy/repo.py +17 -3
  60. dstack/_internal/server/services/runner/client.py +9 -6
  61. dstack/_internal/server/services/runner/ssh.py +33 -5
  62. dstack/_internal/server/services/runs.py +48 -179
  63. dstack/_internal/server/services/services/__init__.py +9 -1
  64. dstack/_internal/server/services/volumes.py +68 -9
  65. dstack/_internal/server/statics/index.html +1 -1
  66. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
  67. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
  68. dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
  69. dstack/_internal/server/testing/common.py +130 -61
  70. dstack/_internal/utils/common.py +22 -8
  71. dstack/_internal/utils/env.py +14 -0
  72. dstack/_internal/utils/ssh.py +1 -1
  73. dstack/api/server/_fleets.py +25 -1
  74. dstack/api/server/_runs.py +23 -2
  75. dstack/api/server/_volumes.py +12 -1
  76. dstack/version.py +1 -1
  77. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/METADATA +1 -1
  78. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/RECORD +104 -93
  79. tests/_internal/cli/services/configurators/test_profile.py +3 -3
  80. tests/_internal/core/services/ssh/test_tunnel.py +56 -4
  81. tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
  82. tests/_internal/server/background/tasks/test_process_instances.py +138 -20
  83. tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
  84. tests/_internal/server/background/tasks/test_process_running_jobs.py +193 -0
  85. tests/_internal/server/background/tasks/test_process_runs.py +27 -3
  86. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +53 -6
  87. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +135 -17
  88. tests/_internal/server/routers/test_fleets.py +15 -2
  89. tests/_internal/server/routers/test_pools.py +6 -0
  90. tests/_internal/server/routers/test_runs.py +27 -0
  91. tests/_internal/server/routers/test_volumes.py +9 -2
  92. tests/_internal/server/services/jobs/__init__.py +0 -0
  93. tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
  94. tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
  95. tests/_internal/server/services/runner/test_client.py +22 -3
  96. tests/_internal/server/services/test_offers.py +167 -0
  97. tests/_internal/server/services/test_pools.py +109 -1
  98. tests/_internal/server/services/test_runs.py +5 -41
  99. tests/_internal/utils/test_common.py +21 -0
  100. tests/_internal/utils/test_env.py +38 -0
  101. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/LICENSE.md +0 -0
  102. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/WHEEL +0 -0
  103. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/entry_points.txt +0 -0
  104. {dstack-0.18.40rc1.dist-info → dstack-0.18.42.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import atexit
2
2
  import re
3
3
  import time
4
4
  from pathlib import Path
5
- from typing import Optional
5
+ from typing import Optional, Union
6
6
 
7
7
  import psutil
8
8
 
@@ -14,6 +14,8 @@ from dstack._internal.core.services.ssh.ports import PortsLock
14
14
  from dstack._internal.core.services.ssh.tunnel import SSHTunnel, ports_to_forwarded_sockets
15
15
  from dstack._internal.utils.path import FilePath, PathLike
16
16
  from dstack._internal.utils.ssh import (
17
+ default_ssh_config_path,
18
+ get_host_config,
17
19
  include_ssh_config,
18
20
  normalize_path,
19
21
  update_ssh_config,
@@ -88,28 +90,63 @@ class SSHAttach:
88
90
  },
89
91
  )
90
92
  self.ssh_proxy = ssh_proxy
91
- if ssh_proxy is None:
92
- self.host_config = {
93
+
94
+ hosts: dict[str, dict[str, Union[str, int, FilePath]]] = {}
95
+ self.hosts = hosts
96
+
97
+ if local_backend:
98
+ hosts[run_name] = {
93
99
  "HostName": hostname,
94
- "Port": ssh_port,
95
- "User": user if dockerized else container_user,
96
- "IdentityFile": self.identity_file,
97
- "IdentitiesOnly": "yes",
98
- "StrictHostKeyChecking": "no",
99
- "UserKnownHostsFile": "/dev/null",
100
- }
101
- else:
102
- self.host_config = {
103
- "HostName": ssh_proxy.hostname,
104
- "Port": ssh_proxy.port,
105
- "User": ssh_proxy.username,
100
+ "Port": container_ssh_port,
101
+ "User": container_user,
106
102
  "IdentityFile": self.identity_file,
107
103
  "IdentitiesOnly": "yes",
108
104
  "StrictHostKeyChecking": "no",
109
105
  "UserKnownHostsFile": "/dev/null",
110
106
  }
111
- if dockerized and not local_backend:
112
- self.container_config = {
107
+ elif dockerized:
108
+ if ssh_proxy is not None:
109
+ # SSH instance with jump host
110
+ # dstack has no IdentityFile for jump host, it must be either preconfigured
111
+ # in the ~/.ssh/config or loaded into ssh-agent
112
+ hosts[f"{run_name}-jump-host"] = {
113
+ "HostName": ssh_proxy.hostname,
114
+ "Port": ssh_proxy.port,
115
+ "User": ssh_proxy.username,
116
+ "StrictHostKeyChecking": "no",
117
+ "UserKnownHostsFile": "/dev/null",
118
+ }
119
+ jump_host_config = get_host_config(ssh_proxy.hostname, default_ssh_config_path)
120
+ jump_host_identity_files = jump_host_config.get("identityfile")
121
+ if jump_host_identity_files:
122
+ hosts[f"{run_name}-jump-host"].update(
123
+ {
124
+ "IdentityFile": jump_host_identity_files[0],
125
+ "IdentitiesOnly": "yes",
126
+ }
127
+ )
128
+ hosts[f"{run_name}-host"] = {
129
+ "HostName": hostname,
130
+ "Port": ssh_port,
131
+ "User": user,
132
+ "IdentityFile": self.identity_file,
133
+ "IdentitiesOnly": "yes",
134
+ "StrictHostKeyChecking": "no",
135
+ "UserKnownHostsFile": "/dev/null",
136
+ "ProxyJump": f"{run_name}-jump-host",
137
+ }
138
+ else:
139
+ # Regular SSH instance or VM-based cloud instance
140
+ hosts[f"{run_name}-host"] = {
141
+ "HostName": hostname,
142
+ "Port": ssh_port,
143
+ "User": user,
144
+ "IdentityFile": self.identity_file,
145
+ "IdentitiesOnly": "yes",
146
+ "StrictHostKeyChecking": "no",
147
+ "UserKnownHostsFile": "/dev/null",
148
+ }
149
+ hosts[run_name] = {
113
150
  "HostName": "localhost",
114
151
  "Port": container_ssh_port,
115
152
  "User": container_user,
@@ -119,32 +156,41 @@ class SSHAttach:
119
156
  "UserKnownHostsFile": "/dev/null",
120
157
  "ProxyJump": f"{run_name}-host",
121
158
  }
122
- elif ssh_proxy is not None:
123
- self.container_config = {
124
- "HostName": hostname,
125
- "Port": ssh_port,
126
- "User": container_user,
127
- "IdentityFile": self.identity_file,
128
- "IdentitiesOnly": "yes",
129
- "StrictHostKeyChecking": "no",
130
- "UserKnownHostsFile": "/dev/null",
131
- "ProxyJump": f"{run_name}-jump-host",
132
- }
133
159
  else:
134
- self.container_config = None
135
- if local_backend:
136
- self.container_config = None
137
- self.host_config = {
138
- "HostName": hostname,
139
- "Port": container_ssh_port,
140
- "User": container_user,
141
- "IdentityFile": self.identity_file,
142
- "IdentitiesOnly": "yes",
143
- "StrictHostKeyChecking": "no",
144
- "UserKnownHostsFile": "/dev/null",
145
- }
146
- if self.container_config is not None and get_ssh_client_info().supports_multiplexing:
147
- self.container_config.update(
160
+ if ssh_proxy is not None:
161
+ # Kubernetes
162
+ hosts[f"{run_name}-jump-host"] = {
163
+ "HostName": ssh_proxy.hostname,
164
+ "Port": ssh_proxy.port,
165
+ "User": ssh_proxy.username,
166
+ "IdentityFile": self.identity_file,
167
+ "IdentitiesOnly": "yes",
168
+ "StrictHostKeyChecking": "no",
169
+ "UserKnownHostsFile": "/dev/null",
170
+ }
171
+ hosts[run_name] = {
172
+ "HostName": hostname,
173
+ "Port": ssh_port,
174
+ "User": container_user,
175
+ "IdentityFile": self.identity_file,
176
+ "IdentitiesOnly": "yes",
177
+ "StrictHostKeyChecking": "no",
178
+ "UserKnownHostsFile": "/dev/null",
179
+ "ProxyJump": f"{run_name}-jump-host",
180
+ }
181
+ else:
182
+ # Container-based backends
183
+ hosts[run_name] = {
184
+ "HostName": hostname,
185
+ "Port": ssh_port,
186
+ "User": container_user,
187
+ "IdentityFile": self.identity_file,
188
+ "IdentitiesOnly": "yes",
189
+ "StrictHostKeyChecking": "no",
190
+ "UserKnownHostsFile": "/dev/null",
191
+ }
192
+ if get_ssh_client_info().supports_multiplexing:
193
+ hosts[run_name].update(
148
194
  {
149
195
  "ControlMaster": "auto",
150
196
  "ControlPath": self.control_sock_path,
@@ -153,14 +199,8 @@ class SSHAttach:
153
199
 
154
200
  def attach(self):
155
201
  include_ssh_config(self.ssh_config_path)
156
- if self.container_config is None:
157
- update_ssh_config(self.ssh_config_path, self.run_name, self.host_config)
158
- elif self.ssh_proxy is not None:
159
- update_ssh_config(self.ssh_config_path, f"{self.run_name}-jump-host", self.host_config)
160
- update_ssh_config(self.ssh_config_path, self.run_name, self.container_config)
161
- else:
162
- update_ssh_config(self.ssh_config_path, f"{self.run_name}-host", self.host_config)
163
- update_ssh_config(self.ssh_config_path, self.run_name, self.container_config)
202
+ for host, options in self.hosts.items():
203
+ update_ssh_config(self.ssh_config_path, host, options)
164
204
 
165
205
  max_retries = 10
166
206
  self._ports_lock.release()
@@ -178,9 +218,8 @@ class SSHAttach:
178
218
 
179
219
  def detach(self):
180
220
  self.tunnel.close()
181
- update_ssh_config(self.ssh_config_path, f"{self.run_name}-jump-host", {})
182
- update_ssh_config(self.ssh_config_path, f"{self.run_name}-host", {})
183
- update_ssh_config(self.ssh_config_path, self.run_name, {})
221
+ for host in self.hosts:
222
+ update_ssh_config(self.ssh_config_path, host, {})
184
223
 
185
224
  def __enter__(self):
186
225
  self.attach()
@@ -69,13 +69,16 @@ class SSHTunnel:
69
69
  options: Dict[str, str] = SSH_DEFAULT_OPTIONS,
70
70
  ssh_config_path: Union[PathLike, Literal["none"]] = "none",
71
71
  port: Optional[int] = None,
72
- ssh_proxy: Optional[SSHConnectionParams] = None,
72
+ ssh_proxies: Iterable[tuple[SSHConnectionParams, Optional[FilePathOrContent]]] = (),
73
73
  ):
74
74
  """
75
75
  :param forwarded_sockets: Connections to the specified local sockets will be
76
76
  forwarded to their corresponding remote sockets
77
77
  :param reverse_forwarded_sockets: Connections to the specified remote sockets
78
78
  will be forwarded to their corresponding local sockets
79
+ :param ssh_proxies: pairs of SSH connections params and optional identities,
80
+ in order from outer to inner. If an identity is `None`, the `identity` param
81
+ is used instead.
79
82
  """
80
83
  self.destination = destination
81
84
  self.forwarded_sockets = list(forwarded_sockets)
@@ -83,21 +86,21 @@ class SSHTunnel:
83
86
  self.options = options
84
87
  self.port = port
85
88
  self.ssh_config_path = normalize_path(ssh_config_path)
86
- self.ssh_proxy = ssh_proxy
87
89
  temp_dir = tempfile.TemporaryDirectory()
88
90
  self.temp_dir = temp_dir
89
91
  if control_sock_path is None:
90
92
  control_sock_path = os.path.join(temp_dir.name, "control.sock")
91
93
  self.control_sock_path = normalize_path(control_sock_path)
92
- if isinstance(identity, FilePath):
93
- identity_path = identity.path
94
- else:
95
- identity_path = os.path.join(temp_dir.name, "identity")
96
- with open(
97
- identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w"
98
- ) as f:
99
- f.write(identity.content)
100
- self.identity_path = normalize_path(identity_path)
94
+ self.identity_path = normalize_path(self._get_identity_path(identity, "identity"))
95
+ self.ssh_proxies: list[tuple[SSHConnectionParams, PathLike]] = []
96
+ for proxy_index, (proxy_params, proxy_identity) in enumerate(ssh_proxies):
97
+ if proxy_identity is None:
98
+ proxy_identity_path = self.identity_path
99
+ else:
100
+ proxy_identity_path = self._get_identity_path(
101
+ proxy_identity, f"proxy_identity_{proxy_index}"
102
+ )
103
+ self.ssh_proxies.append((proxy_params, proxy_identity_path))
101
104
  self.log_path = normalize_path(os.path.join(temp_dir.name, "tunnel.log"))
102
105
  self.ssh_client_info = get_ssh_client_info()
103
106
  self.ssh_exec_path = str(self.ssh_client_info.path)
@@ -142,8 +145,8 @@ class SSHTunnel:
142
145
  command += ["-p", str(self.port)]
143
146
  for k, v in self.options.items():
144
147
  command += ["-o", f"{k}={v}"]
145
- if proxy_command := self.proxy_command():
146
- command += ["-o", "ProxyCommand=" + shlex.join(proxy_command)]
148
+ if proxy_command := self._get_proxy_command():
149
+ command += ["-o", proxy_command]
147
150
  for socket_pair in self.forwarded_sockets:
148
151
  command += ["-L", f"{socket_pair.local.render()}:{socket_pair.remote.render()}"]
149
152
  for socket_pair in self.reverse_forwarded_sockets:
@@ -160,24 +163,6 @@ class SSHTunnel:
160
163
  def exec_command(self) -> List[str]:
161
164
  return [self.ssh_exec_path, "-S", self.control_sock_path, self.destination]
162
165
 
163
- def proxy_command(self) -> Optional[List[str]]:
164
- if self.ssh_proxy is None:
165
- return None
166
- return [
167
- self.ssh_exec_path,
168
- "-i",
169
- self.identity_path,
170
- "-W",
171
- "%h:%p",
172
- "-o",
173
- "StrictHostKeyChecking=no",
174
- "-o",
175
- "UserKnownHostsFile=/dev/null",
176
- "-p",
177
- str(self.ssh_proxy.port),
178
- f"{self.ssh_proxy.username}@{self.ssh_proxy.hostname}",
179
- ]
180
-
181
166
  def open(self) -> None:
182
167
  # We cannot use `stderr=subprocess.PIPE` here since the forked process (daemon) does not
183
168
  # close standard streams if ProxyJump is used, therefore we will wait EOF from the pipe
@@ -251,6 +236,38 @@ class SSHTunnel:
251
236
  def __exit__(self, exc_type, exc_val, exc_tb):
252
237
  self.close()
253
238
 
239
+ def _get_proxy_command(self) -> Optional[str]:
240
+ proxy_command: Optional[str] = None
241
+ for params, identity_path in self.ssh_proxies:
242
+ proxy_command = self._build_proxy_command(params, identity_path, proxy_command)
243
+ return proxy_command
244
+
245
+ def _build_proxy_command(
246
+ self,
247
+ params: SSHConnectionParams,
248
+ identity_path: PathLike,
249
+ prev_proxy_command: Optional[str],
250
+ ) -> Optional[str]:
251
+ command = [
252
+ self.ssh_exec_path,
253
+ "-i",
254
+ identity_path,
255
+ "-W",
256
+ "%h:%p",
257
+ "-o",
258
+ "StrictHostKeyChecking=no",
259
+ "-o",
260
+ "UserKnownHostsFile=/dev/null",
261
+ ]
262
+ if prev_proxy_command is not None:
263
+ command += ["-o", prev_proxy_command.replace("%", "%%")]
264
+ command += [
265
+ "-p",
266
+ str(params.port),
267
+ f"{params.username}@{params.hostname}",
268
+ ]
269
+ return "ProxyCommand=" + shlex.join(command)
270
+
254
271
  def _read_log_file(self) -> bytes:
255
272
  with open(self.log_path, "rb") as f:
256
273
  return f.read()
@@ -263,6 +280,16 @@ class SSHTunnel:
263
280
  except OSError as e:
264
281
  logger.debug("Failed to remove SSH tunnel log file %s: %s", self.log_path, e)
265
282
 
283
+ def _get_identity_path(self, identity: FilePathOrContent, tmp_filename: str) -> PathLike:
284
+ if isinstance(identity, FilePath):
285
+ return identity.path
286
+ identity_path = os.path.join(self.temp_dir.name, tmp_filename)
287
+ with open(
288
+ identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w"
289
+ ) as f:
290
+ f.write(identity.content)
291
+ return identity_path
292
+
266
293
 
267
294
  def ports_to_forwarded_sockets(
268
295
  ports: Dict[int, int], bind_local: str = "localhost"
@@ -76,6 +76,8 @@ async def register_replica(
76
76
  ssh_destination=body.ssh_host,
77
77
  ssh_port=body.ssh_port,
78
78
  ssh_proxy=body.ssh_proxy,
79
+ ssh_head_proxy=body.ssh_head_proxy,
80
+ ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
79
81
  repo=repo,
80
82
  nginx=nginx,
81
83
  service_conn_pool=service_conn_pool,
@@ -50,6 +50,8 @@ class RegisterReplicaRequest(BaseModel):
50
50
  ssh_host: str
51
51
  ssh_port: int
52
52
  ssh_proxy: Optional[SSHConnectionParams]
53
+ ssh_head_proxy: Optional[SSHConnectionParams]
54
+ ssh_head_proxy_private_key: Optional[str]
53
55
 
54
56
 
55
57
  class RegisterEntrypointRequest(BaseModel):
@@ -123,6 +123,8 @@ async def register_replica(
123
123
  ssh_destination: str,
124
124
  ssh_port: int,
125
125
  ssh_proxy: Optional[SSHConnectionParams],
126
+ ssh_head_proxy: Optional[SSHConnectionParams],
127
+ ssh_head_proxy_private_key: Optional[str],
126
128
  repo: GatewayProxyRepo,
127
129
  nginx: Nginx,
128
130
  service_conn_pool: ServiceConnectionPool,
@@ -133,6 +135,8 @@ async def register_replica(
133
135
  ssh_destination=ssh_destination,
134
136
  ssh_port=ssh_port,
135
137
  ssh_proxy=ssh_proxy,
138
+ ssh_head_proxy=ssh_head_proxy,
139
+ ssh_head_proxy_private_key=ssh_head_proxy_private_key,
136
140
  )
137
141
 
138
142
  async with lock:
@@ -23,6 +23,9 @@ class Replica(ImmutableModel):
23
23
  ssh_destination: str
24
24
  ssh_port: int
25
25
  ssh_proxy: Optional[SSHConnectionParams]
26
+ # Optional outer proxy, a head node/bastion
27
+ ssh_head_proxy: Optional[SSHConnectionParams] = None
28
+ ssh_head_proxy_private_key: Optional[str] = None
26
29
 
27
30
 
28
31
  class Service(ImmutableModel):
@@ -18,6 +18,7 @@ from dstack._internal.core.services.ssh.tunnel import (
18
18
  from dstack._internal.proxy.lib.errors import UnexpectedProxyError
19
19
  from dstack._internal.proxy.lib.models import Project, Replica, Service
20
20
  from dstack._internal.proxy.lib.repo import BaseProxyRepo
21
+ from dstack._internal.utils.common import get_or_error
21
22
  from dstack._internal.utils.logging import get_logger
22
23
  from dstack._internal.utils.path import FileContent
23
24
 
@@ -45,10 +46,16 @@ class ServiceConnection:
45
46
  os.chmod(self._temp_dir.name, 0o755)
46
47
  options["StreamLocalBindMask"] = "0111"
47
48
  self._app_socket_path = (Path(self._temp_dir.name) / "replica.sock").absolute()
49
+ ssh_proxies = []
50
+ if replica.ssh_head_proxy is not None:
51
+ ssh_head_proxy_private_key = get_or_error(replica.ssh_head_proxy_private_key)
52
+ ssh_proxies.append((replica.ssh_head_proxy, FileContent(ssh_head_proxy_private_key)))
53
+ if replica.ssh_proxy is not None:
54
+ ssh_proxies.append((replica.ssh_proxy, None))
48
55
  self._tunnel = SSHTunnel(
49
56
  destination=replica.ssh_destination,
50
57
  port=replica.ssh_port,
51
- ssh_proxy=replica.ssh_proxy,
58
+ ssh_proxies=ssh_proxies,
52
59
  identity=FileContent(project.ssh_private_key),
53
60
  forwarded_sockets=[
54
61
  SocketPair(
@@ -42,12 +42,12 @@ from dstack._internal.core.models.backends.base import BackendType
42
42
  from dstack._internal.core.models.fleets import InstanceGroupPlacement
43
43
  from dstack._internal.core.models.instances import (
44
44
  InstanceAvailability,
45
- InstanceConfiguration,
46
45
  InstanceOfferWithAvailability,
47
46
  InstanceRuntime,
48
47
  InstanceStatus,
49
48
  InstanceType,
50
49
  RemoteConnectionInfo,
50
+ SSHKey,
51
51
  )
52
52
  from dstack._internal.core.models.placement import (
53
53
  PlacementGroup,
@@ -77,6 +77,7 @@ from dstack._internal.server.services.fleets import (
77
77
  get_create_instance_offers,
78
78
  )
79
79
  from dstack._internal.server.services.locking import get_locker
80
+ from dstack._internal.server.services.offers import is_divisible_into_blocks
80
81
  from dstack._internal.server.services.placement import (
81
82
  get_fleet_placement_groups,
82
83
  placement_group_model_to_placement_group,
@@ -86,6 +87,7 @@ from dstack._internal.server.services.pools import (
86
87
  get_instance_profile,
87
88
  get_instance_provisioning_data,
88
89
  get_instance_requirements,
90
+ get_instance_ssh_private_keys,
89
91
  )
90
92
  from dstack._internal.server.services.runner import client as runner_client
91
93
  from dstack._internal.server.services.runner.client import HealthStatus
@@ -133,7 +135,7 @@ async def _process_next_instance():
133
135
  ),
134
136
  InstanceModel.id.not_in(lockset),
135
137
  )
136
- .options(lazyload(InstanceModel.job))
138
+ .options(lazyload(InstanceModel.jobs))
137
139
  .order_by(InstanceModel.last_processed_at.asc())
138
140
  .limit(1)
139
141
  .with_for_update(skip_locked=True)
@@ -156,7 +158,7 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
156
158
  select(InstanceModel)
157
159
  .where(InstanceModel.id == instance.id)
158
160
  .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
159
- .options(joinedload(InstanceModel.job))
161
+ .options(joinedload(InstanceModel.jobs))
160
162
  .options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
161
163
  .execution_options(populate_existing=True)
162
164
  )
@@ -164,7 +166,7 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
164
166
  if (
165
167
  instance.status == InstanceStatus.IDLE
166
168
  and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
167
- and instance.job_id is None
169
+ and not instance.jobs
168
170
  ):
169
171
  await _mark_terminating_if_idle_duration_expired(instance)
170
172
  if instance.status == InstanceStatus.PENDING:
@@ -232,11 +234,11 @@ async def _add_remote(instance: InstanceModel) -> None:
232
234
  remote_details = RemoteConnectionInfo.parse_raw(cast(str, instance.remote_connection_info))
233
235
  # Prepare connection key
234
236
  try:
235
- pkeys = [
236
- pkey_from_str(sk.private)
237
- for sk in remote_details.ssh_keys
238
- if sk.private is not None
239
- ]
237
+ pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys)
238
+ if remote_details.ssh_proxy_keys is not None:
239
+ ssh_proxy_pkeys = _ssh_keys_to_pkeys(remote_details.ssh_proxy_keys)
240
+ else:
241
+ ssh_proxy_pkeys = None
240
242
  except (ValueError, PasswordRequiredException):
241
243
  instance.status = InstanceStatus.TERMINATED
242
244
  instance.termination_reason = "Unsupported private SSH key type"
@@ -254,7 +256,9 @@ async def _add_remote(instance: InstanceModel) -> None:
254
256
  authorized_keys.append(instance.project.ssh_public_key.strip())
255
257
 
256
258
  try:
257
- future = run_async(_deploy_instance, remote_details, pkeys, authorized_keys)
259
+ future = run_async(
260
+ _deploy_instance, remote_details, pkeys, ssh_proxy_pkeys, authorized_keys
261
+ )
258
262
  deploy_timeout = 20 * 60 # 20 minutes
259
263
  result = await asyncio.wait_for(future, timeout=deploy_timeout)
260
264
  health, host_info = result
@@ -322,6 +326,26 @@ async def _add_remote(instance: InstanceModel) -> None:
322
326
  )
323
327
  return
324
328
 
329
+ divisible, blocks = is_divisible_into_blocks(
330
+ cpu_count=instance_type.resources.cpus,
331
+ gpu_count=len(instance_type.resources.gpus),
332
+ blocks="auto" if instance.total_blocks is None else instance.total_blocks,
333
+ )
334
+ if divisible:
335
+ instance.total_blocks = blocks
336
+ else:
337
+ instance.status = InstanceStatus.TERMINATED
338
+ instance.termination_reason = "Cannot split into blocks"
339
+ logger.warning(
340
+ "Failed to add instance %s: cannot split into blocks",
341
+ instance.name,
342
+ extra={
343
+ "instance_name": instance.name,
344
+ "instance_status": InstanceStatus.TERMINATED.value,
345
+ },
346
+ )
347
+ return
348
+
325
349
  region = instance.region
326
350
  jpd = JobProvisioningData(
327
351
  backend=BackendType.REMOTE,
@@ -336,7 +360,7 @@ async def _add_remote(instance: InstanceModel) -> None:
336
360
  ssh_port=remote_details.port,
337
361
  dockerized=True,
338
362
  backend_data=None,
339
- ssh_proxy=None,
363
+ ssh_proxy=remote_details.ssh_proxy,
340
364
  )
341
365
 
342
366
  instance.status = InstanceStatus.IDLE if health else InstanceStatus.PROVISIONING
@@ -359,10 +383,16 @@ async def _add_remote(instance: InstanceModel) -> None:
359
383
  def _deploy_instance(
360
384
  remote_details: RemoteConnectionInfo,
361
385
  pkeys: List[PKey],
386
+ ssh_proxy_pkeys: Optional[list[PKey]],
362
387
  authorized_keys: List[str],
363
388
  ) -> Tuple[HealthStatus, Dict[str, Any]]:
364
389
  with get_paramiko_connection(
365
- remote_details.ssh_user, remote_details.host, remote_details.port, pkeys
390
+ remote_details.ssh_user,
391
+ remote_details.host,
392
+ remote_details.port,
393
+ pkeys,
394
+ remote_details.ssh_proxy,
395
+ ssh_proxy_pkeys,
366
396
  ) as client:
367
397
  logger.info(f"Connected to {remote_details.ssh_user} {remote_details.host}")
368
398
 
@@ -477,8 +507,9 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
477
507
  project=instance.project,
478
508
  profile=profile,
479
509
  requirements=requirements,
480
- exclude_not_available=True,
481
510
  fleet_model=instance.fleet,
511
+ blocks="auto" if instance.total_blocks is None else instance.total_blocks,
512
+ exclude_not_available=True,
482
513
  )
483
514
 
484
515
  if not offers and should_retry:
@@ -496,11 +527,10 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
496
527
  session=session, fleet_id=instance.fleet_id
497
528
  )
498
529
 
499
- instance_configuration = _patch_instance_configuration(instance)
500
-
501
530
  for backend, instance_offer in offers:
502
531
  if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT:
503
532
  continue
533
+ instance_offer = _get_instance_offer_for_instance(instance_offer, instance)
504
534
  if (
505
535
  instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT
506
536
  and instance.fleet
@@ -554,6 +584,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
554
584
  instance.instance_configuration = instance_configuration.json()
555
585
  instance.job_provisioning_data = job_provisioning_data.json()
556
586
  instance.offer = instance_offer.json()
587
+ instance.total_blocks = instance_offer.total_blocks
557
588
  instance.started_at = get_current_datetime()
558
589
  instance.last_retry_at = get_current_datetime()
559
590
 
@@ -585,8 +616,8 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
585
616
  async def _check_instance(instance: InstanceModel) -> None:
586
617
  if (
587
618
  instance.status == InstanceStatus.BUSY
588
- and instance.job is not None
589
- and instance.job.status.is_finished()
619
+ and instance.jobs
620
+ and all(job.status.is_finished() for job in instance.jobs)
590
621
  ):
591
622
  # A busy instance could have no active jobs due to this bug: https://github.com/dstackai/dstack/issues/2068
592
623
  instance.status = InstanceStatus.TERMINATING
@@ -617,18 +648,14 @@ async def _check_instance(instance: InstanceModel) -> None:
617
648
  instance.status = InstanceStatus.BUSY
618
649
  return
619
650
 
620
- ssh_private_key = instance.project.ssh_private_key
621
- # TODO: Drop this logic and always use project key once it's safe to assume that most on-prem
622
- # fleets are (re)created after this change: https://github.com/dstackai/dstack/pull/1716
623
- if instance.remote_connection_info is not None:
624
- remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
625
- instance.remote_connection_info
626
- )
627
- ssh_private_key = remote_conn_info.ssh_keys[0].private
651
+ ssh_private_keys = get_instance_ssh_private_keys(instance)
628
652
 
629
653
  # May return False if fails to establish ssh connection
630
654
  health_status_response = await run_async(
631
- _instance_healthcheck, ssh_private_key, job_provisioning_data, None
655
+ _instance_healthcheck,
656
+ ssh_private_keys,
657
+ job_provisioning_data,
658
+ None,
632
659
  )
633
660
  if isinstance(health_status_response, bool) or health_status_response is None:
634
661
  health_status = HealthStatus(healthy=False, reason="SSH or tunnel error")
@@ -648,9 +675,7 @@ async def _check_instance(instance: InstanceModel) -> None:
648
675
  instance.unreachable = False
649
676
 
650
677
  if instance.status == InstanceStatus.PROVISIONING:
651
- instance.status = (
652
- InstanceStatus.IDLE if instance.job_id is None else InstanceStatus.BUSY
653
- )
678
+ instance.status = InstanceStatus.IDLE if not instance.jobs else InstanceStatus.BUSY
654
679
  logger.info(
655
680
  "Instance %s has switched to %s status",
656
681
  instance.name,
@@ -869,21 +894,30 @@ def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool:
869
894
  )
870
895
 
871
896
 
872
- def _patch_instance_configuration(instance: InstanceModel) -> InstanceConfiguration:
873
- instance_configuration = get_instance_configuration(instance)
897
+ def _get_instance_offer_for_instance(
898
+ instance_offer: InstanceOfferWithAvailability,
899
+ instance: InstanceModel,
900
+ ) -> InstanceOfferWithAvailability:
874
901
  if instance.fleet is None:
875
- return instance_configuration
902
+ return instance_offer
876
903
 
877
904
  fleet = fleet_model_to_fleet(instance.fleet)
878
905
  master_instance = instance.fleet.instances[0]
879
906
  master_job_provisioning_data = get_instance_provisioning_data(master_instance)
907
+ instance_offer = instance_offer.copy()
880
908
  if (
881
909
  fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER
882
910
  and master_job_provisioning_data is not None
911
+ and master_job_provisioning_data.availability_zone is not None
883
912
  ):
884
- instance_configuration.availability_zone = master_job_provisioning_data.availability_zone
885
-
886
- return instance_configuration
913
+ if instance_offer.availability_zones is None:
914
+ instance_offer.availability_zones = [master_job_provisioning_data.availability_zone]
915
+ instance_offer.availability_zones = [
916
+ z
917
+ for z in instance_offer.availability_zones
918
+ if z == master_job_provisioning_data.availability_zone
919
+ ]
920
+ return instance_offer
887
921
 
888
922
 
889
923
  def _create_placement_group_if_does_not_exist(
@@ -942,3 +976,7 @@ def _get_instance_timeout_interval(
942
976
  if backend_type == BackendType.VULTR and instance_type_name.startswith("vbm"):
943
977
  return timedelta(seconds=3300)
944
978
  return timedelta(seconds=600)
979
+
980
+
981
+ def _ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]:
982
+ return [pkey_from_str(sk.private) for sk in ssh_keys if sk.private is not None]