nemo-evaluator-launcher 0.1.19__py3-none-any.whl → 0.1.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. nemo_evaluator_launcher/api/functional.py +159 -5
  2. nemo_evaluator_launcher/cli/logs.py +102 -0
  3. nemo_evaluator_launcher/cli/ls_task.py +280 -0
  4. nemo_evaluator_launcher/cli/ls_tasks.py +208 -55
  5. nemo_evaluator_launcher/cli/main.py +29 -2
  6. nemo_evaluator_launcher/cli/run.py +114 -16
  7. nemo_evaluator_launcher/cli/version.py +26 -23
  8. nemo_evaluator_launcher/common/container_metadata/__init__.py +61 -0
  9. nemo_evaluator_launcher/common/container_metadata/intermediate_repr.py +530 -0
  10. nemo_evaluator_launcher/common/container_metadata/loading.py +1126 -0
  11. nemo_evaluator_launcher/common/container_metadata/registries.py +824 -0
  12. nemo_evaluator_launcher/common/container_metadata/utils.py +63 -0
  13. nemo_evaluator_launcher/common/helpers.py +200 -51
  14. nemo_evaluator_launcher/common/logging_utils.py +16 -5
  15. nemo_evaluator_launcher/common/mapping.py +341 -155
  16. nemo_evaluator_launcher/common/printing_utils.py +25 -12
  17. nemo_evaluator_launcher/configs/deployment/sglang.yaml +4 -2
  18. nemo_evaluator_launcher/configs/deployment/trtllm.yaml +2 -3
  19. nemo_evaluator_launcher/configs/deployment/vllm.yaml +0 -1
  20. nemo_evaluator_launcher/configs/execution/slurm/default.yaml +14 -0
  21. nemo_evaluator_launcher/executors/base.py +31 -1
  22. nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +36 -1
  23. nemo_evaluator_launcher/executors/lepton/executor.py +107 -9
  24. nemo_evaluator_launcher/executors/local/executor.py +383 -24
  25. nemo_evaluator_launcher/executors/local/run.template.sh +54 -2
  26. nemo_evaluator_launcher/executors/slurm/executor.py +559 -64
  27. nemo_evaluator_launcher/executors/slurm/proxy.cfg.template +26 -0
  28. nemo_evaluator_launcher/exporters/utils.py +32 -46
  29. nemo_evaluator_launcher/package_info.py +1 -1
  30. nemo_evaluator_launcher/resources/all_tasks_irs.yaml +17016 -0
  31. nemo_evaluator_launcher/resources/mapping.toml +64 -315
  32. {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.56.dist-info}/METADATA +4 -3
  33. nemo_evaluator_launcher-0.1.56.dist-info/RECORD +69 -0
  34. {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.56.dist-info}/entry_points.txt +1 -0
  35. nemo_evaluator_launcher-0.1.19.dist-info/RECORD +0 -60
  36. {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.56.dist-info}/WHEEL +0 -0
  37. {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.56.dist-info}/licenses/LICENSE +0 -0
  38. {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.56.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ from pathlib import Path
30
30
  from typing import Dict, List, Optional
31
31
 
32
32
  import yaml
33
+ from jinja2 import Environment, FileSystemLoader
33
34
  from omegaconf import DictConfig, OmegaConf
34
35
 
35
36
  from nemo_evaluator_launcher.common.execdb import (
@@ -39,18 +40,19 @@ from nemo_evaluator_launcher.common.execdb import (
39
40
  generate_job_id,
40
41
  )
41
42
  from nemo_evaluator_launcher.common.helpers import (
43
+ CmdAndReadableComment,
44
+ _str_to_echo_command,
42
45
  get_api_key_name,
43
- get_endpoint_url,
44
46
  get_eval_factory_command,
45
47
  get_eval_factory_dataset_size_from_run_config,
46
- get_health_url,
47
48
  get_timestamp_string,
48
49
  )
50
+ from nemo_evaluator_launcher.common.logging_utils import logger
49
51
  from nemo_evaluator_launcher.common.mapping import (
50
- get_task_from_mapping,
52
+ get_task_definition_for_job,
51
53
  load_tasks_mapping,
52
54
  )
53
- from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey
55
+ from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey, red
54
56
  from nemo_evaluator_launcher.executors.base import (
55
57
  BaseExecutor,
56
58
  ExecutionState,
@@ -94,6 +96,7 @@ class SlurmExecutor(BaseExecutor):
94
96
  tasks_mapping = load_tasks_mapping()
95
97
  eval_images: list[str] = []
96
98
 
99
+ is_potentially_unsafe = False
97
100
  for idx, task in enumerate(cfg.evaluation.tasks):
98
101
  # calculate job_id
99
102
  job_id = f"{invocation_id}.{idx}"
@@ -106,7 +109,11 @@ class SlurmExecutor(BaseExecutor):
106
109
  (local_task_subdir / "artifacts").mkdir()
107
110
 
108
111
  # resolve eval image and pass directly via task override
109
- task_definition = get_task_from_mapping(task.name, tasks_mapping)
112
+ task_definition = get_task_definition_for_job(
113
+ task_query=task.name,
114
+ base_mapping=tasks_mapping,
115
+ container=task.get("container"),
116
+ )
110
117
  eval_image = task_definition["container"]
111
118
  if "container" in task:
112
119
  eval_image = task["container"]
@@ -114,7 +121,7 @@ class SlurmExecutor(BaseExecutor):
114
121
  eval_images.append(eval_image)
115
122
 
116
123
  # generate and write down sbatch script
117
- sbatch_script_content_str = _create_slurm_sbatch_script(
124
+ sbatch_script_content_struct = _create_slurm_sbatch_script(
118
125
  cfg=cfg,
119
126
  task=task,
120
127
  eval_image=eval_image,
@@ -122,6 +129,32 @@ class SlurmExecutor(BaseExecutor):
122
129
  invocation_id=invocation_id,
123
130
  job_id=job_id,
124
131
  )
132
+
133
+ # Create proxy config file with placeholder IPs for multi-instance deployments
134
+ if cfg.deployment.get("multiple_instances", False):
135
+ proxy_type = cfg.execution.get("proxy", {}).get("type", "haproxy")
136
+ if proxy_type == "haproxy":
137
+ proxy_config = _generate_haproxy_config_with_placeholders(cfg)
138
+ else:
139
+ raise ValueError(
140
+ f"Unsupported proxy type: {proxy_type}. Currently only 'haproxy' is supported."
141
+ )
142
+
143
+ # Save both template and working config
144
+ proxy_template_path = local_task_subdir / "proxy.cfg.template"
145
+ proxy_config_path = local_task_subdir / "proxy.cfg"
146
+ with open(proxy_template_path, "w") as f:
147
+ f.write(proxy_config)
148
+ with open(proxy_config_path, "w") as f:
149
+ f.write(proxy_config)
150
+
151
+ sbatch_script_content_str = sbatch_script_content_struct.cmd
152
+
153
+ # We accumulate if any task contains unsafe commands
154
+ is_potentially_unsafe = (
155
+ is_potentially_unsafe
156
+ or sbatch_script_content_struct.is_potentially_unsafe
157
+ )
125
158
  local_runsub_path = local_task_subdir / "run.sub"
126
159
  remote_runsub_path = remote_task_subdir / "run.sub"
127
160
  with open(local_runsub_path, "w") as f:
@@ -138,14 +171,56 @@ class SlurmExecutor(BaseExecutor):
138
171
  with open(local_runsub_path, "r") as f:
139
172
  print(grey(f.read()))
140
173
  print(bold("To submit jobs") + ", run the executor without --dry-run")
174
+ if is_potentially_unsafe:
175
+ print(
176
+ red(
177
+ "\nFound `pre_cmd` (evaluation or deployment) which carries security risk. When running without --dry-run "
178
+ "make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1"
179
+ )
180
+ )
181
+
141
182
  return invocation_id
142
183
 
184
+ if is_potentially_unsafe:
185
+ if os.environ.get("NEMO_EVALUATOR_TRUST_PRE_CMD", "") == "1":
186
+ logger.warning(
187
+ "Found non-empty commands (e.g. `pre_cmd` in evaluation or deployment) and NEMO_EVALUATOR_TRUST_PRE_CMD "
188
+ "is set, proceeding with caution."
189
+ )
190
+
191
+ else:
192
+ logger.error(
193
+ "Found non-empty commands (e.g. `pre_cmd` in evaluation or deployment) and NEMO_EVALUATOR_TRUST_PRE_CMD "
194
+ "is not set. This might carry security risk and unstable environments. "
195
+ "To continue, make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1.",
196
+ )
197
+ raise AttributeError(
198
+ "Untrusted command found in config, make sure you trust and "
199
+ "set NEMO_EVALUATOR_TRUST_PRE_CMD=1."
200
+ )
201
+
143
202
  socket = str(Path(tmpdirname) / "socket")
144
203
  socket_or_none = _open_master_connection(
145
204
  username=cfg.execution.username,
146
205
  hostname=cfg.execution.hostname,
147
206
  socket=socket,
148
207
  )
208
+
209
+ if socket_or_none is None:
210
+ raise RuntimeError(
211
+ f"Failed to connect to the cluster {cfg.execution.hostname} as user {cfg.execution.username}. "
212
+ "Please check your SSH configuration."
213
+ )
214
+
215
+ # Validate that all mount paths exist on the remote host
216
+ mount_paths = _collect_mount_paths(cfg)
217
+ _validate_remote_paths_exist(
218
+ paths=mount_paths,
219
+ username=cfg.execution.username,
220
+ hostname=cfg.execution.hostname,
221
+ socket=socket_or_none,
222
+ )
223
+
149
224
  _make_remote_execution_output_dir(
150
225
  dirpath=cfg.execution.output_dir,
151
226
  username=cfg.execution.username,
@@ -437,7 +512,7 @@ def _create_slurm_sbatch_script(
437
512
  remote_task_subdir: Path,
438
513
  invocation_id: str,
439
514
  job_id: str,
440
- ) -> str:
515
+ ) -> CmdAndReadableComment:
441
516
  """Generate the contents of a SLURM sbatch script for a given evaluation task.
442
517
 
443
518
  Args:
@@ -452,8 +527,11 @@ def _create_slurm_sbatch_script(
452
527
  """
453
528
  # get task from mapping, overrides, urls
454
529
  tasks_mapping = load_tasks_mapping()
455
- task_definition = get_task_from_mapping(task.name, tasks_mapping)
456
- health_url = get_health_url(cfg, get_endpoint_url(cfg, task, task_definition))
530
+ task_definition = get_task_definition_for_job(
531
+ task_query=task.name,
532
+ base_mapping=tasks_mapping,
533
+ container=task.get("container"),
534
+ )
457
535
 
458
536
  # TODO(public release): convert to template
459
537
  s = "#!/bin/bash\n"
@@ -468,6 +546,8 @@ def _create_slurm_sbatch_script(
468
546
  s += "#SBATCH --gpus-per-node {}\n".format(cfg.execution.gpus_per_node)
469
547
  if hasattr(cfg.execution, "gres"):
470
548
  s += "#SBATCH --gres {}\n".format(cfg.execution.gres)
549
+ if cfg.execution.get("sbatch_comment"):
550
+ s += "#SBATCH --comment='{}'\n".format(cfg.execution.sbatch_comment)
471
551
  job_name = "{account}-{subproject}.{details}".format(
472
552
  account=cfg.execution.account,
473
553
  subproject=cfg.execution.subproject,
@@ -475,7 +555,8 @@ def _create_slurm_sbatch_script(
475
555
  )
476
556
  s += "#SBATCH --job-name {}\n".format(job_name)
477
557
  s += "#SBATCH --exclusive\n"
478
- s += "#SBATCH --output {}\n".format(remote_task_subdir / "logs" / "slurm-%A.out")
558
+ s += "#SBATCH --no-requeue\n" # We have our own auto-resume logic
559
+ s += "#SBATCH --output {}\n".format(remote_task_subdir / "logs" / "slurm-%A.log")
479
560
  s += "\n"
480
561
  s += f'TASK_DIR="{str(remote_task_subdir)}"\n'
481
562
  s += "\n"
@@ -493,8 +574,11 @@ def _create_slurm_sbatch_script(
493
574
  if os.getenv(env_var) is None:
494
575
  raise ValueError(f"Trying to pass an unset environment variable {env_var}.")
495
576
 
496
- # check if required env vars are defined:
577
+ # check if required env vars are defined (excluding NEMO_EVALUATOR_DATASET_DIR which is handled separately):
497
578
  for required_env_var in task_definition.get("required_env_vars", []):
579
+ # Skip NEMO_EVALUATOR_DATASET_DIR as it's handled by dataset mounting logic below
580
+ if required_env_var == "NEMO_EVALUATOR_DATASET_DIR":
581
+ continue
498
582
  if required_env_var not in env_vars.keys():
499
583
  raise ValueError(
500
584
  f"{task.name} task requires environment variable {required_env_var}."
@@ -540,6 +624,7 @@ def _create_slurm_sbatch_script(
540
624
 
541
625
  # prepare deployment mounts
542
626
  deployment_mounts_list = []
627
+ deployment_is_unsafe = False
543
628
  if cfg.deployment.type != "none":
544
629
  if checkpoint_path := cfg.deployment.get("checkpoint_path"):
545
630
  deployment_mounts_list.append(f"{checkpoint_path}:/checkpoint:ro")
@@ -551,36 +636,33 @@ def _create_slurm_sbatch_script(
551
636
  deployment_mounts_list.append(f"{source_mnt}:{target_mnt}")
552
637
 
553
638
  # add deployment srun command
554
- s += "# deployment server\n"
555
- s += "srun --mpi pmix --overlap "
556
- s += "--container-image {} ".format(cfg.deployment.image)
557
- if deployment_mounts_list:
558
- s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
559
- if not cfg.execution.get("mounts", {}).get("mount_home", True):
560
- s += "--no-container-mount-home "
561
- s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A.out")
562
- deployment_env_var_names = list(
563
- cfg.execution.get("env_vars", {}).get("deployment", {})
564
- )
565
- if cfg.deployment.get("env_vars"):
566
- warnings.warn(
567
- "cfg.deployment.env_vars will be deprecated in future versions. "
568
- "Use cfg.execution.env_vars.deployment instead.",
569
- category=DeprecationWarning,
570
- stacklevel=2,
639
+ deployment_srun_cmd, deployment_is_unsafe, deployment_debug = (
640
+ _generate_deployment_srun_command(
641
+ cfg, deployment_mounts_list, remote_task_subdir
571
642
  )
572
- deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
573
- if deployment_env_var_names:
574
- s += f"--container-env {','.join(deployment_env_var_names)} "
575
- s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
576
- s += (
577
- "SERVER_PID=$! # capture the PID of the server background srun process\n\n"
578
643
  )
644
+ s += deployment_srun_cmd
579
645
 
580
646
  # wait for the server to initialize
581
- s += _WAIT_FOR_SERVER_HANDLER.format(health_url=health_url)
647
+ health_path = cfg.deployment.get("health_check_path", "/health")
648
+ # For multi-instance check all node IPs, for single instance check localhost
649
+ if cfg.deployment.get("multiple_instances", False):
650
+ ip_list = '"${NODES_IPS_ARRAY[@]}"'
651
+ else:
652
+ ip_list = '"127.0.0.1"'
653
+ s += _get_wait_for_server_handler(
654
+ ip_list,
655
+ cfg.deployment.port,
656
+ health_path,
657
+ "server",
658
+ check_pid=True,
659
+ )
582
660
  s += "\n\n"
583
661
 
662
+ # add proxy load balancer for multi-instance deployments
663
+ if cfg.deployment.get("multiple_instances", False):
664
+ s += _get_proxy_server_srun_command(cfg, remote_task_subdir)
665
+
584
666
  # prepare evaluation mounts
585
667
  evaluation_mounts_list = [
586
668
  "{}:/results".format(remote_task_subdir / "artifacts"),
@@ -590,7 +672,29 @@ def _create_slurm_sbatch_script(
590
672
  ):
591
673
  evaluation_mounts_list.append(f"{source_mnt}:{target_mnt}")
592
674
 
593
- eval_factory_command_struct = get_eval_factory_command(cfg, task, task_definition)
675
+ # Handle dataset directory mounting if NEMO_EVALUATOR_DATASET_DIR is required
676
+ if "NEMO_EVALUATOR_DATASET_DIR" in task_definition.get("required_env_vars", []):
677
+ # Get dataset directory from task config
678
+ if "dataset_dir" in task:
679
+ dataset_mount_host = task["dataset_dir"]
680
+ else:
681
+ raise ValueError(
682
+ f"{task.name} task requires a dataset_dir to be specified. "
683
+ f"Add 'dataset_dir: /path/to/your/dataset' under the task configuration."
684
+ )
685
+ # Get container mount path (default to /datasets if not specified)
686
+ dataset_mount_container = task.get("dataset_mount_path", "/datasets")
687
+ # Add dataset mount to evaluation mounts list
688
+ evaluation_mounts_list.append(f"{dataset_mount_host}:{dataset_mount_container}")
689
+ # Export NEMO_EVALUATOR_DATASET_DIR environment variable
690
+ s += f"export NEMO_EVALUATOR_DATASET_DIR={dataset_mount_container}\n\n"
691
+
692
+ eval_factory_command_struct = get_eval_factory_command(
693
+ cfg,
694
+ task,
695
+ task_definition,
696
+ )
697
+
594
698
  eval_factory_command = eval_factory_command_struct.cmd
595
699
  # The debug comment for placing into the script and easy debug. Reason
596
700
  # (see `CmdAndReadableComment`) is the current way of passing the command
@@ -606,6 +710,7 @@ def _create_slurm_sbatch_script(
606
710
 
607
711
  s += "# evaluation client\n"
608
712
  s += "srun --mpi pmix --overlap "
713
+ s += "--nodes 1 --ntasks 1 " # Client always runs on single node
609
714
  s += "--container-image {} ".format(eval_image)
610
715
  evaluation_env_var_names = list(
611
716
  cfg.execution.get("env_vars", {}).get("evaluation", {})
@@ -616,14 +721,17 @@ def _create_slurm_sbatch_script(
616
721
  s += "--no-container-mount-home "
617
722
 
618
723
  s += "--container-mounts {} ".format(",".join(evaluation_mounts_list))
619
- s += "--output {} ".format(remote_task_subdir / "logs" / "client-%A.out")
724
+ s += "--output {} ".format(remote_task_subdir / "logs" / "client-%A.log")
620
725
  s += "bash -c '\n"
621
726
  s += eval_factory_command
622
727
  s += "'\n\n"
623
728
 
624
729
  # terminate the server after all evaluation clients finish
625
730
  if cfg.deployment.type != "none":
626
- s += "kill $SERVER_PID # terminate the server to finish gracefully\n\n"
731
+ s += "kill $SERVER_PID # terminate the server to finish gracefully\n"
732
+ if cfg.deployment.get("multiple_instances", False):
733
+ s += "kill $PROXY_PID # terminate proxy to finish gracefully\n"
734
+ s += "\n"
627
735
 
628
736
  # auto-export
629
737
  ae_cfg = cfg.execution.get("auto_export")
@@ -635,9 +743,22 @@ def _create_slurm_sbatch_script(
635
743
 
636
744
  if destinations:
637
745
  export_env = dict(cfg.execution.get("env_vars", {}).get("export", {}) or {})
638
- s += _generate_auto_export_section(cfg, job_id, destinations, export_env)
746
+ s += _generate_auto_export_section(
747
+ cfg, job_id, destinations, export_env, remote_task_subdir
748
+ )
639
749
 
640
- return s
750
+ debug_str = "\n".join(["# " + line for line in s.splitlines()])
751
+
752
+ # Combine unsafe flags from both deployment and evaluation
753
+ is_potentially_unsafe = (
754
+ eval_factory_command_struct.is_potentially_unsafe or deployment_is_unsafe
755
+ )
756
+
757
+ return CmdAndReadableComment(
758
+ cmd=s,
759
+ debug=debug_str,
760
+ is_potentially_unsafe=is_potentially_unsafe,
761
+ )
641
762
 
642
763
 
643
764
  def _generate_auto_export_section(
@@ -645,6 +766,8 @@ def _generate_auto_export_section(
645
766
  job_id: str,
646
767
  destinations: list,
647
768
  export_env: dict,
769
+ remote_task_subdir: Path,
770
+ export_image: str = "python:3.12.7-slim",
648
771
  ) -> str:
649
772
  """Generate simple auto-export section for sbatch script."""
650
773
  if not destinations:
@@ -654,10 +777,7 @@ def _generate_auto_export_section(
654
777
  s += "EVAL_EXIT_CODE=$?\n"
655
778
  s += "if [ $EVAL_EXIT_CODE -eq 0 ]; then\n"
656
779
  s += " echo 'Evaluation completed successfully. Starting auto-export...'\n"
657
- s += " set +e\n"
658
- s += " set +x\n"
659
- s += " set +u\n"
660
- s += ' cd "$TASK_DIR/artifacts"\n'
780
+ s += f' cd "{remote_task_subdir}/artifacts"\n'
661
781
 
662
782
  # Work with DictConfig; convert only for YAML at the end
663
783
  exec_type = (
@@ -713,10 +833,25 @@ def _generate_auto_export_section(
713
833
  esc = str(v).replace('"', '\\"')
714
834
  s += f' export {k}="{esc}"\n'
715
835
 
716
- for dest in destinations:
717
- s += f" echo 'Exporting to {dest}...'\n"
718
- s += f" nemo-evaluator-launcher export {job_id} --dest {dest} || echo 'Export to {dest} failed'\n"
836
+ s += " # export\n"
837
+ s += " srun --mpi pmix --overlap "
838
+ s += "--nodes 1 --ntasks 1 " # Client always runs on single node
839
+ s += "--container-image {} ".format(export_image)
840
+ if export_env:
841
+ s += "--container-env {} ".format(",".join(export_env))
842
+ if not cfg.execution.get("mounts", {}).get("mount_home", True):
843
+ s += "--no-container-mount-home "
719
844
 
845
+ s += f"--container-mounts {remote_task_subdir}/artifacts:{remote_task_subdir}/artifacts,{remote_task_subdir}/logs:{remote_task_subdir}/logs "
846
+ s += "--output {} ".format(remote_task_subdir / "logs" / "export-%A.log")
847
+ s += " bash -c '\n"
848
+ # FIXME(martas): would be good to install specific version
849
+ s += " pip install nemo-evaluator-launcher[all]\n"
850
+ s += f" cd {remote_task_subdir}/artifacts\n"
851
+ for dest in destinations:
852
+ s += f' echo "Exporting to {dest}..."\n'
853
+ s += f' nemo-evaluator-launcher export {job_id} --dest {dest} || echo "Export to {dest} failed"\n'
854
+ s += "'\n"
720
855
  s += " echo 'Auto-export completed.'\n"
721
856
  s += "else\n"
722
857
  s += " echo 'Evaluation failed with exit code $EVAL_EXIT_CODE. Skipping auto-export.'\n"
@@ -731,11 +866,12 @@ def _open_master_connection(
731
866
  socket: str,
732
867
  ) -> str | None:
733
868
  ssh_command = f"ssh -MNf -S {socket} {username}@{hostname}"
734
- completed_process = subprocess.run(
735
- args=shlex.split(ssh_command), capture_output=True
736
- )
869
+ logger.info("Opening master connection", cmd=ssh_command)
870
+ completed_process = subprocess.run(args=shlex.split(ssh_command))
737
871
  if completed_process.returncode == 0:
872
+ logger.info("Opened master connection successfully", cmd=ssh_command)
738
873
  return socket
874
+ logger.error("Failed to open master connection", code=completed_process.returncode)
739
875
  return None
740
876
 
741
877
 
@@ -747,9 +883,7 @@ def _close_master_connection(
747
883
  if socket is None:
748
884
  return
749
885
  ssh_command = f"ssh -O exit -S {socket} {username}@{hostname}"
750
- completed_process = subprocess.run(
751
- args=shlex.split(ssh_command), capture_output=True
752
- )
886
+ completed_process = subprocess.run(args=shlex.split(ssh_command))
753
887
  if completed_process.returncode != 0:
754
888
  raise RuntimeError(
755
889
  "failed to close the master connection\n{}".format(
@@ -771,8 +905,9 @@ def _make_remote_execution_output_dir(
771
905
  ssh_command.append(f"{username}@{hostname}")
772
906
  ssh_command.append(mkdir_command)
773
907
  ssh_command = " ".join(ssh_command)
908
+ logger.info("Creating remote dir", cmd=ssh_command)
774
909
  completed_process = subprocess.run(
775
- args=shlex.split(ssh_command), capture_output=True
910
+ args=shlex.split(ssh_command), stderr=subprocess.PIPE
776
911
  )
777
912
  if completed_process.returncode != 0:
778
913
  error_msg = (
@@ -780,6 +915,11 @@ def _make_remote_execution_output_dir(
780
915
  if completed_process.stderr
781
916
  else "Unknown error"
782
917
  )
918
+ logger.error(
919
+ "Erorr creating remote dir",
920
+ code=completed_process.returncode,
921
+ msg=error_msg,
922
+ )
783
923
  raise RuntimeError(
784
924
  "failed to make a remote execution output dir\n{}".format(error_msg)
785
925
  )
@@ -807,8 +947,10 @@ def _rsync_upload_rundirs(
807
947
  remote_destination_str = f"{username}@{hostname}:{remote_target}"
808
948
  local_sources_str = " ".join(map(str, local_sources))
809
949
  rsync_upload_command = f"rsync -qcaz {local_sources_str} {remote_destination_str}"
950
+ logger.info("Rsyncing to remote dir", cmd=rsync_upload_command)
810
951
  completed_process = subprocess.run(
811
- args=shlex.split(rsync_upload_command), capture_output=True
952
+ args=shlex.split(rsync_upload_command),
953
+ stderr=subprocess.PIPE,
812
954
  )
813
955
  if completed_process.returncode != 0:
814
956
  error_msg = (
@@ -816,6 +958,12 @@ def _rsync_upload_rundirs(
816
958
  if completed_process.stderr
817
959
  else "Unknown error"
818
960
  )
961
+
962
+ logger.error(
963
+ "Erorr rsyncing to remote dir",
964
+ code=completed_process.returncode,
965
+ msg=error_msg,
966
+ )
819
967
  raise RuntimeError("failed to upload local sources\n{}".format(error_msg))
820
968
 
821
969
 
@@ -837,9 +985,12 @@ def _sbatch_remote_runsubs(
837
985
  ssh_command.append(f"{username}@{hostname}")
838
986
  ssh_command.append(sbatch_commands)
839
987
  ssh_command = " ".join(ssh_command)
840
-
988
+ logger.info("Running sbatch", cmd=ssh_command)
841
989
  completed_process = subprocess.run(
842
- args=shlex.split(ssh_command), capture_output=True
990
+ args=shlex.split(ssh_command),
991
+ # NOTE(agronskiy): look out for hangs and deadlocks
992
+ stdout=subprocess.PIPE,
993
+ stderr=subprocess.PIPE,
843
994
  )
844
995
  if completed_process.returncode != 0:
845
996
  error_msg = completed_process.stderr.decode("utf-8")
@@ -849,6 +1000,7 @@ def _sbatch_remote_runsubs(
849
1000
 
850
1001
  sbatch_output = completed_process.stdout.decode("utf-8")
851
1002
  slurm_job_ids = re.findall(r"(?<=Submitted batch job )\d+", sbatch_output)
1003
+ logger.info("Started sbatch successfully", slurm_job_ids=slurm_job_ids)
852
1004
  return slurm_job_ids
853
1005
 
854
1006
 
@@ -881,7 +1033,10 @@ def _query_slurm_jobs_status(
881
1033
  ssh_command.append(sacct_command)
882
1034
  ssh_command = " ".join(ssh_command)
883
1035
  completed_process = subprocess.run(
884
- args=shlex.split(ssh_command), capture_output=True
1036
+ args=shlex.split(ssh_command),
1037
+ # NOTE(agronskiy): look out for hangs and deadlocks
1038
+ stdout=subprocess.PIPE,
1039
+ stderr=subprocess.PIPE,
885
1040
  )
886
1041
  if completed_process.returncode != 0:
887
1042
  raise RuntimeError(
@@ -930,7 +1085,10 @@ def _kill_slurm_job(
930
1085
  ssh_command = " ".join(ssh_command)
931
1086
 
932
1087
  completed_process = subprocess.run(
933
- args=shlex.split(ssh_command), capture_output=True
1088
+ args=shlex.split(ssh_command),
1089
+ # NOTE(agronskiy): look out for hangs and deadlocks
1090
+ stdout=subprocess.PIPE,
1091
+ stderr=subprocess.PIPE,
934
1092
  )
935
1093
 
936
1094
  # Parse the sacct output (before scancel runs)
@@ -1008,7 +1166,10 @@ def _read_files_from_remote(
1008
1166
  ssh_command.append(cat_commands)
1009
1167
  ssh_command = " ".join(ssh_command)
1010
1168
  completed_process = subprocess.run(
1011
- args=shlex.split(ssh_command), capture_output=True
1169
+ args=shlex.split(ssh_command),
1170
+ # NOTE(agronskiy): look out for hangs and deadlocks
1171
+ stdout=subprocess.PIPE,
1172
+ stderr=subprocess.PIPE,
1012
1173
  )
1013
1174
  if completed_process.returncode != 0:
1014
1175
  raise RuntimeError(
@@ -1085,9 +1246,343 @@ sbatch --dependency=afternotok:$SLURM_JOB_ID $_this_script $SLURM_JOB_ID
1085
1246
  """.strip()
1086
1247
 
1087
1248
 
1088
- _WAIT_FOR_SERVER_HANDLER = """
1089
- date
1090
- # wait for the server to initialize
1091
- bash -c 'while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" {health_url})" != "200" ]]; do kill -0 '"$SERVER_PID"' 2>/dev/null || {{ echo "Server process '"$SERVER_PID"' died"; exit 1; }}; sleep 5; done'
1249
+ def _generate_haproxy_config_with_placeholders(cfg):
1250
+ """Generate HAProxy configuration with placeholder IPs using Jinja template."""
1251
+ # Set up Jinja environment
1252
+ template_dir = Path(__file__).parent
1253
+ template_path = template_dir / "proxy.cfg.template"
1254
+
1255
+ if not template_path.exists():
1256
+ raise FileNotFoundError(f"Proxy template not found: {template_path}")
1257
+
1258
+ env = Environment(loader=FileSystemLoader(template_dir))
1259
+ template = env.get_template("proxy.cfg.template")
1260
+
1261
+ # Prepare template data with placeholder IPs - use actual number of nodes
1262
+ num_nodes = cfg.execution.num_nodes
1263
+ nodes = []
1264
+ for i in range(num_nodes):
1265
+ nodes.append({"ip": f"{{IP_{i}}}", "port": cfg.deployment.port})
1266
+
1267
+ # Get health check parameters from execution config
1268
+ proxy_config = cfg.execution.get("proxy", {}).get("config", {})
1269
+ health_check_path = proxy_config.get("health_check_path", "/health")
1270
+ health_check_status = proxy_config.get("health_check_status", 200)
1271
+ haproxy_port = proxy_config.get("haproxy_port", 5009)
1272
+
1273
+ # Render template
1274
+ config = template.render(
1275
+ haproxy_port=haproxy_port,
1276
+ health_check_path=health_check_path,
1277
+ health_check_status=health_check_status,
1278
+ nodes=nodes,
1279
+ )
1280
+
1281
+ return config
1282
+
1283
+
1284
+ def _generate_haproxy_config(cfg, nodes_ips):
1285
+ """Generate HAProxy configuration using Jinja template."""
1286
+ # Set up Jinja environment
1287
+ template_dir = Path(__file__).parent
1288
+ template_path = template_dir / "proxy.cfg.template"
1289
+
1290
+ if not template_path.exists():
1291
+ raise FileNotFoundError(f"Proxy template not found: {template_path}")
1292
+
1293
+ env = Environment(loader=FileSystemLoader(template_dir))
1294
+ template = env.get_template("proxy.cfg.template")
1295
+
1296
+ # Prepare template data
1297
+ nodes = []
1298
+ for i, ip in enumerate(nodes_ips, 1):
1299
+ nodes.append(
1300
+ {"ip": ip, "port": cfg.deployment.port} # All nodes use the same port
1301
+ )
1302
+
1303
+ # Get health check parameters from deployment config
1304
+ health_check_path = cfg.deployment.get("health_check_path", "/health")
1305
+ health_check_status = cfg.deployment.get("health_check_status", 200)
1306
+ haproxy_port = cfg.deployment.get("haproxy_port", 5009)
1307
+
1308
+ # Render template
1309
+ config = template.render(
1310
+ haproxy_port=haproxy_port,
1311
+ health_check_path=health_check_path,
1312
+ health_check_status=health_check_status,
1313
+ nodes=nodes,
1314
+ )
1315
+
1316
+ return config
1317
+
1318
+
1319
+ def _generate_deployment_srun_command(
1320
+ cfg, deployment_mounts_list, remote_task_subdir, instance_id: int = 0
1321
+ ):
1322
+ """Generate the deployment srun command with proper node/ntask configuration.
1323
+
1324
+ Returns:
1325
+ tuple: (script_string, is_potentially_unsafe, debug_comment)
1326
+ """
1327
+ s = ""
1328
+ debug_comment = ""
1329
+ is_potentially_unsafe = False
1330
+
1331
+ s += "# deployment server\n"
1332
+
1333
+ # Extract pre_cmd for later use inside container
1334
+ pre_cmd: str = cfg.deployment.get("pre_cmd") or ""
1335
+ if pre_cmd:
1336
+ is_potentially_unsafe = True
1337
+ create_pre_script_cmd = _str_to_echo_command(
1338
+ pre_cmd, filename="deployment_pre_cmd.sh"
1339
+ )
1340
+ debug_comment += create_pre_script_cmd.debug + "\n\n"
1341
+
1342
+ s += "# Get node IPs\n"
1343
+ s += "nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) )\n"
1344
+ s += 'nodes_array=("${nodes[@]}") # Ensure nodes are stored properly\n'
1345
+ s += 'export NODES_IPS_ARRAY=($(for node in "${nodes_array[@]}"; do srun --nodelist=$node --ntasks=1 --nodes=1 hostname --ip-address; done))\n'
1346
+ s += 'echo "Node IPs: ${NODES_IPS_ARRAY[@]}"\n'
1347
+ s += "# Export MASTER_IP as the first node IP\n"
1348
+ s += "export MASTER_IP=${NODES_IPS_ARRAY[0]}\n"
1349
+ s += 'echo "MASTER_IP: $MASTER_IP"\n'
1350
+
1351
+ # Add debug comment for deployment pre_cmd before srun command
1352
+ if debug_comment:
1353
+ s += "# Debug contents of deployment pre_cmd\n"
1354
+ s += debug_comment
1355
+ s += "\n"
1356
+
1357
+ s += "srun --mpi pmix --overlap "
1358
+ s += f"--nodes {cfg.execution.num_nodes} --ntasks {cfg.execution.get('deployment', {}).get('n_tasks', 1)} "
1359
+ s += "--container-image {} ".format(cfg.deployment.image)
1360
+ if deployment_mounts_list:
1361
+ s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
1362
+ if not cfg.execution.get("mounts", {}).get("mount_home", True):
1363
+ s += "--no-container-mount-home "
1364
+ s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A-%t.log")
1365
+
1366
+ deployment_env_var_names = list(
1367
+ cfg.execution.get("env_vars", {}).get("deployment", {})
1368
+ )
1369
+ if cfg.deployment.get("env_vars"):
1370
+ warnings.warn(
1371
+ "cfg.deployment.env_vars will be deprecated in future versions. "
1372
+ "Use cfg.execution.env_vars.deployment instead.",
1373
+ category=DeprecationWarning,
1374
+ stacklevel=2,
1375
+ )
1376
+ deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
1377
+
1378
+ # Always add MASTER_IP to the environment variables
1379
+ if "MASTER_IP" not in deployment_env_var_names:
1380
+ deployment_env_var_names.append("MASTER_IP")
1381
+
1382
+ if deployment_env_var_names:
1383
+ s += f"--container-env {','.join(deployment_env_var_names)} "
1384
+
1385
+ # Wrap deployment command to execute pre_cmd inside container if needed
1386
+ if pre_cmd:
1387
+ # Create a wrapper command that runs inside the container:
1388
+ # 1. Create deployment_pre_cmd.sh file
1389
+ # 2. Source it
1390
+ # 3. Execute the original deployment command
1391
+ create_pre_script_cmd = _str_to_echo_command(
1392
+ pre_cmd, filename="deployment_pre_cmd.sh"
1393
+ )
1394
+ # Escape single quotes in the deployment command for bash -c
1395
+ escaped_deployment_cmd = cfg.deployment.command.replace("'", "'\"'\"'")
1396
+ wrapped_command = (
1397
+ f"bash -c '{create_pre_script_cmd.cmd} && "
1398
+ f"source deployment_pre_cmd.sh && "
1399
+ f"{escaped_deployment_cmd}'"
1400
+ )
1401
+ s += "{} &\n\n".format(wrapped_command)
1402
+ else:
1403
+ s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
1404
+
1405
+ s += "SERVER_PID=$! # capture the PID of the server background srun process\n\n"
1406
+
1407
+ return s, is_potentially_unsafe, debug_comment
1408
+
1409
+
1410
+ def _get_wait_for_server_handler(
1411
+ ip_list: str,
1412
+ port: int,
1413
+ health_check_path: str,
1414
+ service_name: str = "server",
1415
+ check_pid: bool = False,
1416
+ ):
1417
+ """Generate wait for server handler that takes a list of IPs."""
1418
+ pid_check = ""
1419
+ if check_pid:
1420
+ pid_check = 'kill -0 "$SERVER_PID" 2>/dev/null || { echo "Server process $SERVER_PID died"; exit 1; }'
1421
+
1422
+ handler = f"""date
1423
+ # wait for the {service_name} to initialize
1424
+ for ip in {ip_list}; do
1425
+ echo "Waiting for {service_name} on $ip..."
1426
+ while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" http://$ip:{port}{health_check_path})" != "200" ]]; do
1427
+ {pid_check}
1428
+ sleep 5
1429
+ done
1430
+ echo "{service_name} ready on $ip!"
1431
+ done
1092
1432
  date
1093
1433
  """.strip()
1434
+
1435
+ return handler
1436
+
1437
+
1438
+ def _get_proxy_server_srun_command(cfg, remote_task_subdir):
1439
+ """Generate proxy server srun command based on proxy type."""
1440
+ proxy_type = cfg.execution.get("proxy", {}).get("type", "haproxy")
1441
+
1442
+ if proxy_type == "haproxy":
1443
+ return _generate_haproxy_srun_command(cfg, remote_task_subdir)
1444
+ else:
1445
+ raise ValueError(
1446
+ f"Unsupported proxy type: {proxy_type}. Currently only 'haproxy' is supported."
1447
+ )
1448
+
1449
+
1450
+ def _generate_haproxy_srun_command(cfg, remote_task_subdir):
1451
+ """Generate HAProxy-specific srun command using template-based config."""
1452
+ s = ""
1453
+ s += "# Proxy load balancer\n"
1454
+ s += "# Copy template to config file (important for restarts)\n"
1455
+ s += f"cp {remote_task_subdir}/proxy.cfg.template {remote_task_subdir}/proxy.cfg\n"
1456
+ s += "# Replace placeholder IPs with actual node IPs\n"
1457
+ s += f"proxy_config_file={remote_task_subdir}/proxy.cfg\n"
1458
+ s += 'for i in "${!NODES_IPS_ARRAY[@]}"; do\n'
1459
+ s += ' ip="${NODES_IPS_ARRAY[$i]}"\n'
1460
+ s += ' sed -i "s/{IP_$i}/$ip/g" "$proxy_config_file"\n'
1461
+ s += "done\n"
1462
+ s += "\n"
1463
+ s += "srun --mpi pmix --overlap "
1464
+ s += "--nodes 1 --ntasks 1 "
1465
+ s += f"--container-image {cfg.execution.get('proxy', {}).get('image', 'haproxy:latest')} "
1466
+ s += f"--container-mounts {remote_task_subdir}/proxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro "
1467
+ s += f"--output {remote_task_subdir}/logs/proxy-%A.log "
1468
+ s += "haproxy -f /usr/local/etc/haproxy/haproxy.cfg &\n"
1469
+ s += "PROXY_PID=$! # capture the PID of the proxy background srun process\n"
1470
+ s += 'echo "Proxy started with PID: $PROXY_PID"\n\n'
1471
+
1472
+ # Wait for proxy to be ready on localhost
1473
+ proxy_config = cfg.execution.get("proxy", {}).get("config", {})
1474
+ haproxy_port = proxy_config.get("haproxy_port", 5009)
1475
+ health_path = proxy_config.get("health_check_path", "/health")
1476
+ s += _get_wait_for_server_handler(
1477
+ "127.0.0.1", haproxy_port, health_path, "Proxy", check_pid=False
1478
+ )
1479
+ s += "\n"
1480
+
1481
+ return s
1482
+
1483
+
1484
+ def _collect_mount_paths(cfg: DictConfig) -> List[str]:
1485
+ """Collect all mount source paths from the configuration.
1486
+
1487
+ Args:
1488
+ cfg: The configuration object for the evaluation run.
1489
+
1490
+ Returns:
1491
+ List of source paths that need to be mounted.
1492
+ """
1493
+ mount_paths = []
1494
+
1495
+ # Deployment mounts
1496
+ if cfg.deployment.type != "none":
1497
+ if checkpoint_path := cfg.deployment.get("checkpoint_path"):
1498
+ mount_paths.append(checkpoint_path)
1499
+ if cache_path := cfg.deployment.get("cache_path"):
1500
+ mount_paths.append(cache_path)
1501
+ for source_mnt in cfg.execution.get("mounts", {}).get("deployment", {}).keys():
1502
+ mount_paths.append(source_mnt)
1503
+
1504
+ # Evaluation mounts
1505
+ for source_mnt in cfg.execution.get("mounts", {}).get("evaluation", {}).keys():
1506
+ mount_paths.append(source_mnt)
1507
+
1508
+ return mount_paths
1509
+
1510
+
1511
+ def _validate_remote_paths_exist(
1512
+ paths: List[str],
1513
+ username: str,
1514
+ hostname: str,
1515
+ socket: str | None,
1516
+ ) -> None:
1517
+ """Validate that all specified paths exist as directories on the remote host.
1518
+
1519
+ Args:
1520
+ paths: List of directory paths to validate.
1521
+ username: SSH username.
1522
+ hostname: SSH hostname.
1523
+ socket: control socket location or None
1524
+
1525
+ Raises:
1526
+ ValueError: If any paths do not exist as directories on the remote host.
1527
+ """
1528
+ if not paths:
1529
+ return
1530
+
1531
+ # Remove duplicates while preserving order
1532
+ unique_paths = list(dict.fromkeys(paths))
1533
+
1534
+ # Build a single SSH command to check all paths at once
1535
+ test_commands = []
1536
+ for path in unique_paths:
1537
+ # Use test -d to check if directory exists
1538
+ # Escape single quotes in path using POSIX-safe method: ' becomes '"'"'
1539
+ escaped_path = path.replace("'", "'\"'\"'")
1540
+ test_commands.append(
1541
+ f"test -d '{escaped_path}' && echo 'EXISTS:{path}' || echo 'MISSING:{path}'"
1542
+ )
1543
+
1544
+ combined_command = " ; ".join(test_commands)
1545
+
1546
+ ssh_command = ["ssh"]
1547
+ if socket is not None:
1548
+ ssh_command.append(f"-S {socket}")
1549
+ ssh_command.append(f"{username}@{hostname}")
1550
+ ssh_command.append(combined_command)
1551
+ ssh_command = " ".join(ssh_command)
1552
+
1553
+ logger.info("Validating mount directories exist on remote host", cmd=ssh_command)
1554
+ completed_process = subprocess.run(
1555
+ args=shlex.split(ssh_command),
1556
+ stdout=subprocess.PIPE,
1557
+ stderr=subprocess.PIPE,
1558
+ )
1559
+
1560
+ if completed_process.returncode != 0:
1561
+ error_msg = (
1562
+ completed_process.stderr.decode("utf-8")
1563
+ if completed_process.stderr
1564
+ else "Unknown error"
1565
+ )
1566
+ logger.error(
1567
+ "Error validating remote paths",
1568
+ code=completed_process.returncode,
1569
+ msg=error_msg,
1570
+ )
1571
+ raise RuntimeError(f"Failed to validate remote paths: {error_msg}")
1572
+
1573
+ # Parse output to find missing paths
1574
+ output = completed_process.stdout.decode("utf-8")
1575
+ missing_paths = []
1576
+ for line in output.strip().split("\n"):
1577
+ if line.startswith("MISSING:"):
1578
+ missing_path = line.replace("MISSING:", "")
1579
+ missing_paths.append(missing_path)
1580
+
1581
+ if missing_paths:
1582
+ error_message = (
1583
+ f"The following mount paths do not exist as directories on {username}@{hostname}:\n"
1584
+ + "\n".join(f" - {path}" for path in missing_paths)
1585
+ + "\n\nMount paths must be directories. Please create these directories on the cluster or update your configuration."
1586
+ )
1587
+ logger.error("Mount validation failed", missing_paths=missing_paths)
1588
+ raise ValueError(error_message)