dstack 0.19.31__py3-none-any.whl → 0.19.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dstack might be problematic. Click here for more details.

Files changed (53) hide show
  1. dstack/_internal/cli/commands/offer.py +1 -1
  2. dstack/_internal/cli/services/configurators/run.py +1 -5
  3. dstack/_internal/core/backends/aws/compute.py +8 -5
  4. dstack/_internal/core/backends/azure/compute.py +9 -6
  5. dstack/_internal/core/backends/base/compute.py +40 -17
  6. dstack/_internal/core/backends/base/offers.py +5 -1
  7. dstack/_internal/core/backends/datacrunch/compute.py +9 -6
  8. dstack/_internal/core/backends/gcp/compute.py +137 -7
  9. dstack/_internal/core/backends/gcp/models.py +7 -0
  10. dstack/_internal/core/backends/gcp/resources.py +87 -5
  11. dstack/_internal/core/backends/hotaisle/compute.py +30 -0
  12. dstack/_internal/core/backends/kubernetes/compute.py +218 -77
  13. dstack/_internal/core/backends/kubernetes/models.py +4 -2
  14. dstack/_internal/core/backends/nebius/compute.py +24 -6
  15. dstack/_internal/core/backends/nebius/configurator.py +15 -0
  16. dstack/_internal/core/backends/nebius/models.py +57 -5
  17. dstack/_internal/core/backends/nebius/resources.py +45 -2
  18. dstack/_internal/core/backends/oci/compute.py +9 -6
  19. dstack/_internal/core/backends/runpod/compute.py +10 -6
  20. dstack/_internal/core/backends/vastai/compute.py +3 -1
  21. dstack/_internal/core/backends/vastai/configurator.py +0 -1
  22. dstack/_internal/core/compatibility/runs.py +8 -0
  23. dstack/_internal/core/models/fleets.py +1 -1
  24. dstack/_internal/core/models/profiles.py +12 -5
  25. dstack/_internal/core/models/runs.py +3 -2
  26. dstack/_internal/core/models/users.py +10 -0
  27. dstack/_internal/core/services/configs/__init__.py +1 -0
  28. dstack/_internal/server/background/tasks/process_fleets.py +75 -17
  29. dstack/_internal/server/background/tasks/process_instances.py +6 -4
  30. dstack/_internal/server/background/tasks/process_running_jobs.py +1 -0
  31. dstack/_internal/server/background/tasks/process_runs.py +27 -23
  32. dstack/_internal/server/background/tasks/process_submitted_jobs.py +63 -20
  33. dstack/_internal/server/migrations/versions/ff1d94f65b08_user_ssh_key.py +34 -0
  34. dstack/_internal/server/models.py +3 -0
  35. dstack/_internal/server/routers/runs.py +5 -1
  36. dstack/_internal/server/routers/users.py +14 -2
  37. dstack/_internal/server/services/runs.py +9 -4
  38. dstack/_internal/server/services/users.py +35 -2
  39. dstack/_internal/server/statics/index.html +1 -1
  40. dstack/_internal/server/statics/main-720ce3a11140daa480cc.css +3 -0
  41. dstack/_internal/server/statics/{main-c51afa7f243e24d3e446.js → main-97c7e184573ca23f9fe4.js} +12218 -7625
  42. dstack/_internal/server/statics/{main-c51afa7f243e24d3e446.js.map → main-97c7e184573ca23f9fe4.js.map} +1 -1
  43. dstack/api/_public/__init__.py +9 -12
  44. dstack/api/_public/repos.py +0 -21
  45. dstack/api/_public/runs.py +64 -9
  46. dstack/api/server/_users.py +17 -2
  47. dstack/version.py +2 -2
  48. {dstack-0.19.31.dist-info → dstack-0.19.33.dist-info}/METADATA +12 -14
  49. {dstack-0.19.31.dist-info → dstack-0.19.33.dist-info}/RECORD +52 -51
  50. dstack/_internal/server/statics/main-56191fbfe77f49b251de.css +0 -3
  51. {dstack-0.19.31.dist-info → dstack-0.19.33.dist-info}/WHEEL +0 -0
  52. {dstack-0.19.31.dist-info → dstack-0.19.33.dist-info}/entry_points.txt +0 -0
  53. {dstack-0.19.31.dist-info → dstack-0.19.33.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,11 +1,12 @@
1
1
  import logging
2
+ import re
2
3
  import time
3
4
  from collections import defaultdict
4
5
  from collections.abc import Container as ContainerT
5
6
  from collections.abc import Generator, Iterable, Sequence
6
7
  from contextlib import contextmanager
7
8
  from tempfile import NamedTemporaryFile
8
- from typing import Optional
9
+ from typing import Dict, Optional
9
10
 
10
11
  from nebius.aio.authorization.options import options_to_metadata
11
12
  from nebius.aio.operation import Operation as SDKOperation
@@ -249,13 +250,14 @@ def get_default_subnet(sdk: SDK, project_id: str) -> Subnet:
249
250
 
250
251
 
251
252
  def create_disk(
252
- sdk: SDK, name: str, project_id: str, size_mib: int, image_family: str
253
+ sdk: SDK, name: str, project_id: str, size_mib: int, image_family: str, labels: Dict[str, str]
253
254
  ) -> SDKOperation[Operation]:
254
255
  client = DiskServiceClient(sdk)
255
256
  request = CreateDiskRequest(
256
257
  metadata=ResourceMetadata(
257
258
  name=name,
258
259
  parent_id=project_id,
260
+ labels=labels,
259
261
  ),
260
262
  spec=DiskSpec(
261
263
  size_mebibytes=size_mib,
@@ -288,12 +290,14 @@ def create_instance(
288
290
  disk_id: str,
289
291
  subnet_id: str,
290
292
  preemptible: bool,
293
+ labels: Dict[str, str],
291
294
  ) -> SDKOperation[Operation]:
292
295
  client = InstanceServiceClient(sdk)
293
296
  request = CreateInstanceRequest(
294
297
  metadata=ResourceMetadata(
295
298
  name=name,
296
299
  parent_id=project_id,
300
+ labels=labels,
297
301
  ),
298
302
  spec=InstanceSpec(
299
303
  cloud_init_user_data=user_data,
@@ -367,3 +371,42 @@ def delete_cluster(sdk: SDK, cluster_id: str) -> None:
367
371
  metadata=REQUEST_MD,
368
372
  )
369
373
  )
374
+
375
+
376
+ def filter_invalid_labels(labels: Dict[str, str]) -> Dict[str, str]:
377
+ filtered_labels = {}
378
+ for k, v in labels.items():
379
+ if not _is_valid_label(k, v):
380
+ logger.warning("Skipping invalid label '%s: %s'", k, v)
381
+ continue
382
+ filtered_labels[k] = v
383
+ return filtered_labels
384
+
385
+
386
+ def validate_labels(labels: Dict[str, str]):
387
+ for k, v in labels.items():
388
+ if not _is_valid_label(k, v):
389
+ raise BackendError("Invalid resource labels")
390
+
391
+
392
+ def _is_valid_label(key: str, value: str) -> bool:
393
+ # TODO: [Nebius] current validation logic reuses GCP's approach.
394
+ # There is no public information on Nebius labels restrictions.
395
+ return is_valid_resource_name(key) and is_valid_label_value(value)
396
+
397
+
398
+ MAX_RESOURCE_NAME_LEN = 63
399
+ NAME_PATTERN = re.compile(r"^[a-z][_\-a-z0-9]{0,62}$")
400
+ LABEL_VALUE_PATTERN = re.compile(r"^[_\-a-z0-9]{0,63}$")
401
+
402
+
403
+ def is_valid_resource_name(name: str) -> bool:
404
+ if len(name) < 1 or len(name) > MAX_RESOURCE_NAME_LEN:
405
+ return False
406
+ match = re.match(NAME_PATTERN, name)
407
+ return match is not None
408
+
409
+
410
+ def is_valid_label_value(value: str) -> bool:
411
+ match = re.match(LABEL_VALUE_PATTERN, value)
412
+ return match is not None
@@ -1,6 +1,7 @@
1
+ from collections.abc import Iterable
1
2
  from concurrent.futures import ThreadPoolExecutor
2
3
  from functools import cached_property
3
- from typing import Callable, List, Optional
4
+ from typing import List, Optional
4
5
 
5
6
  import oci
6
7
 
@@ -13,7 +14,11 @@ from dstack._internal.core.backends.base.compute import (
13
14
  generate_unique_instance_name,
14
15
  get_user_data,
15
16
  )
16
- from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier
17
+ from dstack._internal.core.backends.base.offers import (
18
+ OfferModifier,
19
+ get_catalog_offers,
20
+ get_offers_disk_modifier,
21
+ )
17
22
  from dstack._internal.core.backends.oci import resources
18
23
  from dstack._internal.core.backends.oci.models import OCIConfig
19
24
  from dstack._internal.core.backends.oci.region import make_region_clients_map
@@ -96,10 +101,8 @@ class OCICompute(
96
101
 
97
102
  return offers_with_availability
98
103
 
99
- def get_offers_modifier(
100
- self, requirements: Requirements
101
- ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]:
102
- return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)
104
+ def get_offers_modifiers(self, requirements: Requirements) -> Iterable[OfferModifier]:
105
+ return [get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)]
103
106
 
104
107
  def terminate_instance(
105
108
  self, instance_id: str, region: str, backend_data: Optional[str] = None
@@ -1,7 +1,8 @@
1
1
  import json
2
2
  import uuid
3
+ from collections.abc import Iterable
3
4
  from datetime import timedelta
4
- from typing import Callable, List, Optional
5
+ from typing import List, Optional
5
6
 
6
7
  from dstack._internal.core.backends.base.backend import Compute
7
8
  from dstack._internal.core.backends.base.compute import (
@@ -12,7 +13,11 @@ from dstack._internal.core.backends.base.compute import (
12
13
  get_docker_commands,
13
14
  get_job_instance_name,
14
15
  )
15
- from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier
16
+ from dstack._internal.core.backends.base.offers import (
17
+ OfferModifier,
18
+ get_catalog_offers,
19
+ get_offers_disk_modifier,
20
+ )
16
21
  from dstack._internal.core.backends.runpod.api_client import RunpodApiClient
17
22
  from dstack._internal.core.backends.runpod.models import RunpodConfig
18
23
  from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
@@ -72,10 +77,8 @@ class RunpodCompute(
72
77
  ]
73
78
  return offers
74
79
 
75
- def get_offers_modifier(
76
- self, requirements: Requirements
77
- ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]:
78
- return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)
80
+ def get_offers_modifiers(self, requirements: Requirements) -> Iterable[OfferModifier]:
81
+ return [get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)]
79
82
 
80
83
  def run_job(
81
84
  self,
@@ -86,6 +89,7 @@ class RunpodCompute(
86
89
  project_ssh_private_key: str,
87
90
  volumes: List[Volume],
88
91
  ) -> JobProvisioningData:
92
+ assert run.run_spec.ssh_key_pub is not None
89
93
  instance_config = InstanceConfiguration(
90
94
  project_name=run.project_name,
91
95
  instance_name=get_job_instance_name(run, job),
@@ -47,7 +47,7 @@ class VastAICompute(
47
47
  "reliability2": {"gte": 0.9},
48
48
  "inet_down": {"gt": 128},
49
49
  "verified": {"eq": True},
50
- "cuda_max_good": {"gte": 12.1},
50
+ "cuda_max_good": {"gte": 12.8},
51
51
  "compute_cap": {"gte": 600},
52
52
  }
53
53
  )
@@ -58,6 +58,7 @@ class VastAICompute(
58
58
  ) -> List[InstanceOfferWithAvailability]:
59
59
  offers = get_catalog_offers(
60
60
  backend=BackendType.VASTAI,
61
+ locations=self.config.regions or None,
61
62
  requirements=requirements,
62
63
  # TODO(egor-s): spots currently not supported
63
64
  extra_filter=lambda offer: not offer.instance.resources.spot,
@@ -85,6 +86,7 @@ class VastAICompute(
85
86
  instance_name = generate_unique_instance_name_for_job(
86
87
  run, job, max_length=MAX_INSTANCE_NAME_LEN
87
88
  )
89
+ assert run.run_spec.ssh_key_pub is not None
88
90
  commands = get_docker_commands(
89
91
  [run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()]
90
92
  )
@@ -18,7 +18,6 @@ from dstack._internal.core.models.backends.base import (
18
18
  BackendType,
19
19
  )
20
20
 
21
- # VastAI regions are dynamic, currently we don't offer any filtering
22
21
  REGIONS = []
23
22
 
24
23
 
@@ -53,6 +53,10 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[IncludeExcludeD
53
53
  }
54
54
  if all(js.exit_status is None for js in job_submissions):
55
55
  job_submissions_excludes["exit_status"] = True
56
+ if all(js.status_message == "" for js in job_submissions):
57
+ job_submissions_excludes["status_message"] = True
58
+ if all(js.error is None for js in job_submissions):
59
+ job_submissions_excludes["error"] = True
56
60
  if all(js.deployment_num == 0 for js in job_submissions):
57
61
  job_submissions_excludes["deployment_num"] = True
58
62
  if all(not js.probes for js in job_submissions):
@@ -71,6 +75,10 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[IncludeExcludeD
71
75
  }
72
76
  if latest_job_submission.exit_status is None:
73
77
  latest_job_submission_excludes["exit_status"] = True
78
+ if latest_job_submission.status_message == "":
79
+ latest_job_submission_excludes["status_message"] = True
80
+ if latest_job_submission.error is None:
81
+ latest_job_submission_excludes["error"] = True
74
82
  if latest_job_submission.deployment_num == 0:
75
83
  latest_job_submission_excludes["deployment_num"] = True
76
84
  if not latest_job_submission.probes:
@@ -244,7 +244,7 @@ class InstanceGroupParams(CoreModel):
244
244
  Field(
245
245
  description=(
246
246
  "The existing reservation to use for instance provisioning."
247
- " Supports AWS Capacity Reservations and Capacity Blocks"
247
+ " Supports AWS Capacity Reservations, AWS Capacity Blocks, and GCP reservations"
248
248
  )
249
249
  ),
250
250
  ] = None
@@ -80,14 +80,21 @@ def parse_stop_duration(
80
80
  def parse_off_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[Literal["off"], int]]:
81
81
  if v == "off" or v is False:
82
82
  return "off"
83
- if v is True:
83
+ if v is True or v is None:
84
84
  return None
85
- return parse_duration(v)
85
+ duration = parse_duration(v)
86
+ if duration < 0:
87
+ raise ValueError("Duration cannot be negative")
88
+ return duration
86
89
 
87
90
 
88
- def parse_idle_duration(v: Optional[Union[int, str]]) -> Optional[int]:
89
- if v == "off" or v == -1:
91
+ def parse_idle_duration(v: Optional[Union[int, str, bool]]) -> Optional[int]:
92
+ # Differs from `parse_off_duration` to accept negative durations as `off`
93
+ # for backward compatibility.
94
+ if v == "off" or v is False or v == -1:
90
95
  return -1
96
+ if v is True:
97
+ return None
91
98
  return parse_duration(v)
92
99
 
93
100
 
@@ -276,7 +283,7 @@ class ProfileParams(CoreModel):
276
283
  Field(
277
284
  description=(
278
285
  "The existing reservation to use for instance provisioning."
279
- " Supports AWS Capacity Reservations and Capacity Blocks"
286
+ " Supports AWS Capacity Reservations, AWS Capacity Blocks, and GCP reservations"
280
287
  )
281
288
  ),
282
289
  ] = None
@@ -462,11 +462,12 @@ class RunSpec(generate_dual_core_model(RunSpecConfig)):
462
462
  configuration: Annotated[AnyRunConfiguration, Field(discriminator="type")]
463
463
  profile: Annotated[Optional[Profile], Field(description="The profile parameters")] = None
464
464
  ssh_key_pub: Annotated[
465
- str,
465
+ Optional[str],
466
466
  Field(
467
467
  description="The contents of the SSH public key that will be used to connect to the run."
468
+ " Can be empty only before the run is submitted."
468
469
  ),
469
- ]
470
+ ] = None
470
471
  # merged_profile stores profile parameters merged from profile and configuration.
471
472
  # Read profile parameters from merged_profile instead of profile directly.
472
473
  # TODO: make merged_profile a computed field after migrating to pydanticV2
@@ -30,6 +30,7 @@ class User(CoreModel):
30
30
  email: Optional[str]
31
31
  active: bool
32
32
  permissions: UserPermissions
33
+ ssh_public_key: Optional[str] = None
33
34
 
34
35
 
35
36
  class UserTokenCreds(CoreModel):
@@ -38,3 +39,12 @@ class UserTokenCreds(CoreModel):
38
39
 
39
40
  class UserWithCreds(User):
40
41
  creds: UserTokenCreds
42
+ ssh_private_key: Optional[str] = None
43
+
44
+
45
+ class UserHookConfig(CoreModel):
46
+ """
47
+ This class can be inherited to extend the user creation configuration passed to the hooks.
48
+ """
49
+
50
+ pass
@@ -117,6 +117,7 @@ class ConfigManager:
117
117
 
118
118
  @property
119
119
  def dstack_key_path(self) -> Path:
120
+ # TODO: Remove since 0.19.40
120
121
  return self.dstack_ssh_dir / "id_rsa"
121
122
 
122
123
  @property
@@ -1,10 +1,11 @@
1
+ from collections import defaultdict
1
2
  from datetime import timedelta
2
3
  from typing import List
3
4
  from uuid import UUID
4
5
 
5
6
  from sqlalchemy import select, update
6
7
  from sqlalchemy.ext.asyncio import AsyncSession
7
- from sqlalchemy.orm import joinedload, load_only
8
+ from sqlalchemy.orm import joinedload, load_only, selectinload
8
9
 
9
10
  from dstack._internal.core.models.fleets import FleetSpec, FleetStatus
10
11
  from dstack._internal.core.models.instances import InstanceStatus
@@ -37,30 +38,68 @@ MIN_PROCESSING_INTERVAL = timedelta(seconds=30)
37
38
 
38
39
  @sentry_utils.instrument_background_task
39
40
  async def process_fleets():
40
- lock, lockset = get_locker(get_db().dialect_name).get_lockset(FleetModel.__tablename__)
41
+ fleet_lock, fleet_lockset = get_locker(get_db().dialect_name).get_lockset(
42
+ FleetModel.__tablename__
43
+ )
44
+ instance_lock, instance_lockset = get_locker(get_db().dialect_name).get_lockset(
45
+ InstanceModel.__tablename__
46
+ )
41
47
  async with get_session_ctx() as session:
42
- async with lock:
48
+ async with fleet_lock, instance_lock:
43
49
  res = await session.execute(
44
50
  select(FleetModel)
45
51
  .where(
46
52
  FleetModel.deleted == False,
47
- FleetModel.id.not_in(lockset),
53
+ FleetModel.id.not_in(fleet_lockset),
48
54
  FleetModel.last_processed_at
49
55
  < get_current_datetime() - MIN_PROCESSING_INTERVAL,
50
56
  )
51
- .options(load_only(FleetModel.id))
57
+ .options(
58
+ load_only(FleetModel.id, FleetModel.name),
59
+ selectinload(FleetModel.instances).load_only(InstanceModel.id),
60
+ )
52
61
  .order_by(FleetModel.last_processed_at.asc())
53
62
  .limit(BATCH_SIZE)
54
63
  .with_for_update(skip_locked=True, key_share=True)
55
64
  )
56
- fleet_models = list(res.scalars().all())
65
+ fleet_models = list(res.scalars().unique().all())
57
66
  fleet_ids = [fm.id for fm in fleet_models]
67
+ res = await session.execute(
68
+ select(InstanceModel)
69
+ .where(
70
+ InstanceModel.id.not_in(instance_lockset),
71
+ InstanceModel.fleet_id.in_(fleet_ids),
72
+ )
73
+ .options(load_only(InstanceModel.id, InstanceModel.fleet_id))
74
+ .order_by(InstanceModel.id)
75
+ .with_for_update(skip_locked=True, key_share=True)
76
+ )
77
+ instance_models = list(res.scalars().all())
78
+ fleet_id_to_locked_instances = defaultdict(list)
79
+ for instance_model in instance_models:
80
+ fleet_id_to_locked_instances[instance_model.fleet_id].append(instance_model)
81
+ # Process only fleets with all instances locked.
82
+ # Other fleets won't be processed but will still be locked to avoid new transaction.
83
+ # This should not be problematic as long as process_fleets is quick.
84
+ fleet_models_to_process = []
85
+ for fleet_model in fleet_models:
86
+ if len(fleet_model.instances) == len(fleet_id_to_locked_instances[fleet_model.id]):
87
+ fleet_models_to_process.append(fleet_model)
88
+ else:
89
+ logger.debug(
90
+ "Fleet %s processing will be skipped: some instance were not locked",
91
+ fleet_model.name,
92
+ )
58
93
  for fleet_id in fleet_ids:
59
- lockset.add(fleet_id)
94
+ fleet_lockset.add(fleet_id)
95
+ instance_ids = [im.id for im in instance_models]
96
+ for instance_id in instance_ids:
97
+ instance_lockset.add(instance_id)
60
98
  try:
61
- await _process_fleets(session=session, fleet_models=fleet_models)
99
+ await _process_fleets(session=session, fleet_models=fleet_models_to_process)
62
100
  finally:
63
- lockset.difference_update(fleet_ids)
101
+ fleet_lockset.difference_update(fleet_ids)
102
+ instance_lockset.difference_update(instance_ids)
64
103
 
65
104
 
66
105
  async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel]):
