nemo-evaluator-launcher 0.1.19__py3-none-any.whl → 0.1.56__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nemo_evaluator_launcher/api/functional.py +159 -5
- nemo_evaluator_launcher/cli/logs.py +102 -0
- nemo_evaluator_launcher/cli/ls_task.py +280 -0
- nemo_evaluator_launcher/cli/ls_tasks.py +208 -55
- nemo_evaluator_launcher/cli/main.py +29 -2
- nemo_evaluator_launcher/cli/run.py +114 -16
- nemo_evaluator_launcher/cli/version.py +26 -23
- nemo_evaluator_launcher/common/container_metadata/__init__.py +61 -0
- nemo_evaluator_launcher/common/container_metadata/intermediate_repr.py +530 -0
- nemo_evaluator_launcher/common/container_metadata/loading.py +1126 -0
- nemo_evaluator_launcher/common/container_metadata/registries.py +824 -0
- nemo_evaluator_launcher/common/container_metadata/utils.py +63 -0
- nemo_evaluator_launcher/common/helpers.py +200 -51
- nemo_evaluator_launcher/common/logging_utils.py +16 -5
- nemo_evaluator_launcher/common/mapping.py +341 -155
- nemo_evaluator_launcher/common/printing_utils.py +25 -12
- nemo_evaluator_launcher/configs/deployment/sglang.yaml +4 -2
- nemo_evaluator_launcher/configs/deployment/trtllm.yaml +2 -3
- nemo_evaluator_launcher/configs/deployment/vllm.yaml +0 -1
- nemo_evaluator_launcher/configs/execution/slurm/default.yaml +14 -0
- nemo_evaluator_launcher/executors/base.py +31 -1
- nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +36 -1
- nemo_evaluator_launcher/executors/lepton/executor.py +107 -9
- nemo_evaluator_launcher/executors/local/executor.py +383 -24
- nemo_evaluator_launcher/executors/local/run.template.sh +54 -2
- nemo_evaluator_launcher/executors/slurm/executor.py +559 -64
- nemo_evaluator_launcher/executors/slurm/proxy.cfg.template +26 -0
- nemo_evaluator_launcher/exporters/utils.py +32 -46
- nemo_evaluator_launcher/package_info.py +1 -1
- nemo_evaluator_launcher/resources/all_tasks_irs.yaml +17016 -0
- nemo_evaluator_launcher/resources/mapping.toml +64 -315
- {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.56.dist-info}/METADATA +4 -3
- nemo_evaluator_launcher-0.1.56.dist-info/RECORD +69 -0
- {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.56.dist-info}/entry_points.txt +1 -0
- nemo_evaluator_launcher-0.1.19.dist-info/RECORD +0 -60
- {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.56.dist-info}/WHEEL +0 -0
- {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.56.dist-info}/licenses/LICENSE +0 -0
- {nemo_evaluator_launcher-0.1.19.dist-info → nemo_evaluator_launcher-0.1.56.dist-info}/top_level.txt +0 -0
|
@@ -30,6 +30,7 @@ from pathlib import Path
|
|
|
30
30
|
from typing import Dict, List, Optional
|
|
31
31
|
|
|
32
32
|
import yaml
|
|
33
|
+
from jinja2 import Environment, FileSystemLoader
|
|
33
34
|
from omegaconf import DictConfig, OmegaConf
|
|
34
35
|
|
|
35
36
|
from nemo_evaluator_launcher.common.execdb import (
|
|
@@ -39,18 +40,19 @@ from nemo_evaluator_launcher.common.execdb import (
|
|
|
39
40
|
generate_job_id,
|
|
40
41
|
)
|
|
41
42
|
from nemo_evaluator_launcher.common.helpers import (
|
|
43
|
+
CmdAndReadableComment,
|
|
44
|
+
_str_to_echo_command,
|
|
42
45
|
get_api_key_name,
|
|
43
|
-
get_endpoint_url,
|
|
44
46
|
get_eval_factory_command,
|
|
45
47
|
get_eval_factory_dataset_size_from_run_config,
|
|
46
|
-
get_health_url,
|
|
47
48
|
get_timestamp_string,
|
|
48
49
|
)
|
|
50
|
+
from nemo_evaluator_launcher.common.logging_utils import logger
|
|
49
51
|
from nemo_evaluator_launcher.common.mapping import (
|
|
50
|
-
|
|
52
|
+
get_task_definition_for_job,
|
|
51
53
|
load_tasks_mapping,
|
|
52
54
|
)
|
|
53
|
-
from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey
|
|
55
|
+
from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey, red
|
|
54
56
|
from nemo_evaluator_launcher.executors.base import (
|
|
55
57
|
BaseExecutor,
|
|
56
58
|
ExecutionState,
|
|
@@ -94,6 +96,7 @@ class SlurmExecutor(BaseExecutor):
|
|
|
94
96
|
tasks_mapping = load_tasks_mapping()
|
|
95
97
|
eval_images: list[str] = []
|
|
96
98
|
|
|
99
|
+
is_potentially_unsafe = False
|
|
97
100
|
for idx, task in enumerate(cfg.evaluation.tasks):
|
|
98
101
|
# calculate job_id
|
|
99
102
|
job_id = f"{invocation_id}.{idx}"
|
|
@@ -106,7 +109,11 @@ class SlurmExecutor(BaseExecutor):
|
|
|
106
109
|
(local_task_subdir / "artifacts").mkdir()
|
|
107
110
|
|
|
108
111
|
# resolve eval image and pass directly via task override
|
|
109
|
-
task_definition =
|
|
112
|
+
task_definition = get_task_definition_for_job(
|
|
113
|
+
task_query=task.name,
|
|
114
|
+
base_mapping=tasks_mapping,
|
|
115
|
+
container=task.get("container"),
|
|
116
|
+
)
|
|
110
117
|
eval_image = task_definition["container"]
|
|
111
118
|
if "container" in task:
|
|
112
119
|
eval_image = task["container"]
|
|
@@ -114,7 +121,7 @@ class SlurmExecutor(BaseExecutor):
|
|
|
114
121
|
eval_images.append(eval_image)
|
|
115
122
|
|
|
116
123
|
# generate and write down sbatch script
|
|
117
|
-
|
|
124
|
+
sbatch_script_content_struct = _create_slurm_sbatch_script(
|
|
118
125
|
cfg=cfg,
|
|
119
126
|
task=task,
|
|
120
127
|
eval_image=eval_image,
|
|
@@ -122,6 +129,32 @@ class SlurmExecutor(BaseExecutor):
|
|
|
122
129
|
invocation_id=invocation_id,
|
|
123
130
|
job_id=job_id,
|
|
124
131
|
)
|
|
132
|
+
|
|
133
|
+
# Create proxy config file with placeholder IPs for multi-instance deployments
|
|
134
|
+
if cfg.deployment.get("multiple_instances", False):
|
|
135
|
+
proxy_type = cfg.execution.get("proxy", {}).get("type", "haproxy")
|
|
136
|
+
if proxy_type == "haproxy":
|
|
137
|
+
proxy_config = _generate_haproxy_config_with_placeholders(cfg)
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
f"Unsupported proxy type: {proxy_type}. Currently only 'haproxy' is supported."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Save both template and working config
|
|
144
|
+
proxy_template_path = local_task_subdir / "proxy.cfg.template"
|
|
145
|
+
proxy_config_path = local_task_subdir / "proxy.cfg"
|
|
146
|
+
with open(proxy_template_path, "w") as f:
|
|
147
|
+
f.write(proxy_config)
|
|
148
|
+
with open(proxy_config_path, "w") as f:
|
|
149
|
+
f.write(proxy_config)
|
|
150
|
+
|
|
151
|
+
sbatch_script_content_str = sbatch_script_content_struct.cmd
|
|
152
|
+
|
|
153
|
+
# We accumulate if any task contains unsafe commands
|
|
154
|
+
is_potentially_unsafe = (
|
|
155
|
+
is_potentially_unsafe
|
|
156
|
+
or sbatch_script_content_struct.is_potentially_unsafe
|
|
157
|
+
)
|
|
125
158
|
local_runsub_path = local_task_subdir / "run.sub"
|
|
126
159
|
remote_runsub_path = remote_task_subdir / "run.sub"
|
|
127
160
|
with open(local_runsub_path, "w") as f:
|
|
@@ -138,14 +171,56 @@ class SlurmExecutor(BaseExecutor):
|
|
|
138
171
|
with open(local_runsub_path, "r") as f:
|
|
139
172
|
print(grey(f.read()))
|
|
140
173
|
print(bold("To submit jobs") + ", run the executor without --dry-run")
|
|
174
|
+
if is_potentially_unsafe:
|
|
175
|
+
print(
|
|
176
|
+
red(
|
|
177
|
+
"\nFound `pre_cmd` (evaluation or deployment) which carries security risk. When running without --dry-run "
|
|
178
|
+
"make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1"
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
|
|
141
182
|
return invocation_id
|
|
142
183
|
|
|
184
|
+
if is_potentially_unsafe:
|
|
185
|
+
if os.environ.get("NEMO_EVALUATOR_TRUST_PRE_CMD", "") == "1":
|
|
186
|
+
logger.warning(
|
|
187
|
+
"Found non-empty commands (e.g. `pre_cmd` in evaluation or deployment) and NEMO_EVALUATOR_TRUST_PRE_CMD "
|
|
188
|
+
"is set, proceeding with caution."
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
else:
|
|
192
|
+
logger.error(
|
|
193
|
+
"Found non-empty commands (e.g. `pre_cmd` in evaluation or deployment) and NEMO_EVALUATOR_TRUST_PRE_CMD "
|
|
194
|
+
"is not set. This might carry security risk and unstable environments. "
|
|
195
|
+
"To continue, make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1.",
|
|
196
|
+
)
|
|
197
|
+
raise AttributeError(
|
|
198
|
+
"Untrusted command found in config, make sure you trust and "
|
|
199
|
+
"set NEMO_EVALUATOR_TRUST_PRE_CMD=1."
|
|
200
|
+
)
|
|
201
|
+
|
|
143
202
|
socket = str(Path(tmpdirname) / "socket")
|
|
144
203
|
socket_or_none = _open_master_connection(
|
|
145
204
|
username=cfg.execution.username,
|
|
146
205
|
hostname=cfg.execution.hostname,
|
|
147
206
|
socket=socket,
|
|
148
207
|
)
|
|
208
|
+
|
|
209
|
+
if socket_or_none is None:
|
|
210
|
+
raise RuntimeError(
|
|
211
|
+
f"Failed to connect to the cluster {cfg.execution.hostname} as user {cfg.execution.username}. "
|
|
212
|
+
"Please check your SSH configuration."
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Validate that all mount paths exist on the remote host
|
|
216
|
+
mount_paths = _collect_mount_paths(cfg)
|
|
217
|
+
_validate_remote_paths_exist(
|
|
218
|
+
paths=mount_paths,
|
|
219
|
+
username=cfg.execution.username,
|
|
220
|
+
hostname=cfg.execution.hostname,
|
|
221
|
+
socket=socket_or_none,
|
|
222
|
+
)
|
|
223
|
+
|
|
149
224
|
_make_remote_execution_output_dir(
|
|
150
225
|
dirpath=cfg.execution.output_dir,
|
|
151
226
|
username=cfg.execution.username,
|
|
@@ -437,7 +512,7 @@ def _create_slurm_sbatch_script(
|
|
|
437
512
|
remote_task_subdir: Path,
|
|
438
513
|
invocation_id: str,
|
|
439
514
|
job_id: str,
|
|
440
|
-
) ->
|
|
515
|
+
) -> CmdAndReadableComment:
|
|
441
516
|
"""Generate the contents of a SLURM sbatch script for a given evaluation task.
|
|
442
517
|
|
|
443
518
|
Args:
|
|
@@ -452,8 +527,11 @@ def _create_slurm_sbatch_script(
|
|
|
452
527
|
"""
|
|
453
528
|
# get task from mapping, overrides, urls
|
|
454
529
|
tasks_mapping = load_tasks_mapping()
|
|
455
|
-
task_definition =
|
|
456
|
-
|
|
530
|
+
task_definition = get_task_definition_for_job(
|
|
531
|
+
task_query=task.name,
|
|
532
|
+
base_mapping=tasks_mapping,
|
|
533
|
+
container=task.get("container"),
|
|
534
|
+
)
|
|
457
535
|
|
|
458
536
|
# TODO(public release): convert to template
|
|
459
537
|
s = "#!/bin/bash\n"
|
|
@@ -468,6 +546,8 @@ def _create_slurm_sbatch_script(
|
|
|
468
546
|
s += "#SBATCH --gpus-per-node {}\n".format(cfg.execution.gpus_per_node)
|
|
469
547
|
if hasattr(cfg.execution, "gres"):
|
|
470
548
|
s += "#SBATCH --gres {}\n".format(cfg.execution.gres)
|
|
549
|
+
if cfg.execution.get("sbatch_comment"):
|
|
550
|
+
s += "#SBATCH --comment='{}'\n".format(cfg.execution.sbatch_comment)
|
|
471
551
|
job_name = "{account}-{subproject}.{details}".format(
|
|
472
552
|
account=cfg.execution.account,
|
|
473
553
|
subproject=cfg.execution.subproject,
|
|
@@ -475,7 +555,8 @@ def _create_slurm_sbatch_script(
|
|
|
475
555
|
)
|
|
476
556
|
s += "#SBATCH --job-name {}\n".format(job_name)
|
|
477
557
|
s += "#SBATCH --exclusive\n"
|
|
478
|
-
s += "#SBATCH --
|
|
558
|
+
s += "#SBATCH --no-requeue\n" # We have our own auto-resume logic
|
|
559
|
+
s += "#SBATCH --output {}\n".format(remote_task_subdir / "logs" / "slurm-%A.log")
|
|
479
560
|
s += "\n"
|
|
480
561
|
s += f'TASK_DIR="{str(remote_task_subdir)}"\n'
|
|
481
562
|
s += "\n"
|
|
@@ -493,8 +574,11 @@ def _create_slurm_sbatch_script(
|
|
|
493
574
|
if os.getenv(env_var) is None:
|
|
494
575
|
raise ValueError(f"Trying to pass an unset environment variable {env_var}.")
|
|
495
576
|
|
|
496
|
-
# check if required env vars are defined:
|
|
577
|
+
# check if required env vars are defined (excluding NEMO_EVALUATOR_DATASET_DIR which is handled separately):
|
|
497
578
|
for required_env_var in task_definition.get("required_env_vars", []):
|
|
579
|
+
# Skip NEMO_EVALUATOR_DATASET_DIR as it's handled by dataset mounting logic below
|
|
580
|
+
if required_env_var == "NEMO_EVALUATOR_DATASET_DIR":
|
|
581
|
+
continue
|
|
498
582
|
if required_env_var not in env_vars.keys():
|
|
499
583
|
raise ValueError(
|
|
500
584
|
f"{task.name} task requires environment variable {required_env_var}."
|
|
@@ -540,6 +624,7 @@ def _create_slurm_sbatch_script(
|
|
|
540
624
|
|
|
541
625
|
# prepare deployment mounts
|
|
542
626
|
deployment_mounts_list = []
|
|
627
|
+
deployment_is_unsafe = False
|
|
543
628
|
if cfg.deployment.type != "none":
|
|
544
629
|
if checkpoint_path := cfg.deployment.get("checkpoint_path"):
|
|
545
630
|
deployment_mounts_list.append(f"{checkpoint_path}:/checkpoint:ro")
|
|
@@ -551,36 +636,33 @@ def _create_slurm_sbatch_script(
|
|
|
551
636
|
deployment_mounts_list.append(f"{source_mnt}:{target_mnt}")
|
|
552
637
|
|
|
553
638
|
# add deployment srun command
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
if deployment_mounts_list:
|
|
558
|
-
s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
|
|
559
|
-
if not cfg.execution.get("mounts", {}).get("mount_home", True):
|
|
560
|
-
s += "--no-container-mount-home "
|
|
561
|
-
s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A.out")
|
|
562
|
-
deployment_env_var_names = list(
|
|
563
|
-
cfg.execution.get("env_vars", {}).get("deployment", {})
|
|
564
|
-
)
|
|
565
|
-
if cfg.deployment.get("env_vars"):
|
|
566
|
-
warnings.warn(
|
|
567
|
-
"cfg.deployment.env_vars will be deprecated in future versions. "
|
|
568
|
-
"Use cfg.execution.env_vars.deployment instead.",
|
|
569
|
-
category=DeprecationWarning,
|
|
570
|
-
stacklevel=2,
|
|
639
|
+
deployment_srun_cmd, deployment_is_unsafe, deployment_debug = (
|
|
640
|
+
_generate_deployment_srun_command(
|
|
641
|
+
cfg, deployment_mounts_list, remote_task_subdir
|
|
571
642
|
)
|
|
572
|
-
deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
|
|
573
|
-
if deployment_env_var_names:
|
|
574
|
-
s += f"--container-env {','.join(deployment_env_var_names)} "
|
|
575
|
-
s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
|
|
576
|
-
s += (
|
|
577
|
-
"SERVER_PID=$! # capture the PID of the server background srun process\n\n"
|
|
578
643
|
)
|
|
644
|
+
s += deployment_srun_cmd
|
|
579
645
|
|
|
580
646
|
# wait for the server to initialize
|
|
581
|
-
|
|
647
|
+
health_path = cfg.deployment.get("health_check_path", "/health")
|
|
648
|
+
# For multi-instance check all node IPs, for single instance check localhost
|
|
649
|
+
if cfg.deployment.get("multiple_instances", False):
|
|
650
|
+
ip_list = '"${NODES_IPS_ARRAY[@]}"'
|
|
651
|
+
else:
|
|
652
|
+
ip_list = '"127.0.0.1"'
|
|
653
|
+
s += _get_wait_for_server_handler(
|
|
654
|
+
ip_list,
|
|
655
|
+
cfg.deployment.port,
|
|
656
|
+
health_path,
|
|
657
|
+
"server",
|
|
658
|
+
check_pid=True,
|
|
659
|
+
)
|
|
582
660
|
s += "\n\n"
|
|
583
661
|
|
|
662
|
+
# add proxy load balancer for multi-instance deployments
|
|
663
|
+
if cfg.deployment.get("multiple_instances", False):
|
|
664
|
+
s += _get_proxy_server_srun_command(cfg, remote_task_subdir)
|
|
665
|
+
|
|
584
666
|
# prepare evaluation mounts
|
|
585
667
|
evaluation_mounts_list = [
|
|
586
668
|
"{}:/results".format(remote_task_subdir / "artifacts"),
|
|
@@ -590,7 +672,29 @@ def _create_slurm_sbatch_script(
|
|
|
590
672
|
):
|
|
591
673
|
evaluation_mounts_list.append(f"{source_mnt}:{target_mnt}")
|
|
592
674
|
|
|
593
|
-
|
|
675
|
+
# Handle dataset directory mounting if NEMO_EVALUATOR_DATASET_DIR is required
|
|
676
|
+
if "NEMO_EVALUATOR_DATASET_DIR" in task_definition.get("required_env_vars", []):
|
|
677
|
+
# Get dataset directory from task config
|
|
678
|
+
if "dataset_dir" in task:
|
|
679
|
+
dataset_mount_host = task["dataset_dir"]
|
|
680
|
+
else:
|
|
681
|
+
raise ValueError(
|
|
682
|
+
f"{task.name} task requires a dataset_dir to be specified. "
|
|
683
|
+
f"Add 'dataset_dir: /path/to/your/dataset' under the task configuration."
|
|
684
|
+
)
|
|
685
|
+
# Get container mount path (default to /datasets if not specified)
|
|
686
|
+
dataset_mount_container = task.get("dataset_mount_path", "/datasets")
|
|
687
|
+
# Add dataset mount to evaluation mounts list
|
|
688
|
+
evaluation_mounts_list.append(f"{dataset_mount_host}:{dataset_mount_container}")
|
|
689
|
+
# Export NEMO_EVALUATOR_DATASET_DIR environment variable
|
|
690
|
+
s += f"export NEMO_EVALUATOR_DATASET_DIR={dataset_mount_container}\n\n"
|
|
691
|
+
|
|
692
|
+
eval_factory_command_struct = get_eval_factory_command(
|
|
693
|
+
cfg,
|
|
694
|
+
task,
|
|
695
|
+
task_definition,
|
|
696
|
+
)
|
|
697
|
+
|
|
594
698
|
eval_factory_command = eval_factory_command_struct.cmd
|
|
595
699
|
# The debug comment for placing into the script and easy debug. Reason
|
|
596
700
|
# (see `CmdAndReadableComment`) is the current way of passing the command
|
|
@@ -606,6 +710,7 @@ def _create_slurm_sbatch_script(
|
|
|
606
710
|
|
|
607
711
|
s += "# evaluation client\n"
|
|
608
712
|
s += "srun --mpi pmix --overlap "
|
|
713
|
+
s += "--nodes 1 --ntasks 1 " # Client always runs on single node
|
|
609
714
|
s += "--container-image {} ".format(eval_image)
|
|
610
715
|
evaluation_env_var_names = list(
|
|
611
716
|
cfg.execution.get("env_vars", {}).get("evaluation", {})
|
|
@@ -616,14 +721,17 @@ def _create_slurm_sbatch_script(
|
|
|
616
721
|
s += "--no-container-mount-home "
|
|
617
722
|
|
|
618
723
|
s += "--container-mounts {} ".format(",".join(evaluation_mounts_list))
|
|
619
|
-
s += "--output {} ".format(remote_task_subdir / "logs" / "client-%A.
|
|
724
|
+
s += "--output {} ".format(remote_task_subdir / "logs" / "client-%A.log")
|
|
620
725
|
s += "bash -c '\n"
|
|
621
726
|
s += eval_factory_command
|
|
622
727
|
s += "'\n\n"
|
|
623
728
|
|
|
624
729
|
# terminate the server after all evaluation clients finish
|
|
625
730
|
if cfg.deployment.type != "none":
|
|
626
|
-
s += "kill $SERVER_PID # terminate the server to finish gracefully\n
|
|
731
|
+
s += "kill $SERVER_PID # terminate the server to finish gracefully\n"
|
|
732
|
+
if cfg.deployment.get("multiple_instances", False):
|
|
733
|
+
s += "kill $PROXY_PID # terminate proxy to finish gracefully\n"
|
|
734
|
+
s += "\n"
|
|
627
735
|
|
|
628
736
|
# auto-export
|
|
629
737
|
ae_cfg = cfg.execution.get("auto_export")
|
|
@@ -635,9 +743,22 @@ def _create_slurm_sbatch_script(
|
|
|
635
743
|
|
|
636
744
|
if destinations:
|
|
637
745
|
export_env = dict(cfg.execution.get("env_vars", {}).get("export", {}) or {})
|
|
638
|
-
s += _generate_auto_export_section(
|
|
746
|
+
s += _generate_auto_export_section(
|
|
747
|
+
cfg, job_id, destinations, export_env, remote_task_subdir
|
|
748
|
+
)
|
|
639
749
|
|
|
640
|
-
|
|
750
|
+
debug_str = "\n".join(["# " + line for line in s.splitlines()])
|
|
751
|
+
|
|
752
|
+
# Combine unsafe flags from both deployment and evaluation
|
|
753
|
+
is_potentially_unsafe = (
|
|
754
|
+
eval_factory_command_struct.is_potentially_unsafe or deployment_is_unsafe
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
return CmdAndReadableComment(
|
|
758
|
+
cmd=s,
|
|
759
|
+
debug=debug_str,
|
|
760
|
+
is_potentially_unsafe=is_potentially_unsafe,
|
|
761
|
+
)
|
|
641
762
|
|
|
642
763
|
|
|
643
764
|
def _generate_auto_export_section(
|
|
@@ -645,6 +766,8 @@ def _generate_auto_export_section(
|
|
|
645
766
|
job_id: str,
|
|
646
767
|
destinations: list,
|
|
647
768
|
export_env: dict,
|
|
769
|
+
remote_task_subdir: Path,
|
|
770
|
+
export_image: str = "python:3.12.7-slim",
|
|
648
771
|
) -> str:
|
|
649
772
|
"""Generate simple auto-export section for sbatch script."""
|
|
650
773
|
if not destinations:
|
|
@@ -654,10 +777,7 @@ def _generate_auto_export_section(
|
|
|
654
777
|
s += "EVAL_EXIT_CODE=$?\n"
|
|
655
778
|
s += "if [ $EVAL_EXIT_CODE -eq 0 ]; then\n"
|
|
656
779
|
s += " echo 'Evaluation completed successfully. Starting auto-export...'\n"
|
|
657
|
-
s +=
|
|
658
|
-
s += " set +x\n"
|
|
659
|
-
s += " set +u\n"
|
|
660
|
-
s += ' cd "$TASK_DIR/artifacts"\n'
|
|
780
|
+
s += f' cd "{remote_task_subdir}/artifacts"\n'
|
|
661
781
|
|
|
662
782
|
# Work with DictConfig; convert only for YAML at the end
|
|
663
783
|
exec_type = (
|
|
@@ -713,10 +833,25 @@ def _generate_auto_export_section(
|
|
|
713
833
|
esc = str(v).replace('"', '\\"')
|
|
714
834
|
s += f' export {k}="{esc}"\n'
|
|
715
835
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
836
|
+
s += " # export\n"
|
|
837
|
+
s += " srun --mpi pmix --overlap "
|
|
838
|
+
s += "--nodes 1 --ntasks 1 " # Client always runs on single node
|
|
839
|
+
s += "--container-image {} ".format(export_image)
|
|
840
|
+
if export_env:
|
|
841
|
+
s += "--container-env {} ".format(",".join(export_env))
|
|
842
|
+
if not cfg.execution.get("mounts", {}).get("mount_home", True):
|
|
843
|
+
s += "--no-container-mount-home "
|
|
719
844
|
|
|
845
|
+
s += f"--container-mounts {remote_task_subdir}/artifacts:{remote_task_subdir}/artifacts,{remote_task_subdir}/logs:{remote_task_subdir}/logs "
|
|
846
|
+
s += "--output {} ".format(remote_task_subdir / "logs" / "export-%A.log")
|
|
847
|
+
s += " bash -c '\n"
|
|
848
|
+
# FIXME(martas): would be good to install specific version
|
|
849
|
+
s += " pip install nemo-evaluator-launcher[all]\n"
|
|
850
|
+
s += f" cd {remote_task_subdir}/artifacts\n"
|
|
851
|
+
for dest in destinations:
|
|
852
|
+
s += f' echo "Exporting to {dest}..."\n'
|
|
853
|
+
s += f' nemo-evaluator-launcher export {job_id} --dest {dest} || echo "Export to {dest} failed"\n'
|
|
854
|
+
s += "'\n"
|
|
720
855
|
s += " echo 'Auto-export completed.'\n"
|
|
721
856
|
s += "else\n"
|
|
722
857
|
s += " echo 'Evaluation failed with exit code $EVAL_EXIT_CODE. Skipping auto-export.'\n"
|
|
@@ -731,11 +866,12 @@ def _open_master_connection(
|
|
|
731
866
|
socket: str,
|
|
732
867
|
) -> str | None:
|
|
733
868
|
ssh_command = f"ssh -MNf -S {socket} {username}@{hostname}"
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
)
|
|
869
|
+
logger.info("Opening master connection", cmd=ssh_command)
|
|
870
|
+
completed_process = subprocess.run(args=shlex.split(ssh_command))
|
|
737
871
|
if completed_process.returncode == 0:
|
|
872
|
+
logger.info("Opened master connection successfully", cmd=ssh_command)
|
|
738
873
|
return socket
|
|
874
|
+
logger.error("Failed to open master connection", code=completed_process.returncode)
|
|
739
875
|
return None
|
|
740
876
|
|
|
741
877
|
|
|
@@ -747,9 +883,7 @@ def _close_master_connection(
|
|
|
747
883
|
if socket is None:
|
|
748
884
|
return
|
|
749
885
|
ssh_command = f"ssh -O exit -S {socket} {username}@{hostname}"
|
|
750
|
-
completed_process = subprocess.run(
|
|
751
|
-
args=shlex.split(ssh_command), capture_output=True
|
|
752
|
-
)
|
|
886
|
+
completed_process = subprocess.run(args=shlex.split(ssh_command))
|
|
753
887
|
if completed_process.returncode != 0:
|
|
754
888
|
raise RuntimeError(
|
|
755
889
|
"failed to close the master connection\n{}".format(
|
|
@@ -771,8 +905,9 @@ def _make_remote_execution_output_dir(
|
|
|
771
905
|
ssh_command.append(f"{username}@{hostname}")
|
|
772
906
|
ssh_command.append(mkdir_command)
|
|
773
907
|
ssh_command = " ".join(ssh_command)
|
|
908
|
+
logger.info("Creating remote dir", cmd=ssh_command)
|
|
774
909
|
completed_process = subprocess.run(
|
|
775
|
-
args=shlex.split(ssh_command),
|
|
910
|
+
args=shlex.split(ssh_command), stderr=subprocess.PIPE
|
|
776
911
|
)
|
|
777
912
|
if completed_process.returncode != 0:
|
|
778
913
|
error_msg = (
|
|
@@ -780,6 +915,11 @@ def _make_remote_execution_output_dir(
|
|
|
780
915
|
if completed_process.stderr
|
|
781
916
|
else "Unknown error"
|
|
782
917
|
)
|
|
918
|
+
logger.error(
|
|
919
|
+
"Erorr creating remote dir",
|
|
920
|
+
code=completed_process.returncode,
|
|
921
|
+
msg=error_msg,
|
|
922
|
+
)
|
|
783
923
|
raise RuntimeError(
|
|
784
924
|
"failed to make a remote execution output dir\n{}".format(error_msg)
|
|
785
925
|
)
|
|
@@ -807,8 +947,10 @@ def _rsync_upload_rundirs(
|
|
|
807
947
|
remote_destination_str = f"{username}@{hostname}:{remote_target}"
|
|
808
948
|
local_sources_str = " ".join(map(str, local_sources))
|
|
809
949
|
rsync_upload_command = f"rsync -qcaz {local_sources_str} {remote_destination_str}"
|
|
950
|
+
logger.info("Rsyncing to remote dir", cmd=rsync_upload_command)
|
|
810
951
|
completed_process = subprocess.run(
|
|
811
|
-
args=shlex.split(rsync_upload_command),
|
|
952
|
+
args=shlex.split(rsync_upload_command),
|
|
953
|
+
stderr=subprocess.PIPE,
|
|
812
954
|
)
|
|
813
955
|
if completed_process.returncode != 0:
|
|
814
956
|
error_msg = (
|
|
@@ -816,6 +958,12 @@ def _rsync_upload_rundirs(
|
|
|
816
958
|
if completed_process.stderr
|
|
817
959
|
else "Unknown error"
|
|
818
960
|
)
|
|
961
|
+
|
|
962
|
+
logger.error(
|
|
963
|
+
"Erorr rsyncing to remote dir",
|
|
964
|
+
code=completed_process.returncode,
|
|
965
|
+
msg=error_msg,
|
|
966
|
+
)
|
|
819
967
|
raise RuntimeError("failed to upload local sources\n{}".format(error_msg))
|
|
820
968
|
|
|
821
969
|
|
|
@@ -837,9 +985,12 @@ def _sbatch_remote_runsubs(
|
|
|
837
985
|
ssh_command.append(f"{username}@{hostname}")
|
|
838
986
|
ssh_command.append(sbatch_commands)
|
|
839
987
|
ssh_command = " ".join(ssh_command)
|
|
840
|
-
|
|
988
|
+
logger.info("Running sbatch", cmd=ssh_command)
|
|
841
989
|
completed_process = subprocess.run(
|
|
842
|
-
args=shlex.split(ssh_command),
|
|
990
|
+
args=shlex.split(ssh_command),
|
|
991
|
+
# NOTE(agronskiy): look out for hangs and deadlocks
|
|
992
|
+
stdout=subprocess.PIPE,
|
|
993
|
+
stderr=subprocess.PIPE,
|
|
843
994
|
)
|
|
844
995
|
if completed_process.returncode != 0:
|
|
845
996
|
error_msg = completed_process.stderr.decode("utf-8")
|
|
@@ -849,6 +1000,7 @@ def _sbatch_remote_runsubs(
|
|
|
849
1000
|
|
|
850
1001
|
sbatch_output = completed_process.stdout.decode("utf-8")
|
|
851
1002
|
slurm_job_ids = re.findall(r"(?<=Submitted batch job )\d+", sbatch_output)
|
|
1003
|
+
logger.info("Started sbatch successfully", slurm_job_ids=slurm_job_ids)
|
|
852
1004
|
return slurm_job_ids
|
|
853
1005
|
|
|
854
1006
|
|
|
@@ -881,7 +1033,10 @@ def _query_slurm_jobs_status(
|
|
|
881
1033
|
ssh_command.append(sacct_command)
|
|
882
1034
|
ssh_command = " ".join(ssh_command)
|
|
883
1035
|
completed_process = subprocess.run(
|
|
884
|
-
args=shlex.split(ssh_command),
|
|
1036
|
+
args=shlex.split(ssh_command),
|
|
1037
|
+
# NOTE(agronskiy): look out for hangs and deadlocks
|
|
1038
|
+
stdout=subprocess.PIPE,
|
|
1039
|
+
stderr=subprocess.PIPE,
|
|
885
1040
|
)
|
|
886
1041
|
if completed_process.returncode != 0:
|
|
887
1042
|
raise RuntimeError(
|
|
@@ -930,7 +1085,10 @@ def _kill_slurm_job(
|
|
|
930
1085
|
ssh_command = " ".join(ssh_command)
|
|
931
1086
|
|
|
932
1087
|
completed_process = subprocess.run(
|
|
933
|
-
args=shlex.split(ssh_command),
|
|
1088
|
+
args=shlex.split(ssh_command),
|
|
1089
|
+
# NOTE(agronskiy): look out for hangs and deadlocks
|
|
1090
|
+
stdout=subprocess.PIPE,
|
|
1091
|
+
stderr=subprocess.PIPE,
|
|
934
1092
|
)
|
|
935
1093
|
|
|
936
1094
|
# Parse the sacct output (before scancel runs)
|
|
@@ -1008,7 +1166,10 @@ def _read_files_from_remote(
|
|
|
1008
1166
|
ssh_command.append(cat_commands)
|
|
1009
1167
|
ssh_command = " ".join(ssh_command)
|
|
1010
1168
|
completed_process = subprocess.run(
|
|
1011
|
-
args=shlex.split(ssh_command),
|
|
1169
|
+
args=shlex.split(ssh_command),
|
|
1170
|
+
# NOTE(agronskiy): look out for hangs and deadlocks
|
|
1171
|
+
stdout=subprocess.PIPE,
|
|
1172
|
+
stderr=subprocess.PIPE,
|
|
1012
1173
|
)
|
|
1013
1174
|
if completed_process.returncode != 0:
|
|
1014
1175
|
raise RuntimeError(
|
|
@@ -1085,9 +1246,343 @@ sbatch --dependency=afternotok:$SLURM_JOB_ID $_this_script $SLURM_JOB_ID
|
|
|
1085
1246
|
""".strip()
|
|
1086
1247
|
|
|
1087
1248
|
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
#
|
|
1091
|
-
|
|
1249
|
+
def _generate_haproxy_config_with_placeholders(cfg):
|
|
1250
|
+
"""Generate HAProxy configuration with placeholder IPs using Jinja template."""
|
|
1251
|
+
# Set up Jinja environment
|
|
1252
|
+
template_dir = Path(__file__).parent
|
|
1253
|
+
template_path = template_dir / "proxy.cfg.template"
|
|
1254
|
+
|
|
1255
|
+
if not template_path.exists():
|
|
1256
|
+
raise FileNotFoundError(f"Proxy template not found: {template_path}")
|
|
1257
|
+
|
|
1258
|
+
env = Environment(loader=FileSystemLoader(template_dir))
|
|
1259
|
+
template = env.get_template("proxy.cfg.template")
|
|
1260
|
+
|
|
1261
|
+
# Prepare template data with placeholder IPs - use actual number of nodes
|
|
1262
|
+
num_nodes = cfg.execution.num_nodes
|
|
1263
|
+
nodes = []
|
|
1264
|
+
for i in range(num_nodes):
|
|
1265
|
+
nodes.append({"ip": f"{{IP_{i}}}", "port": cfg.deployment.port})
|
|
1266
|
+
|
|
1267
|
+
# Get health check parameters from execution config
|
|
1268
|
+
proxy_config = cfg.execution.get("proxy", {}).get("config", {})
|
|
1269
|
+
health_check_path = proxy_config.get("health_check_path", "/health")
|
|
1270
|
+
health_check_status = proxy_config.get("health_check_status", 200)
|
|
1271
|
+
haproxy_port = proxy_config.get("haproxy_port", 5009)
|
|
1272
|
+
|
|
1273
|
+
# Render template
|
|
1274
|
+
config = template.render(
|
|
1275
|
+
haproxy_port=haproxy_port,
|
|
1276
|
+
health_check_path=health_check_path,
|
|
1277
|
+
health_check_status=health_check_status,
|
|
1278
|
+
nodes=nodes,
|
|
1279
|
+
)
|
|
1280
|
+
|
|
1281
|
+
return config
|
|
1282
|
+
|
|
1283
|
+
|
|
1284
|
+
def _generate_haproxy_config(cfg, nodes_ips):
|
|
1285
|
+
"""Generate HAProxy configuration using Jinja template."""
|
|
1286
|
+
# Set up Jinja environment
|
|
1287
|
+
template_dir = Path(__file__).parent
|
|
1288
|
+
template_path = template_dir / "proxy.cfg.template"
|
|
1289
|
+
|
|
1290
|
+
if not template_path.exists():
|
|
1291
|
+
raise FileNotFoundError(f"Proxy template not found: {template_path}")
|
|
1292
|
+
|
|
1293
|
+
env = Environment(loader=FileSystemLoader(template_dir))
|
|
1294
|
+
template = env.get_template("proxy.cfg.template")
|
|
1295
|
+
|
|
1296
|
+
# Prepare template data
|
|
1297
|
+
nodes = []
|
|
1298
|
+
for i, ip in enumerate(nodes_ips, 1):
|
|
1299
|
+
nodes.append(
|
|
1300
|
+
{"ip": ip, "port": cfg.deployment.port} # All nodes use the same port
|
|
1301
|
+
)
|
|
1302
|
+
|
|
1303
|
+
# Get health check parameters from deployment config
|
|
1304
|
+
health_check_path = cfg.deployment.get("health_check_path", "/health")
|
|
1305
|
+
health_check_status = cfg.deployment.get("health_check_status", 200)
|
|
1306
|
+
haproxy_port = cfg.deployment.get("haproxy_port", 5009)
|
|
1307
|
+
|
|
1308
|
+
# Render template
|
|
1309
|
+
config = template.render(
|
|
1310
|
+
haproxy_port=haproxy_port,
|
|
1311
|
+
health_check_path=health_check_path,
|
|
1312
|
+
health_check_status=health_check_status,
|
|
1313
|
+
nodes=nodes,
|
|
1314
|
+
)
|
|
1315
|
+
|
|
1316
|
+
return config
|
|
1317
|
+
|
|
1318
|
+
|
|
1319
|
+
def _generate_deployment_srun_command(
|
|
1320
|
+
cfg, deployment_mounts_list, remote_task_subdir, instance_id: int = 0
|
|
1321
|
+
):
|
|
1322
|
+
"""Generate the deployment srun command with proper node/ntask configuration.
|
|
1323
|
+
|
|
1324
|
+
Returns:
|
|
1325
|
+
tuple: (script_string, is_potentially_unsafe, debug_comment)
|
|
1326
|
+
"""
|
|
1327
|
+
s = ""
|
|
1328
|
+
debug_comment = ""
|
|
1329
|
+
is_potentially_unsafe = False
|
|
1330
|
+
|
|
1331
|
+
s += "# deployment server\n"
|
|
1332
|
+
|
|
1333
|
+
# Extract pre_cmd for later use inside container
|
|
1334
|
+
pre_cmd: str = cfg.deployment.get("pre_cmd") or ""
|
|
1335
|
+
if pre_cmd:
|
|
1336
|
+
is_potentially_unsafe = True
|
|
1337
|
+
create_pre_script_cmd = _str_to_echo_command(
|
|
1338
|
+
pre_cmd, filename="deployment_pre_cmd.sh"
|
|
1339
|
+
)
|
|
1340
|
+
debug_comment += create_pre_script_cmd.debug + "\n\n"
|
|
1341
|
+
|
|
1342
|
+
s += "# Get node IPs\n"
|
|
1343
|
+
s += "nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) )\n"
|
|
1344
|
+
s += 'nodes_array=("${nodes[@]}") # Ensure nodes are stored properly\n'
|
|
1345
|
+
s += 'export NODES_IPS_ARRAY=($(for node in "${nodes_array[@]}"; do srun --nodelist=$node --ntasks=1 --nodes=1 hostname --ip-address; done))\n'
|
|
1346
|
+
s += 'echo "Node IPs: ${NODES_IPS_ARRAY[@]}"\n'
|
|
1347
|
+
s += "# Export MASTER_IP as the first node IP\n"
|
|
1348
|
+
s += "export MASTER_IP=${NODES_IPS_ARRAY[0]}\n"
|
|
1349
|
+
s += 'echo "MASTER_IP: $MASTER_IP"\n'
|
|
1350
|
+
|
|
1351
|
+
# Add debug comment for deployment pre_cmd before srun command
|
|
1352
|
+
if debug_comment:
|
|
1353
|
+
s += "# Debug contents of deployment pre_cmd\n"
|
|
1354
|
+
s += debug_comment
|
|
1355
|
+
s += "\n"
|
|
1356
|
+
|
|
1357
|
+
s += "srun --mpi pmix --overlap "
|
|
1358
|
+
s += f"--nodes {cfg.execution.num_nodes} --ntasks {cfg.execution.get('deployment', {}).get('n_tasks', 1)} "
|
|
1359
|
+
s += "--container-image {} ".format(cfg.deployment.image)
|
|
1360
|
+
if deployment_mounts_list:
|
|
1361
|
+
s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
|
|
1362
|
+
if not cfg.execution.get("mounts", {}).get("mount_home", True):
|
|
1363
|
+
s += "--no-container-mount-home "
|
|
1364
|
+
s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A-%t.log")
|
|
1365
|
+
|
|
1366
|
+
deployment_env_var_names = list(
|
|
1367
|
+
cfg.execution.get("env_vars", {}).get("deployment", {})
|
|
1368
|
+
)
|
|
1369
|
+
if cfg.deployment.get("env_vars"):
|
|
1370
|
+
warnings.warn(
|
|
1371
|
+
"cfg.deployment.env_vars will be deprecated in future versions. "
|
|
1372
|
+
"Use cfg.execution.env_vars.deployment instead.",
|
|
1373
|
+
category=DeprecationWarning,
|
|
1374
|
+
stacklevel=2,
|
|
1375
|
+
)
|
|
1376
|
+
deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
|
|
1377
|
+
|
|
1378
|
+
# Always add MASTER_IP to the environment variables
|
|
1379
|
+
if "MASTER_IP" not in deployment_env_var_names:
|
|
1380
|
+
deployment_env_var_names.append("MASTER_IP")
|
|
1381
|
+
|
|
1382
|
+
if deployment_env_var_names:
|
|
1383
|
+
s += f"--container-env {','.join(deployment_env_var_names)} "
|
|
1384
|
+
|
|
1385
|
+
# Wrap deployment command to execute pre_cmd inside container if needed
|
|
1386
|
+
if pre_cmd:
|
|
1387
|
+
# Create a wrapper command that runs inside the container:
|
|
1388
|
+
# 1. Create deployment_pre_cmd.sh file
|
|
1389
|
+
# 2. Source it
|
|
1390
|
+
# 3. Execute the original deployment command
|
|
1391
|
+
create_pre_script_cmd = _str_to_echo_command(
|
|
1392
|
+
pre_cmd, filename="deployment_pre_cmd.sh"
|
|
1393
|
+
)
|
|
1394
|
+
# Escape single quotes in the deployment command for bash -c
|
|
1395
|
+
escaped_deployment_cmd = cfg.deployment.command.replace("'", "'\"'\"'")
|
|
1396
|
+
wrapped_command = (
|
|
1397
|
+
f"bash -c '{create_pre_script_cmd.cmd} && "
|
|
1398
|
+
f"source deployment_pre_cmd.sh && "
|
|
1399
|
+
f"{escaped_deployment_cmd}'"
|
|
1400
|
+
)
|
|
1401
|
+
s += "{} &\n\n".format(wrapped_command)
|
|
1402
|
+
else:
|
|
1403
|
+
s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
|
|
1404
|
+
|
|
1405
|
+
s += "SERVER_PID=$! # capture the PID of the server background srun process\n\n"
|
|
1406
|
+
|
|
1407
|
+
return s, is_potentially_unsafe, debug_comment
|
|
1408
|
+
|
|
1409
|
+
|
|
1410
|
+
def _get_wait_for_server_handler(
|
|
1411
|
+
ip_list: str,
|
|
1412
|
+
port: int,
|
|
1413
|
+
health_check_path: str,
|
|
1414
|
+
service_name: str = "server",
|
|
1415
|
+
check_pid: bool = False,
|
|
1416
|
+
):
|
|
1417
|
+
"""Generate wait for server handler that takes a list of IPs."""
|
|
1418
|
+
pid_check = ""
|
|
1419
|
+
if check_pid:
|
|
1420
|
+
pid_check = 'kill -0 "$SERVER_PID" 2>/dev/null || { echo "Server process $SERVER_PID died"; exit 1; }'
|
|
1421
|
+
|
|
1422
|
+
handler = f"""date
|
|
1423
|
+
# wait for the {service_name} to initialize
|
|
1424
|
+
for ip in {ip_list}; do
|
|
1425
|
+
echo "Waiting for {service_name} on $ip..."
|
|
1426
|
+
while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" http://$ip:{port}{health_check_path})" != "200" ]]; do
|
|
1427
|
+
{pid_check}
|
|
1428
|
+
sleep 5
|
|
1429
|
+
done
|
|
1430
|
+
echo "{service_name} ready on $ip!"
|
|
1431
|
+
done
|
|
1092
1432
|
date
|
|
1093
1433
|
""".strip()
|
|
1434
|
+
|
|
1435
|
+
return handler
|
|
1436
|
+
|
|
1437
|
+
|
|
1438
|
+
def _get_proxy_server_srun_command(cfg, remote_task_subdir):
|
|
1439
|
+
"""Generate proxy server srun command based on proxy type."""
|
|
1440
|
+
proxy_type = cfg.execution.get("proxy", {}).get("type", "haproxy")
|
|
1441
|
+
|
|
1442
|
+
if proxy_type == "haproxy":
|
|
1443
|
+
return _generate_haproxy_srun_command(cfg, remote_task_subdir)
|
|
1444
|
+
else:
|
|
1445
|
+
raise ValueError(
|
|
1446
|
+
f"Unsupported proxy type: {proxy_type}. Currently only 'haproxy' is supported."
|
|
1447
|
+
)
|
|
1448
|
+
|
|
1449
|
+
|
|
1450
|
+
def _generate_haproxy_srun_command(cfg, remote_task_subdir):
|
|
1451
|
+
"""Generate HAProxy-specific srun command using template-based config."""
|
|
1452
|
+
s = ""
|
|
1453
|
+
s += "# Proxy load balancer\n"
|
|
1454
|
+
s += "# Copy template to config file (important for restarts)\n"
|
|
1455
|
+
s += f"cp {remote_task_subdir}/proxy.cfg.template {remote_task_subdir}/proxy.cfg\n"
|
|
1456
|
+
s += "# Replace placeholder IPs with actual node IPs\n"
|
|
1457
|
+
s += f"proxy_config_file={remote_task_subdir}/proxy.cfg\n"
|
|
1458
|
+
s += 'for i in "${!NODES_IPS_ARRAY[@]}"; do\n'
|
|
1459
|
+
s += ' ip="${NODES_IPS_ARRAY[$i]}"\n'
|
|
1460
|
+
s += ' sed -i "s/{IP_$i}/$ip/g" "$proxy_config_file"\n'
|
|
1461
|
+
s += "done\n"
|
|
1462
|
+
s += "\n"
|
|
1463
|
+
s += "srun --mpi pmix --overlap "
|
|
1464
|
+
s += "--nodes 1 --ntasks 1 "
|
|
1465
|
+
s += f"--container-image {cfg.execution.get('proxy', {}).get('image', 'haproxy:latest')} "
|
|
1466
|
+
s += f"--container-mounts {remote_task_subdir}/proxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro "
|
|
1467
|
+
s += f"--output {remote_task_subdir}/logs/proxy-%A.log "
|
|
1468
|
+
s += "haproxy -f /usr/local/etc/haproxy/haproxy.cfg &\n"
|
|
1469
|
+
s += "PROXY_PID=$! # capture the PID of the proxy background srun process\n"
|
|
1470
|
+
s += 'echo "Proxy started with PID: $PROXY_PID"\n\n'
|
|
1471
|
+
|
|
1472
|
+
# Wait for proxy to be ready on localhost
|
|
1473
|
+
proxy_config = cfg.execution.get("proxy", {}).get("config", {})
|
|
1474
|
+
haproxy_port = proxy_config.get("haproxy_port", 5009)
|
|
1475
|
+
health_path = proxy_config.get("health_check_path", "/health")
|
|
1476
|
+
s += _get_wait_for_server_handler(
|
|
1477
|
+
"127.0.0.1", haproxy_port, health_path, "Proxy", check_pid=False
|
|
1478
|
+
)
|
|
1479
|
+
s += "\n"
|
|
1480
|
+
|
|
1481
|
+
return s
|
|
1482
|
+
|
|
1483
|
+
|
|
1484
|
+
def _collect_mount_paths(cfg: DictConfig) -> List[str]:
|
|
1485
|
+
"""Collect all mount source paths from the configuration.
|
|
1486
|
+
|
|
1487
|
+
Args:
|
|
1488
|
+
cfg: The configuration object for the evaluation run.
|
|
1489
|
+
|
|
1490
|
+
Returns:
|
|
1491
|
+
List of source paths that need to be mounted.
|
|
1492
|
+
"""
|
|
1493
|
+
mount_paths = []
|
|
1494
|
+
|
|
1495
|
+
# Deployment mounts
|
|
1496
|
+
if cfg.deployment.type != "none":
|
|
1497
|
+
if checkpoint_path := cfg.deployment.get("checkpoint_path"):
|
|
1498
|
+
mount_paths.append(checkpoint_path)
|
|
1499
|
+
if cache_path := cfg.deployment.get("cache_path"):
|
|
1500
|
+
mount_paths.append(cache_path)
|
|
1501
|
+
for source_mnt in cfg.execution.get("mounts", {}).get("deployment", {}).keys():
|
|
1502
|
+
mount_paths.append(source_mnt)
|
|
1503
|
+
|
|
1504
|
+
# Evaluation mounts
|
|
1505
|
+
for source_mnt in cfg.execution.get("mounts", {}).get("evaluation", {}).keys():
|
|
1506
|
+
mount_paths.append(source_mnt)
|
|
1507
|
+
|
|
1508
|
+
return mount_paths
|
|
1509
|
+
|
|
1510
|
+
|
|
1511
|
+
def _validate_remote_paths_exist(
|
|
1512
|
+
paths: List[str],
|
|
1513
|
+
username: str,
|
|
1514
|
+
hostname: str,
|
|
1515
|
+
socket: str | None,
|
|
1516
|
+
) -> None:
|
|
1517
|
+
"""Validate that all specified paths exist as directories on the remote host.
|
|
1518
|
+
|
|
1519
|
+
Args:
|
|
1520
|
+
paths: List of directory paths to validate.
|
|
1521
|
+
username: SSH username.
|
|
1522
|
+
hostname: SSH hostname.
|
|
1523
|
+
socket: control socket location or None
|
|
1524
|
+
|
|
1525
|
+
Raises:
|
|
1526
|
+
ValueError: If any paths do not exist as directories on the remote host.
|
|
1527
|
+
"""
|
|
1528
|
+
if not paths:
|
|
1529
|
+
return
|
|
1530
|
+
|
|
1531
|
+
# Remove duplicates while preserving order
|
|
1532
|
+
unique_paths = list(dict.fromkeys(paths))
|
|
1533
|
+
|
|
1534
|
+
# Build a single SSH command to check all paths at once
|
|
1535
|
+
test_commands = []
|
|
1536
|
+
for path in unique_paths:
|
|
1537
|
+
# Use test -d to check if directory exists
|
|
1538
|
+
# Escape single quotes in path using POSIX-safe method: ' becomes '"'"'
|
|
1539
|
+
escaped_path = path.replace("'", "'\"'\"'")
|
|
1540
|
+
test_commands.append(
|
|
1541
|
+
f"test -d '{escaped_path}' && echo 'EXISTS:{path}' || echo 'MISSING:{path}'"
|
|
1542
|
+
)
|
|
1543
|
+
|
|
1544
|
+
combined_command = " ; ".join(test_commands)
|
|
1545
|
+
|
|
1546
|
+
ssh_command = ["ssh"]
|
|
1547
|
+
if socket is not None:
|
|
1548
|
+
ssh_command.append(f"-S {socket}")
|
|
1549
|
+
ssh_command.append(f"{username}@{hostname}")
|
|
1550
|
+
ssh_command.append(combined_command)
|
|
1551
|
+
ssh_command = " ".join(ssh_command)
|
|
1552
|
+
|
|
1553
|
+
logger.info("Validating mount directories exist on remote host", cmd=ssh_command)
|
|
1554
|
+
completed_process = subprocess.run(
|
|
1555
|
+
args=shlex.split(ssh_command),
|
|
1556
|
+
stdout=subprocess.PIPE,
|
|
1557
|
+
stderr=subprocess.PIPE,
|
|
1558
|
+
)
|
|
1559
|
+
|
|
1560
|
+
if completed_process.returncode != 0:
|
|
1561
|
+
error_msg = (
|
|
1562
|
+
completed_process.stderr.decode("utf-8")
|
|
1563
|
+
if completed_process.stderr
|
|
1564
|
+
else "Unknown error"
|
|
1565
|
+
)
|
|
1566
|
+
logger.error(
|
|
1567
|
+
"Error validating remote paths",
|
|
1568
|
+
code=completed_process.returncode,
|
|
1569
|
+
msg=error_msg,
|
|
1570
|
+
)
|
|
1571
|
+
raise RuntimeError(f"Failed to validate remote paths: {error_msg}")
|
|
1572
|
+
|
|
1573
|
+
# Parse output to find missing paths
|
|
1574
|
+
output = completed_process.stdout.decode("utf-8")
|
|
1575
|
+
missing_paths = []
|
|
1576
|
+
for line in output.strip().split("\n"):
|
|
1577
|
+
if line.startswith("MISSING:"):
|
|
1578
|
+
missing_path = line.replace("MISSING:", "")
|
|
1579
|
+
missing_paths.append(missing_path)
|
|
1580
|
+
|
|
1581
|
+
if missing_paths:
|
|
1582
|
+
error_message = (
|
|
1583
|
+
f"The following mount paths do not exist as directories on {username}@{hostname}:\n"
|
|
1584
|
+
+ "\n".join(f" - {path}" for path in missing_paths)
|
|
1585
|
+
+ "\n\nMount paths must be directories. Please create these directories on the cluster or update your configuration."
|
|
1586
|
+
)
|
|
1587
|
+
logger.error("Mount validation failed", missing_paths=missing_paths)
|
|
1588
|
+
raise ValueError(error_message)
|