dstack 0.18.40__py3-none-any.whl → 0.18.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. dstack/_internal/cli/commands/apply.py +8 -5
  2. dstack/_internal/cli/services/configurators/base.py +4 -2
  3. dstack/_internal/cli/services/configurators/fleet.py +21 -9
  4. dstack/_internal/cli/services/configurators/gateway.py +15 -0
  5. dstack/_internal/cli/services/configurators/run.py +6 -5
  6. dstack/_internal/cli/services/configurators/volume.py +15 -0
  7. dstack/_internal/cli/services/repos.py +3 -3
  8. dstack/_internal/cli/utils/fleet.py +44 -33
  9. dstack/_internal/cli/utils/run.py +27 -7
  10. dstack/_internal/cli/utils/volume.py +21 -9
  11. dstack/_internal/core/backends/aws/compute.py +92 -52
  12. dstack/_internal/core/backends/aws/resources.py +22 -12
  13. dstack/_internal/core/backends/azure/compute.py +2 -0
  14. dstack/_internal/core/backends/base/compute.py +20 -2
  15. dstack/_internal/core/backends/gcp/compute.py +30 -23
  16. dstack/_internal/core/backends/gcp/resources.py +0 -15
  17. dstack/_internal/core/backends/oci/compute.py +10 -5
  18. dstack/_internal/core/backends/oci/resources.py +23 -26
  19. dstack/_internal/core/backends/remote/provisioning.py +65 -27
  20. dstack/_internal/core/backends/runpod/compute.py +1 -0
  21. dstack/_internal/core/models/backends/azure.py +3 -1
  22. dstack/_internal/core/models/configurations.py +24 -1
  23. dstack/_internal/core/models/fleets.py +46 -0
  24. dstack/_internal/core/models/instances.py +5 -1
  25. dstack/_internal/core/models/pools.py +4 -1
  26. dstack/_internal/core/models/profiles.py +10 -4
  27. dstack/_internal/core/models/runs.py +20 -0
  28. dstack/_internal/core/models/volumes.py +3 -0
  29. dstack/_internal/core/services/ssh/attach.py +92 -53
  30. dstack/_internal/core/services/ssh/tunnel.py +58 -31
  31. dstack/_internal/proxy/gateway/routers/registry.py +2 -0
  32. dstack/_internal/proxy/gateway/schemas/registry.py +2 -0
  33. dstack/_internal/proxy/gateway/services/registry.py +4 -0
  34. dstack/_internal/proxy/lib/models.py +3 -0
  35. dstack/_internal/proxy/lib/services/service_connection.py +8 -1
  36. dstack/_internal/server/background/tasks/process_instances.py +72 -33
  37. dstack/_internal/server/background/tasks/process_metrics.py +9 -9
  38. dstack/_internal/server/background/tasks/process_running_jobs.py +73 -26
  39. dstack/_internal/server/background/tasks/process_runs.py +2 -12
  40. dstack/_internal/server/background/tasks/process_submitted_jobs.py +109 -42
  41. dstack/_internal/server/background/tasks/process_terminating_jobs.py +1 -1
  42. dstack/_internal/server/migrations/versions/1338b788b612_reverse_job_instance_relationship.py +71 -0
  43. dstack/_internal/server/migrations/versions/1e76fb0dde87_add_jobmodel_inactivity_secs.py +32 -0
  44. dstack/_internal/server/migrations/versions/51d45659d574_add_instancemodel_blocks_fields.py +43 -0
  45. dstack/_internal/server/migrations/versions/63c3f19cb184_add_jobterminationreason_inactivity_.py +83 -0
  46. dstack/_internal/server/models.py +10 -4
  47. dstack/_internal/server/routers/runs.py +1 -0
  48. dstack/_internal/server/schemas/runner.py +1 -0
  49. dstack/_internal/server/services/backends/configurators/azure.py +34 -8
  50. dstack/_internal/server/services/config.py +9 -0
  51. dstack/_internal/server/services/fleets.py +27 -2
  52. dstack/_internal/server/services/gateways/client.py +9 -1
  53. dstack/_internal/server/services/jobs/__init__.py +215 -43
  54. dstack/_internal/server/services/jobs/configurators/base.py +47 -2
  55. dstack/_internal/server/services/offers.py +91 -5
  56. dstack/_internal/server/services/pools.py +95 -11
  57. dstack/_internal/server/services/proxy/repo.py +17 -3
  58. dstack/_internal/server/services/runner/client.py +1 -1
  59. dstack/_internal/server/services/runner/ssh.py +33 -5
  60. dstack/_internal/server/services/runs.py +48 -179
  61. dstack/_internal/server/services/services/__init__.py +9 -1
  62. dstack/_internal/server/statics/index.html +1 -1
  63. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js → main-2ac66bfcbd2e39830b88.js} +30 -31
  64. dstack/_internal/server/statics/{main-11ec5e4a00ea6ec833e3.js.map → main-2ac66bfcbd2e39830b88.js.map} +1 -1
  65. dstack/_internal/server/statics/{main-fc56d1f4af8e57522a1c.css → main-ad5150a441de98cd8987.css} +1 -1
  66. dstack/_internal/server/testing/common.py +117 -52
  67. dstack/_internal/utils/common.py +22 -8
  68. dstack/_internal/utils/env.py +14 -0
  69. dstack/_internal/utils/ssh.py +1 -1
  70. dstack/api/server/_fleets.py +25 -1
  71. dstack/api/server/_runs.py +23 -2
  72. dstack/api/server/_volumes.py +12 -1
  73. dstack/version.py +1 -1
  74. {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/METADATA +1 -1
  75. {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/RECORD +98 -89
  76. tests/_internal/cli/services/configurators/test_profile.py +3 -3
  77. tests/_internal/core/services/ssh/test_tunnel.py +56 -4
  78. tests/_internal/proxy/gateway/routers/test_registry.py +30 -7
  79. tests/_internal/server/background/tasks/test_process_instances.py +138 -20
  80. tests/_internal/server/background/tasks/test_process_metrics.py +12 -0
  81. tests/_internal/server/background/tasks/test_process_running_jobs.py +192 -0
  82. tests/_internal/server/background/tasks/test_process_runs.py +27 -3
  83. tests/_internal/server/background/tasks/test_process_submitted_jobs.py +48 -3
  84. tests/_internal/server/background/tasks/test_process_terminating_jobs.py +126 -13
  85. tests/_internal/server/routers/test_fleets.py +15 -2
  86. tests/_internal/server/routers/test_pools.py +6 -0
  87. tests/_internal/server/routers/test_runs.py +27 -0
  88. tests/_internal/server/services/jobs/__init__.py +0 -0
  89. tests/_internal/server/services/jobs/configurators/__init__.py +0 -0
  90. tests/_internal/server/services/jobs/configurators/test_base.py +72 -0
  91. tests/_internal/server/services/test_pools.py +4 -0
  92. tests/_internal/server/services/test_runs.py +5 -41
  93. tests/_internal/utils/test_common.py +21 -0
  94. tests/_internal/utils/test_env.py +38 -0
  95. {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/LICENSE.md +0 -0
  96. {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/WHEEL +0 -0
  97. {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/entry_points.txt +0 -0
  98. {dstack-0.18.40.dist-info → dstack-0.18.41.dist-info}/top_level.txt +0 -0
@@ -69,13 +69,16 @@ class SSHTunnel:
69
69
  options: Dict[str, str] = SSH_DEFAULT_OPTIONS,
70
70
  ssh_config_path: Union[PathLike, Literal["none"]] = "none",
71
71
  port: Optional[int] = None,
72
- ssh_proxy: Optional[SSHConnectionParams] = None,
72
+ ssh_proxies: Iterable[tuple[SSHConnectionParams, Optional[FilePathOrContent]]] = (),
73
73
  ):
74
74
  """
75
75
  :param forwarded_sockets: Connections to the specified local sockets will be
76
76
  forwarded to their corresponding remote sockets
77
77
  :param reverse_forwarded_sockets: Connections to the specified remote sockets
78
78
  will be forwarded to their corresponding local sockets
79
+ :param ssh_proxies: pairs of SSH connections params and optional identities,
80
+ in order from outer to inner. If an identity is `None`, the `identity` param
81
+ is used instead.
79
82
  """
80
83
  self.destination = destination
81
84
  self.forwarded_sockets = list(forwarded_sockets)
@@ -83,21 +86,21 @@ class SSHTunnel:
83
86
  self.options = options
84
87
  self.port = port
85
88
  self.ssh_config_path = normalize_path(ssh_config_path)
86
- self.ssh_proxy = ssh_proxy
87
89
  temp_dir = tempfile.TemporaryDirectory()
88
90
  self.temp_dir = temp_dir
89
91
  if control_sock_path is None:
90
92
  control_sock_path = os.path.join(temp_dir.name, "control.sock")
91
93
  self.control_sock_path = normalize_path(control_sock_path)
92
- if isinstance(identity, FilePath):
93
- identity_path = identity.path
94
- else:
95
- identity_path = os.path.join(temp_dir.name, "identity")
96
- with open(
97
- identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w"
98
- ) as f:
99
- f.write(identity.content)
100
- self.identity_path = normalize_path(identity_path)
94
+ self.identity_path = normalize_path(self._get_identity_path(identity, "identity"))
95
+ self.ssh_proxies: list[tuple[SSHConnectionParams, PathLike]] = []
96
+ for proxy_index, (proxy_params, proxy_identity) in enumerate(ssh_proxies):
97
+ if proxy_identity is None:
98
+ proxy_identity_path = self.identity_path
99
+ else:
100
+ proxy_identity_path = self._get_identity_path(
101
+ proxy_identity, f"proxy_identity_{proxy_index}"
102
+ )
103
+ self.ssh_proxies.append((proxy_params, proxy_identity_path))
101
104
  self.log_path = normalize_path(os.path.join(temp_dir.name, "tunnel.log"))
102
105
  self.ssh_client_info = get_ssh_client_info()
103
106
  self.ssh_exec_path = str(self.ssh_client_info.path)
@@ -142,8 +145,8 @@ class SSHTunnel:
142
145
  command += ["-p", str(self.port)]
143
146
  for k, v in self.options.items():
144
147
  command += ["-o", f"{k}={v}"]
145
- if proxy_command := self.proxy_command():
146
- command += ["-o", "ProxyCommand=" + shlex.join(proxy_command)]
148
+ if proxy_command := self._get_proxy_command():
149
+ command += ["-o", proxy_command]
147
150
  for socket_pair in self.forwarded_sockets:
148
151
  command += ["-L", f"{socket_pair.local.render()}:{socket_pair.remote.render()}"]
149
152
  for socket_pair in self.reverse_forwarded_sockets:
@@ -160,24 +163,6 @@ class SSHTunnel:
160
163
  def exec_command(self) -> List[str]:
161
164
  return [self.ssh_exec_path, "-S", self.control_sock_path, self.destination]
162
165
 
163
- def proxy_command(self) -> Optional[List[str]]:
164
- if self.ssh_proxy is None:
165
- return None
166
- return [
167
- self.ssh_exec_path,
168
- "-i",
169
- self.identity_path,
170
- "-W",
171
- "%h:%p",
172
- "-o",
173
- "StrictHostKeyChecking=no",
174
- "-o",
175
- "UserKnownHostsFile=/dev/null",
176
- "-p",
177
- str(self.ssh_proxy.port),
178
- f"{self.ssh_proxy.username}@{self.ssh_proxy.hostname}",
179
- ]
180
-
181
166
  def open(self) -> None:
182
167
  # We cannot use `stderr=subprocess.PIPE` here since the forked process (daemon) does not
183
168
  # close standard streams if ProxyJump is used, therefore we will wait EOF from the pipe
@@ -251,6 +236,38 @@ class SSHTunnel:
251
236
  def __exit__(self, exc_type, exc_val, exc_tb):
252
237
  self.close()
253
238
 
239
+ def _get_proxy_command(self) -> Optional[str]:
240
+ proxy_command: Optional[str] = None
241
+ for params, identity_path in self.ssh_proxies:
242
+ proxy_command = self._build_proxy_command(params, identity_path, proxy_command)
243
+ return proxy_command
244
+
245
+ def _build_proxy_command(
246
+ self,
247
+ params: SSHConnectionParams,
248
+ identity_path: PathLike,
249
+ prev_proxy_command: Optional[str],
250
+ ) -> Optional[str]:
251
+ command = [
252
+ self.ssh_exec_path,
253
+ "-i",
254
+ identity_path,
255
+ "-W",
256
+ "%h:%p",
257
+ "-o",
258
+ "StrictHostKeyChecking=no",
259
+ "-o",
260
+ "UserKnownHostsFile=/dev/null",
261
+ ]
262
+ if prev_proxy_command is not None:
263
+ command += ["-o", prev_proxy_command.replace("%", "%%")]
264
+ command += [
265
+ "-p",
266
+ str(params.port),
267
+ f"{params.username}@{params.hostname}",
268
+ ]
269
+ return "ProxyCommand=" + shlex.join(command)
270
+
254
271
  def _read_log_file(self) -> bytes:
255
272
  with open(self.log_path, "rb") as f:
256
273
  return f.read()
@@ -263,6 +280,16 @@ class SSHTunnel:
263
280
  except OSError as e:
264
281
  logger.debug("Failed to remove SSH tunnel log file %s: %s", self.log_path, e)
265
282
 
283
+ def _get_identity_path(self, identity: FilePathOrContent, tmp_filename: str) -> PathLike:
284
+ if isinstance(identity, FilePath):
285
+ return identity.path
286
+ identity_path = os.path.join(self.temp_dir.name, tmp_filename)
287
+ with open(
288
+ identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w"
289
+ ) as f:
290
+ f.write(identity.content)
291
+ return identity_path
292
+
266
293
 
267
294
  def ports_to_forwarded_sockets(
268
295
  ports: Dict[int, int], bind_local: str = "localhost"
@@ -76,6 +76,8 @@ async def register_replica(
76
76
  ssh_destination=body.ssh_host,
77
77
  ssh_port=body.ssh_port,
78
78
  ssh_proxy=body.ssh_proxy,
79
+ ssh_head_proxy=body.ssh_head_proxy,
80
+ ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
79
81
  repo=repo,
80
82
  nginx=nginx,
81
83
  service_conn_pool=service_conn_pool,
@@ -50,6 +50,8 @@ class RegisterReplicaRequest(BaseModel):
50
50
  ssh_host: str
51
51
  ssh_port: int
52
52
  ssh_proxy: Optional[SSHConnectionParams]
53
+ ssh_head_proxy: Optional[SSHConnectionParams]
54
+ ssh_head_proxy_private_key: Optional[str]
53
55
 
54
56
 
55
57
  class RegisterEntrypointRequest(BaseModel):
@@ -123,6 +123,8 @@ async def register_replica(
123
123
  ssh_destination: str,
124
124
  ssh_port: int,
125
125
  ssh_proxy: Optional[SSHConnectionParams],
126
+ ssh_head_proxy: Optional[SSHConnectionParams],
127
+ ssh_head_proxy_private_key: Optional[str],
126
128
  repo: GatewayProxyRepo,
127
129
  nginx: Nginx,
128
130
  service_conn_pool: ServiceConnectionPool,
@@ -133,6 +135,8 @@ async def register_replica(
133
135
  ssh_destination=ssh_destination,
134
136
  ssh_port=ssh_port,
135
137
  ssh_proxy=ssh_proxy,
138
+ ssh_head_proxy=ssh_head_proxy,
139
+ ssh_head_proxy_private_key=ssh_head_proxy_private_key,
136
140
  )
137
141
 
138
142
  async with lock:
@@ -23,6 +23,9 @@ class Replica(ImmutableModel):
23
23
  ssh_destination: str
24
24
  ssh_port: int
25
25
  ssh_proxy: Optional[SSHConnectionParams]
26
+ # Optional outer proxy, a head node/bastion
27
+ ssh_head_proxy: Optional[SSHConnectionParams] = None
28
+ ssh_head_proxy_private_key: Optional[str] = None
26
29
 
27
30
 
28
31
  class Service(ImmutableModel):
@@ -18,6 +18,7 @@ from dstack._internal.core.services.ssh.tunnel import (
18
18
  from dstack._internal.proxy.lib.errors import UnexpectedProxyError
19
19
  from dstack._internal.proxy.lib.models import Project, Replica, Service
20
20
  from dstack._internal.proxy.lib.repo import BaseProxyRepo
21
+ from dstack._internal.utils.common import get_or_error
21
22
  from dstack._internal.utils.logging import get_logger
22
23
  from dstack._internal.utils.path import FileContent
23
24
 
@@ -45,10 +46,16 @@ class ServiceConnection:
45
46
  os.chmod(self._temp_dir.name, 0o755)
46
47
  options["StreamLocalBindMask"] = "0111"
47
48
  self._app_socket_path = (Path(self._temp_dir.name) / "replica.sock").absolute()
49
+ ssh_proxies = []
50
+ if replica.ssh_head_proxy is not None:
51
+ ssh_head_proxy_private_key = get_or_error(replica.ssh_head_proxy_private_key)
52
+ ssh_proxies.append((replica.ssh_head_proxy, FileContent(ssh_head_proxy_private_key)))
53
+ if replica.ssh_proxy is not None:
54
+ ssh_proxies.append((replica.ssh_proxy, None))
48
55
  self._tunnel = SSHTunnel(
49
56
  destination=replica.ssh_destination,
50
57
  port=replica.ssh_port,
51
- ssh_proxy=replica.ssh_proxy,
58
+ ssh_proxies=ssh_proxies,
52
59
  identity=FileContent(project.ssh_private_key),
53
60
  forwarded_sockets=[
54
61
  SocketPair(
@@ -42,12 +42,12 @@ from dstack._internal.core.models.backends.base import BackendType
42
42
  from dstack._internal.core.models.fleets import InstanceGroupPlacement
43
43
  from dstack._internal.core.models.instances import (
44
44
  InstanceAvailability,
45
- InstanceConfiguration,
46
45
  InstanceOfferWithAvailability,
47
46
  InstanceRuntime,
48
47
  InstanceStatus,
49
48
  InstanceType,
50
49
  RemoteConnectionInfo,
50
+ SSHKey,
51
51
  )
52
52
  from dstack._internal.core.models.placement import (
53
53
  PlacementGroup,
@@ -77,6 +77,7 @@ from dstack._internal.server.services.fleets import (
77
77
  get_create_instance_offers,
78
78
  )
79
79
  from dstack._internal.server.services.locking import get_locker
80
+ from dstack._internal.server.services.offers import is_divisible_into_blocks
80
81
  from dstack._internal.server.services.placement import (
81
82
  get_fleet_placement_groups,
82
83
  placement_group_model_to_placement_group,
@@ -86,6 +87,7 @@ from dstack._internal.server.services.pools import (
86
87
  get_instance_profile,
87
88
  get_instance_provisioning_data,
88
89
  get_instance_requirements,
90
+ get_instance_ssh_private_keys,
89
91
  )
90
92
  from dstack._internal.server.services.runner import client as runner_client
91
93
  from dstack._internal.server.services.runner.client import HealthStatus
@@ -133,7 +135,7 @@ async def _process_next_instance():
133
135
  ),
134
136
  InstanceModel.id.not_in(lockset),
135
137
  )
136
- .options(lazyload(InstanceModel.job))
138
+ .options(lazyload(InstanceModel.jobs))
137
139
  .order_by(InstanceModel.last_processed_at.asc())
138
140
  .limit(1)
139
141
  .with_for_update(skip_locked=True)
@@ -156,7 +158,7 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
156
158
  select(InstanceModel)
157
159
  .where(InstanceModel.id == instance.id)
158
160
  .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
159
- .options(joinedload(InstanceModel.job))
161
+ .options(joinedload(InstanceModel.jobs))
160
162
  .options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
161
163
  .execution_options(populate_existing=True)
162
164
  )
@@ -164,7 +166,7 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
164
166
  if (
165
167
  instance.status == InstanceStatus.IDLE
166
168
  and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
167
- and instance.job_id is None
169
+ and not instance.jobs
168
170
  ):
169
171
  await _mark_terminating_if_idle_duration_expired(instance)
170
172
  if instance.status == InstanceStatus.PENDING:
@@ -232,11 +234,11 @@ async def _add_remote(instance: InstanceModel) -> None:
232
234
  remote_details = RemoteConnectionInfo.parse_raw(cast(str, instance.remote_connection_info))
233
235
  # Prepare connection key
234
236
  try:
235
- pkeys = [
236
- pkey_from_str(sk.private)
237
- for sk in remote_details.ssh_keys
238
- if sk.private is not None
239
- ]
237
+ pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys)
238
+ if remote_details.ssh_proxy_keys is not None:
239
+ ssh_proxy_pkeys = _ssh_keys_to_pkeys(remote_details.ssh_proxy_keys)
240
+ else:
241
+ ssh_proxy_pkeys = None
240
242
  except (ValueError, PasswordRequiredException):
241
243
  instance.status = InstanceStatus.TERMINATED
242
244
  instance.termination_reason = "Unsupported private SSH key type"
@@ -254,7 +256,9 @@ async def _add_remote(instance: InstanceModel) -> None:
254
256
  authorized_keys.append(instance.project.ssh_public_key.strip())
255
257
 
256
258
  try:
257
- future = run_async(_deploy_instance, remote_details, pkeys, authorized_keys)
259
+ future = run_async(
260
+ _deploy_instance, remote_details, pkeys, ssh_proxy_pkeys, authorized_keys
261
+ )
258
262
  deploy_timeout = 20 * 60 # 20 minutes
259
263
  result = await asyncio.wait_for(future, timeout=deploy_timeout)
260
264
  health, host_info = result
@@ -322,6 +326,26 @@ async def _add_remote(instance: InstanceModel) -> None:
322
326
  )
323
327
  return
324
328
 
329
+ divisible, blocks = is_divisible_into_blocks(
330
+ cpu_count=instance_type.resources.cpus,
331
+ gpu_count=len(instance_type.resources.gpus),
332
+ blocks="auto" if instance.total_blocks is None else instance.total_blocks,
333
+ )
334
+ if divisible:
335
+ instance.total_blocks = blocks
336
+ else:
337
+ instance.status = InstanceStatus.TERMINATED
338
+ instance.termination_reason = "Cannot split into blocks"
339
+ logger.warning(
340
+ "Failed to add instance %s: cannot split into blocks",
341
+ instance.name,
342
+ extra={
343
+ "instance_name": instance.name,
344
+ "instance_status": InstanceStatus.TERMINATED.value,
345
+ },
346
+ )
347
+ return
348
+
325
349
  region = instance.region
326
350
  jpd = JobProvisioningData(
327
351
  backend=BackendType.REMOTE,
@@ -336,7 +360,7 @@ async def _add_remote(instance: InstanceModel) -> None:
336
360
  ssh_port=remote_details.port,
337
361
  dockerized=True,
338
362
  backend_data=None,
339
- ssh_proxy=None,
363
+ ssh_proxy=remote_details.ssh_proxy,
340
364
  )
341
365
 
342
366
  instance.status = InstanceStatus.IDLE if health else InstanceStatus.PROVISIONING
@@ -359,10 +383,16 @@ async def _add_remote(instance: InstanceModel) -> None:
359
383
  def _deploy_instance(
360
384
  remote_details: RemoteConnectionInfo,
361
385
  pkeys: List[PKey],
386
+ ssh_proxy_pkeys: Optional[list[PKey]],
362
387
  authorized_keys: List[str],
363
388
  ) -> Tuple[HealthStatus, Dict[str, Any]]:
364
389
  with get_paramiko_connection(
365
- remote_details.ssh_user, remote_details.host, remote_details.port, pkeys
390
+ remote_details.ssh_user,
391
+ remote_details.host,
392
+ remote_details.port,
393
+ pkeys,
394
+ remote_details.ssh_proxy,
395
+ ssh_proxy_pkeys,
366
396
  ) as client:
367
397
  logger.info(f"Connected to {remote_details.ssh_user} {remote_details.host}")
368
398
 
@@ -479,6 +509,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
479
509
  requirements=requirements,
480
510
  exclude_not_available=True,
481
511
  fleet_model=instance.fleet,
512
+ blocks="auto" if instance.total_blocks is None else instance.total_blocks,
482
513
  )
483
514
 
484
515
  if not offers and should_retry:
@@ -496,11 +527,10 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
496
527
  session=session, fleet_id=instance.fleet_id
497
528
  )
498
529
 
499
- instance_configuration = _patch_instance_configuration(instance)
500
-
501
530
  for backend, instance_offer in offers:
502
531
  if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT:
503
532
  continue
533
+ instance_offer = _get_instance_offer_for_instance(instance_offer, instance)
504
534
  if (
505
535
  instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT
506
536
  and instance.fleet
@@ -554,6 +584,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
554
584
  instance.instance_configuration = instance_configuration.json()
555
585
  instance.job_provisioning_data = job_provisioning_data.json()
556
586
  instance.offer = instance_offer.json()
587
+ instance.total_blocks = instance_offer.total_blocks
557
588
  instance.started_at = get_current_datetime()
558
589
  instance.last_retry_at = get_current_datetime()
559
590
 
@@ -585,8 +616,8 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
585
616
  async def _check_instance(instance: InstanceModel) -> None:
586
617
  if (
587
618
  instance.status == InstanceStatus.BUSY
588
- and instance.job is not None
589
- and instance.job.status.is_finished()
619
+ and instance.jobs
620
+ and all(job.status.is_finished() for job in instance.jobs)
590
621
  ):
591
622
  # A busy instance could have no active jobs due to this bug: https://github.com/dstackai/dstack/issues/2068
592
623
  instance.status = InstanceStatus.TERMINATING
@@ -617,18 +648,14 @@ async def _check_instance(instance: InstanceModel) -> None:
617
648
  instance.status = InstanceStatus.BUSY
618
649
  return
619
650
 
620
- ssh_private_key = instance.project.ssh_private_key
621
- # TODO: Drop this logic and always use project key once it's safe to assume that most on-prem
622
- # fleets are (re)created after this change: https://github.com/dstackai/dstack/pull/1716
623
- if instance.remote_connection_info is not None:
624
- remote_conn_info: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
625
- instance.remote_connection_info
626
- )
627
- ssh_private_key = remote_conn_info.ssh_keys[0].private
651
+ ssh_private_keys = get_instance_ssh_private_keys(instance)
628
652
 
629
653
  # May return False if fails to establish ssh connection
630
654
  health_status_response = await run_async(
631
- _instance_healthcheck, ssh_private_key, job_provisioning_data, None
655
+ _instance_healthcheck,
656
+ ssh_private_keys,
657
+ job_provisioning_data,
658
+ None,
632
659
  )
633
660
  if isinstance(health_status_response, bool) or health_status_response is None:
634
661
  health_status = HealthStatus(healthy=False, reason="SSH or tunnel error")
@@ -648,9 +675,7 @@ async def _check_instance(instance: InstanceModel) -> None:
648
675
  instance.unreachable = False
649
676
 
650
677
  if instance.status == InstanceStatus.PROVISIONING:
651
- instance.status = (
652
- InstanceStatus.IDLE if instance.job_id is None else InstanceStatus.BUSY
653
- )
678
+ instance.status = InstanceStatus.IDLE if not instance.jobs else InstanceStatus.BUSY
654
679
  logger.info(
655
680
  "Instance %s has switched to %s status",
656
681
  instance.name,
@@ -869,21 +894,31 @@ def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool:
869
894
  )
870
895
 
871
896
 
872
- def _patch_instance_configuration(instance: InstanceModel) -> InstanceConfiguration:
873
- instance_configuration = get_instance_configuration(instance)
897
+ def _get_instance_offer_for_instance(
898
+ instance_offer: InstanceOfferWithAvailability,
899
+ instance: InstanceModel,
900
+ ) -> InstanceOfferWithAvailability:
874
901
  if instance.fleet is None:
875
- return instance_configuration
902
+ return instance_offer
876
903
 
877
904
  fleet = fleet_model_to_fleet(instance.fleet)
878
905
  master_instance = instance.fleet.instances[0]
879
906
  master_job_provisioning_data = get_instance_provisioning_data(master_instance)
907
+ instance_offer = instance_offer.copy()
880
908
  if (
881
909
  fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER
882
910
  and master_job_provisioning_data is not None
911
+ and master_job_provisioning_data.availability_zone is not None
883
912
  ):
884
- instance_configuration.availability_zone = master_job_provisioning_data.availability_zone
913
+ if instance_offer.availability_zones is None:
914
+ instance_offer.availability_zones = [master_job_provisioning_data.availability_zone]
915
+ instance_offer.availability_zones = [
916
+ z
917
+ for z in instance_offer.availability_zones
918
+ if instance_offer.availability_zones == master_job_provisioning_data.availability_zone
919
+ ]
885
920
 
886
- return instance_configuration
921
+ return instance_offer
887
922
 
888
923
 
889
924
  def _create_placement_group_if_does_not_exist(
@@ -942,3 +977,7 @@ def _get_instance_timeout_interval(
942
977
  if backend_type == BackendType.VULTR and instance_type_name.startswith("vbm"):
943
978
  return timedelta(seconds=3300)
944
979
  return timedelta(seconds=600)
980
+
981
+
982
+ def _ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]:
983
+ return [pkey_from_str(sk.private) for sk in ssh_keys if sk.private is not None]
@@ -3,18 +3,19 @@ import json
3
3
  from typing import Dict, List, Optional
4
4
 
5
5
  from sqlalchemy import delete, select
6
- from sqlalchemy.orm import selectinload
6
+ from sqlalchemy.orm import joinedload
7
7
 
8
8
  from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT
9
9
  from dstack._internal.core.models.runs import JobStatus
10
10
  from dstack._internal.server import settings
11
11
  from dstack._internal.server.db import get_session_ctx
12
- from dstack._internal.server.models import JobMetricsPoint, JobModel
12
+ from dstack._internal.server.models import InstanceModel, JobMetricsPoint, JobModel
13
13
  from dstack._internal.server.schemas.runner import MetricsResponse
14
14
  from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data
15
+ from dstack._internal.server.services.pools import get_instance_ssh_private_keys
15
16
  from dstack._internal.server.services.runner import client
16
17
  from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
17
- from dstack._internal.utils.common import batched, get_current_datetime, run_async
18
+ from dstack._internal.utils.common import batched, get_current_datetime, get_or_error, run_async
18
19
  from dstack._internal.utils.logging import get_logger
19
20
 
20
21
  logger = get_logger(__name__)
@@ -29,14 +30,12 @@ async def collect_metrics():
29
30
  async with get_session_ctx() as session:
30
31
  res = await session.execute(
31
32
  select(JobModel)
32
- .where(
33
- JobModel.status.in_([JobStatus.RUNNING]),
34
- )
35
- .options(selectinload(JobModel.project))
33
+ .where(JobModel.status.in_([JobStatus.RUNNING]))
34
+ .options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
36
35
  .order_by(JobModel.last_processed_at.asc())
37
36
  .limit(MAX_JOBS_FETCHED)
38
37
  )
39
- job_models = res.scalars().all()
38
+ job_models = res.unique().scalars().all()
40
39
 
41
40
  for batch in batched(job_models, BATCH_SIZE):
42
41
  await _collect_jobs_metrics(batch)
@@ -87,6 +86,7 @@ def _get_recently_collected_metric_cutoff() -> int:
87
86
 
88
87
 
89
88
  async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint]:
89
+ ssh_private_keys = get_instance_ssh_private_keys(get_or_error(job_model.instance))
90
90
  jpd = get_job_provisioning_data(job_model)
91
91
  jrd = get_job_runtime_data(job_model)
92
92
  if jpd is None:
@@ -94,7 +94,7 @@ async def _collect_job_metrics(job_model: JobModel) -> Optional[JobMetricsPoint]
94
94
  try:
95
95
  res = await run_async(
96
96
  _pull_runner_metrics,
97
- job_model.project.ssh_private_key,
97
+ ssh_private_keys,
98
98
  jpd,
99
99
  jrd,
100
100
  )