nemo-evaluator-launcher 0.1.0rc6__py3-none-any.whl → 0.1.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. nemo_evaluator_launcher/__init__.py +15 -1
  2. nemo_evaluator_launcher/api/functional.py +188 -27
  3. nemo_evaluator_launcher/api/types.py +9 -0
  4. nemo_evaluator_launcher/cli/export.py +131 -12
  5. nemo_evaluator_launcher/cli/info.py +477 -82
  6. nemo_evaluator_launcher/cli/kill.py +5 -3
  7. nemo_evaluator_launcher/cli/logs.py +102 -0
  8. nemo_evaluator_launcher/cli/ls_runs.py +31 -10
  9. nemo_evaluator_launcher/cli/ls_tasks.py +105 -3
  10. nemo_evaluator_launcher/cli/main.py +101 -5
  11. nemo_evaluator_launcher/cli/run.py +153 -30
  12. nemo_evaluator_launcher/cli/status.py +49 -5
  13. nemo_evaluator_launcher/cli/version.py +26 -23
  14. nemo_evaluator_launcher/common/execdb.py +121 -27
  15. nemo_evaluator_launcher/common/helpers.py +213 -33
  16. nemo_evaluator_launcher/common/logging_utils.py +16 -5
  17. nemo_evaluator_launcher/common/printing_utils.py +100 -0
  18. nemo_evaluator_launcher/configs/deployment/generic.yaml +33 -0
  19. nemo_evaluator_launcher/configs/deployment/sglang.yaml +4 -2
  20. nemo_evaluator_launcher/configs/deployment/trtllm.yaml +23 -0
  21. nemo_evaluator_launcher/configs/deployment/vllm.yaml +2 -2
  22. nemo_evaluator_launcher/configs/execution/local.yaml +2 -0
  23. nemo_evaluator_launcher/configs/execution/slurm/default.yaml +19 -4
  24. nemo_evaluator_launcher/executors/base.py +54 -1
  25. nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +60 -5
  26. nemo_evaluator_launcher/executors/lepton/executor.py +240 -101
  27. nemo_evaluator_launcher/executors/lepton/job_helpers.py +15 -11
  28. nemo_evaluator_launcher/executors/local/executor.py +492 -56
  29. nemo_evaluator_launcher/executors/local/run.template.sh +76 -9
  30. nemo_evaluator_launcher/executors/slurm/executor.py +571 -98
  31. nemo_evaluator_launcher/executors/slurm/proxy.cfg.template +26 -0
  32. nemo_evaluator_launcher/exporters/base.py +9 -0
  33. nemo_evaluator_launcher/exporters/gsheets.py +27 -9
  34. nemo_evaluator_launcher/exporters/local.py +30 -16
  35. nemo_evaluator_launcher/exporters/mlflow.py +245 -74
  36. nemo_evaluator_launcher/exporters/utils.py +139 -184
  37. nemo_evaluator_launcher/exporters/wandb.py +157 -43
  38. nemo_evaluator_launcher/package_info.py +6 -3
  39. nemo_evaluator_launcher/resources/mapping.toml +56 -15
  40. nemo_evaluator_launcher-0.1.41.dist-info/METADATA +494 -0
  41. nemo_evaluator_launcher-0.1.41.dist-info/RECORD +62 -0
  42. {nemo_evaluator_launcher-0.1.0rc6.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/entry_points.txt +1 -0
  43. nemo_evaluator_launcher-0.1.0rc6.dist-info/METADATA +0 -35
  44. nemo_evaluator_launcher-0.1.0rc6.dist-info/RECORD +0 -57
  45. {nemo_evaluator_launcher-0.1.0rc6.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/WHEEL +0 -0
  46. {nemo_evaluator_launcher-0.1.0rc6.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/licenses/LICENSE +0 -0
  47. {nemo_evaluator_launcher-0.1.0rc6.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ from pathlib import Path
30
30
  from typing import Dict, List, Optional
31
31
 
32
32
  import yaml
33
+ from jinja2 import Environment, FileSystemLoader
33
34
  from omegaconf import DictConfig, OmegaConf
34
35
 
35
36
  from nemo_evaluator_launcher.common.execdb import (
@@ -39,17 +40,19 @@ from nemo_evaluator_launcher.common.execdb import (
39
40
  generate_job_id,
40
41
  )
41
42
  from nemo_evaluator_launcher.common.helpers import (
43
+ CmdAndReadableComment,
44
+ _str_to_echo_command,
42
45
  get_api_key_name,
43
- get_endpoint_url,
44
46
  get_eval_factory_command,
45
47
  get_eval_factory_dataset_size_from_run_config,
46
- get_health_url,
47
48
  get_timestamp_string,
48
49
  )
50
+ from nemo_evaluator_launcher.common.logging_utils import logger
49
51
  from nemo_evaluator_launcher.common.mapping import (
50
52
  get_task_from_mapping,
51
53
  load_tasks_mapping,
52
54
  )
55
+ from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey, red
53
56
  from nemo_evaluator_launcher.executors.base import (
54
57
  BaseExecutor,
55
58
  ExecutionState,
@@ -93,6 +96,7 @@ class SlurmExecutor(BaseExecutor):
93
96
  tasks_mapping = load_tasks_mapping()
94
97
  eval_images: list[str] = []
95
98
 
99
+ is_potentially_unsafe = False
96
100
  for idx, task in enumerate(cfg.evaluation.tasks):
97
101
  # calculate job_id
98
102
  job_id = f"{invocation_id}.{idx}"
@@ -113,7 +117,7 @@ class SlurmExecutor(BaseExecutor):
113
117
  eval_images.append(eval_image)
114
118
 
115
119
  # generate and write down sbatch script
116
- sbatch_script_content_str = _create_slurm_sbatch_script(
120
+ sbatch_script_content_struct = _create_slurm_sbatch_script(
117
121
  cfg=cfg,
118
122
  task=task,
119
123
  eval_image=eval_image,
@@ -121,6 +125,32 @@ class SlurmExecutor(BaseExecutor):
121
125
  invocation_id=invocation_id,
122
126
  job_id=job_id,
123
127
  )
128
+
129
+ # Create proxy config file with placeholder IPs for multi-instance deployments
130
+ if cfg.deployment.get("multiple_instances", False):
131
+ proxy_type = cfg.execution.get("proxy", {}).get("type", "haproxy")
132
+ if proxy_type == "haproxy":
133
+ proxy_config = _generate_haproxy_config_with_placeholders(cfg)
134
+ else:
135
+ raise ValueError(
136
+ f"Unsupported proxy type: {proxy_type}. Currently only 'haproxy' is supported."
137
+ )
138
+
139
+ # Save both template and working config
140
+ proxy_template_path = local_task_subdir / "proxy.cfg.template"
141
+ proxy_config_path = local_task_subdir / "proxy.cfg"
142
+ with open(proxy_template_path, "w") as f:
143
+ f.write(proxy_config)
144
+ with open(proxy_config_path, "w") as f:
145
+ f.write(proxy_config)
146
+
147
+ sbatch_script_content_str = sbatch_script_content_struct.cmd
148
+
149
+ # We accumulate if any task contains unsafe commands
150
+ is_potentially_unsafe = (
151
+ is_potentially_unsafe
152
+ or sbatch_script_content_struct.is_potentially_unsafe
153
+ )
124
154
  local_runsub_path = local_task_subdir / "run.sub"
125
155
  remote_runsub_path = remote_task_subdir / "run.sub"
126
156
  with open(local_runsub_path, "w") as f:
@@ -130,15 +160,41 @@ class SlurmExecutor(BaseExecutor):
130
160
  remote_runsub_paths.append(remote_runsub_path)
131
161
 
132
162
  if dry_run:
133
- print("\n\n=============================================\n\n")
134
- print("DRY RUN: SLURM scripts prepared")
163
+ print(bold("\n\n=============================================\n\n"))
164
+ print(bold(cyan("DRY RUN: SLURM scripts prepared")))
135
165
  for idx, local_runsub_path in enumerate(local_runsub_paths):
136
- print(f"\n\n =========== Task {idx} ===================== \n\n")
166
+ print(cyan(f"\n\n=========== Task {idx} =====================\n\n"))
137
167
  with open(local_runsub_path, "r") as f:
138
- print(f.read())
139
- print("\nTo submit jobs, run the executor without --dry-run")
168
+ print(grey(f.read()))
169
+ print(bold("To submit jobs") + ", run the executor without --dry-run")
170
+ if is_potentially_unsafe:
171
+ print(
172
+ red(
173
+ "\nFound `pre_cmd` (evaluation or deployment) which carries security risk. When running without --dry-run "
174
+ "make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1"
175
+ )
176
+ )
177
+
140
178
  return invocation_id
141
179
 
180
+ if is_potentially_unsafe:
181
+ if os.environ.get("NEMO_EVALUATOR_TRUST_PRE_CMD", "") == "1":
182
+ logger.warning(
183
+ "Found non-empty commands (e.g. `pre_cmd` in evaluation or deployment) and NEMO_EVALUATOR_TRUST_PRE_CMD "
184
+ "is set, proceeding with caution."
185
+ )
186
+
187
+ else:
188
+ logger.error(
189
+ "Found non-empty commands (e.g. `pre_cmd` in evaluation or deployment) and NEMO_EVALUATOR_TRUST_PRE_CMD "
190
+ "is not set. This might carry security risk and unstable environments. "
191
+ "To continue, make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1.",
192
+ )
193
+ raise AttributeError(
194
+ "Untrusted command found in config, make sure you trust and "
195
+ "set NEMO_EVALUATOR_TRUST_PRE_CMD=1."
196
+ )
197
+
142
198
  socket = str(Path(tmpdirname) / "socket")
143
199
  socket_or_none = _open_master_connection(
144
200
  username=cfg.execution.username,
@@ -174,10 +230,11 @@ class SlurmExecutor(BaseExecutor):
174
230
  for idx, (slurm_job_id, remote_runsub_path) in enumerate(
175
231
  zip(slurm_job_ids, remote_runsub_paths)
176
232
  ):
233
+ job_id = generate_job_id(invocation_id, idx)
177
234
  db.write_job(
178
235
  job=JobData(
179
236
  invocation_id=invocation_id,
180
- job_id=generate_job_id(invocation_id, idx),
237
+ job_id=job_id,
181
238
  timestamp=time.time(),
182
239
  executor="slurm",
183
240
  data={
@@ -204,8 +261,8 @@ class SlurmExecutor(BaseExecutor):
204
261
  """
205
262
  db = ExecutionDB()
206
263
 
207
- # If id looks like an invocation_id (8 hex digits, no dot), get all jobs for it
208
- if len(id) == 8 and "." not in id:
264
+ # If id looks like an invocation_id
265
+ if "." not in id:
209
266
  jobs = db.get_jobs(id)
210
267
  if not jobs:
211
268
  return []
@@ -388,7 +445,7 @@ class SlurmExecutor(BaseExecutor):
388
445
  """Kill a SLURM job.
389
446
 
390
447
  Args:
391
- job_id: The job ID to kill.
448
+ job_id: The job ID (e.g., abc123.0) to kill.
392
449
  """
393
450
  db = ExecutionDB()
394
451
  job_data = db.get_job(job_id)
@@ -401,26 +458,31 @@ class SlurmExecutor(BaseExecutor):
401
458
  f"Job {job_id} is not a slurm job (executor: {job_data.executor})"
402
459
  )
403
460
 
404
- killed_something = False
405
-
406
- result = _kill_slurm_job(
461
+ # OPTIMIZATION: Query status AND kill in ONE SSH call
462
+ slurm_status, result = _kill_slurm_job(
407
463
  slurm_job_ids=[job_data.data.get("slurm_job_id")],
408
464
  username=job_data.data.get("username"),
409
465
  hostname=job_data.data.get("hostname"),
410
466
  socket=job_data.data.get("socket"),
411
467
  )
412
468
 
469
+ # Mark job as killed in database if kill succeeded
413
470
  if result.returncode == 0:
414
- killed_something = True
415
-
416
- # Mark job as killed in database if we killed something
417
- if killed_something:
418
471
  job_data.data["killed"] = True
419
472
  db.write_job(job_data)
420
473
  else:
421
- raise RuntimeError(
422
- f"Could not find or kill job {job_id} (slurm_job_id: {job_data.data.get('slurm_job_id')})"
474
+ # Use the pre-fetched status for better error message
475
+ current_status = None
476
+ if slurm_status:
477
+ current_status = SlurmExecutor._map_slurm_state_to_execution_state(
478
+ slurm_status
479
+ )
480
+ error_msg = SlurmExecutor.get_kill_failure_message(
481
+ job_id,
482
+ f"slurm_job_id: {job_data.data.get('slurm_job_id')}",
483
+ current_status,
423
484
  )
485
+ raise RuntimeError(error_msg)
424
486
 
425
487
 
426
488
  def _create_slurm_sbatch_script(
@@ -430,7 +492,7 @@ def _create_slurm_sbatch_script(
430
492
  remote_task_subdir: Path,
431
493
  invocation_id: str,
432
494
  job_id: str,
433
- ) -> str:
495
+ ) -> CmdAndReadableComment:
434
496
  """Generate the contents of a SLURM sbatch script for a given evaluation task.
435
497
 
436
498
  Args:
@@ -446,7 +508,6 @@ def _create_slurm_sbatch_script(
446
508
  # get task from mapping, overrides, urls
447
509
  tasks_mapping = load_tasks_mapping()
448
510
  task_definition = get_task_from_mapping(task.name, tasks_mapping)
449
- health_url = get_health_url(cfg, get_endpoint_url(cfg, task, task_definition))
450
511
 
451
512
  # TODO(public release): convert to template
452
513
  s = "#!/bin/bash\n"
@@ -461,6 +522,8 @@ def _create_slurm_sbatch_script(
461
522
  s += "#SBATCH --gpus-per-node {}\n".format(cfg.execution.gpus_per_node)
462
523
  if hasattr(cfg.execution, "gres"):
463
524
  s += "#SBATCH --gres {}\n".format(cfg.execution.gres)
525
+ if cfg.execution.get("sbatch_comment"):
526
+ s += "#SBATCH --comment='{}'\n".format(cfg.execution.sbatch_comment)
464
527
  job_name = "{account}-{subproject}.{details}".format(
465
528
  account=cfg.execution.account,
466
529
  subproject=cfg.execution.subproject,
@@ -470,6 +533,8 @@ def _create_slurm_sbatch_script(
470
533
  s += "#SBATCH --exclusive\n"
471
534
  s += "#SBATCH --output {}\n".format(remote_task_subdir / "logs" / "slurm-%A.out")
472
535
  s += "\n"
536
+ s += f'TASK_DIR="{str(remote_task_subdir)}"\n'
537
+ s += "\n"
473
538
 
474
539
  # collect all env vars
475
540
  env_vars = copy.deepcopy(dict(cfg.evaluation.get("env_vars", {})))
@@ -484,8 +549,11 @@ def _create_slurm_sbatch_script(
484
549
  if os.getenv(env_var) is None:
485
550
  raise ValueError(f"Trying to pass an unset environment variable {env_var}.")
486
551
 
487
- # check if required env vars are defined:
552
+ # check if required env vars are defined (excluding NEMO_EVALUATOR_DATASET_DIR which is handled separately):
488
553
  for required_env_var in task_definition.get("required_env_vars", []):
554
+ # Skip NEMO_EVALUATOR_DATASET_DIR as it's handled by dataset mounting logic below
555
+ if required_env_var == "NEMO_EVALUATOR_DATASET_DIR":
556
+ continue
489
557
  if required_env_var not in env_vars.keys():
490
558
  raise ValueError(
491
559
  f"{task.name} task requires environment variable {required_env_var}."
@@ -531,6 +599,7 @@ def _create_slurm_sbatch_script(
531
599
 
532
600
  # prepare deployment mounts
533
601
  deployment_mounts_list = []
602
+ deployment_is_unsafe = False
534
603
  if cfg.deployment.type != "none":
535
604
  if checkpoint_path := cfg.deployment.get("checkpoint_path"):
536
605
  deployment_mounts_list.append(f"{checkpoint_path}:/checkpoint:ro")
@@ -542,36 +611,33 @@ def _create_slurm_sbatch_script(
542
611
  deployment_mounts_list.append(f"{source_mnt}:{target_mnt}")
543
612
 
544
613
  # add deployment srun command
545
- s += "# deployment server\n"
546
- s += "srun --mpi pmix --overlap "
547
- s += "--container-image {} ".format(cfg.deployment.image)
548
- if deployment_mounts_list:
549
- s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
550
- if not cfg.execution.get("mounts", {}).get("mount_home", True):
551
- s += "--no-container-mount-home "
552
- s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A.out")
553
- deployment_env_var_names = list(
554
- cfg.execution.get("env_vars", {}).get("deployment", {})
555
- )
556
- if cfg.deployment.get("env_vars"):
557
- warnings.warn(
558
- "cfg.deployment.env_vars will be deprecated in future versions. "
559
- "Use cfg.execution.env_vars.deployment instead.",
560
- category=DeprecationWarning,
561
- stacklevel=2,
614
+ deployment_srun_cmd, deployment_is_unsafe, deployment_debug = (
615
+ _generate_deployment_srun_command(
616
+ cfg, deployment_mounts_list, remote_task_subdir
562
617
  )
563
- deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
564
- if deployment_env_var_names:
565
- s += f"--container-env {','.join(deployment_env_var_names)} "
566
- s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
567
- s += (
568
- "SERVER_PID=$! # capture the PID of the server background srun process\n\n"
569
618
  )
619
+ s += deployment_srun_cmd
570
620
 
571
621
  # wait for the server to initialize
572
- s += _WAIT_FOR_SERVER_HANDLER.format(health_url=health_url)
622
+ health_path = cfg.deployment.get("health_check_path", "/health")
623
+ # For multi-instance check all node IPs, for single instance check localhost
624
+ if cfg.deployment.get("multiple_instances", False):
625
+ ip_list = '"${NODES_IPS_ARRAY[@]}"'
626
+ else:
627
+ ip_list = '"127.0.0.1"'
628
+ s += _get_wait_for_server_handler(
629
+ ip_list,
630
+ cfg.deployment.port,
631
+ health_path,
632
+ "server",
633
+ check_pid=True,
634
+ )
573
635
  s += "\n\n"
574
636
 
637
+ # add proxy load balancer for multi-instance deployments
638
+ if cfg.deployment.get("multiple_instances", False):
639
+ s += _get_proxy_server_srun_command(cfg, remote_task_subdir)
640
+
575
641
  # prepare evaluation mounts
576
642
  evaluation_mounts_list = [
577
643
  "{}:/results".format(remote_task_subdir / "artifacts"),
@@ -581,9 +647,45 @@ def _create_slurm_sbatch_script(
581
647
  ):
582
648
  evaluation_mounts_list.append(f"{source_mnt}:{target_mnt}")
583
649
 
650
+ # Handle dataset directory mounting if NEMO_EVALUATOR_DATASET_DIR is required
651
+ if "NEMO_EVALUATOR_DATASET_DIR" in task_definition.get("required_env_vars", []):
652
+ # Get dataset directory from task config
653
+ if "dataset_dir" in task:
654
+ dataset_mount_host = task["dataset_dir"]
655
+ else:
656
+ raise ValueError(
657
+ f"{task.name} task requires a dataset_dir to be specified. "
658
+ f"Add 'dataset_dir: /path/to/your/dataset' under the task configuration."
659
+ )
660
+ # Get container mount path (default to /datasets if not specified)
661
+ dataset_mount_container = task.get("dataset_mount_path", "/datasets")
662
+ # Add dataset mount to evaluation mounts list
663
+ evaluation_mounts_list.append(f"{dataset_mount_host}:{dataset_mount_container}")
664
+ # Export NEMO_EVALUATOR_DATASET_DIR environment variable
665
+ s += f"export NEMO_EVALUATOR_DATASET_DIR={dataset_mount_container}\n\n"
666
+
667
+ eval_factory_command_struct = get_eval_factory_command(
668
+ cfg,
669
+ task,
670
+ task_definition,
671
+ )
672
+
673
+ eval_factory_command = eval_factory_command_struct.cmd
674
+ # The debug comment for placing into the script and easy debug. Reason
675
+ # (see `CmdAndReadableComment`) is the current way of passing the command
676
+ # is base64-encoded config `echo`-ed into file.
677
+ # TODO(agronskiy): cleaner way is to encode everything with base64, not
678
+ # some parts (like ef_config.yaml) and just output as logs somewhere.
679
+ eval_factory_command_debug_comment = eval_factory_command_struct.debug
680
+
584
681
  # add evaluation srun command
682
+ s += "# Debug contents of the eval factory command's config\n"
683
+ s += eval_factory_command_debug_comment
684
+ s += "\n\n"
685
+
585
686
  s += "# evaluation client\n"
586
687
  s += "srun --mpi pmix --overlap "
688
+ s += "--nodes 1 --ntasks 1 " # Client always runs on single node
587
689
  s += "--container-image {} ".format(eval_image)
588
690
  evaluation_env_var_names = list(
589
691
  cfg.execution.get("env_vars", {}).get("evaluation", {})
@@ -592,43 +694,139 @@ def _create_slurm_sbatch_script(
592
694
  s += "--container-env {} ".format(",".join(evaluation_env_var_names))
593
695
  if not cfg.execution.get("mounts", {}).get("mount_home", True):
594
696
  s += "--no-container-mount-home "
697
+
595
698
  s += "--container-mounts {} ".format(",".join(evaluation_mounts_list))
596
699
  s += "--output {} ".format(remote_task_subdir / "logs" / "client-%A.out")
597
- s += "bash -c '"
598
- s += get_eval_factory_command(cfg, task, task_definition)
700
+ s += "bash -c '\n"
701
+ s += eval_factory_command
599
702
  s += "'\n\n"
600
703
 
601
704
  # terminate the server after all evaluation clients finish
602
705
  if cfg.deployment.type != "none":
603
- s += "kill $SERVER_PID # terminate the server to finish gracefully\n\n"
706
+ s += "kill $SERVER_PID # terminate the server to finish gracefully\n"
707
+ if cfg.deployment.get("multiple_instances", False):
708
+ s += "kill $PROXY_PID # terminate proxy to finish gracefully\n"
709
+ s += "\n"
604
710
 
605
711
  # auto-export
606
- if cfg.execution.get("auto_export", {}).get("destinations", []):
607
- s += _generate_auto_export_section(cfg, job_id)
712
+ ae_cfg = cfg.execution.get("auto_export")
713
+ destinations: list = []
714
+ if isinstance(ae_cfg, list):
715
+ destinations = list(ae_cfg)
716
+ elif isinstance(ae_cfg, dict) or isinstance(ae_cfg, DictConfig):
717
+ destinations = list(ae_cfg.get("destinations", []) or [])
718
+
719
+ if destinations:
720
+ export_env = dict(cfg.execution.get("env_vars", {}).get("export", {}) or {})
721
+ s += _generate_auto_export_section(
722
+ cfg, job_id, destinations, export_env, remote_task_subdir
723
+ )
608
724
 
609
- return s
725
+ debug_str = "\n".join(["# " + line for line in s.splitlines()])
726
+
727
+ # Combine unsafe flags from both deployment and evaluation
728
+ is_potentially_unsafe = (
729
+ eval_factory_command_struct.is_potentially_unsafe or deployment_is_unsafe
730
+ )
731
+
732
+ return CmdAndReadableComment(
733
+ cmd=s,
734
+ debug=debug_str,
735
+ is_potentially_unsafe=is_potentially_unsafe,
736
+ )
610
737
 
611
738
 
612
739
  def _generate_auto_export_section(
613
740
  cfg: DictConfig,
614
- job_id: str, # Complete job_id string
741
+ job_id: str,
742
+ destinations: list,
743
+ export_env: dict,
744
+ remote_task_subdir: Path,
745
+ export_image: str = "python:3.12.7-slim",
615
746
  ) -> str:
616
747
  """Generate simple auto-export section for sbatch script."""
617
- auto_export_config = cfg.execution.get("auto_export", {})
618
- destinations = auto_export_config.get("destinations", [])
619
-
620
748
  if not destinations:
621
749
  return ""
622
750
 
623
- s = "\n# AUTO-EXPORT ON SUCCESS\n"
751
+ s = "\n# Auto-export on success\n"
624
752
  s += "EVAL_EXIT_CODE=$?\n"
625
753
  s += "if [ $EVAL_EXIT_CODE -eq 0 ]; then\n"
626
754
  s += " echo 'Evaluation completed successfully. Starting auto-export...'\n"
755
+ s += f' cd "{remote_task_subdir}/artifacts"\n'
627
756
 
628
- for dest in destinations:
629
- s += f" echo 'Exporting to {dest}...'\n"
630
- s += f" nemo-evaluator-launcher export {job_id} --dest {dest} || echo 'Export to {dest} failed'\n"
757
+ # Work with DictConfig; convert only for YAML at the end
758
+ exec_type = (
759
+ cfg.execution.type
760
+ if hasattr(cfg.execution, "type")
761
+ else cfg.execution.get("type", "slurm")
762
+ )
763
+ eval_tasks = (
764
+ list(cfg.evaluation.tasks)
765
+ if hasattr(cfg, "evaluation") and hasattr(cfg.evaluation, "tasks")
766
+ else list((cfg.get("evaluation", {}) or {}).get("tasks", []) or [])
767
+ )
768
+ export_block = cfg.get("export", {}) or {}
769
+
770
+ payload = {
771
+ "execution": {
772
+ "auto_export": {
773
+ "destinations": list(destinations),
774
+ **({"env_vars": dict(export_env)} if export_env else {}),
775
+ },
776
+ "type": exec_type,
777
+ },
778
+ "evaluation": {"tasks": eval_tasks},
779
+ }
780
+ if export_block:
781
+ # Convert just this block to plain for YAML
782
+ payload["export"] = (
783
+ OmegaConf.to_object(export_block)
784
+ if OmegaConf.is_config(export_block)
785
+ else dict(export_block)
786
+ )
631
787
 
788
+ # Final YAML (single conversion at the end)
789
+ payload_clean = OmegaConf.to_container(OmegaConf.create(payload), resolve=True)
790
+ yaml_str = yaml.safe_dump(payload_clean, sort_keys=False)
791
+ s += " cat > export_config.yml << 'EOF'\n"
792
+ s += yaml_str
793
+ s += "EOF\n"
794
+
795
+ # write launcher config as config.yml for exporters (no core command)
796
+ submitted_yaml = yaml.safe_dump(
797
+ OmegaConf.to_container(cfg, resolve=True), sort_keys=False
798
+ )
799
+ s += " cat > config.yml << 'EOF'\n"
800
+ s += submitted_yaml
801
+ s += "EOF\n"
802
+
803
+ # Export host only env before running auto export
804
+ for k, v in (export_env or {}).items():
805
+ if isinstance(v, str) and re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", v):
806
+ s += f' export {k}="${{{v}}}"\n'
807
+ else:
808
+ esc = str(v).replace('"', '\\"')
809
+ s += f' export {k}="{esc}"\n'
810
+
811
+ s += " # export\n"
812
+ s += " srun --mpi pmix --overlap "
813
+ s += "--nodes 1 --ntasks 1 " # Client always runs on single node
814
+ s += "--container-image {} ".format(export_image)
815
+ if export_env:
816
+ s += "--container-env {} ".format(",".join(export_env))
817
+ if not cfg.execution.get("mounts", {}).get("mount_home", True):
818
+ s += "--no-container-mount-home "
819
+
820
+ s += f"--container-mounts {remote_task_subdir}/artifacts:{remote_task_subdir}/artifacts "
821
+ s += "--output {} ".format(remote_task_subdir / "logs" / "export-%A.out")
822
+ s += " bash -c '\n"
823
+ # FIXME(martas): would be good to install specific version
824
+ s += " pip install nemo-evaluator-launcher[all]\n"
825
+ s += f" cd {remote_task_subdir}/artifacts\n"
826
+ for dest in destinations:
827
+ s += f' echo "Exporting to {dest}..."\n'
828
+ s += f' nemo-evaluator-launcher export {job_id} --dest {dest} || echo "Export to {dest} failed"\n'
829
+ s += "'\n"
632
830
  s += " echo 'Auto-export completed.'\n"
633
831
  s += "else\n"
634
832
  s += " echo 'Evaluation failed with exit code $EVAL_EXIT_CODE. Skipping auto-export.'\n"
@@ -643,9 +841,12 @@ def _open_master_connection(
643
841
  socket: str,
644
842
  ) -> str | None:
645
843
  ssh_command = f"ssh -MNf -S {socket} {username}@{hostname}"
844
+ logger.info("Opening master connection", cmd=ssh_command)
646
845
  completed_process = subprocess.run(args=shlex.split(ssh_command))
647
846
  if completed_process.returncode == 0:
847
+ logger.info("Opened master connection successfully", cmd=ssh_command)
648
848
  return socket
849
+ logger.error("Failed to open master connection", code=completed_process.returncode)
649
850
  return None
650
851
 
651
852
 
@@ -657,9 +858,7 @@ def _close_master_connection(
657
858
  if socket is None:
658
859
  return
659
860
  ssh_command = f"ssh -O exit -S {socket} {username}@{hostname}"
660
- completed_process = subprocess.run(
661
- args=shlex.split(ssh_command), capture_output=True
662
- )
861
+ completed_process = subprocess.run(args=shlex.split(ssh_command))
663
862
  if completed_process.returncode != 0:
664
863
  raise RuntimeError(
665
864
  "failed to close the master connection\n{}".format(
@@ -681,12 +880,23 @@ def _make_remote_execution_output_dir(
681
880
  ssh_command.append(f"{username}@{hostname}")
682
881
  ssh_command.append(mkdir_command)
683
882
  ssh_command = " ".join(ssh_command)
684
- completed_process = subprocess.run(args=shlex.split(ssh_command))
883
+ logger.info("Creating remote dir", cmd=ssh_command)
884
+ completed_process = subprocess.run(
885
+ args=shlex.split(ssh_command), stderr=subprocess.PIPE
886
+ )
685
887
  if completed_process.returncode != 0:
888
+ error_msg = (
889
+ completed_process.stderr.decode("utf-8")
890
+ if completed_process.stderr
891
+ else "Unknown error"
892
+ )
893
+ logger.error(
894
+ "Erorr creating remote dir",
895
+ code=completed_process.returncode,
896
+ msg=error_msg,
897
+ )
686
898
  raise RuntimeError(
687
- "failed to make a remote execution output dir\n{}".format(
688
- completed_process.stderr.decode("utf-8")
689
- )
899
+ "failed to make a remote execution output dir\n{}".format(error_msg)
690
900
  )
691
901
 
692
902
 
@@ -712,14 +922,25 @@ def _rsync_upload_rundirs(
712
922
  remote_destination_str = f"{username}@{hostname}:{remote_target}"
713
923
  local_sources_str = " ".join(map(str, local_sources))
714
924
  rsync_upload_command = f"rsync -qcaz {local_sources_str} {remote_destination_str}"
715
- completed_process = subprocess.run(args=shlex.split(rsync_upload_command))
925
+ logger.info("Rsyncing to remote dir", cmd=rsync_upload_command)
926
+ completed_process = subprocess.run(
927
+ args=shlex.split(rsync_upload_command),
928
+ stderr=subprocess.PIPE,
929
+ )
716
930
  if completed_process.returncode != 0:
717
- raise RuntimeError(
718
- "failed to upload local sources\n{}".format(
719
- completed_process.stderr.decode("utf-8")
720
- )
931
+ error_msg = (
932
+ completed_process.stderr.decode("utf-8")
933
+ if completed_process.stderr
934
+ else "Unknown error"
721
935
  )
722
936
 
937
+ logger.error(
938
+ "Erorr rsyncing to remote dir",
939
+ code=completed_process.returncode,
940
+ msg=error_msg,
941
+ )
942
+ raise RuntimeError("failed to upload local sources\n{}".format(error_msg))
943
+
723
944
 
724
945
  def _sbatch_remote_runsubs(
725
946
  remote_runsub_paths: List[Path],
@@ -739,19 +960,22 @@ def _sbatch_remote_runsubs(
739
960
  ssh_command.append(f"{username}@{hostname}")
740
961
  ssh_command.append(sbatch_commands)
741
962
  ssh_command = " ".join(ssh_command)
742
-
963
+ logger.info("Running sbatch", cmd=ssh_command)
743
964
  completed_process = subprocess.run(
744
- args=shlex.split(ssh_command), capture_output=True
965
+ args=shlex.split(ssh_command),
966
+ # NOTE(agronskiy): look out for hangs and deadlocks
967
+ stdout=subprocess.PIPE,
968
+ stderr=subprocess.PIPE,
745
969
  )
746
970
  if completed_process.returncode != 0:
971
+ error_msg = completed_process.stderr.decode("utf-8")
747
972
  raise RuntimeError(
748
- "failed to submit sbatch scripts for execution\n{}".format(
749
- completed_process.stderr.decode("utf-8")
750
- )
973
+ "failed to submit sbatch scripts for execution\n{}".format(error_msg)
751
974
  )
752
975
 
753
976
  sbatch_output = completed_process.stdout.decode("utf-8")
754
977
  slurm_job_ids = re.findall(r"(?<=Submitted batch job )\d+", sbatch_output)
978
+ logger.info("Started sbatch successfully", slurm_job_ids=slurm_job_ids)
755
979
  return slurm_job_ids
756
980
 
757
981
 
@@ -784,7 +1008,10 @@ def _query_slurm_jobs_status(
784
1008
  ssh_command.append(sacct_command)
785
1009
  ssh_command = " ".join(ssh_command)
786
1010
  completed_process = subprocess.run(
787
- args=shlex.split(ssh_command), capture_output=True
1011
+ args=shlex.split(ssh_command),
1012
+ # NOTE(agronskiy): look out for hangs and deadlocks
1013
+ stdout=subprocess.PIPE,
1014
+ stderr=subprocess.PIPE,
788
1015
  )
789
1016
  if completed_process.returncode != 0:
790
1017
  raise RuntimeError(
@@ -803,34 +1030,50 @@ def _query_slurm_jobs_status(
803
1030
 
804
1031
  def _kill_slurm_job(
805
1032
  slurm_job_ids: List[str], username: str, hostname: str, socket: str | None
806
- ) -> None:
807
- """Kill a SLURM job.
1033
+ ) -> tuple[str | None, subprocess.CompletedProcess]:
1034
+ """Kill a SLURM job, querying status first in one SSH call for efficiency.
808
1035
 
809
1036
  Args:
810
1037
  slurm_job_ids: List of SLURM job IDs to kill.
811
1038
  username: SSH username.
812
1039
  hostname: SSH hostname.
813
1040
  socket: control socket location or None
1041
+
1042
+ Returns:
1043
+ Tuple of (status_string, completed_process) where status_string is the SLURM status or None
814
1044
  """
815
1045
  if len(slurm_job_ids) == 0:
816
- return {}
817
- kill_command = "scancel {}".format(",".join(slurm_job_ids))
1046
+ return None, subprocess.CompletedProcess(args=[], returncode=0)
1047
+
1048
+ jobs_str = ",".join(slurm_job_ids)
1049
+ # Combine both commands in one SSH call: query THEN kill
1050
+ combined_command = (
1051
+ f"sacct -j {jobs_str} --format='JobID,State%32' --noheader -P 2>/dev/null; "
1052
+ f"scancel {jobs_str}"
1053
+ )
1054
+
818
1055
  ssh_command = ["ssh"]
819
1056
  if socket is not None:
820
1057
  ssh_command.append(f"-S {socket}")
821
1058
  ssh_command.append(f"{username}@{hostname}")
822
- ssh_command.append(kill_command)
1059
+ ssh_command.append(combined_command)
823
1060
  ssh_command = " ".join(ssh_command)
1061
+
824
1062
  completed_process = subprocess.run(
825
- args=shlex.split(ssh_command), capture_output=True
1063
+ args=shlex.split(ssh_command),
1064
+ # NOTE(agronskiy): look out for hangs and deadlocks
1065
+ stdout=subprocess.PIPE,
1066
+ stderr=subprocess.PIPE,
826
1067
  )
827
- if completed_process.returncode != 0:
828
- raise RuntimeError(
829
- "failed to kill slurm job\n{}".format(
830
- completed_process.stderr.decode("utf-8")
831
- )
832
- )
833
- return completed_process
1068
+
1069
+ # Parse the sacct output (before scancel runs)
1070
+ sacct_output = completed_process.stdout.decode("utf-8")
1071
+ sacct_output_lines = sacct_output.strip().split("\n")
1072
+ slurm_status = None
1073
+ if sacct_output_lines and len(slurm_job_ids) == 1:
1074
+ slurm_status = _parse_slurm_job_status(slurm_job_ids[0], sacct_output_lines)
1075
+
1076
+ return slurm_status, completed_process
834
1077
 
835
1078
 
836
1079
  def _parse_slurm_job_status(slurm_job_id: str, sacct_output_lines: List[str]) -> str:
@@ -898,7 +1141,10 @@ def _read_files_from_remote(
898
1141
  ssh_command.append(cat_commands)
899
1142
  ssh_command = " ".join(ssh_command)
900
1143
  completed_process = subprocess.run(
901
- args=shlex.split(ssh_command), capture_output=True
1144
+ args=shlex.split(ssh_command),
1145
+ # NOTE(agronskiy): look out for hangs and deadlocks
1146
+ stdout=subprocess.PIPE,
1147
+ stderr=subprocess.PIPE,
902
1148
  )
903
1149
  if completed_process.returncode != 0:
904
1150
  raise RuntimeError(
@@ -975,9 +1221,236 @@ sbatch --dependency=afternotok:$SLURM_JOB_ID $_this_script $SLURM_JOB_ID
975
1221
  """.strip()
976
1222
 
977
1223
 
978
- _WAIT_FOR_SERVER_HANDLER = """
979
- date
980
- # wait for the server to initialize
981
- bash -c 'while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" {health_url})" != "200" ]]; do kill -0 '"$SERVER_PID"' 2>/dev/null || {{ echo "Server process '"$SERVER_PID"' died"; exit 1; }}; sleep 5; done'
1224
+ def _generate_haproxy_config_with_placeholders(cfg):
1225
+ """Generate HAProxy configuration with placeholder IPs using Jinja template."""
1226
+ # Set up Jinja environment
1227
+ template_dir = Path(__file__).parent
1228
+ template_path = template_dir / "proxy.cfg.template"
1229
+
1230
+ if not template_path.exists():
1231
+ raise FileNotFoundError(f"Proxy template not found: {template_path}")
1232
+
1233
+ env = Environment(loader=FileSystemLoader(template_dir))
1234
+ template = env.get_template("proxy.cfg.template")
1235
+
1236
+ # Prepare template data with placeholder IPs - use actual number of nodes
1237
+ num_nodes = cfg.execution.num_nodes
1238
+ nodes = []
1239
+ for i in range(num_nodes):
1240
+ nodes.append({"ip": f"{{IP_{i}}}", "port": cfg.deployment.port})
1241
+
1242
+ # Get health check parameters from execution config
1243
+ proxy_config = cfg.execution.get("proxy", {}).get("config", {})
1244
+ health_check_path = proxy_config.get("health_check_path", "/health")
1245
+ health_check_status = proxy_config.get("health_check_status", 200)
1246
+ haproxy_port = proxy_config.get("haproxy_port", 5009)
1247
+
1248
+ # Render template
1249
+ config = template.render(
1250
+ haproxy_port=haproxy_port,
1251
+ health_check_path=health_check_path,
1252
+ health_check_status=health_check_status,
1253
+ nodes=nodes,
1254
+ )
1255
+
1256
+ return config
1257
+
1258
+
1259
+ def _generate_haproxy_config(cfg, nodes_ips):
1260
+ """Generate HAProxy configuration using Jinja template."""
1261
+ # Set up Jinja environment
1262
+ template_dir = Path(__file__).parent
1263
+ template_path = template_dir / "proxy.cfg.template"
1264
+
1265
+ if not template_path.exists():
1266
+ raise FileNotFoundError(f"Proxy template not found: {template_path}")
1267
+
1268
+ env = Environment(loader=FileSystemLoader(template_dir))
1269
+ template = env.get_template("proxy.cfg.template")
1270
+
1271
+ # Prepare template data
1272
+ nodes = []
1273
+ for i, ip in enumerate(nodes_ips, 1):
1274
+ nodes.append(
1275
+ {"ip": ip, "port": cfg.deployment.port} # All nodes use the same port
1276
+ )
1277
+
1278
+ # Get health check parameters from deployment config
1279
+ health_check_path = cfg.deployment.get("health_check_path", "/health")
1280
+ health_check_status = cfg.deployment.get("health_check_status", 200)
1281
+ haproxy_port = cfg.deployment.get("haproxy_port", 5009)
1282
+
1283
+ # Render template
1284
+ config = template.render(
1285
+ haproxy_port=haproxy_port,
1286
+ health_check_path=health_check_path,
1287
+ health_check_status=health_check_status,
1288
+ nodes=nodes,
1289
+ )
1290
+
1291
+ return config
1292
+
1293
+
1294
+ def _generate_deployment_srun_command(
1295
+ cfg, deployment_mounts_list, remote_task_subdir, instance_id: int = 0
1296
+ ):
1297
+ """Generate the deployment srun command with proper node/ntask configuration.
1298
+
1299
+ Returns:
1300
+ tuple: (script_string, is_potentially_unsafe, debug_comment)
1301
+ """
1302
+ s = ""
1303
+ debug_comment = ""
1304
+ is_potentially_unsafe = False
1305
+
1306
+ s += "# deployment server\n"
1307
+
1308
+ # Extract pre_cmd for later use inside container
1309
+ pre_cmd: str = cfg.deployment.get("pre_cmd") or ""
1310
+ if pre_cmd:
1311
+ is_potentially_unsafe = True
1312
+ create_pre_script_cmd = _str_to_echo_command(
1313
+ pre_cmd, filename="deployment_pre_cmd.sh"
1314
+ )
1315
+ debug_comment += create_pre_script_cmd.debug + "\n\n"
1316
+
1317
+ s += "# Get node IPs\n"
1318
+ s += "nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) )\n"
1319
+ s += 'nodes_array=("${nodes[@]}") # Ensure nodes are stored properly\n'
1320
+ s += 'export NODES_IPS_ARRAY=($(for node in "${nodes_array[@]}"; do srun --nodelist=$node --ntasks=1 --nodes=1 hostname --ip-address; done))\n'
1321
+ s += 'echo "Node IPs: ${NODES_IPS_ARRAY[@]}"\n'
1322
+ s += "# Export MASTER_IP as the first node IP\n"
1323
+ s += "export MASTER_IP=${NODES_IPS_ARRAY[0]}\n"
1324
+ s += 'echo "MASTER_IP: $MASTER_IP"\n'
1325
+
1326
+ # Add debug comment for deployment pre_cmd before srun command
1327
+ if debug_comment:
1328
+ s += "# Debug contents of deployment pre_cmd\n"
1329
+ s += debug_comment
1330
+ s += "\n"
1331
+
1332
+ s += "srun --mpi pmix --overlap "
1333
+ s += f"--nodes {cfg.execution.num_nodes} --ntasks {cfg.execution.get('deployment', {}).get('n_tasks', 1)} "
1334
+ s += "--container-image {} ".format(cfg.deployment.image)
1335
+ if deployment_mounts_list:
1336
+ s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
1337
+ if not cfg.execution.get("mounts", {}).get("mount_home", True):
1338
+ s += "--no-container-mount-home "
1339
+ s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A-%t.out")
1340
+
1341
+ deployment_env_var_names = list(
1342
+ cfg.execution.get("env_vars", {}).get("deployment", {})
1343
+ )
1344
+ if cfg.deployment.get("env_vars"):
1345
+ warnings.warn(
1346
+ "cfg.deployment.env_vars will be deprecated in future versions. "
1347
+ "Use cfg.execution.env_vars.deployment instead.",
1348
+ category=DeprecationWarning,
1349
+ stacklevel=2,
1350
+ )
1351
+ deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
1352
+
1353
+ # Always add MASTER_IP to the environment variables
1354
+ if "MASTER_IP" not in deployment_env_var_names:
1355
+ deployment_env_var_names.append("MASTER_IP")
1356
+
1357
+ if deployment_env_var_names:
1358
+ s += f"--container-env {','.join(deployment_env_var_names)} "
1359
+
1360
+ # Wrap deployment command to execute pre_cmd inside container if needed
1361
+ if pre_cmd:
1362
+ # Create a wrapper command that runs inside the container:
1363
+ # 1. Create deployment_pre_cmd.sh file
1364
+ # 2. Source it
1365
+ # 3. Execute the original deployment command
1366
+ create_pre_script_cmd = _str_to_echo_command(
1367
+ pre_cmd, filename="deployment_pre_cmd.sh"
1368
+ )
1369
+ # Escape single quotes in the deployment command for bash -c
1370
+ escaped_deployment_cmd = cfg.deployment.command.replace("'", "'\"'\"'")
1371
+ wrapped_command = (
1372
+ f"bash -c '{create_pre_script_cmd.cmd} && "
1373
+ f"source deployment_pre_cmd.sh && "
1374
+ f"{escaped_deployment_cmd}'"
1375
+ )
1376
+ s += "{} &\n\n".format(wrapped_command)
1377
+ else:
1378
+ s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
1379
+
1380
+ s += "SERVER_PID=$! # capture the PID of the server background srun process\n\n"
1381
+
1382
+ return s, is_potentially_unsafe, debug_comment
1383
+
1384
+
1385
+ def _get_wait_for_server_handler(
1386
+ ip_list: str,
1387
+ port: int,
1388
+ health_check_path: str,
1389
+ service_name: str = "server",
1390
+ check_pid: bool = False,
1391
+ ):
1392
+ """Generate wait for server handler that takes a list of IPs."""
1393
+ pid_check = ""
1394
+ if check_pid:
1395
+ pid_check = 'kill -0 "$SERVER_PID" 2>/dev/null || { echo "Server process $SERVER_PID died"; exit 1; }'
1396
+
1397
+ handler = f"""date
1398
+ # wait for the {service_name} to initialize
1399
+ for ip in {ip_list}; do
1400
+ echo "Waiting for {service_name} on $ip..."
1401
+ while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" http://$ip:{port}{health_check_path})" != "200" ]]; do
1402
+ {pid_check}
1403
+ sleep 5
1404
+ done
1405
+ echo "{service_name} ready on $ip!"
1406
+ done
982
1407
  date
983
1408
  """.strip()
1409
+
1410
+ return handler
1411
+
1412
+
1413
+ def _get_proxy_server_srun_command(cfg, remote_task_subdir):
1414
+ """Generate proxy server srun command based on proxy type."""
1415
+ proxy_type = cfg.execution.get("proxy", {}).get("type", "haproxy")
1416
+
1417
+ if proxy_type == "haproxy":
1418
+ return _generate_haproxy_srun_command(cfg, remote_task_subdir)
1419
+ else:
1420
+ raise ValueError(
1421
+ f"Unsupported proxy type: {proxy_type}. Currently only 'haproxy' is supported."
1422
+ )
1423
+
1424
+
1425
+ def _generate_haproxy_srun_command(cfg, remote_task_subdir):
1426
+ """Generate HAProxy-specific srun command using template-based config."""
1427
+ s = ""
1428
+ s += "# Proxy load balancer\n"
1429
+ s += "# Copy template to config file (important for restarts)\n"
1430
+ s += f"cp {remote_task_subdir}/proxy.cfg.template {remote_task_subdir}/proxy.cfg\n"
1431
+ s += "# Replace placeholder IPs with actual node IPs\n"
1432
+ s += f"proxy_config_file={remote_task_subdir}/proxy.cfg\n"
1433
+ s += 'for i in "${!NODES_IPS_ARRAY[@]}"; do\n'
1434
+ s += ' ip="${NODES_IPS_ARRAY[$i]}"\n'
1435
+ s += ' sed -i "s/{IP_$i}/$ip/g" "$proxy_config_file"\n'
1436
+ s += "done\n"
1437
+ s += "\n"
1438
+ s += "srun --mpi pmix --overlap "
1439
+ s += "--nodes 1 --ntasks 1 "
1440
+ s += f"--container-image {cfg.execution.get('proxy', {}).get('image', 'haproxy:latest')} "
1441
+ s += f"--container-mounts {remote_task_subdir}/proxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro "
1442
+ s += f"--output {remote_task_subdir}/logs/proxy-%A.out "
1443
+ s += "haproxy -f /usr/local/etc/haproxy/haproxy.cfg &\n"
1444
+ s += "PROXY_PID=$! # capture the PID of the proxy background srun process\n"
1445
+ s += 'echo "Proxy started with PID: $PROXY_PID"\n\n'
1446
+
1447
+ # Wait for proxy to be ready on localhost
1448
+ proxy_config = cfg.execution.get("proxy", {}).get("config", {})
1449
+ haproxy_port = proxy_config.get("haproxy_port", 5009)
1450
+ health_path = proxy_config.get("health_check_path", "/health")
1451
+ s += _get_wait_for_server_handler(
1452
+ "127.0.0.1", haproxy_port, health_path, "Proxy", check_pid=False
1453
+ )
1454
+ s += "\n"
1455
+
1456
+ return s