nemo-evaluator-launcher 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nemo-evaluator-launcher might be problematic. Click here for more details.

Files changed (60) hide show
  1. nemo_evaluator_launcher/__init__.py +79 -0
  2. nemo_evaluator_launcher/api/__init__.py +24 -0
  3. nemo_evaluator_launcher/api/functional.py +698 -0
  4. nemo_evaluator_launcher/api/types.py +98 -0
  5. nemo_evaluator_launcher/api/utils.py +19 -0
  6. nemo_evaluator_launcher/cli/__init__.py +15 -0
  7. nemo_evaluator_launcher/cli/export.py +267 -0
  8. nemo_evaluator_launcher/cli/info.py +512 -0
  9. nemo_evaluator_launcher/cli/kill.py +41 -0
  10. nemo_evaluator_launcher/cli/ls_runs.py +134 -0
  11. nemo_evaluator_launcher/cli/ls_tasks.py +136 -0
  12. nemo_evaluator_launcher/cli/main.py +226 -0
  13. nemo_evaluator_launcher/cli/run.py +200 -0
  14. nemo_evaluator_launcher/cli/status.py +164 -0
  15. nemo_evaluator_launcher/cli/version.py +55 -0
  16. nemo_evaluator_launcher/common/__init__.py +16 -0
  17. nemo_evaluator_launcher/common/execdb.py +283 -0
  18. nemo_evaluator_launcher/common/helpers.py +366 -0
  19. nemo_evaluator_launcher/common/logging_utils.py +357 -0
  20. nemo_evaluator_launcher/common/mapping.py +295 -0
  21. nemo_evaluator_launcher/common/printing_utils.py +93 -0
  22. nemo_evaluator_launcher/configs/__init__.py +15 -0
  23. nemo_evaluator_launcher/configs/default.yaml +28 -0
  24. nemo_evaluator_launcher/configs/deployment/generic.yaml +33 -0
  25. nemo_evaluator_launcher/configs/deployment/nim.yaml +32 -0
  26. nemo_evaluator_launcher/configs/deployment/none.yaml +16 -0
  27. nemo_evaluator_launcher/configs/deployment/sglang.yaml +38 -0
  28. nemo_evaluator_launcher/configs/deployment/trtllm.yaml +24 -0
  29. nemo_evaluator_launcher/configs/deployment/vllm.yaml +42 -0
  30. nemo_evaluator_launcher/configs/execution/lepton/default.yaml +92 -0
  31. nemo_evaluator_launcher/configs/execution/local.yaml +19 -0
  32. nemo_evaluator_launcher/configs/execution/slurm/default.yaml +34 -0
  33. nemo_evaluator_launcher/executors/__init__.py +22 -0
  34. nemo_evaluator_launcher/executors/base.py +120 -0
  35. nemo_evaluator_launcher/executors/lepton/__init__.py +16 -0
  36. nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +609 -0
  37. nemo_evaluator_launcher/executors/lepton/executor.py +1004 -0
  38. nemo_evaluator_launcher/executors/lepton/job_helpers.py +398 -0
  39. nemo_evaluator_launcher/executors/local/__init__.py +15 -0
  40. nemo_evaluator_launcher/executors/local/executor.py +605 -0
  41. nemo_evaluator_launcher/executors/local/run.template.sh +103 -0
  42. nemo_evaluator_launcher/executors/registry.py +38 -0
  43. nemo_evaluator_launcher/executors/slurm/__init__.py +15 -0
  44. nemo_evaluator_launcher/executors/slurm/executor.py +1147 -0
  45. nemo_evaluator_launcher/exporters/__init__.py +36 -0
  46. nemo_evaluator_launcher/exporters/base.py +121 -0
  47. nemo_evaluator_launcher/exporters/gsheets.py +409 -0
  48. nemo_evaluator_launcher/exporters/local.py +502 -0
  49. nemo_evaluator_launcher/exporters/mlflow.py +619 -0
  50. nemo_evaluator_launcher/exporters/registry.py +40 -0
  51. nemo_evaluator_launcher/exporters/utils.py +624 -0
  52. nemo_evaluator_launcher/exporters/wandb.py +490 -0
  53. nemo_evaluator_launcher/package_info.py +38 -0
  54. nemo_evaluator_launcher/resources/mapping.toml +380 -0
  55. nemo_evaluator_launcher-0.1.28.dist-info/METADATA +494 -0
  56. nemo_evaluator_launcher-0.1.28.dist-info/RECORD +60 -0
  57. nemo_evaluator_launcher-0.1.28.dist-info/WHEEL +5 -0
  58. nemo_evaluator_launcher-0.1.28.dist-info/entry_points.txt +3 -0
  59. nemo_evaluator_launcher-0.1.28.dist-info/licenses/LICENSE +451 -0
  60. nemo_evaluator_launcher-0.1.28.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1147 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ """SLURM executor implementation for nemo-evaluator-launcher.
17
+
18
+ Handles submitting evaluation jobs to a SLURM cluster via SSH and sbatch scripts.
19
+ """
20
+
21
+ import copy
22
+ import os
23
+ import re
24
+ import shlex
25
+ import subprocess
26
+ import tempfile
27
+ import time
28
+ import warnings
29
+ from pathlib import Path
30
+ from typing import Dict, List, Optional
31
+
32
+ import yaml
33
+ from omegaconf import DictConfig, OmegaConf
34
+
35
+ from nemo_evaluator_launcher.common.execdb import (
36
+ ExecutionDB,
37
+ JobData,
38
+ generate_invocation_id,
39
+ generate_job_id,
40
+ )
41
+ from nemo_evaluator_launcher.common.helpers import (
42
+ CmdAndReadableComment,
43
+ get_api_key_name,
44
+ get_endpoint_url,
45
+ get_eval_factory_command,
46
+ get_eval_factory_config,
47
+ get_eval_factory_dataset_size_from_run_config,
48
+ get_health_url,
49
+ get_timestamp_string,
50
+ )
51
+ from nemo_evaluator_launcher.common.logging_utils import logger
52
+ from nemo_evaluator_launcher.common.mapping import (
53
+ get_task_from_mapping,
54
+ load_tasks_mapping,
55
+ )
56
+ from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey, red
57
+ from nemo_evaluator_launcher.executors.base import (
58
+ BaseExecutor,
59
+ ExecutionState,
60
+ ExecutionStatus,
61
+ )
62
+ from nemo_evaluator_launcher.executors.registry import register_executor
63
+
64
+
65
+ @register_executor("slurm")
66
+ class SlurmExecutor(BaseExecutor):
67
+ @staticmethod
68
+ def execute_eval(cfg: DictConfig, dry_run: bool = False) -> str:
69
+ """Submit evaluation jobs to a SLURM cluster using the provided configuration.
70
+
71
+ Args:
72
+ cfg: The configuration object for the evaluation run.
73
+ dry_run: If True, prepare scripts and save them without submission.
74
+
75
+ Returns:
76
+ str: The invocation ID for the evaluation run.
77
+
78
+ Raises:
79
+ AssertionError: If deployment type is 'none'.
80
+ RuntimeError: If remote directory creation or sbatch submission fails.
81
+ """
82
+
83
+ # Generate invocation ID
84
+ invocation_id = generate_invocation_id()
85
+
86
+ local_runsub_paths = []
87
+ remote_runsub_paths = []
88
+
89
+ with tempfile.TemporaryDirectory() as tmpdirname:
90
+ timestamp = get_timestamp_string(include_microseconds=False)
91
+ rundir_name = timestamp + "-" + invocation_id
92
+ remote_rundir = Path(cfg.execution.output_dir) / rundir_name
93
+ local_rundir = Path(tmpdirname) / rundir_name
94
+ local_rundir.mkdir()
95
+
96
+ # Preload mapping for image resolution
97
+ tasks_mapping = load_tasks_mapping()
98
+ eval_images: list[str] = []
99
+
100
+ is_potentially_unsafe = False
101
+ for idx, task in enumerate(cfg.evaluation.tasks):
102
+ # calculate job_id
103
+ job_id = f"{invocation_id}.{idx}"
104
+
105
+ # preapre locally
106
+ remote_task_subdir = remote_rundir / task.name
107
+ local_task_subdir = local_rundir / task.name
108
+ local_task_subdir.mkdir() # this ensures the task.name hasn't been used
109
+ (local_task_subdir / "logs").mkdir()
110
+ (local_task_subdir / "artifacts").mkdir()
111
+
112
+ # resolve eval image and pass directly via task override
113
+ task_definition = get_task_from_mapping(task.name, tasks_mapping)
114
+ eval_image = task_definition["container"]
115
+ if "container" in task:
116
+ eval_image = task["container"]
117
+
118
+ eval_images.append(eval_image)
119
+
120
+ # generate and write down sbatch script
121
+ sbatch_script_content_struct = _create_slurm_sbatch_script(
122
+ cfg=cfg,
123
+ task=task,
124
+ eval_image=eval_image,
125
+ remote_task_subdir=remote_task_subdir,
126
+ invocation_id=invocation_id,
127
+ job_id=job_id,
128
+ )
129
+ sbatch_script_content_str = sbatch_script_content_struct.cmd
130
+
131
+ # We accumulate if any task contains unsafe commands
132
+ is_potentially_unsafe = (
133
+ is_potentially_unsafe
134
+ or sbatch_script_content_struct.is_potentially_unsafe
135
+ )
136
+ local_runsub_path = local_task_subdir / "run.sub"
137
+ remote_runsub_path = remote_task_subdir / "run.sub"
138
+ with open(local_runsub_path, "w") as f:
139
+ f.write(sbatch_script_content_str.rstrip("\n") + "\n")
140
+
141
+ local_runsub_paths.append(local_runsub_path)
142
+ remote_runsub_paths.append(remote_runsub_path)
143
+
144
+ if dry_run:
145
+ print(bold("\n\n=============================================\n\n"))
146
+ print(bold(cyan("DRY RUN: SLURM scripts prepared")))
147
+ for idx, local_runsub_path in enumerate(local_runsub_paths):
148
+ print(cyan(f"\n\n=========== Task {idx} =====================\n\n"))
149
+ with open(local_runsub_path, "r") as f:
150
+ print(grey(f.read()))
151
+ print(bold("To submit jobs") + ", run the executor without --dry-run")
152
+ if is_potentially_unsafe:
153
+ print(
154
+ red(
155
+ "\nFound `pre_cmd` which carries security risk. When running without --dry-run "
156
+ "make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1"
157
+ )
158
+ )
159
+
160
+ return invocation_id
161
+
162
+ if is_potentially_unsafe:
163
+ if os.environ.get("NEMO_EVALUATOR_TRUST_PRE_CMD", "") == "1":
164
+ logger.warning(
165
+ "Found non-empty task commands (e.g. `pre_cmd`) and NEMO_EVALUATOR_TRUST_PRE_CMD "
166
+ "is set, proceeding with caution."
167
+ )
168
+
169
+ else:
170
+ logger.error(
171
+ "Found non-empty task commands (e.g. `pre_cmd`) and NEMO_EVALUATOR_TRUST_PRE_CMD "
172
+ "is not set. This might carry security risk and unstable environments. "
173
+ "To continue, make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1.",
174
+ )
175
+ raise AttributeError(
176
+ "Untrusted command found in config, make sure you trust and "
177
+ "set NEMO_EVALUATOR_TRUST_PRE_CMD=1."
178
+ )
179
+
180
+ socket = str(Path(tmpdirname) / "socket")
181
+ socket_or_none = _open_master_connection(
182
+ username=cfg.execution.username,
183
+ hostname=cfg.execution.hostname,
184
+ socket=socket,
185
+ )
186
+ _make_remote_execution_output_dir(
187
+ dirpath=cfg.execution.output_dir,
188
+ username=cfg.execution.username,
189
+ hostname=cfg.execution.hostname,
190
+ socket=socket_or_none,
191
+ )
192
+ _rsync_upload_rundirs(
193
+ local_sources=[local_rundir],
194
+ remote_target=cfg.execution.output_dir,
195
+ username=cfg.execution.username,
196
+ hostname=cfg.execution.hostname,
197
+ )
198
+ slurm_job_ids = _sbatch_remote_runsubs(
199
+ remote_runsub_paths=remote_runsub_paths,
200
+ username=cfg.execution.username,
201
+ hostname=cfg.execution.hostname,
202
+ socket=socket_or_none,
203
+ )
204
+ _close_master_connection(
205
+ username=cfg.execution.username,
206
+ hostname=cfg.execution.hostname,
207
+ socket=socket_or_none,
208
+ )
209
+
210
+ # save launched jobs metadata
211
+ db = ExecutionDB()
212
+ for idx, (slurm_job_id, remote_runsub_path) in enumerate(
213
+ zip(slurm_job_ids, remote_runsub_paths)
214
+ ):
215
+ job_id = generate_job_id(invocation_id, idx)
216
+ db.write_job(
217
+ job=JobData(
218
+ invocation_id=invocation_id,
219
+ job_id=job_id,
220
+ timestamp=time.time(),
221
+ executor="slurm",
222
+ data={
223
+ "slurm_job_id": slurm_job_id,
224
+ "remote_rundir_path": str(remote_runsub_path.parent),
225
+ "hostname": cfg.execution.hostname,
226
+ "username": cfg.execution.username,
227
+ "eval_image": eval_images[idx],
228
+ },
229
+ config=OmegaConf.to_object(cfg),
230
+ )
231
+ )
232
+ return invocation_id
233
+
234
+ @staticmethod
235
+ def get_status(id: str) -> List[ExecutionStatus]:
236
+ """Get the status of a specific SLURM job or all jobs in an invocation group.
237
+
238
+ Args:
239
+ id: Unique job identifier or invocation identifier.
240
+
241
+ Returns:
242
+ List containing the execution status for the job(s).
243
+ """
244
+ db = ExecutionDB()
245
+
246
+ # If id looks like an invocation_id
247
+ if "." not in id:
248
+ jobs = db.get_jobs(id)
249
+ if not jobs:
250
+ return []
251
+ return SlurmExecutor._get_status_for_invocation(jobs)
252
+
253
+ # Otherwise, treat as job_id
254
+ else:
255
+ job_data = db.get_job(id)
256
+ if job_data is None or job_data.executor != "slurm":
257
+ return []
258
+ return [SlurmExecutor._get_status_for_job(id, job_data)]
259
+
260
+ @staticmethod
261
+ def _get_status_for_job(id: str, job_data: JobData) -> ExecutionStatus:
262
+ slurm_job_id = job_data.data.get("slurm_job_id")
263
+ if not slurm_job_id:
264
+ return ExecutionStatus(id=id, state=ExecutionState.FAILED)
265
+
266
+ try:
267
+ return SlurmExecutor._query_slurm_for_status_and_progress(
268
+ slurm_job_ids=[slurm_job_id],
269
+ remote_rundir_paths=[Path(job_data.data.get("remote_rundir_path"))],
270
+ username=job_data.data["username"],
271
+ hostname=job_data.data["hostname"],
272
+ job_id_to_execdb_id={slurm_job_id: id},
273
+ )[0]
274
+ except Exception:
275
+ return ExecutionStatus(id=id, state=ExecutionState.FAILED)
276
+
277
+ @staticmethod
278
+ def _get_status_for_invocation(jobs: dict) -> List[ExecutionStatus]:
279
+ slurm_job_ids = []
280
+ remote_rundir_paths = []
281
+ job_id_to_execdb_id = {}
282
+ username = None
283
+ hostname = None
284
+
285
+ for job_id, job_data in jobs.items():
286
+ if job_data.executor != "slurm":
287
+ continue
288
+ slurm_job_id = job_data.data.get("slurm_job_id")
289
+ if slurm_job_id:
290
+ slurm_job_ids.append(slurm_job_id)
291
+ remote_rundir_paths.append(
292
+ Path(job_data.data.get("remote_rundir_path"))
293
+ )
294
+ job_id_to_execdb_id[slurm_job_id] = job_id
295
+ username = job_data.data.get("username")
296
+ hostname = job_data.data.get("hostname")
297
+
298
+ if not slurm_job_ids or not remote_rundir_paths or not username or not hostname:
299
+ return [
300
+ ExecutionStatus(id=job_id, state=ExecutionState.FAILED)
301
+ for job_id in jobs.keys()
302
+ ]
303
+
304
+ try:
305
+ return SlurmExecutor._query_slurm_for_status_and_progress(
306
+ slurm_job_ids=slurm_job_ids,
307
+ remote_rundir_paths=remote_rundir_paths,
308
+ username=username,
309
+ hostname=hostname,
310
+ job_id_to_execdb_id=job_id_to_execdb_id,
311
+ )
312
+ except Exception:
313
+ return [
314
+ ExecutionStatus(id=job_id, state=ExecutionState.FAILED)
315
+ for job_id in jobs.keys()
316
+ ]
317
+
318
+ @staticmethod
319
+ def _query_slurm_for_status_and_progress(
320
+ slurm_job_ids: List[str],
321
+ remote_rundir_paths: List[Path],
322
+ username: str,
323
+ hostname: str,
324
+ job_id_to_execdb_id: dict,
325
+ ) -> List[ExecutionStatus]:
326
+ with tempfile.TemporaryDirectory() as tmpdirname:
327
+ socket = str(Path(tmpdirname) / "socket")
328
+ socket_or_none = _open_master_connection(
329
+ username=username,
330
+ hostname=hostname,
331
+ socket=socket,
332
+ )
333
+ # get slurm job status for initial jobs:
334
+ slurm_jobs_status = _query_slurm_jobs_status(
335
+ slurm_job_ids=slurm_job_ids,
336
+ username=username,
337
+ hostname=hostname,
338
+ socket=socket_or_none,
339
+ )
340
+ # handle slurm status for autoresumed jobs:
341
+ autoresumed_slurm_job_ids = _read_autoresumed_slurm_job_ids(
342
+ slurm_job_ids=slurm_job_ids,
343
+ remote_rundir_paths=remote_rundir_paths,
344
+ username=username,
345
+ hostname=hostname,
346
+ socket=socket_or_none,
347
+ )
348
+ latest_slurm_job_ids = {
349
+ slurm_job_id: slurm_job_id_list[-1]
350
+ for slurm_job_id, slurm_job_id_list in autoresumed_slurm_job_ids.items()
351
+ if len(slurm_job_id_list) > 0 and slurm_job_id_list[-1] != slurm_job_id
352
+ }
353
+ latest_slurm_jobs_status = _query_slurm_jobs_status(
354
+ slurm_job_ids=list(latest_slurm_job_ids.values()),
355
+ username=username,
356
+ hostname=hostname,
357
+ socket=socket_or_none,
358
+ )
359
+ # get progress:
360
+ progress_list = _get_progress(
361
+ remote_rundir_paths=remote_rundir_paths,
362
+ username=username,
363
+ hostname=hostname,
364
+ socket=socket_or_none,
365
+ )
366
+ _close_master_connection(
367
+ username=username,
368
+ hostname=hostname,
369
+ socket=socket_or_none,
370
+ )
371
+ statuses = []
372
+ for i, slurm_job_id in enumerate(slurm_job_ids):
373
+ slurm_status = slurm_jobs_status[slurm_job_id]
374
+ if slurm_job_id in latest_slurm_job_ids:
375
+ latest_slurm_job_id = latest_slurm_job_ids[slurm_job_id]
376
+ slurm_status = latest_slurm_jobs_status[latest_slurm_job_id]
377
+ progress = progress_list[i]
378
+ progress = progress if progress is not None else "unknown"
379
+ execution_state = SlurmExecutor._map_slurm_state_to_execution_state(
380
+ slurm_status
381
+ )
382
+ execdb_job_id = job_id_to_execdb_id.get(slurm_job_id)
383
+ if execdb_job_id:
384
+ statuses.append(
385
+ ExecutionStatus(
386
+ id=execdb_job_id,
387
+ state=execution_state,
388
+ progress=progress,
389
+ )
390
+ )
391
+ return statuses
392
+
393
+ @staticmethod
394
+ def _map_slurm_state_to_execution_state(slurm_status: str) -> ExecutionState:
395
+ """Map SLURM state to ExecutionState.
396
+
397
+ Args:
398
+ slurm_status: SLURM status string.
399
+
400
+ Returns:
401
+ Corresponding ExecutionState.
402
+ """
403
+ if slurm_status in ["COMPLETED"]:
404
+ return ExecutionState.SUCCESS
405
+ elif slurm_status in [
406
+ "PENDING",
407
+ "RESV_DEL_HOLD",
408
+ "REQUEUE_FED",
409
+ "REQUEUE_HOLD",
410
+ "REQUEUED",
411
+ "REVOKED",
412
+ ]:
413
+ return ExecutionState.PENDING
414
+ elif slurm_status in ["RUNNING", "CONFIGURING", "SUSPENDED", "COMPLETING"]:
415
+ return ExecutionState.RUNNING
416
+ elif slurm_status in ["PREEMPTED", "TIMEOUT", "NODE_FAIL"]:
417
+ return ExecutionState.PENDING # autoresume
418
+ elif slurm_status in ["CANCELLED"]:
419
+ return ExecutionState.KILLED
420
+ elif slurm_status in ["FAILED"]:
421
+ return ExecutionState.FAILED
422
+ else:
423
+ return ExecutionState.FAILED
424
+
425
+ @staticmethod
426
+ def kill_job(job_id: str) -> None:
427
+ """Kill a SLURM job.
428
+
429
+ Args:
430
+ job_id: The job ID (e.g., abc123.0) to kill.
431
+ """
432
+ db = ExecutionDB()
433
+ job_data = db.get_job(job_id)
434
+
435
+ if job_data is None:
436
+ raise ValueError(f"Job {job_id} not found")
437
+
438
+ if job_data.executor != "slurm":
439
+ raise ValueError(
440
+ f"Job {job_id} is not a slurm job (executor: {job_data.executor})"
441
+ )
442
+
443
+ # OPTIMIZATION: Query status AND kill in ONE SSH call
444
+ slurm_status, result = _kill_slurm_job(
445
+ slurm_job_ids=[job_data.data.get("slurm_job_id")],
446
+ username=job_data.data.get("username"),
447
+ hostname=job_data.data.get("hostname"),
448
+ socket=job_data.data.get("socket"),
449
+ )
450
+
451
+ # Mark job as killed in database if kill succeeded
452
+ if result.returncode == 0:
453
+ job_data.data["killed"] = True
454
+ db.write_job(job_data)
455
+ else:
456
+ # Use the pre-fetched status for better error message
457
+ current_status = None
458
+ if slurm_status:
459
+ current_status = SlurmExecutor._map_slurm_state_to_execution_state(
460
+ slurm_status
461
+ )
462
+ error_msg = SlurmExecutor.get_kill_failure_message(
463
+ job_id,
464
+ f"slurm_job_id: {job_data.data.get('slurm_job_id')}",
465
+ current_status,
466
+ )
467
+ raise RuntimeError(error_msg)
468
+
469
+
470
+ def _create_slurm_sbatch_script(
471
+ cfg: DictConfig,
472
+ task: DictConfig,
473
+ eval_image: str,
474
+ remote_task_subdir: Path,
475
+ invocation_id: str,
476
+ job_id: str,
477
+ ) -> CmdAndReadableComment:
478
+ """Generate the contents of a SLURM sbatch script for a given evaluation task.
479
+
480
+ Args:
481
+ cfg: The configuration object for the evaluation run.
482
+ task: The evaluation task configuration.
483
+ remote_task_subdir: The remote directory path for the `run.sub` file.
484
+ invocation_id: The invocation ID for this evaluation run.
485
+ job_id: The complete job ID string.
486
+
487
+ Returns:
488
+ str: The contents of the sbatch script.
489
+ """
490
+ # get task from mapping, overrides, urls
491
+ tasks_mapping = load_tasks_mapping()
492
+ task_definition = get_task_from_mapping(task.name, tasks_mapping)
493
+
494
+ # Create merged config for get_endpoint_url
495
+ merged_nemo_evaluator_config = get_eval_factory_config(cfg, task)
496
+ health_url = get_health_url(
497
+ cfg,
498
+ get_endpoint_url(
499
+ cfg, merged_nemo_evaluator_config, task_definition["endpoint_type"]
500
+ ),
501
+ )
502
+
503
+ # TODO(public release): convert to template
504
+ s = "#!/bin/bash\n"
505
+
506
+ # SBATCH headers
507
+ s += "#SBATCH --time {}\n".format(cfg.execution.walltime)
508
+ s += "#SBATCH --account {}\n".format(cfg.execution.account)
509
+ s += "#SBATCH --partition {}\n".format(cfg.execution.partition)
510
+ s += "#SBATCH --nodes {}\n".format(cfg.execution.num_nodes)
511
+ s += "#SBATCH --ntasks-per-node {}\n".format(cfg.execution.ntasks_per_node)
512
+ if cfg.execution.get("gpus_per_node", None) is not None:
513
+ s += "#SBATCH --gpus-per-node {}\n".format(cfg.execution.gpus_per_node)
514
+ if hasattr(cfg.execution, "gres"):
515
+ s += "#SBATCH --gres {}\n".format(cfg.execution.gres)
516
+ job_name = "{account}-{subproject}.{details}".format(
517
+ account=cfg.execution.account,
518
+ subproject=cfg.execution.subproject,
519
+ details=remote_task_subdir.name,
520
+ )
521
+ s += "#SBATCH --job-name {}\n".format(job_name)
522
+ s += "#SBATCH --exclusive\n"
523
+ s += "#SBATCH --output {}\n".format(remote_task_subdir / "logs" / "slurm-%A.out")
524
+ s += "\n"
525
+ s += f'TASK_DIR="{str(remote_task_subdir)}"\n'
526
+ s += "\n"
527
+
528
+ # collect all env vars
529
+ env_vars = copy.deepcopy(dict(cfg.evaluation.get("env_vars", {})))
530
+ env_vars.update(task.get("env_vars", {}))
531
+ api_key_name = get_api_key_name(cfg)
532
+ if api_key_name:
533
+ assert "API_KEY" not in env_vars
534
+ env_vars["API_KEY"] = api_key_name
535
+
536
+ # check if the environment variables are set
537
+ for env_var in env_vars.values():
538
+ if os.getenv(env_var) is None:
539
+ raise ValueError(f"Trying to pass an unset environment variable {env_var}.")
540
+
541
+ # check if required env vars are defined:
542
+ for required_env_var in task_definition.get("required_env_vars", []):
543
+ if required_env_var not in env_vars.keys():
544
+ raise ValueError(
545
+ f"{task.name} task requires environment variable {required_env_var}."
546
+ " Specify it in the task subconfig in the 'env_vars' dict as the following"
547
+ f" pair {required_env_var}: YOUR_ENV_VAR_NAME"
548
+ )
549
+
550
+ # save env vars:
551
+ for env_var_dst, env_var_src in env_vars.items():
552
+ s += f"export {env_var_dst}={os.getenv(env_var_src)}\n"
553
+ all_env_vars = {
554
+ **cfg.execution.get("env_vars", {}).get("deployment", {}),
555
+ **cfg.execution.get("env_vars", {}).get("evaluation", {}),
556
+ }
557
+ if cfg.deployment.get("env_vars"):
558
+ warnings.warn(
559
+ "cfg.deployment.env_vars will be deprecated in future versions. "
560
+ "Use cfg.execution.env_vars.deployment instead.",
561
+ category=DeprecationWarning,
562
+ stacklevel=2,
563
+ )
564
+ all_env_vars.update(cfg.deployment["env_vars"])
565
+ for env_var_dst, env_var_value in all_env_vars.items():
566
+ s += f"export {env_var_dst}={env_var_value}\n"
567
+ if env_vars:
568
+ s += "\n"
569
+
570
+ # auto resume after timeout
571
+ s += _AUTORESUME_HANDLER
572
+ s += "\n\n"
573
+
574
+ # echo the current SLURM_JOB_ID
575
+ s += "# save the current job id\n"
576
+ s += "echo $SLURM_JOB_ID >> {}\n\n".format(
577
+ remote_task_subdir / ".slurm_job_id.list"
578
+ )
579
+
580
+ # shell options
581
+ s += "set -e # exit immediately if any command exits with a non-zero status\n"
582
+ s += "set -u # treat unset variables as an error when substituting\n"
583
+ s += "set -x # print commands and their arguments as they are executed\n"
584
+ s += "\n"
585
+
586
+ # prepare deployment mounts
587
+ deployment_mounts_list = []
588
+ if cfg.deployment.type != "none":
589
+ if checkpoint_path := cfg.deployment.get("checkpoint_path"):
590
+ deployment_mounts_list.append(f"{checkpoint_path}:/checkpoint:ro")
591
+ if cache_path := cfg.deployment.get("cache_path"):
592
+ deployment_mounts_list.append(f"{cache_path}:/cache")
593
+ for source_mnt, target_mnt in (
594
+ cfg.execution.get("mounts", {}).get("deployment", {}).items()
595
+ ):
596
+ deployment_mounts_list.append(f"{source_mnt}:{target_mnt}")
597
+
598
+ # add deployment srun command
599
+ s += "# deployment server\n"
600
+ s += "srun --mpi pmix --overlap "
601
+ s += "--container-image {} ".format(cfg.deployment.image)
602
+ if deployment_mounts_list:
603
+ s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
604
+ if not cfg.execution.get("mounts", {}).get("mount_home", True):
605
+ s += "--no-container-mount-home "
606
+ s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A.out")
607
+ deployment_env_var_names = list(
608
+ cfg.execution.get("env_vars", {}).get("deployment", {})
609
+ )
610
+ if cfg.deployment.get("env_vars"):
611
+ warnings.warn(
612
+ "cfg.deployment.env_vars will be deprecated in future versions. "
613
+ "Use cfg.execution.env_vars.deployment instead.",
614
+ category=DeprecationWarning,
615
+ stacklevel=2,
616
+ )
617
+ deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
618
+ if deployment_env_var_names:
619
+ s += f"--container-env {','.join(deployment_env_var_names)} "
620
+ s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
621
+ s += (
622
+ "SERVER_PID=$! # capture the PID of the server background srun process\n\n"
623
+ )
624
+
625
+ # wait for the server to initialize
626
+ s += _WAIT_FOR_SERVER_HANDLER.format(health_url=health_url)
627
+ s += "\n\n"
628
+
629
+ # prepare evaluation mounts
630
+ evaluation_mounts_list = [
631
+ "{}:/results".format(remote_task_subdir / "artifacts"),
632
+ ]
633
+ for source_mnt, target_mnt in (
634
+ cfg.execution.get("mounts", {}).get("evaluation", {}).items()
635
+ ):
636
+ evaluation_mounts_list.append(f"{source_mnt}:{target_mnt}")
637
+
638
+ eval_factory_command_struct = get_eval_factory_command(
639
+ cfg,
640
+ task,
641
+ task_definition,
642
+ )
643
+ eval_factory_command = eval_factory_command_struct.cmd
644
+ # The debug comment for placing into the script and easy debug. Reason
645
+ # (see `CmdAndReadableComment`) is the current way of passing the command
646
+ # is base64-encoded config `echo`-ed into file.
647
+ # TODO(agronskiy): cleaner way is to encode everything with base64, not
648
+ # some parts (like ef_config.yaml) and just output as logs somewhere.
649
+ eval_factory_command_debug_comment = eval_factory_command_struct.debug
650
+
651
+ # add evaluation srun command
652
+ s += "# Debug contents of the eval factory command's config\n"
653
+ s += eval_factory_command_debug_comment
654
+ s += "\n\n"
655
+
656
+ s += "# evaluation client\n"
657
+ s += "srun --mpi pmix --overlap "
658
+ s += "--container-image {} ".format(eval_image)
659
+ evaluation_env_var_names = list(
660
+ cfg.execution.get("env_vars", {}).get("evaluation", {})
661
+ )
662
+ if evaluation_env_var_names:
663
+ s += "--container-env {} ".format(",".join(evaluation_env_var_names))
664
+ if not cfg.execution.get("mounts", {}).get("mount_home", True):
665
+ s += "--no-container-mount-home "
666
+
667
+ s += "--container-mounts {} ".format(",".join(evaluation_mounts_list))
668
+ s += "--output {} ".format(remote_task_subdir / "logs" / "client-%A.out")
669
+ s += "bash -c '\n"
670
+ s += eval_factory_command
671
+ s += "'\n\n"
672
+
673
+ # terminate the server after all evaluation clients finish
674
+ if cfg.deployment.type != "none":
675
+ s += "kill $SERVER_PID # terminate the server to finish gracefully\n\n"
676
+
677
+ # auto-export
678
+ ae_cfg = cfg.execution.get("auto_export")
679
+ destinations: list = []
680
+ if isinstance(ae_cfg, list):
681
+ destinations = list(ae_cfg)
682
+ elif isinstance(ae_cfg, dict) or isinstance(ae_cfg, DictConfig):
683
+ destinations = list(ae_cfg.get("destinations", []) or [])
684
+
685
+ if destinations:
686
+ export_env = dict(cfg.execution.get("env_vars", {}).get("export", {}) or {})
687
+ s += _generate_auto_export_section(cfg, job_id, destinations, export_env)
688
+
689
+ debug_str = "\n".join(["# " + line for line in s.splitlines()])
690
+ return CmdAndReadableComment(
691
+ cmd=s,
692
+ debug=debug_str,
693
+ is_potentially_unsafe=eval_factory_command_struct.is_potentially_unsafe,
694
+ )
695
+
696
+
697
+ def _generate_auto_export_section(
698
+ cfg: DictConfig,
699
+ job_id: str,
700
+ destinations: list,
701
+ export_env: dict,
702
+ ) -> str:
703
+ """Generate simple auto-export section for sbatch script."""
704
+ if not destinations:
705
+ return ""
706
+
707
+ s = "\n# Auto-export on success\n"
708
+ s += "EVAL_EXIT_CODE=$?\n"
709
+ s += "if [ $EVAL_EXIT_CODE -eq 0 ]; then\n"
710
+ s += " echo 'Evaluation completed successfully. Starting auto-export...'\n"
711
+ s += " set +e\n"
712
+ s += " set +x\n"
713
+ s += " set +u\n"
714
+ s += ' cd "$TASK_DIR/artifacts"\n'
715
+
716
+ # Work with DictConfig; convert only for YAML at the end
717
+ exec_type = (
718
+ cfg.execution.type
719
+ if hasattr(cfg.execution, "type")
720
+ else cfg.execution.get("type", "slurm")
721
+ )
722
+ eval_tasks = (
723
+ list(cfg.evaluation.tasks)
724
+ if hasattr(cfg, "evaluation") and hasattr(cfg.evaluation, "tasks")
725
+ else list((cfg.get("evaluation", {}) or {}).get("tasks", []) or [])
726
+ )
727
+ export_block = cfg.get("export", {}) or {}
728
+
729
+ payload = {
730
+ "execution": {
731
+ "auto_export": {
732
+ "destinations": list(destinations),
733
+ **({"env_vars": dict(export_env)} if export_env else {}),
734
+ },
735
+ "type": exec_type,
736
+ },
737
+ "evaluation": {"tasks": eval_tasks},
738
+ }
739
+ if export_block:
740
+ # Convert just this block to plain for YAML
741
+ payload["export"] = (
742
+ OmegaConf.to_object(export_block)
743
+ if OmegaConf.is_config(export_block)
744
+ else dict(export_block)
745
+ )
746
+
747
+ # Final YAML (single conversion at the end)
748
+ payload_clean = OmegaConf.to_container(OmegaConf.create(payload), resolve=True)
749
+ yaml_str = yaml.safe_dump(payload_clean, sort_keys=False)
750
+ s += " cat > export_config.yml << 'EOF'\n"
751
+ s += yaml_str
752
+ s += "EOF\n"
753
+
754
+ # write launcher config as config.yml for exporters (no core command)
755
+ submitted_yaml = yaml.safe_dump(
756
+ OmegaConf.to_container(cfg, resolve=True), sort_keys=False
757
+ )
758
+ s += " cat > config.yml << 'EOF'\n"
759
+ s += submitted_yaml
760
+ s += "EOF\n"
761
+
762
+ # Export host only env before running auto export
763
+ for k, v in (export_env or {}).items():
764
+ if isinstance(v, str) and re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", v):
765
+ s += f' export {k}="${{{v}}}"\n'
766
+ else:
767
+ esc = str(v).replace('"', '\\"')
768
+ s += f' export {k}="{esc}"\n'
769
+
770
+ for dest in destinations:
771
+ s += f" echo 'Exporting to {dest}...'\n"
772
+ s += f" nemo-evaluator-launcher export {job_id} --dest {dest} || echo 'Export to {dest} failed'\n"
773
+
774
+ s += " echo 'Auto-export completed.'\n"
775
+ s += "else\n"
776
+ s += " echo 'Evaluation failed with exit code $EVAL_EXIT_CODE. Skipping auto-export.'\n"
777
+ s += "fi\n"
778
+
779
+ return s
780
+
781
+
782
+ def _open_master_connection(
783
+ username: str,
784
+ hostname: str,
785
+ socket: str,
786
+ ) -> str | None:
787
+ ssh_command = f"ssh -MNf -S {socket} {username}@{hostname}"
788
+ completed_process = subprocess.run(
789
+ args=shlex.split(ssh_command), capture_output=True
790
+ )
791
+ if completed_process.returncode == 0:
792
+ return socket
793
+ return None
794
+
795
+
796
+ def _close_master_connection(
797
+ username: str,
798
+ hostname: str,
799
+ socket: str | None,
800
+ ) -> None:
801
+ if socket is None:
802
+ return
803
+ ssh_command = f"ssh -O exit -S {socket} {username}@{hostname}"
804
+ completed_process = subprocess.run(
805
+ args=shlex.split(ssh_command), capture_output=True
806
+ )
807
+ if completed_process.returncode != 0:
808
+ raise RuntimeError(
809
+ "failed to close the master connection\n{}".format(
810
+ completed_process.stderr.decode("utf-8")
811
+ )
812
+ )
813
+
814
+
815
+ def _make_remote_execution_output_dir(
816
+ dirpath: str,
817
+ username: str,
818
+ hostname: str,
819
+ socket: str | None,
820
+ ) -> None:
821
+ mkdir_command = f"mkdir -p {dirpath}"
822
+ ssh_command = ["ssh"]
823
+ if socket is not None:
824
+ ssh_command.append(f"-S {socket}")
825
+ ssh_command.append(f"{username}@{hostname}")
826
+ ssh_command.append(mkdir_command)
827
+ ssh_command = " ".join(ssh_command)
828
+ completed_process = subprocess.run(
829
+ args=shlex.split(ssh_command), capture_output=True
830
+ )
831
+ if completed_process.returncode != 0:
832
+ error_msg = (
833
+ completed_process.stderr.decode("utf-8")
834
+ if completed_process.stderr
835
+ else "Unknown error"
836
+ )
837
+ raise RuntimeError(
838
+ "failed to make a remote execution output dir\n{}".format(error_msg)
839
+ )
840
+
841
+
842
+ def _rsync_upload_rundirs(
843
+ local_sources: List[Path],
844
+ remote_target: str,
845
+ username: str,
846
+ hostname: str,
847
+ ) -> None:
848
+ """Upload local run directories to a remote host using rsync over SSH.
849
+
850
+ Args:
851
+ local_sources: List of local Path objects to upload.
852
+ remote_target: Remote directory path as a string.
853
+ hostname: SSH hostname.
854
+ username: SSH username.
855
+
856
+ Raises:
857
+ RuntimeError: If rsync fails.
858
+ """
859
+ for local_source in local_sources:
860
+ assert local_source.is_dir()
861
+ remote_destination_str = f"{username}@{hostname}:{remote_target}"
862
+ local_sources_str = " ".join(map(str, local_sources))
863
+ rsync_upload_command = f"rsync -qcaz {local_sources_str} {remote_destination_str}"
864
+ completed_process = subprocess.run(
865
+ args=shlex.split(rsync_upload_command), capture_output=True
866
+ )
867
+ if completed_process.returncode != 0:
868
+ error_msg = (
869
+ completed_process.stderr.decode("utf-8")
870
+ if completed_process.stderr
871
+ else "Unknown error"
872
+ )
873
+ raise RuntimeError("failed to upload local sources\n{}".format(error_msg))
874
+
875
+
876
+ def _sbatch_remote_runsubs(
877
+ remote_runsub_paths: List[Path],
878
+ username: str,
879
+ hostname: str,
880
+ socket: str | None,
881
+ ) -> List[str]:
882
+ sbatch_commands = [
883
+ "sbatch {}".format(remote_runsub_path)
884
+ for remote_runsub_path in remote_runsub_paths
885
+ ]
886
+ sbatch_commands = " ; ".join(sbatch_commands)
887
+
888
+ ssh_command = ["ssh"]
889
+ if socket is not None:
890
+ ssh_command.append(f"-S {socket}")
891
+ ssh_command.append(f"{username}@{hostname}")
892
+ ssh_command.append(sbatch_commands)
893
+ ssh_command = " ".join(ssh_command)
894
+
895
+ completed_process = subprocess.run(
896
+ args=shlex.split(ssh_command), capture_output=True
897
+ )
898
+ if completed_process.returncode != 0:
899
+ error_msg = completed_process.stderr.decode("utf-8")
900
+ raise RuntimeError(
901
+ "failed to submit sbatch scripts for execution\n{}".format(error_msg)
902
+ )
903
+
904
+ sbatch_output = completed_process.stdout.decode("utf-8")
905
+ slurm_job_ids = re.findall(r"(?<=Submitted batch job )\d+", sbatch_output)
906
+ return slurm_job_ids
907
+
908
+
909
+ def _query_slurm_jobs_status(
910
+ slurm_job_ids: List[str],
911
+ username: str,
912
+ hostname: str,
913
+ socket: str | None,
914
+ ) -> Dict[str, str]:
915
+ """Query SLURM for job statuses using sacct command.
916
+
917
+ Args:
918
+ slurm_job_ids: List of SLURM job IDs to query.
919
+ username: SSH username.
920
+ hostname: SSH hostname.
921
+ socket: control socket location or None
922
+
923
+ Returns:
924
+ Dict mapping from slurm_job_id to returned slurm status.
925
+ """
926
+ if len(slurm_job_ids) == 0:
927
+ return {}
928
+ sacct_command = "sacct -j {} --format='JobID,State%32' --noheader -P".format(
929
+ ",".join(slurm_job_ids)
930
+ )
931
+ ssh_command = ["ssh"]
932
+ if socket is not None:
933
+ ssh_command.append(f"-S {socket}")
934
+ ssh_command.append(f"{username}@{hostname}")
935
+ ssh_command.append(sacct_command)
936
+ ssh_command = " ".join(ssh_command)
937
+ completed_process = subprocess.run(
938
+ args=shlex.split(ssh_command), capture_output=True
939
+ )
940
+ if completed_process.returncode != 0:
941
+ raise RuntimeError(
942
+ "failed to query slurm job status\n{}".format(
943
+ completed_process.stderr.decode("utf-8")
944
+ )
945
+ )
946
+ sacct_output = completed_process.stdout.decode("utf-8")
947
+ sacct_output_lines = sacct_output.strip().split("\n")
948
+ slurm_jobs_status = {}
949
+ for slurm_job_id in slurm_job_ids:
950
+ slurm_job_status = _parse_slurm_job_status(slurm_job_id, sacct_output_lines)
951
+ slurm_jobs_status[slurm_job_id] = slurm_job_status
952
+ return slurm_jobs_status
953
+
954
+
955
+ def _kill_slurm_job(
956
+ slurm_job_ids: List[str], username: str, hostname: str, socket: str | None
957
+ ) -> tuple[str | None, subprocess.CompletedProcess]:
958
+ """Kill a SLURM job, querying status first in one SSH call for efficiency.
959
+
960
+ Args:
961
+ slurm_job_ids: List of SLURM job IDs to kill.
962
+ username: SSH username.
963
+ hostname: SSH hostname.
964
+ socket: control socket location or None
965
+
966
+ Returns:
967
+ Tuple of (status_string, completed_process) where status_string is the SLURM status or None
968
+ """
969
+ if len(slurm_job_ids) == 0:
970
+ return None, subprocess.CompletedProcess(args=[], returncode=0)
971
+
972
+ jobs_str = ",".join(slurm_job_ids)
973
+ # Combine both commands in one SSH call: query THEN kill
974
+ combined_command = (
975
+ f"sacct -j {jobs_str} --format='JobID,State%32' --noheader -P 2>/dev/null; "
976
+ f"scancel {jobs_str}"
977
+ )
978
+
979
+ ssh_command = ["ssh"]
980
+ if socket is not None:
981
+ ssh_command.append(f"-S {socket}")
982
+ ssh_command.append(f"{username}@{hostname}")
983
+ ssh_command.append(combined_command)
984
+ ssh_command = " ".join(ssh_command)
985
+
986
+ completed_process = subprocess.run(
987
+ args=shlex.split(ssh_command), capture_output=True
988
+ )
989
+
990
+ # Parse the sacct output (before scancel runs)
991
+ sacct_output = completed_process.stdout.decode("utf-8")
992
+ sacct_output_lines = sacct_output.strip().split("\n")
993
+ slurm_status = None
994
+ if sacct_output_lines and len(slurm_job_ids) == 1:
995
+ slurm_status = _parse_slurm_job_status(slurm_job_ids[0], sacct_output_lines)
996
+
997
+ return slurm_status, completed_process
998
+
999
+
1000
+ def _parse_slurm_job_status(slurm_job_id: str, sacct_output_lines: List[str]) -> str:
1001
+ """Parse SLURM job status from sacct output for a specific job.
1002
+
1003
+ Args:
1004
+ slurm_job_id: The SLURM job ID to parse.
1005
+ sacct_output_lines: Lines from sacct output.
1006
+
1007
+ Returns:
1008
+ SLURM status string.
1009
+ """
1010
+ for line in sacct_output_lines:
1011
+ if line.startswith(f"{slurm_job_id}|"):
1012
+ state = line.split("|")[1]
1013
+ state = state.strip()
1014
+ if state:
1015
+ state_split = state.split()
1016
+ if len(state_split) > 0:
1017
+ return state_split[0]
1018
+ return "UNKNOWN"
1019
+
1020
+
1021
+ def _read_autoresumed_slurm_job_ids(
1022
+ slurm_job_ids: List[str],
1023
+ remote_rundir_paths: List[Path],
1024
+ username: str,
1025
+ hostname: str,
1026
+ socket: str | None,
1027
+ ) -> Dict[str, List[str]]:
1028
+ assert len(slurm_job_ids) == len(remote_rundir_paths)
1029
+ slurm_job_id_list_paths = [
1030
+ str(remote_rundir_path / ".slurm_job_id.list")
1031
+ for remote_rundir_path in remote_rundir_paths
1032
+ ]
1033
+ slurm_job_id_list_strs = _read_files_from_remote(
1034
+ slurm_job_id_list_paths, username, hostname, socket
1035
+ )
1036
+ assert len(slurm_job_id_list_strs) == len(slurm_job_ids)
1037
+ autoresumed_slurm_job_ids = {}
1038
+ for i, slurm_job_id_list_str in enumerate(slurm_job_id_list_strs):
1039
+ slurm_job_id = slurm_job_ids[i]
1040
+ slurm_job_id_list = slurm_job_id_list_str.split()
1041
+ autoresumed_slurm_job_ids[slurm_job_id] = slurm_job_id_list
1042
+ return autoresumed_slurm_job_ids
1043
+
1044
+
1045
+ def _read_files_from_remote(
1046
+ filepaths: List[Path],
1047
+ username: str,
1048
+ hostname: str,
1049
+ socket: str | None,
1050
+ ) -> List[str]:
1051
+ cat_commands = [
1052
+ "echo _START_OF_FILE_ ; cat {} 2>/dev/null ; echo _END_OF_FILE_ ".format(
1053
+ filepath
1054
+ )
1055
+ for filepath in filepaths
1056
+ ]
1057
+ cat_commands = " ; ".join(cat_commands)
1058
+ ssh_command = ["ssh"]
1059
+ if socket is not None:
1060
+ ssh_command.append(f"-S {socket}")
1061
+ ssh_command.append(f"{username}@{hostname}")
1062
+ ssh_command.append(cat_commands)
1063
+ ssh_command = " ".join(ssh_command)
1064
+ completed_process = subprocess.run(
1065
+ args=shlex.split(ssh_command), capture_output=True
1066
+ )
1067
+ if completed_process.returncode != 0:
1068
+ raise RuntimeError(
1069
+ "failed to read files from remote\n{}".format(
1070
+ completed_process.stderr.decode("utf-8")
1071
+ )
1072
+ )
1073
+ cat_outputs = completed_process.stdout.decode("utf-8")
1074
+ cat_outputs = cat_outputs.replace("\n", " ")
1075
+ matches = re.findall(r"(?<=_START_OF_FILE_)(.*?)(?=_END_OF_FILE_)", cat_outputs)
1076
+ outputs = [match.strip() for match in matches]
1077
+ return outputs
1078
+
1079
+
1080
+ def _get_progress(
1081
+ remote_rundir_paths: List[Path],
1082
+ username: str,
1083
+ hostname: str,
1084
+ socket: str | None,
1085
+ ) -> List[Optional[float]]:
1086
+ remote_progress_paths = [
1087
+ remote_rundir_path / "artifacts" / "progress"
1088
+ for remote_rundir_path in remote_rundir_paths
1089
+ ]
1090
+ remote_run_config_paths = [
1091
+ remote_rundir_path / "artifacts" / "run_config.yml"
1092
+ for remote_rundir_path in remote_rundir_paths
1093
+ ]
1094
+ progress_strs = _read_files_from_remote(
1095
+ remote_progress_paths, username, hostname, socket
1096
+ )
1097
+ if any(map(bool, progress_strs)):
1098
+ run_config_strs = _read_files_from_remote(
1099
+ remote_run_config_paths, username, hostname, socket
1100
+ )
1101
+ else:
1102
+ run_config_strs = [""] * len(progress_strs)
1103
+ progress_list = []
1104
+ for progress_str, run_config_str in zip(progress_strs, run_config_strs):
1105
+ if not progress_str or not run_config_str:
1106
+ progress_list.append(None)
1107
+ continue
1108
+ run_config = yaml.safe_load(run_config_str)
1109
+ dataset_size = get_eval_factory_dataset_size_from_run_config(run_config)
1110
+ if dataset_size is not None:
1111
+ progress = int(progress_str) / dataset_size
1112
+ else:
1113
+ progress = int(progress_str)
1114
+ progress_list.append(progress)
1115
+ return progress_list
1116
+
1117
+
1118
+ _AUTORESUME_HANDLER = """
1119
+ _this_script=$0
1120
+ _prev_slurm_job_id=$1
1121
+ # Handle automatic resumption after some failed state.
1122
+ if [[ "$_prev_slurm_job_id" != "" ]]; then
1123
+ _prev_state=`sacct -j $_prev_slurm_job_id -P -n -o State | head -n 1`
1124
+ _prev_info="previous SLURM_JOB_ID $_prev_slurm_job_id finished with '$_prev_state' state."
1125
+ if [[ $_prev_state == 'TIMEOUT' || $_prev_state == 'PREEMPTED' || $_prev_state == 'NODE_FAIL' ]]; then
1126
+ echo "$_prev_info RESUMING..."
1127
+ else
1128
+ echo "$_prev_info EXIT!"
1129
+ if [[ $_prev_state == 'COMPLETED' ]]; then
1130
+ exit 0
1131
+ else
1132
+ exit 1
1133
+ fi
1134
+ fi
1135
+ fi
1136
+ # Schedule next execution of this script with the current $SLURM_JOB_ID as an argument.
1137
+ # "afternotok" means next execution will be invoked only if the current execution terminates in some failed state.
1138
+ sbatch --dependency=afternotok:$SLURM_JOB_ID $_this_script $SLURM_JOB_ID
1139
+ """.strip()
1140
+
1141
+
1142
+ _WAIT_FOR_SERVER_HANDLER = """
1143
+ date
1144
+ # wait for the server to initialize
1145
+ bash -c 'while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" {health_url})" != "200" ]]; do kill -0 '"$SERVER_PID"' 2>/dev/null || {{ echo "Server process '"$SERVER_PID"' died"; exit 1; }}; sleep 5; done'
1146
+ date
1147
+ """.strip()