nemo-evaluator-launcher 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nemo-evaluator-launcher might be problematic. Click here for more details.

Files changed (57) hide show
  1. nemo_evaluator_launcher/__init__.py +65 -0
  2. nemo_evaluator_launcher/api/__init__.py +24 -0
  3. nemo_evaluator_launcher/api/functional.py +678 -0
  4. nemo_evaluator_launcher/api/types.py +89 -0
  5. nemo_evaluator_launcher/api/utils.py +19 -0
  6. nemo_evaluator_launcher/cli/__init__.py +15 -0
  7. nemo_evaluator_launcher/cli/export.py +148 -0
  8. nemo_evaluator_launcher/cli/info.py +117 -0
  9. nemo_evaluator_launcher/cli/kill.py +39 -0
  10. nemo_evaluator_launcher/cli/ls_runs.py +113 -0
  11. nemo_evaluator_launcher/cli/ls_tasks.py +134 -0
  12. nemo_evaluator_launcher/cli/main.py +143 -0
  13. nemo_evaluator_launcher/cli/run.py +135 -0
  14. nemo_evaluator_launcher/cli/status.py +120 -0
  15. nemo_evaluator_launcher/cli/version.py +52 -0
  16. nemo_evaluator_launcher/common/__init__.py +16 -0
  17. nemo_evaluator_launcher/common/execdb.py +189 -0
  18. nemo_evaluator_launcher/common/helpers.py +194 -0
  19. nemo_evaluator_launcher/common/logging_utils.py +349 -0
  20. nemo_evaluator_launcher/common/mapping.py +295 -0
  21. nemo_evaluator_launcher/configs/__init__.py +15 -0
  22. nemo_evaluator_launcher/configs/default.yaml +28 -0
  23. nemo_evaluator_launcher/configs/deployment/nim.yaml +32 -0
  24. nemo_evaluator_launcher/configs/deployment/none.yaml +16 -0
  25. nemo_evaluator_launcher/configs/deployment/sglang.yaml +38 -0
  26. nemo_evaluator_launcher/configs/deployment/vllm.yaml +41 -0
  27. nemo_evaluator_launcher/configs/execution/lepton/default.yaml +92 -0
  28. nemo_evaluator_launcher/configs/execution/local.yaml +17 -0
  29. nemo_evaluator_launcher/configs/execution/slurm/default.yaml +33 -0
  30. nemo_evaluator_launcher/executors/__init__.py +22 -0
  31. nemo_evaluator_launcher/executors/base.py +97 -0
  32. nemo_evaluator_launcher/executors/lepton/__init__.py +16 -0
  33. nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +589 -0
  34. nemo_evaluator_launcher/executors/lepton/executor.py +905 -0
  35. nemo_evaluator_launcher/executors/lepton/job_helpers.py +394 -0
  36. nemo_evaluator_launcher/executors/local/__init__.py +15 -0
  37. nemo_evaluator_launcher/executors/local/executor.py +491 -0
  38. nemo_evaluator_launcher/executors/local/run.template.sh +88 -0
  39. nemo_evaluator_launcher/executors/registry.py +38 -0
  40. nemo_evaluator_launcher/executors/slurm/__init__.py +15 -0
  41. nemo_evaluator_launcher/executors/slurm/executor.py +996 -0
  42. nemo_evaluator_launcher/exporters/__init__.py +36 -0
  43. nemo_evaluator_launcher/exporters/base.py +112 -0
  44. nemo_evaluator_launcher/exporters/gsheets.py +391 -0
  45. nemo_evaluator_launcher/exporters/local.py +488 -0
  46. nemo_evaluator_launcher/exporters/mlflow.py +448 -0
  47. nemo_evaluator_launcher/exporters/registry.py +40 -0
  48. nemo_evaluator_launcher/exporters/utils.py +669 -0
  49. nemo_evaluator_launcher/exporters/wandb.py +376 -0
  50. nemo_evaluator_launcher/package_info.py +38 -0
  51. nemo_evaluator_launcher/resources/mapping.toml +344 -0
  52. nemo_evaluator_launcher-0.1.0.dist-info/METADATA +494 -0
  53. nemo_evaluator_launcher-0.1.0.dist-info/RECORD +57 -0
  54. nemo_evaluator_launcher-0.1.0.dist-info/WHEEL +5 -0
  55. nemo_evaluator_launcher-0.1.0.dist-info/entry_points.txt +3 -0
  56. nemo_evaluator_launcher-0.1.0.dist-info/licenses/LICENSE +451 -0
  57. nemo_evaluator_launcher-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,996 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ """SLURM executor implementation for nemo-evaluator-launcher.
17
+
18
+ Handles submitting evaluation jobs to a SLURM cluster via SSH and sbatch scripts.
19
+ """
20
+
21
+ import copy
22
+ import os
23
+ import re
24
+ import shlex
25
+ import subprocess
26
+ import tempfile
27
+ import time
28
+ import warnings
29
+ from pathlib import Path
30
+ from typing import Dict, List, Optional
31
+
32
+ import yaml
33
+ from omegaconf import DictConfig, OmegaConf
34
+
35
+ from nemo_evaluator_launcher.common.execdb import (
36
+ ExecutionDB,
37
+ JobData,
38
+ generate_invocation_id,
39
+ generate_job_id,
40
+ )
41
+ from nemo_evaluator_launcher.common.helpers import (
42
+ get_api_key_name,
43
+ get_endpoint_url,
44
+ get_eval_factory_command,
45
+ get_eval_factory_dataset_size_from_run_config,
46
+ get_health_url,
47
+ get_timestamp_string,
48
+ )
49
+ from nemo_evaluator_launcher.common.mapping import (
50
+ get_task_from_mapping,
51
+ load_tasks_mapping,
52
+ )
53
+ from nemo_evaluator_launcher.executors.base import (
54
+ BaseExecutor,
55
+ ExecutionState,
56
+ ExecutionStatus,
57
+ )
58
+ from nemo_evaluator_launcher.executors.registry import register_executor
59
+
60
+
61
+ @register_executor("slurm")
62
+ class SlurmExecutor(BaseExecutor):
63
+ @staticmethod
64
+ def execute_eval(cfg: DictConfig, dry_run: bool = False) -> str:
65
+ """Submit evaluation jobs to a SLURM cluster using the provided configuration.
66
+
67
+ Args:
68
+ cfg: The configuration object for the evaluation run.
69
+ dry_run: If True, prepare scripts and save them without submission.
70
+
71
+ Returns:
72
+ str: The invocation ID for the evaluation run.
73
+
74
+ Raises:
75
+ AssertionError: If deployment type is 'none'.
76
+ RuntimeError: If remote directory creation or sbatch submission fails.
77
+ """
78
+
79
+ # Generate invocation ID
80
+ invocation_id = generate_invocation_id()
81
+
82
+ local_runsub_paths = []
83
+ remote_runsub_paths = []
84
+
85
+ with tempfile.TemporaryDirectory() as tmpdirname:
86
+ timestamp = get_timestamp_string(include_microseconds=False)
87
+ rundir_name = timestamp + "-" + invocation_id
88
+ remote_rundir = Path(cfg.execution.output_dir) / rundir_name
89
+ local_rundir = Path(tmpdirname) / rundir_name
90
+ local_rundir.mkdir()
91
+
92
+ # Preload mapping for image resolution
93
+ tasks_mapping = load_tasks_mapping()
94
+ eval_images: list[str] = []
95
+
96
+ for idx, task in enumerate(cfg.evaluation.tasks):
97
+ # calculate job_id
98
+ job_id = f"{invocation_id}.{idx}"
99
+
100
+ # preapre locally
101
+ remote_task_subdir = remote_rundir / task.name
102
+ local_task_subdir = local_rundir / task.name
103
+ local_task_subdir.mkdir() # this ensures the task.name hasn't been used
104
+ (local_task_subdir / "logs").mkdir()
105
+ (local_task_subdir / "artifacts").mkdir()
106
+
107
+ # resolve eval image and pass directly via task override
108
+ task_definition = get_task_from_mapping(task.name, tasks_mapping)
109
+ eval_image = task_definition["container"]
110
+ if "container" in task:
111
+ eval_image = task["container"]
112
+
113
+ eval_images.append(eval_image)
114
+
115
+ # generate and write down sbatch script
116
+ sbatch_script_content_str = _create_slurm_sbatch_script(
117
+ cfg=cfg,
118
+ task=task,
119
+ eval_image=eval_image,
120
+ remote_task_subdir=remote_task_subdir,
121
+ invocation_id=invocation_id,
122
+ job_id=job_id,
123
+ )
124
+ local_runsub_path = local_task_subdir / "run.sub"
125
+ remote_runsub_path = remote_task_subdir / "run.sub"
126
+ with open(local_runsub_path, "w") as f:
127
+ f.write(sbatch_script_content_str.rstrip("\n") + "\n")
128
+
129
+ local_runsub_paths.append(local_runsub_path)
130
+ remote_runsub_paths.append(remote_runsub_path)
131
+
132
+ if dry_run:
133
+ print("\n\n=============================================\n\n")
134
+ print("DRY RUN: SLURM scripts prepared")
135
+ for idx, local_runsub_path in enumerate(local_runsub_paths):
136
+ print(f"\n\n =========== Task {idx} ===================== \n\n")
137
+ with open(local_runsub_path, "r") as f:
138
+ print(f.read())
139
+ print("\nTo submit jobs, run the executor without --dry-run")
140
+ return invocation_id
141
+
142
+ socket = str(Path(tmpdirname) / "socket")
143
+ socket_or_none = _open_master_connection(
144
+ username=cfg.execution.username,
145
+ hostname=cfg.execution.hostname,
146
+ socket=socket,
147
+ )
148
+ _make_remote_execution_output_dir(
149
+ dirpath=cfg.execution.output_dir,
150
+ username=cfg.execution.username,
151
+ hostname=cfg.execution.hostname,
152
+ socket=socket_or_none,
153
+ )
154
+ _rsync_upload_rundirs(
155
+ local_sources=[local_rundir],
156
+ remote_target=cfg.execution.output_dir,
157
+ username=cfg.execution.username,
158
+ hostname=cfg.execution.hostname,
159
+ )
160
+ slurm_job_ids = _sbatch_remote_runsubs(
161
+ remote_runsub_paths=remote_runsub_paths,
162
+ username=cfg.execution.username,
163
+ hostname=cfg.execution.hostname,
164
+ socket=socket_or_none,
165
+ )
166
+ _close_master_connection(
167
+ username=cfg.execution.username,
168
+ hostname=cfg.execution.hostname,
169
+ socket=socket_or_none,
170
+ )
171
+
172
+ # save launched jobs metadata
173
+ db = ExecutionDB()
174
+ for idx, (slurm_job_id, remote_runsub_path) in enumerate(
175
+ zip(slurm_job_ids, remote_runsub_paths)
176
+ ):
177
+ db.write_job(
178
+ job=JobData(
179
+ invocation_id=invocation_id,
180
+ job_id=generate_job_id(invocation_id, idx),
181
+ timestamp=time.time(),
182
+ executor="slurm",
183
+ data={
184
+ "slurm_job_id": slurm_job_id,
185
+ "remote_rundir_path": str(remote_runsub_path.parent),
186
+ "hostname": cfg.execution.hostname,
187
+ "username": cfg.execution.username,
188
+ "eval_image": eval_images[idx],
189
+ },
190
+ config=OmegaConf.to_object(cfg),
191
+ )
192
+ )
193
+ return invocation_id
194
+
195
+ @staticmethod
196
+ def get_status(id: str) -> List[ExecutionStatus]:
197
+ """Get the status of a specific SLURM job or all jobs in an invocation group.
198
+
199
+ Args:
200
+ id: Unique job identifier or invocation identifier.
201
+
202
+ Returns:
203
+ List containing the execution status for the job(s).
204
+ """
205
+ db = ExecutionDB()
206
+
207
+ # If id looks like an invocation_id (8 hex digits, no dot), get all jobs for it
208
+ if len(id) == 8 and "." not in id:
209
+ jobs = db.get_jobs(id)
210
+ if not jobs:
211
+ return []
212
+ return SlurmExecutor._get_status_for_invocation(jobs)
213
+
214
+ # Otherwise, treat as job_id
215
+ else:
216
+ job_data = db.get_job(id)
217
+ if job_data is None or job_data.executor != "slurm":
218
+ return []
219
+ return [SlurmExecutor._get_status_for_job(id, job_data)]
220
+
221
+ @staticmethod
222
+ def _get_status_for_job(id: str, job_data: JobData) -> ExecutionStatus:
223
+ slurm_job_id = job_data.data.get("slurm_job_id")
224
+ if not slurm_job_id:
225
+ return ExecutionStatus(id=id, state=ExecutionState.FAILED)
226
+
227
+ try:
228
+ return SlurmExecutor._query_slurm_for_status_and_progress(
229
+ slurm_job_ids=[slurm_job_id],
230
+ remote_rundir_paths=[Path(job_data.data.get("remote_rundir_path"))],
231
+ username=job_data.data["username"],
232
+ hostname=job_data.data["hostname"],
233
+ job_id_to_execdb_id={slurm_job_id: id},
234
+ )[0]
235
+ except Exception:
236
+ return ExecutionStatus(id=id, state=ExecutionState.FAILED)
237
+
238
+ @staticmethod
239
+ def _get_status_for_invocation(jobs: dict) -> List[ExecutionStatus]:
240
+ slurm_job_ids = []
241
+ remote_rundir_paths = []
242
+ job_id_to_execdb_id = {}
243
+ username = None
244
+ hostname = None
245
+
246
+ for job_id, job_data in jobs.items():
247
+ if job_data.executor != "slurm":
248
+ continue
249
+ slurm_job_id = job_data.data.get("slurm_job_id")
250
+ if slurm_job_id:
251
+ slurm_job_ids.append(slurm_job_id)
252
+ remote_rundir_paths.append(
253
+ Path(job_data.data.get("remote_rundir_path"))
254
+ )
255
+ job_id_to_execdb_id[slurm_job_id] = job_id
256
+ username = job_data.data.get("username")
257
+ hostname = job_data.data.get("hostname")
258
+
259
+ if not slurm_job_ids or not remote_rundir_paths or not username or not hostname:
260
+ return [
261
+ ExecutionStatus(id=job_id, state=ExecutionState.FAILED)
262
+ for job_id in jobs.keys()
263
+ ]
264
+
265
+ try:
266
+ return SlurmExecutor._query_slurm_for_status_and_progress(
267
+ slurm_job_ids=slurm_job_ids,
268
+ remote_rundir_paths=remote_rundir_paths,
269
+ username=username,
270
+ hostname=hostname,
271
+ job_id_to_execdb_id=job_id_to_execdb_id,
272
+ )
273
+ except Exception:
274
+ return [
275
+ ExecutionStatus(id=job_id, state=ExecutionState.FAILED)
276
+ for job_id in jobs.keys()
277
+ ]
278
+
279
+ @staticmethod
280
+ def _query_slurm_for_status_and_progress(
281
+ slurm_job_ids: List[str],
282
+ remote_rundir_paths: List[Path],
283
+ username: str,
284
+ hostname: str,
285
+ job_id_to_execdb_id: dict,
286
+ ) -> List[ExecutionStatus]:
287
+ with tempfile.TemporaryDirectory() as tmpdirname:
288
+ socket = str(Path(tmpdirname) / "socket")
289
+ socket_or_none = _open_master_connection(
290
+ username=username,
291
+ hostname=hostname,
292
+ socket=socket,
293
+ )
294
+ # get slurm job status for initial jobs:
295
+ slurm_jobs_status = _query_slurm_jobs_status(
296
+ slurm_job_ids=slurm_job_ids,
297
+ username=username,
298
+ hostname=hostname,
299
+ socket=socket_or_none,
300
+ )
301
+ # handle slurm status for autoresumed jobs:
302
+ autoresumed_slurm_job_ids = _read_autoresumed_slurm_job_ids(
303
+ slurm_job_ids=slurm_job_ids,
304
+ remote_rundir_paths=remote_rundir_paths,
305
+ username=username,
306
+ hostname=hostname,
307
+ socket=socket_or_none,
308
+ )
309
+ latest_slurm_job_ids = {
310
+ slurm_job_id: slurm_job_id_list[-1]
311
+ for slurm_job_id, slurm_job_id_list in autoresumed_slurm_job_ids.items()
312
+ if len(slurm_job_id_list) > 0 and slurm_job_id_list[-1] != slurm_job_id
313
+ }
314
+ latest_slurm_jobs_status = _query_slurm_jobs_status(
315
+ slurm_job_ids=list(latest_slurm_job_ids.values()),
316
+ username=username,
317
+ hostname=hostname,
318
+ socket=socket_or_none,
319
+ )
320
+ # get progress:
321
+ progress_list = _get_progress(
322
+ remote_rundir_paths=remote_rundir_paths,
323
+ username=username,
324
+ hostname=hostname,
325
+ socket=socket_or_none,
326
+ )
327
+ _close_master_connection(
328
+ username=username,
329
+ hostname=hostname,
330
+ socket=socket_or_none,
331
+ )
332
+ statuses = []
333
+ for i, slurm_job_id in enumerate(slurm_job_ids):
334
+ slurm_status = slurm_jobs_status[slurm_job_id]
335
+ if slurm_job_id in latest_slurm_job_ids:
336
+ latest_slurm_job_id = latest_slurm_job_ids[slurm_job_id]
337
+ slurm_status = latest_slurm_jobs_status[latest_slurm_job_id]
338
+ progress = progress_list[i]
339
+ progress = progress if progress is not None else "unknown"
340
+ execution_state = SlurmExecutor._map_slurm_state_to_execution_state(
341
+ slurm_status
342
+ )
343
+ execdb_job_id = job_id_to_execdb_id.get(slurm_job_id)
344
+ if execdb_job_id:
345
+ statuses.append(
346
+ ExecutionStatus(
347
+ id=execdb_job_id,
348
+ state=execution_state,
349
+ progress=progress,
350
+ )
351
+ )
352
+ return statuses
353
+
354
+ @staticmethod
355
+ def _map_slurm_state_to_execution_state(slurm_status: str) -> ExecutionState:
356
+ """Map SLURM state to ExecutionState.
357
+
358
+ Args:
359
+ slurm_status: SLURM status string.
360
+
361
+ Returns:
362
+ Corresponding ExecutionState.
363
+ """
364
+ if slurm_status in ["COMPLETED"]:
365
+ return ExecutionState.SUCCESS
366
+ elif slurm_status in [
367
+ "PENDING",
368
+ "RESV_DEL_HOLD",
369
+ "REQUEUE_FED",
370
+ "REQUEUE_HOLD",
371
+ "REQUEUED",
372
+ "REVOKED",
373
+ ]:
374
+ return ExecutionState.PENDING
375
+ elif slurm_status in ["RUNNING", "CONFIGURING", "SUSPENDED", "COMPLETING"]:
376
+ return ExecutionState.RUNNING
377
+ elif slurm_status in ["PREEMPTED", "TIMEOUT", "NODE_FAIL"]:
378
+ return ExecutionState.PENDING # autoresume
379
+ elif slurm_status in ["CANCELLED"]:
380
+ return ExecutionState.KILLED
381
+ elif slurm_status in ["FAILED"]:
382
+ return ExecutionState.FAILED
383
+ else:
384
+ return ExecutionState.FAILED
385
+
386
+ @staticmethod
387
+ def kill_job(job_id: str) -> None:
388
+ """Kill a SLURM job.
389
+
390
+ Args:
391
+ job_id: The job ID to kill.
392
+ """
393
+ db = ExecutionDB()
394
+ job_data = db.get_job(job_id)
395
+
396
+ if job_data is None:
397
+ raise ValueError(f"Job {job_id} not found")
398
+
399
+ if job_data.executor != "slurm":
400
+ raise ValueError(
401
+ f"Job {job_id} is not a slurm job (executor: {job_data.executor})"
402
+ )
403
+
404
+ killed_something = False
405
+
406
+ result = _kill_slurm_job(
407
+ slurm_job_ids=[job_data.data.get("slurm_job_id")],
408
+ username=job_data.data.get("username"),
409
+ hostname=job_data.data.get("hostname"),
410
+ socket=job_data.data.get("socket"),
411
+ )
412
+
413
+ if result.returncode == 0:
414
+ killed_something = True
415
+
416
+ # Mark job as killed in database if we killed something
417
+ if killed_something:
418
+ job_data.data["killed"] = True
419
+ db.write_job(job_data)
420
+ else:
421
+ raise RuntimeError(
422
+ f"Could not find or kill job {job_id} (slurm_job_id: {job_data.data.get('slurm_job_id')})"
423
+ )
424
+
425
+
426
+ def _create_slurm_sbatch_script(
427
+ cfg: DictConfig,
428
+ task: DictConfig,
429
+ eval_image: str,
430
+ remote_task_subdir: Path,
431
+ invocation_id: str,
432
+ job_id: str,
433
+ ) -> str:
434
+ """Generate the contents of a SLURM sbatch script for a given evaluation task.
435
+
436
+ Args:
437
+ cfg: The configuration object for the evaluation run.
438
+ task: The evaluation task configuration.
439
+ remote_task_subdir: The remote directory path for the `run.sub` file.
440
+ invocation_id: The invocation ID for this evaluation run.
441
+ job_id: The complete job ID string.
442
+
443
+ Returns:
444
+ str: The contents of the sbatch script.
445
+ """
446
+ # get task from mapping, overrides, urls
447
+ tasks_mapping = load_tasks_mapping()
448
+ task_definition = get_task_from_mapping(task.name, tasks_mapping)
449
+ health_url = get_health_url(cfg, get_endpoint_url(cfg, task, task_definition))
450
+
451
+ # TODO(public release): convert to template
452
+ s = "#!/bin/bash\n"
453
+
454
+ # SBATCH headers
455
+ s += "#SBATCH --time {}\n".format(cfg.execution.walltime)
456
+ s += "#SBATCH --account {}\n".format(cfg.execution.account)
457
+ s += "#SBATCH --partition {}\n".format(cfg.execution.partition)
458
+ s += "#SBATCH --nodes {}\n".format(cfg.execution.num_nodes)
459
+ s += "#SBATCH --ntasks-per-node {}\n".format(cfg.execution.ntasks_per_node)
460
+ if cfg.execution.get("gpus_per_node", None) is not None:
461
+ s += "#SBATCH --gpus-per-node {}\n".format(cfg.execution.gpus_per_node)
462
+ if hasattr(cfg.execution, "gres"):
463
+ s += "#SBATCH --gres {}\n".format(cfg.execution.gres)
464
+ job_name = "{account}-{subproject}.{details}".format(
465
+ account=cfg.execution.account,
466
+ subproject=cfg.execution.subproject,
467
+ details=remote_task_subdir.name,
468
+ )
469
+ s += "#SBATCH --job-name {}\n".format(job_name)
470
+ s += "#SBATCH --exclusive\n"
471
+ s += "#SBATCH --output {}\n".format(remote_task_subdir / "logs" / "slurm-%A.out")
472
+ s += "\n"
473
+ s += f'TASK_DIR="{str(remote_task_subdir)}"\n'
474
+ s += "\n"
475
+
476
+ # collect all env vars
477
+ env_vars = copy.deepcopy(dict(cfg.evaluation.get("env_vars", {})))
478
+ env_vars.update(task.get("env_vars", {}))
479
+ api_key_name = get_api_key_name(cfg)
480
+ if api_key_name:
481
+ assert "API_KEY" not in env_vars
482
+ env_vars["API_KEY"] = api_key_name
483
+
484
+ # check if the environment variables are set
485
+ for env_var in env_vars.values():
486
+ if os.getenv(env_var) is None:
487
+ raise ValueError(f"Trying to pass an unset environment variable {env_var}.")
488
+
489
+ # check if required env vars are defined:
490
+ for required_env_var in task_definition.get("required_env_vars", []):
491
+ if required_env_var not in env_vars.keys():
492
+ raise ValueError(
493
+ f"{task.name} task requires environment variable {required_env_var}."
494
+ " Specify it in the task subconfig in the 'env_vars' dict as the following"
495
+ f" pair {required_env_var}: YOUR_ENV_VAR_NAME"
496
+ )
497
+
498
+ # save env vars:
499
+ for env_var_dst, env_var_src in env_vars.items():
500
+ s += f"export {env_var_dst}={os.getenv(env_var_src)}\n"
501
+ all_env_vars = {
502
+ **cfg.execution.get("env_vars", {}).get("deployment", {}),
503
+ **cfg.execution.get("env_vars", {}).get("evaluation", {}),
504
+ }
505
+ if cfg.deployment.get("env_vars"):
506
+ warnings.warn(
507
+ "cfg.deployment.env_vars will be deprecated in future versions. "
508
+ "Use cfg.execution.env_vars.deployment instead.",
509
+ category=DeprecationWarning,
510
+ stacklevel=2,
511
+ )
512
+ all_env_vars.update(cfg.deployment["env_vars"])
513
+ for env_var_dst, env_var_value in all_env_vars.items():
514
+ s += f"export {env_var_dst}={env_var_value}\n"
515
+ if env_vars:
516
+ s += "\n"
517
+
518
+ # auto resume after timeout
519
+ s += _AUTORESUME_HANDLER
520
+ s += "\n\n"
521
+
522
+ # echo the current SLURM_JOB_ID
523
+ s += "# save the current job id\n"
524
+ s += "echo $SLURM_JOB_ID >> {}\n\n".format(
525
+ remote_task_subdir / ".slurm_job_id.list"
526
+ )
527
+
528
+ # shell options
529
+ s += "set -e # exit immediately if any command exits with a non-zero status\n"
530
+ s += "set -u # treat unset variables as an error when substituting\n"
531
+ s += "set -x # print commands and their arguments as they are executed\n"
532
+ s += "\n"
533
+
534
+ # prepare deployment mounts
535
+ deployment_mounts_list = []
536
+ if cfg.deployment.type != "none":
537
+ if checkpoint_path := cfg.deployment.get("checkpoint_path"):
538
+ deployment_mounts_list.append(f"{checkpoint_path}:/checkpoint:ro")
539
+ if cache_path := cfg.deployment.get("cache_path"):
540
+ deployment_mounts_list.append(f"{cache_path}:/cache")
541
+ for source_mnt, target_mnt in (
542
+ cfg.execution.get("mounts", {}).get("deployment", {}).items()
543
+ ):
544
+ deployment_mounts_list.append(f"{source_mnt}:{target_mnt}")
545
+
546
+ # add deployment srun command
547
+ s += "# deployment server\n"
548
+ s += "srun --mpi pmix --overlap "
549
+ s += "--container-image {} ".format(cfg.deployment.image)
550
+ if deployment_mounts_list:
551
+ s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
552
+ if not cfg.execution.get("mounts", {}).get("mount_home", True):
553
+ s += "--no-container-mount-home "
554
+ s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A.out")
555
+ deployment_env_var_names = list(
556
+ cfg.execution.get("env_vars", {}).get("deployment", {})
557
+ )
558
+ if cfg.deployment.get("env_vars"):
559
+ warnings.warn(
560
+ "cfg.deployment.env_vars will be deprecated in future versions. "
561
+ "Use cfg.execution.env_vars.deployment instead.",
562
+ category=DeprecationWarning,
563
+ stacklevel=2,
564
+ )
565
+ deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
566
+ if deployment_env_var_names:
567
+ s += f"--container-env {','.join(deployment_env_var_names)} "
568
+ s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
569
+ s += (
570
+ "SERVER_PID=$! # capture the PID of the server background srun process\n\n"
571
+ )
572
+
573
+ # wait for the server to initialize
574
+ s += _WAIT_FOR_SERVER_HANDLER.format(health_url=health_url)
575
+ s += "\n\n"
576
+
577
+ # prepare evaluation mounts
578
+ evaluation_mounts_list = [
579
+ "{}:/results".format(remote_task_subdir / "artifacts"),
580
+ ]
581
+ for source_mnt, target_mnt in (
582
+ cfg.execution.get("mounts", {}).get("evaluation", {}).items()
583
+ ):
584
+ evaluation_mounts_list.append(f"{source_mnt}:{target_mnt}")
585
+
586
+ # add evaluation srun command
587
+ s += "# evaluation client\n"
588
+ s += "srun --mpi pmix --overlap "
589
+ s += "--container-image {} ".format(eval_image)
590
+ evaluation_env_var_names = list(
591
+ cfg.execution.get("env_vars", {}).get("evaluation", {})
592
+ )
593
+ if evaluation_env_var_names:
594
+ s += "--container-env {} ".format(",".join(evaluation_env_var_names))
595
+ if not cfg.execution.get("mounts", {}).get("mount_home", True):
596
+ s += "--no-container-mount-home "
597
+ s += "--container-mounts {} ".format(",".join(evaluation_mounts_list))
598
+ s += "--output {} ".format(remote_task_subdir / "logs" / "client-%A.out")
599
+ s += "bash -c '"
600
+ s += get_eval_factory_command(cfg, task, task_definition)
601
+ s += "'\n\n"
602
+
603
+ # terminate the server after all evaluation clients finish
604
+ if cfg.deployment.type != "none":
605
+ s += "kill $SERVER_PID # terminate the server to finish gracefully\n\n"
606
+
607
+ # auto-export
608
+ if cfg.execution.get("auto_export", {}).get("destinations", []):
609
+ s += _generate_auto_export_section(cfg, job_id)
610
+
611
+ return s
612
+
613
+
614
+ def _generate_auto_export_section(
615
+ cfg: DictConfig,
616
+ job_id: str, # Complete job_id string
617
+ ) -> str:
618
+ """Generate simple auto-export section for sbatch script."""
619
+ auto_export_config = cfg.execution.get("auto_export", {})
620
+ destinations = auto_export_config.get("destinations", [])
621
+
622
+ if not destinations:
623
+ return ""
624
+
625
+ s = "\n# Auto-export on success\n"
626
+ s += "EVAL_EXIT_CODE=$?\n"
627
+ s += "if [ $EVAL_EXIT_CODE -eq 0 ]; then\n"
628
+ s += " echo 'Evaluation completed successfully. Starting auto-export...'\n"
629
+ s += " set +e\n" # per exporter failure allowed
630
+ s += " set +x\n"
631
+ s += ' cd "$TASK_DIR/artifacts"\n'
632
+ auto_export_cfg = OmegaConf.to_container(
633
+ cfg.execution.get("auto_export", {}), resolve=True
634
+ )
635
+ yaml_str = yaml.safe_dump(
636
+ {"execution": {"auto_export": auto_export_cfg}}, sort_keys=False
637
+ )
638
+ s += " cat > export_config.yml << 'EOF'\n"
639
+ s += yaml_str
640
+ s += "EOF\n"
641
+ for dest in destinations:
642
+ s += f" echo 'Exporting to {dest}...'\n"
643
+ s += f" nemo-evaluator-launcher export {job_id} --dest {dest} || echo 'Export to {dest} failed'\n"
644
+
645
+ s += " echo 'Auto-export completed.'\n"
646
+ s += "else\n"
647
+ s += " echo 'Evaluation failed with exit code $EVAL_EXIT_CODE. Skipping auto-export.'\n"
648
+ s += "fi\n"
649
+
650
+ return s
651
+
652
+
653
+ def _open_master_connection(
654
+ username: str,
655
+ hostname: str,
656
+ socket: str,
657
+ ) -> str | None:
658
+ ssh_command = f"ssh -MNf -S {socket} {username}@{hostname}"
659
+ completed_process = subprocess.run(args=shlex.split(ssh_command))
660
+ if completed_process.returncode == 0:
661
+ return socket
662
+ return None
663
+
664
+
665
+ def _close_master_connection(
666
+ username: str,
667
+ hostname: str,
668
+ socket: str | None,
669
+ ) -> None:
670
+ if socket is None:
671
+ return
672
+ ssh_command = f"ssh -O exit -S {socket} {username}@{hostname}"
673
+ completed_process = subprocess.run(
674
+ args=shlex.split(ssh_command), capture_output=True
675
+ )
676
+ if completed_process.returncode != 0:
677
+ raise RuntimeError(
678
+ "failed to close the master connection\n{}".format(
679
+ completed_process.stderr.decode("utf-8")
680
+ )
681
+ )
682
+
683
+
684
+ def _make_remote_execution_output_dir(
685
+ dirpath: str,
686
+ username: str,
687
+ hostname: str,
688
+ socket: str | None,
689
+ ) -> None:
690
+ mkdir_command = f"mkdir -p {dirpath}"
691
+ ssh_command = ["ssh"]
692
+ if socket is not None:
693
+ ssh_command.append(f"-S {socket}")
694
+ ssh_command.append(f"{username}@{hostname}")
695
+ ssh_command.append(mkdir_command)
696
+ ssh_command = " ".join(ssh_command)
697
+ completed_process = subprocess.run(args=shlex.split(ssh_command))
698
+ if completed_process.returncode != 0:
699
+ raise RuntimeError(
700
+ "failed to make a remote execution output dir\n{}".format(
701
+ completed_process.stderr.decode("utf-8")
702
+ )
703
+ )
704
+
705
+
706
+ def _rsync_upload_rundirs(
707
+ local_sources: List[Path],
708
+ remote_target: str,
709
+ username: str,
710
+ hostname: str,
711
+ ) -> None:
712
+ """Upload local run directories to a remote host using rsync over SSH.
713
+
714
+ Args:
715
+ local_sources: List of local Path objects to upload.
716
+ remote_target: Remote directory path as a string.
717
+ hostname: SSH hostname.
718
+ username: SSH username.
719
+
720
+ Raises:
721
+ RuntimeError: If rsync fails.
722
+ """
723
+ for local_source in local_sources:
724
+ assert local_source.is_dir()
725
+ remote_destination_str = f"{username}@{hostname}:{remote_target}"
726
+ local_sources_str = " ".join(map(str, local_sources))
727
+ rsync_upload_command = f"rsync -qcaz {local_sources_str} {remote_destination_str}"
728
+ completed_process = subprocess.run(args=shlex.split(rsync_upload_command))
729
+ if completed_process.returncode != 0:
730
+ raise RuntimeError(
731
+ "failed to upload local sources\n{}".format(
732
+ completed_process.stderr.decode("utf-8")
733
+ )
734
+ )
735
+
736
+
737
+ def _sbatch_remote_runsubs(
738
+ remote_runsub_paths: List[Path],
739
+ username: str,
740
+ hostname: str,
741
+ socket: str | None,
742
+ ) -> List[str]:
743
+ sbatch_commands = [
744
+ "sbatch {}".format(remote_runsub_path)
745
+ for remote_runsub_path in remote_runsub_paths
746
+ ]
747
+ sbatch_commands = " ; ".join(sbatch_commands)
748
+
749
+ ssh_command = ["ssh"]
750
+ if socket is not None:
751
+ ssh_command.append(f"-S {socket}")
752
+ ssh_command.append(f"{username}@{hostname}")
753
+ ssh_command.append(sbatch_commands)
754
+ ssh_command = " ".join(ssh_command)
755
+
756
+ completed_process = subprocess.run(
757
+ args=shlex.split(ssh_command), capture_output=True
758
+ )
759
+ if completed_process.returncode != 0:
760
+ raise RuntimeError(
761
+ "failed to submit sbatch scripts for execution\n{}".format(
762
+ completed_process.stderr.decode("utf-8")
763
+ )
764
+ )
765
+
766
+ sbatch_output = completed_process.stdout.decode("utf-8")
767
+ slurm_job_ids = re.findall(r"(?<=Submitted batch job )\d+", sbatch_output)
768
+ return slurm_job_ids
769
+
770
+
771
+ def _query_slurm_jobs_status(
772
+ slurm_job_ids: List[str],
773
+ username: str,
774
+ hostname: str,
775
+ socket: str | None,
776
+ ) -> Dict[str, str]:
777
+ """Query SLURM for job statuses using sacct command.
778
+
779
+ Args:
780
+ slurm_job_ids: List of SLURM job IDs to query.
781
+ username: SSH username.
782
+ hostname: SSH hostname.
783
+ socket: control socket location or None
784
+
785
+ Returns:
786
+ Dict mapping from slurm_job_id to returned slurm status.
787
+ """
788
+ if len(slurm_job_ids) == 0:
789
+ return {}
790
+ sacct_command = "sacct -j {} --format='JobID,State%32' --noheader -P".format(
791
+ ",".join(slurm_job_ids)
792
+ )
793
+ ssh_command = ["ssh"]
794
+ if socket is not None:
795
+ ssh_command.append(f"-S {socket}")
796
+ ssh_command.append(f"{username}@{hostname}")
797
+ ssh_command.append(sacct_command)
798
+ ssh_command = " ".join(ssh_command)
799
+ completed_process = subprocess.run(
800
+ args=shlex.split(ssh_command), capture_output=True
801
+ )
802
+ if completed_process.returncode != 0:
803
+ raise RuntimeError(
804
+ "failed to query slurm job status\n{}".format(
805
+ completed_process.stderr.decode("utf-8")
806
+ )
807
+ )
808
+ sacct_output = completed_process.stdout.decode("utf-8")
809
+ sacct_output_lines = sacct_output.strip().split("\n")
810
+ slurm_jobs_status = {}
811
+ for slurm_job_id in slurm_job_ids:
812
+ slurm_job_status = _parse_slurm_job_status(slurm_job_id, sacct_output_lines)
813
+ slurm_jobs_status[slurm_job_id] = slurm_job_status
814
+ return slurm_jobs_status
815
+
816
+
817
+ def _kill_slurm_job(
818
+ slurm_job_ids: List[str], username: str, hostname: str, socket: str | None
819
+ ) -> None:
820
+ """Kill a SLURM job.
821
+
822
+ Args:
823
+ slurm_job_ids: List of SLURM job IDs to kill.
824
+ username: SSH username.
825
+ hostname: SSH hostname.
826
+ socket: control socket location or None
827
+ """
828
+ if len(slurm_job_ids) == 0:
829
+ return {}
830
+ kill_command = "scancel {}".format(",".join(slurm_job_ids))
831
+ ssh_command = ["ssh"]
832
+ if socket is not None:
833
+ ssh_command.append(f"-S {socket}")
834
+ ssh_command.append(f"{username}@{hostname}")
835
+ ssh_command.append(kill_command)
836
+ ssh_command = " ".join(ssh_command)
837
+ completed_process = subprocess.run(
838
+ args=shlex.split(ssh_command), capture_output=True
839
+ )
840
+ if completed_process.returncode != 0:
841
+ raise RuntimeError(
842
+ "failed to kill slurm job\n{}".format(
843
+ completed_process.stderr.decode("utf-8")
844
+ )
845
+ )
846
+ return completed_process
847
+
848
+
849
+ def _parse_slurm_job_status(slurm_job_id: str, sacct_output_lines: List[str]) -> str:
850
+ """Parse SLURM job status from sacct output for a specific job.
851
+
852
+ Args:
853
+ slurm_job_id: The SLURM job ID to parse.
854
+ sacct_output_lines: Lines from sacct output.
855
+
856
+ Returns:
857
+ SLURM status string.
858
+ """
859
+ for line in sacct_output_lines:
860
+ if line.startswith(f"{slurm_job_id}|"):
861
+ state = line.split("|")[1]
862
+ state = state.strip()
863
+ if state:
864
+ state_split = state.split()
865
+ if len(state_split) > 0:
866
+ return state_split[0]
867
+ return "UNKNOWN"
868
+
869
+
870
+ def _read_autoresumed_slurm_job_ids(
871
+ slurm_job_ids: List[str],
872
+ remote_rundir_paths: List[Path],
873
+ username: str,
874
+ hostname: str,
875
+ socket: str | None,
876
+ ) -> Dict[str, List[str]]:
877
+ assert len(slurm_job_ids) == len(remote_rundir_paths)
878
+ slurm_job_id_list_paths = [
879
+ str(remote_rundir_path / ".slurm_job_id.list")
880
+ for remote_rundir_path in remote_rundir_paths
881
+ ]
882
+ slurm_job_id_list_strs = _read_files_from_remote(
883
+ slurm_job_id_list_paths, username, hostname, socket
884
+ )
885
+ assert len(slurm_job_id_list_strs) == len(slurm_job_ids)
886
+ autoresumed_slurm_job_ids = {}
887
+ for i, slurm_job_id_list_str in enumerate(slurm_job_id_list_strs):
888
+ slurm_job_id = slurm_job_ids[i]
889
+ slurm_job_id_list = slurm_job_id_list_str.split()
890
+ autoresumed_slurm_job_ids[slurm_job_id] = slurm_job_id_list
891
+ return autoresumed_slurm_job_ids
892
+
893
+
894
+ def _read_files_from_remote(
895
+ filepaths: List[Path],
896
+ username: str,
897
+ hostname: str,
898
+ socket: str | None,
899
+ ) -> List[str]:
900
+ cat_commands = [
901
+ "echo _START_OF_FILE_ ; cat {} 2>/dev/null ; echo _END_OF_FILE_ ".format(
902
+ filepath
903
+ )
904
+ for filepath in filepaths
905
+ ]
906
+ cat_commands = " ; ".join(cat_commands)
907
+ ssh_command = ["ssh"]
908
+ if socket is not None:
909
+ ssh_command.append(f"-S {socket}")
910
+ ssh_command.append(f"{username}@{hostname}")
911
+ ssh_command.append(cat_commands)
912
+ ssh_command = " ".join(ssh_command)
913
+ completed_process = subprocess.run(
914
+ args=shlex.split(ssh_command), capture_output=True
915
+ )
916
+ if completed_process.returncode != 0:
917
+ raise RuntimeError(
918
+ "failed to read files from remote\n{}".format(
919
+ completed_process.stderr.decode("utf-8")
920
+ )
921
+ )
922
+ cat_outputs = completed_process.stdout.decode("utf-8")
923
+ cat_outputs = cat_outputs.replace("\n", " ")
924
+ matches = re.findall(r"(?<=_START_OF_FILE_)(.*?)(?=_END_OF_FILE_)", cat_outputs)
925
+ outputs = [match.strip() for match in matches]
926
+ return outputs
927
+
928
+
929
+ def _get_progress(
930
+ remote_rundir_paths: List[Path],
931
+ username: str,
932
+ hostname: str,
933
+ socket: str | None,
934
+ ) -> List[Optional[float]]:
935
+ remote_progress_paths = [
936
+ remote_rundir_path / "artifacts" / "progress"
937
+ for remote_rundir_path in remote_rundir_paths
938
+ ]
939
+ remote_run_config_paths = [
940
+ remote_rundir_path / "artifacts" / "run_config.yml"
941
+ for remote_rundir_path in remote_rundir_paths
942
+ ]
943
+ progress_strs = _read_files_from_remote(
944
+ remote_progress_paths, username, hostname, socket
945
+ )
946
+ if any(map(bool, progress_strs)):
947
+ run_config_strs = _read_files_from_remote(
948
+ remote_run_config_paths, username, hostname, socket
949
+ )
950
+ else:
951
+ run_config_strs = [""] * len(progress_strs)
952
+ progress_list = []
953
+ for progress_str, run_config_str in zip(progress_strs, run_config_strs):
954
+ if not progress_str or not run_config_str:
955
+ progress_list.append(None)
956
+ continue
957
+ run_config = yaml.safe_load(run_config_str)
958
+ dataset_size = get_eval_factory_dataset_size_from_run_config(run_config)
959
+ if dataset_size is not None:
960
+ progress = int(progress_str) / dataset_size
961
+ else:
962
+ progress = int(progress_str)
963
+ progress_list.append(progress)
964
+ return progress_list
965
+
966
+
967
+ _AUTORESUME_HANDLER = """
968
+ _this_script=$0
969
+ _prev_slurm_job_id=$1
970
+ # Handle automatic resumption after some failed state.
971
+ if [[ "$_prev_slurm_job_id" != "" ]]; then
972
+ _prev_state=`sacct -j $_prev_slurm_job_id -P -n -o State | head -n 1`
973
+ _prev_info="previous SLURM_JOB_ID $_prev_slurm_job_id finished with '$_prev_state' state."
974
+ if [[ $_prev_state == 'TIMEOUT' || $_prev_state == 'PREEMPTED' || $_prev_state == 'NODE_FAIL' ]]; then
975
+ echo "$_prev_info RESUMING..."
976
+ else
977
+ echo "$_prev_info EXIT!"
978
+ if [[ $_prev_state == 'COMPLETED' ]]; then
979
+ exit 0
980
+ else
981
+ exit 1
982
+ fi
983
+ fi
984
+ fi
985
+ # Schedule next execution of this script with the current $SLURM_JOB_ID as an argument.
986
+ # "afternotok" means next execution will be invoked only if the current execution terminates in some failed state.
987
+ sbatch --dependency=afternotok:$SLURM_JOB_ID $_this_script $SLURM_JOB_ID
988
+ """.strip()
989
+
990
+
991
+ _WAIT_FOR_SERVER_HANDLER = """
992
+ date
993
+ # wait for the server to initialize
994
+ bash -c 'while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" {health_url})" != "200" ]]; do kill -0 '"$SERVER_PID"' 2>/dev/null || {{ echo "Server process '"$SERVER_PID"' died"; exit 1; }}; sleep 5; done'
995
+ date
996
+ """.strip()