nemo-evaluator-launcher 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nemo-evaluator-launcher might be problematic. Click here for more details.

Files changed (23) hide show
  1. nemo_evaluator_launcher/api/functional.py +28 -2
  2. nemo_evaluator_launcher/cli/export.py +128 -10
  3. nemo_evaluator_launcher/cli/run.py +22 -3
  4. nemo_evaluator_launcher/cli/status.py +3 -1
  5. nemo_evaluator_launcher/configs/execution/local.yaml +1 -0
  6. nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +24 -4
  7. nemo_evaluator_launcher/executors/lepton/executor.py +3 -5
  8. nemo_evaluator_launcher/executors/local/executor.py +30 -5
  9. nemo_evaluator_launcher/executors/local/run.template.sh +1 -1
  10. nemo_evaluator_launcher/executors/slurm/executor.py +90 -26
  11. nemo_evaluator_launcher/exporters/base.py +9 -0
  12. nemo_evaluator_launcher/exporters/gsheets.py +27 -9
  13. nemo_evaluator_launcher/exporters/local.py +5 -0
  14. nemo_evaluator_launcher/exporters/mlflow.py +105 -32
  15. nemo_evaluator_launcher/exporters/utils.py +22 -105
  16. nemo_evaluator_launcher/exporters/wandb.py +117 -38
  17. nemo_evaluator_launcher/package_info.py +1 -1
  18. {nemo_evaluator_launcher-0.1.11.dist-info → nemo_evaluator_launcher-0.1.13.dist-info}/METADATA +1 -1
  19. {nemo_evaluator_launcher-0.1.11.dist-info → nemo_evaluator_launcher-0.1.13.dist-info}/RECORD +23 -23
  20. {nemo_evaluator_launcher-0.1.11.dist-info → nemo_evaluator_launcher-0.1.13.dist-info}/WHEEL +0 -0
  21. {nemo_evaluator_launcher-0.1.11.dist-info → nemo_evaluator_launcher-0.1.13.dist-info}/entry_points.txt +0 -0
  22. {nemo_evaluator_launcher-0.1.11.dist-info → nemo_evaluator_launcher-0.1.13.dist-info}/licenses/LICENSE +0 -0
  23. {nemo_evaluator_launcher-0.1.11.dist-info → nemo_evaluator_launcher-0.1.13.dist-info}/top_level.txt +0 -0
@@ -174,10 +174,11 @@ class SlurmExecutor(BaseExecutor):
174
174
  for idx, (slurm_job_id, remote_runsub_path) in enumerate(
175
175
  zip(slurm_job_ids, remote_runsub_paths)
176
176
  ):
