nemo-evaluator-launcher 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nemo-evaluator-launcher might be problematic. Click here for more details.

Files changed (60) hide show
  1. nemo_evaluator_launcher/__init__.py +79 -0
  2. nemo_evaluator_launcher/api/__init__.py +24 -0
  3. nemo_evaluator_launcher/api/functional.py +698 -0
  4. nemo_evaluator_launcher/api/types.py +98 -0
  5. nemo_evaluator_launcher/api/utils.py +19 -0
  6. nemo_evaluator_launcher/cli/__init__.py +15 -0
  7. nemo_evaluator_launcher/cli/export.py +267 -0
  8. nemo_evaluator_launcher/cli/info.py +512 -0
  9. nemo_evaluator_launcher/cli/kill.py +41 -0
  10. nemo_evaluator_launcher/cli/ls_runs.py +134 -0
  11. nemo_evaluator_launcher/cli/ls_tasks.py +136 -0
  12. nemo_evaluator_launcher/cli/main.py +226 -0
  13. nemo_evaluator_launcher/cli/run.py +200 -0
  14. nemo_evaluator_launcher/cli/status.py +164 -0
  15. nemo_evaluator_launcher/cli/version.py +55 -0
  16. nemo_evaluator_launcher/common/__init__.py +16 -0
  17. nemo_evaluator_launcher/common/execdb.py +283 -0
  18. nemo_evaluator_launcher/common/helpers.py +366 -0
  19. nemo_evaluator_launcher/common/logging_utils.py +357 -0
  20. nemo_evaluator_launcher/common/mapping.py +295 -0
  21. nemo_evaluator_launcher/common/printing_utils.py +93 -0
  22. nemo_evaluator_launcher/configs/__init__.py +15 -0
  23. nemo_evaluator_launcher/configs/default.yaml +28 -0
  24. nemo_evaluator_launcher/configs/deployment/generic.yaml +33 -0
  25. nemo_evaluator_launcher/configs/deployment/nim.yaml +32 -0
  26. nemo_evaluator_launcher/configs/deployment/none.yaml +16 -0
  27. nemo_evaluator_launcher/configs/deployment/sglang.yaml +38 -0
  28. nemo_evaluator_launcher/configs/deployment/trtllm.yaml +24 -0
  29. nemo_evaluator_launcher/configs/deployment/vllm.yaml +42 -0
  30. nemo_evaluator_launcher/configs/execution/lepton/default.yaml +92 -0
  31. nemo_evaluator_launcher/configs/execution/local.yaml +19 -0
  32. nemo_evaluator_launcher/configs/execution/slurm/default.yaml +34 -0
  33. nemo_evaluator_launcher/executors/__init__.py +22 -0
  34. nemo_evaluator_launcher/executors/base.py +120 -0
  35. nemo_evaluator_launcher/executors/lepton/__init__.py +16 -0
  36. nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +609 -0
  37. nemo_evaluator_launcher/executors/lepton/executor.py +1004 -0
  38. nemo_evaluator_launcher/executors/lepton/job_helpers.py +398 -0
  39. nemo_evaluator_launcher/executors/local/__init__.py +15 -0
  40. nemo_evaluator_launcher/executors/local/executor.py +605 -0
  41. nemo_evaluator_launcher/executors/local/run.template.sh +103 -0
  42. nemo_evaluator_launcher/executors/registry.py +38 -0
  43. nemo_evaluator_launcher/executors/slurm/__init__.py +15 -0
  44. nemo_evaluator_launcher/executors/slurm/executor.py +1147 -0
  45. nemo_evaluator_launcher/exporters/__init__.py +36 -0
  46. nemo_evaluator_launcher/exporters/base.py +121 -0
  47. nemo_evaluator_launcher/exporters/gsheets.py +409 -0
  48. nemo_evaluator_launcher/exporters/local.py +502 -0
  49. nemo_evaluator_launcher/exporters/mlflow.py +619 -0
  50. nemo_evaluator_launcher/exporters/registry.py +40 -0
  51. nemo_evaluator_launcher/exporters/utils.py +624 -0
  52. nemo_evaluator_launcher/exporters/wandb.py +490 -0
  53. nemo_evaluator_launcher/package_info.py +38 -0
  54. nemo_evaluator_launcher/resources/mapping.toml +380 -0
  55. nemo_evaluator_launcher-0.1.28.dist-info/METADATA +494 -0
  56. nemo_evaluator_launcher-0.1.28.dist-info/RECORD +60 -0
  57. nemo_evaluator_launcher-0.1.28.dist-info/WHEEL +5 -0
  58. nemo_evaluator_launcher-0.1.28.dist-info/entry_points.txt +3 -0
  59. nemo_evaluator_launcher-0.1.28.dist-info/licenses/LICENSE +451 -0
  60. nemo_evaluator_launcher-0.1.28.dist-info/top_level.txt +1 -0
