nemo-evaluator-launcher 0.1.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nemo-evaluator-launcher might be problematic. Click here for more details.

Files changed (57) hide show
  1. nemo_evaluator_launcher/__init__.py +65 -0
  2. nemo_evaluator_launcher/api/__init__.py +24 -0
  3. nemo_evaluator_launcher/api/functional.py +641 -0
  4. nemo_evaluator_launcher/api/types.py +89 -0
  5. nemo_evaluator_launcher/api/utils.py +19 -0
  6. nemo_evaluator_launcher/cli/__init__.py +15 -0
  7. nemo_evaluator_launcher/cli/export.py +148 -0
  8. nemo_evaluator_launcher/cli/info.py +117 -0
  9. nemo_evaluator_launcher/cli/kill.py +39 -0
  10. nemo_evaluator_launcher/cli/ls_runs.py +113 -0
  11. nemo_evaluator_launcher/cli/ls_tasks.py +34 -0
  12. nemo_evaluator_launcher/cli/main.py +136 -0
  13. nemo_evaluator_launcher/cli/run.py +135 -0
  14. nemo_evaluator_launcher/cli/status.py +118 -0
  15. nemo_evaluator_launcher/cli/version.py +52 -0
  16. nemo_evaluator_launcher/common/__init__.py +16 -0
  17. nemo_evaluator_launcher/common/execdb.py +189 -0
  18. nemo_evaluator_launcher/common/helpers.py +157 -0
  19. nemo_evaluator_launcher/common/logging_utils.py +349 -0
  20. nemo_evaluator_launcher/common/mapping.py +310 -0
  21. nemo_evaluator_launcher/configs/__init__.py +15 -0
  22. nemo_evaluator_launcher/configs/default.yaml +28 -0
  23. nemo_evaluator_launcher/configs/deployment/nim.yaml +32 -0
  24. nemo_evaluator_launcher/configs/deployment/none.yaml +16 -0
  25. nemo_evaluator_launcher/configs/deployment/sglang.yaml +38 -0
  26. nemo_evaluator_launcher/configs/deployment/vllm.yaml +41 -0
  27. nemo_evaluator_launcher/configs/execution/lepton/default.yaml +92 -0
  28. nemo_evaluator_launcher/configs/execution/local.yaml +17 -0
  29. nemo_evaluator_launcher/configs/execution/slurm/default.yaml +33 -0
  30. nemo_evaluator_launcher/executors/__init__.py +22 -0
  31. nemo_evaluator_launcher/executors/base.py +97 -0
  32. nemo_evaluator_launcher/executors/lepton/__init__.py +16 -0
  33. nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +589 -0
  34. nemo_evaluator_launcher/executors/lepton/executor.py +905 -0
  35. nemo_evaluator_launcher/executors/lepton/job_helpers.py +394 -0
  36. nemo_evaluator_launcher/executors/local/__init__.py +15 -0
  37. nemo_evaluator_launcher/executors/local/executor.py +491 -0
  38. nemo_evaluator_launcher/executors/local/run.template.sh +88 -0
  39. nemo_evaluator_launcher/executors/registry.py +38 -0
  40. nemo_evaluator_launcher/executors/slurm/__init__.py +15 -0
  41. nemo_evaluator_launcher/executors/slurm/executor.py +982 -0
  42. nemo_evaluator_launcher/exporters/__init__.py +36 -0
  43. nemo_evaluator_launcher/exporters/base.py +112 -0
  44. nemo_evaluator_launcher/exporters/gsheets.py +391 -0
  45. nemo_evaluator_launcher/exporters/local.py +488 -0
  46. nemo_evaluator_launcher/exporters/mlflow.py +448 -0
  47. nemo_evaluator_launcher/exporters/registry.py +40 -0
  48. nemo_evaluator_launcher/exporters/utils.py +669 -0
  49. nemo_evaluator_launcher/exporters/wandb.py +376 -0
  50. nemo_evaluator_launcher/package_info.py +35 -0
  51. nemo_evaluator_launcher/resources/mapping.toml +344 -0
  52. nemo_evaluator_launcher-0.1.0rc2.dist-info/METADATA +35 -0
  53. nemo_evaluator_launcher-0.1.0rc2.dist-info/RECORD +57 -0
  54. nemo_evaluator_launcher-0.1.0rc2.dist-info/WHEEL +5 -0
  55. nemo_evaluator_launcher-0.1.0rc2.dist-info/entry_points.txt +3 -0
  56. nemo_evaluator_launcher-0.1.0rc2.dist-info/licenses/LICENSE +451 -0
  57. nemo_evaluator_launcher-0.1.0rc2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,982 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ """SLURM executor implementation for nemo-evaluator-launcher.
17
+
18
+ Handles submitting evaluation jobs to a SLURM cluster via SSH and sbatch scripts.
19
+ """
20
+
21
+ import copy
22
+ import os
23
+ import re
24
+ import shlex
25
+ import subprocess
26
+ import tempfile
27
+ import time
28
+ import warnings
29
+ from pathlib import Path
30
+ from typing import Dict, List, Optional
31
+
32
+ import yaml
33
+ from omegaconf import DictConfig, OmegaConf
34
+
35
+ from nemo_evaluator_launcher.common.execdb import (
36
+ ExecutionDB,
37
+ JobData,
38
+ generate_invocation_id,
39
+ generate_job_id,
40
+ )
41
+ from nemo_evaluator_launcher.common.helpers import (
42
+ get_api_key_name,
43
+ get_endpoint_url,
44
+ get_eval_factory_command,
45
+ get_eval_factory_dataset_size_from_run_config,
46
+ get_health_url,
47
+ get_timestamp_string,
48
+ )
49
+ from nemo_evaluator_launcher.common.mapping import (
50
+ get_task_from_mapping,
51
+ load_tasks_mapping,
52
+ )
53
+ from nemo_evaluator_launcher.executors.base import (
54
+ BaseExecutor,
55
+ ExecutionState,
56
+ ExecutionStatus,
57
+ )
58
+ from nemo_evaluator_launcher.executors.registry import register_executor
59
+
60
+
61
+ @register_executor("slurm")
62
+ class SlurmExecutor(BaseExecutor):
63
+ @staticmethod
64
+ def execute_eval(cfg: DictConfig, dry_run: bool = False) -> str:
65
+ """Submit evaluation jobs to a SLURM cluster using the provided configuration.
66
+
67
+ Args:
68
+ cfg: The configuration object for the evaluation run.
69
+ dry_run: If True, prepare scripts and save them without submission.
70
+
71
+ Returns:
72
+ str: The invocation ID for the evaluation run.
73
+
74
+ Raises:
75
+ AssertionError: If deployment type is 'none'.
76
+ RuntimeError: If remote directory creation or sbatch submission fails.
77
+ """
78
+
79
+ # Generate invocation ID
80
+ invocation_id = generate_invocation_id()
81
+
82
+ local_runsub_paths = []
83
+ remote_runsub_paths = []
84
+
85
+ with tempfile.TemporaryDirectory() as tmpdirname:
86
+ timestamp = get_timestamp_string(include_microseconds=False)
87
+ rundir_name = timestamp + "-" + invocation_id
88
+ remote_rundir = Path(cfg.execution.output_dir) / rundir_name
89
+ local_rundir = Path(tmpdirname) / rundir_name
90
+ local_rundir.mkdir()
91
+
92
+ # Preload mapping for image resolution
93
+ tasks_mapping = load_tasks_mapping()
94
+ eval_images: list[str] = []
95
+
96
+ for idx, task in enumerate(cfg.evaluation.tasks):
97
+ # calculate job_id
98
+ job_id = f"{invocation_id}.{idx}"
99
+
100
+ # preapre locally
101
+ remote_task_subdir = remote_rundir / task.name
102
+ local_task_subdir = local_rundir / task.name
103
+ local_task_subdir.mkdir() # this ensures the task.name hasn't been used
104
+ (local_task_subdir / "logs").mkdir()
105
+ (local_task_subdir / "artifacts").mkdir()
106
+
107
+ # resolve eval image and pass directly via task override
108
+ task_definition = get_task_from_mapping(task.name, tasks_mapping)
109
+ eval_image = task_definition["container"]
110
+ if "container" in task:
111
+ eval_image = task["container"]
112
+
113
+ eval_images.append(eval_image)
114
+
115
+ # generate and write down sbatch script
116
+ sbatch_script_content_str = _create_slurm_sbatch_script(
117
+ cfg=cfg,
118
+ task=task,
119
+ eval_image=eval_image,
120
+ remote_task_subdir=remote_task_subdir,
121
+ invocation_id=invocation_id,
122
+ job_id=job_id,
123
+ )
124
+ local_runsub_path = local_task_subdir / "run.sub"
125
+ remote_runsub_path = remote_task_subdir / "run.sub"
126
+ with open(local_runsub_path, "w") as f:
127
+ f.write(sbatch_script_content_str.rstrip("\n") + "\n")
128
+
129
+ local_runsub_paths.append(local_runsub_path)
130
+ remote_runsub_paths.append(remote_runsub_path)
131
+
132
+ if dry_run:
133
+ print("\n\n=============================================\n\n")
134
+ print("DRY RUN: SLURM scripts prepared")
135
+ for idx, local_runsub_path in enumerate(local_runsub_paths):
136
+ print(f"\n\n =========== Task {idx} ===================== \n\n")
137
+ with open(local_runsub_path, "r") as f:
138
+ print(f.read())
139
+ print("\nTo submit jobs, run the executor without --dry-run")
140
+ return invocation_id
141
+
142
+ socket = str(Path(tmpdirname) / "socket")
143
+ socket_or_none = _open_master_connection(
144
+ username=cfg.execution.username,
145
+ hostname=cfg.execution.hostname,
146
+ socket=socket,
147
+ )
148
+ _make_remote_execution_output_dir(
149
+ dirpath=cfg.execution.output_dir,
150
+ username=cfg.execution.username,
151
+ hostname=cfg.execution.hostname,
152
+ socket=socket_or_none,
153
+ )
154
+ _rsync_upload_rundirs(
155
+ local_sources=[local_rundir],
156
+ remote_target=cfg.execution.output_dir,
157
+ username=cfg.execution.username,
158
+ hostname=cfg.execution.hostname,
159
+ )
160
+ slurm_job_ids = _sbatch_remote_runsubs(
161
+ remote_runsub_paths=remote_runsub_paths,
162
+ username=cfg.execution.username,
163
+ hostname=cfg.execution.hostname,
164
+ socket=socket_or_none,
165
+ )
166
+ _close_master_connection(
167
+ username=cfg.execution.username,
168
+ hostname=cfg.execution.hostname,
169
+ socket=socket_or_none,
170
+ )
171
+
172
+ # save launched jobs metadata
173
+ db = ExecutionDB()
174
+ for idx, (slurm_job_id, remote_runsub_path) in enumerate(
175
+ zip(slurm_job_ids, remote_runsub_paths)
176
+ ):
177
+ db.write_job(
178
+ job=JobData(
179
+ invocation_id=invocation_id,
180
+ job_id=generate_job_id(invocation_id, idx),
181
+ timestamp=time.time(),
182
+ executor="slurm",
183
+ data={
184
+ "slurm_job_id": slurm_job_id,
185
+ "remote_rundir_path": str(remote_runsub_path.parent),
186
+ "hostname": cfg.execution.hostname,
187
+ "username": cfg.execution.username,
188
+ "eval_image": eval_images[idx],
189
+ },
190
+ config=OmegaConf.to_object(cfg),
191
+ )
192
+ )
193
+ return invocation_id
194
+
195
+ @staticmethod
196
+ def get_status(id: str) -> List[ExecutionStatus]:
197
+ """Get the status of a specific SLURM job or all jobs in an invocation group.
198
+
199
+ Args:
200
+ id: Unique job identifier or invocation identifier.
201
+
202
+ Returns:
203
+ List containing the execution status for the job(s).
204
+ """
205
+ db = ExecutionDB()
206
+
207
+ # If id looks like an invocation_id (8 hex digits, no dot), get all jobs for it
208
+ if len(id) == 8 and "." not in id:
209
+ jobs = db.get_jobs(id)
210
+ if not jobs:
211
+ return []
212
+ return SlurmExecutor._get_status_for_invocation(jobs)
213
+
214
+ # Otherwise, treat as job_id
215
+ else:
216
+ job_data = db.get_job(id)
217
+ if job_data is None or job_data.executor != "slurm":
218
+ return []
219
+ return [SlurmExecutor._get_status_for_job(id, job_data)]
220
+
221
+ @staticmethod
222
+ def _get_status_for_job(id: str, job_data: JobData) -> ExecutionStatus:
223
+ slurm_job_id = job_data.data.get("slurm_job_id")
224
+ if not slurm_job_id:
225
+ return ExecutionStatus(id=id, state=ExecutionState.FAILED)
226
+
227
+ try:
228
+ return SlurmExecutor._query_slurm_for_status_and_progress(
229
+ slurm_job_ids=[slurm_job_id],
230
+ remote_rundir_paths=[Path(job_data.data.get("remote_rundir_path"))],
231
+ username=job_data.data["username"],
232
+ hostname=job_data.data["hostname"],
233
+ job_id_to_execdb_id={slurm_job_id: id},
234
+ )[0]
235
+ except Exception:
236
+ return ExecutionStatus(id=id, state=ExecutionState.FAILED)
237
+
238
+ @staticmethod
239
+ def _get_status_for_invocation(jobs: dict) -> List[ExecutionStatus]:
240
+ slurm_job_ids = []
241
+ remote_rundir_paths = []
242
+ job_id_to_execdb_id = {}
243
+ username = None
244
+ hostname = None
245
+
246
+ for job_id, job_data in jobs.items():
247
+ if job_data.executor != "slurm":
248
+ continue
249
+ slurm_job_id = job_data.data.get("slurm_job_id")
250
+ if slurm_job_id:
251
+ slurm_job_ids.append(slurm_job_id)
252
+ remote_rundir_paths.append(
253
+ Path(job_data.data.get("remote_rundir_path"))
254
+ )
255
+ job_id_to_execdb_id[slurm_job_id] = job_id
256
+ username = job_data.data.get("username")
257
+ hostname = job_data.data.get("hostname")
258
+
259
+ if not slurm_job_ids or not remote_rundir_paths or not username or not hostname:
260
+ return [
261
+ ExecutionStatus(id=job_id, state=ExecutionState.FAILED)
262
+ for job_id in jobs.keys()
263
+ ]
264
+
265
+ try:
266
+ return SlurmExecutor._query_slurm_for_status_and_progress(
267
+ slurm_job_ids=slurm_job_ids,
268
+ remote_rundir_paths=remote_rundir_paths,
269
+ username=username,
270
+ hostname=hostname,
271
+ job_id_to_execdb_id=job_id_to_execdb_id,
272
+ )
273
+ except Exception:
274
+ return [
275
+ ExecutionStatus(id=job_id, state=ExecutionState.FAILED)
276
+ for job_id in jobs.keys()
277
+ ]
278
+
279
+ @staticmethod
280
+ def _query_slurm_for_status_and_progress(
281
+ slurm_job_ids: List[str],
282
+ remote_rundir_paths: List[Path],
283
+ username: str,
284
+ hostname: str,
285
+ job_id_to_execdb_id: dict,
286
+ ) -> List[ExecutionStatus]:
287
+ with tempfile.TemporaryDirectory() as tmpdirname:
288
+ socket = str(Path(tmpdirname) / "socket")
289
+ socket_or_none = _open_master_connection(
290
+ username=username,
291
+ hostname=hostname,
292
+ socket=socket,
293
+ )
294
+ # get slurm job status for initial jobs:
295
+ slurm_jobs_status = _query_slurm_jobs_status(
296
+ slurm_job_ids=slurm_job_ids,
297
+ username=username,
298
+ hostname=hostname,
299
+ socket=socket_or_none,
300
+ )
301
+ # handle slurm status for autoresumed jobs:
302
+ autoresumed_slurm_job_ids = _read_autoresumed_slurm_job_ids(
303
+ slurm_job_ids=slurm_job_ids,
304
+ remote_rundir_paths=remote_rundir_paths,
305
+ username=username,
306
+ hostname=hostname,
307
+ socket=socket_or_none,
308
+ )
309
+ latest_slurm_job_ids = {
310
+ slurm_job_id: slurm_job_id_list[-1]
311
+ for slurm_job_id, slurm_job_id_list in autoresumed_slurm_job_ids.items()
312
+ if len(slurm_job_id_list) > 0 and slurm_job_id_list[-1] != slurm_job_id
313
+ }
314
+ latest_slurm_jobs_status = _query_slurm_jobs_status(
315
+ slurm_job_ids=list(latest_slurm_job_ids.values()),
316
+ username=username,
317
+ hostname=hostname,
318
+ socket=socket_or_none,
319
+ )
320
+ # get progress:
321
+ progress_list = _get_progress(
322
+ remote_rundir_paths=remote_rundir_paths,
323
+ username=username,
324
+ hostname=hostname,
325
+ socket=socket_or_none,
326
+ )
327
+ _close_master_connection(
328
+ username=username,
329
+ hostname=hostname,
330
+ socket=socket_or_none,
331
+ )
332
+ statuses = []
333
+ for i, slurm_job_id in enumerate(slurm_job_ids):
334
+ slurm_status = slurm_jobs_status[slurm_job_id]
335
+ if slurm_job_id in latest_slurm_job_ids:
336
+ latest_slurm_job_id = latest_slurm_job_ids[slurm_job_id]
337
+ slurm_status = latest_slurm_jobs_status[latest_slurm_job_id]
338
+ progress = progress_list[i]
339
+ progress = progress if progress is not None else "unknown"
340
+ execution_state = SlurmExecutor._map_slurm_state_to_execution_state(
341
+ slurm_status
342
+ )
343
+ execdb_job_id = job_id_to_execdb_id.get(slurm_job_id)
344
+ if execdb_job_id:
345
+ statuses.append(
346
+ ExecutionStatus(
347
+ id=execdb_job_id,
348
+ state=execution_state,
349
+ progress=progress,
350
+ )
351
+ )
352
+ return statuses
353
+
354
+ @staticmethod
355
+ def _map_slurm_state_to_execution_state(slurm_status: str) -> ExecutionState:
356
+ """Map SLURM state to ExecutionState.
357
+
358
+ Args:
359
+ slurm_status: SLURM status string.
360
+
361
+ Returns:
362
+ Corresponding ExecutionState.
363
+ """
364
+ if slurm_status in ["COMPLETED"]:
365
+ return ExecutionState.SUCCESS
366
+ elif slurm_status in [
367
+ "PENDING",
368
+ "RESV_DEL_HOLD",
369
+ "REQUEUE_FED",
370
+ "REQUEUE_HOLD",
371
+ "REQUEUED",
372
+ "REVOKED",
373
+ ]:
374
+ return ExecutionState.PENDING
375
+ elif slurm_status in ["RUNNING", "CONFIGURING", "SUSPENDED", "COMPLETING"]:
376
+ return ExecutionState.RUNNING
377
+ elif slurm_status in ["PREEMPTED", "TIMEOUT", "NODE_FAIL"]:
378
+ return ExecutionState.PENDING # autoresume
379
+ elif slurm_status in ["CANCELLED"]:
380
+ return ExecutionState.KILLED
381
+ elif slurm_status in ["FAILED"]:
382
+ return ExecutionState.FAILED
383
+ else:
384
+ return ExecutionState.FAILED
385
+
386
+ @staticmethod
387
+ def kill_job(job_id: str) -> None:
388
+ """Kill a SLURM job.
389
+
390
+ Args:
391
+ job_id: The job ID to kill.
392
+ """
393
+ db = ExecutionDB()
394
+ job_data = db.get_job(job_id)
395
+
396
+ if job_data is None:
397
+ raise ValueError(f"Job {job_id} not found")
398
+
399
+ if job_data.executor != "slurm":
400
+ raise ValueError(
401
+ f"Job {job_id} is not a slurm job (executor: {job_data.executor})"
402
+ )
403
+
404
+ killed_something = False
405
+
406
+ result = _kill_slurm_job(
407
+ slurm_job_ids=[job_data.data.get("slurm_job_id")],
408
+ username=job_data.data.get("username"),
409
+ hostname=job_data.data.get("hostname"),
410
+ socket=job_data.data.get("socket"),
411
+ )
412
+
413
+ if result.returncode == 0:
414
+ killed_something = True
415
+
416
+ # Mark job as killed in database if we killed something
417
+ if killed_something:
418
+ job_data.data["killed"] = True
419
+ db.write_job(job_data)
420
+ else:
421
+ raise RuntimeError(
422
+ f"Could not find or kill job {job_id} (slurm_job_id: {job_data.data.get('slurm_job_id')})"
423
+ )
424
+
425
+
426
+ def _create_slurm_sbatch_script(
427
+ cfg: DictConfig,
428
+ task: DictConfig,
429
+ eval_image: str,
430
+ remote_task_subdir: Path,
431
+ invocation_id: str,
432
+ job_id: str,
433
+ ) -> str:
434
+ """Generate the contents of a SLURM sbatch script for a given evaluation task.
435
+
436
+ Args:
437
+ cfg: The configuration object for the evaluation run.
438
+ task: The evaluation task configuration.
439
+ remote_task_subdir: The remote directory path for the `run.sub` file.
440
+ invocation_id: The invocation ID for this evaluation run.
441
+ job_id: The complete job ID string.
442
+
443
+ Returns:
444
+ str: The contents of the sbatch script.
445
+ """
446
+ # get task from mapping, overrides, urls
447
+ tasks_mapping = load_tasks_mapping()
448
+ task_definition = get_task_from_mapping(task.name, tasks_mapping)
449
+ health_url = get_health_url(cfg, get_endpoint_url(cfg, task, task_definition))
450
+
451
+ # TODO(public release): convert to template
452
+ s = "#!/bin/bash\n"
453
+
454
+ # SBATCH headers
455
+ s += "#SBATCH --time {}\n".format(cfg.execution.walltime)
456
+ s += "#SBATCH --account {}\n".format(cfg.execution.account)
457
+ s += "#SBATCH --partition {}\n".format(cfg.execution.partition)
458
+ s += "#SBATCH --nodes {}\n".format(cfg.execution.num_nodes)
459
+ s += "#SBATCH --ntasks-per-node {}\n".format(cfg.execution.ntasks_per_node)
460
+ if cfg.execution.get("gpus_per_node", None) is not None:
461
+ s += "#SBATCH --gpus-per-node {}\n".format(cfg.execution.gpus_per_node)
462
+ if hasattr(cfg.execution, "gres"):
463
+ s += "#SBATCH --gres {}\n".format(cfg.execution.gres)
464
+ job_name = "{account}-{subproject}.{details}".format(
465
+ account=cfg.execution.account,
466
+ subproject=cfg.execution.subproject,
467
+ details=remote_task_subdir.name,
468
+ )
469
+ s += "#SBATCH --job-name {}\n".format(job_name)
470
+ s += "#SBATCH --exclusive\n"
471
+ s += "#SBATCH --output {}\n".format(remote_task_subdir / "logs" / "slurm-%A.out")
472
+ s += "\n"
473
+
474
+ # collect all env vars
475
+ env_vars = copy.deepcopy(dict(cfg.evaluation.get("env_vars", {})))
476
+ env_vars.update(task.get("env_vars", {}))
477
+ api_key_name = get_api_key_name(cfg)
478
+ if api_key_name:
479
+ assert "API_KEY" not in env_vars
480
+ env_vars["API_KEY"] = api_key_name
481
+
482
+ # check if the environment variables are set
483
+ for env_var in env_vars.values():
484
+ if os.getenv(env_var) is None:
485
+ raise ValueError(f"Trying to pass an unset environment variable {env_var}.")
486
+
487
+ # check if required env vars are defined:
488
+ for required_env_var in task_definition.get("required_env_vars", []):
489
+ if required_env_var not in env_vars.keys():
490
+ raise ValueError(
491
+ f"{task.name} task requires environment variable {required_env_var}."
492
+ " Specify it in the task subconfig in the 'env_vars' dict as the following"
493
+ f" pair {required_env_var}: YOUR_ENV_VAR_NAME"
494
+ )
495
+
496
+ # save env vars:
497
+ for env_var_dst, env_var_src in env_vars.items():
498
+ s += f"export {env_var_dst}={os.getenv(env_var_src)}\n"
499
+ all_env_vars = {
500
+ **cfg.execution.get("env_vars", {}).get("deployment", {}),
501
+ **cfg.execution.get("env_vars", {}).get("evaluation", {}),
502
+ }
503
+ if cfg.deployment.get("env_vars"):
504
+ warnings.warn(
505
+ "cfg.deployment.env_vars will be deprecated in future versions. "
506
+ "Use cfg.execution.env_vars.deployment instead.",
507
+ category=DeprecationWarning,
508
+ stacklevel=2,
509
+ )
510
+ all_env_vars.update(cfg.deployment["env_vars"])
511
+ for env_var_dst, env_var_value in all_env_vars.items():
512
+ s += f"export {env_var_dst}={env_var_value}\n"
513
+ if env_vars:
514
+ s += "\n"
515
+
516
+ # auto resume after timeout
517
+ s += _AUTORESUME_HANDLER
518
+ s += "\n\n"
519
+
520
+ # echo the current SLURM_JOB_ID
521
+ s += "# save the current job id\n"
522
+ s += "echo $SLURM_JOB_ID >> {}\n\n".format(
523
+ remote_task_subdir / ".slurm_job_id.list"
524
+ )
525
+
526
+ # shell options
527
+ s += "set -e # exit immediately if any command exits with a non-zero status\n"
528
+ s += "set -u # treat unset variables as an error when substituting\n"
529
+ s += "set -x # print commands and their arguments as they are executed\n"
530
+ s += "\n"
531
+
532
+ # prepare deployment mounts
533
+ deployment_mounts_list = []
534
+ if cfg.deployment.type != "none":
535
+ if checkpoint_path := cfg.deployment.get("checkpoint_path"):
536
+ deployment_mounts_list.append(f"{checkpoint_path}:/checkpoint:ro")
537
+ if cache_path := cfg.deployment.get("cache_path"):
538
+ deployment_mounts_list.append(f"{cache_path}:/cache")
539
+ for source_mnt, target_mnt in (
540
+ cfg.execution.get("mounts", {}).get("deployment", {}).items()
541
+ ):
542
+ deployment_mounts_list.append(f"{source_mnt}:{target_mnt}")
543
+
544
+ # add deployment srun command
545
+ s += "# deployment server\n"
546
+ s += "srun --mpi pmix --overlap "
547
+ s += "--container-image {} ".format(cfg.deployment.image)
548
+ if deployment_mounts_list:
549
+ s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
550
+ if not cfg.execution.get("mounts", {}).get("mount_home", True):
551
+ s += "--no-container-mount-home "
552
+ s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A.out")
553
+ deployment_env_var_names = list(
554
+ cfg.execution.get("env_vars", {}).get("deployment", {})
555
+ )
556
+ if cfg.deployment.get("env_vars"):
557
+ warnings.warn(
558
+ "cfg.deployment.env_vars will be deprecated in future versions. "
559
+ "Use cfg.execution.env_vars.deployment instead.",
560
+ category=DeprecationWarning,
561
+ stacklevel=2,
562
+ )
563
+ deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
564
+ if deployment_env_var_names:
565
+ s += f"--container-env {','.join(deployment_env_var_names)} "
566
+ s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
567
+ s += (
568
+ "SERVER_PID=$! # capture the PID of the server background srun process\n\n"
569
+ )
570
+
571
+ # wait for the server to initialize
572
+ s += _WAIT_FOR_SERVER_HANDLER.format(health_url=health_url)
573
+ s += "\n\n"
574
+
575
+ # prepare evaluation mounts
576
+ evaluation_mounts_list = [
577
+ "{}:/results".format(remote_task_subdir / "artifacts"),
578
+ ]
579
+ for source_mnt, target_mnt in (
580
+ cfg.execution.get("mounts", {}).get("evaluation", {}).items()
581
+ ):
582
+ evaluation_mounts_list.append(f"{source_mnt}:{target_mnt}")
583
+
584
+ # add evaluation srun command
585
+ s += "# evaluation client\n"
586
+ s += "srun --mpi pmix --overlap "
587
+ s += "--container-image {} ".format(eval_image)
588
+ evaluation_env_var_names = list(
589
+ cfg.execution.get("env_vars", {}).get("evaluation", {})
590
+ )
591
+ if evaluation_env_var_names:
592
+ s += "--container-env {} ".format(",".join(evaluation_env_var_names))
593
+ if not cfg.execution.get("mounts", {}).get("mount_home", True):
594
+ s += "--no-container-mount-home "
595
+ s += "--container-mounts {} ".format(",".join(evaluation_mounts_list))
596
+ s += "--output {} ".format(remote_task_subdir / "logs" / "client-%A.out")
597
+ s += get_eval_factory_command(cfg, task, task_definition)
598
+ s += "\n\n"
599
+
600
+ # terminate the server after all evaluation clients finish
601
+ if cfg.deployment.type != "none":
602
+ s += "kill $SERVER_PID # terminate the server to finish gracefully\n\n"
603
+
604
+ # auto-export
605
+ if cfg.execution.get("auto_export", {}).get("destinations", []):
606
+ s += _generate_auto_export_section(cfg, job_id)
607
+
608
+ return s
609
+
610
+
611
+ def _generate_auto_export_section(
612
+ cfg: DictConfig,
613
+ job_id: str, # Complete job_id string
614
+ ) -> str:
615
+ """Generate simple auto-export section for sbatch script."""
616
+ auto_export_config = cfg.execution.get("auto_export", {})
617
+ destinations = auto_export_config.get("destinations", [])
618
+
619
+ if not destinations:
620
+ return ""
621
+
622
+ s = "\n# AUTO-EXPORT ON SUCCESS\n"
623
+ s += "EVAL_EXIT_CODE=$?\n"
624
+ s += "if [ $EVAL_EXIT_CODE -eq 0 ]; then\n"
625
+ s += " echo 'Evaluation completed successfully. Starting auto-export...'\n"
626
+
627
+ for dest in destinations:
628
+ s += f" echo 'Exporting to {dest}...'\n"
629
+ s += f" nemo-evaluator-launcher export {job_id} --dest {dest} || echo 'Export to {dest} failed'\n"
630
+
631
+ s += " echo 'Auto-export completed.'\n"
632
+ s += "else\n"
633
+ s += " echo 'Evaluation failed with exit code $EVAL_EXIT_CODE. Skipping auto-export.'\n"
634
+ s += "fi\n"
635
+
636
+ return s
637
+
638
+
639
+ def _open_master_connection(
640
+ username: str,
641
+ hostname: str,
642
+ socket: str,
643
+ ) -> str | None:
644
+ ssh_command = f"ssh -MNf -S {socket} {username}@{hostname}"
645
+ completed_process = subprocess.run(args=shlex.split(ssh_command))
646
+ if completed_process.returncode == 0:
647
+ return socket
648
+ return None
649
+
650
+
651
+ def _close_master_connection(
652
+ username: str,
653
+ hostname: str,
654
+ socket: str | None,
655
+ ) -> None:
656
+ if socket is None:
657
+ return
658
+ ssh_command = f"ssh -O exit -S {socket} {username}@{hostname}"
659
+ completed_process = subprocess.run(
660
+ args=shlex.split(ssh_command), capture_output=True
661
+ )
662
+ if completed_process.returncode != 0:
663
+ raise RuntimeError(
664
+ "failed to close the master connection\n{}".format(
665
+ completed_process.stderr.decode("utf-8")
666
+ )
667
+ )
668
+
669
+
670
+ def _make_remote_execution_output_dir(
671
+ dirpath: str,
672
+ username: str,
673
+ hostname: str,
674
+ socket: str | None,
675
+ ) -> None:
676
+ mkdir_command = f"mkdir -p {dirpath}"
677
+ ssh_command = ["ssh"]
678
+ if socket is not None:
679
+ ssh_command.append(f"-S {socket}")
680
+ ssh_command.append(f"{username}@{hostname}")
681
+ ssh_command.append(mkdir_command)
682
+ ssh_command = " ".join(ssh_command)
683
+ completed_process = subprocess.run(args=shlex.split(ssh_command))
684
+ if completed_process.returncode != 0:
685
+ raise RuntimeError(
686
+ "failed to make a remote execution output dir\n{}".format(
687
+ completed_process.stderr.decode("utf-8")
688
+ )
689
+ )
690
+
691
+
692
+ def _rsync_upload_rundirs(
693
+ local_sources: List[Path],
694
+ remote_target: str,
695
+ username: str,
696
+ hostname: str,
697
+ ) -> None:
698
+ """Upload local run directories to a remote host using rsync over SSH.
699
+
700
+ Args:
701
+ local_sources: List of local Path objects to upload.
702
+ remote_target: Remote directory path as a string.
703
+ hostname: SSH hostname.
704
+ username: SSH username.
705
+
706
+ Raises:
707
+ RuntimeError: If rsync fails.
708
+ """
709
+ for local_source in local_sources:
710
+ assert local_source.is_dir()
711
+ remote_destination_str = f"{username}@{hostname}:{remote_target}"
712
+ local_sources_str = " ".join(map(str, local_sources))
713
+ rsync_upload_command = f"rsync -qcaz {local_sources_str} {remote_destination_str}"
714
+ completed_process = subprocess.run(args=shlex.split(rsync_upload_command))
715
+ if completed_process.returncode != 0:
716
+ raise RuntimeError(
717
+ "failed to upload local sources\n{}".format(
718
+ completed_process.stderr.decode("utf-8")
719
+ )
720
+ )
721
+
722
+
723
+ def _sbatch_remote_runsubs(
724
+ remote_runsub_paths: List[Path],
725
+ username: str,
726
+ hostname: str,
727
+ socket: str | None,
728
+ ) -> List[str]:
729
+ sbatch_commands = [
730
+ "sbatch {}".format(remote_runsub_path)
731
+ for remote_runsub_path in remote_runsub_paths
732
+ ]
733
+ sbatch_commands = " ; ".join(sbatch_commands)
734
+
735
+ ssh_command = ["ssh"]
736
+ if socket is not None:
737
+ ssh_command.append(f"-S {socket}")
738
+ ssh_command.append(f"{username}@{hostname}")
739
+ ssh_command.append(sbatch_commands)
740
+ ssh_command = " ".join(ssh_command)
741
+
742
+ completed_process = subprocess.run(
743
+ args=shlex.split(ssh_command), capture_output=True
744
+ )
745
+ if completed_process.returncode != 0:
746
+ raise RuntimeError(
747
+ "failed to submit sbatch scripts for execution\n{}".format(
748
+ completed_process.stderr.decode("utf-8")
749
+ )
750
+ )
751
+
752
+ sbatch_output = completed_process.stdout.decode("utf-8")
753
+ slurm_job_ids = re.findall(r"(?<=Submitted batch job )\d+", sbatch_output)
754
+ return slurm_job_ids
755
+
756
+
757
+ def _query_slurm_jobs_status(
758
+ slurm_job_ids: List[str],
759
+ username: str,
760
+ hostname: str,
761
+ socket: str | None,
762
+ ) -> Dict[str, str]:
763
+ """Query SLURM for job statuses using sacct command.
764
+
765
+ Args:
766
+ slurm_job_ids: List of SLURM job IDs to query.
767
+ username: SSH username.
768
+ hostname: SSH hostname.
769
+ socket: control socket location or None
770
+
771
+ Returns:
772
+ Dict mapping from slurm_job_id to returned slurm status.
773
+ """
774
+ if len(slurm_job_ids) == 0:
775
+ return {}
776
+ sacct_command = "sacct -j {} --format='JobID,State%32' --noheader -P".format(
777
+ ",".join(slurm_job_ids)
778
+ )
779
+ ssh_command = ["ssh"]
780
+ if socket is not None:
781
+ ssh_command.append(f"-S {socket}")
782
+ ssh_command.append(f"{username}@{hostname}")
783
+ ssh_command.append(sacct_command)
784
+ ssh_command = " ".join(ssh_command)
785
+ completed_process = subprocess.run(
786
+ args=shlex.split(ssh_command), capture_output=True
787
+ )
788
+ if completed_process.returncode != 0:
789
+ raise RuntimeError(
790
+ "failed to query slurm job status\n{}".format(
791
+ completed_process.stderr.decode("utf-8")
792
+ )
793
+ )
794
+ sacct_output = completed_process.stdout.decode("utf-8")
795
+ sacct_output_lines = sacct_output.strip().split("\n")
796
+ slurm_jobs_status = {}
797
+ for slurm_job_id in slurm_job_ids:
798
+ slurm_job_status = _parse_slurm_job_status(slurm_job_id, sacct_output_lines)
799
+ slurm_jobs_status[slurm_job_id] = slurm_job_status
800
+ return slurm_jobs_status
801
+
802
+
803
+ def _kill_slurm_job(
804
+ slurm_job_ids: List[str], username: str, hostname: str, socket: str | None
805
+ ) -> None:
806
+ """Kill a SLURM job.
807
+
808
+ Args:
809
+ slurm_job_ids: List of SLURM job IDs to kill.
810
+ username: SSH username.
811
+ hostname: SSH hostname.
812
+ socket: control socket location or None
813
+ """
814
+ if len(slurm_job_ids) == 0:
815
+ return {}
816
+ kill_command = "scancel {}".format(",".join(slurm_job_ids))
817
+ ssh_command = ["ssh"]
818
+ if socket is not None:
819
+ ssh_command.append(f"-S {socket}")
820
+ ssh_command.append(f"{username}@{hostname}")
821
+ ssh_command.append(kill_command)
822
+ ssh_command = " ".join(ssh_command)
823
+ completed_process = subprocess.run(
824
+ args=shlex.split(ssh_command), capture_output=True
825
+ )
826
+ if completed_process.returncode != 0:
827
+ raise RuntimeError(
828
+ "failed to kill slurm job\n{}".format(
829
+ completed_process.stderr.decode("utf-8")
830
+ )
831
+ )
832
+ return completed_process
833
+
834
+
835
+ def _parse_slurm_job_status(slurm_job_id: str, sacct_output_lines: List[str]) -> str:
836
+ """Parse SLURM job status from sacct output for a specific job.
837
+
838
+ Args:
839
+ slurm_job_id: The SLURM job ID to parse.
840
+ sacct_output_lines: Lines from sacct output.
841
+
842
+ Returns:
843
+ SLURM status string.
844
+ """
845
+ for line in sacct_output_lines:
846
+ if line.startswith(f"{slurm_job_id}|"):
847
+ state = line.split("|")[1]
848
+ state = state.strip()
849
+ if state:
850
+ state_split = state.split()
851
+ if len(state_split) > 0:
852
+ return state_split[0]
853
+ return "UNKNOWN"
854
+
855
+
856
+ def _read_autoresumed_slurm_job_ids(
857
+ slurm_job_ids: List[str],
858
+ remote_rundir_paths: List[Path],
859
+ username: str,
860
+ hostname: str,
861
+ socket: str | None,
862
+ ) -> Dict[str, List[str]]:
863
+ assert len(slurm_job_ids) == len(remote_rundir_paths)
864
+ slurm_job_id_list_paths = [
865
+ str(remote_rundir_path / ".slurm_job_id.list")
866
+ for remote_rundir_path in remote_rundir_paths
867
+ ]
868
+ slurm_job_id_list_strs = _read_files_from_remote(
869
+ slurm_job_id_list_paths, username, hostname, socket
870
+ )
871
+ assert len(slurm_job_id_list_strs) == len(slurm_job_ids)
872
+ autoresumed_slurm_job_ids = {}
873
+ for i, slurm_job_id_list_str in enumerate(slurm_job_id_list_strs):
874
+ slurm_job_id = slurm_job_ids[i]
875
+ slurm_job_id_list = slurm_job_id_list_str.split()
876
+ autoresumed_slurm_job_ids[slurm_job_id] = slurm_job_id_list
877
+ return autoresumed_slurm_job_ids
878
+
879
+
880
+ def _read_files_from_remote(
881
+ filepaths: List[Path],
882
+ username: str,
883
+ hostname: str,
884
+ socket: str | None,
885
+ ) -> List[str]:
886
+ cat_commands = [
887
+ "echo _START_OF_FILE_ ; cat {} 2>/dev/null ; echo _END_OF_FILE_ ".format(
888
+ filepath
889
+ )
890
+ for filepath in filepaths
891
+ ]
892
+ cat_commands = " ; ".join(cat_commands)
893
+ ssh_command = ["ssh"]
894
+ if socket is not None:
895
+ ssh_command.append(f"-S {socket}")
896
+ ssh_command.append(f"{username}@{hostname}")
897
+ ssh_command.append(cat_commands)
898
+ ssh_command = " ".join(ssh_command)
899
+ completed_process = subprocess.run(
900
+ args=shlex.split(ssh_command), capture_output=True
901
+ )
902
+ if completed_process.returncode != 0:
903
+ raise RuntimeError(
904
+ "failed to read files from remote\n{}".format(
905
+ completed_process.stderr.decode("utf-8")
906
+ )
907
+ )
908
+ cat_outputs = completed_process.stdout.decode("utf-8")
909
+ cat_outputs = cat_outputs.replace("\n", " ")
910
+ matches = re.findall(r"(?<=_START_OF_FILE_)(.*?)(?=_END_OF_FILE_)", cat_outputs)
911
+ outputs = [match.strip() for match in matches]
912
+ return outputs
913
+
914
+
915
+ def _get_progress(
916
+ remote_rundir_paths: List[Path],
917
+ username: str,
918
+ hostname: str,
919
+ socket: str | None,
920
+ ) -> List[Optional[float]]:
921
+ remote_progress_paths = [
922
+ remote_rundir_path / "artifacts" / "progress"
923
+ for remote_rundir_path in remote_rundir_paths
924
+ ]
925
+ remote_run_config_paths = [
926
+ remote_rundir_path / "artifacts" / "run_config.yml"
927
+ for remote_rundir_path in remote_rundir_paths
928
+ ]
929
+ progress_strs = _read_files_from_remote(
930
+ remote_progress_paths, username, hostname, socket
931
+ )
932
+ if any(map(bool, progress_strs)):
933
+ run_config_strs = _read_files_from_remote(
934
+ remote_run_config_paths, username, hostname, socket
935
+ )
936
+ else:
937
+ run_config_strs = [""] * len(progress_strs)
938
+ progress_list = []
939
+ for progress_str, run_config_str in zip(progress_strs, run_config_strs):
940
+ if not progress_str or not run_config_str:
941
+ progress_list.append(None)
942
+ continue
943
+ run_config = yaml.safe_load(run_config_str)
944
+ dataset_size = get_eval_factory_dataset_size_from_run_config(run_config)
945
+ if dataset_size is not None:
946
+ progress = int(progress_str) / dataset_size
947
+ else:
948
+ progress = int(progress_str)
949
+ progress_list.append(progress)
950
+ return progress_list
951
+
952
+
953
+ _AUTORESUME_HANDLER = """
954
+ _this_script=$0
955
+ _prev_slurm_job_id=$1
956
+ # Handle automatic resumption after some failed state.
957
+ if [[ "$_prev_slurm_job_id" != "" ]]; then
958
+ _prev_state=`sacct -j $_prev_slurm_job_id -P -n -o State | head -n 1`
959
+ _prev_info="previous SLURM_JOB_ID $_prev_slurm_job_id finished with '$_prev_state' state."
960
+ if [[ $_prev_state == 'TIMEOUT' || $_prev_state == 'PREEMPTED' || $_prev_state == 'NODE_FAIL' ]]; then
961
+ echo "$_prev_info RESUMING..."
962
+ else
963
+ echo "$_prev_info EXIT!"
964
+ if [[ $_prev_state == 'COMPLETED' ]]; then
965
+ exit 0
966
+ else
967
+ exit 1
968
+ fi
969
+ fi
970
+ fi
971
+ # Schedule next execution of this script with the current $SLURM_JOB_ID as an argument.
972
+ # "afternotok" means next execution will be invoked only if the current execution terminates in some failed state.
973
+ sbatch --dependency=afternotok:$SLURM_JOB_ID $_this_script $SLURM_JOB_ID
974
+ """.strip()
975
+
976
+
977
+ _WAIT_FOR_SERVER_HANDLER = """
978
+ date
979
+ # wait for the server to initialize
980
+ bash -c 'while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" {health_url})" != "200" ]]; do kill -0 '"$SERVER_PID"' 2>/dev/null || {{ echo "Server process '"$SERVER_PID"' died"; exit 1; }}; sleep 5; done'
981
+ date
982
+ """.strip()