nemo-evaluator-launcher 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nemo-evaluator-launcher might be problematic. Click here for more details.
- nemo_evaluator_launcher/__init__.py +79 -0
- nemo_evaluator_launcher/api/__init__.py +24 -0
- nemo_evaluator_launcher/api/functional.py +698 -0
- nemo_evaluator_launcher/api/types.py +98 -0
- nemo_evaluator_launcher/api/utils.py +19 -0
- nemo_evaluator_launcher/cli/__init__.py +15 -0
- nemo_evaluator_launcher/cli/export.py +267 -0
- nemo_evaluator_launcher/cli/info.py +512 -0
- nemo_evaluator_launcher/cli/kill.py +41 -0
- nemo_evaluator_launcher/cli/ls_runs.py +134 -0
- nemo_evaluator_launcher/cli/ls_tasks.py +136 -0
- nemo_evaluator_launcher/cli/main.py +226 -0
- nemo_evaluator_launcher/cli/run.py +200 -0
- nemo_evaluator_launcher/cli/status.py +164 -0
- nemo_evaluator_launcher/cli/version.py +55 -0
- nemo_evaluator_launcher/common/__init__.py +16 -0
- nemo_evaluator_launcher/common/execdb.py +283 -0
- nemo_evaluator_launcher/common/helpers.py +366 -0
- nemo_evaluator_launcher/common/logging_utils.py +357 -0
- nemo_evaluator_launcher/common/mapping.py +295 -0
- nemo_evaluator_launcher/common/printing_utils.py +93 -0
- nemo_evaluator_launcher/configs/__init__.py +15 -0
- nemo_evaluator_launcher/configs/default.yaml +28 -0
- nemo_evaluator_launcher/configs/deployment/generic.yaml +33 -0
- nemo_evaluator_launcher/configs/deployment/nim.yaml +32 -0
- nemo_evaluator_launcher/configs/deployment/none.yaml +16 -0
- nemo_evaluator_launcher/configs/deployment/sglang.yaml +38 -0
- nemo_evaluator_launcher/configs/deployment/trtllm.yaml +24 -0
- nemo_evaluator_launcher/configs/deployment/vllm.yaml +42 -0
- nemo_evaluator_launcher/configs/execution/lepton/default.yaml +92 -0
- nemo_evaluator_launcher/configs/execution/local.yaml +19 -0
- nemo_evaluator_launcher/configs/execution/slurm/default.yaml +34 -0
- nemo_evaluator_launcher/executors/__init__.py +22 -0
- nemo_evaluator_launcher/executors/base.py +120 -0
- nemo_evaluator_launcher/executors/lepton/__init__.py +16 -0
- nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +609 -0
- nemo_evaluator_launcher/executors/lepton/executor.py +1004 -0
- nemo_evaluator_launcher/executors/lepton/job_helpers.py +398 -0
- nemo_evaluator_launcher/executors/local/__init__.py +15 -0
- nemo_evaluator_launcher/executors/local/executor.py +605 -0
- nemo_evaluator_launcher/executors/local/run.template.sh +103 -0
- nemo_evaluator_launcher/executors/registry.py +38 -0
- nemo_evaluator_launcher/executors/slurm/__init__.py +15 -0
- nemo_evaluator_launcher/executors/slurm/executor.py +1147 -0
- nemo_evaluator_launcher/exporters/__init__.py +36 -0
- nemo_evaluator_launcher/exporters/base.py +121 -0
- nemo_evaluator_launcher/exporters/gsheets.py +409 -0
- nemo_evaluator_launcher/exporters/local.py +502 -0
- nemo_evaluator_launcher/exporters/mlflow.py +619 -0
- nemo_evaluator_launcher/exporters/registry.py +40 -0
- nemo_evaluator_launcher/exporters/utils.py +624 -0
- nemo_evaluator_launcher/exporters/wandb.py +490 -0
- nemo_evaluator_launcher/package_info.py +38 -0
- nemo_evaluator_launcher/resources/mapping.toml +380 -0
- nemo_evaluator_launcher-0.1.28.dist-info/METADATA +494 -0
- nemo_evaluator_launcher-0.1.28.dist-info/RECORD +60 -0
- nemo_evaluator_launcher-0.1.28.dist-info/WHEEL +5 -0
- nemo_evaluator_launcher-0.1.28.dist-info/entry_points.txt +3 -0
- nemo_evaluator_launcher-0.1.28.dist-info/licenses/LICENSE +451 -0
- nemo_evaluator_launcher-0.1.28.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,624 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
#
|
|
16
|
+
"""Shared utilities for metrics and configuration handling."""
|
|
17
|
+
|
|
18
|
+
import json
|
|
19
|
+
import re
|
|
20
|
+
import subprocess
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Any, Callable, Dict, List, Tuple
|
|
23
|
+
|
|
24
|
+
import yaml
|
|
25
|
+
|
|
26
|
+
from nemo_evaluator_launcher.common.execdb import JobData
|
|
27
|
+
from nemo_evaluator_launcher.common.logging_utils import logger
|
|
28
|
+
from nemo_evaluator_launcher.common.mapping import (
|
|
29
|
+
get_task_from_mapping,
|
|
30
|
+
load_tasks_mapping,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# =============================================================================
|
|
34
|
+
# ARTIFACTS
|
|
35
|
+
# =============================================================================
|
|
36
|
+
|
|
37
|
+
# Artifacts to be logged by default
|
|
38
|
+
REQUIRED_ARTIFACTS = ["results.yml", "eval_factory_metrics.json"]
|
|
39
|
+
OPTIONAL_ARTIFACTS = ["omni-info.json"]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_relevant_artifacts() -> List[str]:
|
|
43
|
+
"""Get relevant artifacts (required + optional)."""
|
|
44
|
+
return REQUIRED_ARTIFACTS + OPTIONAL_ARTIFACTS
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def validate_artifacts(artifacts_dir: Path) -> Dict[str, Any]:
|
|
48
|
+
"""Check which artifacts are available."""
|
|
49
|
+
if not artifacts_dir or not artifacts_dir.exists():
|
|
50
|
+
return {
|
|
51
|
+
"can_export": False,
|
|
52
|
+
"missing_required": REQUIRED_ARTIFACTS.copy(),
|
|
53
|
+
"missing_optional": OPTIONAL_ARTIFACTS.copy(),
|
|
54
|
+
"message": "Artifacts directory not found",
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
missing_required = [
|
|
58
|
+
f for f in REQUIRED_ARTIFACTS if not (artifacts_dir / f).exists()
|
|
59
|
+
]
|
|
60
|
+
missing_optional = [
|
|
61
|
+
f for f in OPTIONAL_ARTIFACTS if not (artifacts_dir / f).exists()
|
|
62
|
+
]
|
|
63
|
+
can_export = len(missing_required) == 0
|
|
64
|
+
|
|
65
|
+
message_parts = []
|
|
66
|
+
if missing_required:
|
|
67
|
+
message_parts.append(f"Missing required: {', '.join(missing_required)}")
|
|
68
|
+
if missing_optional:
|
|
69
|
+
message_parts.append(f"Missing optional: {', '.join(missing_optional)}")
|
|
70
|
+
|
|
71
|
+
return {
|
|
72
|
+
"can_export": can_export,
|
|
73
|
+
"missing_required": missing_required,
|
|
74
|
+
"missing_optional": missing_optional,
|
|
75
|
+
"message": (
|
|
76
|
+
". ".join(message_parts) if message_parts else "All artifacts available"
|
|
77
|
+
),
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_available_artifacts(artifacts_dir: Path) -> List[str]:
|
|
82
|
+
"""Get list of artifacts available in artifacts directory."""
|
|
83
|
+
if not artifacts_dir or not artifacts_dir.exists():
|
|
84
|
+
return []
|
|
85
|
+
return [
|
|
86
|
+
filename
|
|
87
|
+
for filename in get_relevant_artifacts()
|
|
88
|
+
if (artifacts_dir / filename).exists()
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# =============================================================================
|
|
93
|
+
# METRICS EXTRACTION
|
|
94
|
+
# =============================================================================
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class MetricConflictError(Exception):
|
|
98
|
+
"""Raised when attempting to set the same metric key with a different value."""
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def extract_accuracy_metrics(
|
|
102
|
+
job_data: JobData, get_job_paths_func: Callable, log_metrics: List[str] = None
|
|
103
|
+
) -> Dict[str, float]:
|
|
104
|
+
"""Extract accuracy metrics from job results."""
|
|
105
|
+
try:
|
|
106
|
+
paths = get_job_paths_func(job_data)
|
|
107
|
+
artifacts_dir = _get_artifacts_dir(paths)
|
|
108
|
+
|
|
109
|
+
if not artifacts_dir or not artifacts_dir.exists():
|
|
110
|
+
logger.warning(f"Artifacts directory not found for job {job_data.job_id}")
|
|
111
|
+
return {}
|
|
112
|
+
|
|
113
|
+
# Prefer results.yml, but also merge JSON metrics to avoid missing values
|
|
114
|
+
metrics: Dict[str, float] = {}
|
|
115
|
+
results_yml = artifacts_dir / "results.yml"
|
|
116
|
+
if results_yml.exists():
|
|
117
|
+
yml_metrics = _extract_from_results_yml(results_yml)
|
|
118
|
+
if yml_metrics:
|
|
119
|
+
metrics.update(yml_metrics)
|
|
120
|
+
|
|
121
|
+
# Merge in JSON metrics (handles tasks that only emit JSON or extra fields)
|
|
122
|
+
json_metrics = _extract_from_json_files(artifacts_dir)
|
|
123
|
+
for k, v in json_metrics.items():
|
|
124
|
+
metrics.setdefault(k, v)
|
|
125
|
+
|
|
126
|
+
# Filter metrics if specified
|
|
127
|
+
if log_metrics:
|
|
128
|
+
filtered_metrics = {}
|
|
129
|
+
for metric_name, metric_value in metrics.items():
|
|
130
|
+
if any(filter_key in metric_name.lower() for filter_key in log_metrics):
|
|
131
|
+
filtered_metrics[metric_name] = metric_value
|
|
132
|
+
return filtered_metrics
|
|
133
|
+
|
|
134
|
+
return metrics
|
|
135
|
+
|
|
136
|
+
except Exception as e:
|
|
137
|
+
logger.error(f"Failed to extract metrics for job {job_data.job_id}: {e}")
|
|
138
|
+
return {}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# =============================================================================
|
|
142
|
+
# CONFIG EXTRACTION
|
|
143
|
+
# =============================================================================
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def extract_exporter_config(
|
|
147
|
+
job_data: JobData, exporter_name: str, constructor_config: Dict[str, Any] = None
|
|
148
|
+
) -> Dict[str, Any]:
|
|
149
|
+
"""Extract and merge exporter configuration from multiple sources."""
|
|
150
|
+
config = {}
|
|
151
|
+
|
|
152
|
+
# root-level `export.<exporter-name>`
|
|
153
|
+
if job_data.config:
|
|
154
|
+
export_block = (job_data.config or {}).get("export", {})
|
|
155
|
+
yaml_config = (export_block or {}).get(exporter_name, {})
|
|
156
|
+
if yaml_config:
|
|
157
|
+
config.update(yaml_config)
|
|
158
|
+
|
|
159
|
+
# From webhook metadata (if triggered by webhook)
|
|
160
|
+
if "webhook_metadata" in job_data.data:
|
|
161
|
+
webhook_data = job_data.data["webhook_metadata"]
|
|
162
|
+
webhook_config = {
|
|
163
|
+
"triggered_by_webhook": True,
|
|
164
|
+
"webhook_source": webhook_data.get("webhook_source", "unknown"),
|
|
165
|
+
"source_artifact": f"{webhook_data.get('artifact_name', 'unknown')}:{webhook_data.get('artifact_version', 'unknown')}",
|
|
166
|
+
"config_source": webhook_data.get("config_file", "unknown"),
|
|
167
|
+
}
|
|
168
|
+
if exporter_name == "wandb" and webhook_data.get("webhook_source") == "wandb":
|
|
169
|
+
wandb_specific = {
|
|
170
|
+
"entity": webhook_data.get("entity"),
|
|
171
|
+
"project": webhook_data.get("project"),
|
|
172
|
+
"run_id": webhook_data.get("run_id"),
|
|
173
|
+
}
|
|
174
|
+
webhook_config.update({k: v for k, v in wandb_specific.items() if v})
|
|
175
|
+
config.update(webhook_config)
|
|
176
|
+
|
|
177
|
+
# allows CLI overrides
|
|
178
|
+
if constructor_config:
|
|
179
|
+
config.update(constructor_config)
|
|
180
|
+
|
|
181
|
+
return config
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# =============================================================================
|
|
185
|
+
# JOB DATA EXTRACTION
|
|
186
|
+
# =============================================================================
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def get_task_name(job_data: JobData) -> str:
|
|
190
|
+
"""Get task name from job data."""
|
|
191
|
+
if "." in job_data.job_id:
|
|
192
|
+
try:
|
|
193
|
+
idx = int(job_data.job_id.split(".")[-1])
|
|
194
|
+
return job_data.config["evaluation"]["tasks"][idx]["name"]
|
|
195
|
+
except Exception:
|
|
196
|
+
return f"job_{job_data.job_id}"
|
|
197
|
+
return "all_tasks"
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def get_model_name(job_data: JobData, config: Dict[str, Any] = None) -> str:
|
|
201
|
+
"""Extract model name from config or job data."""
|
|
202
|
+
if config and "model_name" in config:
|
|
203
|
+
return config["model_name"]
|
|
204
|
+
|
|
205
|
+
job_config = job_data.config or {}
|
|
206
|
+
model_sources = [
|
|
207
|
+
job_config.get("target", {}).get("api_endpoint", {}).get("model_id"),
|
|
208
|
+
job_config.get("deployment", {}).get("served_model_name"),
|
|
209
|
+
job_data.data.get("served_model_name"),
|
|
210
|
+
job_data.data.get("model_name"),
|
|
211
|
+
job_data.data.get("model_id"),
|
|
212
|
+
]
|
|
213
|
+
|
|
214
|
+
for source in model_sources:
|
|
215
|
+
if source:
|
|
216
|
+
return str(source)
|
|
217
|
+
|
|
218
|
+
return f"unknown_model_{job_data.job_id}"
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def get_pipeline_id(job_data: JobData) -> str:
|
|
222
|
+
"""Get pipeline ID for GitLab jobs."""
|
|
223
|
+
return job_data.data.get("pipeline_id") if job_data.executor == "gitlab" else None
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def get_benchmark_info(job_data: JobData) -> Dict[str, str]:
|
|
227
|
+
"""Get harness and benchmark info from mapping."""
|
|
228
|
+
try:
|
|
229
|
+
task_name = get_task_name(job_data)
|
|
230
|
+
if task_name in ["all_tasks", f"job_{job_data.job_id}"]:
|
|
231
|
+
return {"harness": "unknown", "benchmark": task_name}
|
|
232
|
+
|
|
233
|
+
# Use mapping to get harness info
|
|
234
|
+
mapping = load_tasks_mapping()
|
|
235
|
+
task_definition = get_task_from_mapping(task_name, mapping)
|
|
236
|
+
harness = task_definition.get("harness", "unknown")
|
|
237
|
+
|
|
238
|
+
# Extract benchmark name (remove harness prefix)
|
|
239
|
+
if "." in task_name:
|
|
240
|
+
benchmark = ".".join(task_name.split(".")[1:])
|
|
241
|
+
else:
|
|
242
|
+
benchmark = task_name
|
|
243
|
+
|
|
244
|
+
return {"harness": harness, "benchmark": benchmark}
|
|
245
|
+
|
|
246
|
+
except Exception as e:
|
|
247
|
+
logger.warning(f"Failed to get benchmark info: {e}")
|
|
248
|
+
return {"harness": "unknown", "benchmark": get_task_name(job_data)}
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def get_container_from_mapping(job_data: JobData) -> str:
|
|
252
|
+
"""Get container from mapping."""
|
|
253
|
+
try:
|
|
254
|
+
task_name = get_task_name(job_data)
|
|
255
|
+
if task_name in ["all_tasks", f"job_{job_data.job_id}"]:
|
|
256
|
+
return None
|
|
257
|
+
|
|
258
|
+
mapping = load_tasks_mapping()
|
|
259
|
+
task_definition = get_task_from_mapping(task_name, mapping)
|
|
260
|
+
return task_definition.get("container")
|
|
261
|
+
|
|
262
|
+
except Exception as e:
|
|
263
|
+
logger.warning(f"Failed to get container from mapping: {e}")
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def get_artifact_root(job_data: JobData) -> str:
|
|
268
|
+
"""Get artifact root from job data."""
|
|
269
|
+
bench = get_benchmark_info(job_data)
|
|
270
|
+
h = bench.get("harness", "unknown")
|
|
271
|
+
b = bench.get("benchmark", get_task_name(job_data))
|
|
272
|
+
return f"{h}.{b}"
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# =============================================================================
|
|
276
|
+
# GITLAB DOWNLOAD
|
|
277
|
+
# =============================================================================
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def download_gitlab_artifacts(
|
|
281
|
+
paths: Dict[str, Any], export_dir: Path, extract_specific: bool = False
|
|
282
|
+
) -> Dict[str, Path]:
|
|
283
|
+
"""Download artifacts from GitLab API.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
paths: Dictionary containing pipeline_id and project_id
|
|
287
|
+
export_dir: Local directory to save artifacts
|
|
288
|
+
extract_specific: If True, extract individual files; if False, keep as ZIP files
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
Dictionary mapping artifact names to local file paths
|
|
292
|
+
"""
|
|
293
|
+
raise NotImplementedError("Downloading from gitlab is not implemented")
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# =============================================================================
|
|
297
|
+
# SSH UTILS
|
|
298
|
+
# =============================================================================
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
# SSH connections directory
|
|
302
|
+
CONNECTIONS_DIR = Path.home() / ".nemo-evaluator" / "connections"
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def ssh_setup_masters(jobs: Dict[str, JobData]) -> Dict[Tuple[str, str], str]:
|
|
306
|
+
"""Start SSH master connections for remote jobs, returns control_paths."""
|
|
307
|
+
remote_pairs: set[tuple[str, str]] = set()
|
|
308
|
+
for jd in jobs.values():
|
|
309
|
+
try:
|
|
310
|
+
# Preferred: explicit 'paths' from job data
|
|
311
|
+
p = (jd.data or {}).get("paths") or {}
|
|
312
|
+
if (
|
|
313
|
+
p.get("storage_type") == "remote_ssh"
|
|
314
|
+
and p.get("username")
|
|
315
|
+
and p.get("hostname")
|
|
316
|
+
):
|
|
317
|
+
remote_pairs.add((p["username"], p["hostname"]))
|
|
318
|
+
continue
|
|
319
|
+
# Fallback: common slurm fields (works with BaseExporter.get_job_paths)
|
|
320
|
+
d = jd.data or {}
|
|
321
|
+
if jd.executor == "slurm" and d.get("username") and d.get("hostname"):
|
|
322
|
+
remote_pairs.add((d["username"], d["hostname"]))
|
|
323
|
+
except Exception:
|
|
324
|
+
pass
|
|
325
|
+
|
|
326
|
+
if not remote_pairs:
|
|
327
|
+
return {}
|
|
328
|
+
|
|
329
|
+
CONNECTIONS_DIR.mkdir(parents=True, exist_ok=True)
|
|
330
|
+
control_paths: Dict[Tuple[str, str], str] = {}
|
|
331
|
+
for username, hostname in remote_pairs:
|
|
332
|
+
socket_path = CONNECTIONS_DIR / f"{username}_{hostname}.sock"
|
|
333
|
+
try:
|
|
334
|
+
cmd = [
|
|
335
|
+
"ssh",
|
|
336
|
+
"-N",
|
|
337
|
+
"-f",
|
|
338
|
+
"-o",
|
|
339
|
+
"ControlMaster=auto",
|
|
340
|
+
"-o",
|
|
341
|
+
"ControlPersist=60",
|
|
342
|
+
"-o",
|
|
343
|
+
f"ControlPath={socket_path}",
|
|
344
|
+
f"{username}@{hostname}",
|
|
345
|
+
]
|
|
346
|
+
subprocess.run(cmd, check=False, capture_output=True)
|
|
347
|
+
control_paths[(username, hostname)] = str(socket_path)
|
|
348
|
+
except Exception as e:
|
|
349
|
+
logger.warning(f"Failed to start SSH master for {username}@{hostname}: {e}")
|
|
350
|
+
return control_paths
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def ssh_cleanup_masters(control_paths: Dict[Tuple[str, str], str]) -> None:
|
|
354
|
+
"""Clean up SSH master connections from control_paths."""
|
|
355
|
+
for (username, hostname), socket_path in (control_paths or {}).items():
|
|
356
|
+
try:
|
|
357
|
+
cmd = [
|
|
358
|
+
"ssh",
|
|
359
|
+
"-O",
|
|
360
|
+
"exit",
|
|
361
|
+
"-o",
|
|
362
|
+
f"ControlPath={socket_path}",
|
|
363
|
+
f"{username}@{hostname}",
|
|
364
|
+
]
|
|
365
|
+
subprocess.run(cmd, check=False, capture_output=True)
|
|
366
|
+
except Exception as e:
|
|
367
|
+
logger.warning(f"Failed to stop SSH master for {username}@{hostname}: {e}")
|
|
368
|
+
|
|
369
|
+
# Clean up
|
|
370
|
+
try:
|
|
371
|
+
Path(socket_path).unlink(missing_ok=True)
|
|
372
|
+
except Exception as e:
|
|
373
|
+
logger.warning(f"Failed to clean up file: {e}")
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def ssh_download_artifacts(
|
|
377
|
+
paths: Dict[str, Any],
|
|
378
|
+
export_dir: Path,
|
|
379
|
+
config: Dict[str, Any] | None = None,
|
|
380
|
+
control_paths: Dict[Tuple[str, str], str] | None = None,
|
|
381
|
+
) -> List[str]:
|
|
382
|
+
"""Download artifacts/logs via SSH with optional connection reuse."""
|
|
383
|
+
exported_files: List[str] = []
|
|
384
|
+
copy_logs = bool((config or {}).get("copy_logs", False))
|
|
385
|
+
copy_artifacts = bool((config or {}).get("copy_artifacts", True))
|
|
386
|
+
only_required = bool((config or {}).get("only_required", True))
|
|
387
|
+
|
|
388
|
+
control_path = None
|
|
389
|
+
if control_paths:
|
|
390
|
+
control_path = control_paths.get((paths["username"], paths["hostname"]))
|
|
391
|
+
ssh_opts = ["-o", f"ControlPath={control_path}"] if control_path else []
|
|
392
|
+
|
|
393
|
+
def scp_file(remote_path: str, local_path: Path) -> bool:
|
|
394
|
+
cmd = (
|
|
395
|
+
["scp"]
|
|
396
|
+
+ ssh_opts
|
|
397
|
+
+ [
|
|
398
|
+
f"{paths['username']}@{paths['hostname']}:{remote_path}",
|
|
399
|
+
str(local_path),
|
|
400
|
+
]
|
|
401
|
+
)
|
|
402
|
+
return subprocess.run(cmd, capture_output=True).returncode == 0
|
|
403
|
+
|
|
404
|
+
export_dir.mkdir(parents=True, exist_ok=True)
|
|
405
|
+
|
|
406
|
+
# Artifacts
|
|
407
|
+
if copy_artifacts:
|
|
408
|
+
art_dir = export_dir / "artifacts"
|
|
409
|
+
art_dir.mkdir(parents=True, exist_ok=True)
|
|
410
|
+
|
|
411
|
+
if only_required:
|
|
412
|
+
for artifact in get_relevant_artifacts():
|
|
413
|
+
remote_file = f"{paths['remote_path']}/artifacts/{artifact}"
|
|
414
|
+
local_file = art_dir / artifact
|
|
415
|
+
local_file.parent.mkdir(parents=True, exist_ok=True)
|
|
416
|
+
if scp_file(remote_file, local_file):
|
|
417
|
+
exported_files.append(str(local_file))
|
|
418
|
+
else:
|
|
419
|
+
# Copy known files individually to avoid subfolders and satisfy tests
|
|
420
|
+
for artifact in get_available_artifacts(paths.get("artifacts_dir", Path())):
|
|
421
|
+
remote_file = f"{paths['remote_path']}/artifacts/{artifact}"
|
|
422
|
+
local_file = art_dir / artifact
|
|
423
|
+
if scp_file(remote_file, local_file):
|
|
424
|
+
exported_files.append(str(local_file))
|
|
425
|
+
|
|
426
|
+
# Logs (top-level only)
|
|
427
|
+
if copy_logs:
|
|
428
|
+
local_logs = export_dir / "logs"
|
|
429
|
+
remote_logs = f"{paths['remote_path']}/logs"
|
|
430
|
+
cmd = (
|
|
431
|
+
["scp", "-r"]
|
|
432
|
+
+ ssh_opts
|
|
433
|
+
+ [
|
|
434
|
+
f"{paths['username']}@{paths['hostname']}:{remote_logs}/.",
|
|
435
|
+
str(local_logs),
|
|
436
|
+
]
|
|
437
|
+
)
|
|
438
|
+
if subprocess.run(cmd, capture_output=True).returncode == 0:
|
|
439
|
+
for p in local_logs.iterdir():
|
|
440
|
+
if p.is_dir():
|
|
441
|
+
import shutil
|
|
442
|
+
|
|
443
|
+
shutil.rmtree(p, ignore_errors=True)
|
|
444
|
+
exported_files.extend([str(f) for f in local_logs.glob("*") if f.is_file()])
|
|
445
|
+
|
|
446
|
+
return exported_files
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
# =============================================================================
|
|
450
|
+
# PRIVATE HELPER FUNCTIONS
|
|
451
|
+
# =============================================================================
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def _get_artifacts_dir(paths: Dict[str, Any]) -> Path:
|
|
455
|
+
"""Get artifacts directory from paths."""
|
|
456
|
+
storage_type = paths.get("storage_type")
|
|
457
|
+
|
|
458
|
+
# For SSH-based remote access, artifacts aren't available locally yet
|
|
459
|
+
if storage_type == "remote_ssh":
|
|
460
|
+
return None
|
|
461
|
+
|
|
462
|
+
# For all local access (local_filesystem, remote_local, gitlab_ci_local)
|
|
463
|
+
# return the artifacts_dir from paths
|
|
464
|
+
return paths.get("artifacts_dir")
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def _extract_metrics_from_results(results: dict) -> Dict[str, float]:
|
|
468
|
+
"""Extract metrics from a 'results' dict (with optional 'groups'/'tasks')."""
|
|
469
|
+
metrics: Dict[str, float] = {}
|
|
470
|
+
for section in ["groups", "tasks"]:
|
|
471
|
+
section_data = results.get(section)
|
|
472
|
+
if isinstance(section_data, dict):
|
|
473
|
+
for task_name, task_data in section_data.items():
|
|
474
|
+
task_metrics = _extract_task_metrics(task_name, task_data)
|
|
475
|
+
_safe_update_metrics(
|
|
476
|
+
target=metrics,
|
|
477
|
+
source=task_metrics,
|
|
478
|
+
context=f" while extracting results for task '{task_name}'",
|
|
479
|
+
)
|
|
480
|
+
return metrics
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _extract_from_results_yml(results_yml: Path) -> Dict[str, float]:
|
|
484
|
+
"""Extract metrics from results.yml file."""
|
|
485
|
+
try:
|
|
486
|
+
with open(results_yml, "r", encoding="utf-8") as f:
|
|
487
|
+
data = yaml.safe_load(f)
|
|
488
|
+
if not isinstance(data, dict) or "results" not in data:
|
|
489
|
+
return {}
|
|
490
|
+
return _extract_metrics_from_results(data.get("results"))
|
|
491
|
+
except Exception as e:
|
|
492
|
+
logger.warning(f"Failed to parse results.yml: {e}")
|
|
493
|
+
return {}
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def _extract_from_json_files(artifacts_dir: Path) -> Dict[str, float]:
|
|
497
|
+
"""Extract metrics from individual JSON result files."""
|
|
498
|
+
metrics = {}
|
|
499
|
+
|
|
500
|
+
for json_file in artifacts_dir.glob("*.json"):
|
|
501
|
+
if json_file.name in get_relevant_artifacts():
|
|
502
|
+
continue # Skip known artifact files, focus on task result files
|
|
503
|
+
|
|
504
|
+
try:
|
|
505
|
+
with open(json_file, "r", encoding="utf-8") as f:
|
|
506
|
+
data = json.load(f)
|
|
507
|
+
|
|
508
|
+
if isinstance(data, dict) and "score" in data:
|
|
509
|
+
task_name = json_file.stem
|
|
510
|
+
metrics[f"{task_name}_score"] = float(data["score"])
|
|
511
|
+
|
|
512
|
+
except Exception as e:
|
|
513
|
+
logger.warning(f"Failed to parse {json_file}: {e}")
|
|
514
|
+
|
|
515
|
+
return metrics
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def _extract_task_metrics(task_name: str, task_data: dict) -> Dict[str, float]:
|
|
519
|
+
"""Extract metrics from a task's metrics data."""
|
|
520
|
+
extracted = {}
|
|
521
|
+
|
|
522
|
+
metrics_data = task_data.get("metrics", {})
|
|
523
|
+
if "groups" in task_data:
|
|
524
|
+
for group_name, group_data in task_data["groups"].items():
|
|
525
|
+
group_extracted = _extract_task_metrics(
|
|
526
|
+
f"{task_name}_{group_name}", group_data
|
|
527
|
+
)
|
|
528
|
+
_safe_update_metrics(
|
|
529
|
+
target=extracted,
|
|
530
|
+
source=group_extracted,
|
|
531
|
+
context=f" in task '{task_name}'",
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
for metric_name, metric_data in metrics_data.items():
|
|
535
|
+
try:
|
|
536
|
+
for score_type, score_data in metric_data["scores"].items():
|
|
537
|
+
if score_type != metric_name:
|
|
538
|
+
key = f"{task_name}_{metric_name}_{score_type}"
|
|
539
|
+
else:
|
|
540
|
+
key = f"{task_name}_{metric_name}"
|
|
541
|
+
_safe_set_metric(
|
|
542
|
+
container=extracted,
|
|
543
|
+
key=key,
|
|
544
|
+
new_value=score_data["value"],
|
|
545
|
+
context=f" in task '{task_name}'",
|
|
546
|
+
)
|
|
547
|
+
for stat_name, stat_value in metric_data.get("stats", {}).items():
|
|
548
|
+
stats_key = f"{key}_{stat_name}"
|
|
549
|
+
_safe_set_metric(
|
|
550
|
+
container=extracted,
|
|
551
|
+
key=stats_key,
|
|
552
|
+
new_value=stat_value,
|
|
553
|
+
context=f" in task '{task_name}'",
|
|
554
|
+
)
|
|
555
|
+
except (ValueError, TypeError) as e:
|
|
556
|
+
logger.warning(
|
|
557
|
+
f"Failed to extract metric {metric_name} for task {task_name}: {e}"
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
return extracted
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def _safe_set_metric(
|
|
564
|
+
container: Dict[str, float], key: str, new_value: float, context: str
|
|
565
|
+
) -> None:
|
|
566
|
+
"""Set a metric into container; raise with details if key exists."""
|
|
567
|
+
if key in container:
|
|
568
|
+
# Allow exact matches; warn and keep existing
|
|
569
|
+
if container[key] == float(new_value):
|
|
570
|
+
logger.warning(
|
|
571
|
+
f"Metric rewrite{context}: '{key}' has identical value; keeping existing. value={container[key]}"
|
|
572
|
+
)
|
|
573
|
+
return
|
|
574
|
+
# Different value is an error we want to surface distinctly
|
|
575
|
+
raise MetricConflictError(
|
|
576
|
+
f"Metric key collision{context}: '{key}' already set. existing={container[key]} new={new_value}"
|
|
577
|
+
)
|
|
578
|
+
container[key] = float(new_value)
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def _safe_update_metrics(
|
|
582
|
+
target: Dict[str, float], source: Dict[str, float], context: str
|
|
583
|
+
) -> None:
|
|
584
|
+
"""Update target from source safely, raising on collisions with detailed values."""
|
|
585
|
+
for k, v in source.items():
|
|
586
|
+
_safe_set_metric(target, k, v, context)
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
# =============================================================================
|
|
590
|
+
# MLFLOW FUNCTIONS
|
|
591
|
+
# =============================================================================
|
|
592
|
+
|
|
593
|
+
# MLflow constants
|
|
594
|
+
_MLFLOW_KEY_MAX = 250
|
|
595
|
+
_MLFLOW_PARAM_VAL_MAX = 250
|
|
596
|
+
_MLFLOW_TAG_VAL_MAX = 5000
|
|
597
|
+
|
|
598
|
+
_INVALID_KEY_CHARS = re.compile(r"[^/\w.\- ]")
|
|
599
|
+
_MULTI_UNDERSCORE = re.compile(r"_+")
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def mlflow_sanitize(s: Any, kind: str = "key") -> str:
|
|
603
|
+
"""
|
|
604
|
+
Sanitize strings for MLflow logging.
|
|
605
|
+
|
|
606
|
+
kind:
|
|
607
|
+
- "key", "metric", "tag_key", "param_key": apply key rules
|
|
608
|
+
- "tag_value": apply tag value rules
|
|
609
|
+
- "param_value": apply param value rules
|
|
610
|
+
"""
|
|
611
|
+
s = "" if s is None else str(s)
|
|
612
|
+
|
|
613
|
+
if kind in ("key", "metric", "tag_key", "param_key"):
|
|
614
|
+
# common replacements
|
|
615
|
+
s = s.replace("pass@", "pass_at_")
|
|
616
|
+
# drop disallowed chars, collapse underscores, trim
|
|
617
|
+
s = _INVALID_KEY_CHARS.sub("_", s)
|
|
618
|
+
s = _MULTI_UNDERSCORE.sub("_", s).strip()
|
|
619
|
+
return s[:_MLFLOW_KEY_MAX] or "key"
|
|
620
|
+
|
|
621
|
+
# values: normalize whitespace, enforce length
|
|
622
|
+
s = s.replace("\n", " ").replace("\r", " ").strip()
|
|
623
|
+
max_len = _MLFLOW_TAG_VAL_MAX if kind == "tag_value" else _MLFLOW_PARAM_VAL_MAX
|
|
624
|
+
return s[:max_len]
|