177
+ job_id = generate_job_id(invocation_id, idx)
177
178
  db.write_job(
178
179
  job=JobData(
179
180
  invocation_id=invocation_id,
180
- job_id=generate_job_id(invocation_id, idx),
181
+ job_id=job_id,
181
182
  timestamp=time.time(),
182
183
  executor="slurm",
183
184
  data={
@@ -204,7 +205,7 @@ class SlurmExecutor(BaseExecutor):
204
205
  """
205
206
  db = ExecutionDB()
206
207
 
207
- # If id looks like an invocation_id (no dot), get all jobs for it
208
+ # If id looks like an invocation_id
208
209
  if "." not in id:
209
210
  jobs = db.get_jobs(id)
210
211
  if not jobs:
@@ -605,20 +606,27 @@ def _create_slurm_sbatch_script(
605
606
  s += "kill $SERVER_PID # terminate the server to finish gracefully\n\n"
606
607
 
607
608
  # auto-export
608
- if cfg.execution.get("auto_export", {}).get("destinations", []):
609
- s += _generate_auto_export_section(cfg, job_id)
609
+ ae_cfg = cfg.execution.get("auto_export")
610
+ destinations: list = []
611
+ if isinstance(ae_cfg, list):
612
+ destinations = list(ae_cfg)
613
+ elif isinstance(ae_cfg, dict) or isinstance(ae_cfg, DictConfig):
614
+ destinations = list(ae_cfg.get("destinations", []) or [])
615
+
616
+ if destinations:
617
+ export_env = dict(cfg.execution.get("env_vars", {}).get("export", {}) or {})
618
+ s += _generate_auto_export_section(cfg, job_id, destinations, export_env)
610
619
 
611
620
  return s
612
621
 
613
622
 
614
623
  def _generate_auto_export_section(
615
624
  cfg: DictConfig,
616
- job_id: str, # Complete job_id string
625
+ job_id: str,
626
+ destinations: list,
627
+ export_env: dict,
617
628
  ) -> str:
618
629
  """Generate simple auto-export section for sbatch script."""
619
- auto_export_config = cfg.execution.get("auto_export", {})
620
- destinations = auto_export_config.get("destinations", [])
621
-
622
630
  if not destinations:
623
631
  return ""
624
632
 
@@ -626,18 +634,65 @@ def _generate_auto_export_section(
626
634
  s += "EVAL_EXIT_CODE=$?\n"
627
635
  s += "if [ $EVAL_EXIT_CODE -eq 0 ]; then\n"
628
636
  s += " echo 'Evaluation completed successfully. Starting auto-export...'\n"
629
- s += " set +e\n" # per exporter failure allowed
637
+ s += " set +e\n"
630
638
  s += " set +x\n"
639
+ s += " set +u\n"
631
640
  s += ' cd "$TASK_DIR/artifacts"\n'
632
- auto_export_cfg = OmegaConf.to_container(
633
- cfg.execution.get("auto_export", {}), resolve=True
641
+
642
+ # Work with DictConfig; convert only for YAML at the end
643
+ exec_type = (
644
+ cfg.execution.type
645
+ if hasattr(cfg.execution, "type")
646
+ else cfg.execution.get("type", "slurm")
634
647
  )
635
- yaml_str = yaml.safe_dump(
636
- {"execution": {"auto_export": auto_export_cfg}}, sort_keys=False
648
+ eval_tasks = (
649
+ list(cfg.evaluation.tasks)
650
+ if hasattr(cfg, "evaluation") and hasattr(cfg.evaluation, "tasks")
651
+ else list((cfg.get("evaluation", {}) or {}).get("tasks", []) or [])
637
652
  )
653
+ export_block = cfg.get("export", {}) or {}
654
+
655
+ payload = {
656
+ "execution": {
657
+ "auto_export": {
658
+ "destinations": list(destinations),
659
+ **({"env_vars": dict(export_env)} if export_env else {}),
660
+ },
661
+ "type": exec_type,
662
+ },
663
+ "evaluation": {"tasks": eval_tasks},
664
+ }
665
+ if export_block:
666
+ # Convert just this block to plain for YAML
667
+ payload["export"] = (
668
+ OmegaConf.to_object(export_block)
669
+ if OmegaConf.is_config(export_block)
670
+ else dict(export_block)
671
+ )
672
+
673
+ # Final YAML (single conversion at the end)
674
+ payload_clean = OmegaConf.to_container(OmegaConf.create(payload), resolve=True)
675
+ yaml_str = yaml.safe_dump(payload_clean, sort_keys=False)
638
676
  s += " cat > export_config.yml << 'EOF'\n"
639
677
  s += yaml_str
640
678
  s += "EOF\n"
679
+
680
+ # write launcher config as config.yml for exporters (no core command)
681
+ submitted_yaml = yaml.safe_dump(
682
+ OmegaConf.to_container(cfg, resolve=True), sort_keys=False
683
+ )
684
+ s += " cat > config.yml << 'EOF'\n"
685
+ s += submitted_yaml
686
+ s += "EOF\n"
687
+
688
+ # Export host only env before running auto export
689
+ for k, v in (export_env or {}).items():
690
+ if isinstance(v, str) and re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", v):
691
+ s += f' export {k}="${{{v}}}"\n'
692
+ else:
693
+ esc = str(v).replace('"', '\\"')
694
+ s += f' export {k}="{esc}"\n'
695
+
641
696
  for dest in destinations:
642
697
  s += f" echo 'Exporting to {dest}...'\n"
643
698
  s += f" nemo-evaluator-launcher export {job_id} --dest {dest} || echo 'Export to {dest} failed'\n"
@@ -656,7 +711,9 @@ def _open_master_connection(
656
711
  socket: str,
657
712
  ) -> str | None:
658
713
  ssh_command = f"ssh -MNf -S {socket} {username}@{hostname}"
659
- completed_process = subprocess.run(args=shlex.split(ssh_command))
714
+ completed_process = subprocess.run(
715
+ args=shlex.split(ssh_command), capture_output=True
716
+ )
660
717
  if completed_process.returncode == 0:
661
718
  return socket
662
719
  return None
@@ -694,12 +751,17 @@ def _make_remote_execution_output_dir(
694
751
  ssh_command.append(f"{username}@{hostname}")
695
752
  ssh_command.append(mkdir_command)
696
753
  ssh_command = " ".join(ssh_command)
697
- completed_process = subprocess.run(args=shlex.split(ssh_command))
754
+ completed_process = subprocess.run(
755
+ args=shlex.split(ssh_command), capture_output=True
756
+ )
698
757
  if completed_process.returncode != 0:
758
+ error_msg = (
759
+ completed_process.stderr.decode("utf-8")
760
+ if completed_process.stderr
761
+ else "Unknown error"
762
+ )
699
763
  raise RuntimeError(
700
- "failed to make a remote execution output dir\n{}".format(
701
- completed_process.stderr.decode("utf-8")
702
- )
764
+ "failed to make a remote execution output dir\n{}".format(error_msg)
703
765
  )
704
766
 
705
767
 
@@ -725,13 +787,16 @@ def _rsync_upload_rundirs(
725
787
  remote_destination_str = f"{username}@{hostname}:{remote_target}"
726
788
  local_sources_str = " ".join(map(str, local_sources))
727
789
  rsync_upload_command = f"rsync -qcaz {local_sources_str} {remote_destination_str}"
728
- completed_process = subprocess.run(args=shlex.split(rsync_upload_command))
790
+ completed_process = subprocess.run(
791
+ args=shlex.split(rsync_upload_command), capture_output=True
792
+ )
729
793
  if completed_process.returncode != 0:
730
- raise RuntimeError(
731
- "failed to upload local sources\n{}".format(
732
- completed_process.stderr.decode("utf-8")
733
- )
794
+ error_msg = (
795
+ completed_process.stderr.decode("utf-8")
796
+ if completed_process.stderr
797
+ else "Unknown error"
734
798
  )
799
+ raise RuntimeError("failed to upload local sources\n{}".format(error_msg))
735
800
 
736
801
 
737
802
  def _sbatch_remote_runsubs(
@@ -757,10 +822,9 @@ def _sbatch_remote_runsubs(
757
822
  args=shlex.split(ssh_command), capture_output=True
758
823
  )
759
824
  if completed_process.returncode != 0:
825
+ error_msg = completed_process.stderr.decode("utf-8")
760
826
  raise RuntimeError(
761
- "failed to submit sbatch scripts for execution\n{}".format(
762
- completed_process.stderr.decode("utf-8")
763
- )
827
+ "failed to submit sbatch scripts for execution\n{}".format(error_msg)
764
828
  )
765
829
 
766
830
  sbatch_output = completed_process.stdout.decode("utf-8")
@@ -70,6 +70,15 @@ class BaseExporter(ABC):
70
70
 
71
71
  def get_job_paths(self, job_data: JobData) -> Dict[str, Any]:
72
72
  """Get result paths based on executor type from job metadata."""
73
+ # Special case: remote executor artifacts accessed locally (remote auto-export)
74
+ if job_data.data.get("storage_type") == "remote_local":
75
+ output_dir = Path(job_data.data["output_dir"])
76
+ return {
77
+ "artifacts_dir": output_dir / "artifacts",
78
+ "logs_dir": output_dir / "logs",
79
+ "storage_type": "remote_local",
80
+ }
81
+
73
82
  if job_data.executor == "local":
74
83
  output_dir = Path(job_data.data["output_dir"])
75
84
  return {
@@ -15,6 +15,7 @@
15
15
  #
16
16
  """Google Sheets evaluation results exporter."""
17
17
 
18
+ import os
18
19
  import shutil
19
20
  import tempfile
20
21
  from pathlib import Path
@@ -89,28 +90,38 @@ class GSheetsExporter(BaseExporter):
89
90
  }
90
91
 
91
92
  try:
93
+ # Load exporter config from the first job (supports job-embedded config and CLI overrides)
94
+ first_job = next(iter(jobs.values()))
95
+ gsheets_config = extract_exporter_config(first_job, "gsheets", self.config)
96
+
92
97
  # Connect to Google Sheets
93
- service_account_file = self.config.get("service_account_file")
94
- spreadsheet_name = self.config.get(
98
+ service_account_file = gsheets_config.get("service_account_file")
99
+ spreadsheet_name = gsheets_config.get(
95
100
  "spreadsheet_name", "NeMo Evaluator Launcher Results"
96
101
  )
97
102
 
98
103
  if service_account_file:
99
- gc = gspread.service_account(filename=service_account_file)
104
+ gc = gspread.service_account(
105
+ filename=os.path.expanduser(service_account_file)
106
+ )
100
107
  else:
101
108
  gc = gspread.service_account()
102
109
 
103
110
  # Get or create spreadsheet
111
+ spreadsheet_id = gsheets_config.get("spreadsheet_id")
104
112
  try:
105
- sh = gc.open(spreadsheet_name)
113
+ if spreadsheet_id:
114
+ sh = gc.open_by_key(spreadsheet_id)
115
+ else:
116
+ sh = gc.open(spreadsheet_name)
106
117
  logger.info(f"Opened existing spreadsheet: {spreadsheet_name}")
107
118
  except gspread.SpreadsheetNotFound:
119
+ if spreadsheet_id:
120
+ raise # Can't create with explicit ID
108
121
  sh = gc.create(spreadsheet_name)
109
122
  logger.info(f"Created new spreadsheet: {spreadsheet_name}")
110
- sh.share("", perm_type="anyone", role="reader")
111
123
 
112
124
  worksheet = sh.sheet1
113
-
114
125
  # Extract metrics from ALL jobs first to determine headers
115
126
  all_job_metrics = {}
116
127
  results = {}
@@ -226,16 +237,23 @@ class GSheetsExporter(BaseExporter):
226
237
  )
227
238
 
228
239
  if service_account_file:
229
- gc = gspread.service_account(filename=service_account_file)
240
+ gc = gspread.service_account(
241
+ filename=os.path.expanduser(service_account_file)
242
+ )
230
243
  else:
231
244
  gc = gspread.service_account()
232
245
 
233
246
  # Get or create spreadsheet
247
+ spreadsheet_id = gsheets_config.get("spreadsheet_id")
234
248
  try:
235
- sh = gc.open(spreadsheet_name)
249
+ if spreadsheet_id:
250
+ sh = gc.open_by_key(spreadsheet_id)
251
+ else:
252
+ sh = gc.open(spreadsheet_name)
236
253
  except gspread.SpreadsheetNotFound:
254
+ if spreadsheet_id:
255
+ raise # Can't create with explicit ID
237
256
  sh = gc.create(spreadsheet_name)
238
- sh.share("", perm_type="anyone", role="reader")
239
257
 
240
258
  worksheet = sh.sheet1
241
259
 
@@ -74,6 +74,9 @@ class LocalExporter(BaseExporter):
74
74
  # Stage artifacts per storage type
75
75
  if paths["storage_type"] == "local_filesystem":
76
76
  exported_files = self._copy_local_artifacts(paths, job_export_dir, cfg)
77
+ elif paths["storage_type"] == "remote_local":
78
+ # Same as local_filesystem (we're on the remote machine, accessing locally)
79
+ exported_files = self._copy_local_artifacts(paths, job_export_dir, cfg)
77
80
  elif paths["storage_type"] == "remote_ssh":
78
81
  exported_files = ssh_download_artifacts(
79
82
  paths, job_export_dir, cfg, None
@@ -125,6 +128,8 @@ class LocalExporter(BaseExporter):
125
128
  logger.warning(f"Failed to create {fmt} summary: {e}")
126
129
  msg += " (summary failed)"
127
130
 
131
+ meta["output_dir"] = str(job_export_dir.resolve())
132
+
128
133
  return ExportResult(
129
134
  success=True, dest=str(job_export_dir), message=msg, metadata=meta
130
135
  )
@@ -15,6 +15,7 @@
15
15
  #
16
16
  """Evaluation results exporter for MLflow tracking."""
17
17
 
18
+ import os
18
19
  import shutil
19
20
  import tempfile
20
21
  from pathlib import Path
@@ -37,6 +38,7 @@ from nemo_evaluator_launcher.exporters.registry import register_exporter
37
38
  from nemo_evaluator_launcher.exporters.utils import (
38
39
  extract_accuracy_metrics,
39
40
  extract_exporter_config,
41
+ get_artifact_root,
40
42
  get_available_artifacts,
41
43
  get_benchmark_info,
42
44
  get_task_name,
@@ -100,6 +102,21 @@ class MLflowExporter(BaseExporter):
100
102
  # Extract config using common utility
101
103
  mlflow_config = extract_exporter_config(job_data, "mlflow", self.config)
102
104
 
105
+ # resolve tracking_uri with fallbacks
106
+ tracking_uri = mlflow_config.get("tracking_uri")
107
+ if not tracking_uri:
108
+ tracking_uri = os.getenv("MLFLOW_TRACKING_URI")
109
+ # allow env var name
110
+ if tracking_uri and "://" not in tracking_uri:
111
+ tracking_uri = os.getenv(tracking_uri, tracking_uri)
112
+
113
+ if not tracking_uri:
114
+ return ExportResult(
115
+ success=False,
116
+ dest="mlflow",
117
+ message="tracking_uri is required (set export.mlflow.tracking_uri or MLFLOW_TRACKING_URI)",
118
+ )
119
+
103
120
  # Extract metrics
104
121
  log_metrics = mlflow_config.get("log_metrics", [])
105
122
  accuracy_metrics = extract_accuracy_metrics(
@@ -112,12 +129,6 @@ class MLflowExporter(BaseExporter):
112
129
  )
113
130
 
114
131
  # Set up MLflow
115
- tracking_uri = mlflow_config.get("tracking_uri")
116
- if not tracking_uri:
117
- return ExportResult(
118
- success=False, dest="mlflow", message="tracking_uri is required"
119
- )
120
-
121
132
  tracking_uri = tracking_uri.rstrip("/")
122
133
  mlflow.set_tracking_uri(tracking_uri)
123
134
 
@@ -253,37 +264,91 @@ class MLflowExporter(BaseExporter):
253
264
  try:
254
265
  # Use LocalExporter to get files locally first
255
266
  temp_dir = tempfile.mkdtemp(prefix="mlflow_artifacts_")
256
- local_exporter = LocalExporter({"output_dir": temp_dir})
267
+ local_exporter = LocalExporter(
268
+ {
269
+ "output_dir": temp_dir,
270
+ "copy_logs": mlflow_config.get(
271
+ "log_logs", mlflow_config.get("copy_logs", False)
272
+ ),
273
+ "only_required": mlflow_config.get("only_required", True),
274
+ "format": mlflow_config.get("format", None),
275
+ "log_metrics": mlflow_config.get("log_metrics", []),
276
+ "output_filename": mlflow_config.get("output_filename", None),
277
+ }
278
+ )
257
279
  local_result = local_exporter.export_job(job_data)
258
280
 
259
281
  if not local_result.success:
260
282
  logger.error(f"Failed to download artifacts: {local_result.message}")
261
283
  return []
262
284
 
263
- artifacts_dir = Path(local_result.dest) / "artifacts"
264
- logged_names = []
285
+ base_dir = Path(local_result.dest)
286
+ artifacts_dir = base_dir / "artifacts"
287
+ logs_dir = base_dir / "logs"
288
+ logged_names: list[str] = []
265
289
 
266
- task_name = get_task_name(job_data)
267
- artifact_path = task_name
290
+ artifact_path = get_artifact_root(job_data) # "<harness>.<benchmark>"
268
291
 
269
292
  # Log config at root level
270
- with tempfile.TemporaryDirectory() as tmpdir:
271
- cfg_file = Path(tmpdir) / "config.yaml"
272
- with cfg_file.open("w") as f:
273
- yaml.dump(
274
- job_data.config or {},
275
- f,
276
- default_flow_style=False,
277
- sort_keys=False,
278
- )
279
- mlflow.log_artifact(str(cfg_file))
280
-
281
- # Then log results files
282
- for fname in get_available_artifacts(artifacts_dir):
283
- file_path = artifacts_dir / fname
284
- if file_path.exists():
285
- mlflow.log_artifact(str(file_path), artifact_path=artifact_path)
286
- logged_names.append(fname)
293
+ cfg_logged = False
294
+ for fname in ("config.yml", "run_config.yml"):
295
+ p = artifacts_dir / fname
296
+ if p.exists():
297
+ mlflow.log_artifact(str(p))
298
+ cfg_logged = True
299
+ break
300
+ if not cfg_logged:
301
+ with tempfile.TemporaryDirectory() as tmpdir:
302
+ cfg_file = Path(tmpdir) / "config.yaml"
303
+ with cfg_file.open("w") as f:
304
+ yaml.dump(
305
+ job_data.config or {},
306
+ f,
307
+ default_flow_style=False,
308
+ sort_keys=False,
309
+ )
310
+ mlflow.log_artifact(str(cfg_file))
311
+
312
+ files_to_upload: list[Path] = []
313
+ if mlflow_config.get("only_required", True):
314
+ for fname in get_available_artifacts(artifacts_dir):
315
+ p = artifacts_dir / fname
316
+ if p.exists():
317
+ files_to_upload.append(p)
318
+ else:
319
+ for p in artifacts_dir.iterdir():
320
+ if p.is_file():
321
+ files_to_upload.append(p)
322
+
323
+ for fpath in files_to_upload:
324
+ rel = fpath.relative_to(artifacts_dir).as_posix()
325
+ parent = os.path.dirname(rel)
326
+ mlflow.log_artifact(
327
+ str(fpath),
328
+ artifact_path=f"{artifact_path}/artifacts/{parent}".rstrip("/"),
329
+ )
330
+ logged_names.append(rel)
331
+
332
+ # Optionally upload logs under "<harness.task>/logs"
333
+ if mlflow_config.get("log_logs", False) and logs_dir.exists():
334
+ for p in logs_dir.rglob("*"):
335
+ if p.is_file():
336
+ mlflow.log_artifact(
337
+ str(p),
338
+ artifact_path=f"{artifact_path}/logs",
339
+ )
340
+ logged_names.append(f"logs/{p.name}")
341
+
342
+ # Debug summary of what we uploaded
343
+ logger.info(
344
+ f"MLflow upload summary: files={len(logged_names)}, only_required={mlflow_config.get('only_required', True)}, log_logs={mlflow_config.get('log_logs', False)}"
345
+ )
346
+ if logger.isEnabledFor(10): # DEBUG
347
+ try:
348
+ preview = "\n - " + "\n - ".join(sorted(logged_names)[:50])
349
+ logger.debug(f"Uploaded files preview (first 50):{preview}")
350
+ except Exception:
351
+ pass
287
352
 
288
353
  # cleanup temp
289
354
  shutil.rmtree(temp_dir)
@@ -312,6 +377,18 @@ class MLflowExporter(BaseExporter):
312
377
  # Extract config using common utility
313
378
  mlflow_config = extract_exporter_config(first_job, "mlflow", self.config)
314
379
 
380
+ # resolve tracking_uri with fallbacks
381
+ tracking_uri = mlflow_config.get("tracking_uri") or os.getenv(
382
+ "MLFLOW_TRACKING_URI"
383
+ )
384
+ if tracking_uri and "://" not in tracking_uri:
385
+ tracking_uri = os.getenv(tracking_uri, tracking_uri)
386
+ if not tracking_uri:
387
+ return {
388
+ "success": False,
389
+ "error": "tracking_uri is required (set export.mlflow.tracking_uri or MLFLOW_TRACKING_URI)",
390
+ }
391
+
315
392
  # Collect metrics from ALL jobs
316
393
  all_metrics = {}
317
394
  for job_id, job_data in jobs.items():
@@ -328,10 +405,6 @@ class MLflowExporter(BaseExporter):
328
405
  }
329
406
 
330
407
  # Set up MLflow
331
- tracking_uri = mlflow_config.get("tracking_uri")
332
- if not tracking_uri:
333
- return {"success": False, "error": "tracking_uri is required"}
334
-
335
408
  tracking_uri = tracking_uri.rstrip("/")
336
409
  mlflow.set_tracking_uri(tracking_uri)
337
410
 
@@ -148,15 +148,12 @@ def extract_exporter_config(
148
148
  """Extract and merge exporter configuration from multiple sources."""
149
149
  config = {}
150
150
 
151
- # Get config from dedicated field
151
+ # root-level `export.<exporter-name>`
152
152
  if job_data.config:
153
- execution_config = job_data.config.get("execution", {})
154
- auto_export_config = execution_config.get("auto_export", {})
155
- exporter_configs = auto_export_config.get("configs", {})
156
- yaml_config = exporter_configs.get(exporter_name, {})
157
-
158
- # No conversion needed
159
- config.update(yaml_config)
153
+ export_block = (job_data.config or {}).get("export", {})
154
+ yaml_config = (export_block or {}).get(exporter_name, {})
155
+ if yaml_config:
156
+ config.update(yaml_config)
160
157
 
161
158
  # From webhook metadata (if triggered by webhook)
162
159
  if "webhook_metadata" in job_data.data:
@@ -167,8 +164,6 @@ def extract_exporter_config(
167
164
  "source_artifact": f"{webhook_data.get('artifact_name', 'unknown')}:{webhook_data.get('artifact_version', 'unknown')}",
168
165
  "config_source": webhook_data.get("config_file", "unknown"),
169
166
  }
170
-
171
- # For W&B specifically, extract run info if available
172
167
  if exporter_name == "wandb" and webhook_data.get("webhook_source") == "wandb":
173
168
  wandb_specific = {
174
169
  "entity": webhook_data.get("entity"),
@@ -176,10 +171,9 @@ def extract_exporter_config(
176
171
  "run_id": webhook_data.get("run_id"),
177
172
  }
178
173
  webhook_config.update({k: v for k, v in wandb_specific.items() if v})
179
-
180
174
  config.update(webhook_config)
181
175
 
182
- # Constructor config: allows CLI overrides
176
+ # allows CLI overrides
183
177
  if constructor_config:
184
178
  config.update(constructor_config)
185
179
 
@@ -269,6 +263,14 @@ def get_container_from_mapping(job_data: JobData) -> str:
269
263
  return None
270
264
 
271
265
 
266
+ def get_artifact_root(job_data: JobData) -> str:
267
+ """Get artifact root from job data."""
268
+ bench = get_benchmark_info(job_data)
269
+ h = bench.get("harness", "unknown")
270
+ b = bench.get("benchmark", get_task_name(job_data))
271
+ return f"{h}.{b}"
272
+
273
+
272
274
  # =============================================================================
273
275
  # GITLAB DOWNLOAD
274
276
  # =============================================================================
@@ -288,91 +290,6 @@ def download_gitlab_artifacts(
288
290
  Dictionary mapping artifact names to local file paths
289
291
  """
290
292
  raise NotImplementedError("Downloading from gitlab is not implemented")
291
- # TODO: rework this logic
292
- # pipeline_id = paths["pipeline_id"]
293
- # project_id = paths["project_id"]
294
- # gitlab_token = os.getenv("GITLAB_TOKEN")
295
- #
296
- # if not gitlab_token:
297
- # raise RuntimeError(
298
- # "GITLAB_TOKEN environment variable required for GitLab remote downloads"
299
- # )
300
- #
301
- # # GitLab API endpoint for artifacts
302
- # base_url = "TODO: replace"
303
- # artifacts_url = "TODO: replace"
304
- #
305
- # headers = {"Private-Token": gitlab_token}
306
- # downloaded_artifacts = {}
307
- #
308
- # try:
309
- # # Get pipeline jobs
310
- # response = requests.get(artifacts_url, headers=headers, timeout=30)
311
- # response.raise_for_status()
312
- # jobs = response.json()
313
- #
314
- # for job in jobs:
315
- # if job.get("artifacts_file"):
316
- # job_id = job["id"]
317
- # job_name = job.get("name", f"job_{job_id}")
318
- # artifacts_download_url = (
319
- # f"{base_url}/api/v4/projects/{project_id}/jobs/{job_id}/artifacts"
320
- # )
321
- #
322
- # logger.info(f"Downloading artifacts from job: {job_name}")
323
- #
324
- # # Download job artifacts
325
- # response = requests.get(
326
- # artifacts_download_url, headers=headers, timeout=300
327
- # )
328
- # response.raise_for_status()
329
- #
330
- # if extract_specific:
331
- # # Extract specific files from ZIP
332
- # with tempfile.NamedTemporaryFile(
333
- # suffix=".zip", delete=False
334
- # ) as temp_zip:
335
- # temp_zip.write(response.content)
336
- # temp_zip_path = temp_zip.name
337
- #
338
- # try:
339
- # with zipfile.ZipFile(temp_zip_path, "r") as zip_ref:
340
- # # Create artifacts directory
341
- # artifacts_dir = export_dir / "artifacts"
342
- # artifacts_dir.mkdir(parents=True, exist_ok=True)
343
- #
344
- # # Extract to be logged artifacts
345
- # for member in zip_ref.namelist():
346
- # filename = Path(member).name
347
- # if filename in get_relevant_artifacts():
348
- # # Extract the file
349
- # source = zip_ref.open(member)
350
- # target_path = artifacts_dir / filename
351
- # with open(target_path, "wb") as f:
352
- # f.write(source.read())
353
- # source.close()
354
- #
355
- # downloaded_artifacts[filename] = target_path
356
- # logger.info(f"Extracted: {filename}")
357
- # finally:
358
- # os.unlink(temp_zip_path)
359
- # else:
360
- # # Save as ZIP files (original behavior)
361
- # artifacts_zip = export_dir / f"job_{job_id}_artifacts.zip"
362
- # with open(artifacts_zip, "wb") as f:
363
- # f.write(response.content)
364
- #
365
- # downloaded_artifacts[f"job_{job_id}_artifacts.zip"] = artifacts_zip
366
- # logger.info(f"Downloaded: {artifacts_zip.name}")
367
- #
368
- # except requests.RequestException as e:
369
- # logger.error(f"GitLab API request failed: {e}")
370
- # raise RuntimeError(f"GitLab API request failed: {e}")
371
- # except Exception as e:
372
- # logger.error(f"GitLab remote download failed: {e}")
373
- # raise RuntimeError(f"GitLab remote download failed: {e}")
374
- #
375
- # return downloaded_artifacts
376
293
 
377
294
 
378
295
  # =============================================================================
@@ -522,16 +439,16 @@ def ssh_download_artifacts(
522
439
 
523
440
  def _get_artifacts_dir(paths: Dict[str, Any]) -> Path:
524
441
  """Get artifacts directory from paths."""
525
- if paths["storage_type"] == "local_filesystem":
526
- return paths["artifacts_dir"]
527
- elif paths["storage_type"] == "gitlab_ci_local":
528
- return paths["artifacts_dir"]
529
- elif paths["storage_type"] == "remote_ssh":
530
- return None
531
- else:
532
- logger.error(f"Unsupported storage type: {paths['storage_type']}")
442
+ storage_type = paths.get("storage_type")
443
+
444
+ # For SSH-based remote access, artifacts aren't available locally yet
445
+ if storage_type == "remote_ssh":
533
446
  return None
534
447
 
448
+ # For all local access (local_filesystem, remote_local, gitlab_ci_local)
449
+ # return the artifacts_dir from paths
450
+ return paths.get("artifacts_dir")
451
+
535
452
 
536
453
  def _extract_metrics_from_results(results: dict) -> Dict[str, float]:
537
454
  """Extract metrics from a 'results' dict (with optional 'groups'/'tasks')."""