workbench 0.8.217__py3-none-any.whl → 0.8.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  3. workbench/algorithms/dataframe/projection_2d.py +8 -2
  4. workbench/algorithms/dataframe/proximity.py +3 -0
  5. workbench/algorithms/sql/outliers.py +3 -3
  6. workbench/api/feature_set.py +0 -1
  7. workbench/core/artifacts/endpoint_core.py +2 -2
  8. workbench/core/artifacts/feature_set_core.py +185 -230
  9. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  10. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  11. workbench/model_script_utils/model_script_utils.py +15 -11
  12. workbench/model_scripts/chemprop/chemprop.template +195 -70
  13. workbench/model_scripts/chemprop/generated_model_script.py +198 -73
  14. workbench/model_scripts/chemprop/model_script_utils.py +15 -11
  15. workbench/model_scripts/custom_models/chem_info/fingerprints.py +80 -43
  16. workbench/model_scripts/pytorch_model/generated_model_script.py +2 -2
  17. workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
  18. workbench/model_scripts/xgb_model/generated_model_script.py +7 -7
  19. workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
  20. workbench/scripts/meta_model_sim.py +35 -0
  21. workbench/scripts/ml_pipeline_sqs.py +71 -2
  22. workbench/themes/light/custom.css +7 -1
  23. workbench/themes/midnight_blue/custom.css +34 -0
  24. workbench/utils/chem_utils/fingerprints.py +80 -43
  25. workbench/utils/chem_utils/projections.py +16 -6
  26. workbench/utils/meta_model_simulator.py +41 -13
  27. workbench/utils/model_utils.py +0 -1
  28. workbench/utils/plot_utils.py +146 -28
  29. workbench/utils/shap_utils.py +1 -55
  30. workbench/utils/theme_manager.py +95 -30
  31. workbench/web_interface/components/plugins/scatter_plot.py +152 -66
  32. workbench/web_interface/components/settings_menu.py +184 -0
  33. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/METADATA +4 -13
  34. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/RECORD +38 -37
  35. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/entry_points.txt +1 -0
  36. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  37. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  38. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/WHEEL +0 -0
  39. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/licenses/LICENSE +0 -0
  40. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,7 @@ DEFAULT_HYPERPARAMETERS = {
34
34
  "max_epochs": 400,
35
35
  "patience": 50,
36
36
  "batch_size": 32,
37
- # Message Passing
37
+ # Message Passing (ignored when using foundation model)
38
38
  "hidden_dim": 700,
39
39
  "depth": 6,
40
40
  "dropout": 0.1, # Lower dropout - ensemble provides regularization
@@ -45,6 +45,14 @@ DEFAULT_HYPERPARAMETERS = {
45
45
  "criterion": "mae",
46
46
  # Random seed
47
47
  "seed": 42,
48
+ # Foundation model support
49
+ # - "CheMeleon": Load CheMeleon pretrained weights (auto-downloads on first use)
50
+ # - Path to .pt file: Load custom pretrained Chemprop model
51
+ # - None: Train from scratch (default)
52
+ "from_foundation": None,
53
+ # Freeze MPNN for N epochs, then unfreeze (0 = no freezing, train all params from start)
54
+ # Recommended: 5-20 epochs when using foundation models to stabilize FFN before fine-tuning MPNN
55
+ "freeze_mpnn_epochs": 0,
48
56
  }
49
57
 
50
58
  # Template parameters (filled in by Workbench)
@@ -114,26 +122,27 @@ def _create_molecule_datapoints(
114
122
  # Model Loading (for SageMaker inference)
115
123
  # =============================================================================
116
124
  def model_fn(model_dir: str) -> dict:
117
- """Load ChemProp MPNN ensemble from the specified directory."""
118
- from lightning import pytorch as pl
125
+ """Load ChemProp MPNN ensemble from the specified directory.
119
126
 
127
+ Optimized for serverless cold starts - uses direct PyTorch inference
128
+ instead of Lightning Trainer to minimize startup time.
129
+ """
120
130
  metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
131
+
132
+ # Load all ensemble models (keep on CPU for serverless compatibility)
133
+ # ChemProp handles device placement internally
121
134
  ensemble_models = []
122
135
  for i in range(metadata["n_ensemble"]):
123
136
  model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
124
137
  model.eval()
125
138
  ensemble_models.append(model)
126
139
 
127
- # Pre-initialize trainer once during model loading (expensive operation)
128
- trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
129
-
130
140
  print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
131
141
  return {
132
142
  "ensemble_models": ensemble_models,
133
143
  "n_ensemble": metadata["n_ensemble"],
134
144
  "target_columns": metadata["target_columns"],
135
145
  "median_std": metadata["median_std"],
136
- "trainer": trainer,
137
146
  }
138
147
 
139
148
 
@@ -141,13 +150,15 @@ def model_fn(model_dir: str) -> dict:
141
150
  # Inference (for SageMaker inference)
142
151
  # =============================================================================
143
152
  def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
144
- """Make predictions with ChemProp MPNN ensemble."""
153
+ """Make predictions with ChemProp MPNN ensemble.
154
+
155
+ Uses direct PyTorch inference (no Lightning Trainer) for fast serverless inference.
156
+ """
145
157
  model_type = TEMPLATE_PARAMS["model_type"]
146
158
  model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
147
159
 
148
160
  ensemble_models = model_dict["ensemble_models"]
149
161
  target_columns = model_dict["target_columns"]
150
- trainer = model_dict["trainer"] # Use pre-initialized trainer
151
162
 
152
163
  # Load artifacts
153
164
  label_encoder = None
@@ -202,18 +213,34 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
202
213
  return df
203
214
 
204
215
  dataset = data.MoleculeDataset(datapoints)
205
- dataloader = data.build_dataloader(dataset, shuffle=False)
216
+ dataloader = data.build_dataloader(dataset, shuffle=False, batch_size=64)
206
217
 
207
- # Ensemble predictions
218
+ # Ensemble predictions using direct PyTorch inference (no Lightning Trainer)
208
219
  all_preds = []
209
220
  for model in ensemble_models:
221
+ model_preds = []
222
+ model.eval()
210
223
  with torch.inference_mode():
211
- predictions = trainer.predict(model, dataloader)
212
- preds = np.concatenate([p.numpy() for p in predictions], axis=0)
224
+ for batch in dataloader:
225
+ # TrainingBatch contains (bmg, V_d, X_d, targets, weights, lt_mask, gt_mask)
226
+ # For inference we only need bmg, V_d, X_d
227
+ bmg, V_d, X_d, *_ = batch
228
+ output = model(bmg, V_d, X_d)
229
+ model_preds.append(output.detach().cpu().numpy())
230
+
231
+ if len(model_preds) == 0:
232
+ print(f"Warning: No predictions generated. Dataset size: {len(datapoints)}")
233
+ continue
234
+
235
+ preds = np.concatenate(model_preds, axis=0)
213
236
  if preds.ndim == 3 and preds.shape[1] == 1:
214
237
  preds = preds.squeeze(axis=1)
215
238
  all_preds.append(preds)
216
239
 
240
+ if len(all_preds) == 0:
241
+ print("Error: No ensemble predictions generated")
242
+ return df
243
+
217
244
  preds = np.mean(np.stack(all_preds), axis=0)
218
245
  preds_std = np.std(np.stack(all_preds), axis=0)
219
246
  if preds.ndim == 1:
@@ -243,8 +270,11 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
243
270
  df["prediction"] = df[f"{target_columns[0]}_pred"]
244
271
  df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
245
272
 
246
- # Compute confidence from ensemble std
247
- df = _compute_std_confidence(df, model_dict["median_std"])
273
+ # Compute confidence from ensemble std (or NaN if single model)
274
+ if model_dict["median_std"] is not None:
275
+ df = _compute_std_confidence(df, model_dict["median_std"])
276
+ else:
277
+ df["confidence"] = np.nan
248
278
 
249
279
  return df
250
280
 
@@ -279,54 +309,107 @@ if __name__ == "__main__":
279
309
  )
280
310
 
281
311
  # -------------------------------------------------------------------------
282
- # Training-only helper function
312
+ # Training-only helper functions
283
313
  # -------------------------------------------------------------------------
284
- def build_mpnn_model(
285
- hyperparameters: dict,
286
- task: str = "regression",
287
- num_classes: int | None = None,
288
- n_targets: int = 1,
289
- n_extra_descriptors: int = 0,
290
- x_d_transform: nn.ScaleTransform | None = None,
291
- output_transform: nn.UnscaleTransform | None = None,
292
- task_weights: np.ndarray | None = None,
293
- ) -> models.MPNN:
294
- """Build an MPNN model with specified hyperparameters."""
295
- hidden_dim = hyperparameters["hidden_dim"]
296
- depth = hyperparameters["depth"]
314
+ def _load_foundation_weights(from_foundation: str) -> tuple[nn.BondMessagePassing, nn.Aggregation]:
315
+ """Load pretrained MPNN weights from foundation model.
316
+
317
+ Args:
318
+ from_foundation: "CheMeleon" or path to .pt file
319
+
320
+ Returns:
321
+ Tuple of (message_passing, aggregation) modules
322
+ """
323
+ import urllib.request
324
+ from pathlib import Path
325
+
326
+ print(f"Loading foundation model: {from_foundation}")
327
+
328
+ if from_foundation.lower() == "chemeleon":
329
+ # Download from Zenodo if not cached
330
+ cache_dir = Path.home() / ".chemprop" / "foundation"
331
+ cache_dir.mkdir(parents=True, exist_ok=True)
332
+ chemeleon_path = cache_dir / "chemeleon_mp.pt"
333
+
334
+ if not chemeleon_path.exists():
335
+ print(" Downloading CheMeleon weights from Zenodo...")
336
+ urllib.request.urlretrieve(
337
+ "https://zenodo.org/records/15460715/files/chemeleon_mp.pt", chemeleon_path
338
+ )
339
+ print(f" Downloaded to {chemeleon_path}")
340
+
341
+ ckpt = torch.load(chemeleon_path, weights_only=True)
342
+ mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
343
+ mp.load_state_dict(ckpt["state_dict"])
344
+ print(f" Loaded CheMeleon MPNN (hidden_dim={mp.output_dim})")
345
+ return mp, nn.MeanAggregation()
346
+
347
+ if not os.path.exists(from_foundation):
348
+ raise ValueError(f"Foundation model not found: {from_foundation}. Use 'CheMeleon' or a valid .pt path.")
349
+
350
+ ckpt = torch.load(from_foundation, weights_only=False)
351
+ if "hyper_parameters" in ckpt and "state_dict" in ckpt:
352
+ # CheMeleon-style checkpoint
353
+ mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
354
+ mp.load_state_dict(ckpt["state_dict"])
355
+ print(f" Loaded custom foundation weights (hidden_dim={mp.output_dim})")
356
+ return mp, nn.MeanAggregation()
357
+
358
+ # Full MPNN model file
359
+ pretrained = models.MPNN.load_from_file(from_foundation)
360
+ print(f" Loaded custom MPNN (hidden_dim={pretrained.message_passing.output_dim})")
361
+ return pretrained.message_passing, pretrained.agg
362
+
363
+ def _build_ffn(
364
+ task: str, input_dim: int, hyperparameters: dict,
365
+ num_classes: int | None, n_targets: int,
366
+ output_transform: nn.UnscaleTransform | None, task_weights: np.ndarray | None,
367
+ ) -> nn.Predictor:
368
+ """Build task-specific FFN head."""
297
369
  dropout = hyperparameters["dropout"]
298
370
  ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
299
371
  ffn_num_layers = hyperparameters["ffn_num_layers"]
300
372
 
301
- mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
302
- agg = nn.NormAggregation()
303
- ffn_input_dim = hidden_dim + n_extra_descriptors
304
-
305
373
  if task == "classification" and num_classes is not None:
306
- ffn = nn.MulticlassClassificationFFN(
307
- n_classes=num_classes, input_dim=ffn_input_dim,
374
+ return nn.MulticlassClassificationFFN(
375
+ n_classes=num_classes, input_dim=input_dim,
308
376
  hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
309
377
  )
378
+
379
+ from chemprop.nn.metrics import MAE, MSE
380
+ criterion_map = {"mae": MAE, "mse": MSE}
381
+ criterion_name = hyperparameters.get("criterion", "mae")
382
+ if criterion_name not in criterion_map:
383
+ raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
384
+
385
+ weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
386
+ return nn.RegressionFFN(
387
+ input_dim=input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
388
+ dropout=dropout, n_tasks=n_targets, output_transform=output_transform,
389
+ task_weights=weights_tensor, criterion=criterion_map[criterion_name](),
390
+ )
391
+
392
+ def build_mpnn_model(
393
+ hyperparameters: dict, task: str = "regression", num_classes: int | None = None,
394
+ n_targets: int = 1, n_extra_descriptors: int = 0,
395
+ x_d_transform: nn.ScaleTransform | None = None,
396
+ output_transform: nn.UnscaleTransform | None = None, task_weights: np.ndarray | None = None,
397
+ ) -> models.MPNN:
398
+ """Build MPNN model, optionally loading pretrained weights."""
399
+ from_foundation = hyperparameters.get("from_foundation")
400
+
401
+ if from_foundation:
402
+ mp, agg = _load_foundation_weights(from_foundation)
403
+ ffn_input_dim = mp.output_dim + n_extra_descriptors
310
404
  else:
311
- # Map criterion name to ChemProp metric class (must have .clone() method)
312
- from chemprop.nn.metrics import MAE, MSE
313
-
314
- criterion_map = {
315
- "mae": MAE,
316
- "mse": MSE,
317
- }
318
- criterion_name = hyperparameters.get("criterion", "mae")
319
- if criterion_name not in criterion_map:
320
- raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
321
- criterion = criterion_map[criterion_name]()
322
-
323
- weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
324
- ffn = nn.RegressionFFN(
325
- input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
326
- dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
327
- criterion=criterion,
405
+ mp = nn.BondMessagePassing(
406
+ d_h=hyperparameters["hidden_dim"], depth=hyperparameters["depth"],
407
+ dropout=hyperparameters["dropout"],
328
408
  )
409
+ agg = nn.NormAggregation()
410
+ ffn_input_dim = hyperparameters["hidden_dim"] + n_extra_descriptors
329
411
 
412
+ ffn = _build_ffn(task, ffn_input_dim, hyperparameters, num_classes, n_targets, output_transform, task_weights)
330
413
  return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
331
414
 
332
415
  # -------------------------------------------------------------------------
@@ -359,6 +442,14 @@ if __name__ == "__main__":
359
442
  print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
360
443
  print(f"Hyperparameters: {hyperparameters}")
361
444
 
445
+ # Log foundation model configuration
446
+ if hyperparameters.get("from_foundation"):
447
+ freeze_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
448
+ freeze_msg = f"MPNN frozen for {freeze_epochs} epochs" if freeze_epochs > 0 else "no freezing"
449
+ print(f"Foundation model: {hyperparameters['from_foundation']} ({freeze_msg})")
450
+ else:
451
+ print("Foundation model: None (training from scratch)")
452
+
362
453
  # Load training data
363
454
  training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
364
455
  print(f"Training Files: {training_files}")
@@ -456,7 +547,7 @@ if __name__ == "__main__":
456
547
  print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
457
548
  print(f"{'='*50}")
458
549
 
459
- # Split data
550
+ # Split data (val_extra_raw preserves unscaled features for OOF predictions)
460
551
  df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
461
552
  train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
462
553
  train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
@@ -481,10 +572,10 @@ if __name__ == "__main__":
481
572
  val_dataset.normalize_targets(target_scaler)
482
573
  output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
483
574
 
484
- train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
485
- val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
575
+ train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
576
+ val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
486
577
 
487
- # Build and train model
578
+ # Build model
488
579
  pl.seed_everything(hyperparameters["seed"] + fold_idx)
489
580
  mpnn = build_mpnn_model(
490
581
  hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
@@ -492,14 +583,39 @@ if __name__ == "__main__":
492
583
  output_transform=output_transform, task_weights=task_weights,
493
584
  )
494
585
 
495
- trainer = pl.Trainer(
496
- accelerator="auto", max_epochs=hyperparameters["max_epochs"], logger=False, enable_progress_bar=True,
497
- callbacks=[
498
- pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min"),
499
- pl.callbacks.ModelCheckpoint(dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1),
500
- ],
501
- )
502
- trainer.fit(mpnn, train_loader, val_loader)
586
+ # Train model (with optional two-phase foundation training)
587
+ freeze_mpnn_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
588
+ use_two_phase = hyperparameters.get("from_foundation") and freeze_mpnn_epochs > 0
589
+
590
+ def _set_mpnn_frozen(frozen: bool):
591
+ for param in mpnn.message_passing.parameters():
592
+ param.requires_grad = not frozen
593
+ for param in mpnn.agg.parameters():
594
+ param.requires_grad = not frozen
595
+
596
+ def _make_trainer(max_epochs: int, save_checkpoint: bool = False):
597
+ callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min")]
598
+ if save_checkpoint:
599
+ callbacks.append(pl.callbacks.ModelCheckpoint(
600
+ dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1
601
+ ))
602
+ return pl.Trainer(accelerator="auto", max_epochs=max_epochs, logger=False, enable_progress_bar=True, callbacks=callbacks)
603
+
604
+ if use_two_phase:
605
+ # Phase 1: Freeze MPNN, train FFN only
606
+ print(f"Phase 1: Training with frozen MPNN for {freeze_mpnn_epochs} epochs...")
607
+ _set_mpnn_frozen(True)
608
+ _make_trainer(freeze_mpnn_epochs).fit(mpnn, train_loader, val_loader)
609
+
610
+ # Phase 2: Unfreeze and fine-tune all
611
+ print("Phase 2: Unfreezing MPNN, continuing training...")
612
+ _set_mpnn_frozen(False)
613
+ remaining_epochs = max(1, hyperparameters["max_epochs"] - freeze_mpnn_epochs)
614
+ trainer = _make_trainer(remaining_epochs, save_checkpoint=True)
615
+ trainer.fit(mpnn, train_loader, val_loader)
616
+ else:
617
+ trainer = _make_trainer(hyperparameters["max_epochs"], save_checkpoint=True)
618
+ trainer.fit(mpnn, train_loader, val_loader)
503
619
 
504
620
  # Load best checkpoint
505
621
  if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
@@ -509,7 +625,7 @@ if __name__ == "__main__":
509
625
  mpnn.eval()
510
626
  ensemble_models.append(mpnn)
511
627
 
512
- # Out-of-fold predictions (using raw features)
628
+ # Out-of-fold predictions (using unscaled features - model's x_d_transform handles scaling)
513
629
  val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
514
630
  val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
515
631
 
@@ -599,11 +715,17 @@ if __name__ == "__main__":
599
715
  df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
600
716
  df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
601
717
 
602
- # Compute confidence from ensemble std
603
- median_std = float(np.median(preds_std[:, 0]))
604
- print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
605
- df_val = _compute_std_confidence(df_val, median_std)
606
- print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
718
+ # Compute confidence from ensemble std (or NaN for single model)
719
+ if preds_std is not None:
720
+ median_std = float(np.median(preds_std[:, 0]))
721
+ print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
722
+ df_val = _compute_std_confidence(df_val, median_std)
723
+ print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
724
+ else:
725
+ # Single model - no ensemble std available, confidence is undefined
726
+ median_std = None
727
+ df_val["confidence"] = np.nan
728
+ print("\nSingle model (n_folds=1): No ensemble std, confidence set to NaN")
607
729
 
608
730
  # -------------------------------------------------------------------------
609
731
  # Save validation predictions to S3
@@ -633,6 +755,9 @@ if __name__ == "__main__":
633
755
  "n_folds": n_folds,
634
756
  "target_columns": target_columns,
635
757
  "median_std": median_std, # For confidence calculation during inference
758
+ # Foundation model provenance (for tracking/reproducibility)
759
+ "from_foundation": hyperparameters.get("from_foundation", None),
760
+ "freeze_mpnn_epochs": hyperparameters.get("freeze_mpnn_epochs", 0),
636
761
  }
637
762
  joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
638
763