workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +2 -0
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  5. workbench/algorithms/dataframe/projection_2d.py +8 -2
  6. workbench/algorithms/dataframe/proximity.py +3 -0
  7. workbench/algorithms/dataframe/smart_aggregator.py +161 -0
  8. workbench/algorithms/sql/column_stats.py +0 -1
  9. workbench/algorithms/sql/correlations.py +0 -1
  10. workbench/algorithms/sql/descriptive_stats.py +0 -1
  11. workbench/api/feature_set.py +0 -1
  12. workbench/api/meta.py +0 -1
  13. workbench/cached/cached_meta.py +0 -1
  14. workbench/cached/cached_model.py +37 -7
  15. workbench/core/artifacts/endpoint_core.py +12 -2
  16. workbench/core/artifacts/feature_set_core.py +238 -225
  17. workbench/core/cloud_platform/cloud_meta.py +0 -1
  18. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  19. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  20. workbench/model_script_utils/model_script_utils.py +30 -0
  21. workbench/model_script_utils/uq_harness.py +0 -1
  22. workbench/model_scripts/chemprop/chemprop.template +196 -68
  23. workbench/model_scripts/chemprop/generated_model_script.py +197 -72
  24. workbench/model_scripts/chemprop/model_script_utils.py +30 -0
  25. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  26. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  27. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
  29. workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
  30. workbench/model_scripts/pytorch_model/pytorch.template +47 -29
  31. workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
  32. workbench/model_scripts/script_generation.py +0 -1
  33. workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
  34. workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
  35. workbench/model_scripts/xgb_model/uq_harness.py +0 -1
  36. workbench/scripts/ml_pipeline_sqs.py +71 -2
  37. workbench/themes/dark/custom.css +85 -8
  38. workbench/themes/dark/plotly.json +6 -6
  39. workbench/themes/light/custom.css +172 -64
  40. workbench/themes/light/plotly.json +9 -9
  41. workbench/themes/midnight_blue/custom.css +82 -29
  42. workbench/themes/midnight_blue/plotly.json +1 -1
  43. workbench/utils/aws_utils.py +0 -1
  44. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  45. workbench/utils/chem_utils/projections.py +16 -6
  46. workbench/utils/chem_utils/vis.py +137 -27
  47. workbench/utils/clientside_callbacks.py +41 -0
  48. workbench/utils/markdown_utils.py +57 -0
  49. workbench/utils/model_utils.py +0 -1
  50. workbench/utils/pipeline_utils.py +0 -1
  51. workbench/utils/plot_utils.py +52 -36
  52. workbench/utils/theme_manager.py +95 -30
  53. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  54. workbench/web_interface/components/model_plot.py +2 -0
  55. workbench/web_interface/components/plugin_unit_test.py +0 -1
  56. workbench/web_interface/components/plugins/ag_table.py +2 -4
  57. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  58. workbench/web_interface/components/plugins/model_details.py +10 -6
  59. workbench/web_interface/components/plugins/scatter_plot.py +184 -85
  60. workbench/web_interface/components/settings_menu.py +185 -0
  61. workbench/web_interface/page_views/main_page.py +0 -1
  62. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
  63. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
  64. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
  65. workbench/themes/quartz/base_css.url +0 -1
  66. workbench/themes/quartz/custom.css +0 -117
  67. workbench/themes/quartz/plotly.json +0 -642
  68. workbench/themes/quartz_dark/base_css.url +0 -1
  69. workbench/themes/quartz_dark/custom.css +0 -131
  70. workbench/themes/quartz_dark/plotly.json +0 -642
  71. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
  72. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
  73. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,6 @@ import joblib
22
22
  from lightgbm import LGBMRegressor
23
23
  from mapie.regression import ConformalizedQuantileRegressor
24
24
 
25
-
26
25
  # Default confidence levels for prediction intervals
