nemo-evaluator-launcher 0.1.19__py3-none-any.whl → 0.1.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. nemo_evaluator_launcher/api/functional.py +105 -1
  2. nemo_evaluator_launcher/cli/logs.py +102 -0
  3. nemo_evaluator_launcher/cli/main.py +12 -0
  4. nemo_evaluator_launcher/cli/run.py +73 -15
  5. nemo_evaluator_launcher/cli/version.py +26 -23
  6. nemo_evaluator_launcher/common/helpers.py +176 -43
  7. nemo_evaluator_launcher/common/logging_utils.py +16 -5
  8. nemo_evaluator_launcher/common/printing_utils.py +7 -0
  9. nemo_evaluator_launcher/configs/deployment/sglang.yaml +4 -2
  10. nemo_evaluator_launcher/configs/deployment/trtllm.yaml +2 -3
  11. nemo_evaluator_launcher/configs/deployment/vllm.yaml +0 -1
  12. nemo_evaluator_launcher/configs/execution/slurm/default.yaml +14 -0
  13. nemo_evaluator_launcher/executors/base.py +31 -1
  14. nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +36 -1
  15. nemo_evaluator_launcher/executors/lepton/executor.py +81 -1
  16. nemo_evaluator_launcher/executors/local/executor.py +377 -22
  17. nemo_evaluator_launcher/executors/local/run.template.sh +54 -2
  18. nemo_evaluator_launcher/executors/slurm/executor.py +422 -59
  19. nemo_evaluator_launcher/executors/slurm/proxy.cfg.template +26 -0
  20. nemo_evaluator_launcher/exporters/utils.py +32 -46
  21. nemo_evaluator_launcher/package_info.py +1 -1
  22. nemo_evaluator_launcher/resources/mapping.toml +56 -15
  23. {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/METADATA +3 -3
  24. {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/RECORD +28 -26
  25. {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/entry_points.txt +1 -0
  26. {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/WHEEL +0 -0
  27. {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/licenses/LICENSE +0 -0
  28. {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ from pathlib import Path
30
30
  from typing import Dict, List, Optional
31
31
 
32
32
  import yaml
33
+ from jinja2 import Environment, FileSystemLoader
33
34
  from omegaconf import DictConfig, OmegaConf
34
35
 
35
36
  from nemo_evaluator_launcher.common.execdb import (
@@ -39,18 +40,19 @@ from nemo_evaluator_launcher.common.execdb import (
39
40
  generate_job_id,
40
41
  )
41
42
  from nemo_evaluator_launcher.common.helpers import (
43
+ CmdAndReadableComment,
44
+ _str_to_echo_command,
42
45
  get_api_key_name,
43
- get_endpoint_url,
44
46
  get_eval_factory_command,
45
47
  get_eval_factory_dataset_size_from_run_config,
46
- get_health_url,
47
48
  get_timestamp_string,
48
49
  )
50
+ from nemo_evaluator_launcher.common.logging_utils import logger
49
51
  from nemo_evaluator_launcher.common.mapping import (
50
52
  get_task_from_mapping,
51
53
  load_tasks_mapping,
52
54
  )
53
- from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey
55
+ from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey, red
54
56
  from nemo_evaluator_launcher.executors.base import (
55
57
  BaseExecutor,
56
58
  ExecutionState,
@@ -94,6 +96,7 @@ class SlurmExecutor(BaseExecutor):
94
96
  tasks_mapping = load_tasks_mapping()
95
97
  eval_images: list[str] = []
96
98
 
99
+ is_potentially_unsafe = False
97
100
  for idx, task in enumerate(cfg.evaluation.tasks):
98
101
  # calculate job_id
99
102
  job_id = f"{invocation_id}.{idx}"
@@ -114,7 +117,7 @@ class SlurmExecutor(BaseExecutor):
114
117
  eval_images.append(eval_image)
115
118
 
116
119
  # generate and write down sbatch script
117
- sbatch_script_content_str = _create_slurm_sbatch_script(
120
+ sbatch_script_content_struct = _create_slurm_sbatch_script(
118
121
  cfg=cfg,
119
122
  task=task,
120
123
  eval_image=eval_image,
@@ -122,6 +125,32 @@ class SlurmExecutor(BaseExecutor):
122
125
  invocation_id=invocation_id,
123
126
  job_id=job_id,
124
127
  )
128
+
129
+ # Create proxy config file with placeholder IPs for multi-instance deployments
130
+ if cfg.deployment.get("multiple_instances", False):
131
+ proxy_type = cfg.execution.get("proxy", {}).get("type", "haproxy")
132
+ if proxy_type == "haproxy":
133
+ proxy_config = _generate_haproxy_config_with_placeholders(cfg)
134
+ else:
135
+ raise ValueError(
136
+ f"Unsupported proxy type: {proxy_type}. Currently only 'haproxy' is supported."
137
+ )
138
+
139
+ # Save both template and working config
140
+ proxy_template_path = local_task_subdir / "proxy.cfg.template"
141
+ proxy_config_path = local_task_subdir / "proxy.cfg"
142
+ with open(proxy_template_path, "w") as f:
143
+ f.write(proxy_config)
144
+ with open(proxy_config_path, "w") as f:
145
+ f.write(proxy_config)
146
+
147
+ sbatch_script_content_str = sbatch_script_content_struct.cmd
148
+
149
+ # We accumulate if any task contains unsafe commands
150
+ is_potentially_unsafe = (
151
+ is_potentially_unsafe
152
+ or sbatch_script_content_struct.is_potentially_unsafe
153
+ )
125
154
  local_runsub_path = local_task_subdir / "run.sub"
126
155
  remote_runsub_path = remote_task_subdir / "run.sub"
127
156
  with open(local_runsub_path, "w") as f:
@@ -138,8 +167,34 @@ class SlurmExecutor(BaseExecutor):
138
167
  with open(local_runsub_path, "r") as f:
139
168
  print(grey(f.read()))
140
169
  print(bold("To submit jobs") + ", run the executor without --dry-run")
170
+ if is_potentially_unsafe:
171
+ print(
172
+ red(
173
+ "\nFound `pre_cmd` (evaluation or deployment) which carries security risk. When running without --dry-run "
174
+ "make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1"
175
+ )
176
+ )
177
+
141
178
  return invocation_id
142
179
 
180
+ if is_potentially_unsafe:
181
+ if os.environ.get("NEMO_EVALUATOR_TRUST_PRE_CMD", "") == "1":
182
+ logger.warning(
183
+ "Found non-empty commands (e.g. `pre_cmd` in evaluation or deployment) and NEMO_EVALUATOR_TRUST_PRE_CMD "
184
+ "is set, proceeding with caution."
185
+ )
186
+
187
+ else:
188
+ logger.error(
189
+ "Found non-empty commands (e.g. `pre_cmd` in evaluation or deployment) and NEMO_EVALUATOR_TRUST_PRE_CMD "
190
+ "is not set. This might carry security risk and unstable environments. "
191
+ "To continue, make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1.",
192
+ )
193
+ raise AttributeError(
194
+ "Untrusted command found in config, make sure you trust and "
195
+ "set NEMO_EVALUATOR_TRUST_PRE_CMD=1."
196
+ )
197
+
143
198
  socket = str(Path(tmpdirname) / "socket")
144
199
  socket_or_none = _open_master_connection(
145
200
  username=cfg.execution.username,
@@ -437,7 +492,7 @@ def _create_slurm_sbatch_script(
437
492
  remote_task_subdir: Path,
438
493
  invocation_id: str,
439
494
  job_id: str,
440
- ) -> str:
495
+ ) -> CmdAndReadableComment:
441
496
  """Generate the contents of a SLURM sbatch script for a given evaluation task.
442
497
 
443
498
  Args:
@@ -453,7 +508,6 @@ def _create_slurm_sbatch_script(
453
508
  # get task from mapping, overrides, urls
454
509
  tasks_mapping = load_tasks_mapping()
455
510
  task_definition = get_task_from_mapping(task.name, tasks_mapping)
456
- health_url = get_health_url(cfg, get_endpoint_url(cfg, task, task_definition))
457
511
 
458
512
  # TODO(public release): convert to template
459
513
  s = "#!/bin/bash\n"
@@ -468,6 +522,8 @@ def _create_slurm_sbatch_script(
468
522
  s += "#SBATCH --gpus-per-node {}\n".format(cfg.execution.gpus_per_node)
469
523
  if hasattr(cfg.execution, "gres"):
470
524
  s += "#SBATCH --gres {}\n".format(cfg.execution.gres)
525
+ if cfg.execution.get("sbatch_comment"):
526
+ s += "#SBATCH --comment='{}'\n".format(cfg.execution.sbatch_comment)
471
527
  job_name = "{account}-{subproject}.{details}".format(
472
528
  account=cfg.execution.account,
473
529
  subproject=cfg.execution.subproject,
@@ -493,8 +549,11 @@ def _create_slurm_sbatch_script(
493
549
  if os.getenv(env_var) is None:
494
550
  raise ValueError(f"Trying to pass an unset environment variable {env_var}.")
495
551
 
496
- # check if required env vars are defined:
552
+ # check if required env vars are defined (excluding NEMO_EVALUATOR_DATASET_DIR which is handled separately):
497
553
  for required_env_var in task_definition.get("required_env_vars", []):
554
+ # Skip NEMO_EVALUATOR_DATASET_DIR as it's handled by dataset mounting logic below
555
+ if required_env_var == "NEMO_EVALUATOR_DATASET_DIR":
556
+ continue
498
557
  if required_env_var not in env_vars.keys():
499
558
  raise ValueError(
500
559
  f"{task.name} task requires environment variable {required_env_var}."
@@ -540,6 +599,7 @@ def _create_slurm_sbatch_script(
540
599
 
541
600
  # prepare deployment mounts
542
601
  deployment_mounts_list = []
602
+ deployment_is_unsafe = False
543
603
  if cfg.deployment.type != "none":
544
604
  if checkpoint_path := cfg.deployment.get("checkpoint_path"):
545
605
  deployment_mounts_list.append(f"{checkpoint_path}:/checkpoint:ro")
@@ -551,36 +611,33 @@ def _create_slurm_sbatch_script(
551
611
  deployment_mounts_list.append(f"{source_mnt}:{target_mnt}")
552
612
 
553
613
  # add deployment srun command
554
- s += "# deployment server\n"
555
- s += "srun --mpi pmix --overlap "
556
- s += "--container-image {} ".format(cfg.deployment.image)
557
- if deployment_mounts_list:
558
- s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
559
- if not cfg.execution.get("mounts", {}).get("mount_home", True):
560
- s += "--no-container-mount-home "
561
- s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A.out")
562
- deployment_env_var_names = list(
563
- cfg.execution.get("env_vars", {}).get("deployment", {})
564
- )
565
- if cfg.deployment.get("env_vars"):
566
- warnings.warn(
567
- "cfg.deployment.env_vars will be deprecated in future versions. "
568
- "Use cfg.execution.env_vars.deployment instead.",
569
- category=DeprecationWarning,
570
- stacklevel=2,
614
+ deployment_srun_cmd, deployment_is_unsafe, deployment_debug = (
615
+ _generate_deployment_srun_command(
616
+ cfg, deployment_mounts_list, remote_task_subdir
571
617
  )
572
- deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
573
- if deployment_env_var_names:
574
- s += f"--container-env {','.join(deployment_env_var_names)} "
575
- s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
576
- s += (
577
- "SERVER_PID=$! # capture the PID of the server background srun process\n\n"
578
618
  )
619
+ s += deployment_srun_cmd
579
620
 
580
621
  # wait for the server to initialize
581
- s += _WAIT_FOR_SERVER_HANDLER.format(health_url=health_url)
622
+ health_path = cfg.deployment.get("health_check_path", "/health")
623
+ # For multi-instance check all node IPs, for single instance check localhost
624
+ if cfg.deployment.get("multiple_instances", False):
625
+ ip_list = '"${NODES_IPS_ARRAY[@]}"'
626
+ else:
627
+ ip_list = '"127.0.0.1"'
628
+ s += _get_wait_for_server_handler(
629
+ ip_list,
630
+ cfg.deployment.port,
631
+ health_path,
632
+ "server",
633
+ check_pid=True,
634
+ )
582
635
  s += "\n\n"
583
636
 
637
+ # add proxy load balancer for multi-instance deployments
638
+ if cfg.deployment.get("multiple_instances", False):
639
+ s += _get_proxy_server_srun_command(cfg, remote_task_subdir)
640
+
584
641
  # prepare evaluation mounts
585
642
  evaluation_mounts_list = [
586
643
  "{}:/results".format(remote_task_subdir / "artifacts"),
@@ -590,7 +647,29 @@ def _create_slurm_sbatch_script(
590
647
  ):
591
648
  evaluation_mounts_list.append(f"{source_mnt}:{target_mnt}")
592
649
 
593
- eval_factory_command_struct = get_eval_factory_command(cfg, task, task_definition)
650
+ # Handle dataset directory mounting if NEMO_EVALUATOR_DATASET_DIR is required
651
+ if "NEMO_EVALUATOR_DATASET_DIR" in task_definition.get("required_env_vars", []):
652
+ # Get dataset directory from task config
653
+ if "dataset_dir" in task:
654
+ dataset_mount_host = task["dataset_dir"]
655
+ else:
656
+ raise ValueError(
657
+ f"{task.name} task requires a dataset_dir to be specified. "
658
+ f"Add 'dataset_dir: /path/to/your/dataset' under the task configuration."
659
+ )
660
+ # Get container mount path (default to /datasets if not specified)
661
+ dataset_mount_container = task.get("dataset_mount_path", "/datasets")
662
+ # Add dataset mount to evaluation mounts list
663
+ evaluation_mounts_list.append(f"{dataset_mount_host}:{dataset_mount_container}")
664
+ # Export NEMO_EVALUATOR_DATASET_DIR environment variable
665
+ s += f"export NEMO_EVALUATOR_DATASET_DIR={dataset_mount_container}\n\n"
666
+
667
+ eval_factory_command_struct = get_eval_factory_command(
668
+ cfg,
669
+ task,
670
+ task_definition,
671
+ )
672
+
594
673
  eval_factory_command = eval_factory_command_struct.cmd
595
674
  # The debug comment for placing into the script and easy debug. Reason
596
675
  # (see `CmdAndReadableComment`) is the current way of passing the command
@@ -606,6 +685,7 @@ def _create_slurm_sbatch_script(
606
685
 
607
686
  s += "# evaluation client\n"
608
687
  s += "srun --mpi pmix --overlap "
688
+ s += "--nodes 1 --ntasks 1 " # Client always runs on single node
609
689
  s += "--container-image {} ".format(eval_image)
610
690
  evaluation_env_var_names = list(
611
691
  cfg.execution.get("env_vars", {}).get("evaluation", {})
@@ -623,7 +703,10 @@ def _create_slurm_sbatch_script(
623
703
 
624
704
  # terminate the server after all evaluation clients finish
625
705
  if cfg.deployment.type != "none":
626
- s += "kill $SERVER_PID # terminate the server to finish gracefully\n\n"
706
+ s += "kill $SERVER_PID # terminate the server to finish gracefully\n"
707
+ if cfg.deployment.get("multiple_instances", False):
708
+ s += "kill $PROXY_PID # terminate proxy to finish gracefully\n"
709
+ s += "\n"
627
710
 
628
711
  # auto-export
629
712
  ae_cfg = cfg.execution.get("auto_export")
@@ -635,9 +718,22 @@ def _create_slurm_sbatch_script(
635
718
 
636
719
  if destinations:
637
720
  export_env = dict(cfg.execution.get("env_vars", {}).get("export", {}) or {})
638
- s += _generate_auto_export_section(cfg, job_id, destinations, export_env)
721
+ s += _generate_auto_export_section(
722
+ cfg, job_id, destinations, export_env, remote_task_subdir
723
+ )
639
724
 
640
- return s
725
+ debug_str = "\n".join(["# " + line for line in s.splitlines()])
726
+
727
+ # Combine unsafe flags from both deployment and evaluation
728
+ is_potentially_unsafe = (
729
+ eval_factory_command_struct.is_potentially_unsafe or deployment_is_unsafe
730
+ )
731
+
732
+ return CmdAndReadableComment(
733
+ cmd=s,
734
+ debug=debug_str,
735
+ is_potentially_unsafe=is_potentially_unsafe,
736
+ )
641
737
 
642
738
 
643
739
  def _generate_auto_export_section(
@@ -645,6 +741,8 @@ def _generate_auto_export_section(
645
741
  job_id: str,
646
742
  destinations: list,
647
743
  export_env: dict,
744
+ remote_task_subdir: Path,
745
+ export_image: str = "python:3.12.7-slim",
648
746
  ) -> str:
649
747
  """Generate simple auto-export section for sbatch script."""
650
748
  if not destinations:
@@ -654,10 +752,7 @@ def _generate_auto_export_section(
654
752
  s += "EVAL_EXIT_CODE=$?\n"
655
753
  s += "if [ $EVAL_EXIT_CODE -eq 0 ]; then\n"
656
754
  s += " echo 'Evaluation completed successfully. Starting auto-export...'\n"
657
- s += " set +e\n"
658
- s += " set +x\n"
659
- s += " set +u\n"
660
- s += ' cd "$TASK_DIR/artifacts"\n'
755
+ s += f' cd "{remote_task_subdir}/artifacts"\n'
661
756
 
662
757
  # Work with DictConfig; convert only for YAML at the end
663
758
  exec_type = (
@@ -713,10 +808,25 @@ def _generate_auto_export_section(
713
808
  esc = str(v).replace('"', '\\"')
714
809
  s += f' export {k}="{esc}"\n'
715
810
 
716
- for dest in destinations:
717
- s += f" echo 'Exporting to {dest}...'\n"
718
- s += f" nemo-evaluator-launcher export {job_id} --dest {dest} || echo 'Export to {dest} failed'\n"
811
+ s += " # export\n"
812
+ s += " srun --mpi pmix --overlap "
813
+ s += "--nodes 1 --ntasks 1 " # Client always runs on single node
814
+ s += "--container-image {} ".format(export_image)
815
+ if export_env:
816
+ s += "--container-env {} ".format(",".join(export_env))
817
+ if not cfg.execution.get("mounts", {}).get("mount_home", True):
818
+ s += "--no-container-mount-home "
719
819
 
820
+ s += f"--container-mounts {remote_task_subdir}/artifacts:{remote_task_subdir}/artifacts "
821
+ s += "--output {} ".format(remote_task_subdir / "logs" / "export-%A.out")
822
+ s += " bash -c '\n"
823
+ # FIXME(martas): would be good to install specific version
824
+ s += " pip install nemo-evaluator-launcher[all]\n"
825
+ s += f" cd {remote_task_subdir}/artifacts\n"
826
+ for dest in destinations:
827
+ s += f' echo "Exporting to {dest}..."\n'
828
+ s += f' nemo-evaluator-launcher export {job_id} --dest {dest} || echo "Export to {dest} failed"\n'
829
+ s += "'\n"
720
830
  s += " echo 'Auto-export completed.'\n"
721
831
  s += "else\n"
722
832
  s += " echo 'Evaluation failed with exit code $EVAL_EXIT_CODE. Skipping auto-export.'\n"
@@ -731,11 +841,12 @@ def _open_master_connection(
731
841
  socket: str,
732
842
  ) -> str | None:
733
843
  ssh_command = f"ssh -MNf -S {socket} {username}@{hostname}"
734
- completed_process = subprocess.run(
735
- args=shlex.split(ssh_command), capture_output=True
736
- )
844
+ logger.info("Opening master connection", cmd=ssh_command)
845
+ completed_process = subprocess.run(args=shlex.split(ssh_command))
737
846
  if completed_process.returncode == 0:
847
+ logger.info("Opened master connection successfully", cmd=ssh_command)
738
848
  return socket
849
+ logger.error("Failed to open master connection", code=completed_process.returncode)
739
850
  return None
740
851
 
741
852
 
@@ -747,9 +858,7 @@ def _close_master_connection(
747
858
  if socket is None:
748
859
  return
749
860
  ssh_command = f"ssh -O exit -S {socket} {username}@{hostname}"
750
- completed_process = subprocess.run(
751
- args=shlex.split(ssh_command), capture_output=True
752
- )
861
+ completed_process = subprocess.run(args=shlex.split(ssh_command))
753
862
  if completed_process.returncode != 0:
754
863
  raise RuntimeError(
755
864
  "failed to close the master connection\n{}".format(
@@ -771,8 +880,9 @@ def _make_remote_execution_output_dir(
771
880
  ssh_command.append(f"{username}@{hostname}")
772
881
  ssh_command.append(mkdir_command)
773
882
  ssh_command = " ".join(ssh_command)
883
+ logger.info("Creating remote dir", cmd=ssh_command)
774
884
  completed_process = subprocess.run(
775
- args=shlex.split(ssh_command), capture_output=True
885
+ args=shlex.split(ssh_command), stderr=subprocess.PIPE
776
886
  )
777
887
  if completed_process.returncode != 0:
778
888
  error_msg = (
@@ -780,6 +890,11 @@ def _make_remote_execution_output_dir(
780
890
  if completed_process.stderr
781
891
  else "Unknown error"
782
892
  )
893
+ logger.error(
894
+ "Erorr creating remote dir",
895
+ code=completed_process.returncode,
896
+ msg=error_msg,
897
+ )
783
898
  raise RuntimeError(
784
899
  "failed to make a remote execution output dir\n{}".format(error_msg)
785
900
  )
@@ -807,8 +922,10 @@ def _rsync_upload_rundirs(
807
922
  remote_destination_str = f"{username}@{hostname}:{remote_target}"
808
923
  local_sources_str = " ".join(map(str, local_sources))
809
924
  rsync_upload_command = f"rsync -qcaz {local_sources_str} {remote_destination_str}"
925
+ logger.info("Rsyncing to remote dir", cmd=rsync_upload_command)
810
926
  completed_process = subprocess.run(
811
- args=shlex.split(rsync_upload_command), capture_output=True
927
+ args=shlex.split(rsync_upload_command),
928
+ stderr=subprocess.PIPE,
812
929
  )
813
930
  if completed_process.returncode != 0:
814
931
  error_msg = (
@@ -816,6 +933,12 @@ def _rsync_upload_rundirs(
816
933
  if completed_process.stderr
817
934
  else "Unknown error"
818
935
  )
936
+
937
+ logger.error(
938
+ "Erorr rsyncing to remote dir",
939
+ code=completed_process.returncode,
940
+ msg=error_msg,
941
+ )
819
942
  raise RuntimeError("failed to upload local sources\n{}".format(error_msg))
820
943
 
821
944
 
@@ -837,9 +960,12 @@ def _sbatch_remote_runsubs(
837
960
  ssh_command.append(f"{username}@{hostname}")
838
961
  ssh_command.append(sbatch_commands)
839
962
  ssh_command = " ".join(ssh_command)
840
-
963
+ logger.info("Running sbatch", cmd=ssh_command)
841
964
  completed_process = subprocess.run(
842
- args=shlex.split(ssh_command), capture_output=True
965
+ args=shlex.split(ssh_command),
966
+ # NOTE(agronskiy): look out for hangs and deadlocks
967
+ stdout=subprocess.PIPE,
968
+ stderr=subprocess.PIPE,
843
969
  )
844
970
  if completed_process.returncode != 0:
845
971
  error_msg = completed_process.stderr.decode("utf-8")
@@ -849,6 +975,7 @@ def _sbatch_remote_runsubs(
849
975
 
850
976
  sbatch_output = completed_process.stdout.decode("utf-8")
851
977
  slurm_job_ids = re.findall(r"(?<=Submitted batch job )\d+", sbatch_output)
978
+ logger.info("Started sbatch successfully", slurm_job_ids=slurm_job_ids)
852
979
  return slurm_job_ids
853
980
 
854
981
 
@@ -881,7 +1008,10 @@ def _query_slurm_jobs_status(
881
1008
  ssh_command.append(sacct_command)
882
1009
  ssh_command = " ".join(ssh_command)
883
1010
  completed_process = subprocess.run(
884
- args=shlex.split(ssh_command), capture_output=True
1011
+ args=shlex.split(ssh_command),
1012
+ # NOTE(agronskiy): look out for hangs and deadlocks
1013
+ stdout=subprocess.PIPE,
1014
+ stderr=subprocess.PIPE,
885
1015
  )
886
1016
  if completed_process.returncode != 0:
887
1017
  raise RuntimeError(
@@ -930,7 +1060,10 @@ def _kill_slurm_job(
930
1060
  ssh_command = " ".join(ssh_command)
931
1061
 
932
1062
  completed_process = subprocess.run(
933
- args=shlex.split(ssh_command), capture_output=True
1063
+ args=shlex.split(ssh_command),
1064
+ # NOTE(agronskiy): look out for hangs and deadlocks
1065
+ stdout=subprocess.PIPE,
1066
+ stderr=subprocess.PIPE,
934
1067
  )
935
1068
 
936
1069
  # Parse the sacct output (before scancel runs)
@@ -1008,7 +1141,10 @@ def _read_files_from_remote(
1008
1141
  ssh_command.append(cat_commands)
1009
1142
  ssh_command = " ".join(ssh_command)
1010
1143
  completed_process = subprocess.run(
1011
- args=shlex.split(ssh_command), capture_output=True
1144
+ args=shlex.split(ssh_command),
1145
+ # NOTE(agronskiy): look out for hangs and deadlocks
1146
+ stdout=subprocess.PIPE,
1147
+ stderr=subprocess.PIPE,
1012
1148
  )
1013
1149
  if completed_process.returncode != 0:
1014
1150
  raise RuntimeError(
@@ -1085,9 +1221,236 @@ sbatch --dependency=afternotok:$SLURM_JOB_ID $_this_script $SLURM_JOB_ID
1085
1221
  """.strip()
1086
1222
 
1087
1223
 
1088
- _WAIT_FOR_SERVER_HANDLER = """
1089
- date
1090
- # wait for the server to initialize
1091
- bash -c 'while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" {health_url})" != "200" ]]; do kill -0 '"$SERVER_PID"' 2>/dev/null || {{ echo "Server process '"$SERVER_PID"' died"; exit 1; }}; sleep 5; done'
1224
+ def _generate_haproxy_config_with_placeholders(cfg):
1225
+ """Generate HAProxy configuration with placeholder IPs using Jinja template."""
1226
+ # Set up Jinja environment
1227
+ template_dir = Path(__file__).parent
1228
+ template_path = template_dir / "proxy.cfg.template"
1229
+
1230
+ if not template_path.exists():
1231
+ raise FileNotFoundError(f"Proxy template not found: {template_path}")
1232
+
1233
+ env = Environment(loader=FileSystemLoader(template_dir))
1234
+ template = env.get_template("proxy.cfg.template")
1235
+
1236
+ # Prepare template data with placeholder IPs - use actual number of nodes
1237
+ num_nodes = cfg.execution.num_nodes
1238
+ nodes = []
1239
+ for i in range(num_nodes):
1240
+ nodes.append({"ip": f"{{IP_{i}}}", "port": cfg.deployment.port})
1241
+
1242
+ # Get health check parameters from execution config
1243
+ proxy_config = cfg.execution.get("proxy", {}).get("config", {})
1244
+ health_check_path = proxy_config.get("health_check_path", "/health")
1245
+ health_check_status = proxy_config.get("health_check_status", 200)
1246
+ haproxy_port = proxy_config.get("haproxy_port", 5009)
1247
+
1248
+ # Render template
1249
+ config = template.render(
1250
+ haproxy_port=haproxy_port,
1251
+ health_check_path=health_check_path,
1252
+ health_check_status=health_check_status,
1253
+ nodes=nodes,
1254
+ )
1255
+
1256
+ return config
1257
+
1258
+
1259
+ def _generate_haproxy_config(cfg, nodes_ips):
1260
+ """Generate HAProxy configuration using Jinja template."""
1261
+ # Set up Jinja environment
1262
+ template_dir = Path(__file__).parent
1263
+ template_path = template_dir / "proxy.cfg.template"
1264
+
1265
+ if not template_path.exists():
1266
+ raise FileNotFoundError(f"Proxy template not found: {template_path}")
1267
+
1268
+ env = Environment(loader=FileSystemLoader(template_dir))
1269
+ template = env.get_template("proxy.cfg.template")
1270
+
1271
+ # Prepare template data
1272
+ nodes = []
1273
+ for i, ip in enumerate(nodes_ips, 1):
1274
+ nodes.append(
1275
+ {"ip": ip, "port": cfg.deployment.port} # All nodes use the same port
1276
+ )
1277
+
1278
+ # Get health check parameters from deployment config
1279
+ health_check_path = cfg.deployment.get("health_check_path", "/health")
1280
+ health_check_status = cfg.deployment.get("health_check_status", 200)
1281
+ haproxy_port = cfg.deployment.get("haproxy_port", 5009)
1282
+
1283
+ # Render template
1284
+ config = template.render(
1285
+ haproxy_port=haproxy_port,
1286
+ health_check_path=health_check_path,
1287
+ health_check_status=health_check_status,
1288
+ nodes=nodes,
1289
+ )
1290
+
1291
+ return config
1292
+
1293
+
1294
+ def _generate_deployment_srun_command(
1295
+ cfg, deployment_mounts_list, remote_task_subdir, instance_id: int = 0
1296
+ ):
1297
+ """Generate the deployment srun command with proper node/ntask configuration.
1298
+
1299
+ Returns:
1300
+ tuple: (script_string, is_potentially_unsafe, debug_comment)
1301
+ """
1302
+ s = ""
1303
+ debug_comment = ""
1304
+ is_potentially_unsafe = False
1305
+
1306
+ s += "# deployment server\n"
1307
+
1308
+ # Extract pre_cmd for later use inside container
1309
+ pre_cmd: str = cfg.deployment.get("pre_cmd") or ""
1310
+ if pre_cmd:
1311
+ is_potentially_unsafe = True
1312
+ create_pre_script_cmd = _str_to_echo_command(
1313
+ pre_cmd, filename="deployment_pre_cmd.sh"
1314
+ )
1315
+ debug_comment += create_pre_script_cmd.debug + "\n\n"
1316
+
1317
+ s += "# Get node IPs\n"
1318
+ s += "nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) )\n"
1319
+ s += 'nodes_array=("${nodes[@]}") # Ensure nodes are stored properly\n'
1320
+ s += 'export NODES_IPS_ARRAY=($(for node in "${nodes_array[@]}"; do srun --nodelist=$node --ntasks=1 --nodes=1 hostname --ip-address; done))\n'
1321
+ s += 'echo "Node IPs: ${NODES_IPS_ARRAY[@]}"\n'
1322
+ s += "# Export MASTER_IP as the first node IP\n"
1323
+ s += "export MASTER_IP=${NODES_IPS_ARRAY[0]}\n"
1324
+ s += 'echo "MASTER_IP: $MASTER_IP"\n'
1325
+
1326
+ # Add debug comment for deployment pre_cmd before srun command
1327
+ if debug_comment:
1328
+ s += "# Debug contents of deployment pre_cmd\n"
1329
+ s += debug_comment
1330
+ s += "\n"
1331
+
1332
+ s += "srun --mpi pmix --overlap "
1333
+ s += f"--nodes {cfg.execution.num_nodes} --ntasks {cfg.execution.get('deployment', {}).get('n_tasks', 1)} "
1334
+ s += "--container-image {} ".format(cfg.deployment.image)
1335
+ if deployment_mounts_list:
1336
+ s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
1337
+ if not cfg.execution.get("mounts", {}).get("mount_home", True):
1338
+ s += "--no-container-mount-home "
1339
+ s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A-%t.out")
1340
+
1341
+ deployment_env_var_names = list(
1342
+ cfg.execution.get("env_vars", {}).get("deployment", {})
1343
+ )
1344
+ if cfg.deployment.get("env_vars"):
1345
+ warnings.warn(
1346
+ "cfg.deployment.env_vars will be deprecated in future versions. "
1347
+ "Use cfg.execution.env_vars.deployment instead.",
1348
+ category=DeprecationWarning,
1349
+ stacklevel=2,
1350
+ )
1351
+ deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
1352
+
1353
+ # Always add MASTER_IP to the environment variables
1354
+ if "MASTER_IP" not in deployment_env_var_names:
1355
+ deployment_env_var_names.append("MASTER_IP")
1356
+
1357
+ if deployment_env_var_names:
1358
+ s += f"--container-env {','.join(deployment_env_var_names)} "
1359
+
1360
+ # Wrap deployment command to execute pre_cmd inside container if needed
1361
+ if pre_cmd:
1362
+ # Create a wrapper command that runs inside the container:
1363
+ # 1. Create deployment_pre_cmd.sh file
1364
+ # 2. Source it
1365
+ # 3. Execute the original deployment command
1366
+ create_pre_script_cmd = _str_to_echo_command(
1367
+ pre_cmd, filename="deployment_pre_cmd.sh"
1368
+ )
1369
+ # Escape single quotes in the deployment command for bash -c
1370
+ escaped_deployment_cmd = cfg.deployment.command.replace("'", "'\"'\"'")
1371
+ wrapped_command = (
1372
+ f"bash -c '{create_pre_script_cmd.cmd} && "
1373
+ f"source deployment_pre_cmd.sh && "
1374
+ f"{escaped_deployment_cmd}'"
1375
+ )
1376
+ s += "{} &\n\n".format(wrapped_command)
1377
+ else:
1378
+ s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
1379
+
1380
+ s += "SERVER_PID=$! # capture the PID of the server background srun process\n\n"
1381
+
1382
+ return s, is_potentially_unsafe, debug_comment
1383
+
1384
+
1385
+ def _get_wait_for_server_handler(
1386
+ ip_list: str,
1387
+ port: int,
1388
+ health_check_path: str,
1389
+ service_name: str = "server",
1390
+ check_pid: bool = False,
1391
+ ):
1392
+ """Generate wait for server handler that takes a list of IPs."""
1393
+ pid_check = ""
1394
+ if check_pid:
1395
+ pid_check = 'kill -0 "$SERVER_PID" 2>/dev/null || { echo "Server process $SERVER_PID died"; exit 1; }'
1396
+
1397
+ handler = f"""date
1398
+ # wait for the {service_name} to initialize
1399
+ for ip in {ip_list}; do
1400
+ echo "Waiting for {service_name} on $ip..."
1401
+ while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" http://$ip:{port}{health_check_path})" != "200" ]]; do
1402
+ {pid_check}
1403
+ sleep 5
1404
+ done
1405
+ echo "{service_name} ready on $ip!"
1406
+ done
1092
1407
  date
1093
1408
  """.strip()
1409
+
1410
+ return handler
1411
+
1412
+
1413
+ def _get_proxy_server_srun_command(cfg, remote_task_subdir):
1414
+ """Generate proxy server srun command based on proxy type."""
1415
+ proxy_type = cfg.execution.get("proxy", {}).get("type", "haproxy")
1416
+
1417
+ if proxy_type == "haproxy":
1418
+ return _generate_haproxy_srun_command(cfg, remote_task_subdir)
1419
+ else:
1420
+ raise ValueError(
1421
+ f"Unsupported proxy type: {proxy_type}. Currently only 'haproxy' is supported."
1422
+ )
1423
+
1424
+
1425
+ def _generate_haproxy_srun_command(cfg, remote_task_subdir):
1426
+ """Generate HAProxy-specific srun command using template-based config."""
1427
+ s = ""
1428
+ s += "# Proxy load balancer\n"
1429
+ s += "# Copy template to config file (important for restarts)\n"
1430
+ s += f"cp {remote_task_subdir}/proxy.cfg.template {remote_task_subdir}/proxy.cfg\n"
1431
+ s += "# Replace placeholder IPs with actual node IPs\n"
1432
+ s += f"proxy_config_file={remote_task_subdir}/proxy.cfg\n"
1433
+ s += 'for i in "${!NODES_IPS_ARRAY[@]}"; do\n'
1434
+ s += ' ip="${NODES_IPS_ARRAY[$i]}"\n'
1435
+ s += ' sed -i "s/{IP_$i}/$ip/g" "$proxy_config_file"\n'
1436
+ s += "done\n"
1437
+ s += "\n"
1438
+ s += "srun --mpi pmix --overlap "
1439
+ s += "--nodes 1 --ntasks 1 "
1440
+ s += f"--container-image {cfg.execution.get('proxy', {}).get('image', 'haproxy:latest')} "
1441
+ s += f"--container-mounts {remote_task_subdir}/proxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro "
1442
+ s += f"--output {remote_task_subdir}/logs/proxy-%A.out "
1443
+ s += "haproxy -f /usr/local/etc/haproxy/haproxy.cfg &\n"
1444
+ s += "PROXY_PID=$! # capture the PID of the proxy background srun process\n"
1445
+ s += 'echo "Proxy started with PID: $PROXY_PID"\n\n'
1446
+
1447
+ # Wait for proxy to be ready on localhost
1448
+ proxy_config = cfg.execution.get("proxy", {}).get("config", {})
1449
+ haproxy_port = proxy_config.get("haproxy_port", 5009)
1450
+ health_path = proxy_config.get("health_check_path", "/health")
1451
+ s += _get_wait_for_server_handler(
1452
+ "127.0.0.1", haproxy_port, health_path, "Proxy", check_pid=False
1453
+ )
1454
+ s += "\n"
1455
+
1456
+ return s