@@ -0,0 +1,624 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ """Shared utilities for metrics and configuration handling."""
17
+
18
+ import json
19
+ import re
20
+ import subprocess
21
+ from pathlib import Path
22
+ from typing import Any, Callable, Dict, List, Tuple
23
+
24
+ import yaml
25
+
26
+ from nemo_evaluator_launcher.common.execdb import JobData
27
+ from nemo_evaluator_launcher.common.logging_utils import logger
28
+ from nemo_evaluator_launcher.common.mapping import (
29
+ get_task_from_mapping,
30
+ load_tasks_mapping,
31
+ )
32
+
33
+ # =============================================================================
34
+ # ARTIFACTS
35
+ # =============================================================================
36
+
37
+ # Artifacts to be logged by default
38
+ REQUIRED_ARTIFACTS = ["results.yml", "eval_factory_metrics.json"]
39
+ OPTIONAL_ARTIFACTS = ["omni-info.json"]
40
+
41
+
42
+ def get_relevant_artifacts() -> List[str]:
43
+ """Get relevant artifacts (required + optional)."""
44
+ return REQUIRED_ARTIFACTS + OPTIONAL_ARTIFACTS
45
+
46
+
47
+ def validate_artifacts(artifacts_dir: Path) -> Dict[str, Any]:
48
+ """Check which artifacts are available."""
49
+ if not artifacts_dir or not artifacts_dir.exists():
50
+ return {
51
+ "can_export": False,
52
+ "missing_required": REQUIRED_ARTIFACTS.copy(),
53
+ "missing_optional": OPTIONAL_ARTIFACTS.copy(),
54
+ "message": "Artifacts directory not found",
55
+ }
56
+
57
+ missing_required = [
58
+ f for f in REQUIRED_ARTIFACTS if not (artifacts_dir / f).exists()
59
+ ]
60
+ missing_optional = [
61
+ f for f in OPTIONAL_ARTIFACTS if not (artifacts_dir / f).exists()
62
+ ]
63
+ can_export = len(missing_required) == 0
64
+
65
+ message_parts = []
66
+ if missing_required:
67
+ message_parts.append(f"Missing required: {', '.join(missing_required)}")
68
+ if missing_optional:
69
+ message_parts.append(f"Missing optional: {', '.join(missing_optional)}")
70
+
71
+ return {
72
+ "can_export": can_export,
73
+ "missing_required": missing_required,
74
+ "missing_optional": missing_optional,
75
+ "message": (
76
+ ". ".join(message_parts) if message_parts else "All artifacts available"
77
+ ),
78
+ }
79
+
80
+
81
+ def get_available_artifacts(artifacts_dir: Path) -> List[str]:
82
+ """Get list of artifacts available in artifacts directory."""
83
+ if not artifacts_dir or not artifacts_dir.exists():
84
+ return []
85
+ return [
86
+ filename
87
+ for filename in get_relevant_artifacts()
88
+ if (artifacts_dir / filename).exists()
89
+ ]
90
+
91
+
92
+ # =============================================================================
93
+ # METRICS EXTRACTION
94
+ # =============================================================================
95
+
96
+
97
+ class MetricConflictError(Exception):
98
+ """Raised when attempting to set the same metric key with a different value."""
99
+
100
+
101
+ def extract_accuracy_metrics(
102
+ job_data: JobData, get_job_paths_func: Callable, log_metrics: List[str] = None
103
+ ) -> Dict[str, float]:
104
+ """Extract accuracy metrics from job results."""
105
+ try:
106
+ paths = get_job_paths_func(job_data)
107
+ artifacts_dir = _get_artifacts_dir(paths)
108
+
109
+ if not artifacts_dir or not artifacts_dir.exists():
110
+ logger.warning(f"Artifacts directory not found for job {job_data.job_id}")
111
+ return {}
112
+
113
+ # Prefer results.yml, but also merge JSON metrics to avoid missing values
114
+ metrics: Dict[str, float] = {}
115
+ results_yml = artifacts_dir / "results.yml"
116
+ if results_yml.exists():
117
+ yml_metrics = _extract_from_results_yml(results_yml)
118
+ if yml_metrics:
119
+ metrics.update(yml_metrics)
120
+
121
+ # Merge in JSON metrics (handles tasks that only emit JSON or extra fields)
122
+ json_metrics = _extract_from_json_files(artifacts_dir)
123
+ for k, v in json_metrics.items():
124
+ metrics.setdefault(k, v)
125
+
126
+ # Filter metrics if specified
127
+ if log_metrics:
128
+ filtered_metrics = {}
129
+ for metric_name, metric_value in metrics.items():
130
+ if any(filter_key in metric_name.lower() for filter_key in log_metrics):
131
+ filtered_metrics[metric_name] = metric_value
132
+ return filtered_metrics
133
+
134
+ return metrics
135
+
136
+ except Exception as e:
137
+ logger.error(f"Failed to extract metrics for job {job_data.job_id}: {e}")
138
+ return {}
139
+
140
+
141
+ # =============================================================================
142
+ # CONFIG EXTRACTION
143
+ # =============================================================================
144
+
145
+
146
+ def extract_exporter_config(
147
+ job_data: JobData, exporter_name: str, constructor_config: Dict[str, Any] = None
148
+ ) -> Dict[str, Any]:
149
+ """Extract and merge exporter configuration from multiple sources."""
150
+ config = {}
151
+
152
+ # root-level `export.<exporter-name>`
153
+ if job_data.config:
154
+ export_block = (job_data.config or {}).get("export", {})
155
+ yaml_config = (export_block or {}).get(exporter_name, {})
156
+ if yaml_config:
157
+ config.update(yaml_config)
158
+
159
+ # From webhook metadata (if triggered by webhook)
160
+ if "webhook_metadata" in job_data.data:
161
+ webhook_data = job_data.data["webhook_metadata"]
162
+ webhook_config = {
163
+ "triggered_by_webhook": True,
164
+ "webhook_source": webhook_data.get("webhook_source", "unknown"),
165
+ "source_artifact": f"{webhook_data.get('artifact_name', 'unknown')}:{webhook_data.get('artifact_version', 'unknown')}",
166
+ "config_source": webhook_data.get("config_file", "unknown"),
167
+ }
168
+ if exporter_name == "wandb" and webhook_data.get("webhook_source") == "wandb":
169
+ wandb_specific = {
170
+ "entity": webhook_data.get("entity"),
171
+ "project": webhook_data.get("project"),
172
+ "run_id": webhook_data.get("run_id"),
173
+ }
174
+ webhook_config.update({k: v for k, v in wandb_specific.items() if v})
175
+ config.update(webhook_config)
176
+
177
+ # allows CLI overrides
178
+ if constructor_config:
179
+ config.update(constructor_config)
180
+
181
+ return config
182
+
183
+
184
+ # =============================================================================
185
+ # JOB DATA EXTRACTION
186
+ # =============================================================================
187
+
188
+
189
+ def get_task_name(job_data: JobData) -> str:
190
+ """Get task name from job data."""
191
+ if "." in job_data.job_id:
192
+ try:
193
+ idx = int(job_data.job_id.split(".")[-1])
194
+ return job_data.config["evaluation"]["tasks"][idx]["name"]
195
+ except Exception:
196
+ return f"job_{job_data.job_id}"
197
+ return "all_tasks"
198
+
199
+
200
+ def get_model_name(job_data: JobData, config: Dict[str, Any] = None) -> str:
201
+ """Extract model name from config or job data."""
202
+ if config and "model_name" in config:
203
+ return config["model_name"]
204
+
205
+ job_config = job_data.config or {}
206
+ model_sources = [
207
+ job_config.get("target", {}).get("api_endpoint", {}).get("model_id"),
208
+ job_config.get("deployment", {}).get("served_model_name"),
209
+ job_data.data.get("served_model_name"),
210
+ job_data.data.get("model_name"),
211
+ job_data.data.get("model_id"),
212
+ ]
213
+
214
+ for source in model_sources:
215
+ if source:
216
+ return str(source)
217
+
218
+ return f"unknown_model_{job_data.job_id}"
219
+
220
+
221
+ def get_pipeline_id(job_data: JobData) -> str:
222
+ """Get pipeline ID for GitLab jobs."""
223
+ return job_data.data.get("pipeline_id") if job_data.executor == "gitlab" else None
224
+
225
+
226
+ def get_benchmark_info(job_data: JobData) -> Dict[str, str]:
227
+ """Get harness and benchmark info from mapping."""
228
+ try:
229
+ task_name = get_task_name(job_data)
230
+ if task_name in ["all_tasks", f"job_{job_data.job_id}"]:
231
+ return {"harness": "unknown", "benchmark": task_name}
232
+
233
+ # Use mapping to get harness info
234
+ mapping = load_tasks_mapping()
235
+ task_definition = get_task_from_mapping(task_name, mapping)
236
+ harness = task_definition.get("harness", "unknown")
237
+
238
+ # Extract benchmark name (remove harness prefix)
239
+ if "." in task_name:
240
+ benchmark = ".".join(task_name.split(".")[1:])
241
+ else:
242
+ benchmark = task_name
243
+
244
+ return {"harness": harness, "benchmark": benchmark}
245
+
246
+ except Exception as e:
247
+ logger.warning(f"Failed to get benchmark info: {e}")
248
+ return {"harness": "unknown", "benchmark": get_task_name(job_data)}
249
+
250
+
251
+ def get_container_from_mapping(job_data: JobData) -> str:
252
+ """Get container from mapping."""
253
+ try:
254
+ task_name = get_task_name(job_data)
255
+ if task_name in ["all_tasks", f"job_{job_data.job_id}"]:
256
+ return None
257
+
258
+ mapping = load_tasks_mapping()
259
+ task_definition = get_task_from_mapping(task_name, mapping)
260
+ return task_definition.get("container")
261
+
262
+ except Exception as e:
263
+ logger.warning(f"Failed to get container from mapping: {e}")
264
+ return None
265
+
266
+
267
+ def get_artifact_root(job_data: JobData) -> str:
268
+ """Get artifact root from job data."""
269
+ bench = get_benchmark_info(job_data)
270
+ h = bench.get("harness", "unknown")
271
+ b = bench.get("benchmark", get_task_name(job_data))
272
+ return f"{h}.{b}"
273
+
274
+
275
+ # =============================================================================
276
+ # GITLAB DOWNLOAD
277
+ # =============================================================================
278
+
279
+
280
+ def download_gitlab_artifacts(
281
+ paths: Dict[str, Any], export_dir: Path, extract_specific: bool = False
282
+ ) -> Dict[str, Path]:
283
+ """Download artifacts from GitLab API.
284
+
285
+ Args:
286
+ paths: Dictionary containing pipeline_id and project_id
287
+ export_dir: Local directory to save artifacts
288
+ extract_specific: If True, extract individual files; if False, keep as ZIP files
289
+
290
+ Returns:
291
+ Dictionary mapping artifact names to local file paths
292
+ """
293
+ raise NotImplementedError("Downloading from gitlab is not implemented")
294
+
295
+
296
+ # =============================================================================
297
+ # SSH UTILS
298
+ # =============================================================================
299
+
300
+
301
+ # SSH connections directory
302
+ CONNECTIONS_DIR = Path.home() / ".nemo-evaluator" / "connections"
303
+
304
+
305
+ def ssh_setup_masters(jobs: Dict[str, JobData]) -> Dict[Tuple[str, str], str]:
306
+ """Start SSH master connections for remote jobs, returns control_paths."""
307
+ remote_pairs: set[tuple[str, str]] = set()
308
+ for jd in jobs.values():
309
+ try:
310
+ # Preferred: explicit 'paths' from job data
311
+ p = (jd.data or {}).get("paths") or {}
312
+ if (
313
+ p.get("storage_type") == "remote_ssh"
314
+ and p.get("username")
315
+ and p.get("hostname")
316
+ ):
317
+ remote_pairs.add((p["username"], p["hostname"]))
318
+ continue
319
+ # Fallback: common slurm fields (works with BaseExporter.get_job_paths)
320
+ d = jd.data or {}
321
+ if jd.executor == "slurm" and d.get("username") and d.get("hostname"):
322
+ remote_pairs.add((d["username"], d["hostname"]))
323
+ except Exception:
324
+ pass
325
+
326
+ if not remote_pairs:
327
+ return {}
328
+
329
+ CONNECTIONS_DIR.mkdir(parents=True, exist_ok=True)
330
+ control_paths: Dict[Tuple[str, str], str] = {}
331
+ for username, hostname in remote_pairs:
332
+ socket_path = CONNECTIONS_DIR / f"{username}_{hostname}.sock"
333
+ try:
334
+ cmd = [
335
+ "ssh",
336
+ "-N",
337
+ "-f",
338
+ "-o",
339
+ "ControlMaster=auto",
340
+ "-o",
341
+ "ControlPersist=60",
342
+ "-o",
343
+ f"ControlPath={socket_path}",
344
+ f"{username}@{hostname}",
345
+ ]
346
+ subprocess.run(cmd, check=False, capture_output=True)
347
+ control_paths[(username, hostname)] = str(socket_path)
348
+ except Exception as e:
349
+ logger.warning(f"Failed to start SSH master for {username}@{hostname}: {e}")
350
+ return control_paths
351
+
352
+
353
+ def ssh_cleanup_masters(control_paths: Dict[Tuple[str, str], str]) -> None:
354
+ """Clean up SSH master connections from control_paths."""
355
+ for (username, hostname), socket_path in (control_paths or {}).items():
356
+ try:
357
+ cmd = [
358
+ "ssh",
359
+ "-O",
360
+ "exit",
361
+ "-o",
362
+ f"ControlPath={socket_path}",
363
+ f"{username}@{hostname}",
364
+ ]
365
+ subprocess.run(cmd, check=False, capture_output=True)
366
+ except Exception as e:
367
+ logger.warning(f"Failed to stop SSH master for {username}@{hostname}: {e}")
368
+
369
+ # Clean up
370
+ try:
371
+ Path(socket_path).unlink(missing_ok=True)
372
+ except Exception as e:
373
+ logger.warning(f"Failed to clean up file: {e}")
374
+
375
+
376
+ def ssh_download_artifacts(
377
+ paths: Dict[str, Any],
378
+ export_dir: Path,
379
+ config: Dict[str, Any] | None = None,
380
+ control_paths: Dict[Tuple[str, str], str] | None = None,
381
+ ) -> List[str]:
382
+ """Download artifacts/logs via SSH with optional connection reuse."""
383
+ exported_files: List[str] = []
384
+ copy_logs = bool((config or {}).get("copy_logs", False))
385
+ copy_artifacts = bool((config or {}).get("copy_artifacts", True))
386
+ only_required = bool((config or {}).get("only_required", True))
387
+
388
+ control_path = None
389
+ if control_paths:
390
+ control_path = control_paths.get((paths["username"], paths["hostname"]))
391
+ ssh_opts = ["-o", f"ControlPath={control_path}"] if control_path else []
392
+
393
+ def scp_file(remote_path: str, local_path: Path) -> bool:
394
+ cmd = (
395
+ ["scp"]
396
+ + ssh_opts
397
+ + [
398
+ f"{paths['username']}@{paths['hostname']}:{remote_path}",
399
+ str(local_path),
400
+ ]
401
+ )
402
+ return subprocess.run(cmd, capture_output=True).returncode == 0
403
+
404
+ export_dir.mkdir(parents=True, exist_ok=True)
405
+
406
+ # Artifacts
407
+ if copy_artifacts:
408
+ art_dir = export_dir / "artifacts"
409
+ art_dir.mkdir(parents=True, exist_ok=True)
410
+
411
+ if only_required:
412
+ for artifact in get_relevant_artifacts():
413
+ remote_file = f"{paths['remote_path']}/artifacts/{artifact}"
414
+ local_file = art_dir / artifact
415
+ local_file.parent.mkdir(parents=True, exist_ok=True)
416
+ if scp_file(remote_file, local_file):
417
+ exported_files.append(str(local_file))
418
+ else:
419
+ # Copy known files individually to avoid subfolders and satisfy tests
420
+ for artifact in get_available_artifacts(paths.get("artifacts_dir", Path())):
421
+ remote_file = f"{paths['remote_path']}/artifacts/{artifact}"
422
+ local_file = art_dir / artifact
423
+ if scp_file(remote_file, local_file):
424
+ exported_files.append(str(local_file))
425
+
426
+ # Logs (top-level only)
427
+ if copy_logs:
428
+ local_logs = export_dir / "logs"
429
+ remote_logs = f"{paths['remote_path']}/logs"
430
+ cmd = (
431
+ ["scp", "-r"]
432
+ + ssh_opts
433
+ + [
434
+ f"{paths['username']}@{paths['hostname']}:{remote_logs}/.",
435
+ str(local_logs),
436
+ ]
437
+ )
438
+ if subprocess.run(cmd, capture_output=True).returncode == 0:
439
+ for p in local_logs.iterdir():
440
+ if p.is_dir():
441
+ import shutil
442
+
443
+ shutil.rmtree(p, ignore_errors=True)
444
+ exported_files.extend([str(f) for f in local_logs.glob("*") if f.is_file()])
445
+
446
+ return exported_files
447
+
448
+
449
+ # =============================================================================
450
+ # PRIVATE HELPER FUNCTIONS
451
+ # =============================================================================
452
+
453
+
454
+ def _get_artifacts_dir(paths: Dict[str, Any]) -> Path:
455
+ """Get artifacts directory from paths."""
456
+ storage_type = paths.get("storage_type")
457
+
458
+ # For SSH-based remote access, artifacts aren't available locally yet
459
+ if storage_type == "remote_ssh":
460
+ return None
461
+
462
+ # For all local access (local_filesystem, remote_local, gitlab_ci_local)
463
+ # return the artifacts_dir from paths
464
+ return paths.get("artifacts_dir")
465
+
466
+
467
+ def _extract_metrics_from_results(results: dict) -> Dict[str, float]:
468
+ """Extract metrics from a 'results' dict (with optional 'groups'/'tasks')."""
469
+ metrics: Dict[str, float] = {}
470
+ for section in ["groups", "tasks"]:
471
+ section_data = results.get(section)
472
+ if isinstance(section_data, dict):
473
+ for task_name, task_data in section_data.items():
474
+ task_metrics = _extract_task_metrics(task_name, task_data)
475
+ _safe_update_metrics(
476
+ target=metrics,
477
+ source=task_metrics,
478
+ context=f" while extracting results for task '{task_name}'",
479
+ )
480
+ return metrics
481
+
482
+
483
+ def _extract_from_results_yml(results_yml: Path) -> Dict[str, float]:
484
+ """Extract metrics from results.yml file."""
485
+ try:
486
+ with open(results_yml, "r", encoding="utf-8") as f:
487
+ data = yaml.safe_load(f)
488
+ if not isinstance(data, dict) or "results" not in data:
489
+ return {}
490
+ return _extract_metrics_from_results(data.get("results"))
491
+ except Exception as e:
492
+ logger.warning(f"Failed to parse results.yml: {e}")
493
+ return {}
494
+
495
+
496
+ def _extract_from_json_files(artifacts_dir: Path) -> Dict[str, float]:
497
+ """Extract metrics from individual JSON result files."""
498
+ metrics = {}
499
+
500
+ for json_file in artifacts_dir.glob("*.json"):
501
+ if json_file.name in get_relevant_artifacts():
502
+ continue # Skip known artifact files, focus on task result files
503
+
504
+ try:
505
+ with open(json_file, "r", encoding="utf-8") as f:
506
+ data = json.load(f)
507
+
508
+ if isinstance(data, dict) and "score" in data:
509
+ task_name = json_file.stem
510
+ metrics[f"{task_name}_score"] = float(data["score"])
511
+
512
+ except Exception as e:
513
+ logger.warning(f"Failed to parse {json_file}: {e}")
514
+
515
+ return metrics
516
+
517
+
518
+ def _extract_task_metrics(task_name: str, task_data: dict) -> Dict[str, float]:
519
+ """Extract metrics from a task's metrics data."""
520
+ extracted = {}
521
+
522
+ metrics_data = task_data.get("metrics", {})
523
+ if "groups" in task_data:
524
+ for group_name, group_data in task_data["groups"].items():
525
+ group_extracted = _extract_task_metrics(
526
+ f"{task_name}_{group_name}", group_data
527
+ )
528
+ _safe_update_metrics(
529
+ target=extracted,
530
+ source=group_extracted,
531
+ context=f" in task '{task_name}'",
532
+ )
533
+
534
+ for metric_name, metric_data in metrics_data.items():
535
+ try:
536
+ for score_type, score_data in metric_data["scores"].items():
537
+ if score_type != metric_name:
538
+ key = f"{task_name}_{metric_name}_{score_type}"
539
+ else:
540
+ key = f"{task_name}_{metric_name}"
541
+ _safe_set_metric(
542
+ container=extracted,
543
+ key=key,
544
+ new_value=score_data["value"],
545
+ context=f" in task '{task_name}'",
546
+ )
547
+ for stat_name, stat_value in metric_data.get("stats", {}).items():
548
+ stats_key = f"{key}_{stat_name}"
549
+ _safe_set_metric(
550
+ container=extracted,
551
+ key=stats_key,
552
+ new_value=stat_value,
553
+ context=f" in task '{task_name}'",
554
+ )
555
+ except (ValueError, TypeError) as e:
556
+ logger.warning(
557
+ f"Failed to extract metric {metric_name} for task {task_name}: {e}"
558
+ )
559
+
560
+ return extracted
561
+
562
+
563
+ def _safe_set_metric(
564
+ container: Dict[str, float], key: str, new_value: float, context: str
565
+ ) -> None:
566
+ """Set a metric into container; raise with details if key exists."""
567
+ if key in container:
568
+ # Allow exact matches; warn and keep existing
569
+ if container[key] == float(new_value):
570
+ logger.warning(
571
+ f"Metric rewrite{context}: '{key}' has identical value; keeping existing. value={container[key]}"
572
+ )
573
+ return
574
+ # Different value is an error we want to surface distinctly
575
+ raise MetricConflictError(
576
+ f"Metric key collision{context}: '{key}' already set. existing={container[key]} new={new_value}"
577
+ )
578
+ container[key] = float(new_value)
579
+
580
+
581
+ def _safe_update_metrics(
582
+ target: Dict[str, float], source: Dict[str, float], context: str
583
+ ) -> None:
584
+ """Update target from source safely, raising on collisions with detailed values."""
585
+ for k, v in source.items():
586
+ _safe_set_metric(target, k, v, context)
587
+
588
+
589
+ # =============================================================================
590
+ # MLFLOW FUNCTIONS
591
+ # =============================================================================
592
+
593
+ # MLflow constants
594
+ _MLFLOW_KEY_MAX = 250
595
+ _MLFLOW_PARAM_VAL_MAX = 250
596
+ _MLFLOW_TAG_VAL_MAX = 5000
597
+
598
+ _INVALID_KEY_CHARS = re.compile(r"[^/\w.\- ]")
599
+ _MULTI_UNDERSCORE = re.compile(r"_+")
600
+
601
+
602
+ def mlflow_sanitize(s: Any, kind: str = "key") -> str:
603
+ """
604
+ Sanitize strings for MLflow logging.
605
+
606
+ kind:
607
+ - "key", "metric", "tag_key", "param_key": apply key rules
608
+ - "tag_value": apply tag value rules
609
+ - "param_value": apply param value rules
610
+ """
611
+ s = "" if s is None else str(s)
612
+
613
+ if kind in ("key", "metric", "tag_key", "param_key"):
614
+ # common replacements
615
+ s = s.replace("pass@", "pass_at_")
616
+ # drop disallowed chars, collapse underscores, trim
617
+ s = _INVALID_KEY_CHARS.sub("_", s)
618
+ s = _MULTI_UNDERSCORE.sub("_", s).strip()
619
+ return s[:_MLFLOW_KEY_MAX] or "key"
620
+
621
+ # values: normalize whitespace, enforce length
622
+ s = s.replace("\n", " ").replace("\r", " ").strip()
623
+ max_len = _MLFLOW_TAG_VAL_MAX if kind == "tag_value" else _MLFLOW_PARAM_VAL_MAX
624
+ return s[:max_len]