workbench 0.8.231__py3-none-any.whl → 0.8.236__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. workbench/algorithms/dataframe/smart_aggregator.py +17 -12
  2. workbench/api/endpoint.py +13 -4
  3. workbench/cached/cached_model.py +2 -2
  4. workbench/core/artifacts/endpoint_core.py +30 -5
  5. workbench/model_script_utils/model_script_utils.py +225 -0
  6. workbench/model_script_utils/uq_harness.py +39 -21
  7. workbench/model_scripts/chemprop/chemprop.template +29 -14
  8. workbench/model_scripts/chemprop/generated_model_script.py +35 -18
  9. workbench/model_scripts/chemprop/model_script_utils.py +225 -0
  10. workbench/model_scripts/pytorch_model/generated_model_script.py +34 -20
  11. workbench/model_scripts/pytorch_model/model_script_utils.py +225 -0
  12. workbench/model_scripts/pytorch_model/pytorch.template +28 -14
  13. workbench/model_scripts/pytorch_model/uq_harness.py +39 -21
  14. workbench/model_scripts/xgb_model/generated_model_script.py +35 -22
  15. workbench/model_scripts/xgb_model/model_script_utils.py +225 -0
  16. workbench/model_scripts/xgb_model/uq_harness.py +39 -21
  17. workbench/model_scripts/xgb_model/xgb_model.template +29 -18
  18. workbench/themes/dark/custom.css +29 -0
  19. workbench/themes/light/custom.css +29 -0
  20. workbench/themes/midnight_blue/custom.css +28 -0
  21. workbench/utils/markdown_utils.py +5 -1
  22. workbench/utils/model_utils.py +9 -0
  23. workbench/utils/theme_manager.py +95 -0
  24. workbench/web_interface/components/component_interface.py +3 -0
  25. workbench/web_interface/components/plugin_interface.py +26 -0
  26. workbench/web_interface/components/plugins/confusion_matrix.py +14 -8
  27. workbench/web_interface/components/plugins/model_details.py +18 -5
  28. workbench/web_interface/components/plugins/model_plot.py +156 -0
  29. workbench/web_interface/components/plugins/scatter_plot.py +9 -2
  30. workbench/web_interface/components/plugins/shap_summary_plot.py +12 -4
  31. workbench/web_interface/components/settings_menu.py +10 -49
  32. {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/METADATA +1 -1
  33. {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/RECORD +37 -37
  34. workbench/web_interface/components/model_plot.py +0 -75
  35. {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/WHEEL +0 -0
  36. {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/entry_points.txt +0 -0
  37. {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/licenses/LICENSE +0 -0
  38. {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import torch
20
20
  from chemprop import data, models
21
21
 
22
22
  from model_script_utils import (
23
+ cap_std_outliers,
23
24
  expand_proba_column,
24
25
  input_fn,
25
26
  output_fn,
@@ -43,6 +44,12 @@ DEFAULT_HYPERPARAMETERS = {
43
44
  "ffn_num_layers": 2,
44
45
  # Loss function for regression (mae, mse)
45
46
  "criterion": "mae",
47
+ # Split strategy: "random", "scaffold", or "butina"
48
+ # - random: Standard random split
49
+ # - scaffold: Bemis-Murcko scaffold-based grouping
50
+ # - butina: Morgan fingerprint clustering (recommended for ADMET)
51
+ "split_strategy": "butina",
52
+ "butina_cutoff": 0.4, # Tanimoto distance cutoff for Butina clustering
46
53
  # Random seed
47
54
  "seed": 42,
48
55
  # Foundation model support
@@ -58,11 +65,11 @@ DEFAULT_HYPERPARAMETERS = {
58
65
  # Template parameters (filled in by Workbench)
59
66
  TEMPLATE_PARAMS = {
60
67
  "model_type": "uq_regressor",
61
- "targets": ['udm_asy_res_value'],
68
+ "targets": ['logd'],
62
69
  "feature_list": ['smiles'],
63
- "id_column": "udm_mol_bat_id",
64
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/logd-value-reg-chemprop-1-dt/training",
65
- "hyperparameters": {},
70
+ "id_column": "molecule_name",
71
+ "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/logd-chemprop-split-butina/training",
72
+ "hyperparameters": {'split_strategy': 'butina'},
66
73
  }
67
74
 
68
75
 
@@ -245,6 +252,7 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
245
252
  preds_std = np.std(np.stack(all_preds), axis=0)
246
253
  if preds.ndim == 1:
247
254
  preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
255
+ preds_std = cap_std_outliers(preds_std)
248
256
 
249
257
  print(f"Inference complete: {preds.shape[0]} predictions")
250
258
 
@@ -303,6 +311,7 @@ if __name__ == "__main__":
303
311
  check_dataframe,
304
312
  compute_classification_metrics,
305
313
  compute_regression_metrics,
314
+ get_split_indices,
306
315
  print_classification_metrics,
307
316
  print_confusion_matrix,
308
317
  print_regression_metrics,
@@ -516,22 +525,29 @@ if __name__ == "__main__":
516
525
  n_folds = hyperparameters["n_folds"]
517
526
  batch_size = hyperparameters["batch_size"]
518
527
 
519
- if n_folds == 1:
520
- if "training" in all_df.columns:
521
- print("Using 'training' column for train/val split")
522
- train_idx = np.where(all_df["training"])[0]
523
- val_idx = np.where(~all_df["training"])[0]
524
- else:
525
- print("WARNING: No 'training' column, using random 80/20 split")
526
- train_idx, val_idx = train_test_split(np.arange(len(all_df)), test_size=0.2, random_state=42)
528
+ # Get split strategy parameters
529
+ split_strategy = hyperparameters.get("split_strategy", "random")
530
+ butina_cutoff = hyperparameters.get("butina_cutoff", 0.4)
531
+
532
+ # Check for pre-defined training column (overrides split strategy)
533
+ if n_folds == 1 and "training" in all_df.columns:
534
+ print("Using 'training' column for train/val split")
535
+ train_idx = np.where(all_df["training"])[0]
536
+ val_idx = np.where(~all_df["training"])[0]
527
537
  folds = [(train_idx, val_idx)]
528
538
  else:
529
- if model_type == "classifier":
530
- kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
531
- folds = list(kfold.split(all_df, all_df[target_columns[0]]))
532
- else:
533
- kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
534
- folds = list(kfold.split(all_df))
539
+ # Use unified split interface (auto-detects 'smiles' column for scaffold/butina)
540
+ target_col = target_columns[0] if model_type == "classifier" else None
541
+ folds = get_split_indices(
542
+ all_df,
543
+ n_splits=n_folds,
544
+ strategy=split_strategy,
545
+ target_column=target_col,
546
+ test_size=0.2,
547
+ random_state=42,
548
+ butina_cutoff=butina_cutoff,
549
+ )
550
+ print(f"Split strategy: {split_strategy}")
535
551
 
536
552
  print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
537
553
 
@@ -701,6 +717,7 @@ if __name__ == "__main__":
701
717
  preds_std = np.std(np.stack(all_ens_preds), axis=0)
702
718
  if preds_std.ndim == 1:
703
719
  preds_std = preds_std.reshape(-1, 1)
720
+ preds_std = cap_std_outliers(preds_std)
704
721
 
705
722
  print("\n--- Per-target metrics ---")
706
723
  for t_idx, t_name in enumerate(target_columns):
@@ -16,6 +16,7 @@ from sklearn.metrics import (
16
16
  r2_score,
17
17
  root_mean_squared_error,
18
18
  )
19
+ from sklearn.model_selection import GroupKFold, GroupShuffleSplit
19
20
  from scipy.stats import spearmanr
20
21
 
21
22
 
@@ -367,3 +368,227 @@ def print_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, label_names:
367
368
  for j, col_name in enumerate(label_names):
368
369
  value = conf_mtx[i, j]
369
370
  print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
371
+
372
+
373
+ # =============================================================================
374
+ # Dataset Splitting Utilities for Molecular Data
375
+ # =============================================================================
376
+ def get_scaffold(smiles: str) -> str:
377
+ """Extract Bemis-Murcko scaffold from a SMILES string.
378
+
379
+ Args:
380
+ smiles: SMILES string of the molecule
381
+
382
+ Returns:
383
+ SMILES string of the scaffold, or empty string if molecule is invalid
384
+ """
385
+ from rdkit import Chem
386
+ from rdkit.Chem.Scaffolds import MurckoScaffold
387
+
388
+ mol = Chem.MolFromSmiles(smiles)
389
+ if mol is None:
390
+ return ""
391
+ try:
392
+ scaffold = MurckoScaffold.GetScaffoldForMol(mol)
393
+ return Chem.MolToSmiles(scaffold)
394
+ except Exception:
395
+ return ""
396
+
397
+
398
+ def get_scaffold_groups(smiles_list: list[str]) -> np.ndarray:
399
+ """Assign each molecule to a scaffold group.
400
+
401
+ Args:
402
+ smiles_list: List of SMILES strings
403
+
404
+ Returns:
405
+ Array of group indices (same scaffold = same group)
406
+ """
407
+ scaffold_to_group = {}
408
+ groups = []
409
+
410
+ for smi in smiles_list:
411
+ scaffold = get_scaffold(smi)
412
+ if scaffold not in scaffold_to_group:
413
+ scaffold_to_group[scaffold] = len(scaffold_to_group)
414
+ groups.append(scaffold_to_group[scaffold])
415
+
416
+ n_scaffolds = len(scaffold_to_group)
417
+ print(f"Found {n_scaffolds} unique scaffolds from {len(smiles_list)} molecules")
418
+ return np.array(groups)
419
+
420
+
421
+ def get_butina_clusters(smiles_list: list[str], cutoff: float = 0.4) -> np.ndarray:
422
+ """Cluster molecules using Butina algorithm on Morgan fingerprints.
423
+
424
+ Uses RDKit's Butina clustering with Tanimoto distance on Morgan fingerprints.
425
+ This is Pat Walters' recommended approach for creating diverse train/test splits.
426
+
427
+ Args:
428
+ smiles_list: List of SMILES strings
429
+ cutoff: Tanimoto distance cutoff for clustering (default 0.4)
430
+ Lower values = more clusters = more similar molecules per cluster
431
+
432
+ Returns:
433
+ Array of cluster indices
434
+ """
435
+ from rdkit import Chem, DataStructs
436
+ from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
437
+ from rdkit.ML.Cluster import Butina
438
+
439
+ # Create Morgan fingerprint generator
440
+ fp_gen = GetMorganGenerator(radius=2, fpSize=2048)
441
+
442
+ # Generate Morgan fingerprints
443
+ fps = []
444
+ valid_indices = []
445
+ for i, smi in enumerate(smiles_list):
446
+ mol = Chem.MolFromSmiles(smi)
447
+ if mol is not None:
448
+ fp = fp_gen.GetFingerprint(mol)
449
+ fps.append(fp)
450
+ valid_indices.append(i)
451
+
452
+ if len(fps) == 0:
453
+ raise ValueError("No valid molecules found for clustering")
454
+
455
+ # Compute distance matrix (upper triangle only for efficiency)
456
+ n = len(fps)
457
+ dists = []
458
+ for i in range(1, n):
459
+ sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
460
+ dists.extend([1 - s for s in sims])
461
+
462
+ # Butina clustering
463
+ clusters = Butina.ClusterData(dists, n, cutoff, isDistData=True)
464
+
465
+ # Map back to original indices
466
+ cluster_labels = np.zeros(len(smiles_list), dtype=int)
467
+ for cluster_idx, cluster in enumerate(clusters):
468
+ for mol_idx in cluster:
469
+ original_idx = valid_indices[mol_idx]
470
+ cluster_labels[original_idx] = cluster_idx
471
+
472
+ # Assign invalid molecules to their own clusters
473
+ next_cluster = len(clusters)
474
+ for i in range(len(smiles_list)):
475
+ if i not in valid_indices:
476
+ cluster_labels[i] = next_cluster
477
+ next_cluster += 1
478
+
479
+ n_clusters = len(set(cluster_labels))
480
+ print(f"Butina clustering: {n_clusters} clusters from {len(smiles_list)} molecules (cutoff={cutoff})")
481
+ return cluster_labels
482
+
483
+
484
+ def _find_smiles_column(columns: list[str]) -> str | None:
485
+ """Find SMILES column (case-insensitive match for 'smiles').
486
+
487
+ Args:
488
+ columns: List of column names
489
+
490
+ Returns:
491
+ The matching column name, or None if not found
492
+ """
493
+ return next((c for c in columns if c.lower() == "smiles"), None)
494
+
495
+
496
+ def get_split_indices(
497
+ df: pd.DataFrame,
498
+ n_splits: int = 5,
499
+ strategy: str = "random",
500
+ smiles_column: str | None = None,
501
+ target_column: str | None = None,
502
+ test_size: float = 0.2,
503
+ random_state: int = 42,
504
+ butina_cutoff: float = 0.4,
505
+ ) -> list[tuple[np.ndarray, np.ndarray]]:
506
+ """Get train/validation split indices using various strategies.
507
+
508
+ This is a unified interface for generating splits that can be used across
509
+ all model templates (XGBoost, PyTorch, ChemProp).
510
+
511
+ Args:
512
+ df: DataFrame containing the data
513
+ n_splits: Number of CV folds (1 = single train/val split)
514
+ strategy: Split strategy - one of:
515
+ - "random": Standard random split (default sklearn behavior)
516
+ - "scaffold": Bemis-Murcko scaffold-based grouping
517
+ - "butina": Morgan fingerprint clustering (recommended for ADMET)
518
+ smiles_column: Column containing SMILES. If None, auto-detects 'smiles' (case-insensitive)
519
+ target_column: Column containing target values (for stratification, optional)
520
+ test_size: Fraction for validation set when n_splits=1 (default 0.2)
521
+ random_state: Random seed for reproducibility
522
+ butina_cutoff: Tanimoto distance cutoff for Butina clustering (default 0.4)
523
+
524
+ Returns:
525
+ List of (train_indices, val_indices) tuples
526
+
527
+ Note:
528
+ If scaffold/butina strategy is requested but no SMILES column is found,
529
+ automatically falls back to random split with a warning message.
530
+
531
+ Example:
532
+ >>> folds = get_split_indices(df, n_splits=5, strategy="scaffold")
533
+ >>> for train_idx, val_idx in folds:
534
+ ... X_train, X_val = df.iloc[train_idx], df.iloc[val_idx]
535
+ """
536
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
537
+
538
+ n_samples = len(df)
539
+
540
+ # Random split (original behavior)
541
+ if strategy == "random":
542
+ if n_splits == 1:
543
+ indices = np.arange(n_samples)
544
+ train_idx, val_idx = train_test_split(indices, test_size=test_size, random_state=random_state)
545
+ return [(train_idx, val_idx)]
546
+ else:
547
+ if target_column and df[target_column].dtype in ["object", "category", "bool"]:
548
+ kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
549
+ return list(kfold.split(df, df[target_column]))
550
+ else:
551
+ kfold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
552
+ return list(kfold.split(df))
553
+
554
+ # Scaffold or Butina split requires SMILES - auto-detect if not provided
555
+ if smiles_column is None:
556
+ smiles_column = _find_smiles_column(df.columns.tolist())
557
+
558
+ # Fall back to random split if no SMILES column available
559
+ if smiles_column is None or smiles_column not in df.columns:
560
+ print(f"No 'smiles' column found for strategy='{strategy}', falling back to random split")
561
+ return get_split_indices(
562
+ df,
563
+ n_splits=n_splits,
564
+ strategy="random",
565
+ target_column=target_column,
566
+ test_size=test_size,
567
+ random_state=random_state,
568
+ )
569
+
570
+ smiles_list = df[smiles_column].tolist()
571
+
572
+ # Get group assignments
573
+ if strategy == "scaffold":
574
+ groups = get_scaffold_groups(smiles_list)
575
+ elif strategy == "butina":
576
+ groups = get_butina_clusters(smiles_list, cutoff=butina_cutoff)
577
+ else:
578
+ raise ValueError(f"Unknown strategy: {strategy}. Use 'random', 'scaffold', or 'butina'")
579
+
580
+ # Generate splits using GroupKFold or GroupShuffleSplit
581
+ if n_splits == 1:
582
+ # Single split: use GroupShuffleSplit
583
+ splitter = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
584
+ return list(splitter.split(df, groups=groups))
585
+ else:
586
+ # K-fold: use GroupKFold (ensures no group appears in both train and val)
587
+ # Note: GroupKFold doesn't shuffle, so we shuffle group order first
588
+ unique_groups = np.unique(groups)
589
+ rng = np.random.default_rng(random_state)
590
+ shuffled_group_map = {g: i for i, g in enumerate(rng.permutation(unique_groups))}
591
+ shuffled_groups = np.array([shuffled_group_map[g] for g in groups])
592
+
593
+ gkf = GroupKFold(n_splits=n_splits)
594
+ return list(gkf.split(df, groups=shuffled_groups))
@@ -53,19 +53,25 @@ DEFAULT_HYPERPARAMETERS = {
53
53
  "use_batch_norm": True,
54
54
  # Loss function for regression (L1Loss=MAE, MSELoss=MSE, HuberLoss, SmoothL1Loss)
55
55
  "loss": "L1Loss",
56
+ # Split strategy: "random", "scaffold", or "butina"
57
+ # - random: Standard random split
58
+ # - scaffold: Bemis-Murcko scaffold-based grouping (requires 'smiles' column in data)
59
+ # - butina: Morgan fingerprint clustering (requires 'smiles' column, recommended for ADMET)
60
+ "split_strategy": "butina",
61
+ "butina_cutoff": 0.4, # Tanimoto distance cutoff for Butina clustering
56
62
  # Random seed
57
63
  "seed": 42,
58
64
  }
59
65
 
60
66
  # Template parameters (filled in by Workbench)
61
67
  TEMPLATE_PARAMS = {
62
- "model_type": "classifier",
63
- "target": "class",
68
+ "model_type": "uq_regressor",
69
+ "target": "logd",
64
70
  "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
65
- "id_column": "udm_mol_bat_id",
66
- "compressed_features": [],
67
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-class-pytorch-1-fr/training",
68
- "hyperparameters": {},
71
+ "id_column": "molecule_name",
72
+ "compressed_features": ['fingerprint'],
73
+ "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/logd-pytorch-split-butina/training",
74
+ "hyperparameters": {'split_strategy': 'butina'},
69
75
  }
70
76
 
71
77
 
@@ -234,6 +240,7 @@ if __name__ == "__main__":
234
240
  check_dataframe,
235
241
  compute_classification_metrics,
236
242
  compute_regression_metrics,
243
+ get_split_indices,
237
244
  print_classification_metrics,
238
245
  print_confusion_matrix,
239
246
  print_regression_metrics,
@@ -337,22 +344,29 @@ if __name__ == "__main__":
337
344
  # Get categorical cardinalities
338
345
  categorical_cardinalities = [len(category_mappings.get(col, {})) for col in categorical_cols]
339
346
 
340
- if n_folds == 1:
341
- if "training" in all_df.columns:
342
- print("Using 'training' column for train/val split")
343
- train_idx = np.where(all_df["training"])[0]
344
- val_idx = np.where(~all_df["training"])[0]
345
- else:
346
- print("WARNING: No 'training' column found, using random 80/20 split")
347
- train_idx, val_idx = train_test_split(np.arange(len(all_df)), test_size=0.2, random_state=42)
347
+ # Get split strategy parameters
348
+ split_strategy = hyperparameters.get("split_strategy", "random")
349
+ butina_cutoff = hyperparameters.get("butina_cutoff", 0.4)
350
+
351
+ # Check for pre-defined training column (overrides split strategy)
352
+ if n_folds == 1 and "training" in all_df.columns:
353
+ print("Using 'training' column for train/val split")
354
+ train_idx = np.where(all_df["training"])[0]
355
+ val_idx = np.where(~all_df["training"])[0]
348
356
  folds = [(train_idx, val_idx)]
349
357
  else:
350
- if model_type == "classifier":
351
- kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
352
- folds = list(kfold.split(all_df, all_df[target]))
353
- else:
354
- kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
355
- folds = list(kfold.split(all_df))
358
+ # Use unified split interface (auto-detects 'smiles' column for scaffold/butina)
359
+ target_col = target if model_type == "classifier" else None
360
+ folds = get_split_indices(
361
+ all_df,
362
+ n_splits=n_folds,
363
+ strategy=split_strategy,
364
+ target_column=target_col,
365
+ test_size=0.2,
366
+ random_state=42,
367
+ butina_cutoff=butina_cutoff,
368
+ )
369
+ print(f"Split strategy: {split_strategy}")
356
370
 
357
371
  print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
358
372
 
@@ -16,6 +16,7 @@ from sklearn.metrics import (
16
16
  r2_score,
17
17
  root_mean_squared_error,
18
18
  )
19
+ from sklearn.model_selection import GroupKFold, GroupShuffleSplit
19
20
  from scipy.stats import spearmanr
20
21
 
21
22
 
@@ -367,3 +368,227 @@ def print_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, label_names:
367
368
  for j, col_name in enumerate(label_names):
368
369
  value = conf_mtx[i, j]
369
370
  print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
371
+
372
+
373
+ # =============================================================================
374
+ # Dataset Splitting Utilities for Molecular Data
375
+ # =============================================================================
376
+ def get_scaffold(smiles: str) -> str:
377
+ """Extract Bemis-Murcko scaffold from a SMILES string.
378
+
379
+ Args:
380
+ smiles: SMILES string of the molecule
381
+
382
+ Returns:
383
+ SMILES string of the scaffold, or empty string if molecule is invalid
384
+ """
385
+ from rdkit import Chem
386
+ from rdkit.Chem.Scaffolds import MurckoScaffold
387
+
388
+ mol = Chem.MolFromSmiles(smiles)
389
+ if mol is None:
390
+ return ""
391
+ try:
392
+ scaffold = MurckoScaffold.GetScaffoldForMol(mol)
393
+ return Chem.MolToSmiles(scaffold)
394
+ except Exception:
395
+ return ""
396
+
397
+
398
+ def get_scaffold_groups(smiles_list: list[str]) -> np.ndarray:
399
+ """Assign each molecule to a scaffold group.
400
+
401
+ Args:
402
+ smiles_list: List of SMILES strings
403
+
404
+ Returns:
405
+ Array of group indices (same scaffold = same group)
406
+ """
407
+ scaffold_to_group = {}
408
+ groups = []
409
+
410
+ for smi in smiles_list:
411
+ scaffold = get_scaffold(smi)
412
+ if scaffold not in scaffold_to_group:
413
+ scaffold_to_group[scaffold] = len(scaffold_to_group)
414
+ groups.append(scaffold_to_group[scaffold])
415
+
416
+ n_scaffolds = len(scaffold_to_group)
417
+ print(f"Found {n_scaffolds} unique scaffolds from {len(smiles_list)} molecules")
418
+ return np.array(groups)
419
+
420
+
421
+ def get_butina_clusters(smiles_list: list[str], cutoff: float = 0.4) -> np.ndarray:
422
+ """Cluster molecules using Butina algorithm on Morgan fingerprints.
423
+
424
+ Uses RDKit's Butina clustering with Tanimoto distance on Morgan fingerprints.
425
+ This is Pat Walters' recommended approach for creating diverse train/test splits.
426
+
427
+ Args:
428
+ smiles_list: List of SMILES strings
429
+ cutoff: Tanimoto distance cutoff for clustering (default 0.4)
430
+ Lower values = more clusters = more similar molecules per cluster
431
+
432
+ Returns:
433
+ Array of cluster indices
434
+ """
435
+ from rdkit import Chem, DataStructs
436
+ from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
437
+ from rdkit.ML.Cluster import Butina
438
+
439
+ # Create Morgan fingerprint generator
440
+ fp_gen = GetMorganGenerator(radius=2, fpSize=2048)
441
+
442
+ # Generate Morgan fingerprints
443
+ fps = []
444
+ valid_indices = []
445
+ for i, smi in enumerate(smiles_list):
446
+ mol = Chem.MolFromSmiles(smi)
447
+ if mol is not None:
448
+ fp = fp_gen.GetFingerprint(mol)
449
+ fps.append(fp)
450
+ valid_indices.append(i)
451
+
452
+ if len(fps) == 0:
453
+ raise ValueError("No valid molecules found for clustering")
454
+
455
+ # Compute distance matrix (upper triangle only for efficiency)
456
+ n = len(fps)
457
+ dists = []
458
+ for i in range(1, n):
459
+ sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
460
+ dists.extend([1 - s for s in sims])
461
+
462
+ # Butina clustering
463
+ clusters = Butina.ClusterData(dists, n, cutoff, isDistData=True)
464
+
465
+ # Map back to original indices
466
+ cluster_labels = np.zeros(len(smiles_list), dtype=int)
467
+ for cluster_idx, cluster in enumerate(clusters):
468
+ for mol_idx in cluster:
469
+ original_idx = valid_indices[mol_idx]
470
+ cluster_labels[original_idx] = cluster_idx
471
+
472
+ # Assign invalid molecules to their own clusters
473
+ next_cluster = len(clusters)
474
+ for i in range(len(smiles_list)):
475
+ if i not in valid_indices:
476
+ cluster_labels[i] = next_cluster
477
+ next_cluster += 1
478
+
479
+ n_clusters = len(set(cluster_labels))
480
+ print(f"Butina clustering: {n_clusters} clusters from {len(smiles_list)} molecules (cutoff={cutoff})")
481
+ return cluster_labels
482
+
483
+
484
+ def _find_smiles_column(columns: list[str]) -> str | None:
485
+ """Find SMILES column (case-insensitive match for 'smiles').
486
+
487
+ Args:
488
+ columns: List of column names
489
+
490
+ Returns:
491
+ The matching column name, or None if not found
492
+ """
493
+ return next((c for c in columns if c.lower() == "smiles"), None)
494
+
495
+
496
+ def get_split_indices(
497
+ df: pd.DataFrame,
498
+ n_splits: int = 5,
499
+ strategy: str = "random",
500
+ smiles_column: str | None = None,
501
+ target_column: str | None = None,
502
+ test_size: float = 0.2,
503
+ random_state: int = 42,
504
+ butina_cutoff: float = 0.4,
505
+ ) -> list[tuple[np.ndarray, np.ndarray]]:
506
+ """Get train/validation split indices using various strategies.
507
+
508
+ This is a unified interface for generating splits that can be used across
509
+ all model templates (XGBoost, PyTorch, ChemProp).
510
+
511
+ Args:
512
+ df: DataFrame containing the data
513
+ n_splits: Number of CV folds (1 = single train/val split)
514
+ strategy: Split strategy - one of:
515
+ - "random": Standard random split (default sklearn behavior)
516
+ - "scaffold": Bemis-Murcko scaffold-based grouping
517
+ - "butina": Morgan fingerprint clustering (recommended for ADMET)
518
+ smiles_column: Column containing SMILES. If None, auto-detects 'smiles' (case-insensitive)
519
+ target_column: Column containing target values (for stratification, optional)
520
+ test_size: Fraction for validation set when n_splits=1 (default 0.2)
521
+ random_state: Random seed for reproducibility
522
+ butina_cutoff: Tanimoto distance cutoff for Butina clustering (default 0.4)
523
+
524
+ Returns:
525
+ List of (train_indices, val_indices) tuples
526
+
527
+ Note:
528
+ If scaffold/butina strategy is requested but no SMILES column is found,
529
+ automatically falls back to random split with a warning message.
530
+
531
+ Example:
532
+ >>> folds = get_split_indices(df, n_splits=5, strategy="scaffold")
533
+ >>> for train_idx, val_idx in folds:
534
+ ... X_train, X_val = df.iloc[train_idx], df.iloc[val_idx]
535
+ """
536
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
537
+
538
+ n_samples = len(df)
539
+
540
+ # Random split (original behavior)
541
+ if strategy == "random":
542
+ if n_splits == 1:
543
+ indices = np.arange(n_samples)
544
+ train_idx, val_idx = train_test_split(indices, test_size=test_size, random_state=random_state)
545
+ return [(train_idx, val_idx)]
546
+ else:
547
+ if target_column and df[target_column].dtype in ["object", "category", "bool"]:
548
+ kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
549
+ return list(kfold.split(df, df[target_column]))
550
+ else:
551
+ kfold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
552
+ return list(kfold.split(df))
553
+
554
+ # Scaffold or Butina split requires SMILES - auto-detect if not provided
555
+ if smiles_column is None:
556
+ smiles_column = _find_smiles_column(df.columns.tolist())
557
+
558
+ # Fall back to random split if no SMILES column available
559
+ if smiles_column is None or smiles_column not in df.columns:
560
+ print(f"No 'smiles' column found for strategy='{strategy}', falling back to random split")
561
+ return get_split_indices(
562
+ df,
563
+ n_splits=n_splits,
564
+ strategy="random",
565
+ target_column=target_column,
566
+ test_size=test_size,
567
+ random_state=random_state,
568
+ )
569
+
570
+ smiles_list = df[smiles_column].tolist()
571
+
572
+ # Get group assignments
573
+ if strategy == "scaffold":
574
+ groups = get_scaffold_groups(smiles_list)
575
+ elif strategy == "butina":
576
+ groups = get_butina_clusters(smiles_list, cutoff=butina_cutoff)
577
+ else:
578
+ raise ValueError(f"Unknown strategy: {strategy}. Use 'random', 'scaffold', or 'butina'")
579
+
580
+ # Generate splits using GroupKFold or GroupShuffleSplit
581
+ if n_splits == 1:
582
+ # Single split: use GroupShuffleSplit
583
+ splitter = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
584
+ return list(splitter.split(df, groups=groups))
585
+ else:
586
+ # K-fold: use GroupKFold (ensures no group appears in both train and val)
587
+ # Note: GroupKFold doesn't shuffle, so we shuffle group order first
588
+ unique_groups = np.unique(groups)
589
+ rng = np.random.default_rng(random_state)
590
+ shuffled_group_map = {g: i for i, g in enumerate(rng.permutation(unique_groups))}
591
+ shuffled_groups = np.array([shuffled_group_map[g] for g in groups])
592
+
593
+ gkf = GroupKFold(n_splits=n_splits)
594
+ return list(gkf.split(df, groups=shuffled_groups))