dstack 0.19.30rc1__py3-none-any.whl → 0.19.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (54) hide show
  1. dstack/_internal/cli/commands/__init__.py +8 -0
  2. dstack/_internal/cli/commands/project.py +27 -20
  3. dstack/_internal/cli/commands/server.py +5 -0
  4. dstack/_internal/cli/services/configurators/fleet.py +20 -6
  5. dstack/_internal/cli/utils/gpu.py +2 -2
  6. dstack/_internal/core/backends/aws/compute.py +13 -5
  7. dstack/_internal/core/backends/aws/resources.py +11 -6
  8. dstack/_internal/core/backends/azure/compute.py +17 -6
  9. dstack/_internal/core/backends/base/compute.py +57 -9
  10. dstack/_internal/core/backends/base/offers.py +1 -0
  11. dstack/_internal/core/backends/cloudrift/compute.py +2 -0
  12. dstack/_internal/core/backends/cudo/compute.py +2 -0
  13. dstack/_internal/core/backends/datacrunch/compute.py +2 -0
  14. dstack/_internal/core/backends/digitalocean_base/compute.py +2 -0
  15. dstack/_internal/core/backends/features.py +5 -0
  16. dstack/_internal/core/backends/gcp/compute.py +87 -38
  17. dstack/_internal/core/backends/gcp/configurator.py +1 -1
  18. dstack/_internal/core/backends/gcp/models.py +14 -1
  19. dstack/_internal/core/backends/gcp/resources.py +35 -12
  20. dstack/_internal/core/backends/hotaisle/compute.py +22 -0
  21. dstack/_internal/core/backends/kubernetes/compute.py +531 -215
  22. dstack/_internal/core/backends/kubernetes/models.py +13 -16
  23. dstack/_internal/core/backends/kubernetes/utils.py +145 -8
  24. dstack/_internal/core/backends/lambdalabs/compute.py +2 -0
  25. dstack/_internal/core/backends/local/compute.py +2 -0
  26. dstack/_internal/core/backends/nebius/compute.py +17 -0
  27. dstack/_internal/core/backends/nebius/configurator.py +15 -0
  28. dstack/_internal/core/backends/nebius/models.py +57 -5
  29. dstack/_internal/core/backends/nebius/resources.py +45 -2
  30. dstack/_internal/core/backends/oci/compute.py +7 -1
  31. dstack/_internal/core/backends/oci/resources.py +8 -3
  32. dstack/_internal/core/backends/template/compute.py.jinja +2 -0
  33. dstack/_internal/core/backends/tensordock/compute.py +2 -0
  34. dstack/_internal/core/backends/vultr/compute.py +2 -0
  35. dstack/_internal/core/compatibility/runs.py +8 -0
  36. dstack/_internal/core/consts.py +2 -0
  37. dstack/_internal/core/models/profiles.py +11 -4
  38. dstack/_internal/core/services/repos.py +101 -11
  39. dstack/_internal/server/background/tasks/common.py +2 -0
  40. dstack/_internal/server/background/tasks/process_fleets.py +75 -17
  41. dstack/_internal/server/background/tasks/process_instances.py +3 -5
  42. dstack/_internal/server/background/tasks/process_running_jobs.py +1 -1
  43. dstack/_internal/server/background/tasks/process_runs.py +27 -23
  44. dstack/_internal/server/background/tasks/process_submitted_jobs.py +107 -54
  45. dstack/_internal/server/services/offers.py +7 -1
  46. dstack/_internal/server/testing/common.py +2 -0
  47. dstack/_internal/server/utils/provisioning.py +3 -10
  48. dstack/_internal/utils/ssh.py +22 -2
  49. dstack/version.py +2 -2
  50. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/METADATA +20 -18
  51. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/RECORD +54 -54
  52. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/WHEEL +0 -0
  53. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/entry_points.txt +0 -0
  54. {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/licenses/LICENSE.md +0 -0
@@ -80,14 +80,21 @@ def parse_stop_duration(
80
80
  def parse_off_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[Literal["off"], int]]:
81
81
  if v == "off" or v is False:
82
82
  return "off"
83
- if v is True:
83
+ if v is True or v is None:
84
84
  return None
85
- return parse_duration(v)
85
+ duration = parse_duration(v)
86
+ if duration < 0:
87
+ raise ValueError("Duration cannot be negative")
88
+ return duration
86
89
 
87
90
 
88
- def parse_idle_duration(v: Optional[Union[int, str]]) -> Optional[int]:
89
- if v == "off" or v == -1:
91
+ def parse_idle_duration(v: Optional[Union[int, str, bool]]) -> Optional[int]:
92
+ # Differs from `parse_off_duration` to accept negative durations as `off`
93
+ # for backward compatibility.
94
+ if v == "off" or v is False or v == -1:
90
95
  return -1
96
+ if v is True:
97
+ return None
91
98
  return parse_duration(v)
92
99
 
93
100
 
@@ -36,24 +36,59 @@ def get_repo_creds_and_default_branch(
36
36
 
37
37
  # no auth
38
38
  with suppress(InvalidRepoCredentialsError):
39
- return _get_repo_creds_and_default_branch_https(url)
39
+ creds, default_branch = _get_repo_creds_and_default_branch_https(url)
40
+ logger.debug(
41
+ "Git repo %s is public. Using no auth. Default branch: %s", repo_url, default_branch
42
+ )
43
+ return creds, default_branch
40
44
 
41
45
  # ssh key provided by the user or pulled from the server
42
46
  if identity_file is not None or private_key is not None:
43
47
  if identity_file is not None:
44
48
  private_key = _read_private_key(identity_file)
45
- return _get_repo_creds_and_default_branch_ssh(url, identity_file, private_key)
49
+ creds, default_branch = _get_repo_creds_and_default_branch_ssh(
50
+ url, identity_file, private_key
51
+ )
52
+ logger.debug(
53
+ "Git repo %s is private. Using identity file: %s. Default branch: %s",
54
+ repo_url,
55
+ identity_file,
56
+ default_branch,
57
+ )
58
+ return creds, default_branch
46
59
  elif private_key is not None:
47
60
  with NamedTemporaryFile("w+", 0o600) as f:
48
61
  f.write(private_key)
49
62
  f.flush()
50
- return _get_repo_creds_and_default_branch_ssh(url, f.name, private_key)
63
+ creds, default_branch = _get_repo_creds_and_default_branch_ssh(
64
+ url, f.name, private_key
65
+ )
66
+ masked_key = "***" + private_key[-10:] if len(private_key) > 10 else "***MASKED***"
67
+ logger.debug(
68
+ "Git repo %s is private. Using private key: %s. Default branch: %s",
69
+ repo_url,
70
+ masked_key,
71
+ default_branch,
72
+ )
73
+ return creds, default_branch
51
74
  else:
52
75
  assert False, "should not reach here"
53
76
 
54
77
  # oauth token provided by the user or pulled from the server
55
78
  if oauth_token is not None:
56
- return _get_repo_creds_and_default_branch_https(url, oauth_token)
79
+ creds, default_branch = _get_repo_creds_and_default_branch_https(url, oauth_token)
80
+ masked_token = (
81
+ len(oauth_token[:-4]) * "*" + oauth_token[-4:]
82
+ if len(oauth_token) > 4
83
+ else "***MASKED***"
84
+ )
85
+ logger.debug(
86
+ "Git repo %s is private. Using provided OAuth token: %s. Default branch: %s",
87
+ repo_url,
88
+ masked_token,
89
+ default_branch,
90
+ )
91
+ return creds, default_branch
57
92
 
58
93
  # key from ssh config
59
94
  identities = get_host_config(url.original_host).get("identityfile")
@@ -61,7 +96,16 @@ def get_repo_creds_and_default_branch(
61
96
  _identity_file = identities[0]
62
97
  with suppress(InvalidRepoCredentialsError):
63
98
  _private_key = _read_private_key(_identity_file)
64
- return _get_repo_creds_and_default_branch_ssh(url, _identity_file, _private_key)
99
+ creds, default_branch = _get_repo_creds_and_default_branch_ssh(
100
+ url, _identity_file, _private_key
101
+ )
102
+ logger.debug(
103
+ "Git repo %s is private. Using SSH config identity file: %s. Default branch: %s",
104
+ repo_url,
105
+ _identity_file,
106
+ default_branch,
107
+ )
108
+ return creds, default_branch
65
109
 
66
110
  # token from gh config
67
111
  if os.path.exists(gh_config_path):
@@ -70,13 +114,35 @@ def get_repo_creds_and_default_branch(
70
114
  _oauth_token = gh_hosts.get(url.host, {}).get("oauth_token")
71
115
  if _oauth_token is not None:
72
116
  with suppress(InvalidRepoCredentialsError):
73
- return _get_repo_creds_and_default_branch_https(url, _oauth_token)
117
+ creds, default_branch = _get_repo_creds_and_default_branch_https(url, _oauth_token)
118
+ masked_token = (
119
+ len(_oauth_token[:-4]) * "*" + _oauth_token[-4:]
120
+ if len(_oauth_token) > 4
121
+ else "***MASKED***"
122
+ )
123
+ logger.debug(
124
+ "Git repo %s is private. Using GitHub config token: %s from %s. Default branch: %s",
125
+ repo_url,
126
+ masked_token,
127
+ gh_config_path,
128
+ default_branch,
129
+ )
130
+ return creds, default_branch
74
131
 
75
132
  # default user key
76
133
  if os.path.exists(default_ssh_key):
77
134
  with suppress(InvalidRepoCredentialsError):
78
135
  _private_key = _read_private_key(default_ssh_key)
79
- return _get_repo_creds_and_default_branch_ssh(url, default_ssh_key, _private_key)
136
+ creds, default_branch = _get_repo_creds_and_default_branch_ssh(
137
+ url, default_ssh_key, _private_key
138
+ )
139
+ logger.debug(
140
+ "Git repo %s is private. Using default identity file: %s. Default branch: %s",
141
+ repo_url,
142
+ default_ssh_key,
143
+ default_branch,
144
+ )
145
+ return creds, default_branch
80
146
 
81
147
  raise InvalidRepoCredentialsError(
82
148
  "No valid default Git credentials found. Pass valid `--token` or `--git-identity`."
@@ -87,8 +153,9 @@ def _get_repo_creds_and_default_branch_ssh(
87
153
  url: GitRepoURL, identity_file: PathLike, private_key: str
88
154
  ) -> tuple[RemoteRepoCreds, Optional[str]]:
89
155
  _url = url.as_ssh()
156
+ env = _make_git_env_for_creds_check(identity_file=identity_file)
90
157
  try:
91
- default_branch = _get_repo_default_branch(_url, make_git_env(identity_file=identity_file))
158
+ default_branch = _get_repo_default_branch(_url, env)
92
159
  except GitCommandError as e:
93
160
  message = f"Cannot access `{_url}` using the `{identity_file}` private SSH key"
94
161
  raise InvalidRepoCredentialsError(message) from e
@@ -104,8 +171,9 @@ def _get_repo_creds_and_default_branch_https(
104
171
  url: GitRepoURL, oauth_token: Optional[str] = None
105
172
  ) -> tuple[RemoteRepoCreds, Optional[str]]:
106
173
  _url = url.as_https()
174
+ env = _make_git_env_for_creds_check()
107
175
  try:
108
- default_branch = _get_repo_default_branch(url.as_https(oauth_token), make_git_env())
176
+ default_branch = _get_repo_default_branch(url.as_https(oauth_token), env)
109
177
  except GitCommandError as e:
110
178
  message = f"Cannot access `{_url}`"
111
179
  if oauth_token is not None:
@@ -120,10 +188,32 @@ def _get_repo_creds_and_default_branch_https(
120
188
  return creds, default_branch
121
189
 
122
190
 
191
+ def _make_git_env_for_creds_check(identity_file: Optional[PathLike] = None) -> dict[str, str]:
192
+ # Our goal is to check if _provided_ creds (if any) are correct, so we need to be sure that
193
+ # only the provided creds are used, without falling back to any additional mechanisms.
194
+ # To do this, we:
195
+ # 1. Disable all configs to ignore any stored creds
196
+ # 2. Disable askpass to avoid asking for creds interactively or fetching stored creds from
197
+ # a non-interactive askpass helper (for example, VS Code sets GIT_ASKPASS to its own helper,
198
+ # which silently provides creds to Git).
199
+ return make_git_env(disable_config=True, disable_askpass=True, identity_file=identity_file)
200
+
201
+
123
202
  def _get_repo_default_branch(url: str, env: dict[str, str]) -> Optional[str]:
203
+ # Git shipped by Apple with XCode is patched to support an additional config scope
204
+ # above "system" called "xcode". There is no option in `git config list` to show this config,
205
+ # but you can list the merged config (`git config list` without options) and then exclude
206
+ # all settings listed in `git config list --{system,global,local,worktree}`.
207
+ # As of time of writing, there are only two settings in the "xcode" config, one of which breaks
208
+ # our "is repo public?" check, namely "credential.helper=osxkeychain".
209
+ # As there is no way to disable "xcode" config (no env variable, no CLI option, etc.),
210
+ # the only way to disable credential helper is to override this specific setting with an empty
211
+ # string via command line argument: `git -c credential.helper= COMMAND [ARGS ...]`.
212
+ # See: https://github.com/git/git/commit/3d4355712b9fe77a96ad4ad877d92dc7ff6e0874
213
+ # See: https://gist.github.com/ChrisTollefson/ab9c0a5d1dd4dd615217345c6936a307
214
+ _git = git.cmd.Git()(c="credential.helper=")
124
215
  # output example: "ref: refs/heads/dev\tHEAD\n545344f77c0df78367085952a97fc3a058eb4c65\tHEAD"
125
- # Disable credential helpers to exclude any default credentials from being used
126
- output: str = git.cmd.Git()(c="credential.helper=").ls_remote("--symref", url, "HEAD", env=env)
216
+ output: str = _git.ls_remote("--symref", url, "HEAD", env=env)
127
217
  for line in output.splitlines():
128
218
  # line format: `<oid> TAB <ref> LF`
129
219
  oid, _, ref = line.partition("\t")
@@ -19,4 +19,6 @@ def get_provisioning_timeout(backend_type: BackendType, instance_type_name: str)
19
19
  return timedelta(minutes=20)
20
20
  if backend_type == BackendType.VULTR and instance_type_name.startswith("vbm"):
21
21
  return timedelta(minutes=55)
22
+ if backend_type == BackendType.GCP and instance_type_name == "a4-highgpu-8g":
23
+ return timedelta(minutes=16)
22
24
  return timedelta(minutes=10)
@@ -1,10 +1,11 @@
1
+ from collections import defaultdict
1
2
  from datetime import timedelta
2
3
  from typing import List
3
4
  from uuid import UUID
4
5
 
5
6
  from sqlalchemy import select, update
6
7
  from sqlalchemy.ext.asyncio import AsyncSession
7
- from sqlalchemy.orm import joinedload, load_only
8
+ from sqlalchemy.orm import joinedload, load_only, selectinload
8
9
 
9
10
  from dstack._internal.core.models.fleets import FleetSpec, FleetStatus
10
11
  from dstack._internal.core.models.instances import InstanceStatus
@@ -37,30 +38,68 @@ MIN_PROCESSING_INTERVAL = timedelta(seconds=30)
37
38
 
38
39
  @sentry_utils.instrument_background_task
39
40
  async def process_fleets():
40
- lock, lockset = get_locker(get_db().dialect_name).get_lockset(FleetModel.__tablename__)
41
+ fleet_lock, fleet_lockset = get_locker(get_db().dialect_name).get_lockset(
42
+ FleetModel.__tablename__
43
+ )
44
+ instance_lock, instance_lockset = get_locker(get_db().dialect_name).get_lockset(
45
+ InstanceModel.__tablename__
46
+ )
41
47
  async with get_session_ctx() as session:
42
- async with lock:
48
+ async with fleet_lock, instance_lock:
43
49
  res = await session.execute(
44
50
  select(FleetModel)
45
51
  .where(
46
52
  FleetModel.deleted == False,
47
- FleetModel.id.not_in(lockset),
53
+ FleetModel.id.not_in(fleet_lockset),
48
54
  FleetModel.last_processed_at
49
55
  < get_current_datetime() - MIN_PROCESSING_INTERVAL,
50
56
  )
51
- .options(load_only(FleetModel.id))
57
+ .options(
58
+ load_only(FleetModel.id, FleetModel.name),
59
+ selectinload(FleetModel.instances).load_only(InstanceModel.id),
60
+ )
52
61
  .order_by(FleetModel.last_processed_at.asc())
53
62
  .limit(BATCH_SIZE)
54
63
  .with_for_update(skip_locked=True, key_share=True)
55
64
  )
56
- fleet_models = list(res.scalars().all())
65
+ fleet_models = list(res.scalars().unique().all())
57
66
  fleet_ids = [fm.id for fm in fleet_models]
67
+ res = await session.execute(
68
+ select(InstanceModel)
69
+ .where(
70
+ InstanceModel.id.not_in(instance_lockset),
71
+ InstanceModel.fleet_id.in_(fleet_ids),
72
+ )
73
+ .options(load_only(InstanceModel.id, InstanceModel.fleet_id))
74
+ .order_by(InstanceModel.id)
75
+ .with_for_update(skip_locked=True, key_share=True)
76
+ )
77
+ instance_models = list(res.scalars().all())
78
+ fleet_id_to_locked_instances = defaultdict(list)
79
+ for instance_model in instance_models:
80
+ fleet_id_to_locked_instances[instance_model.fleet_id].append(instance_model)
81
+ # Process only fleets with all instances locked.
82
+ # Other fleets won't be processed but will still be locked to avoid new transaction.
83
+ # This should not be problematic as long as process_fleets is quick.
84
+ fleet_models_to_process = []
85
+ for fleet_model in fleet_models:
86
+ if len(fleet_model.instances) == len(fleet_id_to_locked_instances[fleet_model.id]):
87
+ fleet_models_to_process.append(fleet_model)
88
+ else:
89
+ logger.debug(
90
+ "Fleet %s processing will be skipped: some instance were not locked",
91
+ fleet_model.name,
92
+ )
58
93
  for fleet_id in fleet_ids:
59
- lockset.add(fleet_id)
94
+ fleet_lockset.add(fleet_id)
95
+ instance_ids = [im.id for im in instance_models]
96
+ for instance_id in instance_ids:
97
+ instance_lockset.add(instance_id)
60
98
  try:
61
- await _process_fleets(session=session, fleet_models=fleet_models)
99
+ await _process_fleets(session=session, fleet_models=fleet_models_to_process)
62
100
  finally:
63
- lockset.difference_update(fleet_ids)
101
+ fleet_lockset.difference_update(fleet_ids)
102
+ instance_lockset.difference_update(instance_ids)
64
103
 
65
104
 
66
105
  async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel]):
@@ -99,8 +138,8 @@ def _consolidate_fleet_state_with_spec(session: AsyncSession, fleet_model: Fleet
99
138
  return
100
139
  if not _is_fleet_ready_for_consolidation(fleet_model):
101
140
  return
102
- added_instances = _maintain_fleet_nodes_min(session, fleet_model, fleet_spec)
103
- if added_instances:
141
+ changed_instances = _maintain_fleet_nodes_in_min_max_range(session, fleet_model, fleet_spec)
142
+ if changed_instances:
104
143
  fleet_model.consolidation_attempt += 1
105
144
  else:
106
145
  # The fleet is already consolidated or consolidation is in progress.
@@ -138,28 +177,47 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta:
138
177
  return _CONSOLIDATION_RETRY_DELAYS[-1]
139
178
 
140
179
 
141
- def _maintain_fleet_nodes_min(
180
+ def _maintain_fleet_nodes_in_min_max_range(
142
181
  session: AsyncSession,
143
182
  fleet_model: FleetModel,
144
183
  fleet_spec: FleetSpec,
145
184
  ) -> bool:
146
185
  """
147
- Ensures the fleet has at least `nodes.min` instances.
148
- Returns `True` if retried or added new instances and `False` otherwise.
186
+ Ensures the fleet has at least `nodes.min` and at most `nodes.max` instances.
187
+ Returns `True` if retried, added new instances, or terminated redundant instances and `False` otherwise.
149
188
  """
150
189
  assert fleet_spec.configuration.nodes is not None
151
190
  for instance in fleet_model.instances:
152
191
  # Delete terminated but not deleted instances since
153
192
  # they are going to be replaced with new pending instances.
154
193
  if instance.status == InstanceStatus.TERMINATED and not instance.deleted:
155
- # It's safe to modify instances without instance lock since
156
- # no other task modifies already terminated instances.
157
194
  instance.deleted = True
158
195
  instance.deleted_at = get_current_datetime()
159
196
  active_instances = [i for i in fleet_model.instances if not i.deleted]
160
197
  active_instances_num = len(active_instances)
161
198
  if active_instances_num >= fleet_spec.configuration.nodes.min:
162
- return False
199
+ if (
200
+ fleet_spec.configuration.nodes.max is None
201
+ or active_instances_num <= fleet_spec.configuration.nodes.max
202
+ ):
203
+ return False
204
+ # Fleet has more instances than allowed by nodes.max.
205
+ # This is possible due to race conditions (e.g. provisioning jobs in a fleet concurrently)
206
+ # or if nodes.max is updated.
207
+ nodes_redundant = active_instances_num - fleet_spec.configuration.nodes.max
208
+ for instance in fleet_model.instances:
209
+ if nodes_redundant == 0:
210
+ break
211
+ if instance.status in [InstanceStatus.IDLE]:
212
+ instance.status = InstanceStatus.TERMINATING
213
+ instance.termination_reason = "Fleet has too many instances"
214
+ nodes_redundant -= 1
215
+ logger.info(
216
+ "Terminating instance %s: %s",
217
+ instance.name,
218
+ instance.termination_reason,
219
+ )
220
+ return True
163
221
  nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num
164
222
  for i in range(nodes_missing):
165
223
  instance_model = create_fleet_instance_model(
@@ -259,9 +259,7 @@ async def _add_remote(instance: InstanceModel) -> None:
259
259
  if instance.status == InstanceStatus.PENDING:
260
260
  instance.status = InstanceStatus.PROVISIONING
261
261
 
262
- retry_duration_deadline = instance.created_at.replace(
263
- tzinfo=datetime.timezone.utc
264
- ) + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
262
+ retry_duration_deadline = instance.created_at + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
265
263
  if retry_duration_deadline < get_current_datetime():
266
264
  instance.status = InstanceStatus.TERMINATED
267
265
  instance.termination_reason = "Provisioning timeout expired"
@@ -307,7 +305,7 @@ async def _add_remote(instance: InstanceModel) -> None:
307
305
  )
308
306
  deploy_timeout = 20 * 60 # 20 minutes
309
307
  result = await asyncio.wait_for(future, timeout=deploy_timeout)
310
- health, host_info, cpu_arch = result
308
+ health, host_info, arch = result
311
309
  except (asyncio.TimeoutError, TimeoutError) as e:
312
310
  raise ProvisioningError(f"Deploy timeout: {e}") from e
313
311
  except Exception as e:
@@ -327,7 +325,7 @@ async def _add_remote(instance: InstanceModel) -> None:
327
325
  instance.status = InstanceStatus.PENDING
328
326
  return
329
327
 
330
- instance_type = host_info_to_instance_type(host_info, cpu_arch)
328
+ instance_type = host_info_to_instance_type(host_info, arch)
331
329
  instance_network = None
332
330
  internal_ip = None
333
331
  try:
@@ -1139,7 +1139,7 @@ def _patch_base_image_for_aws_efa(
1139
1139
  efa_enabled_patterns = [
1140
1140
  # TODO: p6-b200 isn't supported yet in gpuhunt
1141
1141
  r"^p6-b200\.(48xlarge)$",
1142
- r"^p5\.(48xlarge)$",
1142
+ r"^p5\.(4xlarge|48xlarge)$",
1143
1143
  r"^p5e\.(48xlarge)$",
1144
1144
  r"^p5en\.(48xlarge)$",
1145
1145
  r"^p4d\.(24xlarge)$",
@@ -256,8 +256,8 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
256
256
  for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
257
257
  replica_statuses: Set[RunStatus] = set()
258
258
  replica_needs_retry = False
259
-
260
259
  replica_active = True
260
+ jobs_done_num = 0
261
261
  for job_model in job_models:
262
262
  job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
263
263
  if (
@@ -272,8 +272,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
272
272
  ):
273
273
  # the job is done or going to be done
274
274
  replica_statuses.add(RunStatus.DONE)
275
- # for some reason the replica is done, it's not active
276
- replica_active = False
275
+ jobs_done_num += 1
277
276
  elif job_model.termination_reason == JobTerminationReason.SCALED_DOWN:
278
277
  # the job was scaled down
279
278
  replica_active = False
@@ -313,26 +312,14 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
313
312
  if not replica_needs_retry or retry_single_job:
314
313
  run_statuses.update(replica_statuses)
315
314
 
316
- if replica_active:
317
- # submitted_at = replica created
318
- replicas_info.append(
319
- autoscalers.ReplicaInfo(
320
- active=True,
321
- timestamp=min(job.submitted_at for job in job_models).replace(
322
- tzinfo=datetime.timezone.utc
323
- ),
324
- )
325
- )
326
- else:
327
- # last_processed_at = replica scaled down
328
- replicas_info.append(
329
- autoscalers.ReplicaInfo(
330
- active=False,
331
- timestamp=max(job.last_processed_at for job in job_models).replace(
332
- tzinfo=datetime.timezone.utc
333
- ),
334
- )
335
- )
315
+ if jobs_done_num == len(job_models):
316
+ # Consider replica inactive if all its jobs are done for some reason.
317
+ # If only some jobs are done, replica is considered active to avoid
318
+ # provisioning new replicas for partially done multi-node tasks.
319
+ replica_active = False
320
+
321
+ replica_info = _get_replica_info(job_models, replica_active)
322
+ replicas_info.append(replica_info)
336
323
 
337
324
  termination_reason: Optional[RunTerminationReason] = None
338
325
  if RunStatus.FAILED in run_statuses:
@@ -410,6 +397,23 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
410
397
  run_model.resubmission_attempt += 1
411
398
 
412
399
 
400
+ def _get_replica_info(
401
+ replica_job_models: list[JobModel],
402
+ replica_active: bool,
403
+ ) -> autoscalers.ReplicaInfo:
404
+ if replica_active:
405
+ # submitted_at = replica created
406
+ return autoscalers.ReplicaInfo(
407
+ active=True,
408
+ timestamp=min(job.submitted_at for job in replica_job_models),
409
+ )
410
+ # last_processed_at = replica scaled down
411
+ return autoscalers.ReplicaInfo(
412
+ active=False,
413
+ timestamp=max(job.last_processed_at for job in replica_job_models),
414
+ )
415
+
416
+
413
417
  async def _handle_run_replicas(
414
418
  session: AsyncSession,
415
419
  run_model: RunModel,