27
26
  DEFAULT_CONFIDENCE_LEVELS = [0.50, 0.68, 0.80, 0.90, 0.95]
28
27
 
@@ -20,6 +20,7 @@ import torch
20
20
  from chemprop import data, models
21
21
 
22
22
  from model_script_utils import (
23
+ cap_std_outliers,
23
24
  expand_proba_column,
24
25
  input_fn,
25
26
  output_fn,
@@ -34,7 +35,7 @@ DEFAULT_HYPERPARAMETERS = {
34
35
  "max_epochs": 400,
35
36
  "patience": 50,
36
37
  "batch_size": 32,
37
- # Message Passing
38
+ # Message Passing (ignored when using foundation model)
38
39
  "hidden_dim": 700,
39
40
  "depth": 6,
40
41
  "dropout": 0.1, # Lower dropout - ensemble provides regularization
@@ -45,6 +46,14 @@ DEFAULT_HYPERPARAMETERS = {
45
46
  "criterion": "mae",
46
47
  # Random seed
47
48
  "seed": 42,
49
+ # Foundation model support
50
+ # - "CheMeleon": Load CheMeleon pretrained weights (auto-downloads on first use)
51
+ # - Path to .pt file: Load custom pretrained Chemprop model
52
+ # - None: Train from scratch (default)
53
+ "from_foundation": None,
54
+ # Freeze MPNN for N epochs, then unfreeze (0 = no freezing, train all params from start)
55
+ # Recommended: 5-20 epochs when using foundation models to stabilize FFN before fine-tuning MPNN
56
+ "freeze_mpnn_epochs": 0,
48
57
  }
49
58
 
50
59
  # Template parameters (filled in by Workbench)
@@ -114,26 +123,27 @@ def _create_molecule_datapoints(
114
123
  # Model Loading (for SageMaker inference)
115
124
  # =============================================================================
116
125
  def model_fn(model_dir: str) -> dict:
117
- """Load ChemProp MPNN ensemble from the specified directory."""
118
- from lightning import pytorch as pl
126
+ """Load ChemProp MPNN ensemble from the specified directory.
119
127
 
128
+ Optimized for serverless cold starts - uses direct PyTorch inference
129
+ instead of Lightning Trainer to minimize startup time.
130
+ """
120
131
  metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
132
+
133
+ # Load all ensemble models (keep on CPU for serverless compatibility)
134
+ # ChemProp handles device placement internally
121
135
  ensemble_models = []
122
136
  for i in range(metadata["n_ensemble"]):
123
137
  model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
124
138
  model.eval()
125
139
  ensemble_models.append(model)
126
140
 
127
- # Pre-initialize trainer once during model loading (expensive operation)
128
- trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
129
-
130
141
  print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
131
142
  return {
132
143
  "ensemble_models": ensemble_models,
133
144
  "n_ensemble": metadata["n_ensemble"],
134
145
  "target_columns": metadata["target_columns"],
135
146
  "median_std": metadata["median_std"],
136
- "trainer": trainer,
137
147
  }
138
148
 
139
149
 
@@ -141,13 +151,15 @@ def model_fn(model_dir: str) -> dict:
141
151
  # Inference (for SageMaker inference)
142
152
  # =============================================================================
143
153
  def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
144
- """Make predictions with ChemProp MPNN ensemble."""
154
+ """Make predictions with ChemProp MPNN ensemble.
155
+
156
+ Uses direct PyTorch inference (no Lightning Trainer) for fast serverless inference.
157
+ """
145
158
  model_type = TEMPLATE_PARAMS["model_type"]
146
159
  model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
147
160
 
148
161
  ensemble_models = model_dict["ensemble_models"]
149
162
  target_columns = model_dict["target_columns"]
150
- trainer = model_dict["trainer"] # Use pre-initialized trainer
151
163
 
152
164
  # Load artifacts
153
165
  label_encoder = None
@@ -202,22 +214,39 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
202
214
  return df
203
215
 
204
216
  dataset = data.MoleculeDataset(datapoints)
205
- dataloader = data.build_dataloader(dataset, shuffle=False)
217
+ dataloader = data.build_dataloader(dataset, shuffle=False, batch_size=64)
206
218
 
207
- # Ensemble predictions
219
+ # Ensemble predictions using direct PyTorch inference (no Lightning Trainer)
208
220
  all_preds = []
209
221
  for model in ensemble_models:
222
+ model_preds = []
223
+ model.eval()
210
224
  with torch.inference_mode():
211
- predictions = trainer.predict(model, dataloader)
212
- preds = np.concatenate([p.numpy() for p in predictions], axis=0)
225
+ for batch in dataloader:
226
+ # TrainingBatch contains (bmg, V_d, X_d, targets, weights, lt_mask, gt_mask)
227
+ # For inference we only need bmg, V_d, X_d
228
+ bmg, V_d, X_d, *_ = batch
229
+ output = model(bmg, V_d, X_d)
230
+ model_preds.append(output.detach().cpu().numpy())
231
+
232
+ if len(model_preds) == 0:
233
+ print(f"Warning: No predictions generated. Dataset size: {len(datapoints)}")
234
+ continue
235
+
236
+ preds = np.concatenate(model_preds, axis=0)
213
237
  if preds.ndim == 3 and preds.shape[1] == 1:
214
238
  preds = preds.squeeze(axis=1)
215
239
  all_preds.append(preds)
216
240
 
241
+ if len(all_preds) == 0:
242
+ print("Error: No ensemble predictions generated")
243
+ return df
244
+
217
245
  preds = np.mean(np.stack(all_preds), axis=0)
218
246
  preds_std = np.std(np.stack(all_preds), axis=0)
219
247
  if preds.ndim == 1:
220
248
  preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
249
+ preds_std = cap_std_outliers(preds_std)
221
250
 
222
251
  print(f"Inference complete: {preds.shape[0]} predictions")
223
252
 
@@ -243,8 +272,11 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
243
272
  df["prediction"] = df[f"{target_columns[0]}_pred"]
244
273
  df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
245
274
 
246
- # Compute confidence from ensemble std
247
- df = _compute_std_confidence(df, model_dict["median_std"])
275
+ # Compute confidence from ensemble std (or NaN if single model)
276
+ if model_dict["median_std"] is not None:
277
+ df = _compute_std_confidence(df, model_dict["median_std"])
278
+ else:
279
+ df["confidence"] = np.nan
248
280
 
249
281
  return df
250
282
 
@@ -279,54 +311,107 @@ if __name__ == "__main__":
279
311
  )
280
312
 
281
313
  # -------------------------------------------------------------------------
282
- # Training-only helper function
314
+ # Training-only helper functions
283
315
  # -------------------------------------------------------------------------
284
- def build_mpnn_model(
285
- hyperparameters: dict,
286
- task: str = "regression",
287
- num_classes: int | None = None,
288
- n_targets: int = 1,
289
- n_extra_descriptors: int = 0,
290
- x_d_transform: nn.ScaleTransform | None = None,
291
- output_transform: nn.UnscaleTransform | None = None,
292
- task_weights: np.ndarray | None = None,
293
- ) -> models.MPNN:
294
- """Build an MPNN model with specified hyperparameters."""
295
- hidden_dim = hyperparameters["hidden_dim"]
296
- depth = hyperparameters["depth"]
316
+ def _load_foundation_weights(from_foundation: str) -> tuple[nn.BondMessagePassing, nn.Aggregation]:
317
+ """Load pretrained MPNN weights from foundation model.
318
+
319
+ Args:
320
+ from_foundation: "CheMeleon" or path to .pt file
321
+
322
+ Returns:
323
+ Tuple of (message_passing, aggregation) modules
324
+ """
325
+ import urllib.request
326
+ from pathlib import Path
327
+
328
+ print(f"Loading foundation model: {from_foundation}")
329
+
330
+ if from_foundation.lower() == "chemeleon":
331
+ # Download from Zenodo if not cached
332
+ cache_dir = Path.home() / ".chemprop" / "foundation"
333
+ cache_dir.mkdir(parents=True, exist_ok=True)
334
+ chemeleon_path = cache_dir / "chemeleon_mp.pt"
335
+
336
+ if not chemeleon_path.exists():
337
+ print(" Downloading CheMeleon weights from Zenodo...")
338
+ urllib.request.urlretrieve(
339
+ "https://zenodo.org/records/15460715/files/chemeleon_mp.pt", chemeleon_path
340
+ )
341
+ print(f" Downloaded to {chemeleon_path}")
342
+
343
+ ckpt = torch.load(chemeleon_path, weights_only=True)
344
+ mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
345
+ mp.load_state_dict(ckpt["state_dict"])
346
+ print(f" Loaded CheMeleon MPNN (hidden_dim={mp.output_dim})")
347
+ return mp, nn.MeanAggregation()
348
+
349
+ if not os.path.exists(from_foundation):
350
+ raise ValueError(f"Foundation model not found: {from_foundation}. Use 'CheMeleon' or a valid .pt path.")
351
+
352
+ ckpt = torch.load(from_foundation, weights_only=False)
353
+ if "hyper_parameters" in ckpt and "state_dict" in ckpt:
354
+ # CheMeleon-style checkpoint
355
+ mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
356
+ mp.load_state_dict(ckpt["state_dict"])
357
+ print(f" Loaded custom foundation weights (hidden_dim={mp.output_dim})")
358
+ return mp, nn.MeanAggregation()
359
+
360
+ # Full MPNN model file
361
+ pretrained = models.MPNN.load_from_file(from_foundation)
362
+ print(f" Loaded custom MPNN (hidden_dim={pretrained.message_passing.output_dim})")
363
+ return pretrained.message_passing, pretrained.agg
364
+
365
+ def _build_ffn(
366
+ task: str, input_dim: int, hyperparameters: dict,
367
+ num_classes: int | None, n_targets: int,
368
+ output_transform: nn.UnscaleTransform | None, task_weights: np.ndarray | None,
369
+ ) -> nn.Predictor:
370
+ """Build task-specific FFN head."""
297
371
  dropout = hyperparameters["dropout"]
298
372
  ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
299
373
  ffn_num_layers = hyperparameters["ffn_num_layers"]
300
374
 
301
- mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
302
- agg = nn.NormAggregation()
303
- ffn_input_dim = hidden_dim + n_extra_descriptors
304
-
305
375
  if task == "classification" and num_classes is not None:
306
- ffn = nn.MulticlassClassificationFFN(
307
- n_classes=num_classes, input_dim=ffn_input_dim,
376
+ return nn.MulticlassClassificationFFN(
377
+ n_classes=num_classes, input_dim=input_dim,
308
378
  hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
309
379
  )
380
+
381
+ from chemprop.nn.metrics import MAE, MSE
382
+ criterion_map = {"mae": MAE, "mse": MSE}
383
+ criterion_name = hyperparameters.get("criterion", "mae")
384
+ if criterion_name not in criterion_map:
385
+ raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
386
+
387
+ weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
388
+ return nn.RegressionFFN(
389
+ input_dim=input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
390
+ dropout=dropout, n_tasks=n_targets, output_transform=output_transform,
391
+ task_weights=weights_tensor, criterion=criterion_map[criterion_name](),
392
+ )
393
+
394
+ def build_mpnn_model(
395
+ hyperparameters: dict, task: str = "regression", num_classes: int | None = None,
396
+ n_targets: int = 1, n_extra_descriptors: int = 0,
397
+ x_d_transform: nn.ScaleTransform | None = None,
398
+ output_transform: nn.UnscaleTransform | None = None, task_weights: np.ndarray | None = None,
399
+ ) -> models.MPNN:
400
+ """Build MPNN model, optionally loading pretrained weights."""
401
+ from_foundation = hyperparameters.get("from_foundation")
402
+
403
+ if from_foundation:
404
+ mp, agg = _load_foundation_weights(from_foundation)
405
+ ffn_input_dim = mp.output_dim + n_extra_descriptors
310
406
  else:
311
- # Map criterion name to ChemProp metric class (must have .clone() method)
312
- from chemprop.nn.metrics import MAE, MSE
313
-
314
- criterion_map = {
315
- "mae": MAE,
316
- "mse": MSE,
317
- }
318
- criterion_name = hyperparameters.get("criterion", "mae")
319
- if criterion_name not in criterion_map:
320
- raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
321
- criterion = criterion_map[criterion_name]()
322
-
323
- weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
324
- ffn = nn.RegressionFFN(
325
- input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
326
- dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
327
- criterion=criterion,
407
+ mp = nn.BondMessagePassing(
408
+ d_h=hyperparameters["hidden_dim"], depth=hyperparameters["depth"],
409
+ dropout=hyperparameters["dropout"],
328
410
  )
411
+ agg = nn.NormAggregation()
412
+ ffn_input_dim = hyperparameters["hidden_dim"] + n_extra_descriptors
329
413
 
414
+ ffn = _build_ffn(task, ffn_input_dim, hyperparameters, num_classes, n_targets, output_transform, task_weights)
330
415
  return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
331
416
 
332
417
  # -------------------------------------------------------------------------
@@ -359,6 +444,14 @@ if __name__ == "__main__":
359
444
  print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
360
445
  print(f"Hyperparameters: {hyperparameters}")
361
446
 
447
+ # Log foundation model configuration
448
+ if hyperparameters.get("from_foundation"):
449
+ freeze_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
450
+ freeze_msg = f"MPNN frozen for {freeze_epochs} epochs" if freeze_epochs > 0 else "no freezing"
451
+ print(f"Foundation model: {hyperparameters['from_foundation']} ({freeze_msg})")
452
+ else:
453
+ print("Foundation model: None (training from scratch)")
454
+
362
455
  # Load training data
363
456
  training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
364
457
  print(f"Training Files: {training_files}")
@@ -456,7 +549,7 @@ if __name__ == "__main__":
456
549
  print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
457
550
  print(f"{'='*50}")
458
551
 
459
- # Split data
552
+ # Split data (val_extra_raw preserves unscaled features for OOF predictions)
460
553
  df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
461
554
  train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
462
555
  train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
@@ -484,7 +577,7 @@ if __name__ == "__main__":
484
577
  train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
485
578
  val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
486
579
 
487
- # Build and train model
580
+ # Build model
488
581
  pl.seed_everything(hyperparameters["seed"] + fold_idx)
489
582
  mpnn = build_mpnn_model(
490
583
  hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
@@ -492,14 +585,39 @@ if __name__ == "__main__":
492
585
  output_transform=output_transform, task_weights=task_weights,
493
586
  )
494
587
 
495
- trainer = pl.Trainer(
496
- accelerator="auto", max_epochs=hyperparameters["max_epochs"], logger=False, enable_progress_bar=True,
497
- callbacks=[
498
- pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min"),
499
- pl.callbacks.ModelCheckpoint(dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1),
500
- ],
501
- )
502
- trainer.fit(mpnn, train_loader, val_loader)
588
+ # Train model (with optional two-phase foundation training)
589
+ freeze_mpnn_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
590
+ use_two_phase = hyperparameters.get("from_foundation") and freeze_mpnn_epochs > 0
591
+
592
+ def _set_mpnn_frozen(frozen: bool):
593
+ for param in mpnn.message_passing.parameters():
594
+ param.requires_grad = not frozen
595
+ for param in mpnn.agg.parameters():
596
+ param.requires_grad = not frozen
597
+
598
+ def _make_trainer(max_epochs: int, save_checkpoint: bool = False):
599
+ callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min")]
600
+ if save_checkpoint:
601
+ callbacks.append(pl.callbacks.ModelCheckpoint(
602
+ dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1
603
+ ))
604
+ return pl.Trainer(accelerator="auto", max_epochs=max_epochs, logger=False, enable_progress_bar=True, callbacks=callbacks)
605
+
606
+ if use_two_phase:
607
+ # Phase 1: Freeze MPNN, train FFN only
608
+ print(f"Phase 1: Training with frozen MPNN for {freeze_mpnn_epochs} epochs...")
609
+ _set_mpnn_frozen(True)
610
+ _make_trainer(freeze_mpnn_epochs).fit(mpnn, train_loader, val_loader)
611
+
612
+ # Phase 2: Unfreeze and fine-tune all
613
+ print("Phase 2: Unfreezing MPNN, continuing training...")
614
+ _set_mpnn_frozen(False)
615
+ remaining_epochs = max(1, hyperparameters["max_epochs"] - freeze_mpnn_epochs)
616
+ trainer = _make_trainer(remaining_epochs, save_checkpoint=True)
617
+ trainer.fit(mpnn, train_loader, val_loader)
618
+ else:
619
+ trainer = _make_trainer(hyperparameters["max_epochs"], save_checkpoint=True)
620
+ trainer.fit(mpnn, train_loader, val_loader)
503
621
 
504
622
  # Load best checkpoint
505
623
  if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
@@ -509,7 +627,7 @@ if __name__ == "__main__":
509
627
  mpnn.eval()
510
628
  ensemble_models.append(mpnn)
511
629
 
512
- # Out-of-fold predictions (using raw features)
630
+ # Out-of-fold predictions (using unscaled features - model's x_d_transform handles scaling)
513
631
  val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
514
632
  val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
515
633
 
@@ -585,6 +703,7 @@ if __name__ == "__main__":
585
703
  preds_std = np.std(np.stack(all_ens_preds), axis=0)
586
704
  if preds_std.ndim == 1:
587
705
  preds_std = preds_std.reshape(-1, 1)
706
+ preds_std = cap_std_outliers(preds_std)
588
707
 
589
708
  print("\n--- Per-target metrics ---")
590
709
  for t_idx, t_name in enumerate(target_columns):
@@ -599,11 +718,17 @@ if __name__ == "__main__":
599
718
  df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
600
719
  df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
601
720
 
602
- # Compute confidence from ensemble std
603
- median_std = float(np.median(preds_std[:, 0]))
604
- print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
605
- df_val = _compute_std_confidence(df_val, median_std)
606
- print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
721
+ # Compute confidence from ensemble std (or NaN for single model)
722
+ if preds_std is not None:
723
+ median_std = float(np.median(preds_std[:, 0]))
724
+ print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
725
+ df_val = _compute_std_confidence(df_val, median_std)
726
+ print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
727
+ else:
728
+ # Single model - no ensemble std available, confidence is undefined
729
+ median_std = None
730
+ df_val["confidence"] = np.nan
731
+ print("\nSingle model (n_folds=1): No ensemble std, confidence set to NaN")
607
732
 
608
733
  # -------------------------------------------------------------------------
609
734
  # Save validation predictions to S3
@@ -633,6 +758,9 @@ if __name__ == "__main__":
633
758
  "n_folds": n_folds,
634
759
  "target_columns": target_columns,
635
760
  "median_std": median_std, # For confidence calculation during inference
761
+ # Foundation model provenance (for tracking/reproducibility)
762
+ "from_foundation": hyperparameters.get("from_foundation", None),
763
+ "freeze_mpnn_epochs": hyperparameters.get("freeze_mpnn_epochs", 0),
636
764
  }
637
765
  joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
638
766