workbench 0.8.217__py3-none-any.whl → 0.8.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  3. workbench/algorithms/dataframe/projection_2d.py +8 -2
  4. workbench/algorithms/dataframe/proximity.py +3 -0
  5. workbench/algorithms/sql/outliers.py +3 -3
  6. workbench/api/feature_set.py +0 -1
  7. workbench/core/artifacts/endpoint_core.py +2 -2
  8. workbench/core/artifacts/feature_set_core.py +185 -230
  9. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  10. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  11. workbench/model_script_utils/model_script_utils.py +15 -11
  12. workbench/model_scripts/chemprop/chemprop.template +195 -70
  13. workbench/model_scripts/chemprop/generated_model_script.py +198 -73
  14. workbench/model_scripts/chemprop/model_script_utils.py +15 -11
  15. workbench/model_scripts/custom_models/chem_info/fingerprints.py +80 -43
  16. workbench/model_scripts/pytorch_model/generated_model_script.py +2 -2
  17. workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
  18. workbench/model_scripts/xgb_model/generated_model_script.py +7 -7
  19. workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
  20. workbench/scripts/meta_model_sim.py +35 -0
  21. workbench/scripts/ml_pipeline_sqs.py +71 -2
  22. workbench/themes/light/custom.css +7 -1
  23. workbench/themes/midnight_blue/custom.css +34 -0
  24. workbench/utils/chem_utils/fingerprints.py +80 -43
  25. workbench/utils/chem_utils/projections.py +16 -6
  26. workbench/utils/meta_model_simulator.py +41 -13
  27. workbench/utils/model_utils.py +0 -1
  28. workbench/utils/plot_utils.py +146 -28
  29. workbench/utils/shap_utils.py +1 -55
  30. workbench/utils/theme_manager.py +95 -30
  31. workbench/web_interface/components/plugins/scatter_plot.py +152 -66
  32. workbench/web_interface/components/settings_menu.py +184 -0
  33. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/METADATA +4 -13
  34. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/RECORD +38 -37
  35. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/entry_points.txt +1 -0
  36. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  37. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  38. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/WHEEL +0 -0
  39. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/licenses/LICENSE +0 -0
  40. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,7 @@ DEFAULT_HYPERPARAMETERS = {
34
34
  "max_epochs": 400,
35
35
  "patience": 50,
36
36
  "batch_size": 32,
37
- # Message Passing
37
+ # Message Passing (ignored when using foundation model)
38
38
  "hidden_dim": 700,
39
39
  "depth": 6,
40
40
  "dropout": 0.1, # Lower dropout - ensemble provides regularization
@@ -45,16 +45,24 @@ DEFAULT_HYPERPARAMETERS = {
45
45
  "criterion": "mae",
46
46
  # Random seed
47
47
  "seed": 42,
48
+ # Foundation model support
49
+ # - "CheMeleon": Load CheMeleon pretrained weights (auto-downloads on first use)
50
+ # - Path to .pt file: Load custom pretrained Chemprop model
51
+ # - None: Train from scratch (default)
52
+ "from_foundation": None,
53
+ # Freeze MPNN for N epochs, then unfreeze (0 = no freezing, train all params from start)
54
+ # Recommended: 5-20 epochs when using foundation models to stabilize FFN before fine-tuning MPNN
55
+ "freeze_mpnn_epochs": 0,
48
56
  }
49
57
 
50
58
  # Template parameters (filled in by Workbench)
51
59
  TEMPLATE_PARAMS = {
52
60
  "model_type": "uq_regressor",
53
- "targets": ['udm_asy_res_efflux_ratio'],
61
+ "targets": ['udm_asy_res_free_percent'],
54
62
  "feature_list": ['smiles'],
55
63
  "id_column": "udm_mol_bat_id",
56
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-chemprop/training",
57
- "hyperparameters": {},
64
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/ppb-human-free-reg-chemprop-foundation-1-dt/training",
65
+ "hyperparameters": {'from_foundation': 'CheMeleon', 'freeze_mpnn_epochs': 10, 'n_folds': 5, 'max_epochs': 100, 'patience': 20, 'ffn_hidden_dim': 512, 'dropout': 0.15},
58
66
  }
59
67
 
60
68
 
@@ -114,26 +122,27 @@ def _create_molecule_datapoints(
114
122
  # Model Loading (for SageMaker inference)
115
123
  # =============================================================================
116
124
  def model_fn(model_dir: str) -> dict:
117
- """Load ChemProp MPNN ensemble from the specified directory."""
118
- from lightning import pytorch as pl
125
+ """Load ChemProp MPNN ensemble from the specified directory.
119
126
 
127
+ Optimized for serverless cold starts - uses direct PyTorch inference
128
+ instead of Lightning Trainer to minimize startup time.
129
+ """
120
130
  metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
131
+
132
+ # Load all ensemble models (keep on CPU for serverless compatibility)
133
+ # ChemProp handles device placement internally
121
134
  ensemble_models = []
122
135
  for i in range(metadata["n_ensemble"]):
123
136
  model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
124
137
  model.eval()
125
138
  ensemble_models.append(model)
126
139
 
127
- # Pre-initialize trainer once during model loading (expensive operation)
128
- trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
129
-
130
140
  print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
131
141
  return {
132
142
  "ensemble_models": ensemble_models,
133
143
  "n_ensemble": metadata["n_ensemble"],
134
144
  "target_columns": metadata["target_columns"],
135
145
  "median_std": metadata["median_std"],
136
- "trainer": trainer,
137
146
  }
138
147
 
139
148
 
@@ -141,13 +150,15 @@ def model_fn(model_dir: str) -> dict:
141
150
  # Inference (for SageMaker inference)
142
151
  # =============================================================================
143
152
  def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
144
- """Make predictions with ChemProp MPNN ensemble."""
153
+ """Make predictions with ChemProp MPNN ensemble.
154
+
155
+ Uses direct PyTorch inference (no Lightning Trainer) for fast serverless inference.
156
+ """
145
157
  model_type = TEMPLATE_PARAMS["model_type"]
146
158
  model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
147
159
 
148
160
  ensemble_models = model_dict["ensemble_models"]
149
161
  target_columns = model_dict["target_columns"]
150
- trainer = model_dict["trainer"] # Use pre-initialized trainer
151
162
 
152
163
  # Load artifacts
153
164
  label_encoder = None
@@ -202,18 +213,34 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
202
213
  return df
203
214
 
204
215
  dataset = data.MoleculeDataset(datapoints)
205
- dataloader = data.build_dataloader(dataset, shuffle=False)
216
+ dataloader = data.build_dataloader(dataset, shuffle=False, batch_size=64)
206
217
 
207
- # Ensemble predictions
218
+ # Ensemble predictions using direct PyTorch inference (no Lightning Trainer)
208
219
  all_preds = []
209
220
  for model in ensemble_models:
221
+ model_preds = []
222
+ model.eval()
210
223
  with torch.inference_mode():
211
- predictions = trainer.predict(model, dataloader)
212
- preds = np.concatenate([p.numpy() for p in predictions], axis=0)
224
+ for batch in dataloader:
225
+ # TrainingBatch contains (bmg, V_d, X_d, targets, weights, lt_mask, gt_mask)
226
+ # For inference we only need bmg, V_d, X_d
227
+ bmg, V_d, X_d, *_ = batch
228
+ output = model(bmg, V_d, X_d)
229
+ model_preds.append(output.detach().cpu().numpy())
230
+
231
+ if len(model_preds) == 0:
232
+ print(f"Warning: No predictions generated. Dataset size: {len(datapoints)}")
233
+ continue
234
+
235
+ preds = np.concatenate(model_preds, axis=0)
213
236
  if preds.ndim == 3 and preds.shape[1] == 1:
214
237
  preds = preds.squeeze(axis=1)
215
238
  all_preds.append(preds)
216
239
 
240
+ if len(all_preds) == 0:
241
+ print("Error: No ensemble predictions generated")
242
+ return df
243
+
217
244
  preds = np.mean(np.stack(all_preds), axis=0)
218
245
  preds_std = np.std(np.stack(all_preds), axis=0)
219
246
  if preds.ndim == 1:
@@ -243,8 +270,11 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
243
270
  df["prediction"] = df[f"{target_columns[0]}_pred"]
244
271
  df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
245
272
 
246
- # Compute confidence from ensemble std
247
- df = _compute_std_confidence(df, model_dict["median_std"])
273
+ # Compute confidence from ensemble std (or NaN if single model)
274
+ if model_dict["median_std"] is not None:
275
+ df = _compute_std_confidence(df, model_dict["median_std"])
276
+ else:
277
+ df["confidence"] = np.nan
248
278
 
249
279
  return df
250
280
 
@@ -279,54 +309,107 @@ if __name__ == "__main__":
279
309
  )
280
310
 
281
311
  # -------------------------------------------------------------------------
282
- # Training-only helper function
312
+ # Training-only helper functions
283
313
  # -------------------------------------------------------------------------
284
- def build_mpnn_model(
285
- hyperparameters: dict,
286
- task: str = "regression",
287
- num_classes: int | None = None,
288
- n_targets: int = 1,
289
- n_extra_descriptors: int = 0,
290
- x_d_transform: nn.ScaleTransform | None = None,
291
- output_transform: nn.UnscaleTransform | None = None,
292
- task_weights: np.ndarray | None = None,
293
- ) -> models.MPNN:
294
- """Build an MPNN model with specified hyperparameters."""
295
- hidden_dim = hyperparameters["hidden_dim"]
296
- depth = hyperparameters["depth"]
314
+ def _load_foundation_weights(from_foundation: str) -> tuple[nn.BondMessagePassing, nn.Aggregation]:
315
+ """Load pretrained MPNN weights from foundation model.
316
+
317
+ Args:
318
+ from_foundation: "CheMeleon" or path to .pt file
319
+
320
+ Returns:
321
+ Tuple of (message_passing, aggregation) modules
322
+ """
323
+ import urllib.request
324
+ from pathlib import Path
325
+
326
+ print(f"Loading foundation model: {from_foundation}")
327
+
328
+ if from_foundation.lower() == "chemeleon":
329
+ # Download from Zenodo if not cached
330
+ cache_dir = Path.home() / ".chemprop" / "foundation"
331
+ cache_dir.mkdir(parents=True, exist_ok=True)
332
+ chemeleon_path = cache_dir / "chemeleon_mp.pt"
333
+
334
+ if not chemeleon_path.exists():
335
+ print(" Downloading CheMeleon weights from Zenodo...")
336
+ urllib.request.urlretrieve(
337
+ "https://zenodo.org/records/15460715/files/chemeleon_mp.pt", chemeleon_path
338
+ )
339
+ print(f" Downloaded to {chemeleon_path}")
340
+
341
+ ckpt = torch.load(chemeleon_path, weights_only=True)
342
+ mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
343
+ mp.load_state_dict(ckpt["state_dict"])
344
+ print(f" Loaded CheMeleon MPNN (hidden_dim={mp.output_dim})")
345
+ return mp, nn.MeanAggregation()
346
+
347
+ if not os.path.exists(from_foundation):
348
+ raise ValueError(f"Foundation model not found: {from_foundation}. Use 'CheMeleon' or a valid .pt path.")
349
+
350
+ ckpt = torch.load(from_foundation, weights_only=False)
351
+ if "hyper_parameters" in ckpt and "state_dict" in ckpt:
352
+ # CheMeleon-style checkpoint
353
+ mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
354
+ mp.load_state_dict(ckpt["state_dict"])
355
+ print(f" Loaded custom foundation weights (hidden_dim={mp.output_dim})")
356
+ return mp, nn.MeanAggregation()
357
+
358
+ # Full MPNN model file
359
+ pretrained = models.MPNN.load_from_file(from_foundation)
360
+ print(f" Loaded custom MPNN (hidden_dim={pretrained.message_passing.output_dim})")
361
+ return pretrained.message_passing, pretrained.agg
362
+
363
+ def _build_ffn(
364
+ task: str, input_dim: int, hyperparameters: dict,
365
+ num_classes: int | None, n_targets: int,
366
+ output_transform: nn.UnscaleTransform | None, task_weights: np.ndarray | None,
367
+ ) -> nn.Predictor:
368
+ """Build task-specific FFN head."""
297
369
  dropout = hyperparameters["dropout"]
298
370
  ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
299
371
  ffn_num_layers = hyperparameters["ffn_num_layers"]
300
372
 
301
- mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
302
- agg = nn.NormAggregation()
303
- ffn_input_dim = hidden_dim + n_extra_descriptors
304
-
305
373
  if task == "classification" and num_classes is not None:
306
- ffn = nn.MulticlassClassificationFFN(
307
- n_classes=num_classes, input_dim=ffn_input_dim,
374
+ return nn.MulticlassClassificationFFN(
375
+ n_classes=num_classes, input_dim=input_dim,
308
376
  hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
309
377
  )
378
+
379
+ from chemprop.nn.metrics import MAE, MSE
380
+ criterion_map = {"mae": MAE, "mse": MSE}
381
+ criterion_name = hyperparameters.get("criterion", "mae")
382
+ if criterion_name not in criterion_map:
383
+ raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
384
+
385
+ weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
386
+ return nn.RegressionFFN(
387
+ input_dim=input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
388
+ dropout=dropout, n_tasks=n_targets, output_transform=output_transform,
389
+ task_weights=weights_tensor, criterion=criterion_map[criterion_name](),
390
+ )
391
+
392
+ def build_mpnn_model(
393
+ hyperparameters: dict, task: str = "regression", num_classes: int | None = None,
394
+ n_targets: int = 1, n_extra_descriptors: int = 0,
395
+ x_d_transform: nn.ScaleTransform | None = None,
396
+ output_transform: nn.UnscaleTransform | None = None, task_weights: np.ndarray | None = None,
397
+ ) -> models.MPNN:
398
+ """Build MPNN model, optionally loading pretrained weights."""
399
+ from_foundation = hyperparameters.get("from_foundation")
400
+
401
+ if from_foundation:
402
+ mp, agg = _load_foundation_weights(from_foundation)
403
+ ffn_input_dim = mp.output_dim + n_extra_descriptors
310
404
  else:
311
- # Map criterion name to ChemProp metric class (must have .clone() method)
312
- from chemprop.nn.metrics import MAE, MSE
313
-
314
- criterion_map = {
315
- "mae": MAE,
316
- "mse": MSE,
317
- }
318
- criterion_name = hyperparameters.get("criterion", "mae")
319
- if criterion_name not in criterion_map:
320
- raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
321
- criterion = criterion_map[criterion_name]()
322
-
323
- weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
324
- ffn = nn.RegressionFFN(
325
- input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
326
- dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
327
- criterion=criterion,
405
+ mp = nn.BondMessagePassing(
406
+ d_h=hyperparameters["hidden_dim"], depth=hyperparameters["depth"],
407
+ dropout=hyperparameters["dropout"],
328
408
  )
409
+ agg = nn.NormAggregation()
410
+ ffn_input_dim = hyperparameters["hidden_dim"] + n_extra_descriptors
329
411
 
412
+ ffn = _build_ffn(task, ffn_input_dim, hyperparameters, num_classes, n_targets, output_transform, task_weights)
330
413
  return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
331
414
 
332
415
  # -------------------------------------------------------------------------
@@ -359,6 +442,14 @@ if __name__ == "__main__":
359
442
  print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
360
443
  print(f"Hyperparameters: {hyperparameters}")
361
444
 
445
+ # Log foundation model configuration
446
+ if hyperparameters.get("from_foundation"):
447
+ freeze_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
448
+ freeze_msg = f"MPNN frozen for {freeze_epochs} epochs" if freeze_epochs > 0 else "no freezing"
449
+ print(f"Foundation model: {hyperparameters['from_foundation']} ({freeze_msg})")
450
+ else:
451
+ print("Foundation model: None (training from scratch)")
452
+
362
453
  # Load training data
363
454
  training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
364
455
  print(f"Training Files: {training_files}")
@@ -456,7 +547,7 @@ if __name__ == "__main__":
456
547
  print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
457
548
  print(f"{'='*50}")
458
549
 
459
- # Split data
550
+ # Split data (val_extra_raw preserves unscaled features for OOF predictions)
460
551
  df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
461
552
  train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
462
553
  train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
@@ -481,10 +572,10 @@ if __name__ == "__main__":
481
572
  val_dataset.normalize_targets(target_scaler)
482
573
  output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
483
574
 
484
- train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
485
- val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
575
+ train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
576
+ val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
486
577
 
487
- # Build and train model
578
+ # Build model
488
579
  pl.seed_everything(hyperparameters["seed"] + fold_idx)
489
580
  mpnn = build_mpnn_model(
490
581
  hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
@@ -492,14 +583,39 @@ if __name__ == "__main__":
492
583
  output_transform=output_transform, task_weights=task_weights,
493
584
  )
494
585
 
495
- trainer = pl.Trainer(
496
- accelerator="auto", max_epochs=hyperparameters["max_epochs"], logger=False, enable_progress_bar=True,
497
- callbacks=[
498
- pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min"),
499
- pl.callbacks.ModelCheckpoint(dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1),
500
- ],
501
- )
502
- trainer.fit(mpnn, train_loader, val_loader)
586
+ # Train model (with optional two-phase foundation training)
587
+ freeze_mpnn_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
588
+ use_two_phase = hyperparameters.get("from_foundation") and freeze_mpnn_epochs > 0
589
+
590
+ def _set_mpnn_frozen(frozen: bool):
591
+ for param in mpnn.message_passing.parameters():
592
+ param.requires_grad = not frozen
593
+ for param in mpnn.agg.parameters():
594
+ param.requires_grad = not frozen
595
+
596
+ def _make_trainer(max_epochs: int, save_checkpoint: bool = False):
597
+ callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min")]
598
+ if save_checkpoint:
599
+ callbacks.append(pl.callbacks.ModelCheckpoint(
600
+ dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1
601
+ ))
602
+ return pl.Trainer(accelerator="auto", max_epochs=max_epochs, logger=False, enable_progress_bar=True, callbacks=callbacks)
603
+
604
+ if use_two_phase:
605
+ # Phase 1: Freeze MPNN, train FFN only
606
+ print(f"Phase 1: Training with frozen MPNN for {freeze_mpnn_epochs} epochs...")
607
+ _set_mpnn_frozen(True)
608
+ _make_trainer(freeze_mpnn_epochs).fit(mpnn, train_loader, val_loader)
609
+
610
+ # Phase 2: Unfreeze and fine-tune all
611
+ print("Phase 2: Unfreezing MPNN, continuing training...")
612
+ _set_mpnn_frozen(False)
613
+ remaining_epochs = max(1, hyperparameters["max_epochs"] - freeze_mpnn_epochs)
614
+ trainer = _make_trainer(remaining_epochs, save_checkpoint=True)
615
+ trainer.fit(mpnn, train_loader, val_loader)
616
+ else:
617
+ trainer = _make_trainer(hyperparameters["max_epochs"], save_checkpoint=True)
618
+ trainer.fit(mpnn, train_loader, val_loader)
503
619
 
504
620
  # Load best checkpoint
505
621
  if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
@@ -509,7 +625,7 @@ if __name__ == "__main__":
509
625
  mpnn.eval()
510
626
  ensemble_models.append(mpnn)
511
627
 
512
- # Out-of-fold predictions (using raw features)
628
+ # Out-of-fold predictions (using unscaled features - model's x_d_transform handles scaling)
513
629
  val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
514
630
  val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
515
631
 
@@ -599,11 +715,17 @@ if __name__ == "__main__":
599
715
  df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
600
716
  df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
601
717
 
602
- # Compute confidence from ensemble std
603
- median_std = float(np.median(preds_std[:, 0]))
604
- print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
605
- df_val = _compute_std_confidence(df_val, median_std)
606
- print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
718
+ # Compute confidence from ensemble std (or NaN for single model)
719
+ if preds_std is not None:
720
+ median_std = float(np.median(preds_std[:, 0]))
721
+ print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
722
+ df_val = _compute_std_confidence(df_val, median_std)
723
+ print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
724
+ else:
725
+ # Single model - no ensemble std available, confidence is undefined
726
+ median_std = None
727
+ df_val["confidence"] = np.nan
728
+ print("\nSingle model (n_folds=1): No ensemble std, confidence set to NaN")
607
729
 
608
730
  # -------------------------------------------------------------------------
609
731
  # Save validation predictions to S3
@@ -633,6 +755,9 @@ if __name__ == "__main__":
633
755
  "n_folds": n_folds,
634
756
  "target_columns": target_columns,
635
757
  "median_std": median_std, # For confidence calculation during inference
758
+ # Foundation model provenance (for tracking/reproducibility)
759
+ "from_foundation": hyperparameters.get("from_foundation", None),
760
+ "freeze_mpnn_epochs": hyperparameters.get("freeze_mpnn_epochs", 0),
636
761
  }
637
762
  joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
638
763
 
@@ -148,12 +148,16 @@ def convert_categorical_types(
148
148
  def decompress_features(
149
149
  df: pd.DataFrame, features: list[str], compressed_features: list[str]
150
150
  ) -> tuple[pd.DataFrame, list[str]]:
151
- """Decompress bitstring features into individual bit columns.
151
+ """Decompress compressed features (bitstrings or count vectors) into individual columns.
152
+
153
+ Supports two formats (auto-detected):
154
+ - Bitstrings: "10110010..." → individual uint8 columns (0 or 1)
155
+ - Count vectors: "0,3,0,1,5,..." → individual uint8 columns (0-255)
152
156
 
153
157
  Args:
154
158
  df: The features DataFrame
155
159
  features: Full list of feature names
156
- compressed_features: List of feature names to decompress (bitstrings)
160
+ compressed_features: List of feature names to decompress
157
161
 
158
162
  Returns:
159
163
  Tuple of (DataFrame with decompressed features, updated feature list)
@@ -178,18 +182,18 @@ def decompress_features(
178
182
  # Remove the feature from the list to avoid duplication
179
183
  decompressed_features.remove(feature)
180
184
 
181
- # Handle all compressed features as bitstrings
182
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
183
- prefix = feature[:3]
185
+ # Auto-detect format and parse: comma-separated counts or bitstring
186
+ sample = str(df[feature].dropna().iloc[0]) if not df[feature].dropna().empty else ""
187
+ parse_fn = (lambda s: list(map(int, s.split(",")))) if "," in sample else list
188
+ feature_matrix = np.array([parse_fn(s) for s in df[feature]], dtype=np.uint8)
184
189
 
185
- # Create all new columns at once - avoids fragmentation
186
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
187
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
190
+ # Create new columns with prefix from feature name
191
+ prefix = feature[:3]
192
+ new_col_names = [f"{prefix}_{i}" for i in range(feature_matrix.shape[1])]
193
+ new_df = pd.DataFrame(feature_matrix, columns=new_col_names, index=df.index)
188
194
 
189
- # Add to features list
195
+ # Update features list and dataframe
190
196
  decompressed_features.extend(new_col_names)
191
-
192
- # Drop original column and concatenate new ones
193
197
  df = df.drop(columns=[feature])
194
198
  df = pd.concat([df, new_df], axis=1)
195
199
 
@@ -1,11 +1,19 @@
1
- """Molecular fingerprint computation utilities"""
1
+ """Molecular fingerprint computation utilities for ADMET modeling.
2
+
3
+ This module provides Morgan count fingerprints, the standard for ADMET prediction.
4
+ Count fingerprints outperform binary fingerprints for molecular property prediction.
5
+
6
+ References:
7
+ - Count vs Binary: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
8
+ - ECFP/Morgan: https://pubs.acs.org/doi/10.1021/ci100050t
9
+ """
2
10
 
3
11
  import logging
4
- import pandas as pd
5
12
 
6
- # Molecular Descriptor Imports
13
+ import numpy as np
14
+ import pandas as pd
7
15
  from rdkit import Chem, RDLogger
8
- from rdkit.Chem import rdFingerprintGenerator
16
+ from rdkit.Chem import AllChem
9
17
  from rdkit.Chem.MolStandardize import rdMolStandardize
10
18
 
11
19
  # Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
@@ -16,20 +24,25 @@ RDLogger.DisableLog("rdApp.warning")
16
24
  log = logging.getLogger("workbench")
17
25
 
18
26
 
19
- def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=True) -> pd.DataFrame:
20
- """Compute and add Morgan fingerprints to the DataFrame.
27
+ def compute_morgan_fingerprints(df: pd.DataFrame, radius: int = 2, n_bits: int = 2048) -> pd.DataFrame:
28
+ """Compute Morgan count fingerprints for ADMET modeling.
29
+
30
+ Generates true count fingerprints where each bit position contains the
31
+ number of times that substructure appears in the molecule (clamped to 0-255).
32
+ This is the recommended approach for ADMET prediction per 2025 research.
21
33
 
22
34
  Args:
23
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
24
- radius (int): Radius for the Morgan fingerprint.
25
- n_bits (int): Number of bits for the fingerprint.
26
- counts (bool): Count simulation for the fingerprint.
35
+ df: Input DataFrame containing SMILES strings.
36
+ radius: Radius for the Morgan fingerprint (default 2 = ECFP4 equivalent).
37
+ n_bits: Number of bits for the fingerprint (default 2048).
27
38
 
28
39
  Returns:
29
- pd.DataFrame: The input DataFrame with the Morgan fingerprints added as bit strings.
40
+ pd.DataFrame: Input DataFrame with 'fingerprint' column added.
41
+ Values are comma-separated uint8 counts.
30
42
 
31
43
  Note:
32
- See: https://greglandrum.github.io/rdkit-blog/posts/2021-07-06-simulating-counts.html
44
+ Count fingerprints outperform binary for ADMET prediction.
45
+ See: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
33
46
  """
34
47
  delete_mol_column = False
35
48
 
@@ -43,7 +56,7 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
43
56
  log.warning("Detected serialized molecules in 'molecule' column. Removing...")
44
57
  del df["molecule"]
45
58
 
46
- # Convert SMILES to RDKit molecule objects (vectorized)
59
+ # Convert SMILES to RDKit molecule objects
47
60
  if "molecule" not in df.columns:
48
61
  log.info("Converting SMILES to RDKit Molecules...")
49
62
  delete_mol_column = True
@@ -59,15 +72,24 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
59
72
  lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
60
73
  )
61
74
 
62
- # Create a Morgan fingerprint generator
63
- if counts:
64
- n_bits *= 4 # Multiply by 4 to simulate counts
65
- morgan_generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits, countSimulation=counts)
75
+ def mol_to_count_string(mol):
76
+ """Convert molecule to comma-separated count fingerprint string."""
77
+ if mol is None:
78
+ return pd.NA
66
79
 
67
- # Compute Morgan fingerprints (vectorized)
68
- fingerprints = largest_frags.apply(
69
- lambda mol: (morgan_generator.GetFingerprint(mol).ToBitString() if mol else pd.NA)
70
- )
80
+ # Get hashed Morgan fingerprint with counts
81
+ fp = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=n_bits)
82
+
83
+ # Initialize array and populate with counts (clamped to uint8 range)
84
+ counts = np.zeros(n_bits, dtype=np.uint8)
85
+ for idx, count in fp.GetNonzeroElements().items():
86
+ counts[idx] = min(count, 255)
87
+
88
+ # Return as comma-separated string
89
+ return ",".join(map(str, counts))
90
+
91
+ # Compute Morgan count fingerprints
92
+ fingerprints = largest_frags.apply(mol_to_count_string)
71
93
 
72
94
  # Add the fingerprints to the DataFrame
73
95
  df["fingerprint"] = fingerprints
@@ -75,59 +97,62 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
75
97
  # Drop the intermediate 'molecule' column if it was added
76
98
  if delete_mol_column:
77
99
  del df["molecule"]
100
+
78
101
  return df
79
102
 
80
103
 
81
104
  if __name__ == "__main__":
82
- print("Running molecular fingerprint tests...")
83
- print("Note: This requires molecular_screening module to be available")
105
+ print("Running Morgan count fingerprint tests...")
84
106
 
85
107
  # Test molecules
86
108
  test_molecules = {
87
109
  "aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
88
110
  "caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
89
111
  "glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
90
- "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt
112
+ "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt (largest fragment used)
91
113
  "benzene": "c1ccccc1",
92
114
  "butene_e": "C/C=C/C", # E-butene
93
115
  "butene_z": "C/C=C\\C", # Z-butene
94
116
  }
95
117
 
96
- # Test 1: Morgan Fingerprints
97
- print("\n1. Testing Morgan fingerprint generation...")
118
+ # Test 1: Morgan Count Fingerprints (default parameters)
119
+ print("\n1. Testing Morgan fingerprint generation (radius=2, n_bits=2048)...")
98
120
 
99
121
  test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
100
-
101
- fp_df = compute_morgan_fingerprints(test_df.copy(), radius=2, n_bits=512, counts=False)
122
+ fp_df = compute_morgan_fingerprints(test_df.copy())
102
123
 
103
124
  print(" Fingerprint generation results:")
104
125
  for _, row in fp_df.iterrows():
105
126
  fp = row.get("fingerprint", "N/A")
106
- fp_len = len(fp) if fp != "N/A" else 0
107
- print(f" {row['name']:15} {fp_len} bits")
127
+ if pd.notna(fp):
128
+ counts = [int(x) for x in fp.split(",")]
129
+ non_zero = sum(1 for c in counts if c > 0)
130
+ max_count = max(counts)
131
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero, max={max_count}")
132
+ else:
133
+ print(f" {row['name']:15} → N/A")
108
134
 
109
- # Test 2: Different fingerprint parameters
110
- print("\n2. Testing different fingerprint parameters...")
135
+ # Test 2: Different parameters
136
+ print("\n2. Testing with different parameters (radius=3, n_bits=1024)...")
111
137
 
112
- # Test with counts enabled
113
- fp_counts_df = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=256, counts=True)
138
+ fp_df_custom = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=1024)
114
139
 
115
- print(" With count simulation (256 bits * 4):")
116
- for _, row in fp_counts_df.iterrows():
140
+ for _, row in fp_df_custom.iterrows():
117
141
  fp = row.get("fingerprint", "N/A")
118
- fp_len = len(fp) if fp != "N/A" else 0
119
- print(f" {row['name']:15} {fp_len} bits")
142
+ if pd.notna(fp):
143
+ counts = [int(x) for x in fp.split(",")]
144
+ non_zero = sum(1 for c in counts if c > 0)
145
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero")
146
+ else:
147
+ print(f" {row['name']:15} → N/A")
120
148
 
121
149
  # Test 3: Edge cases
122
150
  print("\n3. Testing edge cases...")
123
151
 
124
152
  # Invalid SMILES
125
153
  invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
126
- try:
127
- fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
128
- print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} valid molecules")
129
- except Exception as e:
130
- print(f" ✓ Invalid SMILES properly raised error: {type(e).__name__}")
154
+ fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
155
+ print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} rows returned")
131
156
 
132
157
  # Test with pre-existing molecule column
133
158
  mol_df = test_df.copy()
@@ -135,4 +160,16 @@ if __name__ == "__main__":
135
160
  fp_with_mol = compute_morgan_fingerprints(mol_df)
136
161
  print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
137
162
 
163
+ # Test 4: Verify count values are reasonable
164
+ print("\n4. Verifying count distribution...")
165
+ all_counts = []
166
+ for _, row in fp_df.iterrows():
167
+ fp = row.get("fingerprint", "N/A")
168
+ if pd.notna(fp):
169
+ counts = [int(x) for x in fp.split(",")]
170
+ all_counts.extend([c for c in counts if c > 0])
171
+
172
+ if all_counts:
173
+ print(f" Non-zero counts: min={min(all_counts)}, max={max(all_counts)}, mean={np.mean(all_counts):.2f}")
174
+
138
175
  print("\n✅ All fingerprint tests completed!")