dstack 0.19.30rc1__py3-none-any.whl → 0.19.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dstack might be problematic. Click here for more details.
- dstack/_internal/cli/commands/__init__.py +8 -0
- dstack/_internal/cli/commands/project.py +27 -20
- dstack/_internal/cli/commands/server.py +5 -0
- dstack/_internal/cli/services/configurators/fleet.py +20 -6
- dstack/_internal/cli/utils/gpu.py +2 -2
- dstack/_internal/core/backends/aws/compute.py +13 -5
- dstack/_internal/core/backends/aws/resources.py +11 -6
- dstack/_internal/core/backends/azure/compute.py +17 -6
- dstack/_internal/core/backends/base/compute.py +57 -9
- dstack/_internal/core/backends/base/offers.py +1 -0
- dstack/_internal/core/backends/cloudrift/compute.py +2 -0
- dstack/_internal/core/backends/cudo/compute.py +2 -0
- dstack/_internal/core/backends/datacrunch/compute.py +2 -0
- dstack/_internal/core/backends/digitalocean_base/compute.py +2 -0
- dstack/_internal/core/backends/features.py +5 -0
- dstack/_internal/core/backends/gcp/compute.py +87 -38
- dstack/_internal/core/backends/gcp/configurator.py +1 -1
- dstack/_internal/core/backends/gcp/models.py +14 -1
- dstack/_internal/core/backends/gcp/resources.py +35 -12
- dstack/_internal/core/backends/hotaisle/compute.py +22 -0
- dstack/_internal/core/backends/kubernetes/compute.py +531 -215
- dstack/_internal/core/backends/kubernetes/models.py +13 -16
- dstack/_internal/core/backends/kubernetes/utils.py +145 -8
- dstack/_internal/core/backends/lambdalabs/compute.py +2 -0
- dstack/_internal/core/backends/local/compute.py +2 -0
- dstack/_internal/core/backends/nebius/compute.py +17 -0
- dstack/_internal/core/backends/nebius/configurator.py +15 -0
- dstack/_internal/core/backends/nebius/models.py +57 -5
- dstack/_internal/core/backends/nebius/resources.py +45 -2
- dstack/_internal/core/backends/oci/compute.py +7 -1
- dstack/_internal/core/backends/oci/resources.py +8 -3
- dstack/_internal/core/backends/template/compute.py.jinja +2 -0
- dstack/_internal/core/backends/tensordock/compute.py +2 -0
- dstack/_internal/core/backends/vultr/compute.py +2 -0
- dstack/_internal/core/compatibility/runs.py +8 -0
- dstack/_internal/core/consts.py +2 -0
- dstack/_internal/core/models/profiles.py +11 -4
- dstack/_internal/core/services/repos.py +101 -11
- dstack/_internal/server/background/tasks/common.py +2 -0
- dstack/_internal/server/background/tasks/process_fleets.py +75 -17
- dstack/_internal/server/background/tasks/process_instances.py +3 -5
- dstack/_internal/server/background/tasks/process_running_jobs.py +1 -1
- dstack/_internal/server/background/tasks/process_runs.py +27 -23
- dstack/_internal/server/background/tasks/process_submitted_jobs.py +107 -54
- dstack/_internal/server/services/offers.py +7 -1
- dstack/_internal/server/testing/common.py +2 -0
- dstack/_internal/server/utils/provisioning.py +3 -10
- dstack/_internal/utils/ssh.py +22 -2
- dstack/version.py +2 -2
- {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/METADATA +20 -18
- {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/RECORD +54 -54
- {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/WHEEL +0 -0
- {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/entry_points.txt +0 -0
- {dstack-0.19.30rc1.dist-info → dstack-0.19.32.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -80,14 +80,21 @@ def parse_stop_duration(
|
|
|
80
80
|
def parse_off_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[Literal["off"], int]]:
|
|
81
81
|
if v == "off" or v is False:
|
|
82
82
|
return "off"
|
|
83
|
-
if v is True:
|
|
83
|
+
if v is True or v is None:
|
|
84
84
|
return None
|
|
85
|
-
|
|
85
|
+
duration = parse_duration(v)
|
|
86
|
+
if duration < 0:
|
|
87
|
+
raise ValueError("Duration cannot be negative")
|
|
88
|
+
return duration
|
|
86
89
|
|
|
87
90
|
|
|
88
|
-
def parse_idle_duration(v: Optional[Union[int, str]]) -> Optional[int]:
|
|
89
|
-
|
|
91
|
+
def parse_idle_duration(v: Optional[Union[int, str, bool]]) -> Optional[int]:
|
|
92
|
+
# Differs from `parse_off_duration` to accept negative durations as `off`
|
|
93
|
+
# for backward compatibility.
|
|
94
|
+
if v == "off" or v is False or v == -1:
|
|
90
95
|
return -1
|
|
96
|
+
if v is True:
|
|
97
|
+
return None
|
|
91
98
|
return parse_duration(v)
|
|
92
99
|
|
|
93
100
|
|
|
@@ -36,24 +36,59 @@ def get_repo_creds_and_default_branch(
|
|
|
36
36
|
|
|
37
37
|
# no auth
|
|
38
38
|
with suppress(InvalidRepoCredentialsError):
|
|
39
|
-
|
|
39
|
+
creds, default_branch = _get_repo_creds_and_default_branch_https(url)
|
|
40
|
+
logger.debug(
|
|
41
|
+
"Git repo %s is public. Using no auth. Default branch: %s", repo_url, default_branch
|
|
42
|
+
)
|
|
43
|
+
return creds, default_branch
|
|
40
44
|
|
|
41
45
|
# ssh key provided by the user or pulled from the server
|
|
42
46
|
if identity_file is not None or private_key is not None:
|
|
43
47
|
if identity_file is not None:
|
|
44
48
|
private_key = _read_private_key(identity_file)
|
|
45
|
-
|
|
49
|
+
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
|
|
50
|
+
url, identity_file, private_key
|
|
51
|
+
)
|
|
52
|
+
logger.debug(
|
|
53
|
+
"Git repo %s is private. Using identity file: %s. Default branch: %s",
|
|
54
|
+
repo_url,
|
|
55
|
+
identity_file,
|
|
56
|
+
default_branch,
|
|
57
|
+
)
|
|
58
|
+
return creds, default_branch
|
|
46
59
|
elif private_key is not None:
|
|
47
60
|
with NamedTemporaryFile("w+", 0o600) as f:
|
|
48
61
|
f.write(private_key)
|
|
49
62
|
f.flush()
|
|
50
|
-
|
|
63
|
+
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
|
|
64
|
+
url, f.name, private_key
|
|
65
|
+
)
|
|
66
|
+
masked_key = "***" + private_key[-10:] if len(private_key) > 10 else "***MASKED***"
|
|
67
|
+
logger.debug(
|
|
68
|
+
"Git repo %s is private. Using private key: %s. Default branch: %s",
|
|
69
|
+
repo_url,
|
|
70
|
+
masked_key,
|
|
71
|
+
default_branch,
|
|
72
|
+
)
|
|
73
|
+
return creds, default_branch
|
|
51
74
|
else:
|
|
52
75
|
assert False, "should not reach here"
|
|
53
76
|
|
|
54
77
|
# oauth token provided by the user or pulled from the server
|
|
55
78
|
if oauth_token is not None:
|
|
56
|
-
|
|
79
|
+
creds, default_branch = _get_repo_creds_and_default_branch_https(url, oauth_token)
|
|
80
|
+
masked_token = (
|
|
81
|
+
len(oauth_token[:-4]) * "*" + oauth_token[-4:]
|
|
82
|
+
if len(oauth_token) > 4
|
|
83
|
+
else "***MASKED***"
|
|
84
|
+
)
|
|
85
|
+
logger.debug(
|
|
86
|
+
"Git repo %s is private. Using provided OAuth token: %s. Default branch: %s",
|
|
87
|
+
repo_url,
|
|
88
|
+
masked_token,
|
|
89
|
+
default_branch,
|
|
90
|
+
)
|
|
91
|
+
return creds, default_branch
|
|
57
92
|
|
|
58
93
|
# key from ssh config
|
|
59
94
|
identities = get_host_config(url.original_host).get("identityfile")
|
|
@@ -61,7 +96,16 @@ def get_repo_creds_and_default_branch(
|
|
|
61
96
|
_identity_file = identities[0]
|
|
62
97
|
with suppress(InvalidRepoCredentialsError):
|
|
63
98
|
_private_key = _read_private_key(_identity_file)
|
|
64
|
-
|
|
99
|
+
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
|
|
100
|
+
url, _identity_file, _private_key
|
|
101
|
+
)
|
|
102
|
+
logger.debug(
|
|
103
|
+
"Git repo %s is private. Using SSH config identity file: %s. Default branch: %s",
|
|
104
|
+
repo_url,
|
|
105
|
+
_identity_file,
|
|
106
|
+
default_branch,
|
|
107
|
+
)
|
|
108
|
+
return creds, default_branch
|
|
65
109
|
|
|
66
110
|
# token from gh config
|
|
67
111
|
if os.path.exists(gh_config_path):
|
|
@@ -70,13 +114,35 @@ def get_repo_creds_and_default_branch(
|
|
|
70
114
|
_oauth_token = gh_hosts.get(url.host, {}).get("oauth_token")
|
|
71
115
|
if _oauth_token is not None:
|
|
72
116
|
with suppress(InvalidRepoCredentialsError):
|
|
73
|
-
|
|
117
|
+
creds, default_branch = _get_repo_creds_and_default_branch_https(url, _oauth_token)
|
|
118
|
+
masked_token = (
|
|
119
|
+
len(_oauth_token[:-4]) * "*" + _oauth_token[-4:]
|
|
120
|
+
if len(_oauth_token) > 4
|
|
121
|
+
else "***MASKED***"
|
|
122
|
+
)
|
|
123
|
+
logger.debug(
|
|
124
|
+
"Git repo %s is private. Using GitHub config token: %s from %s. Default branch: %s",
|
|
125
|
+
repo_url,
|
|
126
|
+
masked_token,
|
|
127
|
+
gh_config_path,
|
|
128
|
+
default_branch,
|
|
129
|
+
)
|
|
130
|
+
return creds, default_branch
|
|
74
131
|
|
|
75
132
|
# default user key
|
|
76
133
|
if os.path.exists(default_ssh_key):
|
|
77
134
|
with suppress(InvalidRepoCredentialsError):
|
|
78
135
|
_private_key = _read_private_key(default_ssh_key)
|
|
79
|
-
|
|
136
|
+
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
|
|
137
|
+
url, default_ssh_key, _private_key
|
|
138
|
+
)
|
|
139
|
+
logger.debug(
|
|
140
|
+
"Git repo %s is private. Using default identity file: %s. Default branch: %s",
|
|
141
|
+
repo_url,
|
|
142
|
+
default_ssh_key,
|
|
143
|
+
default_branch,
|
|
144
|
+
)
|
|
145
|
+
return creds, default_branch
|
|
80
146
|
|
|
81
147
|
raise InvalidRepoCredentialsError(
|
|
82
148
|
"No valid default Git credentials found. Pass valid `--token` or `--git-identity`."
|
|
@@ -87,8 +153,9 @@ def _get_repo_creds_and_default_branch_ssh(
|
|
|
87
153
|
url: GitRepoURL, identity_file: PathLike, private_key: str
|
|
88
154
|
) -> tuple[RemoteRepoCreds, Optional[str]]:
|
|
89
155
|
_url = url.as_ssh()
|
|
156
|
+
env = _make_git_env_for_creds_check(identity_file=identity_file)
|
|
90
157
|
try:
|
|
91
|
-
default_branch = _get_repo_default_branch(_url,
|
|
158
|
+
default_branch = _get_repo_default_branch(_url, env)
|
|
92
159
|
except GitCommandError as e:
|
|
93
160
|
message = f"Cannot access `{_url}` using the `{identity_file}` private SSH key"
|
|
94
161
|
raise InvalidRepoCredentialsError(message) from e
|
|
@@ -104,8 +171,9 @@ def _get_repo_creds_and_default_branch_https(
|
|
|
104
171
|
url: GitRepoURL, oauth_token: Optional[str] = None
|
|
105
172
|
) -> tuple[RemoteRepoCreds, Optional[str]]:
|
|
106
173
|
_url = url.as_https()
|
|
174
|
+
env = _make_git_env_for_creds_check()
|
|
107
175
|
try:
|
|
108
|
-
default_branch = _get_repo_default_branch(url.as_https(oauth_token),
|
|
176
|
+
default_branch = _get_repo_default_branch(url.as_https(oauth_token), env)
|
|
109
177
|
except GitCommandError as e:
|
|
110
178
|
message = f"Cannot access `{_url}`"
|
|
111
179
|
if oauth_token is not None:
|
|
@@ -120,10 +188,32 @@ def _get_repo_creds_and_default_branch_https(
|
|
|
120
188
|
return creds, default_branch
|
|
121
189
|
|
|
122
190
|
|
|
191
|
+
def _make_git_env_for_creds_check(identity_file: Optional[PathLike] = None) -> dict[str, str]:
|
|
192
|
+
# Our goal is to check if _provided_ creds (if any) are correct, so we need to be sure that
|
|
193
|
+
# only the provided creds are used, without falling back to any additional mechanisms.
|
|
194
|
+
# To do this, we:
|
|
195
|
+
# 1. Disable all configs to ignore any stored creds
|
|
196
|
+
# 2. Disable askpass to avoid asking for creds interactively or fetching stored creds from
|
|
197
|
+
# a non-interactive askpass helper (for example, VS Code sets GIT_ASKPASS to its own helper,
|
|
198
|
+
# which silently provides creds to Git).
|
|
199
|
+
return make_git_env(disable_config=True, disable_askpass=True, identity_file=identity_file)
|
|
200
|
+
|
|
201
|
+
|
|
123
202
|
def _get_repo_default_branch(url: str, env: dict[str, str]) -> Optional[str]:
|
|
203
|
+
# Git shipped by Apple with XCode is patched to support an additional config scope
|
|
204
|
+
# above "system" called "xcode". There is no option in `git config list` to show this config,
|
|
205
|
+
# but you can list the merged config (`git config list` without options) and then exclude
|
|
206
|
+
# all settings listed in `git config list --{system,global,local,worktree}`.
|
|
207
|
+
# As of time of writing, there are only two settings in the "xcode" config, one of which breaks
|
|
208
|
+
# our "is repo public?" check, namely "credential.helper=osxkeychain".
|
|
209
|
+
# As there is no way to disable "xcode" config (no env variable, no CLI option, etc.),
|
|
210
|
+
# the only way to disable credential helper is to override this specific setting with an empty
|
|
211
|
+
# string via command line argument: `git -c credential.helper= COMMAND [ARGS ...]`.
|
|
212
|
+
# See: https://github.com/git/git/commit/3d4355712b9fe77a96ad4ad877d92dc7ff6e0874
|
|
213
|
+
# See: https://gist.github.com/ChrisTollefson/ab9c0a5d1dd4dd615217345c6936a307
|
|
214
|
+
_git = git.cmd.Git()(c="credential.helper=")
|
|
124
215
|
# output example: "ref: refs/heads/dev\tHEAD\n545344f77c0df78367085952a97fc3a058eb4c65\tHEAD"
|
|
125
|
-
|
|
126
|
-
output: str = git.cmd.Git()(c="credential.helper=").ls_remote("--symref", url, "HEAD", env=env)
|
|
216
|
+
output: str = _git.ls_remote("--symref", url, "HEAD", env=env)
|
|
127
217
|
for line in output.splitlines():
|
|
128
218
|
# line format: `<oid> TAB <ref> LF`
|
|
129
219
|
oid, _, ref = line.partition("\t")
|
|
@@ -19,4 +19,6 @@ def get_provisioning_timeout(backend_type: BackendType, instance_type_name: str)
|
|
|
19
19
|
return timedelta(minutes=20)
|
|
20
20
|
if backend_type == BackendType.VULTR and instance_type_name.startswith("vbm"):
|
|
21
21
|
return timedelta(minutes=55)
|
|
22
|
+
if backend_type == BackendType.GCP and instance_type_name == "a4-highgpu-8g":
|
|
23
|
+
return timedelta(minutes=16)
|
|
22
24
|
return timedelta(minutes=10)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
1
2
|
from datetime import timedelta
|
|
2
3
|
from typing import List
|
|
3
4
|
from uuid import UUID
|
|
4
5
|
|
|
5
6
|
from sqlalchemy import select, update
|
|
6
7
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
-
from sqlalchemy.orm import joinedload, load_only
|
|
8
|
+
from sqlalchemy.orm import joinedload, load_only, selectinload
|
|
8
9
|
|
|
9
10
|
from dstack._internal.core.models.fleets import FleetSpec, FleetStatus
|
|
10
11
|
from dstack._internal.core.models.instances import InstanceStatus
|
|
@@ -37,30 +38,68 @@ MIN_PROCESSING_INTERVAL = timedelta(seconds=30)
|
|
|
37
38
|
|
|
38
39
|
@sentry_utils.instrument_background_task
|
|
39
40
|
async def process_fleets():
|
|
40
|
-
|
|
41
|
+
fleet_lock, fleet_lockset = get_locker(get_db().dialect_name).get_lockset(
|
|
42
|
+
FleetModel.__tablename__
|
|
43
|
+
)
|
|
44
|
+
instance_lock, instance_lockset = get_locker(get_db().dialect_name).get_lockset(
|
|
45
|
+
InstanceModel.__tablename__
|
|
46
|
+
)
|
|
41
47
|
async with get_session_ctx() as session:
|
|
42
|
-
async with
|
|
48
|
+
async with fleet_lock, instance_lock:
|
|
43
49
|
res = await session.execute(
|
|
44
50
|
select(FleetModel)
|
|
45
51
|
.where(
|
|
46
52
|
FleetModel.deleted == False,
|
|
47
|
-
FleetModel.id.not_in(
|
|
53
|
+
FleetModel.id.not_in(fleet_lockset),
|
|
48
54
|
FleetModel.last_processed_at
|
|
49
55
|
< get_current_datetime() - MIN_PROCESSING_INTERVAL,
|
|
50
56
|
)
|
|
51
|
-
.options(
|
|
57
|
+
.options(
|
|
58
|
+
load_only(FleetModel.id, FleetModel.name),
|
|
59
|
+
selectinload(FleetModel.instances).load_only(InstanceModel.id),
|
|
60
|
+
)
|
|
52
61
|
.order_by(FleetModel.last_processed_at.asc())
|
|
53
62
|
.limit(BATCH_SIZE)
|
|
54
63
|
.with_for_update(skip_locked=True, key_share=True)
|
|
55
64
|
)
|
|
56
|
-
fleet_models = list(res.scalars().all())
|
|
65
|
+
fleet_models = list(res.scalars().unique().all())
|
|
57
66
|
fleet_ids = [fm.id for fm in fleet_models]
|
|
67
|
+
res = await session.execute(
|
|
68
|
+
select(InstanceModel)
|
|
69
|
+
.where(
|
|
70
|
+
InstanceModel.id.not_in(instance_lockset),
|
|
71
|
+
InstanceModel.fleet_id.in_(fleet_ids),
|
|
72
|
+
)
|
|
73
|
+
.options(load_only(InstanceModel.id, InstanceModel.fleet_id))
|
|
74
|
+
.order_by(InstanceModel.id)
|
|
75
|
+
.with_for_update(skip_locked=True, key_share=True)
|
|
76
|
+
)
|
|
77
|
+
instance_models = list(res.scalars().all())
|
|
78
|
+
fleet_id_to_locked_instances = defaultdict(list)
|
|
79
|
+
for instance_model in instance_models:
|
|
80
|
+
fleet_id_to_locked_instances[instance_model.fleet_id].append(instance_model)
|
|
81
|
+
# Process only fleets with all instances locked.
|
|
82
|
+
# Other fleets won't be processed but will still be locked to avoid new transaction.
|
|
83
|
+
# This should not be problematic as long as process_fleets is quick.
|
|
84
|
+
fleet_models_to_process = []
|
|
85
|
+
for fleet_model in fleet_models:
|
|
86
|
+
if len(fleet_model.instances) == len(fleet_id_to_locked_instances[fleet_model.id]):
|
|
87
|
+
fleet_models_to_process.append(fleet_model)
|
|
88
|
+
else:
|
|
89
|
+
logger.debug(
|
|
90
|
+
"Fleet %s processing will be skipped: some instance were not locked",
|
|
91
|
+
fleet_model.name,
|
|
92
|
+
)
|
|
58
93
|
for fleet_id in fleet_ids:
|
|
59
|
-
|
|
94
|
+
fleet_lockset.add(fleet_id)
|
|
95
|
+
instance_ids = [im.id for im in instance_models]
|
|
96
|
+
for instance_id in instance_ids:
|
|
97
|
+
instance_lockset.add(instance_id)
|
|
60
98
|
try:
|
|
61
|
-
await _process_fleets(session=session, fleet_models=
|
|
99
|
+
await _process_fleets(session=session, fleet_models=fleet_models_to_process)
|
|
62
100
|
finally:
|
|
63
|
-
|
|
101
|
+
fleet_lockset.difference_update(fleet_ids)
|
|
102
|
+
instance_lockset.difference_update(instance_ids)
|
|
64
103
|
|
|
65
104
|
|
|
66
105
|
async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel]):
|
|
@@ -99,8 +138,8 @@ def _consolidate_fleet_state_with_spec(session: AsyncSession, fleet_model: Fleet
|
|
|
99
138
|
return
|
|
100
139
|
if not _is_fleet_ready_for_consolidation(fleet_model):
|
|
101
140
|
return
|
|
102
|
-
|
|
103
|
-
if
|
|
141
|
+
changed_instances = _maintain_fleet_nodes_in_min_max_range(session, fleet_model, fleet_spec)
|
|
142
|
+
if changed_instances:
|
|
104
143
|
fleet_model.consolidation_attempt += 1
|
|
105
144
|
else:
|
|
106
145
|
# The fleet is already consolidated or consolidation is in progress.
|
|
@@ -138,28 +177,47 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta:
|
|
|
138
177
|
return _CONSOLIDATION_RETRY_DELAYS[-1]
|
|
139
178
|
|
|
140
179
|
|
|
141
|
-
def
|
|
180
|
+
def _maintain_fleet_nodes_in_min_max_range(
|
|
142
181
|
session: AsyncSession,
|
|
143
182
|
fleet_model: FleetModel,
|
|
144
183
|
fleet_spec: FleetSpec,
|
|
145
184
|
) -> bool:
|
|
146
185
|
"""
|
|
147
|
-
Ensures the fleet has at least `nodes.min` instances.
|
|
148
|
-
Returns `True` if retried
|
|
186
|
+
Ensures the fleet has at least `nodes.min` and at most `nodes.max` instances.
|
|
187
|
+
Returns `True` if retried, added new instances, or terminated redundant instances and `False` otherwise.
|
|
149
188
|
"""
|
|
150
189
|
assert fleet_spec.configuration.nodes is not None
|
|
151
190
|
for instance in fleet_model.instances:
|
|
152
191
|
# Delete terminated but not deleted instances since
|
|
153
192
|
# they are going to be replaced with new pending instances.
|
|
154
193
|
if instance.status == InstanceStatus.TERMINATED and not instance.deleted:
|
|
155
|
-
# It's safe to modify instances without instance lock since
|
|
156
|
-
# no other task modifies already terminated instances.
|
|
157
194
|
instance.deleted = True
|
|
158
195
|
instance.deleted_at = get_current_datetime()
|
|
159
196
|
active_instances = [i for i in fleet_model.instances if not i.deleted]
|
|
160
197
|
active_instances_num = len(active_instances)
|
|
161
198
|
if active_instances_num >= fleet_spec.configuration.nodes.min:
|
|
162
|
-
|
|
199
|
+
if (
|
|
200
|
+
fleet_spec.configuration.nodes.max is None
|
|
201
|
+
or active_instances_num <= fleet_spec.configuration.nodes.max
|
|
202
|
+
):
|
|
203
|
+
return False
|
|
204
|
+
# Fleet has more instances than allowed by nodes.max.
|
|
205
|
+
# This is possible due to race conditions (e.g. provisioning jobs in a fleet concurrently)
|
|
206
|
+
# or if nodes.max is updated.
|
|
207
|
+
nodes_redundant = active_instances_num - fleet_spec.configuration.nodes.max
|
|
208
|
+
for instance in fleet_model.instances:
|
|
209
|
+
if nodes_redundant == 0:
|
|
210
|
+
break
|
|
211
|
+
if instance.status in [InstanceStatus.IDLE]:
|
|
212
|
+
instance.status = InstanceStatus.TERMINATING
|
|
213
|
+
instance.termination_reason = "Fleet has too many instances"
|
|
214
|
+
nodes_redundant -= 1
|
|
215
|
+
logger.info(
|
|
216
|
+
"Terminating instance %s: %s",
|
|
217
|
+
instance.name,
|
|
218
|
+
instance.termination_reason,
|
|
219
|
+
)
|
|
220
|
+
return True
|
|
163
221
|
nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num
|
|
164
222
|
for i in range(nodes_missing):
|
|
165
223
|
instance_model = create_fleet_instance_model(
|
|
@@ -259,9 +259,7 @@ async def _add_remote(instance: InstanceModel) -> None:
|
|
|
259
259
|
if instance.status == InstanceStatus.PENDING:
|
|
260
260
|
instance.status = InstanceStatus.PROVISIONING
|
|
261
261
|
|
|
262
|
-
retry_duration_deadline = instance.created_at
|
|
263
|
-
tzinfo=datetime.timezone.utc
|
|
264
|
-
) + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
|
|
262
|
+
retry_duration_deadline = instance.created_at + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
|
|
265
263
|
if retry_duration_deadline < get_current_datetime():
|
|
266
264
|
instance.status = InstanceStatus.TERMINATED
|
|
267
265
|
instance.termination_reason = "Provisioning timeout expired"
|
|
@@ -307,7 +305,7 @@ async def _add_remote(instance: InstanceModel) -> None:
|
|
|
307
305
|
)
|
|
308
306
|
deploy_timeout = 20 * 60 # 20 minutes
|
|
309
307
|
result = await asyncio.wait_for(future, timeout=deploy_timeout)
|
|
310
|
-
health, host_info,
|
|
308
|
+
health, host_info, arch = result
|
|
311
309
|
except (asyncio.TimeoutError, TimeoutError) as e:
|
|
312
310
|
raise ProvisioningError(f"Deploy timeout: {e}") from e
|
|
313
311
|
except Exception as e:
|
|
@@ -327,7 +325,7 @@ async def _add_remote(instance: InstanceModel) -> None:
|
|
|
327
325
|
instance.status = InstanceStatus.PENDING
|
|
328
326
|
return
|
|
329
327
|
|
|
330
|
-
instance_type = host_info_to_instance_type(host_info,
|
|
328
|
+
instance_type = host_info_to_instance_type(host_info, arch)
|
|
331
329
|
instance_network = None
|
|
332
330
|
internal_ip = None
|
|
333
331
|
try:
|
|
@@ -1139,7 +1139,7 @@ def _patch_base_image_for_aws_efa(
|
|
|
1139
1139
|
efa_enabled_patterns = [
|
|
1140
1140
|
# TODO: p6-b200 isn't supported yet in gpuhunt
|
|
1141
1141
|
r"^p6-b200\.(48xlarge)$",
|
|
1142
|
-
r"^p5\.(48xlarge)$",
|
|
1142
|
+
r"^p5\.(4xlarge|48xlarge)$",
|
|
1143
1143
|
r"^p5e\.(48xlarge)$",
|
|
1144
1144
|
r"^p5en\.(48xlarge)$",
|
|
1145
1145
|
r"^p4d\.(24xlarge)$",
|
|
@@ -256,8 +256,8 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
|
|
|
256
256
|
for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
|
|
257
257
|
replica_statuses: Set[RunStatus] = set()
|
|
258
258
|
replica_needs_retry = False
|
|
259
|
-
|
|
260
259
|
replica_active = True
|
|
260
|
+
jobs_done_num = 0
|
|
261
261
|
for job_model in job_models:
|
|
262
262
|
job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
|
|
263
263
|
if (
|
|
@@ -272,8 +272,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
|
|
|
272
272
|
):
|
|
273
273
|
# the job is done or going to be done
|
|
274
274
|
replica_statuses.add(RunStatus.DONE)
|
|
275
|
-
|
|
276
|
-
replica_active = False
|
|
275
|
+
jobs_done_num += 1
|
|
277
276
|
elif job_model.termination_reason == JobTerminationReason.SCALED_DOWN:
|
|
278
277
|
# the job was scaled down
|
|
279
278
|
replica_active = False
|
|
@@ -313,26 +312,14 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
|
|
|
313
312
|
if not replica_needs_retry or retry_single_job:
|
|
314
313
|
run_statuses.update(replica_statuses)
|
|
315
314
|
|
|
316
|
-
if
|
|
317
|
-
#
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
)
|
|
325
|
-
)
|
|
326
|
-
else:
|
|
327
|
-
# last_processed_at = replica scaled down
|
|
328
|
-
replicas_info.append(
|
|
329
|
-
autoscalers.ReplicaInfo(
|
|
330
|
-
active=False,
|
|
331
|
-
timestamp=max(job.last_processed_at for job in job_models).replace(
|
|
332
|
-
tzinfo=datetime.timezone.utc
|
|
333
|
-
),
|
|
334
|
-
)
|
|
335
|
-
)
|
|
315
|
+
if jobs_done_num == len(job_models):
|
|
316
|
+
# Consider replica inactive if all its jobs are done for some reason.
|
|
317
|
+
# If only some jobs are done, replica is considered active to avoid
|
|
318
|
+
# provisioning new replicas for partially done multi-node tasks.
|
|
319
|
+
replica_active = False
|
|
320
|
+
|
|
321
|
+
replica_info = _get_replica_info(job_models, replica_active)
|
|
322
|
+
replicas_info.append(replica_info)
|
|
336
323
|
|
|
337
324
|
termination_reason: Optional[RunTerminationReason] = None
|
|
338
325
|
if RunStatus.FAILED in run_statuses:
|
|
@@ -410,6 +397,23 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
|
|
|
410
397
|
run_model.resubmission_attempt += 1
|
|
411
398
|
|
|
412
399
|
|
|
400
|
+
def _get_replica_info(
|
|
401
|
+
replica_job_models: list[JobModel],
|
|
402
|
+
replica_active: bool,
|
|
403
|
+
) -> autoscalers.ReplicaInfo:
|
|
404
|
+
if replica_active:
|
|
405
|
+
# submitted_at = replica created
|
|
406
|
+
return autoscalers.ReplicaInfo(
|
|
407
|
+
active=True,
|
|
408
|
+
timestamp=min(job.submitted_at for job in replica_job_models),
|
|
409
|
+
)
|
|
410
|
+
# last_processed_at = replica scaled down
|
|
411
|
+
return autoscalers.ReplicaInfo(
|
|
412
|
+
active=False,
|
|
413
|
+
timestamp=max(job.last_processed_at for job in replica_job_models),
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
|
|
413
417
|
async def _handle_run_replicas(
|
|
414
418
|
session: AsyncSession,
|
|
415
419
|
run_model: RunModel,
|