nemo-evaluator-launcher 0.1.19__py3-none-any.whl → 0.1.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nemo_evaluator_launcher/api/functional.py +105 -1
- nemo_evaluator_launcher/cli/logs.py +102 -0
- nemo_evaluator_launcher/cli/main.py +12 -0
- nemo_evaluator_launcher/cli/run.py +73 -15
- nemo_evaluator_launcher/cli/version.py +26 -23
- nemo_evaluator_launcher/common/helpers.py +176 -43
- nemo_evaluator_launcher/common/logging_utils.py +16 -5
- nemo_evaluator_launcher/common/printing_utils.py +7 -0
- nemo_evaluator_launcher/configs/deployment/sglang.yaml +4 -2
- nemo_evaluator_launcher/configs/deployment/trtllm.yaml +2 -3
- nemo_evaluator_launcher/configs/deployment/vllm.yaml +0 -1
- nemo_evaluator_launcher/configs/execution/slurm/default.yaml +14 -0
- nemo_evaluator_launcher/executors/base.py +31 -1
- nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +36 -1
- nemo_evaluator_launcher/executors/lepton/executor.py +81 -1
- nemo_evaluator_launcher/executors/local/executor.py +377 -22
- nemo_evaluator_launcher/executors/local/run.template.sh +54 -2
- nemo_evaluator_launcher/executors/slurm/executor.py +422 -59
- nemo_evaluator_launcher/executors/slurm/proxy.cfg.template +26 -0
- nemo_evaluator_launcher/exporters/utils.py +32 -46
- nemo_evaluator_launcher/package_info.py +1 -1
- nemo_evaluator_launcher/resources/mapping.toml +56 -15
- {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/METADATA +3 -3
- {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/RECORD +28 -26
- {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/entry_points.txt +1 -0
- {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/WHEEL +0 -0
- {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/licenses/LICENSE +0 -0
- {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/top_level.txt +0 -0
|
@@ -30,6 +30,7 @@ from pathlib import Path
|
|
|
30
30
|
from typing import Dict, List, Optional
|
|
31
31
|
|
|
32
32
|
import yaml
|
|
33
|
+
from jinja2 import Environment, FileSystemLoader
|
|
33
34
|
from omegaconf import DictConfig, OmegaConf
|
|
34
35
|
|
|
35
36
|
from nemo_evaluator_launcher.common.execdb import (
|
|
@@ -39,18 +40,19 @@ from nemo_evaluator_launcher.common.execdb import (
|
|
|
39
40
|
generate_job_id,
|
|
40
41
|
)
|
|
41
42
|
from nemo_evaluator_launcher.common.helpers import (
|
|
43
|
+
CmdAndReadableComment,
|
|
44
|
+
_str_to_echo_command,
|
|
42
45
|
get_api_key_name,
|
|
43
|
-
get_endpoint_url,
|
|
44
46
|
get_eval_factory_command,
|
|
45
47
|
get_eval_factory_dataset_size_from_run_config,
|
|
46
|
-
get_health_url,
|
|
47
48
|
get_timestamp_string,
|
|
48
49
|
)
|
|
50
|
+
from nemo_evaluator_launcher.common.logging_utils import logger
|
|
49
51
|
from nemo_evaluator_launcher.common.mapping import (
|
|
50
52
|
get_task_from_mapping,
|
|
51
53
|
load_tasks_mapping,
|
|
52
54
|
)
|
|
53
|
-
from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey
|
|
55
|
+
from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey, red
|
|
54
56
|
from nemo_evaluator_launcher.executors.base import (
|
|
55
57
|
BaseExecutor,
|
|
56
58
|
ExecutionState,
|
|
@@ -94,6 +96,7 @@ class SlurmExecutor(BaseExecutor):
|
|
|
94
96
|
tasks_mapping = load_tasks_mapping()
|
|
95
97
|
eval_images: list[str] = []
|
|
96
98
|
|
|
99
|
+
is_potentially_unsafe = False
|
|
97
100
|
for idx, task in enumerate(cfg.evaluation.tasks):
|
|
98
101
|
# calculate job_id
|
|
99
102
|
job_id = f"{invocation_id}.{idx}"
|
|
@@ -114,7 +117,7 @@ class SlurmExecutor(BaseExecutor):
|
|
|
114
117
|
eval_images.append(eval_image)
|
|
115
118
|
|
|
116
119
|
# generate and write down sbatch script
|
|
117
|
-
|
|
120
|
+
sbatch_script_content_struct = _create_slurm_sbatch_script(
|
|
118
121
|
cfg=cfg,
|
|
119
122
|
task=task,
|
|
120
123
|
eval_image=eval_image,
|
|
@@ -122,6 +125,32 @@ class SlurmExecutor(BaseExecutor):
|
|
|
122
125
|
invocation_id=invocation_id,
|
|
123
126
|
job_id=job_id,
|
|
124
127
|
)
|
|
128
|
+
|
|
129
|
+
# Create proxy config file with placeholder IPs for multi-instance deployments
|
|
130
|
+
if cfg.deployment.get("multiple_instances", False):
|
|
131
|
+
proxy_type = cfg.execution.get("proxy", {}).get("type", "haproxy")
|
|
132
|
+
if proxy_type == "haproxy":
|
|
133
|
+
proxy_config = _generate_haproxy_config_with_placeholders(cfg)
|
|
134
|
+
else:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
f"Unsupported proxy type: {proxy_type}. Currently only 'haproxy' is supported."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Save both template and working config
|
|
140
|
+
proxy_template_path = local_task_subdir / "proxy.cfg.template"
|
|
141
|
+
proxy_config_path = local_task_subdir / "proxy.cfg"
|
|
142
|
+
with open(proxy_template_path, "w") as f:
|
|
143
|
+
f.write(proxy_config)
|
|
144
|
+
with open(proxy_config_path, "w") as f:
|
|
145
|
+
f.write(proxy_config)
|
|
146
|
+
|
|
147
|
+
sbatch_script_content_str = sbatch_script_content_struct.cmd
|
|
148
|
+
|
|
149
|
+
# We accumulate if any task contains unsafe commands
|
|
150
|
+
is_potentially_unsafe = (
|
|
151
|
+
is_potentially_unsafe
|
|
152
|
+
or sbatch_script_content_struct.is_potentially_unsafe
|
|
153
|
+
)
|
|
125
154
|
local_runsub_path = local_task_subdir / "run.sub"
|
|
126
155
|
remote_runsub_path = remote_task_subdir / "run.sub"
|
|
127
156
|
with open(local_runsub_path, "w") as f:
|
|
@@ -138,8 +167,34 @@ class SlurmExecutor(BaseExecutor):
|
|
|
138
167
|
with open(local_runsub_path, "r") as f:
|
|
139
168
|
print(grey(f.read()))
|
|
140
169
|
print(bold("To submit jobs") + ", run the executor without --dry-run")
|
|
170
|
+
if is_potentially_unsafe:
|
|
171
|
+
print(
|
|
172
|
+
red(
|
|
173
|
+
"\nFound `pre_cmd` (evaluation or deployment) which carries security risk. When running without --dry-run "
|
|
174
|
+
"make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1"
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
|
|
141
178
|
return invocation_id
|
|
142
179
|
|
|
180
|
+
if is_potentially_unsafe:
|
|
181
|
+
if os.environ.get("NEMO_EVALUATOR_TRUST_PRE_CMD", "") == "1":
|
|
182
|
+
logger.warning(
|
|
183
|
+
"Found non-empty commands (e.g. `pre_cmd` in evaluation or deployment) and NEMO_EVALUATOR_TRUST_PRE_CMD "
|
|
184
|
+
"is set, proceeding with caution."
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
else:
|
|
188
|
+
logger.error(
|
|
189
|
+
"Found non-empty commands (e.g. `pre_cmd` in evaluation or deployment) and NEMO_EVALUATOR_TRUST_PRE_CMD "
|
|
190
|
+
"is not set. This might carry security risk and unstable environments. "
|
|
191
|
+
"To continue, make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1.",
|
|
192
|
+
)
|
|
193
|
+
raise AttributeError(
|
|
194
|
+
"Untrusted command found in config, make sure you trust and "
|
|
195
|
+
"set NEMO_EVALUATOR_TRUST_PRE_CMD=1."
|
|
196
|
+
)
|
|
197
|
+
|
|
143
198
|
socket = str(Path(tmpdirname) / "socket")
|
|
144
199
|
socket_or_none = _open_master_connection(
|
|
145
200
|
username=cfg.execution.username,
|
|
@@ -437,7 +492,7 @@ def _create_slurm_sbatch_script(
|
|
|
437
492
|
remote_task_subdir: Path,
|
|
438
493
|
invocation_id: str,
|
|
439
494
|
job_id: str,
|
|
440
|
-
) ->
|
|
495
|
+
) -> CmdAndReadableComment:
|
|
441
496
|
"""Generate the contents of a SLURM sbatch script for a given evaluation task.
|
|
442
497
|
|
|
443
498
|
Args:
|
|
@@ -453,7 +508,6 @@ def _create_slurm_sbatch_script(
|
|
|
453
508
|
# get task from mapping, overrides, urls
|
|
454
509
|
tasks_mapping = load_tasks_mapping()
|
|
455
510
|
task_definition = get_task_from_mapping(task.name, tasks_mapping)
|
|
456
|
-
health_url = get_health_url(cfg, get_endpoint_url(cfg, task, task_definition))
|
|
457
511
|
|
|
458
512
|
# TODO(public release): convert to template
|
|
459
513
|
s = "#!/bin/bash\n"
|
|
@@ -468,6 +522,8 @@ def _create_slurm_sbatch_script(
|
|
|
468
522
|
s += "#SBATCH --gpus-per-node {}\n".format(cfg.execution.gpus_per_node)
|
|
469
523
|
if hasattr(cfg.execution, "gres"):
|
|
470
524
|
s += "#SBATCH --gres {}\n".format(cfg.execution.gres)
|
|
525
|
+
if cfg.execution.get("sbatch_comment"):
|
|
526
|
+
s += "#SBATCH --comment='{}'\n".format(cfg.execution.sbatch_comment)
|
|
471
527
|
job_name = "{account}-{subproject}.{details}".format(
|
|
472
528
|
account=cfg.execution.account,
|
|
473
529
|
subproject=cfg.execution.subproject,
|
|
@@ -493,8 +549,11 @@ def _create_slurm_sbatch_script(
|
|
|
493
549
|
if os.getenv(env_var) is None:
|
|
494
550
|
raise ValueError(f"Trying to pass an unset environment variable {env_var}.")
|
|
495
551
|
|
|
496
|
-
# check if required env vars are defined:
|
|
552
|
+
# check if required env vars are defined (excluding NEMO_EVALUATOR_DATASET_DIR which is handled separately):
|
|
497
553
|
for required_env_var in task_definition.get("required_env_vars", []):
|
|
554
|
+
# Skip NEMO_EVALUATOR_DATASET_DIR as it's handled by dataset mounting logic below
|
|
555
|
+
if required_env_var == "NEMO_EVALUATOR_DATASET_DIR":
|
|
556
|
+
continue
|
|
498
557
|
if required_env_var not in env_vars.keys():
|
|
499
558
|
raise ValueError(
|
|
500
559
|
f"{task.name} task requires environment variable {required_env_var}."
|
|
@@ -540,6 +599,7 @@ def _create_slurm_sbatch_script(
|
|
|
540
599
|
|
|
541
600
|
# prepare deployment mounts
|
|
542
601
|
deployment_mounts_list = []
|
|
602
|
+
deployment_is_unsafe = False
|
|
543
603
|
if cfg.deployment.type != "none":
|
|
544
604
|
if checkpoint_path := cfg.deployment.get("checkpoint_path"):
|
|
545
605
|
deployment_mounts_list.append(f"{checkpoint_path}:/checkpoint:ro")
|
|
@@ -551,36 +611,33 @@ def _create_slurm_sbatch_script(
|
|
|
551
611
|
deployment_mounts_list.append(f"{source_mnt}:{target_mnt}")
|
|
552
612
|
|
|
553
613
|
# add deployment srun command
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
if deployment_mounts_list:
|
|
558
|
-
s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
|
|
559
|
-
if not cfg.execution.get("mounts", {}).get("mount_home", True):
|
|
560
|
-
s += "--no-container-mount-home "
|
|
561
|
-
s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A.out")
|
|
562
|
-
deployment_env_var_names = list(
|
|
563
|
-
cfg.execution.get("env_vars", {}).get("deployment", {})
|
|
564
|
-
)
|
|
565
|
-
if cfg.deployment.get("env_vars"):
|
|
566
|
-
warnings.warn(
|
|
567
|
-
"cfg.deployment.env_vars will be deprecated in future versions. "
|
|
568
|
-
"Use cfg.execution.env_vars.deployment instead.",
|
|
569
|
-
category=DeprecationWarning,
|
|
570
|
-
stacklevel=2,
|
|
614
|
+
deployment_srun_cmd, deployment_is_unsafe, deployment_debug = (
|
|
615
|
+
_generate_deployment_srun_command(
|
|
616
|
+
cfg, deployment_mounts_list, remote_task_subdir
|
|
571
617
|
)
|
|
572
|
-
deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
|
|
573
|
-
if deployment_env_var_names:
|
|
574
|
-
s += f"--container-env {','.join(deployment_env_var_names)} "
|
|
575
|
-
s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
|
|
576
|
-
s += (
|
|
577
|
-
"SERVER_PID=$! # capture the PID of the server background srun process\n\n"
|
|
578
618
|
)
|
|
619
|
+
s += deployment_srun_cmd
|
|
579
620
|
|
|
580
621
|
# wait for the server to initialize
|
|
581
|
-
|
|
622
|
+
health_path = cfg.deployment.get("health_check_path", "/health")
|
|
623
|
+
# For multi-instance check all node IPs, for single instance check localhost
|
|
624
|
+
if cfg.deployment.get("multiple_instances", False):
|
|
625
|
+
ip_list = '"${NODES_IPS_ARRAY[@]}"'
|
|
626
|
+
else:
|
|
627
|
+
ip_list = '"127.0.0.1"'
|
|
628
|
+
s += _get_wait_for_server_handler(
|
|
629
|
+
ip_list,
|
|
630
|
+
cfg.deployment.port,
|
|
631
|
+
health_path,
|
|
632
|
+
"server",
|
|
633
|
+
check_pid=True,
|
|
634
|
+
)
|
|
582
635
|
s += "\n\n"
|
|
583
636
|
|
|
637
|
+
# add proxy load balancer for multi-instance deployments
|
|
638
|
+
if cfg.deployment.get("multiple_instances", False):
|
|
639
|
+
s += _get_proxy_server_srun_command(cfg, remote_task_subdir)
|
|
640
|
+
|
|
584
641
|
# prepare evaluation mounts
|
|
585
642
|
evaluation_mounts_list = [
|
|
586
643
|
"{}:/results".format(remote_task_subdir / "artifacts"),
|
|
@@ -590,7 +647,29 @@ def _create_slurm_sbatch_script(
|
|
|
590
647
|
):
|
|
591
648
|
evaluation_mounts_list.append(f"{source_mnt}:{target_mnt}")
|
|
592
649
|
|
|
593
|
-
|
|
650
|
+
# Handle dataset directory mounting if NEMO_EVALUATOR_DATASET_DIR is required
|
|
651
|
+
if "NEMO_EVALUATOR_DATASET_DIR" in task_definition.get("required_env_vars", []):
|
|
652
|
+
# Get dataset directory from task config
|
|
653
|
+
if "dataset_dir" in task:
|
|
654
|
+
dataset_mount_host = task["dataset_dir"]
|
|
655
|
+
else:
|
|
656
|
+
raise ValueError(
|
|
657
|
+
f"{task.name} task requires a dataset_dir to be specified. "
|
|
658
|
+
f"Add 'dataset_dir: /path/to/your/dataset' under the task configuration."
|
|
659
|
+
)
|
|
660
|
+
# Get container mount path (default to /datasets if not specified)
|
|
661
|
+
dataset_mount_container = task.get("dataset_mount_path", "/datasets")
|
|
662
|
+
# Add dataset mount to evaluation mounts list
|
|
663
|
+
evaluation_mounts_list.append(f"{dataset_mount_host}:{dataset_mount_container}")
|
|
664
|
+
# Export NEMO_EVALUATOR_DATASET_DIR environment variable
|
|
665
|
+
s += f"export NEMO_EVALUATOR_DATASET_DIR={dataset_mount_container}\n\n"
|
|
666
|
+
|
|
667
|
+
eval_factory_command_struct = get_eval_factory_command(
|
|
668
|
+
cfg,
|
|
669
|
+
task,
|
|
670
|
+
task_definition,
|
|
671
|
+
)
|
|
672
|
+
|
|
594
673
|
eval_factory_command = eval_factory_command_struct.cmd
|
|
595
674
|
# The debug comment for placing into the script and easy debug. Reason
|
|
596
675
|
# (see `CmdAndReadableComment`) is the current way of passing the command
|
|
@@ -606,6 +685,7 @@ def _create_slurm_sbatch_script(
|
|
|
606
685
|
|
|
607
686
|
s += "# evaluation client\n"
|
|
608
687
|
s += "srun --mpi pmix --overlap "
|
|
688
|
+
s += "--nodes 1 --ntasks 1 " # Client always runs on single node
|
|
609
689
|
s += "--container-image {} ".format(eval_image)
|
|
610
690
|
evaluation_env_var_names = list(
|
|
611
691
|
cfg.execution.get("env_vars", {}).get("evaluation", {})
|
|
@@ -623,7 +703,10 @@ def _create_slurm_sbatch_script(
|
|
|
623
703
|
|
|
624
704
|
# terminate the server after all evaluation clients finish
|
|
625
705
|
if cfg.deployment.type != "none":
|
|
626
|
-
s += "kill $SERVER_PID # terminate the server to finish gracefully\n
|
|
706
|
+
s += "kill $SERVER_PID # terminate the server to finish gracefully\n"
|
|
707
|
+
if cfg.deployment.get("multiple_instances", False):
|
|
708
|
+
s += "kill $PROXY_PID # terminate proxy to finish gracefully\n"
|
|
709
|
+
s += "\n"
|
|
627
710
|
|
|
628
711
|
# auto-export
|
|
629
712
|
ae_cfg = cfg.execution.get("auto_export")
|
|
@@ -635,9 +718,22 @@ def _create_slurm_sbatch_script(
|
|
|
635
718
|
|
|
636
719
|
if destinations:
|
|
637
720
|
export_env = dict(cfg.execution.get("env_vars", {}).get("export", {}) or {})
|
|
638
|
-
s += _generate_auto_export_section(
|
|
721
|
+
s += _generate_auto_export_section(
|
|
722
|
+
cfg, job_id, destinations, export_env, remote_task_subdir
|
|
723
|
+
)
|
|
639
724
|
|
|
640
|
-
|
|
725
|
+
debug_str = "\n".join(["# " + line for line in s.splitlines()])
|
|
726
|
+
|
|
727
|
+
# Combine unsafe flags from both deployment and evaluation
|
|
728
|
+
is_potentially_unsafe = (
|
|
729
|
+
eval_factory_command_struct.is_potentially_unsafe or deployment_is_unsafe
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
return CmdAndReadableComment(
|
|
733
|
+
cmd=s,
|
|
734
|
+
debug=debug_str,
|
|
735
|
+
is_potentially_unsafe=is_potentially_unsafe,
|
|
736
|
+
)
|
|
641
737
|
|
|
642
738
|
|
|
643
739
|
def _generate_auto_export_section(
|
|
@@ -645,6 +741,8 @@ def _generate_auto_export_section(
|
|
|
645
741
|
job_id: str,
|
|
646
742
|
destinations: list,
|
|
647
743
|
export_env: dict,
|
|
744
|
+
remote_task_subdir: Path,
|
|
745
|
+
export_image: str = "python:3.12.7-slim",
|
|
648
746
|
) -> str:
|
|
649
747
|
"""Generate simple auto-export section for sbatch script."""
|
|
650
748
|
if not destinations:
|
|
@@ -654,10 +752,7 @@ def _generate_auto_export_section(
|
|
|
654
752
|
s += "EVAL_EXIT_CODE=$?\n"
|
|
655
753
|
s += "if [ $EVAL_EXIT_CODE -eq 0 ]; then\n"
|
|
656
754
|
s += " echo 'Evaluation completed successfully. Starting auto-export...'\n"
|
|
657
|
-
s +=
|
|
658
|
-
s += " set +x\n"
|
|
659
|
-
s += " set +u\n"
|
|
660
|
-
s += ' cd "$TASK_DIR/artifacts"\n'
|
|
755
|
+
s += f' cd "{remote_task_subdir}/artifacts"\n'
|
|
661
756
|
|
|
662
757
|
# Work with DictConfig; convert only for YAML at the end
|
|
663
758
|
exec_type = (
|
|
@@ -713,10 +808,25 @@ def _generate_auto_export_section(
|
|
|
713
808
|
esc = str(v).replace('"', '\\"')
|
|
714
809
|
s += f' export {k}="{esc}"\n'
|
|
715
810
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
811
|
+
s += " # export\n"
|
|
812
|
+
s += " srun --mpi pmix --overlap "
|
|
813
|
+
s += "--nodes 1 --ntasks 1 " # Client always runs on single node
|
|
814
|
+
s += "--container-image {} ".format(export_image)
|
|
815
|
+
if export_env:
|
|
816
|
+
s += "--container-env {} ".format(",".join(export_env))
|
|
817
|
+
if not cfg.execution.get("mounts", {}).get("mount_home", True):
|
|
818
|
+
s += "--no-container-mount-home "
|
|
719
819
|
|
|
820
|
+
s += f"--container-mounts {remote_task_subdir}/artifacts:{remote_task_subdir}/artifacts "
|
|
821
|
+
s += "--output {} ".format(remote_task_subdir / "logs" / "export-%A.out")
|
|
822
|
+
s += " bash -c '\n"
|
|
823
|
+
# FIXME(martas): would be good to install specific version
|
|
824
|
+
s += " pip install nemo-evaluator-launcher[all]\n"
|
|
825
|
+
s += f" cd {remote_task_subdir}/artifacts\n"
|
|
826
|
+
for dest in destinations:
|
|
827
|
+
s += f' echo "Exporting to {dest}..."\n'
|
|
828
|
+
s += f' nemo-evaluator-launcher export {job_id} --dest {dest} || echo "Export to {dest} failed"\n'
|
|
829
|
+
s += "'\n"
|
|
720
830
|
s += " echo 'Auto-export completed.'\n"
|
|
721
831
|
s += "else\n"
|
|
722
832
|
s += " echo 'Evaluation failed with exit code $EVAL_EXIT_CODE. Skipping auto-export.'\n"
|
|
@@ -731,11 +841,12 @@ def _open_master_connection(
|
|
|
731
841
|
socket: str,
|
|
732
842
|
) -> str | None:
|
|
733
843
|
ssh_command = f"ssh -MNf -S {socket} {username}@{hostname}"
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
)
|
|
844
|
+
logger.info("Opening master connection", cmd=ssh_command)
|
|
845
|
+
completed_process = subprocess.run(args=shlex.split(ssh_command))
|
|
737
846
|
if completed_process.returncode == 0:
|
|
847
|
+
logger.info("Opened master connection successfully", cmd=ssh_command)
|
|
738
848
|
return socket
|
|
849
|
+
logger.error("Failed to open master connection", code=completed_process.returncode)
|
|
739
850
|
return None
|
|
740
851
|
|
|
741
852
|
|
|
@@ -747,9 +858,7 @@ def _close_master_connection(
|
|
|
747
858
|
if socket is None:
|
|
748
859
|
return
|
|
749
860
|
ssh_command = f"ssh -O exit -S {socket} {username}@{hostname}"
|
|
750
|
-
completed_process = subprocess.run(
|
|
751
|
-
args=shlex.split(ssh_command), capture_output=True
|
|
752
|
-
)
|
|
861
|
+
completed_process = subprocess.run(args=shlex.split(ssh_command))
|
|
753
862
|
if completed_process.returncode != 0:
|
|
754
863
|
raise RuntimeError(
|
|
755
864
|
"failed to close the master connection\n{}".format(
|
|
@@ -771,8 +880,9 @@ def _make_remote_execution_output_dir(
|
|
|
771
880
|
ssh_command.append(f"{username}@{hostname}")
|
|
772
881
|
ssh_command.append(mkdir_command)
|
|
773
882
|
ssh_command = " ".join(ssh_command)
|
|
883
|
+
logger.info("Creating remote dir", cmd=ssh_command)
|
|
774
884
|
completed_process = subprocess.run(
|
|
775
|
-
args=shlex.split(ssh_command),
|
|
885
|
+
args=shlex.split(ssh_command), stderr=subprocess.PIPE
|
|
776
886
|
)
|
|
777
887
|
if completed_process.returncode != 0:
|
|
778
888
|
error_msg = (
|
|
@@ -780,6 +890,11 @@ def _make_remote_execution_output_dir(
|
|
|
780
890
|
if completed_process.stderr
|
|
781
891
|
else "Unknown error"
|
|
782
892
|
)
|
|
893
|
+
logger.error(
|
|
894
|
+
"Erorr creating remote dir",
|
|
895
|
+
code=completed_process.returncode,
|
|
896
|
+
msg=error_msg,
|
|
897
|
+
)
|
|
783
898
|
raise RuntimeError(
|
|
784
899
|
"failed to make a remote execution output dir\n{}".format(error_msg)
|
|
785
900
|
)
|
|
@@ -807,8 +922,10 @@ def _rsync_upload_rundirs(
|
|
|
807
922
|
remote_destination_str = f"{username}@{hostname}:{remote_target}"
|
|
808
923
|
local_sources_str = " ".join(map(str, local_sources))
|
|
809
924
|
rsync_upload_command = f"rsync -qcaz {local_sources_str} {remote_destination_str}"
|
|
925
|
+
logger.info("Rsyncing to remote dir", cmd=rsync_upload_command)
|
|
810
926
|
completed_process = subprocess.run(
|
|
811
|
-
args=shlex.split(rsync_upload_command),
|
|
927
|
+
args=shlex.split(rsync_upload_command),
|
|
928
|
+
stderr=subprocess.PIPE,
|
|
812
929
|
)
|
|
813
930
|
if completed_process.returncode != 0:
|
|
814
931
|
error_msg = (
|
|
@@ -816,6 +933,12 @@ def _rsync_upload_rundirs(
|
|
|
816
933
|
if completed_process.stderr
|
|
817
934
|
else "Unknown error"
|
|
818
935
|
)
|
|
936
|
+
|
|
937
|
+
logger.error(
|
|
938
|
+
"Erorr rsyncing to remote dir",
|
|
939
|
+
code=completed_process.returncode,
|
|
940
|
+
msg=error_msg,
|
|
941
|
+
)
|
|
819
942
|
raise RuntimeError("failed to upload local sources\n{}".format(error_msg))
|
|
820
943
|
|
|
821
944
|
|
|
@@ -837,9 +960,12 @@ def _sbatch_remote_runsubs(
|
|
|
837
960
|
ssh_command.append(f"{username}@{hostname}")
|
|
838
961
|
ssh_command.append(sbatch_commands)
|
|
839
962
|
ssh_command = " ".join(ssh_command)
|
|
840
|
-
|
|
963
|
+
logger.info("Running sbatch", cmd=ssh_command)
|
|
841
964
|
completed_process = subprocess.run(
|
|
842
|
-
args=shlex.split(ssh_command),
|
|
965
|
+
args=shlex.split(ssh_command),
|
|
966
|
+
# NOTE(agronskiy): look out for hangs and deadlocks
|
|
967
|
+
stdout=subprocess.PIPE,
|
|
968
|
+
stderr=subprocess.PIPE,
|
|
843
969
|
)
|
|
844
970
|
if completed_process.returncode != 0:
|
|
845
971
|
error_msg = completed_process.stderr.decode("utf-8")
|
|
@@ -849,6 +975,7 @@ def _sbatch_remote_runsubs(
|
|
|
849
975
|
|
|
850
976
|
sbatch_output = completed_process.stdout.decode("utf-8")
|
|
851
977
|
slurm_job_ids = re.findall(r"(?<=Submitted batch job )\d+", sbatch_output)
|
|
978
|
+
logger.info("Started sbatch successfully", slurm_job_ids=slurm_job_ids)
|
|
852
979
|
return slurm_job_ids
|
|
853
980
|
|
|
854
981
|
|
|
@@ -881,7 +1008,10 @@ def _query_slurm_jobs_status(
|
|
|
881
1008
|
ssh_command.append(sacct_command)
|
|
882
1009
|
ssh_command = " ".join(ssh_command)
|
|
883
1010
|
completed_process = subprocess.run(
|
|
884
|
-
args=shlex.split(ssh_command),
|
|
1011
|
+
args=shlex.split(ssh_command),
|
|
1012
|
+
# NOTE(agronskiy): look out for hangs and deadlocks
|
|
1013
|
+
stdout=subprocess.PIPE,
|
|
1014
|
+
stderr=subprocess.PIPE,
|
|
885
1015
|
)
|
|
886
1016
|
if completed_process.returncode != 0:
|
|
887
1017
|
raise RuntimeError(
|
|
@@ -930,7 +1060,10 @@ def _kill_slurm_job(
|
|
|
930
1060
|
ssh_command = " ".join(ssh_command)
|
|
931
1061
|
|
|
932
1062
|
completed_process = subprocess.run(
|
|
933
|
-
args=shlex.split(ssh_command),
|
|
1063
|
+
args=shlex.split(ssh_command),
|
|
1064
|
+
# NOTE(agronskiy): look out for hangs and deadlocks
|
|
1065
|
+
stdout=subprocess.PIPE,
|
|
1066
|
+
stderr=subprocess.PIPE,
|
|
934
1067
|
)
|
|
935
1068
|
|
|
936
1069
|
# Parse the sacct output (before scancel runs)
|
|
@@ -1008,7 +1141,10 @@ def _read_files_from_remote(
|
|
|
1008
1141
|
ssh_command.append(cat_commands)
|
|
1009
1142
|
ssh_command = " ".join(ssh_command)
|
|
1010
1143
|
completed_process = subprocess.run(
|
|
1011
|
-
args=shlex.split(ssh_command),
|
|
1144
|
+
args=shlex.split(ssh_command),
|
|
1145
|
+
# NOTE(agronskiy): look out for hangs and deadlocks
|
|
1146
|
+
stdout=subprocess.PIPE,
|
|
1147
|
+
stderr=subprocess.PIPE,
|
|
1012
1148
|
)
|
|
1013
1149
|
if completed_process.returncode != 0:
|
|
1014
1150
|
raise RuntimeError(
|
|
@@ -1085,9 +1221,236 @@ sbatch --dependency=afternotok:$SLURM_JOB_ID $_this_script $SLURM_JOB_ID
|
|
|
1085
1221
|
""".strip()
|
|
1086
1222
|
|
|
1087
1223
|
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
#
|
|
1091
|
-
|
|
1224
|
+
def _generate_haproxy_config_with_placeholders(cfg):
|
|
1225
|
+
"""Generate HAProxy configuration with placeholder IPs using Jinja template."""
|
|
1226
|
+
# Set up Jinja environment
|
|
1227
|
+
template_dir = Path(__file__).parent
|
|
1228
|
+
template_path = template_dir / "proxy.cfg.template"
|
|
1229
|
+
|
|
1230
|
+
if not template_path.exists():
|
|
1231
|
+
raise FileNotFoundError(f"Proxy template not found: {template_path}")
|
|
1232
|
+
|
|
1233
|
+
env = Environment(loader=FileSystemLoader(template_dir))
|
|
1234
|
+
template = env.get_template("proxy.cfg.template")
|
|
1235
|
+
|
|
1236
|
+
# Prepare template data with placeholder IPs - use actual number of nodes
|
|
1237
|
+
num_nodes = cfg.execution.num_nodes
|
|
1238
|
+
nodes = []
|
|
1239
|
+
for i in range(num_nodes):
|
|
1240
|
+
nodes.append({"ip": f"{{IP_{i}}}", "port": cfg.deployment.port})
|
|
1241
|
+
|
|
1242
|
+
# Get health check parameters from execution config
|
|
1243
|
+
proxy_config = cfg.execution.get("proxy", {}).get("config", {})
|
|
1244
|
+
health_check_path = proxy_config.get("health_check_path", "/health")
|
|
1245
|
+
health_check_status = proxy_config.get("health_check_status", 200)
|
|
1246
|
+
haproxy_port = proxy_config.get("haproxy_port", 5009)
|
|
1247
|
+
|
|
1248
|
+
# Render template
|
|
1249
|
+
config = template.render(
|
|
1250
|
+
haproxy_port=haproxy_port,
|
|
1251
|
+
health_check_path=health_check_path,
|
|
1252
|
+
health_check_status=health_check_status,
|
|
1253
|
+
nodes=nodes,
|
|
1254
|
+
)
|
|
1255
|
+
|
|
1256
|
+
return config
|
|
1257
|
+
|
|
1258
|
+
|
|
1259
|
+
def _generate_haproxy_config(cfg, nodes_ips):
|
|
1260
|
+
"""Generate HAProxy configuration using Jinja template."""
|
|
1261
|
+
# Set up Jinja environment
|
|
1262
|
+
template_dir = Path(__file__).parent
|
|
1263
|
+
template_path = template_dir / "proxy.cfg.template"
|
|
1264
|
+
|
|
1265
|
+
if not template_path.exists():
|
|
1266
|
+
raise FileNotFoundError(f"Proxy template not found: {template_path}")
|
|
1267
|
+
|
|
1268
|
+
env = Environment(loader=FileSystemLoader(template_dir))
|
|
1269
|
+
template = env.get_template("proxy.cfg.template")
|
|
1270
|
+
|
|
1271
|
+
# Prepare template data
|
|
1272
|
+
nodes = []
|
|
1273
|
+
for i, ip in enumerate(nodes_ips, 1):
|
|
1274
|
+
nodes.append(
|
|
1275
|
+
{"ip": ip, "port": cfg.deployment.port} # All nodes use the same port
|
|
1276
|
+
)
|
|
1277
|
+
|
|
1278
|
+
# Get health check parameters from deployment config
|
|
1279
|
+
health_check_path = cfg.deployment.get("health_check_path", "/health")
|
|
1280
|
+
health_check_status = cfg.deployment.get("health_check_status", 200)
|
|
1281
|
+
haproxy_port = cfg.deployment.get("haproxy_port", 5009)
|
|
1282
|
+
|
|
1283
|
+
# Render template
|
|
1284
|
+
config = template.render(
|
|
1285
|
+
haproxy_port=haproxy_port,
|
|
1286
|
+
health_check_path=health_check_path,
|
|
1287
|
+
health_check_status=health_check_status,
|
|
1288
|
+
nodes=nodes,
|
|
1289
|
+
)
|
|
1290
|
+
|
|
1291
|
+
return config
|
|
1292
|
+
|
|
1293
|
+
|
|
1294
|
+
def _generate_deployment_srun_command(
|
|
1295
|
+
cfg, deployment_mounts_list, remote_task_subdir, instance_id: int = 0
|
|
1296
|
+
):
|
|
1297
|
+
"""Generate the deployment srun command with proper node/ntask configuration.
|
|
1298
|
+
|
|
1299
|
+
Returns:
|
|
1300
|
+
tuple: (script_string, is_potentially_unsafe, debug_comment)
|
|
1301
|
+
"""
|
|
1302
|
+
s = ""
|
|
1303
|
+
debug_comment = ""
|
|
1304
|
+
is_potentially_unsafe = False
|
|
1305
|
+
|
|
1306
|
+
s += "# deployment server\n"
|
|
1307
|
+
|
|
1308
|
+
# Extract pre_cmd for later use inside container
|
|
1309
|
+
pre_cmd: str = cfg.deployment.get("pre_cmd") or ""
|
|
1310
|
+
if pre_cmd:
|
|
1311
|
+
is_potentially_unsafe = True
|
|
1312
|
+
create_pre_script_cmd = _str_to_echo_command(
|
|
1313
|
+
pre_cmd, filename="deployment_pre_cmd.sh"
|
|
1314
|
+
)
|
|
1315
|
+
debug_comment += create_pre_script_cmd.debug + "\n\n"
|
|
1316
|
+
|
|
1317
|
+
s += "# Get node IPs\n"
|
|
1318
|
+
s += "nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) )\n"
|
|
1319
|
+
s += 'nodes_array=("${nodes[@]}") # Ensure nodes are stored properly\n'
|
|
1320
|
+
s += 'export NODES_IPS_ARRAY=($(for node in "${nodes_array[@]}"; do srun --nodelist=$node --ntasks=1 --nodes=1 hostname --ip-address; done))\n'
|
|
1321
|
+
s += 'echo "Node IPs: ${NODES_IPS_ARRAY[@]}"\n'
|
|
1322
|
+
s += "# Export MASTER_IP as the first node IP\n"
|
|
1323
|
+
s += "export MASTER_IP=${NODES_IPS_ARRAY[0]}\n"
|
|
1324
|
+
s += 'echo "MASTER_IP: $MASTER_IP"\n'
|
|
1325
|
+
|
|
1326
|
+
# Add debug comment for deployment pre_cmd before srun command
|
|
1327
|
+
if debug_comment:
|
|
1328
|
+
s += "# Debug contents of deployment pre_cmd\n"
|
|
1329
|
+
s += debug_comment
|
|
1330
|
+
s += "\n"
|
|
1331
|
+
|
|
1332
|
+
s += "srun --mpi pmix --overlap "
|
|
1333
|
+
s += f"--nodes {cfg.execution.num_nodes} --ntasks {cfg.execution.get('deployment', {}).get('n_tasks', 1)} "
|
|
1334
|
+
s += "--container-image {} ".format(cfg.deployment.image)
|
|
1335
|
+
if deployment_mounts_list:
|
|
1336
|
+
s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
|
|
1337
|
+
if not cfg.execution.get("mounts", {}).get("mount_home", True):
|
|
1338
|
+
s += "--no-container-mount-home "
|
|
1339
|
+
s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A-%t.out")
|
|
1340
|
+
|
|
1341
|
+
deployment_env_var_names = list(
|
|
1342
|
+
cfg.execution.get("env_vars", {}).get("deployment", {})
|
|
1343
|
+
)
|
|
1344
|
+
if cfg.deployment.get("env_vars"):
|
|
1345
|
+
warnings.warn(
|
|
1346
|
+
"cfg.deployment.env_vars will be deprecated in future versions. "
|
|
1347
|
+
"Use cfg.execution.env_vars.deployment instead.",
|
|
1348
|
+
category=DeprecationWarning,
|
|
1349
|
+
stacklevel=2,
|
|
1350
|
+
)
|
|
1351
|
+
deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
|
|
1352
|
+
|
|
1353
|
+
# Always add MASTER_IP to the environment variables
|
|
1354
|
+
if "MASTER_IP" not in deployment_env_var_names:
|
|
1355
|
+
deployment_env_var_names.append("MASTER_IP")
|
|
1356
|
+
|
|
1357
|
+
if deployment_env_var_names:
|
|
1358
|
+
s += f"--container-env {','.join(deployment_env_var_names)} "
|
|
1359
|
+
|
|
1360
|
+
# Wrap deployment command to execute pre_cmd inside container if needed
|
|
1361
|
+
if pre_cmd:
|
|
1362
|
+
# Create a wrapper command that runs inside the container:
|
|
1363
|
+
# 1. Create deployment_pre_cmd.sh file
|
|
1364
|
+
# 2. Source it
|
|
1365
|
+
# 3. Execute the original deployment command
|
|
1366
|
+
create_pre_script_cmd = _str_to_echo_command(
|
|
1367
|
+
pre_cmd, filename="deployment_pre_cmd.sh"
|
|
1368
|
+
)
|
|
1369
|
+
# Escape single quotes in the deployment command for bash -c
|
|
1370
|
+
escaped_deployment_cmd = cfg.deployment.command.replace("'", "'\"'\"'")
|
|
1371
|
+
wrapped_command = (
|
|
1372
|
+
f"bash -c '{create_pre_script_cmd.cmd} && "
|
|
1373
|
+
f"source deployment_pre_cmd.sh && "
|
|
1374
|
+
f"{escaped_deployment_cmd}'"
|
|
1375
|
+
)
|
|
1376
|
+
s += "{} &\n\n".format(wrapped_command)
|
|
1377
|
+
else:
|
|
1378
|
+
s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
|
|
1379
|
+
|
|
1380
|
+
s += "SERVER_PID=$! # capture the PID of the server background srun process\n\n"
|
|
1381
|
+
|
|
1382
|
+
return s, is_potentially_unsafe, debug_comment
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def _get_wait_for_server_handler(
|
|
1386
|
+
ip_list: str,
|
|
1387
|
+
port: int,
|
|
1388
|
+
health_check_path: str,
|
|
1389
|
+
service_name: str = "server",
|
|
1390
|
+
check_pid: bool = False,
|
|
1391
|
+
):
|
|
1392
|
+
"""Generate wait for server handler that takes a list of IPs."""
|
|
1393
|
+
pid_check = ""
|
|
1394
|
+
if check_pid:
|
|
1395
|
+
pid_check = 'kill -0 "$SERVER_PID" 2>/dev/null || { echo "Server process $SERVER_PID died"; exit 1; }'
|
|
1396
|
+
|
|
1397
|
+
handler = f"""date
|
|
1398
|
+
# wait for the {service_name} to initialize
|
|
1399
|
+
for ip in {ip_list}; do
|
|
1400
|
+
echo "Waiting for {service_name} on $ip..."
|
|
1401
|
+
while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" http://$ip:{port}{health_check_path})" != "200" ]]; do
|
|
1402
|
+
{pid_check}
|
|
1403
|
+
sleep 5
|
|
1404
|
+
done
|
|
1405
|
+
echo "{service_name} ready on $ip!"
|
|
1406
|
+
done
|
|
1092
1407
|
date
|
|
1093
1408
|
""".strip()
|
|
1409
|
+
|
|
1410
|
+
return handler
|
|
1411
|
+
|
|
1412
|
+
|
|
1413
|
+
def _get_proxy_server_srun_command(cfg, remote_task_subdir):
|
|
1414
|
+
"""Generate proxy server srun command based on proxy type."""
|
|
1415
|
+
proxy_type = cfg.execution.get("proxy", {}).get("type", "haproxy")
|
|
1416
|
+
|
|
1417
|
+
if proxy_type == "haproxy":
|
|
1418
|
+
return _generate_haproxy_srun_command(cfg, remote_task_subdir)
|
|
1419
|
+
else:
|
|
1420
|
+
raise ValueError(
|
|
1421
|
+
f"Unsupported proxy type: {proxy_type}. Currently only 'haproxy' is supported."
|
|
1422
|
+
)
|
|
1423
|
+
|
|
1424
|
+
|
|
1425
|
+
def _generate_haproxy_srun_command(cfg, remote_task_subdir):
|
|
1426
|
+
"""Generate HAProxy-specific srun command using template-based config."""
|
|
1427
|
+
s = ""
|
|
1428
|
+
s += "# Proxy load balancer\n"
|
|
1429
|
+
s += "# Copy template to config file (important for restarts)\n"
|
|
1430
|
+
s += f"cp {remote_task_subdir}/proxy.cfg.template {remote_task_subdir}/proxy.cfg\n"
|
|
1431
|
+
s += "# Replace placeholder IPs with actual node IPs\n"
|
|
1432
|
+
s += f"proxy_config_file={remote_task_subdir}/proxy.cfg\n"
|
|
1433
|
+
s += 'for i in "${!NODES_IPS_ARRAY[@]}"; do\n'
|
|
1434
|
+
s += ' ip="${NODES_IPS_ARRAY[$i]}"\n'
|
|
1435
|
+
s += ' sed -i "s/{IP_$i}/$ip/g" "$proxy_config_file"\n'
|
|
1436
|
+
s += "done\n"
|
|
1437
|
+
s += "\n"
|
|
1438
|
+
s += "srun --mpi pmix --overlap "
|
|
1439
|
+
s += "--nodes 1 --ntasks 1 "
|
|
1440
|
+
s += f"--container-image {cfg.execution.get('proxy', {}).get('image', 'haproxy:latest')} "
|
|
1441
|
+
s += f"--container-mounts {remote_task_subdir}/proxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro "
|
|
1442
|
+
s += f"--output {remote_task_subdir}/logs/proxy-%A.out "
|
|
1443
|
+
s += "haproxy -f /usr/local/etc/haproxy/haproxy.cfg &\n"
|
|
1444
|
+
s += "PROXY_PID=$! # capture the PID of the proxy background srun process\n"
|
|
1445
|
+
s += 'echo "Proxy started with PID: $PROXY_PID"\n\n'
|
|
1446
|
+
|
|
1447
|
+
# Wait for proxy to be ready on localhost
|
|
1448
|
+
proxy_config = cfg.execution.get("proxy", {}).get("config", {})
|
|
1449
|
+
haproxy_port = proxy_config.get("haproxy_port", 5009)
|
|
1450
|
+
health_path = proxy_config.get("health_check_path", "/health")
|
|
1451
|
+
s += _get_wait_for_server_handler(
|
|
1452
|
+
"127.0.0.1", haproxy_port, health_path, "Proxy", check_pid=False
|
|
1453
|
+
)
|
|
1454
|
+
s += "\n"
|
|
1455
|
+
|
|
1456
|
+
return s
|