nemo-evaluator-launcher 0.1.0rc6__py3-none-any.whl → 0.1.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nemo_evaluator_launcher/__init__.py +15 -1
- nemo_evaluator_launcher/api/functional.py +188 -27
- nemo_evaluator_launcher/api/types.py +9 -0
- nemo_evaluator_launcher/cli/export.py +131 -12
- nemo_evaluator_launcher/cli/info.py +477 -82
- nemo_evaluator_launcher/cli/kill.py +5 -3
- nemo_evaluator_launcher/cli/logs.py +102 -0
- nemo_evaluator_launcher/cli/ls_runs.py +31 -10
- nemo_evaluator_launcher/cli/ls_tasks.py +105 -3
- nemo_evaluator_launcher/cli/main.py +101 -5
- nemo_evaluator_launcher/cli/run.py +153 -30
- nemo_evaluator_launcher/cli/status.py +49 -5
- nemo_evaluator_launcher/cli/version.py +26 -23
- nemo_evaluator_launcher/common/execdb.py +121 -27
- nemo_evaluator_launcher/common/helpers.py +213 -33
- nemo_evaluator_launcher/common/logging_utils.py +16 -5
- nemo_evaluator_launcher/common/printing_utils.py +100 -0
- nemo_evaluator_launcher/configs/deployment/generic.yaml +33 -0
- nemo_evaluator_launcher/configs/deployment/sglang.yaml +4 -2
- nemo_evaluator_launcher/configs/deployment/trtllm.yaml +23 -0
- nemo_evaluator_launcher/configs/deployment/vllm.yaml +2 -2
- nemo_evaluator_launcher/configs/execution/local.yaml +2 -0
- nemo_evaluator_launcher/configs/execution/slurm/default.yaml +19 -4
- nemo_evaluator_launcher/executors/base.py +54 -1
- nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +60 -5
- nemo_evaluator_launcher/executors/lepton/executor.py +240 -101
- nemo_evaluator_launcher/executors/lepton/job_helpers.py +15 -11
- nemo_evaluator_launcher/executors/local/executor.py +492 -56
- nemo_evaluator_launcher/executors/local/run.template.sh +76 -9
- nemo_evaluator_launcher/executors/slurm/executor.py +571 -98
- nemo_evaluator_launcher/executors/slurm/proxy.cfg.template +26 -0
- nemo_evaluator_launcher/exporters/base.py +9 -0
- nemo_evaluator_launcher/exporters/gsheets.py +27 -9
- nemo_evaluator_launcher/exporters/local.py +30 -16
- nemo_evaluator_launcher/exporters/mlflow.py +245 -74
- nemo_evaluator_launcher/exporters/utils.py +139 -184
- nemo_evaluator_launcher/exporters/wandb.py +157 -43
- nemo_evaluator_launcher/package_info.py +6 -3
- nemo_evaluator_launcher/resources/mapping.toml +56 -15
- nemo_evaluator_launcher-0.1.41.dist-info/METADATA +494 -0
- nemo_evaluator_launcher-0.1.41.dist-info/RECORD +62 -0
- {nemo_evaluator_launcher-0.1.0rc6.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/entry_points.txt +1 -0
- nemo_evaluator_launcher-0.1.0rc6.dist-info/METADATA +0 -35
- nemo_evaluator_launcher-0.1.0rc6.dist-info/RECORD +0 -57
- {nemo_evaluator_launcher-0.1.0rc6.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/WHEEL +0 -0
- {nemo_evaluator_launcher-0.1.0rc6.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/licenses/LICENSE +0 -0
- {nemo_evaluator_launcher-0.1.0rc6.dist-info → nemo_evaluator_launcher-0.1.41.dist-info}/top_level.txt +0 -0
|
@@ -30,6 +30,7 @@ from pathlib import Path
|
|
|
30
30
|
from typing import Dict, List, Optional
|
|
31
31
|
|
|
32
32
|
import yaml
|
|
33
|
+
from jinja2 import Environment, FileSystemLoader
|
|
33
34
|
from omegaconf import DictConfig, OmegaConf
|
|
34
35
|
|
|
35
36
|
from nemo_evaluator_launcher.common.execdb import (
|
|
@@ -39,17 +40,19 @@ from nemo_evaluator_launcher.common.execdb import (
|
|
|
39
40
|
generate_job_id,
|
|
40
41
|
)
|
|
41
42
|
from nemo_evaluator_launcher.common.helpers import (
|
|
43
|
+
CmdAndReadableComment,
|
|
44
|
+
_str_to_echo_command,
|
|
42
45
|
get_api_key_name,
|
|
43
|
-
get_endpoint_url,
|
|
44
46
|
get_eval_factory_command,
|
|
45
47
|
get_eval_factory_dataset_size_from_run_config,
|
|
46
|
-
get_health_url,
|
|
47
48
|
get_timestamp_string,
|
|
48
49
|
)
|
|
50
|
+
from nemo_evaluator_launcher.common.logging_utils import logger
|
|
49
51
|
from nemo_evaluator_launcher.common.mapping import (
|
|
50
52
|
get_task_from_mapping,
|
|
51
53
|
load_tasks_mapping,
|
|
52
54
|
)
|
|
55
|
+
from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey, red
|
|
53
56
|
from nemo_evaluator_launcher.executors.base import (
|
|
54
57
|
BaseExecutor,
|
|
55
58
|
ExecutionState,
|
|
@@ -93,6 +96,7 @@ class SlurmExecutor(BaseExecutor):
|
|
|
93
96
|
tasks_mapping = load_tasks_mapping()
|
|
94
97
|
eval_images: list[str] = []
|
|
95
98
|
|
|
99
|
+
is_potentially_unsafe = False
|
|
96
100
|
for idx, task in enumerate(cfg.evaluation.tasks):
|
|
97
101
|
# calculate job_id
|
|
98
102
|
job_id = f"{invocation_id}.{idx}"
|
|
@@ -113,7 +117,7 @@ class SlurmExecutor(BaseExecutor):
|
|
|
113
117
|
eval_images.append(eval_image)
|
|
114
118
|
|
|
115
119
|
# generate and write down sbatch script
|
|
116
|
-
|
|
120
|
+
sbatch_script_content_struct = _create_slurm_sbatch_script(
|
|
117
121
|
cfg=cfg,
|
|
118
122
|
task=task,
|
|
119
123
|
eval_image=eval_image,
|
|
@@ -121,6 +125,32 @@ class SlurmExecutor(BaseExecutor):
|
|
|
121
125
|
invocation_id=invocation_id,
|
|
122
126
|
job_id=job_id,
|
|
123
127
|
)
|
|
128
|
+
|
|
129
|
+
# Create proxy config file with placeholder IPs for multi-instance deployments
|
|
130
|
+
if cfg.deployment.get("multiple_instances", False):
|
|
131
|
+
proxy_type = cfg.execution.get("proxy", {}).get("type", "haproxy")
|
|
132
|
+
if proxy_type == "haproxy":
|
|
133
|
+
proxy_config = _generate_haproxy_config_with_placeholders(cfg)
|
|
134
|
+
else:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
f"Unsupported proxy type: {proxy_type}. Currently only 'haproxy' is supported."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Save both template and working config
|
|
140
|
+
proxy_template_path = local_task_subdir / "proxy.cfg.template"
|
|
141
|
+
proxy_config_path = local_task_subdir / "proxy.cfg"
|
|
142
|
+
with open(proxy_template_path, "w") as f:
|
|
143
|
+
f.write(proxy_config)
|
|
144
|
+
with open(proxy_config_path, "w") as f:
|
|
145
|
+
f.write(proxy_config)
|
|
146
|
+
|
|
147
|
+
sbatch_script_content_str = sbatch_script_content_struct.cmd
|
|
148
|
+
|
|
149
|
+
# We accumulate if any task contains unsafe commands
|
|
150
|
+
is_potentially_unsafe = (
|
|
151
|
+
is_potentially_unsafe
|
|
152
|
+
or sbatch_script_content_struct.is_potentially_unsafe
|
|
153
|
+
)
|
|
124
154
|
local_runsub_path = local_task_subdir / "run.sub"
|
|
125
155
|
remote_runsub_path = remote_task_subdir / "run.sub"
|
|
126
156
|
with open(local_runsub_path, "w") as f:
|
|
@@ -130,15 +160,41 @@ class SlurmExecutor(BaseExecutor):
|
|
|
130
160
|
remote_runsub_paths.append(remote_runsub_path)
|
|
131
161
|
|
|
132
162
|
if dry_run:
|
|
133
|
-
print("\n\n=============================================\n\n")
|
|
134
|
-
print("DRY RUN: SLURM scripts prepared")
|
|
163
|
+
print(bold("\n\n=============================================\n\n"))
|
|
164
|
+
print(bold(cyan("DRY RUN: SLURM scripts prepared")))
|
|
135
165
|
for idx, local_runsub_path in enumerate(local_runsub_paths):
|
|
136
|
-
print(f"\n\n
|
|
166
|
+
print(cyan(f"\n\n=========== Task {idx} =====================\n\n"))
|
|
137
167
|
with open(local_runsub_path, "r") as f:
|
|
138
|
-
print(f.read())
|
|
139
|
-
print("
|
|
168
|
+
print(grey(f.read()))
|
|
169
|
+
print(bold("To submit jobs") + ", run the executor without --dry-run")
|
|
170
|
+
if is_potentially_unsafe:
|
|
171
|
+
print(
|
|
172
|
+
red(
|
|
173
|
+
"\nFound `pre_cmd` (evaluation or deployment) which carries security risk. When running without --dry-run "
|
|
174
|
+
"make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1"
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
|
|
140
178
|
return invocation_id
|
|
141
179
|
|
|
180
|
+
if is_potentially_unsafe:
|
|
181
|
+
if os.environ.get("NEMO_EVALUATOR_TRUST_PRE_CMD", "") == "1":
|
|
182
|
+
logger.warning(
|
|
183
|
+
"Found non-empty commands (e.g. `pre_cmd` in evaluation or deployment) and NEMO_EVALUATOR_TRUST_PRE_CMD "
|
|
184
|
+
"is set, proceeding with caution."
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
else:
|
|
188
|
+
logger.error(
|
|
189
|
+
"Found non-empty commands (e.g. `pre_cmd` in evaluation or deployment) and NEMO_EVALUATOR_TRUST_PRE_CMD "
|
|
190
|
+
"is not set. This might carry security risk and unstable environments. "
|
|
191
|
+
"To continue, make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1.",
|
|
192
|
+
)
|
|
193
|
+
raise AttributeError(
|
|
194
|
+
"Untrusted command found in config, make sure you trust and "
|
|
195
|
+
"set NEMO_EVALUATOR_TRUST_PRE_CMD=1."
|
|
196
|
+
)
|
|
197
|
+
|
|
142
198
|
socket = str(Path(tmpdirname) / "socket")
|
|
143
199
|
socket_or_none = _open_master_connection(
|
|
144
200
|
username=cfg.execution.username,
|
|
@@ -174,10 +230,11 @@ class SlurmExecutor(BaseExecutor):
|
|
|
174
230
|
for idx, (slurm_job_id, remote_runsub_path) in enumerate(
|
|
175
231
|
zip(slurm_job_ids, remote_runsub_paths)
|
|
176
232
|
):
|
|
233
|
+
job_id = generate_job_id(invocation_id, idx)
|
|
177
234
|
db.write_job(
|
|
178
235
|
job=JobData(
|
|
179
236
|
invocation_id=invocation_id,
|
|
180
|
-
job_id=
|
|
237
|
+
job_id=job_id,
|
|
181
238
|
timestamp=time.time(),
|
|
182
239
|
executor="slurm",
|
|
183
240
|
data={
|
|
@@ -204,8 +261,8 @@ class SlurmExecutor(BaseExecutor):
|
|
|
204
261
|
"""
|
|
205
262
|
db = ExecutionDB()
|
|
206
263
|
|
|
207
|
-
# If id looks like an invocation_id
|
|
208
|
-
if
|
|
264
|
+
# If id looks like an invocation_id
|
|
265
|
+
if "." not in id:
|
|
209
266
|
jobs = db.get_jobs(id)
|
|
210
267
|
if not jobs:
|
|
211
268
|
return []
|
|
@@ -388,7 +445,7 @@ class SlurmExecutor(BaseExecutor):
|
|
|
388
445
|
"""Kill a SLURM job.
|
|
389
446
|
|
|
390
447
|
Args:
|
|
391
|
-
job_id: The job ID to kill.
|
|
448
|
+
job_id: The job ID (e.g., abc123.0) to kill.
|
|
392
449
|
"""
|
|
393
450
|
db = ExecutionDB()
|
|
394
451
|
job_data = db.get_job(job_id)
|
|
@@ -401,26 +458,31 @@ class SlurmExecutor(BaseExecutor):
|
|
|
401
458
|
f"Job {job_id} is not a slurm job (executor: {job_data.executor})"
|
|
402
459
|
)
|
|
403
460
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
result = _kill_slurm_job(
|
|
461
|
+
# OPTIMIZATION: Query status AND kill in ONE SSH call
|
|
462
|
+
slurm_status, result = _kill_slurm_job(
|
|
407
463
|
slurm_job_ids=[job_data.data.get("slurm_job_id")],
|
|
408
464
|
username=job_data.data.get("username"),
|
|
409
465
|
hostname=job_data.data.get("hostname"),
|
|
410
466
|
socket=job_data.data.get("socket"),
|
|
411
467
|
)
|
|
412
468
|
|
|
469
|
+
# Mark job as killed in database if kill succeeded
|
|
413
470
|
if result.returncode == 0:
|
|
414
|
-
killed_something = True
|
|
415
|
-
|
|
416
|
-
# Mark job as killed in database if we killed something
|
|
417
|
-
if killed_something:
|
|
418
471
|
job_data.data["killed"] = True
|
|
419
472
|
db.write_job(job_data)
|
|
420
473
|
else:
|
|
421
|
-
|
|
422
|
-
|
|
474
|
+
# Use the pre-fetched status for better error message
|
|
475
|
+
current_status = None
|
|
476
|
+
if slurm_status:
|
|
477
|
+
current_status = SlurmExecutor._map_slurm_state_to_execution_state(
|
|
478
|
+
slurm_status
|
|
479
|
+
)
|
|
480
|
+
error_msg = SlurmExecutor.get_kill_failure_message(
|
|
481
|
+
job_id,
|
|
482
|
+
f"slurm_job_id: {job_data.data.get('slurm_job_id')}",
|
|
483
|
+
current_status,
|
|
423
484
|
)
|
|
485
|
+
raise RuntimeError(error_msg)
|
|
424
486
|
|
|
425
487
|
|
|
426
488
|
def _create_slurm_sbatch_script(
|
|
@@ -430,7 +492,7 @@ def _create_slurm_sbatch_script(
|
|
|
430
492
|
remote_task_subdir: Path,
|
|
431
493
|
invocation_id: str,
|
|
432
494
|
job_id: str,
|
|
433
|
-
) ->
|
|
495
|
+
) -> CmdAndReadableComment:
|
|
434
496
|
"""Generate the contents of a SLURM sbatch script for a given evaluation task.
|
|
435
497
|
|
|
436
498
|
Args:
|
|
@@ -446,7 +508,6 @@ def _create_slurm_sbatch_script(
|
|
|
446
508
|
# get task from mapping, overrides, urls
|
|
447
509
|
tasks_mapping = load_tasks_mapping()
|
|
448
510
|
task_definition = get_task_from_mapping(task.name, tasks_mapping)
|
|
449
|
-
health_url = get_health_url(cfg, get_endpoint_url(cfg, task, task_definition))
|
|
450
511
|
|
|
451
512
|
# TODO(public release): convert to template
|
|
452
513
|
s = "#!/bin/bash\n"
|
|
@@ -461,6 +522,8 @@ def _create_slurm_sbatch_script(
|
|
|
461
522
|
s += "#SBATCH --gpus-per-node {}\n".format(cfg.execution.gpus_per_node)
|
|
462
523
|
if hasattr(cfg.execution, "gres"):
|
|
463
524
|
s += "#SBATCH --gres {}\n".format(cfg.execution.gres)
|
|
525
|
+
if cfg.execution.get("sbatch_comment"):
|
|
526
|
+
s += "#SBATCH --comment='{}'\n".format(cfg.execution.sbatch_comment)
|
|
464
527
|
job_name = "{account}-{subproject}.{details}".format(
|
|
465
528
|
account=cfg.execution.account,
|
|
466
529
|
subproject=cfg.execution.subproject,
|
|
@@ -470,6 +533,8 @@ def _create_slurm_sbatch_script(
|
|
|
470
533
|
s += "#SBATCH --exclusive\n"
|
|
471
534
|
s += "#SBATCH --output {}\n".format(remote_task_subdir / "logs" / "slurm-%A.out")
|
|
472
535
|
s += "\n"
|
|
536
|
+
s += f'TASK_DIR="{str(remote_task_subdir)}"\n'
|
|
537
|
+
s += "\n"
|
|
473
538
|
|
|
474
539
|
# collect all env vars
|
|
475
540
|
env_vars = copy.deepcopy(dict(cfg.evaluation.get("env_vars", {})))
|
|
@@ -484,8 +549,11 @@ def _create_slurm_sbatch_script(
|
|
|
484
549
|
if os.getenv(env_var) is None:
|
|
485
550
|
raise ValueError(f"Trying to pass an unset environment variable {env_var}.")
|
|
486
551
|
|
|
487
|
-
# check if required env vars are defined:
|
|
552
|
+
# check if required env vars are defined (excluding NEMO_EVALUATOR_DATASET_DIR which is handled separately):
|
|
488
553
|
for required_env_var in task_definition.get("required_env_vars", []):
|
|
554
|
+
# Skip NEMO_EVALUATOR_DATASET_DIR as it's handled by dataset mounting logic below
|
|
555
|
+
if required_env_var == "NEMO_EVALUATOR_DATASET_DIR":
|
|
556
|
+
continue
|
|
489
557
|
if required_env_var not in env_vars.keys():
|
|
490
558
|
raise ValueError(
|
|
491
559
|
f"{task.name} task requires environment variable {required_env_var}."
|
|
@@ -531,6 +599,7 @@ def _create_slurm_sbatch_script(
|
|
|
531
599
|
|
|
532
600
|
# prepare deployment mounts
|
|
533
601
|
deployment_mounts_list = []
|
|
602
|
+
deployment_is_unsafe = False
|
|
534
603
|
if cfg.deployment.type != "none":
|
|
535
604
|
if checkpoint_path := cfg.deployment.get("checkpoint_path"):
|
|
536
605
|
deployment_mounts_list.append(f"{checkpoint_path}:/checkpoint:ro")
|
|
@@ -542,36 +611,33 @@ def _create_slurm_sbatch_script(
|
|
|
542
611
|
deployment_mounts_list.append(f"{source_mnt}:{target_mnt}")
|
|
543
612
|
|
|
544
613
|
# add deployment srun command
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
if deployment_mounts_list:
|
|
549
|
-
s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
|
|
550
|
-
if not cfg.execution.get("mounts", {}).get("mount_home", True):
|
|
551
|
-
s += "--no-container-mount-home "
|
|
552
|
-
s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A.out")
|
|
553
|
-
deployment_env_var_names = list(
|
|
554
|
-
cfg.execution.get("env_vars", {}).get("deployment", {})
|
|
555
|
-
)
|
|
556
|
-
if cfg.deployment.get("env_vars"):
|
|
557
|
-
warnings.warn(
|
|
558
|
-
"cfg.deployment.env_vars will be deprecated in future versions. "
|
|
559
|
-
"Use cfg.execution.env_vars.deployment instead.",
|
|
560
|
-
category=DeprecationWarning,
|
|
561
|
-
stacklevel=2,
|
|
614
|
+
deployment_srun_cmd, deployment_is_unsafe, deployment_debug = (
|
|
615
|
+
_generate_deployment_srun_command(
|
|
616
|
+
cfg, deployment_mounts_list, remote_task_subdir
|
|
562
617
|
)
|
|
563
|
-
deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
|
|
564
|
-
if deployment_env_var_names:
|
|
565
|
-
s += f"--container-env {','.join(deployment_env_var_names)} "
|
|
566
|
-
s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
|
|
567
|
-
s += (
|
|
568
|
-
"SERVER_PID=$! # capture the PID of the server background srun process\n\n"
|
|
569
618
|
)
|
|
619
|
+
s += deployment_srun_cmd
|
|
570
620
|
|
|
571
621
|
# wait for the server to initialize
|
|
572
|
-
|
|
622
|
+
health_path = cfg.deployment.get("health_check_path", "/health")
|
|
623
|
+
# For multi-instance check all node IPs, for single instance check localhost
|
|
624
|
+
if cfg.deployment.get("multiple_instances", False):
|
|
625
|
+
ip_list = '"${NODES_IPS_ARRAY[@]}"'
|
|
626
|
+
else:
|
|
627
|
+
ip_list = '"127.0.0.1"'
|
|
628
|
+
s += _get_wait_for_server_handler(
|
|
629
|
+
ip_list,
|
|
630
|
+
cfg.deployment.port,
|
|
631
|
+
health_path,
|
|
632
|
+
"server",
|
|
633
|
+
check_pid=True,
|
|
634
|
+
)
|
|
573
635
|
s += "\n\n"
|
|
574
636
|
|
|
637
|
+
# add proxy load balancer for multi-instance deployments
|
|
638
|
+
if cfg.deployment.get("multiple_instances", False):
|
|
639
|
+
s += _get_proxy_server_srun_command(cfg, remote_task_subdir)
|
|
640
|
+
|
|
575
641
|
# prepare evaluation mounts
|
|
576
642
|
evaluation_mounts_list = [
|
|
577
643
|
"{}:/results".format(remote_task_subdir / "artifacts"),
|
|
@@ -581,9 +647,45 @@ def _create_slurm_sbatch_script(
|
|
|
581
647
|
):
|
|
582
648
|
evaluation_mounts_list.append(f"{source_mnt}:{target_mnt}")
|
|
583
649
|
|
|
650
|
+
# Handle dataset directory mounting if NEMO_EVALUATOR_DATASET_DIR is required
|
|
651
|
+
if "NEMO_EVALUATOR_DATASET_DIR" in task_definition.get("required_env_vars", []):
|
|
652
|
+
# Get dataset directory from task config
|
|
653
|
+
if "dataset_dir" in task:
|
|
654
|
+
dataset_mount_host = task["dataset_dir"]
|
|
655
|
+
else:
|
|
656
|
+
raise ValueError(
|
|
657
|
+
f"{task.name} task requires a dataset_dir to be specified. "
|
|
658
|
+
f"Add 'dataset_dir: /path/to/your/dataset' under the task configuration."
|
|
659
|
+
)
|
|
660
|
+
# Get container mount path (default to /datasets if not specified)
|
|
661
|
+
dataset_mount_container = task.get("dataset_mount_path", "/datasets")
|
|
662
|
+
# Add dataset mount to evaluation mounts list
|
|
663
|
+
evaluation_mounts_list.append(f"{dataset_mount_host}:{dataset_mount_container}")
|
|
664
|
+
# Export NEMO_EVALUATOR_DATASET_DIR environment variable
|
|
665
|
+
s += f"export NEMO_EVALUATOR_DATASET_DIR={dataset_mount_container}\n\n"
|
|
666
|
+
|
|
667
|
+
eval_factory_command_struct = get_eval_factory_command(
|
|
668
|
+
cfg,
|
|
669
|
+
task,
|
|
670
|
+
task_definition,
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
eval_factory_command = eval_factory_command_struct.cmd
|
|
674
|
+
# The debug comment for placing into the script and easy debug. Reason
|
|
675
|
+
# (see `CmdAndReadableComment`) is the current way of passing the command
|
|
676
|
+
# is base64-encoded config `echo`-ed into file.
|
|
677
|
+
# TODO(agronskiy): cleaner way is to encode everything with base64, not
|
|
678
|
+
# some parts (like ef_config.yaml) and just output as logs somewhere.
|
|
679
|
+
eval_factory_command_debug_comment = eval_factory_command_struct.debug
|
|
680
|
+
|
|
584
681
|
# add evaluation srun command
|
|
682
|
+
s += "# Debug contents of the eval factory command's config\n"
|
|
683
|
+
s += eval_factory_command_debug_comment
|
|
684
|
+
s += "\n\n"
|
|
685
|
+
|
|
585
686
|
s += "# evaluation client\n"
|
|
586
687
|
s += "srun --mpi pmix --overlap "
|
|
688
|
+
s += "--nodes 1 --ntasks 1 " # Client always runs on single node
|
|
587
689
|
s += "--container-image {} ".format(eval_image)
|
|
588
690
|
evaluation_env_var_names = list(
|
|
589
691
|
cfg.execution.get("env_vars", {}).get("evaluation", {})
|
|
@@ -592,43 +694,139 @@ def _create_slurm_sbatch_script(
|
|
|
592
694
|
s += "--container-env {} ".format(",".join(evaluation_env_var_names))
|
|
593
695
|
if not cfg.execution.get("mounts", {}).get("mount_home", True):
|
|
594
696
|
s += "--no-container-mount-home "
|
|
697
|
+
|
|
595
698
|
s += "--container-mounts {} ".format(",".join(evaluation_mounts_list))
|
|
596
699
|
s += "--output {} ".format(remote_task_subdir / "logs" / "client-%A.out")
|
|
597
|
-
s += "bash -c '"
|
|
598
|
-
s +=
|
|
700
|
+
s += "bash -c '\n"
|
|
701
|
+
s += eval_factory_command
|
|
599
702
|
s += "'\n\n"
|
|
600
703
|
|
|
601
704
|
# terminate the server after all evaluation clients finish
|
|
602
705
|
if cfg.deployment.type != "none":
|
|
603
|
-
s += "kill $SERVER_PID # terminate the server to finish gracefully\n
|
|
706
|
+
s += "kill $SERVER_PID # terminate the server to finish gracefully\n"
|
|
707
|
+
if cfg.deployment.get("multiple_instances", False):
|
|
708
|
+
s += "kill $PROXY_PID # terminate proxy to finish gracefully\n"
|
|
709
|
+
s += "\n"
|
|
604
710
|
|
|
605
711
|
# auto-export
|
|
606
|
-
|
|
607
|
-
|
|
712
|
+
ae_cfg = cfg.execution.get("auto_export")
|
|
713
|
+
destinations: list = []
|
|
714
|
+
if isinstance(ae_cfg, list):
|
|
715
|
+
destinations = list(ae_cfg)
|
|
716
|
+
elif isinstance(ae_cfg, dict) or isinstance(ae_cfg, DictConfig):
|
|
717
|
+
destinations = list(ae_cfg.get("destinations", []) or [])
|
|
718
|
+
|
|
719
|
+
if destinations:
|
|
720
|
+
export_env = dict(cfg.execution.get("env_vars", {}).get("export", {}) or {})
|
|
721
|
+
s += _generate_auto_export_section(
|
|
722
|
+
cfg, job_id, destinations, export_env, remote_task_subdir
|
|
723
|
+
)
|
|
608
724
|
|
|
609
|
-
|
|
725
|
+
debug_str = "\n".join(["# " + line for line in s.splitlines()])
|
|
726
|
+
|
|
727
|
+
# Combine unsafe flags from both deployment and evaluation
|
|
728
|
+
is_potentially_unsafe = (
|
|
729
|
+
eval_factory_command_struct.is_potentially_unsafe or deployment_is_unsafe
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
return CmdAndReadableComment(
|
|
733
|
+
cmd=s,
|
|
734
|
+
debug=debug_str,
|
|
735
|
+
is_potentially_unsafe=is_potentially_unsafe,
|
|
736
|
+
)
|
|
610
737
|
|
|
611
738
|
|
|
612
739
|
def _generate_auto_export_section(
|
|
613
740
|
cfg: DictConfig,
|
|
614
|
-
job_id: str,
|
|
741
|
+
job_id: str,
|
|
742
|
+
destinations: list,
|
|
743
|
+
export_env: dict,
|
|
744
|
+
remote_task_subdir: Path,
|
|
745
|
+
export_image: str = "python:3.12.7-slim",
|
|
615
746
|
) -> str:
|
|
616
747
|
"""Generate simple auto-export section for sbatch script."""
|
|
617
|
-
auto_export_config = cfg.execution.get("auto_export", {})
|
|
618
|
-
destinations = auto_export_config.get("destinations", [])
|
|
619
|
-
|
|
620
748
|
if not destinations:
|
|
621
749
|
return ""
|
|
622
750
|
|
|
623
|
-
s = "\n#
|
|
751
|
+
s = "\n# Auto-export on success\n"
|
|
624
752
|
s += "EVAL_EXIT_CODE=$?\n"
|
|
625
753
|
s += "if [ $EVAL_EXIT_CODE -eq 0 ]; then\n"
|
|
626
754
|
s += " echo 'Evaluation completed successfully. Starting auto-export...'\n"
|
|
755
|
+
s += f' cd "{remote_task_subdir}/artifacts"\n'
|
|
627
756
|
|
|
628
|
-
for
|
|
629
|
-
|
|
630
|
-
|
|
757
|
+
# Work with DictConfig; convert only for YAML at the end
|
|
758
|
+
exec_type = (
|
|
759
|
+
cfg.execution.type
|
|
760
|
+
if hasattr(cfg.execution, "type")
|
|
761
|
+
else cfg.execution.get("type", "slurm")
|
|
762
|
+
)
|
|
763
|
+
eval_tasks = (
|
|
764
|
+
list(cfg.evaluation.tasks)
|
|
765
|
+
if hasattr(cfg, "evaluation") and hasattr(cfg.evaluation, "tasks")
|
|
766
|
+
else list((cfg.get("evaluation", {}) or {}).get("tasks", []) or [])
|
|
767
|
+
)
|
|
768
|
+
export_block = cfg.get("export", {}) or {}
|
|
769
|
+
|
|
770
|
+
payload = {
|
|
771
|
+
"execution": {
|
|
772
|
+
"auto_export": {
|
|
773
|
+
"destinations": list(destinations),
|
|
774
|
+
**({"env_vars": dict(export_env)} if export_env else {}),
|
|
775
|
+
},
|
|
776
|
+
"type": exec_type,
|
|
777
|
+
},
|
|
778
|
+
"evaluation": {"tasks": eval_tasks},
|
|
779
|
+
}
|
|
780
|
+
if export_block:
|
|
781
|
+
# Convert just this block to plain for YAML
|
|
782
|
+
payload["export"] = (
|
|
783
|
+
OmegaConf.to_object(export_block)
|
|
784
|
+
if OmegaConf.is_config(export_block)
|
|
785
|
+
else dict(export_block)
|
|
786
|
+
)
|
|
631
787
|
|
|
788
|
+
# Final YAML (single conversion at the end)
|
|
789
|
+
payload_clean = OmegaConf.to_container(OmegaConf.create(payload), resolve=True)
|
|
790
|
+
yaml_str = yaml.safe_dump(payload_clean, sort_keys=False)
|
|
791
|
+
s += " cat > export_config.yml << 'EOF'\n"
|
|
792
|
+
s += yaml_str
|
|
793
|
+
s += "EOF\n"
|
|
794
|
+
|
|
795
|
+
# write launcher config as config.yml for exporters (no core command)
|
|
796
|
+
submitted_yaml = yaml.safe_dump(
|
|
797
|
+
OmegaConf.to_container(cfg, resolve=True), sort_keys=False
|
|
798
|
+
)
|
|
799
|
+
s += " cat > config.yml << 'EOF'\n"
|
|
800
|
+
s += submitted_yaml
|
|
801
|
+
s += "EOF\n"
|
|
802
|
+
|
|
803
|
+
# Export host only env before running auto export
|
|
804
|
+
for k, v in (export_env or {}).items():
|
|
805
|
+
if isinstance(v, str) and re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", v):
|
|
806
|
+
s += f' export {k}="${{{v}}}"\n'
|
|
807
|
+
else:
|
|
808
|
+
esc = str(v).replace('"', '\\"')
|
|
809
|
+
s += f' export {k}="{esc}"\n'
|
|
810
|
+
|
|
811
|
+
s += " # export\n"
|
|
812
|
+
s += " srun --mpi pmix --overlap "
|
|
813
|
+
s += "--nodes 1 --ntasks 1 " # Client always runs on single node
|
|
814
|
+
s += "--container-image {} ".format(export_image)
|
|
815
|
+
if export_env:
|
|
816
|
+
s += "--container-env {} ".format(",".join(export_env))
|
|
817
|
+
if not cfg.execution.get("mounts", {}).get("mount_home", True):
|
|
818
|
+
s += "--no-container-mount-home "
|
|
819
|
+
|
|
820
|
+
s += f"--container-mounts {remote_task_subdir}/artifacts:{remote_task_subdir}/artifacts "
|
|
821
|
+
s += "--output {} ".format(remote_task_subdir / "logs" / "export-%A.out")
|
|
822
|
+
s += " bash -c '\n"
|
|
823
|
+
# FIXME(martas): would be good to install specific version
|
|
824
|
+
s += " pip install nemo-evaluator-launcher[all]\n"
|
|
825
|
+
s += f" cd {remote_task_subdir}/artifacts\n"
|
|
826
|
+
for dest in destinations:
|
|
827
|
+
s += f' echo "Exporting to {dest}..."\n'
|
|
828
|
+
s += f' nemo-evaluator-launcher export {job_id} --dest {dest} || echo "Export to {dest} failed"\n'
|
|
829
|
+
s += "'\n"
|
|
632
830
|
s += " echo 'Auto-export completed.'\n"
|
|
633
831
|
s += "else\n"
|
|
634
832
|
s += " echo 'Evaluation failed with exit code $EVAL_EXIT_CODE. Skipping auto-export.'\n"
|
|
@@ -643,9 +841,12 @@ def _open_master_connection(
|
|
|
643
841
|
socket: str,
|
|
644
842
|
) -> str | None:
|
|
645
843
|
ssh_command = f"ssh -MNf -S {socket} {username}@{hostname}"
|
|
844
|
+
logger.info("Opening master connection", cmd=ssh_command)
|
|
646
845
|
completed_process = subprocess.run(args=shlex.split(ssh_command))
|
|
647
846
|
if completed_process.returncode == 0:
|
|
847
|
+
logger.info("Opened master connection successfully", cmd=ssh_command)
|
|
648
848
|
return socket
|
|
849
|
+
logger.error("Failed to open master connection", code=completed_process.returncode)
|
|
649
850
|
return None
|
|
650
851
|
|
|
651
852
|
|
|
@@ -657,9 +858,7 @@ def _close_master_connection(
|
|
|
657
858
|
if socket is None:
|
|
658
859
|
return
|
|
659
860
|
ssh_command = f"ssh -O exit -S {socket} {username}@{hostname}"
|
|
660
|
-
completed_process = subprocess.run(
|
|
661
|
-
args=shlex.split(ssh_command), capture_output=True
|
|
662
|
-
)
|
|
861
|
+
completed_process = subprocess.run(args=shlex.split(ssh_command))
|
|
663
862
|
if completed_process.returncode != 0:
|
|
664
863
|
raise RuntimeError(
|
|
665
864
|
"failed to close the master connection\n{}".format(
|
|
@@ -681,12 +880,23 @@ def _make_remote_execution_output_dir(
|
|
|
681
880
|
ssh_command.append(f"{username}@{hostname}")
|
|
682
881
|
ssh_command.append(mkdir_command)
|
|
683
882
|
ssh_command = " ".join(ssh_command)
|
|
684
|
-
|
|
883
|
+
logger.info("Creating remote dir", cmd=ssh_command)
|
|
884
|
+
completed_process = subprocess.run(
|
|
885
|
+
args=shlex.split(ssh_command), stderr=subprocess.PIPE
|
|
886
|
+
)
|
|
685
887
|
if completed_process.returncode != 0:
|
|
888
|
+
error_msg = (
|
|
889
|
+
completed_process.stderr.decode("utf-8")
|
|
890
|
+
if completed_process.stderr
|
|
891
|
+
else "Unknown error"
|
|
892
|
+
)
|
|
893
|
+
logger.error(
|
|
894
|
+
"Erorr creating remote dir",
|
|
895
|
+
code=completed_process.returncode,
|
|
896
|
+
msg=error_msg,
|
|
897
|
+
)
|
|
686
898
|
raise RuntimeError(
|
|
687
|
-
"failed to make a remote execution output dir\n{}".format(
|
|
688
|
-
completed_process.stderr.decode("utf-8")
|
|
689
|
-
)
|
|
899
|
+
"failed to make a remote execution output dir\n{}".format(error_msg)
|
|
690
900
|
)
|
|
691
901
|
|
|
692
902
|
|
|
@@ -712,14 +922,25 @@ def _rsync_upload_rundirs(
|
|
|
712
922
|
remote_destination_str = f"{username}@{hostname}:{remote_target}"
|
|
713
923
|
local_sources_str = " ".join(map(str, local_sources))
|
|
714
924
|
rsync_upload_command = f"rsync -qcaz {local_sources_str} {remote_destination_str}"
|
|
715
|
-
|
|
925
|
+
logger.info("Rsyncing to remote dir", cmd=rsync_upload_command)
|
|
926
|
+
completed_process = subprocess.run(
|
|
927
|
+
args=shlex.split(rsync_upload_command),
|
|
928
|
+
stderr=subprocess.PIPE,
|
|
929
|
+
)
|
|
716
930
|
if completed_process.returncode != 0:
|
|
717
|
-
|
|
718
|
-
"
|
|
719
|
-
|
|
720
|
-
|
|
931
|
+
error_msg = (
|
|
932
|
+
completed_process.stderr.decode("utf-8")
|
|
933
|
+
if completed_process.stderr
|
|
934
|
+
else "Unknown error"
|
|
721
935
|
)
|
|
722
936
|
|
|
937
|
+
logger.error(
|
|
938
|
+
"Erorr rsyncing to remote dir",
|
|
939
|
+
code=completed_process.returncode,
|
|
940
|
+
msg=error_msg,
|
|
941
|
+
)
|
|
942
|
+
raise RuntimeError("failed to upload local sources\n{}".format(error_msg))
|
|
943
|
+
|
|
723
944
|
|
|
724
945
|
def _sbatch_remote_runsubs(
|
|
725
946
|
remote_runsub_paths: List[Path],
|
|
@@ -739,19 +960,22 @@ def _sbatch_remote_runsubs(
|
|
|
739
960
|
ssh_command.append(f"{username}@{hostname}")
|
|
740
961
|
ssh_command.append(sbatch_commands)
|
|
741
962
|
ssh_command = " ".join(ssh_command)
|
|
742
|
-
|
|
963
|
+
logger.info("Running sbatch", cmd=ssh_command)
|
|
743
964
|
completed_process = subprocess.run(
|
|
744
|
-
args=shlex.split(ssh_command),
|
|
965
|
+
args=shlex.split(ssh_command),
|
|
966
|
+
# NOTE(agronskiy): look out for hangs and deadlocks
|
|
967
|
+
stdout=subprocess.PIPE,
|
|
968
|
+
stderr=subprocess.PIPE,
|
|
745
969
|
)
|
|
746
970
|
if completed_process.returncode != 0:
|
|
971
|
+
error_msg = completed_process.stderr.decode("utf-8")
|
|
747
972
|
raise RuntimeError(
|
|
748
|
-
"failed to submit sbatch scripts for execution\n{}".format(
|
|
749
|
-
completed_process.stderr.decode("utf-8")
|
|
750
|
-
)
|
|
973
|
+
"failed to submit sbatch scripts for execution\n{}".format(error_msg)
|
|
751
974
|
)
|
|
752
975
|
|
|
753
976
|
sbatch_output = completed_process.stdout.decode("utf-8")
|
|
754
977
|
slurm_job_ids = re.findall(r"(?<=Submitted batch job )\d+", sbatch_output)
|
|
978
|
+
logger.info("Started sbatch successfully", slurm_job_ids=slurm_job_ids)
|
|
755
979
|
return slurm_job_ids
|
|
756
980
|
|
|
757
981
|
|
|
@@ -784,7 +1008,10 @@ def _query_slurm_jobs_status(
|
|
|
784
1008
|
ssh_command.append(sacct_command)
|
|
785
1009
|
ssh_command = " ".join(ssh_command)
|
|
786
1010
|
completed_process = subprocess.run(
|
|
787
|
-
args=shlex.split(ssh_command),
|
|
1011
|
+
args=shlex.split(ssh_command),
|
|
1012
|
+
# NOTE(agronskiy): look out for hangs and deadlocks
|
|
1013
|
+
stdout=subprocess.PIPE,
|
|
1014
|
+
stderr=subprocess.PIPE,
|
|
788
1015
|
)
|
|
789
1016
|
if completed_process.returncode != 0:
|
|
790
1017
|
raise RuntimeError(
|
|
@@ -803,34 +1030,50 @@ def _query_slurm_jobs_status(
|
|
|
803
1030
|
|
|
804
1031
|
def _kill_slurm_job(
|
|
805
1032
|
slurm_job_ids: List[str], username: str, hostname: str, socket: str | None
|
|
806
|
-
) -> None:
|
|
807
|
-
"""Kill a SLURM job.
|
|
1033
|
+
) -> tuple[str | None, subprocess.CompletedProcess]:
|
|
1034
|
+
"""Kill a SLURM job, querying status first in one SSH call for efficiency.
|
|
808
1035
|
|
|
809
1036
|
Args:
|
|
810
1037
|
slurm_job_ids: List of SLURM job IDs to kill.
|
|
811
1038
|
username: SSH username.
|
|
812
1039
|
hostname: SSH hostname.
|
|
813
1040
|
socket: control socket location or None
|
|
1041
|
+
|
|
1042
|
+
Returns:
|
|
1043
|
+
Tuple of (status_string, completed_process) where status_string is the SLURM status or None
|
|
814
1044
|
"""
|
|
815
1045
|
if len(slurm_job_ids) == 0:
|
|
816
|
-
return
|
|
817
|
-
|
|
1046
|
+
return None, subprocess.CompletedProcess(args=[], returncode=0)
|
|
1047
|
+
|
|
1048
|
+
jobs_str = ",".join(slurm_job_ids)
|
|
1049
|
+
# Combine both commands in one SSH call: query THEN kill
|
|
1050
|
+
combined_command = (
|
|
1051
|
+
f"sacct -j {jobs_str} --format='JobID,State%32' --noheader -P 2>/dev/null; "
|
|
1052
|
+
f"scancel {jobs_str}"
|
|
1053
|
+
)
|
|
1054
|
+
|
|
818
1055
|
ssh_command = ["ssh"]
|
|
819
1056
|
if socket is not None:
|
|
820
1057
|
ssh_command.append(f"-S {socket}")
|
|
821
1058
|
ssh_command.append(f"{username}@{hostname}")
|
|
822
|
-
ssh_command.append(
|
|
1059
|
+
ssh_command.append(combined_command)
|
|
823
1060
|
ssh_command = " ".join(ssh_command)
|
|
1061
|
+
|
|
824
1062
|
completed_process = subprocess.run(
|
|
825
|
-
args=shlex.split(ssh_command),
|
|
1063
|
+
args=shlex.split(ssh_command),
|
|
1064
|
+
# NOTE(agronskiy): look out for hangs and deadlocks
|
|
1065
|
+
stdout=subprocess.PIPE,
|
|
1066
|
+
stderr=subprocess.PIPE,
|
|
826
1067
|
)
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
1068
|
+
|
|
1069
|
+
# Parse the sacct output (before scancel runs)
|
|
1070
|
+
sacct_output = completed_process.stdout.decode("utf-8")
|
|
1071
|
+
sacct_output_lines = sacct_output.strip().split("\n")
|
|
1072
|
+
slurm_status = None
|
|
1073
|
+
if sacct_output_lines and len(slurm_job_ids) == 1:
|
|
1074
|
+
slurm_status = _parse_slurm_job_status(slurm_job_ids[0], sacct_output_lines)
|
|
1075
|
+
|
|
1076
|
+
return slurm_status, completed_process
|
|
834
1077
|
|
|
835
1078
|
|
|
836
1079
|
def _parse_slurm_job_status(slurm_job_id: str, sacct_output_lines: List[str]) -> str:
|
|
@@ -898,7 +1141,10 @@ def _read_files_from_remote(
|
|
|
898
1141
|
ssh_command.append(cat_commands)
|
|
899
1142
|
ssh_command = " ".join(ssh_command)
|
|
900
1143
|
completed_process = subprocess.run(
|
|
901
|
-
args=shlex.split(ssh_command),
|
|
1144
|
+
args=shlex.split(ssh_command),
|
|
1145
|
+
# NOTE(agronskiy): look out for hangs and deadlocks
|
|
1146
|
+
stdout=subprocess.PIPE,
|
|
1147
|
+
stderr=subprocess.PIPE,
|
|
902
1148
|
)
|
|
903
1149
|
if completed_process.returncode != 0:
|
|
904
1150
|
raise RuntimeError(
|
|
@@ -975,9 +1221,236 @@ sbatch --dependency=afternotok:$SLURM_JOB_ID $_this_script $SLURM_JOB_ID
|
|
|
975
1221
|
""".strip()
|
|
976
1222
|
|
|
977
1223
|
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
#
|
|
981
|
-
|
|
1224
|
+
def _generate_haproxy_config_with_placeholders(cfg):
|
|
1225
|
+
"""Generate HAProxy configuration with placeholder IPs using Jinja template."""
|
|
1226
|
+
# Set up Jinja environment
|
|
1227
|
+
template_dir = Path(__file__).parent
|
|
1228
|
+
template_path = template_dir / "proxy.cfg.template"
|
|
1229
|
+
|
|
1230
|
+
if not template_path.exists():
|
|
1231
|
+
raise FileNotFoundError(f"Proxy template not found: {template_path}")
|
|
1232
|
+
|
|
1233
|
+
env = Environment(loader=FileSystemLoader(template_dir))
|
|
1234
|
+
template = env.get_template("proxy.cfg.template")
|
|
1235
|
+
|
|
1236
|
+
# Prepare template data with placeholder IPs - use actual number of nodes
|
|
1237
|
+
num_nodes = cfg.execution.num_nodes
|
|
1238
|
+
nodes = []
|
|
1239
|
+
for i in range(num_nodes):
|
|
1240
|
+
nodes.append({"ip": f"{{IP_{i}}}", "port": cfg.deployment.port})
|
|
1241
|
+
|
|
1242
|
+
# Get health check parameters from execution config
|
|
1243
|
+
proxy_config = cfg.execution.get("proxy", {}).get("config", {})
|
|
1244
|
+
health_check_path = proxy_config.get("health_check_path", "/health")
|
|
1245
|
+
health_check_status = proxy_config.get("health_check_status", 200)
|
|
1246
|
+
haproxy_port = proxy_config.get("haproxy_port", 5009)
|
|
1247
|
+
|
|
1248
|
+
# Render template
|
|
1249
|
+
config = template.render(
|
|
1250
|
+
haproxy_port=haproxy_port,
|
|
1251
|
+
health_check_path=health_check_path,
|
|
1252
|
+
health_check_status=health_check_status,
|
|
1253
|
+
nodes=nodes,
|
|
1254
|
+
)
|
|
1255
|
+
|
|
1256
|
+
return config
|
|
1257
|
+
|
|
1258
|
+
|
|
1259
|
+
def _generate_haproxy_config(cfg, nodes_ips):
|
|
1260
|
+
"""Generate HAProxy configuration using Jinja template."""
|
|
1261
|
+
# Set up Jinja environment
|
|
1262
|
+
template_dir = Path(__file__).parent
|
|
1263
|
+
template_path = template_dir / "proxy.cfg.template"
|
|
1264
|
+
|
|
1265
|
+
if not template_path.exists():
|
|
1266
|
+
raise FileNotFoundError(f"Proxy template not found: {template_path}")
|
|
1267
|
+
|
|
1268
|
+
env = Environment(loader=FileSystemLoader(template_dir))
|
|
1269
|
+
template = env.get_template("proxy.cfg.template")
|
|
1270
|
+
|
|
1271
|
+
# Prepare template data
|
|
1272
|
+
nodes = []
|
|
1273
|
+
for i, ip in enumerate(nodes_ips, 1):
|
|
1274
|
+
nodes.append(
|
|
1275
|
+
{"ip": ip, "port": cfg.deployment.port} # All nodes use the same port
|
|
1276
|
+
)
|
|
1277
|
+
|
|
1278
|
+
# Get health check parameters from deployment config
|
|
1279
|
+
health_check_path = cfg.deployment.get("health_check_path", "/health")
|
|
1280
|
+
health_check_status = cfg.deployment.get("health_check_status", 200)
|
|
1281
|
+
haproxy_port = cfg.deployment.get("haproxy_port", 5009)
|
|
1282
|
+
|
|
1283
|
+
# Render template
|
|
1284
|
+
config = template.render(
|
|
1285
|
+
haproxy_port=haproxy_port,
|
|
1286
|
+
health_check_path=health_check_path,
|
|
1287
|
+
health_check_status=health_check_status,
|
|
1288
|
+
nodes=nodes,
|
|
1289
|
+
)
|
|
1290
|
+
|
|
1291
|
+
return config
|
|
1292
|
+
|
|
1293
|
+
|
|
1294
|
+
def _generate_deployment_srun_command(
|
|
1295
|
+
cfg, deployment_mounts_list, remote_task_subdir, instance_id: int = 0
|
|
1296
|
+
):
|
|
1297
|
+
"""Generate the deployment srun command with proper node/ntask configuration.
|
|
1298
|
+
|
|
1299
|
+
Returns:
|
|
1300
|
+
tuple: (script_string, is_potentially_unsafe, debug_comment)
|
|
1301
|
+
"""
|
|
1302
|
+
s = ""
|
|
1303
|
+
debug_comment = ""
|
|
1304
|
+
is_potentially_unsafe = False
|
|
1305
|
+
|
|
1306
|
+
s += "# deployment server\n"
|
|
1307
|
+
|
|
1308
|
+
# Extract pre_cmd for later use inside container
|
|
1309
|
+
pre_cmd: str = cfg.deployment.get("pre_cmd") or ""
|
|
1310
|
+
if pre_cmd:
|
|
1311
|
+
is_potentially_unsafe = True
|
|
1312
|
+
create_pre_script_cmd = _str_to_echo_command(
|
|
1313
|
+
pre_cmd, filename="deployment_pre_cmd.sh"
|
|
1314
|
+
)
|
|
1315
|
+
debug_comment += create_pre_script_cmd.debug + "\n\n"
|
|
1316
|
+
|
|
1317
|
+
s += "# Get node IPs\n"
|
|
1318
|
+
s += "nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) )\n"
|
|
1319
|
+
s += 'nodes_array=("${nodes[@]}") # Ensure nodes are stored properly\n'
|
|
1320
|
+
s += 'export NODES_IPS_ARRAY=($(for node in "${nodes_array[@]}"; do srun --nodelist=$node --ntasks=1 --nodes=1 hostname --ip-address; done))\n'
|
|
1321
|
+
s += 'echo "Node IPs: ${NODES_IPS_ARRAY[@]}"\n'
|
|
1322
|
+
s += "# Export MASTER_IP as the first node IP\n"
|
|
1323
|
+
s += "export MASTER_IP=${NODES_IPS_ARRAY[0]}\n"
|
|
1324
|
+
s += 'echo "MASTER_IP: $MASTER_IP"\n'
|
|
1325
|
+
|
|
1326
|
+
# Add debug comment for deployment pre_cmd before srun command
|
|
1327
|
+
if debug_comment:
|
|
1328
|
+
s += "# Debug contents of deployment pre_cmd\n"
|
|
1329
|
+
s += debug_comment
|
|
1330
|
+
s += "\n"
|
|
1331
|
+
|
|
1332
|
+
s += "srun --mpi pmix --overlap "
|
|
1333
|
+
s += f"--nodes {cfg.execution.num_nodes} --ntasks {cfg.execution.get('deployment', {}).get('n_tasks', 1)} "
|
|
1334
|
+
s += "--container-image {} ".format(cfg.deployment.image)
|
|
1335
|
+
if deployment_mounts_list:
|
|
1336
|
+
s += "--container-mounts {} ".format(",".join(deployment_mounts_list))
|
|
1337
|
+
if not cfg.execution.get("mounts", {}).get("mount_home", True):
|
|
1338
|
+
s += "--no-container-mount-home "
|
|
1339
|
+
s += "--output {} ".format(remote_task_subdir / "logs" / "server-%A-%t.out")
|
|
1340
|
+
|
|
1341
|
+
deployment_env_var_names = list(
|
|
1342
|
+
cfg.execution.get("env_vars", {}).get("deployment", {})
|
|
1343
|
+
)
|
|
1344
|
+
if cfg.deployment.get("env_vars"):
|
|
1345
|
+
warnings.warn(
|
|
1346
|
+
"cfg.deployment.env_vars will be deprecated in future versions. "
|
|
1347
|
+
"Use cfg.execution.env_vars.deployment instead.",
|
|
1348
|
+
category=DeprecationWarning,
|
|
1349
|
+
stacklevel=2,
|
|
1350
|
+
)
|
|
1351
|
+
deployment_env_var_names.extend(list(cfg.deployment["env_vars"]))
|
|
1352
|
+
|
|
1353
|
+
# Always add MASTER_IP to the environment variables
|
|
1354
|
+
if "MASTER_IP" not in deployment_env_var_names:
|
|
1355
|
+
deployment_env_var_names.append("MASTER_IP")
|
|
1356
|
+
|
|
1357
|
+
if deployment_env_var_names:
|
|
1358
|
+
s += f"--container-env {','.join(deployment_env_var_names)} "
|
|
1359
|
+
|
|
1360
|
+
# Wrap deployment command to execute pre_cmd inside container if needed
|
|
1361
|
+
if pre_cmd:
|
|
1362
|
+
# Create a wrapper command that runs inside the container:
|
|
1363
|
+
# 1. Create deployment_pre_cmd.sh file
|
|
1364
|
+
# 2. Source it
|
|
1365
|
+
# 3. Execute the original deployment command
|
|
1366
|
+
create_pre_script_cmd = _str_to_echo_command(
|
|
1367
|
+
pre_cmd, filename="deployment_pre_cmd.sh"
|
|
1368
|
+
)
|
|
1369
|
+
# Escape single quotes in the deployment command for bash -c
|
|
1370
|
+
escaped_deployment_cmd = cfg.deployment.command.replace("'", "'\"'\"'")
|
|
1371
|
+
wrapped_command = (
|
|
1372
|
+
f"bash -c '{create_pre_script_cmd.cmd} && "
|
|
1373
|
+
f"source deployment_pre_cmd.sh && "
|
|
1374
|
+
f"{escaped_deployment_cmd}'"
|
|
1375
|
+
)
|
|
1376
|
+
s += "{} &\n\n".format(wrapped_command)
|
|
1377
|
+
else:
|
|
1378
|
+
s += "{} &\n\n".format(cfg.deployment.command) # run asynchronously
|
|
1379
|
+
|
|
1380
|
+
s += "SERVER_PID=$! # capture the PID of the server background srun process\n\n"
|
|
1381
|
+
|
|
1382
|
+
return s, is_potentially_unsafe, debug_comment
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def _get_wait_for_server_handler(
|
|
1386
|
+
ip_list: str,
|
|
1387
|
+
port: int,
|
|
1388
|
+
health_check_path: str,
|
|
1389
|
+
service_name: str = "server",
|
|
1390
|
+
check_pid: bool = False,
|
|
1391
|
+
):
|
|
1392
|
+
"""Generate wait for server handler that takes a list of IPs."""
|
|
1393
|
+
pid_check = ""
|
|
1394
|
+
if check_pid:
|
|
1395
|
+
pid_check = 'kill -0 "$SERVER_PID" 2>/dev/null || { echo "Server process $SERVER_PID died"; exit 1; }'
|
|
1396
|
+
|
|
1397
|
+
handler = f"""date
|
|
1398
|
+
# wait for the {service_name} to initialize
|
|
1399
|
+
for ip in {ip_list}; do
|
|
1400
|
+
echo "Waiting for {service_name} on $ip..."
|
|
1401
|
+
while [[ "$(curl -s -o /dev/null -w "%{{http_code}}" http://$ip:{port}{health_check_path})" != "200" ]]; do
|
|
1402
|
+
{pid_check}
|
|
1403
|
+
sleep 5
|
|
1404
|
+
done
|
|
1405
|
+
echo "{service_name} ready on $ip!"
|
|
1406
|
+
done
|
|
982
1407
|
date
|
|
983
1408
|
""".strip()
|
|
1409
|
+
|
|
1410
|
+
return handler
|
|
1411
|
+
|
|
1412
|
+
|
|
1413
|
+
def _get_proxy_server_srun_command(cfg, remote_task_subdir):
|
|
1414
|
+
"""Generate proxy server srun command based on proxy type."""
|
|
1415
|
+
proxy_type = cfg.execution.get("proxy", {}).get("type", "haproxy")
|
|
1416
|
+
|
|
1417
|
+
if proxy_type == "haproxy":
|
|
1418
|
+
return _generate_haproxy_srun_command(cfg, remote_task_subdir)
|
|
1419
|
+
else:
|
|
1420
|
+
raise ValueError(
|
|
1421
|
+
f"Unsupported proxy type: {proxy_type}. Currently only 'haproxy' is supported."
|
|
1422
|
+
)
|
|
1423
|
+
|
|
1424
|
+
|
|
1425
|
+
def _generate_haproxy_srun_command(cfg, remote_task_subdir):
|
|
1426
|
+
"""Generate HAProxy-specific srun command using template-based config."""
|
|
1427
|
+
s = ""
|
|
1428
|
+
s += "# Proxy load balancer\n"
|
|
1429
|
+
s += "# Copy template to config file (important for restarts)\n"
|
|
1430
|
+
s += f"cp {remote_task_subdir}/proxy.cfg.template {remote_task_subdir}/proxy.cfg\n"
|
|
1431
|
+
s += "# Replace placeholder IPs with actual node IPs\n"
|
|
1432
|
+
s += f"proxy_config_file={remote_task_subdir}/proxy.cfg\n"
|
|
1433
|
+
s += 'for i in "${!NODES_IPS_ARRAY[@]}"; do\n'
|
|
1434
|
+
s += ' ip="${NODES_IPS_ARRAY[$i]}"\n'
|
|
1435
|
+
s += ' sed -i "s/{IP_$i}/$ip/g" "$proxy_config_file"\n'
|
|
1436
|
+
s += "done\n"
|
|
1437
|
+
s += "\n"
|
|
1438
|
+
s += "srun --mpi pmix --overlap "
|
|
1439
|
+
s += "--nodes 1 --ntasks 1 "
|
|
1440
|
+
s += f"--container-image {cfg.execution.get('proxy', {}).get('image', 'haproxy:latest')} "
|
|
1441
|
+
s += f"--container-mounts {remote_task_subdir}/proxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro "
|
|
1442
|
+
s += f"--output {remote_task_subdir}/logs/proxy-%A.out "
|
|
1443
|
+
s += "haproxy -f /usr/local/etc/haproxy/haproxy.cfg &\n"
|
|
1444
|
+
s += "PROXY_PID=$! # capture the PID of the proxy background srun process\n"
|
|
1445
|
+
s += 'echo "Proxy started with PID: $PROXY_PID"\n\n'
|
|
1446
|
+
|
|
1447
|
+
# Wait for proxy to be ready on localhost
|
|
1448
|
+
proxy_config = cfg.execution.get("proxy", {}).get("config", {})
|
|
1449
|
+
haproxy_port = proxy_config.get("haproxy_port", 5009)
|
|
1450
|
+
health_path = proxy_config.get("health_check_path", "/health")
|
|
1451
|
+
s += _get_wait_for_server_handler(
|
|
1452
|
+
"127.0.0.1", haproxy_port, health_path, "Proxy", check_pid=False
|
|
1453
|
+
)
|
|
1454
|
+
s += "\n"
|
|
1455
|
+
|
|
1456
|
+
return s
|