@@ -99,8 +138,8 @@ def _consolidate_fleet_state_with_spec(session: AsyncSession, fleet_model: Fleet
99
138
  return
100
139
  if not _is_fleet_ready_for_consolidation(fleet_model):
101
140
  return
102
- added_instances = _maintain_fleet_nodes_min(session, fleet_model, fleet_spec)
103
- if added_instances:
141
+ changed_instances = _maintain_fleet_nodes_in_min_max_range(session, fleet_model, fleet_spec)
142
+ if changed_instances:
104
143
  fleet_model.consolidation_attempt += 1
105
144
  else:
106
145
  # The fleet is already consolidated or consolidation is in progress.
@@ -138,28 +177,47 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta:
138
177
  return _CONSOLIDATION_RETRY_DELAYS[-1]
139
178
 
140
179
 
141
- def _maintain_fleet_nodes_min(
180
+ def _maintain_fleet_nodes_in_min_max_range(
142
181
  session: AsyncSession,
143
182
  fleet_model: FleetModel,
144
183
  fleet_spec: FleetSpec,
145
184
  ) -> bool:
146
185
  """
147
- Ensures the fleet has at least `nodes.min` instances.
148
- Returns `True` if retried or added new instances and `False` otherwise.
186
+ Ensures the fleet has at least `nodes.min` and at most `nodes.max` instances.
187
+ Returns `True` if retried, added new instances, or terminated redundant instances and `False` otherwise.
149
188
  """
150
189
  assert fleet_spec.configuration.nodes is not None
151
190
  for instance in fleet_model.instances:
152
191
  # Delete terminated but not deleted instances since
153
192
  # they are going to be replaced with new pending instances.
154
193
  if instance.status == InstanceStatus.TERMINATED and not instance.deleted:
155
- # It's safe to modify instances without instance lock since
156
- # no other task modifies already terminated instances.
157
194
  instance.deleted = True
158
195
  instance.deleted_at = get_current_datetime()
159
196
  active_instances = [i for i in fleet_model.instances if not i.deleted]
160
197
  active_instances_num = len(active_instances)
161
198
  if active_instances_num >= fleet_spec.configuration.nodes.min:
162
- return False
199
+ if (
200
+ fleet_spec.configuration.nodes.max is None
201
+ or active_instances_num <= fleet_spec.configuration.nodes.max
202
+ ):
203
+ return False
204
+ # Fleet has more instances than allowed by nodes.max.
205
+ # This is possible due to race conditions (e.g. provisioning jobs in a fleet concurrently)
206
+ # or if nodes.max is updated.
207
+ nodes_redundant = active_instances_num - fleet_spec.configuration.nodes.max
208
+ for instance in fleet_model.instances:
209
+ if nodes_redundant == 0:
210
+ break
211
+ if instance.status in [InstanceStatus.IDLE]:
212
+ instance.status = InstanceStatus.TERMINATING
213
+ instance.termination_reason = "Fleet has too many instances"
214
+ nodes_redundant -= 1
215
+ logger.info(
216
+ "Terminating instance %s: %s",
217
+ instance.name,
218
+ instance.termination_reason,
219
+ )
220
+ return True
163
221
  nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num
164
222
  for i in range(nodes_missing):
165
223
  instance_model = create_fleet_instance_model(
@@ -259,9 +259,7 @@ async def _add_remote(instance: InstanceModel) -> None:
259
259
  if instance.status == InstanceStatus.PENDING:
260
260
  instance.status = InstanceStatus.PROVISIONING
261
261
 
262
- retry_duration_deadline = instance.created_at.replace(
263
- tzinfo=datetime.timezone.utc
264
- ) + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
262
+ retry_duration_deadline = instance.created_at + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
265
263
  if retry_duration_deadline < get_current_datetime():
266
264
  instance.status = InstanceStatus.TERMINATED
267
265
  instance.termination_reason = "Provisioning timeout expired"
@@ -560,10 +558,14 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
560
558
  if (
561
559
  _is_fleet_master_instance(instance)
562
560
  and instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT
561
+ and isinstance(compute, ComputeWithPlacementGroupSupport)
562
+ and (
563
+ compute.are_placement_groups_compatible_with_reservations(instance_offer.backend)
564
+ or instance_configuration.reservation is None
565
+ )
563
566
  and instance.fleet
564
567
  and _is_cloud_cluster(instance.fleet)
565
568
  ):
566
- assert isinstance(compute, ComputeWithPlacementGroupSupport)
567
569
  placement_group_model = _find_suitable_placement_group(
568
570
  placement_groups=placement_group_models,
569
571
  instance_offer=instance_offer,
@@ -243,6 +243,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
243
243
  job_submission.age,
244
244
  )
245
245
  ssh_user = job_provisioning_data.username
246
+ assert run.run_spec.ssh_key_pub is not None
246
247
  user_ssh_key = run.run_spec.ssh_key_pub.strip()
247
248
  public_keys = [project.ssh_public_key.strip(), user_ssh_key]
248
249
  if job_provisioning_data.backend == BackendType.LOCAL:
@@ -256,8 +256,8 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
256
256
  for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
257
257
  replica_statuses: Set[RunStatus] = set()
258
258
  replica_needs_retry = False
259
-
260
259
  replica_active = True
260
+ jobs_done_num = 0
261
261
  for job_model in job_models:
262
262
  job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
263
263
  if (
@@ -272,8 +272,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
272
272
  ):
273
273
  # the job is done or going to be done
274
274
  replica_statuses.add(RunStatus.DONE)
275
- # for some reason the replica is done, it's not active
276
- replica_active = False
275
+ jobs_done_num += 1
277
276
  elif job_model.termination_reason == JobTerminationReason.SCALED_DOWN:
278
277
  # the job was scaled down
279
278
  replica_active = False
@@ -313,26 +312,14 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
313
312
  if not replica_needs_retry or retry_single_job:
314
313
  run_statuses.update(replica_statuses)
315
314
 
316
- if replica_active:
317
- # submitted_at = replica created
318
- replicas_info.append(
319
- autoscalers.ReplicaInfo(
320
- active=True,
321
- timestamp=min(job.submitted_at for job in job_models).replace(
322
- tzinfo=datetime.timezone.utc
323
- ),
324
- )
325
- )
326
- else:
327
- # last_processed_at = replica scaled down
328
- replicas_info.append(
329
- autoscalers.ReplicaInfo(
330
- active=False,
331
- timestamp=max(job.last_processed_at for job in job_models).replace(
332
- tzinfo=datetime.timezone.utc
333
- ),
334
- )
335
- )
315
+ if jobs_done_num == len(job_models):
316
+ # Consider replica inactive if all its jobs are done for some reason.
317
+ # If only some jobs are done, replica is considered active to avoid
318
+ # provisioning new replicas for partially done multi-node tasks.
319
+ replica_active = False
320
+
321
+ replica_info = _get_replica_info(job_models, replica_active)
322
+ replicas_info.append(replica_info)
336
323
 
337
324
  termination_reason: Optional[RunTerminationReason] = None
338
325
  if RunStatus.FAILED in run_statuses:
@@ -410,6 +397,23 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
410
397
  run_model.resubmission_attempt += 1
411
398
 
412
399
 
400
+ def _get_replica_info(
401
+ replica_job_models: list[JobModel],
402
+ replica_active: bool,
403
+ ) -> autoscalers.ReplicaInfo:
404
+ if replica_active:
405
+ # submitted_at = replica created
406
+ return autoscalers.ReplicaInfo(
407
+ active=True,
408
+ timestamp=min(job.submitted_at for job in replica_job_models),
409
+ )
410
+ # last_processed_at = replica scaled down
411
+ return autoscalers.ReplicaInfo(
412
+ active=False,
413
+ timestamp=max(job.last_processed_at for job in replica_job_models),
414
+ )
415
+
416
+
413
417
  async def _handle_run_replicas(
414
418
  session: AsyncSession,
415
419
  run_model: RunModel,