nemo-evaluator-launcher 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nemo-evaluator-launcher might be problematic. Click here for more details.
- nemo_evaluator_launcher/__init__.py +79 -0
- nemo_evaluator_launcher/api/__init__.py +24 -0
- nemo_evaluator_launcher/api/functional.py +698 -0
- nemo_evaluator_launcher/api/types.py +98 -0
- nemo_evaluator_launcher/api/utils.py +19 -0
- nemo_evaluator_launcher/cli/__init__.py +15 -0
- nemo_evaluator_launcher/cli/export.py +267 -0
- nemo_evaluator_launcher/cli/info.py +512 -0
- nemo_evaluator_launcher/cli/kill.py +41 -0
- nemo_evaluator_launcher/cli/ls_runs.py +134 -0
- nemo_evaluator_launcher/cli/ls_tasks.py +136 -0
- nemo_evaluator_launcher/cli/main.py +226 -0
- nemo_evaluator_launcher/cli/run.py +200 -0
- nemo_evaluator_launcher/cli/status.py +164 -0
- nemo_evaluator_launcher/cli/version.py +55 -0
- nemo_evaluator_launcher/common/__init__.py +16 -0
- nemo_evaluator_launcher/common/execdb.py +283 -0
- nemo_evaluator_launcher/common/helpers.py +366 -0
- nemo_evaluator_launcher/common/logging_utils.py +357 -0
- nemo_evaluator_launcher/common/mapping.py +295 -0
- nemo_evaluator_launcher/common/printing_utils.py +93 -0
- nemo_evaluator_launcher/configs/__init__.py +15 -0
- nemo_evaluator_launcher/configs/default.yaml +28 -0
- nemo_evaluator_launcher/configs/deployment/generic.yaml +33 -0
- nemo_evaluator_launcher/configs/deployment/nim.yaml +32 -0
- nemo_evaluator_launcher/configs/deployment/none.yaml +16 -0
- nemo_evaluator_launcher/configs/deployment/sglang.yaml +38 -0
- nemo_evaluator_launcher/configs/deployment/trtllm.yaml +24 -0
- nemo_evaluator_launcher/configs/deployment/vllm.yaml +42 -0
- nemo_evaluator_launcher/configs/execution/lepton/default.yaml +92 -0
- nemo_evaluator_launcher/configs/execution/local.yaml +19 -0
- nemo_evaluator_launcher/configs/execution/slurm/default.yaml +34 -0
- nemo_evaluator_launcher/executors/__init__.py +22 -0
- nemo_evaluator_launcher/executors/base.py +120 -0
- nemo_evaluator_launcher/executors/lepton/__init__.py +16 -0
- nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +609 -0
- nemo_evaluator_launcher/executors/lepton/executor.py +1004 -0
- nemo_evaluator_launcher/executors/lepton/job_helpers.py +398 -0
- nemo_evaluator_launcher/executors/local/__init__.py +15 -0
- nemo_evaluator_launcher/executors/local/executor.py +605 -0
- nemo_evaluator_launcher/executors/local/run.template.sh +103 -0
- nemo_evaluator_launcher/executors/registry.py +38 -0
- nemo_evaluator_launcher/executors/slurm/__init__.py +15 -0
- nemo_evaluator_launcher/executors/slurm/executor.py +1147 -0
- nemo_evaluator_launcher/exporters/__init__.py +36 -0
- nemo_evaluator_launcher/exporters/base.py +121 -0
- nemo_evaluator_launcher/exporters/gsheets.py +409 -0
- nemo_evaluator_launcher/exporters/local.py +502 -0
- nemo_evaluator_launcher/exporters/mlflow.py +619 -0
- nemo_evaluator_launcher/exporters/registry.py +40 -0
- nemo_evaluator_launcher/exporters/utils.py +624 -0
- nemo_evaluator_launcher/exporters/wandb.py +490 -0
- nemo_evaluator_launcher/package_info.py +38 -0
- nemo_evaluator_launcher/resources/mapping.toml +380 -0
- nemo_evaluator_launcher-0.1.28.dist-info/METADATA +494 -0
- nemo_evaluator_launcher-0.1.28.dist-info/RECORD +60 -0
- nemo_evaluator_launcher-0.1.28.dist-info/WHEEL +5 -0
- nemo_evaluator_launcher-0.1.28.dist-info/entry_points.txt +3 -0
- nemo_evaluator_launcher-0.1.28.dist-info/licenses/LICENSE +451 -0
- nemo_evaluator_launcher-0.1.28.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1147 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
#
|
|
16
|
+
"""SLURM executor implementation for nemo-evaluator-launcher.
|
|
17
|
+
|
|
18
|
+
Handles submitting evaluation jobs to a SLURM cluster via SSH and sbatch scripts.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import copy
|
|
22
|
+
import os
|
|
23
|
+
import re
|
|
24
|
+
import shlex
|
|
25
|
+
import subprocess
|
|
26
|
+
import tempfile
|
|
27
|
+
import time
|
|
28
|
+
import warnings
|
|
29
|
+
from pathlib import Path
|
|
30
|
+
from typing import Dict, List, Optional
|
|
31
|
+
|
|
32
|
+
import yaml
|
|
33
|
+
from omegaconf import DictConfig, OmegaConf
|
|
34
|
+
|
|
35
|
+
from nemo_evaluator_launcher.common.execdb import (
|
|
36
|
+
ExecutionDB,
|
|
37
|
+
JobData,
|
|
38
|
+
generate_invocation_id,
|
|
39
|
+
generate_job_id,
|
|
40
|
+
)
|
|
41
|
+
from nemo_evaluator_launcher.common.helpers import (
|
|
42
|
+
CmdAndReadableComment,
|
|
43
|
+
get_api_key_name,
|
|
44
|
+
get_endpoint_url,
|
|
45
|
+
get_eval_factory_command,
|
|
46
|
+
get_eval_factory_config,
|
|
47
|
+
get_eval_factory_dataset_size_from_run_config,
|
|
48
|
+
get_health_url,
|
|
49
|
+
get_timestamp_string,
|
|
50
|
+
)
|
|
51
|
+
from nemo_evaluator_launcher.common.logging_utils import logger
|
|
52
|
+
from nemo_evaluator_launcher.common.mapping import (
|
|
53
|
+
get_task_from_mapping,
|
|
54
|
+
load_tasks_mapping,
|
|
55
|
+
)
|
|
56
|
+
from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey, red
|
|
57
|
+
from nemo_evaluator_launcher.executors.base import (
|
|
58
|
+
BaseExecutor,
|
|
59
|
+
ExecutionState,
|
|
60
|
+
ExecutionStatus,
|
|
61
|
+
)
|
|
62
|
+
from nemo_evaluator_launcher.executors.registry import register_executor
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@register_executor("slurm")
|
|
66
|
+
class SlurmExecutor(BaseExecutor):
|
|
67
|
+
@staticmethod
|
|
68
|
+
def execute_eval(cfg: DictConfig, dry_run: bool = False) -> str:
|
|
69
|
+
"""Submit evaluation jobs to a SLURM cluster using the provided configuration.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
cfg: The configuration object for the evaluation run.
|
|
73
|
+
dry_run: If True, prepare scripts and save them without submission.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
str: The invocation ID for the evaluation run.
|
|
77
|
+
|
|
78
|
+
Raises:
|
|
79
|
+
AssertionError: If deployment type is 'none'.
|
|
80
|
+
RuntimeError: If remote directory creation or sbatch submission fails.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
# Generate invocation ID
|
|
84
|
+
invocation_id = generate_invocation_id()
|
|
85
|
+
|
|
86
|
+
local_runsub_paths = []
|
|
87
|
+
remote_runsub_paths = []
|
|
88
|
+
|
|
89
|
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
90
|
+
timestamp = get_timestamp_string(include_microseconds=False)
|
|
91
|
+
rundir_name = timestamp + "-" + invocation_id
|
|
92
|
+
remote_rundir = Path(cfg.execution.output_dir) / rundir_name
|
|
93
|
+
local_rundir = Path(tmpdirname) / rundir_name
|
|
94
|
+
local_rundir.mkdir()
|
|
95
|
+
|
|
96
|
+
# Preload mapping for image resolution
|
|
97
|
+
tasks_mapping = load_tasks_mapping()
|
|
98
|
+
eval_images: list[str] = []
|
|
99
|
+
|
|
100
|
+
is_potentially_unsafe = False
|
|
101
|
+
for idx, task in enumerate(cfg.evaluation.tasks):
|
|
102
|
+
# calculate job_id
|
|
103
|
+
job_id = f"{invocation_id}.{idx}"
|
|
104
|
+
|
|
105
|
+
# preapre locally
|
|
106
|
+
remote_task_subdir = remote_rundir / task.name
|
|
107
|
+
local_task_subdir = local_rundir / task.name
|
|
108
|
+
local_task_subdir.mkdir() # this ensures the task.name hasn't been used
|
|
109
|
+
(local_task_subdir / "logs").mkdir()
|
|
110
|
+
(local_task_subdir / "artifacts").mkdir()
|
|
111
|
+
|
|
112
|
+
# resolve eval image and pass directly via task override
|
|
113
|
+
task_definition = get_task_from_mapping(task.name, tasks_mapping)
|
|
114
|
+
eval_image = task_definition["container"]
|
|
115
|
+
if "container" in task:
|
|
116
|
+
eval_image = task["container"]
|
|
117
|
+
|
|
118
|
+
eval_images.append(eval_image)
|
|
119
|
+
|
|
120
|
+
# generate and write down sbatch script
|
|
121
|
+
sbatch_script_content_struct = _create_slurm_sbatch_script(
|
|
122
|
+
cfg=cfg,
|
|
123
|
+
task=task,
|
|
124
|
+
eval_image=eval_image,
|
|
125
|
+
remote_task_subdir=remote_task_subdir,
|
|
126
|
+
invocation_id=invocation_id,
|
|
127
|
+
job_id=job_id,
|
|
128
|
+
)
|
|
129
|
+
sbatch_script_content_str = sbatch_script_content_struct.cmd
|
|
130
|
+
|
|
131
|
+
# We accumulate if any task contains unsafe commands
|
|
132
|
+
is_potentially_unsafe = (
|
|
133
|
+
is_potentially_unsafe
|
|
134
|
+
or sbatch_script_content_struct.is_potentially_unsafe
|
|
135
|
+
)
|
|
136
|
+
local_runsub_path = local_task_subdir / "run.sub"
|
|
137
|
+
remote_runsub_path = remote_task_subdir / "run.sub"
|
|
138
|
+
with open(local_runsub_path, "w") as f:
|
|
139
|
+
f.write(sbatch_script_content_str.rstrip("\n") + "\n")
|
|
140
|
+
|
|
141
|
+
local_runsub_paths.append(local_runsub_path)
|
|
142
|
+
remote_runsub_paths.append(remote_runsub_path)
|
|
143
|
+
|
|
144
|
+
if dry_run:
|
|
145
|
+
print(bold("\n\n=============================================\n\n"))
|
|
146
|
+
print(bold(cyan("DRY RUN: SLURM scripts prepared")))
|
|
147
|
+
for idx, local_runsub_path in enumerate(local_runsub_paths):
|
|
148
|
+
print(cyan(f"\n\n=========== Task {idx} =====================\n\n"))
|
|
149
|
+
with open(local_runsub_path, "r") as f:
|
|
150
|
+
print(grey(f.read()))
|
|
151
|
+
print(bold("To submit jobs") + ", run the executor without --dry-run")
|
|
152
|
+
if is_potentially_unsafe:
|
|
153
|
+
print(
|
|
154
|
+
red(
|
|
155
|
+
"\nFound `pre_cmd` which carries security risk. When running without --dry-run "
|
|
156
|
+
"make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1"
|
|
157
|
+
)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
return invocation_id
|
|
161
|
+
|
|
162
|
+
if is_potentially_unsafe:
|
|
163
|
+
if os.environ.get("NEMO_EVALUATOR_TRUST_PRE_CMD", "") == "1":
|
|
164
|
+
logger.warning(
|
|
165
|
+
"Found non-empty task commands (e.g. `pre_cmd`) and NEMO_EVALUATOR_TRUST_PRE_CMD "
|
|
166
|
+
"is set, proceeding with caution."
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
else:
|
|
170
|
+
logger.error(
|
|
171
|
+
"Found non-empty task commands (e.g. `pre_cmd`) and NEMO_EVALUATOR_TRUST_PRE_CMD "
|
|
172
|
+
"is not set. This might carry security risk and unstable environments. "
|
|
173
|
+
"To continue, make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1.",
|
|
174
|
+
)
|
|
175
|
+
raise AttributeError(
|
|
176
|
+
"Untrusted command found in config, make sure you trust and "
|
|
177
|
+
"set NEMO_EVALUATOR_TRUST_PRE_CMD=1."
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
socket = str(Path(tmpdirname) / "socket")
|
|
181
|
+
socket_or_none = _open_master_connection(
|
|
182
|
+
username=cfg.execution.username,
|
|
183
|
+
hostname=cfg.execution.hostname,
|
|
184
|
+
socket=socket,
|
|
185
|
+
)
|
|
186
|
+
_make_remote_execution_output_dir(
|
|
187
|
+
dirpath=cfg.execution.output_dir,
|
|
188
|
+
username=cfg.execution.username,
|
|
189
|
+
hostname=cfg.execution.hostname,
|
|
190
|
+
socket=socket_or_none,
|
|
191
|
+
)
|
|
192
|
+
_rsync_upload_rundirs(
|
|
193
|
+
local_sources=[local_rundir],
|
|
194
|
+
remote_target=cfg.execution.output_dir,
|
|
195
|
+
username=cfg.execution.username,
|
|
196
|
+
hostname=cfg.execution.hostname,
|
|
197
|
+
)
|
|
198
|
+
slurm_job_ids = _sbatch_remote_runsubs(
|
|
199
|
+
remote_runsub_paths=remote_runsub_paths,
|
|
200
|
+
username=cfg.execution.username,
|
|
201
|
+
hostname=cfg.execution.hostname,
|
|
202
|
+
socket=socket_or_none,
|
|
203
|
+
)
|
|
204
|
+
_close_master_connection(
|
|
205
|
+
username=cfg.execution.username,
|
|
206
|
+
hostname=cfg.execution.hostname,
|
|
207
|
+
socket=socket_or_none,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# save launched jobs metadata
|
|
211
|
+
db = ExecutionDB()
|
|
212
|
+
for idx, (slurm_job_id, remote_runsub_path) in enumerate(
|
|
213
|
+
zip(slurm_job_ids, remote_runsub_paths)
|
|
214
|
+
):
|
|
215
|
+
job_id = generate_job_id(invocation_id, idx)
|
|
216
|
+
db.write_job(
|
|
217
|
+
job=JobData(
|
|
218
|
+
invocation_id=invocation_id,
|
|
219
|
+
job_id=job_id,
|
|
220
|
+
timestamp=time.time(),
|
|
221
|
+
executor="slurm",
|
|
222
|
+
data={
|
|
223
|
+
"slurm_job_id": slurm_job_id,
|
|
224
|
+
"remote_rundir_path": str(remote_runsub_path.parent),
|
|
225
|
+
"hostname": cfg.execution.hostname,
|
|
226
|
+
"username": cfg.execution.username,
|
|
227
|
+
"eval_image": eval_images[idx],
|
|
228
|
+
},
|
|
229
|
+
config=OmegaConf.to_object(cfg),
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
return invocation_id
|
|
233
|
+
|
|
234
|
+
@staticmethod
|
|
235
|
+
def get_status(id: str) -> List[ExecutionStatus]:
|
|
236
|
+
"""Get the status of a specific SLURM job or all jobs in an invocation group.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
id: Unique job identifier or invocation identifier.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
List containing the execution status for the job(s).
|
|
243
|
+
"""
|
|
244
|
+
db = ExecutionDB()
|
|
245
|
+
|
|
246
|
+
# If id looks like an invocation_id
|
|
247
|
+
if "." not in id:
|
|
248
|
+
jobs = db.get_jobs(id)
|
|
249
|
+
if not jobs:
|
|
250
|
+
return []
|
|
251
|
+
return SlurmExecutor._get_status_for_invocation(jobs)
|
|
252
|
+
|
|
253
|
+
# Otherwise, treat as job_id
|
|
254
|
+
else:
|
|
255
|
+
job_data = db.get_job(id)
|
|
256
|
+
if job_data is None or job_data.executor != "slurm":
|
|
257
|
+
return []
|
|
258
|
+
return [SlurmExecutor._get_status_for_job(id, job_data)]
|
|
259
|
+
|
|
260
|
+
@staticmethod
|
|
261
|
+
def _get_status_for_job(id: str, job_data: JobData) -> ExecutionStatus:
|
|
262
|
+
slurm_job_id = job_data.data.get("slurm_job_id")
|
|
263
|
+
if not slurm_job_id:
|
|
264
|
+
return ExecutionStatus(id=id, state=ExecutionState.FAILED)
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
return SlurmExecutor._query_slurm_for_status_and_progress(
|
|
268
|
+
slurm_job_ids=[slurm_job_id],
|
|
269
|
+
remote_rundir_paths=[Path(job_data.data.get("remote_rundir_path"))],
|
|
270
|
+
username=job_data.data["username"],
|
|
271
|
+
hostname=job_data.data["hostname"],
|
|
272
|
+
job_id_to_execdb_id={slurm_job_id: id},
|
|
273
|
+
)[0]
|
|
274
|
+
except Exception:
|
|
275
|
+
return ExecutionStatus(id=id, state=ExecutionState.FAILED)
|
|
276
|
+
|
|
277
|
+
@staticmethod
|
|
278
|
+
def _get_status_for_invocation(jobs: dict) -> List[ExecutionStatus]:
|
|
279
|
+
slurm_job_ids = []
|
|
280
|
+
remote_rundir_paths = []
|
|
281
|
+
job_id_to_execdb_id = {}
|
|
282
|
+
username = None
|
|
283
|
+
hostname = None
|
|
284
|
+
|
|
285
|
+
for job_id, job_data in jobs.items():
|
|
286
|
+
if job_data.executor != "slurm":
|
|
287
|
+
continue
|
|
288
|
+
slurm_job_id = job_data.data.get("slurm_job_id")
|
|
289
|
+
if slurm_job_id:
|
|
290
|
+
slurm_job_ids.append(slurm_job_id)
|
|
291
|
+
remote_rundir_paths.append(
|
|
292
|
+
Path(job_data.data.get("remote_rundir_path"))
|
|
293
|
+
)
|
|
294
|
+
job_id_to_execdb_id[slurm_job_id] = job_id
|
|
295
|
+
username = job_data.data.get("username")
|
|
296
|
+
hostname = job_data.data.get("hostname")
|
|
297
|
+
|
|
298
|
+
if not slurm_job_ids or not remote_rundir_paths or not username or not hostname:
|
|
299
|
+
return [
|
|
300
|
+
ExecutionStatus(id=job_id, state=ExecutionState.FAILED)
|
|
301
|
+
for job_id in jobs.keys()
|
|
302
|
+
]
|
|
303
|
+
|
|
304
|
+
try:
|
|
305
|
+
return SlurmExecutor._query_slurm_for_status_and_progress(
|
|
306
|
+
slurm_job_ids=slurm_job_ids,
|
|
307
|
+
remote_rundir_paths=remote_rundir_paths,
|
|
308
|
+
username=username,
|
|
309
|
+
hostname=hostname,
|
|
310
|
+
job_id_to_execdb_id=job_id_to_execdb_id,
|
|
311
|
+
)
|
|
312
|
+
except Exception:
|
|
313
|
+
return [
|
|
314
|
+
ExecutionStatus(id=job_id, state=ExecutionState.FAILED)
|
|
315
|
+
for job_id in jobs.keys()
|
|
316
|
+
]
|
|
317
|
+
|
|
318
|
+
@staticmethod
|
|
319
|
+
def _query_slurm_for_status_and_progress(
|
|
320
|
+
slurm_job_ids: List[str],
|
|
321
|
+
remote_rundir_paths: List[Path],
|
|
322
|
+
username: str,
|
|
323
|
+
hostname: str,
|
|
324
|
+
job_id_to_execdb_id: dict,
|
|
325
|
+
) -> List[ExecutionStatus]:
|
|
326
|
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
327
|
+
socket = str(Path(tmpdirname) / "socket")
|
|
328
|
+
socket_or_none = _open_master_connection(
|
|
329
|
+
username=username,
|
|
330
|
+
hostname=hostname,
|
|
331
|
+
socket=socket,
|
|
332
|
+
)
|
|
333
|
+
# get slurm job status for initial jobs:
|
|
334
|
+
slurm_jobs_status = _query_slurm_jobs_status(
|
|
335
|
+
slurm_job_ids=slurm_job_ids,
|
|
336
|
+
username=username,
|
|
337
|
+
hostname=hostname,
|
|
338
|
+
socket=socket_or_none,
|
|
339
|
+
)
|
|
340
|
+
# handle slurm status for autoresumed jobs:
|
|
341
|
+
autoresumed_slurm_job_ids = _read_autoresumed_slurm_job_ids(
|
|
342
|
+
slurm_job_ids=slurm_job_ids,
|
|
343
|
+
remote_rundir_paths=remote_rundir_paths,
|
|
344
|
+
username=username,
|
|
345
|
+
hostname=hostname,
|
|
346
|
+
socket=socket_or_none,
|
|
347
|
+
)
|
|
348
|
+
latest_slurm_job_ids = {
|
|
349
|
+
slurm_job_id: slurm_job_id_list[-1]
|
|
350
|
+
for slurm_job_id, slurm_job_id_list in autoresumed_slurm_job_ids.items()
|
|
351
|
+
if len(slurm_job_id_list) > 0 and slurm_job_id_list[-1] != slurm_job_id
|
|
352
|
+
}
|
|
353
|
+
latest_slurm_jobs_status = _query_slurm_jobs_status(
|
|
354
|
+
slurm_job_ids=list(latest_slurm_job_ids.values()),
|
|
355
|
+
username=username,
|
|
356
|
+
hostname=hostname,
|
|
357
|
+
socket=socket_or_none,
|
|
358
|
+
)
|
|
359
|
+
# get progress:
|
|
360
|
+
progress_list = _get_progress(
|
|
361
|
+
remote_rundir_paths=remote_rundir_paths,
|
|
362
|
+
username=username,
|
|
363
|
+
hostname=hostname,
|
|
364
|
+
socket=socket_or_none,
|
|
365
|
+
)
|
|
366
|
+
_close_master_connection(
|
|
367
|
+
username=username,
|
|
368
|
+
hostname=hostname,
|
|
369
|
+
socket=socket_or_none,
|
|
370
|
+
)
|
|
371
|
+
statuses = []
|
|
372
|
+
for i, slurm_job_id in enumerate(slurm_job_ids):
|
|
373
|
+
slurm_status = slurm_jobs_status[slurm_job_id]
|
|
374
|
+
if slurm_job_id in latest_slurm_job_ids:
|
|
375
|
+
latest_slurm_job_id = latest_slurm_job_ids[slurm_job_id]
|
|
376
|
+
slurm_status = latest_slurm_jobs_status[latest_slurm_job_id]
|
|
377
|
+
progress = progress_list[i]
|
|
378
|
+
progress = progress if progress is not None else "unknown"
|
|
379
|
+
execution_state = SlurmExecutor._map_slurm_state_to_execution_state(
|
|
380
|
+
slurm_status
|
|
381
|
+
)
|
|
382
|
+
execdb_job_id = job_id_to_execdb_id.get(slurm_job_id)
|
|
383
|
+
if execdb_job_id:
|
|
384
|
+
statuses.append(
|
|
385
|
+
ExecutionStatus(
|
|
386
|
+
id=execdb_job_id,
|
|
387
|
+
state=execution_state,
|
|
388
|
+
progress=progress,
|
|
389
|
+
)
|
|
390
|
+
)
|
|
391
|
+
return statuses
|
|
392
|
+
|
|
393
|
+
@staticmethod
|
|
394
|
+
def _map_slurm_state_to_execution_state(slurm_status: str) -> ExecutionState:
|
|
395
|
+
"""Map SLURM state to ExecutionState.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
slurm_status: SLURM status string.
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
Corresponding ExecutionState.
|
|
402
|
+
"""
|
|
403
|
+
if slurm_status in ["COMPLETED"]:
|
|
404
|
+
return ExecutionState.SUCCESS
|
|
405
|
+
elif slurm_status in [
|
|
406
|
+
"PENDING",
|
|
407
|
+
"RESV_DEL_HOLD",
|
|
408
|
+
"REQUEUE_FED",
|
|
409
|
+
"REQUEUE_HOLD",
|
|
410
|
+
"REQUEUED",
|
|
411
|
+
"REVOKED",
|
|
412
|
+
]:
|
|
413
|
+
return ExecutionState.PENDING
|
|
414
|
+
elif slurm_status in ["RUNNING", "CONFIGURING", "SUSPENDED", "COMPLETING"]:
|
|
415
|
+
return ExecutionState.RUNNING
|
|
416
|
+
elif slurm_status in ["PREEMPTED", "TIMEOUT", "NODE_FAIL"]:
|
|
417
|
+
return ExecutionState.PENDING # autoresume
|
|
418
|
+
elif slurm_status in ["CANCELLED"]:
|
|
419
|
+
return ExecutionState.KILLED
|
|
420
|
+
elif slurm_status in ["FAILED"]:
|
|
421
|
+
return ExecutionState.FAILED
|
|
422
|
+
else:
|
|
423
|
+
return ExecutionState.FAILED
|
|
424
|
+
|
|
425
|
+
@staticmethod
|
|
426
|
+
def kill_job(job_id: str) -> None:
|
|
427
|
+
"""Kill a SLURM job.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
job_id: The job ID (e.g., abc123.0) to kill.
|
|
431
|
+
"""
|
|
432
|
+
db = ExecutionDB()
|
|
433
|
+
job_data = db.get_job(job_id)
|
|
434
|
+
|
|
435
|
+
if job_data is None:
|
|
436
|
+
raise ValueError(f"Job {job_id} not found")
|
|
437
|
+
|
|
438
|
+
if job_data.executor != "slurm":
|
|
439
|
+
raise ValueError(
|
|
440
|
+
f"Job {job_id} is not a slurm job (executor: {job_data.executor})"
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# OPTIMIZATION: Query status AND kill in ONE SSH call
|
|
444
|
+
slurm_status, result = _kill_slurm_job(
|
|
445
|
+
slurm_job_ids=[job_data.data.get("slurm_job_id")],
|
|
446
|
+
username=job_data.data.get("username"),
|
|
447
|
+
hostname=job_data.data.get("hostname"),
|
|
448
|
+
socket=job_data.data.get("socket"),
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# Mark job as killed in database if kill succeeded
|
|
452
|
+
if result.returncode == 0:
|
|
453
|
+
job_data.data["killed"] = True
|
|
454
|
+
db.write_job(job_data)
|
|
455
|
+
else:
|
|
456
|
+
# Use the pre-fetched status for better error message
|
|
457
|
+
current_status = None
|
|
458
|
+
if slurm_status:
|
|
459
|
+
current_status = SlurmExecutor._map_slurm_state_to_execution_state(
|
|
460
|
+
slurm_status
|
|
461
|
+
)
|
|
462
|
+
error_msg = SlurmExecutor.get_kill_failure_message(
|
|
463
|
+
job_id,
|
|
464
|
+
f"slurm_job_id: {job_data.data.get('slurm_job_id')}",
|
|
465
|
+
current_status,
|
|
466
|
+
)
|
|
467
|
+
raise RuntimeError(error_msg)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def _create_slurm_sbatch_script(
|
|
471
|
+
cfg: DictConfig,
|
|
472
|
+
task: DictConfig,
|
|
473
|
+
eval_image: str,
|
|
474
|
+
remote_task_subdir: Path,
|
|
475
|
+
invocation_id: str,
|
|
476
|
+
job_id: str,
|
|
477
|
+
) -> CmdAndReadableComment:
|
|
478
|
+
"""Generate the contents of a SLURM sbatch script for a given evaluation task.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
cfg: The configuration object for the evaluation run.
|
|
482
|
+
task: The evaluation task configuration.
|
|
483
|
+
remote_task_subdir: The remote directory path for the `run.sub` file.
|
|
484
|
+
invocation_id: The invocation ID for this evaluation run.
|
|
485
|
+
job_id: The complete job ID string.
|
|
486
|
+
|
|
487
|
+
Returns:
|
|
488
|
+
str: The contents of the sbatch script.
|
|
489
|
+
"""
|
|
490
|
+
# get task from mapping, overrides, urls
|
|
491
|
+
tasks_mapping = load_tasks_mapping()
|
|
492
|
+
task_definition = get_task_from_mapping(task.name, tasks_mapping)
|
|
493
|
+
|
|
494
|
+
# Create merged config for get_endpoint_url
|
|
495
|
+
merged_nemo_evaluator_config = get_eval_factory_config(cfg, task)
|
|
496
|
+
health_url = get_health_url(
|
|
497
|
+
cfg,
|
|
498
|
+
get_endpoint_url(
|
|
499
|
+
cfg, merged_nemo_evaluator_config, task_definition["endpoint_type"]
|
|
500
|
+
),
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
# TODO(public release): convert to template
|
|
504
|
+
s = "#!/bin/bash\n"
|
|
505
|
+
|
|
506
|
+
# SBATCH headers
|
|
507
|
+
s += "#SBATCH --time {}\n".format(cfg.execution.walltime)
|
|
508
|
+
s += "#SBATCH --account {}\n".format(cfg.execution.account)
|
|
509
|
+
s += "#SBATCH --partition {}\n".format(cfg.execution.partition)
|
|
510
|
+
s += "#SBATCH --nodes {}\n".format(cfg.execution.num_nodes)
|
|
511
|
+
s += "#SBATCH --ntasks-per-node {}\n".format(cfg.execution.ntasks_per_node)
|
|
512
|
+
if cfg.execution.get("gpus_per_node", None) is not None:
|
|
513
|
+
s += "#SBATCH --gpus-per-node {}\n".format(cfg.execution.gpus_per_node)
|
|
514
|
+
if hasattr(cfg.execution, "gres"):
|
|
515
|
+
s += "#SBATCH --gres {}\n".format(cfg.execution.gres)
|
|
516
|
+
job_name = "{account}-{subproject}.{details}".format(
|
|
517
|
+
account=cfg.execution.account,
|
|
518
|
+
subproject=cfg.execution.subproject,
|
|
519
|
+
details=remote_task_subdir.name,
|
|
520
|
+
)
|
|
521
|
+
s += "#SBATCH --job-name {}\n".format(job_name)
|
|
522
|
+
s += "#SBATCH --exclusive\n"
|
|
523
|
+
s += "#SBATCH --output {}\n".format(remote_task_subdir / "logs" / "slurm-%A.out")
|
|
524
|
+
s += "\n"
|
|
525
|
+
s += f'TASK_DIR="{str(remote_task_subdir)}"\n'
|
|
526
|
+
s += "\n"
|
|
527
|
+
|
|
528
|
+
# collect all env vars
|
|
529
|
+
env_vars = copy.deepcopy(dict(cfg.evaluation.get("env_vars", {})))
|
|
530
|
+
env_vars.update(task.get("env_vars", {}))
|
|
531
|
+
api_key_name = get_api_key_name(cfg)
|
|
532
|
+
if api_key_name:
|
|
533
|
+
assert "API_KEY" not in env_vars
|
|
534
|
+
env_vars["API_KEY"] = api_key_name
|
|
535
|
+
|
|
536
|
+
# check if the environment variables are set
|
|
537
|
+
for env_var in env_vars.values():
|
|
538
|
+
if os.getenv(env_var) is None:
|
|
539
|
+
raise ValueError(f"Trying to pass an unset environment variable {env_var}.")
|
|
540
|
+
|
|
541
|
+
# check if required env vars are defined:
|
|
542
|
+
for required_env_var in task_definition.get("required_env_vars", []):
|
|
543
|
+
if required_env_var not in env_vars.keys():
|
|
544
|
+
raise ValueError(
|
|
545
|
+
f"{task.name} task requires environment variable {required_env_var}."
|
|
546
|
+
" Specify it in the task subconfig in the 'env_vars' dict as the following"
|
|
547
|
+
f" pair {required_env_var}: YOUR_ENV_VAR_NAME"
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# save env vars:
|
|
551
|
+
for env_var_dst, env_var_src in env_vars.items():
|
|
552
|
+
s += f"export {env_var_dst}={os.getenv(env_var_src)}\n"
|
|
553
|
+
all_env_vars = {
|
|
554
|
+
**cfg.execution.get("env_vars", {}).get("deployment", {}),
|
|
555
|
+
**cfg.execution.get("env_vars", {}).get("evaluation", {}),
|
|
556
|
+
}
|
|
557
|
+
if cfg.deployment.get("env_vars"):
|
|
558
|
+
warnings.warn(
|
|
559
|
+
"cfg.deployment.env_vars will be deprecated in future versions. "
|
|
560
|
+
"Use cfg.execution.env_vars.deployment instead.",
|
|
561
|
+
category=DeprecationWarning,
|
|
562
|
+
stacklevel=2,
|
|
563
|
+
)
|
|
564
|
+
all_env_vars.update(cfg.deployment["env_vars"])
|
|
565
|
+
for env_var_dst, env_var_value in all_env_vars.items():
|
|
566
|
+
s += f"export {env_var_dst}={env_var_value}\n"
|
|
567
|
+
if env_vars:
|
|
568
|
+
s += "\n"
|
|
569
|
+
|
|
570
|
+
# auto resume after timeout
|
|
571
|
+
s += _AUTORESUME_HANDLER
|
|
572
|
+
s += "\n\n"
|
|
573
|
+
|
|
574
|
+
# echo the current SLURM_JOB_ID
|
|
575
|
+
s += "# save the current job id\n"
|
|
576
|
+
s += "echo $SLURM_JOB_ID >> {}\n\n".format(
|
|
577
|
+
remote_task_subdir / ".slurm_job_id.list"
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
# shell options
|
|
581
|
+
s += "set -e # exit immediately if any command exits with a non-zero status\n"
|
|
582
|
+
s += "set -u # treat unset variables as an error when substituting\n"
|
|
583
|
+
s += "set -x # print commands and their arguments as they are executed\n"
|
|
584
|
+
s += "\n"
|
|
585
|
+
|
|
586
|
+
# prepare deployment mounts
|
|
587
|
+
deployment_mounts_list = []
|
|
588
|
+
if cfg.deployment.type != "none":
|
|
589
|
+
if checkpoint_path := cfg.deployment.get("checkpoint_path"):
|
|
590
|
+
deployment_mounts_list.append(f"{checkpoint_path}:/checkpoint:ro")
|
|
591
|
+
if cache_path := cfg.deployment.get("cache_path"):
|
|
592
|
+
deployment_mounts_list.append(f"{cache_path}:/cache")
|
|
593
|
+
for source_mnt, target_mnt in (
|
|
594
|
+
cfg.execution.get("mounts", {}).get("deployment", {}).items()
|
|
595
|
+
):
|
|
596
|
+
deployment_mounts_list.append(f"{source_mnt}:{target_mnt}")
|
|
597
|
+
|
|
598
|
+
# add deployment srun command
|
|
599
|
+
s += "# deployment server\n"
|
|
600
|
+
s += "srun --mpi pmix --overlap "
|
|
601
|
+
s += "--container-image {} ".format(cfg.deployment.image)
|
|
602
|
+
if deployment_mounts_list:
|
|
603
|
+
s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
|
|
604
|
+
if not cfg.execution.get("mounts", {}).get("mount_home", True):
|
|
605
|
+
s += "--no-container-mount-home "
|
|
606
|
+
s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A.out")
|
|
607
|
+
deployment_env_var_names = list(
|
|
608
|
+
cfg.execution.get("env_vars", {}).get("deployment", {})
|
|
609
|
+
)
|
|
610
|
+
if cfg.deployment.get("env_vars"):
|
|
611
|
+
warnings.warn(
|
|
612
|
+
"cfg.deployment.env_vars will be deprecated in future versions. "
|
|
613
|
+
"Use cfg.execution.env_vars.deployment instead.",
|
|
614
|
+
category=DeprecationWarning,
|
|
615
|
+
stacklevel=2,
|
|
616
|
+
)
|
|
617
|
+
deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
|
|
618
|
+
if deployment_env_var_names:
|
|
619
|
+
s += f"--container-env {','.join(deployment_env_var_names)} "
|
|
620
|
+
s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
|
|
621
|
+
s += (
|
|
622
|
+
"SERVER_PID=$! # capture the PID of the server background srun process\n\n"
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# wait for the server to initialize
|
|
626
|
+
s += _WAIT_FOR_SERVER_HANDLER.format(health_url=health_url)
|
|
627
|
+
s += "\n\n"
|
|
628
|
+
|
|
629
|
+
# prepare evaluation mounts
|
|
630
|
+
evaluation_mounts_list = [
|
|
631
|
+
"{}:/results".format(remote_task_subdir / "artifacts"),
|
|
632
|
+
]
|
|
633
|
+
for source_mnt, target_mnt in (
|
|
634
|
+
cfg.execution.get("mounts", {}).get("evaluation", {}).items()
|
|
635
|
+
):
|
|
636
|
+
evaluation_mounts_list.append(f"{source_mnt}:{target_mnt}")
|
|
637
|
+
|
|
638
|
+
eval_factory_command_struct = get_eval_factory_command(
|
|
639
|
+
cfg,
|
|
640
|
+
task,
|
|
641
|
+
task_definition,
|
|
642
|
+
)
|
|
643
|
+
eval_factory_command = eval_factory_command_struct.cmd
|
|
644
|
+
# The debug comment for placing into the script and easy debug. Reason
|
|
645
|
+
# (see `CmdAndReadableComment`) is the current way of passing the command
|
|
646
|
+
# is base64-encoded config `echo`-ed into file.
|
|
647
|
+
# TODO(agronskiy): cleaner way is to encode everything with base64, not
|
|
648
|
+
# some parts (like ef_config.yaml) and just output as logs somewhere.
|
|
649
|
+
eval_factory_command_debug_comment = eval_factory_command_struct.debug
|
|
650
|
+
|
|
651
|
+
# add evaluation srun command
|
|
652
|
+
s += "# Debug contents of the eval factory command's config\n"
|
|
653
|
+
s += eval_factory_command_debug_comment
|
|
654
|
+
s += "\n\n"
|
|
655
|
+
|
|
656
|
+
s += "# evaluation client\n"
|
|
657
|
+
s += "srun --mpi pmix --overlap "
|
|
658
|
+
s += "--container-image {} ".format(eval_image)
|
|
659
|
+
evaluation_env_var_names = list(
|
|
660
|
+
cfg.execution.get("env_vars", {}).get("evaluation", {})
|
|
661
|
+
)
|
|
662
|
+
if evaluation_env_var_names:
|
|
663
|
+
s += "--container-env {} ".format(",".join(evaluation_env_var_names))
|
|
664
|
+
if not cfg.execution.get("mounts", {}).get("mount_home", True):
|
|
665
|
+
s += "--no-container-mount-home "
|
|
666
|
+
|
|
667
|
+
s += "--container-mounts {} ".format(",".join(evaluation_mounts_list))
|
|
668
|
+
s += "--output {} ".format(remote_task_subdir / "logs" / "client-%A.out")
|
|
669
|
+
s += "bash -c '\n"
|
|
670
|
+
s += eval_factory_command
|
|
671
|
+
s += "'\n\n"
|
|
672
|
+
|
|
673
|
+
# terminate the server after all evaluation clients finish
|
|
674
|
+
if cfg.deployment.type != "none":
|
|
675
|
+
s += "kill $SERVER_PID # terminate the server to finish gracefully\n\n"
|
|
676
|
+
|
|
677
|
+
# auto-export
|
|
678
|
+
ae_cfg = cfg.execution.get("auto_export")
|
|
679
|
+
destinations: list = []
|
|
680
|
+
if isinstance(ae_cfg, list):
|
|
681
|
+
destinations = list(ae_cfg)
|
|
682
|
+
elif isinstance(ae_cfg, dict) or isinstance(ae_cfg, DictConfig):
|
|
683
|
+
destinations = list(ae_cfg.get("destinations", []) or [])
|
|
684
|
+
|
|
685
|
+
if destinations:
|
|
686
|
+
export_env = dict(cfg.execution.get("env_vars", {}).get("export", {}) or {})
|
|
687
|
+
s += _generate_auto_export_section(cfg, job_id, destinations, export_env)
|
|
688
|
+
|
|
689
|
+
debug_str = "\n".join(["# " + line for line in s.splitlines()])
|
|
690
|
+
return CmdAndReadableComment(
|
|
691
|
+
cmd=s,
|
|
692
|
+
debug=debug_str,
|
|
693
|
+
is_potentially_unsafe=eval_factory_command_struct.is_potentially_unsafe,
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def _generate_auto_export_section(
|
|
698
|
+
cfg: DictConfig,
|
|
699
|
+
job_id: str,
|
|
700
|
+
destinations: list,
|
|
701
|
+
export_env: dict,
|
|
702
|
+
) -> str:
|
|
703
|
+
"""Generate simple auto-export section for sbatch script."""
|
|
704
|
+
if not destinations:
|
|
705
|
+
return ""
|
|
706
|
+
|
|
707
|
+
s = "\n# Auto-export on success\n"
|
|
708
|
+
s += "EVAL_EXIT_CODE=$?\n"
|
|
709
|
+
s += "if [ $EVAL_EXIT_CODE -eq 0 ]; then\n"
|
|
710
|
+
s += " echo 'Evaluation completed successfully. Starting auto-export...'\n"
|
|
711
|
+
s += " set +e\n"
|
|
712
|
+
s += " set +x\n"
|
|
713
|
+
s += " set +u\n"
|
|
714
|
+
s += ' cd "$TASK_DIR/artifacts"\n'
|
|
715
|
+
|
|
716
|
+
# Work with DictConfig; convert only for YAML at the end
|
|
717
|
+
exec_type = (
|
|
718
|
+
cfg.execution.type
|
|
719
|
+
if hasattr(cfg.execution, "type")
|
|
720
|
+
else cfg.execution.get("type", "slurm")
|
|
721
|
+
)
|
|
722
|
+
eval_tasks = (
|
|
723
|
+
list(cfg.evaluation.tasks)
|
|
724
|
+
if hasattr(cfg, "evaluation") and hasattr(cfg.evaluation, "tasks")
|
|
725
|
+
else list((cfg.get("evaluation", {}) or {}).get("tasks", []) or [])
|
|
726
|
+
)
|
|
727
|
+
export_block = cfg.get("export", {}) or {}
|
|
728
|
+
|
|
729
|
+
payload = {
|
|
730
|
+
"execution": {
|
|
731
|
+
"auto_export": {
|
|
732
|
+
"destinations": list(destinations),
|
|
733
|
+
**({"env_vars": dict(export_env)} if export_env else {}),
|
|
734
|
+
},
|
|
735
|
+
"type": exec_type,
|
|
736
|
+
},
|
|
737
|
+
"evaluation": {"tasks": eval_tasks},
|
|
738
|
+
}
|
|
739
|
+
if export_block:
|
|
740
|
+
# Convert just this block to plain for YAML
|
|
741
|
+
payload["export"] = (
|
|
742
|
+
OmegaConf.to_object(export_block)
|
|
743
|
+
if OmegaConf.is_config(export_block)
|
|
744
|
+
else dict(export_block)
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
# Final YAML (single conversion at the end)
|
|
748
|
+
payload_clean = OmegaConf.to_container(OmegaConf.create(payload), resolve=True)
|
|
749
|
+
yaml_str = yaml.safe_dump(payload_clean, sort_keys=False)
|
|
750
|
+
s += " cat > export_config.yml << 'EOF'\n"
|
|
751
|
+
s += yaml_str
|
|
752
|
+
s += "EOF\n"
|
|
753
|
+
|
|
754
|
+
# write launcher config as config.yml for exporters (no core command)
|
|
755
|
+
submitted_yaml = yaml.safe_dump(
|
|
756
|
+
OmegaConf.to_container(cfg, resolve=True), sort_keys=False
|
|
757
|
+
)
|
|
758
|
+
s += " cat > config.yml << 'EOF'\n"
|
|
759
|
+
s += submitted_yaml
|
|
760
|
+
s += "EOF\n"
|
|
761
|
+
|
|
762
|
+
# Export host only env before running auto export
|
|
763
|
+
for k, v in (export_env or {}).items():
|
|
764
|
+
if isinstance(v, str) and re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", v):
|
|
765
|
+
s += f' export {k}="${{{v}}}"\n'
|
|
766
|
+
else:
|
|
767
|
+
esc = str(v).replace('"', '\\"')
|
|
768
|
+
s += f' export {k}="{esc}"\n'
|
|
769
|
+
|
|
770
|
+
for dest in destinations:
|
|
771
|
+
s += f" echo 'Exporting to {dest}...'\n"
|
|
772
|
+
s += f" nemo-evaluator-launcher export {job_id} --dest {dest} || echo 'Export to {dest} failed'\n"
|
|
773
|
+
|
|
774
|
+
s += " echo 'Auto-export completed.'\n"
|
|
775
|
+
s += "else\n"
|
|
776
|
+
s += " echo 'Evaluation failed with exit code $EVAL_EXIT_CODE. Skipping auto-export.'\n"
|
|
777
|
+
s += "fi\n"
|
|
778
|
+
|
|
779
|
+
return s
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
def _open_master_connection(
|
|
783
|
+
username: str,
|
|
784
|
+
hostname: str,
|
|
785
|
+
socket: str,
|
|
786
|
+
) -> str | None:
|
|
787
|
+
ssh_command = f"ssh -MNf -S {socket} {username}@{hostname}"
|
|
788
|
+
completed_process = subprocess.run(
|
|
789
|
+
args=shlex.split(ssh_command), capture_output=True
|
|
790
|
+
)
|
|
791
|
+
if completed_process.returncode == 0:
|
|
792
|
+
return socket
|
|
793
|
+
return None
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def _close_master_connection(
|
|
797
|
+
username: str,
|
|
798
|
+
hostname: str,
|
|
799
|
+
socket: str | None,
|
|
800
|
+
) -> None:
|
|
801
|
+
if socket is None:
|
|
802
|
+
return
|
|
803
|
+
ssh_command = f"ssh -O exit -S {socket} {username}@{hostname}"
|
|
804
|
+
completed_process = subprocess.run(
|
|
805
|
+
args=shlex.split(ssh_command), capture_output=True
|
|
806
|
+
)
|
|
807
|
+
if completed_process.returncode != 0:
|
|
808
|
+
raise RuntimeError(
|
|
809
|
+
"failed to close the master connection\n{}".format(
|
|
810
|
+
completed_process.stderr.decode("utf-8")
|
|
811
|
+
)
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
|
|
815
|
+
def _make_remote_execution_output_dir(
|
|
816
|
+
dirpath: str,
|
|
817
|
+
username: str,
|
|
818
|
+
hostname: str,
|
|
819
|
+
socket: str | None,
|
|
820
|
+
) -> None:
|
|
821
|
+
mkdir_command = f"mkdir -p {dirpath}"
|
|
822
|
+
ssh_command = ["ssh"]
|
|
823
|
+
if socket is not None:
|
|
824
|
+
ssh_command.append(f"-S {socket}")
|
|
825
|
+
ssh_command.append(f"{username}@{hostname}")
|
|
826
|
+
ssh_command.append(mkdir_command)
|
|
827
|
+
ssh_command = " ".join(ssh_command)
|
|
828
|
+
completed_process = subprocess.run(
|
|
829
|
+
args=shlex.split(ssh_command), capture_output=True
|
|
830
|
+
)
|
|
831
|
+
if completed_process.returncode != 0:
|
|
832
|
+
error_msg = (
|
|
833
|
+
completed_process.stderr.decode("utf-8")
|
|
834
|
+
if completed_process.stderr
|
|
835
|
+
else "Unknown error"
|
|
836
|
+
)
|
|
837
|
+
raise RuntimeError(
|
|
838
|
+
"failed to make a remote execution output dir\n{}".format(error_msg)
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
def _rsync_upload_rundirs(
|
|
843
|
+
local_sources: List[Path],
|
|
844
|
+
remote_target: str,
|
|
845
|
+
username: str,
|
|
846
|
+
hostname: str,
|
|
847
|
+
) -> None:
|
|
848
|
+
"""Upload local run directories to a remote host using rsync over SSH.
|
|
849
|
+
|
|
850
|
+
Args:
|
|
851
|
+
local_sources: List of local Path objects to upload.
|
|
852
|
+
remote_target: Remote directory path as a string.
|
|
853
|
+
hostname: SSH hostname.
|
|
854
|
+
username: SSH username.
|
|
855
|
+
|
|
856
|
+
Raises:
|
|
857
|
+
RuntimeError: If rsync fails.
|
|
858
|
+
"""
|
|
859
|
+
for local_source in local_sources:
|
|
860
|
+
assert local_source.is_dir()
|
|
861
|
+
remote_destination_str = f"{username}@{hostname}:{remote_target}"
|
|
862
|
+
local_sources_str = " ".join(map(str, local_sources))
|
|
863
|
+
rsync_upload_command = f"rsync -qcaz {local_sources_str} {remote_destination_str}"
|
|
864
|
+
completed_process = subprocess.run(
|
|
865
|
+
args=shlex.split(rsync_upload_command), capture_output=True
|
|
866
|
+
)
|
|
867
|
+
if completed_process.returncode != 0:
|
|
868
|
+
error_msg = (
|
|
869
|
+
completed_process.stderr.decode("utf-8")
|
|
870
|
+
if completed_process.stderr
|
|
871
|
+
else "Unknown error"
|
|
872
|
+
)
|
|
873
|
+
raise RuntimeError("failed to upload local sources\n{}".format(error_msg))
|
|
874
|
+
|
|
875
|
+
|
|
876
|
+
def _sbatch_remote_runsubs(
|
|
877
|
+
remote_runsub_paths: List[Path],
|
|
878
|
+
username: str,
|
|
879
|
+
hostname: str,
|
|
880
|
+
socket: str | None,
|
|
881
|
+
) -> List[str]:
|
|
882
|
+
sbatch_commands = [
|
|
883
|
+
"sbatch {}".format(remote_runsub_path)
|
|
884
|
+
for remote_runsub_path in remote_runsub_paths
|
|
885
|
+
]
|
|
886
|
+
sbatch_commands = " ; ".join(sbatch_commands)
|
|
887
|
+
|
|
888
|
+
ssh_command = ["ssh"]
|
|
889
|
+
if socket is not None:
|
|
890
|
+
ssh_command.append(f"-S {socket}")
|
|
891
|
+
ssh_command.append(f"{username}@{hostname}")
|
|
892
|
+
ssh_command.append(sbatch_commands)
|
|
893
|
+
ssh_command = " ".join(ssh_command)
|
|
894
|
+
|
|
895
|
+
completed_process = subprocess.run(
|
|
896
|
+
args=shlex.split(ssh_command), capture_output=True
|
|
897
|
+
)
|
|
898
|
+
if completed_process.returncode != 0:
|
|
899
|
+
error_msg = completed_process.stderr.decode("utf-8")
|
|
900
|
+
raise RuntimeError(
|
|
901
|
+
"failed to submit sbatch scripts for execution\n{}".format(error_msg)
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
sbatch_output = completed_process.stdout.decode("utf-8")
|
|
905
|
+
slurm_job_ids = re.findall(r"(?<=Submitted batch job )\d+", sbatch_output)
|
|
906
|
+
return slurm_job_ids
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
def _query_slurm_jobs_status(
|
|
910
|
+
slurm_job_ids: List[str],
|
|
911
|
+
username: str,
|
|
912
|
+
hostname: str,
|
|
913
|
+
socket: str | None,
|
|
914
|
+
) -> Dict[str, str]:
|
|
915
|
+
"""Query SLURM for job statuses using sacct command.
|
|
916
|
+
|
|
917
|
+
Args:
|
|
918
|
+
slurm_job_ids: List of SLURM job IDs to query.
|
|
919
|
+
username: SSH username.
|
|
920
|
+
hostname: SSH hostname.
|
|
921
|
+
socket: control socket location or None
|
|
922
|
+
|
|
923
|
+
Returns:
|
|
924
|
+
Dict mapping from slurm_job_id to returned slurm status.
|
|
925
|
+
"""
|
|
926
|
+
if len(slurm_job_ids) == 0:
|
|
927
|
+
return {}
|
|
928
|
+
sacct_command = "sacct -j {} --format='JobID,State%32' --noheader -P".format(
|
|
929
|
+
",".join(slurm_job_ids)
|
|
930
|
+
)
|
|
931
|
+
ssh_command = ["ssh"]
|
|
932
|
+
if socket is not None:
|
|
933
|
+
ssh_command.append(f"-S {socket}")
|
|
934
|
+
ssh_command.append(f"{username}@{hostname}")
|
|
935
|
+
ssh_command.append(sacct_command)
|
|
936
|
+
ssh_command = " ".join(ssh_command)
|
|
937
|
+
completed_process = subprocess.run(
|
|
938
|
+
args=shlex.split(ssh_command), capture_output=True
|
|
939
|
+
)
|
|
940
|
+
if completed_process.returncode != 0:
|
|
941
|
+
raise RuntimeError(
|
|
942
|
+
"failed to query slurm job status\n{}".format(
|
|
943
|
+
completed_process.stderr.decode("utf-8")
|
|
944
|
+
)
|
|
945
|
+
)
|
|
946
|
+
sacct_output = completed_process.stdout.decode("utf-8")
|
|
947
|
+
sacct_output_lines = sacct_output.strip().split("\n")
|
|
948
|
+
slurm_jobs_status = {}
|
|
949
|
+
for slurm_job_id in slurm_job_ids:
|
|
950
|
+
slurm_job_status = _parse_slurm_job_status(slurm_job_id, sacct_output_lines)
|
|
951
|
+
slurm_jobs_status[slurm_job_id] = slurm_job_status
|
|
952
|
+
return slurm_jobs_status
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
def _kill_slurm_job(
|
|
956
|
+
slurm_job_ids: List[str], username: str, hostname: str, socket: str | None
|
|
957
|
+
) -> tuple[str | None, subprocess.CompletedProcess]:
|
|
958
|
+
"""Kill a SLURM job, querying status first in one SSH call for efficiency.
|
|
959
|
+
|
|
960
|
+
Args:
|
|
961
|
+
slurm_job_ids: List of SLURM job IDs to kill.
|
|
962
|
+
username: SSH username.
|
|
963
|
+
hostname: SSH hostname.
|
|
964
|
+
socket: control socket location or None
|
|
965
|
+
|
|
966
|
+
Returns:
|
|
967
|
+
Tuple of (status_string, completed_process) where status_string is the SLURM status or None
|
|
968
|
+
"""
|
|
969
|
+
if len(slurm_job_ids) == 0:
|
|
970
|
+
return None, subprocess.CompletedProcess(args=[], returncode=0)
|
|
971
|
+
|
|
972
|
+
jobs_str = ",".join(slurm_job_ids)
|
|
973
|
+
# Combine both commands in one SSH call: query THEN kill
|
|
974
|
+
combined_command = (
|
|
975
|
+
f"sacct -j {jobs_str} --format='JobID,State%32' --noheader -P 2>/dev/null; "
|
|
976
|
+
f"scancel {jobs_str}"
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
ssh_command = ["ssh"]
|
|
980
|
+
if socket is not None:
|
|
981
|
+
ssh_command.append(f"-S {socket}")
|
|
982
|
+
ssh_command.append(f"{username}@{hostname}")
|
|
983
|
+
ssh_command.append(combined_command)
|
|
984
|
+
ssh_command = " ".join(ssh_command)
|
|
985
|
+
|
|
986
|
+
completed_process = subprocess.run(
|
|
987
|
+
args=shlex.split(ssh_command), capture_output=True
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
# Parse the sacct output (before scancel runs)
|
|
991
|
+
sacct_output = completed_process.stdout.decode("utf-8")
|
|
992
|
+
sacct_output_lines = sacct_output.strip().split("\n")
|
|
993
|
+
slurm_status = None
|
|
994
|
+
if sacct_output_lines and len(slurm_job_ids) == 1:
|
|
995
|
+
slurm_status = _parse_slurm_job_status(slurm_job_ids[0], sacct_output_lines)
|
|
996
|
+
|
|
997
|
+
return slurm_status, completed_process
|
|
998
|
+
|
|
999
|
+
|
|
1000
|
+
def _parse_slurm_job_status(slurm_job_id: str, sacct_output_lines: List[str]) -> str:
|
|
1001
|
+
"""Parse SLURM job status from sacct output for a specific job.
|
|
1002
|
+
|
|
1003
|
+
Args:
|
|
1004
|
+
slurm_job_id: The SLURM job ID to parse.
|
|
1005
|
+
sacct_output_lines: Lines from sacct output.
|
|
1006
|
+
|
|
1007
|
+
Returns:
|
|
1008
|
+
SLURM status string.
|
|
1009
|
+
"""
|
|
1010
|
+
for line in sacct_output_lines:
|
|
1011
|
+
if line.startswith(f"{slurm_job_id}|"):
|
|
1012
|
+
state = line.split("|")[1]
|
|
1013
|
+
state = state.strip()
|
|
1014
|
+
if state:
|
|
1015
|
+
state_split = state.split()
|
|
1016
|
+
if len(state_split) > 0:
|
|
1017
|
+
return state_split[0]
|
|
1018
|
+
return "UNKNOWN"
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
def _read_autoresumed_slurm_job_ids(
|
|
1022
|
+
slurm_job_ids: List[str],
|
|
1023
|
+
remote_rundir_paths: List[Path],
|
|
1024
|
+
username: str,
|
|
1025
|
+
hostname: str,
|
|
1026
|
+
socket: str | None,
|
|
1027
|
+
) -> Dict[str, List[str]]:
|
|
1028
|
+
assert len(slurm_job_ids) == len(remote_rundir_paths)
|
|
1029
|
+
slurm_job_id_list_paths = [
|
|
1030
|
+
str(remote_rundir_path / ".slurm_job_id.list")
|
|
1031
|
+
for remote_rundir_path in remote_rundir_paths
|
|
1032
|
+
]
|
|
1033
|
+
slurm_job_id_list_strs = _read_files_from_remote(
|
|
1034
|
+
slurm_job_id_list_paths, username, hostname, socket
|
|
1035
|
+
)
|
|
1036
|
+
assert len(slurm_job_id_list_strs) == len(slurm_job_ids)
|
|
1037
|
+
autoresumed_slurm_job_ids = {}
|
|
1038
|
+
for i, slurm_job_id_list_str in enumerate(slurm_job_id_list_strs):
|
|
1039
|
+
slurm_job_id = slurm_job_ids[i]
|
|
1040
|
+
slurm_job_id_list = slurm_job_id_list_str.split()
|
|
1041
|
+
autoresumed_slurm_job_ids[slurm_job_id] = slurm_job_id_list
|
|
1042
|
+
return autoresumed_slurm_job_ids
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
def _read_files_from_remote(
|
|
1046
|
+
filepaths: List[Path],
|
|
1047
|
+
username: str,
|
|
1048
|
+
hostname: str,
|
|
1049
|
+
socket: str | None,
|
|
1050
|
+
) -> List[str]:
|
|
1051
|
+
cat_commands = [
|
|
1052
|
+
"echo _START_OF_FILE_ ; cat {} 2>/dev/null ; echo _END_OF_FILE_ ".format(
|
|
1053
|
+
filepath
|
|
1054
|
+
)
|
|
1055
|
+
for filepath in filepaths
|
|
1056
|
+
]
|
|
1057
|
+
cat_commands = " ; ".join(cat_commands)
|
|
1058
|
+
ssh_command = ["ssh"]
|
|
1059
|
+
if socket is not None:
|
|
1060
|
+
ssh_command.append(f"-S {socket}")
|
|
1061
|
+
ssh_command.append(f"{username}@{hostname}")
|
|
1062
|
+
ssh_command.append(cat_commands)
|
|
1063
|
+
ssh_command = " ".join(ssh_command)
|
|
1064
|
+
completed_process = subprocess.run(
|
|
1065
|
+
args=shlex.split(ssh_command), capture_output=True
|
|
1066
|
+
)
|
|
1067
|
+
if completed_process.returncode != 0:
|
|
1068
|
+
raise RuntimeError(
|
|
1069
|
+
"failed to read files from remote\n{}".format(
|
|
1070
|
+
completed_process.stderr.decode("utf-8")
|
|
1071
|
+
)
|
|
1072
|
+
)
|
|
1073
|
+
cat_outputs = completed_process.stdout.decode("utf-8")
|
|
1074
|
+
cat_outputs = cat_outputs.replace("\n", " ")
|
|
1075
|
+
matches = re.findall(r"(?<=_START_OF_FILE_)(.*?)(?=_END_OF_FILE_)", cat_outputs)
|
|
1076
|
+
outputs = [match.strip() for match in matches]
|
|
1077
|
+
return outputs
|
|
1078
|
+
|
|
1079
|
+
|
|
1080
|
+
def _get_progress(
|
|
1081
|
+
remote_rundir_paths: List[Path],
|
|
1082
|
+
username: str,
|
|
1083
|
+
hostname: str,
|
|
1084
|
+
socket: str | None,
|
|
1085
|
+
) -> List[Optional[float]]:
|
|
1086
|
+
remote_progress_paths = [
|
|
1087
|
+
remote_rundir_path / "artifacts" / "progress"
|
|
1088
|
+
for remote_rundir_path in remote_rundir_paths
|
|
1089
|
+
]
|
|
1090
|
+
remote_run_config_paths = [
|
|
1091
|
+
remote_rundir_path / "artifacts" / "run_config.yml"
|
|
1092
|
+
for remote_rundir_path in remote_rundir_paths
|
|
1093
|
+
]
|
|
1094
|
+
progress_strs = _read_files_from_remote(
|
|
1095
|
+
remote_progress_paths, username, hostname, socket
|
|
1096
|
+
)
|
|
1097
|
+
if any(map(bool, progress_strs)):
|
|
1098
|
+
run_config_strs = _read_files_from_remote(
|
|
1099
|
+
remote_run_config_paths, username, hostname, socket
|
|
1100
|
+
)
|
|
1101
|
+
else:
|
|
1102
|
+
run_config_strs = [""] * len(progress_strs)
|
|
1103
|
+
progress_list = []
|
|
1104
|
+
for progress_str, run_config_str in zip(progress_strs, run_config_strs):
|
|
1105
|
+
if not progress_str or not run_config_str:
|
|
1106
|
+
progress_list.append(None)
|
|
1107
|
+
continue
|
|
1108
|
+
run_config = yaml.safe_load(run_config_str)
|
|
1109
|
+
dataset_size = get_eval_factory_dataset_size_from_run_config(run_config)
|
|
1110
|
+
if dataset_size is not None:
|
|
1111
|
+
progress = int(progress_str) / dataset_size
|
|
1112
|
+
else:
|
|
1113
|
+
progress = int(progress_str)
|
|
1114
|
+
progress_list.append(progress)
|
|
1115
|
+
return progress_list
|
|
1116
|
+
|
|
1117
|
+
|
|
1118
|
+
_AUTORESUME_HANDLER = """
|
|
1119
|
+
_this_script=$0
|
|
1120
|
+
_prev_slurm_job_id=$1
|
|
1121
|
+
# Handle automatic resumption after some failed state.
|
|
1122
|
+
if [[ "$_prev_slurm_job_id" != "" ]]; then
|
|
1123
|
+
_prev_state=`sacct -j $_prev_slurm_job_id -P -n -o State | head -n 1`
|
|
1124
|
+
_prev_info="previous SLURM_JOB_ID $_prev_slurm_job_id finished with '$_prev_state' state."
|
|
1125
|
+
if [[ $_prev_state == 'TIMEOUT' || $_prev_state == 'PREEMPTED' || $_prev_state == 'NODE_FAIL' ]]; then
|
|
1126
|
+
echo "$_prev_info RESUMING..."
|
|
1127
|
+
else
|
|
1128
|
+
echo "$_prev_info EXIT!"
|
|
1129
|
+
if [[ $_prev_state == 'COMPLETED' ]]; then
|
|
1130
|
+
exit 0
|
|
1131
|
+
else
|
|
1132
|
+
exit 1
|
|
1133
|
+
fi
|
|
1134
|
+
fi
|
|
1135
|
+
fi
|
|
1136
|
+
# Schedule next execution of this script with the current $SLURM_JOB_ID as an argument.
|
|
1137
|
+
# "afternotok" means next execution will be invoked only if the current execution terminates in some failed state.
|
|
1138
|
+
sbatch --dependency=afternotok:$SLURM_JOB_ID $_this_script $SLURM_JOB_ID
|
|
1139
|
+
""".strip()
|
|
1140
|
+
|
|
1141
|
+
|
|
1142
|
+
_WAIT_FOR_SERVER_HANDLER = """
|
|
1143
|
+
date
|
|
1144
|
+
# wait for the server to initialize
|
|
1145
|
+
bash -c 'while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" {health_url})" != "200" ]]; do kill -0 '"$SERVER_PID"' 2>/dev/null || {{ echo "Server process '"$SERVER_PID"' died"; exit 1; }}; sleep 5; done'
|
|
1146
|
+
date
|
|
1147
|
+
""".strip()
|