workbench 0.8.231__py3-none-any.whl → 0.8.236__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/smart_aggregator.py +17 -12
- workbench/api/endpoint.py +13 -4
- workbench/cached/cached_model.py +2 -2
- workbench/core/artifacts/endpoint_core.py +30 -5
- workbench/model_script_utils/model_script_utils.py +225 -0
- workbench/model_script_utils/uq_harness.py +39 -21
- workbench/model_scripts/chemprop/chemprop.template +29 -14
- workbench/model_scripts/chemprop/generated_model_script.py +35 -18
- workbench/model_scripts/chemprop/model_script_utils.py +225 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +34 -20
- workbench/model_scripts/pytorch_model/model_script_utils.py +225 -0
- workbench/model_scripts/pytorch_model/pytorch.template +28 -14
- workbench/model_scripts/pytorch_model/uq_harness.py +39 -21
- workbench/model_scripts/xgb_model/generated_model_script.py +35 -22
- workbench/model_scripts/xgb_model/model_script_utils.py +225 -0
- workbench/model_scripts/xgb_model/uq_harness.py +39 -21
- workbench/model_scripts/xgb_model/xgb_model.template +29 -18
- workbench/themes/dark/custom.css +29 -0
- workbench/themes/light/custom.css +29 -0
- workbench/themes/midnight_blue/custom.css +28 -0
- workbench/utils/markdown_utils.py +5 -1
- workbench/utils/model_utils.py +9 -0
- workbench/utils/theme_manager.py +95 -0
- workbench/web_interface/components/component_interface.py +3 -0
- workbench/web_interface/components/plugin_interface.py +26 -0
- workbench/web_interface/components/plugins/confusion_matrix.py +14 -8
- workbench/web_interface/components/plugins/model_details.py +18 -5
- workbench/web_interface/components/plugins/model_plot.py +156 -0
- workbench/web_interface/components/plugins/scatter_plot.py +9 -2
- workbench/web_interface/components/plugins/shap_summary_plot.py +12 -4
- workbench/web_interface/components/settings_menu.py +10 -49
- {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/METADATA +1 -1
- {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/RECORD +37 -37
- workbench/web_interface/components/model_plot.py +0 -75
- {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/WHEEL +0 -0
- {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/top_level.txt +0 -0
|
@@ -20,6 +20,7 @@ import torch
|
|
|
20
20
|
from chemprop import data, models
|
|
21
21
|
|
|
22
22
|
from model_script_utils import (
|
|
23
|
+
cap_std_outliers,
|
|
23
24
|
expand_proba_column,
|
|
24
25
|
input_fn,
|
|
25
26
|
output_fn,
|
|
@@ -43,6 +44,12 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
43
44
|
"ffn_num_layers": 2,
|
|
44
45
|
# Loss function for regression (mae, mse)
|
|
45
46
|
"criterion": "mae",
|
|
47
|
+
# Split strategy: "random", "scaffold", or "butina"
|
|
48
|
+
# - random: Standard random split
|
|
49
|
+
# - scaffold: Bemis-Murcko scaffold-based grouping
|
|
50
|
+
# - butina: Morgan fingerprint clustering (recommended for ADMET)
|
|
51
|
+
"split_strategy": "butina",
|
|
52
|
+
"butina_cutoff": 0.4, # Tanimoto distance cutoff for Butina clustering
|
|
46
53
|
# Random seed
|
|
47
54
|
"seed": 42,
|
|
48
55
|
# Foundation model support
|
|
@@ -58,11 +65,11 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
58
65
|
# Template parameters (filled in by Workbench)
|
|
59
66
|
TEMPLATE_PARAMS = {
|
|
60
67
|
"model_type": "uq_regressor",
|
|
61
|
-
"targets": ['
|
|
68
|
+
"targets": ['logd'],
|
|
62
69
|
"feature_list": ['smiles'],
|
|
63
|
-
"id_column": "
|
|
64
|
-
"model_metrics_s3_path": "s3://
|
|
65
|
-
"hyperparameters": {},
|
|
70
|
+
"id_column": "molecule_name",
|
|
71
|
+
"model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/logd-chemprop-split-butina/training",
|
|
72
|
+
"hyperparameters": {'split_strategy': 'butina'},
|
|
66
73
|
}
|
|
67
74
|
|
|
68
75
|
|
|
@@ -245,6 +252,7 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
245
252
|
preds_std = np.std(np.stack(all_preds), axis=0)
|
|
246
253
|
if preds.ndim == 1:
|
|
247
254
|
preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
|
|
255
|
+
preds_std = cap_std_outliers(preds_std)
|
|
248
256
|
|
|
249
257
|
print(f"Inference complete: {preds.shape[0]} predictions")
|
|
250
258
|
|
|
@@ -303,6 +311,7 @@ if __name__ == "__main__":
|
|
|
303
311
|
check_dataframe,
|
|
304
312
|
compute_classification_metrics,
|
|
305
313
|
compute_regression_metrics,
|
|
314
|
+
get_split_indices,
|
|
306
315
|
print_classification_metrics,
|
|
307
316
|
print_confusion_matrix,
|
|
308
317
|
print_regression_metrics,
|
|
@@ -516,22 +525,29 @@ if __name__ == "__main__":
|
|
|
516
525
|
n_folds = hyperparameters["n_folds"]
|
|
517
526
|
batch_size = hyperparameters["batch_size"]
|
|
518
527
|
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
528
|
+
# Get split strategy parameters
|
|
529
|
+
split_strategy = hyperparameters.get("split_strategy", "random")
|
|
530
|
+
butina_cutoff = hyperparameters.get("butina_cutoff", 0.4)
|
|
531
|
+
|
|
532
|
+
# Check for pre-defined training column (overrides split strategy)
|
|
533
|
+
if n_folds == 1 and "training" in all_df.columns:
|
|
534
|
+
print("Using 'training' column for train/val split")
|
|
535
|
+
train_idx = np.where(all_df["training"])[0]
|
|
536
|
+
val_idx = np.where(~all_df["training"])[0]
|
|
527
537
|
folds = [(train_idx, val_idx)]
|
|
528
538
|
else:
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
539
|
+
# Use unified split interface (auto-detects 'smiles' column for scaffold/butina)
|
|
540
|
+
target_col = target_columns[0] if model_type == "classifier" else None
|
|
541
|
+
folds = get_split_indices(
|
|
542
|
+
all_df,
|
|
543
|
+
n_splits=n_folds,
|
|
544
|
+
strategy=split_strategy,
|
|
545
|
+
target_column=target_col,
|
|
546
|
+
test_size=0.2,
|
|
547
|
+
random_state=42,
|
|
548
|
+
butina_cutoff=butina_cutoff,
|
|
549
|
+
)
|
|
550
|
+
print(f"Split strategy: {split_strategy}")
|
|
535
551
|
|
|
536
552
|
print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
|
|
537
553
|
|
|
@@ -701,6 +717,7 @@ if __name__ == "__main__":
|
|
|
701
717
|
preds_std = np.std(np.stack(all_ens_preds), axis=0)
|
|
702
718
|
if preds_std.ndim == 1:
|
|
703
719
|
preds_std = preds_std.reshape(-1, 1)
|
|
720
|
+
preds_std = cap_std_outliers(preds_std)
|
|
704
721
|
|
|
705
722
|
print("\n--- Per-target metrics ---")
|
|
706
723
|
for t_idx, t_name in enumerate(target_columns):
|
|
@@ -16,6 +16,7 @@ from sklearn.metrics import (
|
|
|
16
16
|
r2_score,
|
|
17
17
|
root_mean_squared_error,
|
|
18
18
|
)
|
|
19
|
+
from sklearn.model_selection import GroupKFold, GroupShuffleSplit
|
|
19
20
|
from scipy.stats import spearmanr
|
|
20
21
|
|
|
21
22
|
|
|
@@ -367,3 +368,227 @@ def print_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, label_names:
|
|
|
367
368
|
for j, col_name in enumerate(label_names):
|
|
368
369
|
value = conf_mtx[i, j]
|
|
369
370
|
print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
# =============================================================================
|
|
374
|
+
# Dataset Splitting Utilities for Molecular Data
|
|
375
|
+
# =============================================================================
|
|
376
|
+
def get_scaffold(smiles: str) -> str:
|
|
377
|
+
"""Extract Bemis-Murcko scaffold from a SMILES string.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
smiles: SMILES string of the molecule
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
SMILES string of the scaffold, or empty string if molecule is invalid
|
|
384
|
+
"""
|
|
385
|
+
from rdkit import Chem
|
|
386
|
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
|
387
|
+
|
|
388
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
389
|
+
if mol is None:
|
|
390
|
+
return ""
|
|
391
|
+
try:
|
|
392
|
+
scaffold = MurckoScaffold.GetScaffoldForMol(mol)
|
|
393
|
+
return Chem.MolToSmiles(scaffold)
|
|
394
|
+
except Exception:
|
|
395
|
+
return ""
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def get_scaffold_groups(smiles_list: list[str]) -> np.ndarray:
|
|
399
|
+
"""Assign each molecule to a scaffold group.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
smiles_list: List of SMILES strings
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
Array of group indices (same scaffold = same group)
|
|
406
|
+
"""
|
|
407
|
+
scaffold_to_group = {}
|
|
408
|
+
groups = []
|
|
409
|
+
|
|
410
|
+
for smi in smiles_list:
|
|
411
|
+
scaffold = get_scaffold(smi)
|
|
412
|
+
if scaffold not in scaffold_to_group:
|
|
413
|
+
scaffold_to_group[scaffold] = len(scaffold_to_group)
|
|
414
|
+
groups.append(scaffold_to_group[scaffold])
|
|
415
|
+
|
|
416
|
+
n_scaffolds = len(scaffold_to_group)
|
|
417
|
+
print(f"Found {n_scaffolds} unique scaffolds from {len(smiles_list)} molecules")
|
|
418
|
+
return np.array(groups)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def get_butina_clusters(smiles_list: list[str], cutoff: float = 0.4) -> np.ndarray:
|
|
422
|
+
"""Cluster molecules using Butina algorithm on Morgan fingerprints.
|
|
423
|
+
|
|
424
|
+
Uses RDKit's Butina clustering with Tanimoto distance on Morgan fingerprints.
|
|
425
|
+
This is Pat Walters' recommended approach for creating diverse train/test splits.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
smiles_list: List of SMILES strings
|
|
429
|
+
cutoff: Tanimoto distance cutoff for clustering (default 0.4)
|
|
430
|
+
Lower values = more clusters = more similar molecules per cluster
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
Array of cluster indices
|
|
434
|
+
"""
|
|
435
|
+
from rdkit import Chem, DataStructs
|
|
436
|
+
from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
|
|
437
|
+
from rdkit.ML.Cluster import Butina
|
|
438
|
+
|
|
439
|
+
# Create Morgan fingerprint generator
|
|
440
|
+
fp_gen = GetMorganGenerator(radius=2, fpSize=2048)
|
|
441
|
+
|
|
442
|
+
# Generate Morgan fingerprints
|
|
443
|
+
fps = []
|
|
444
|
+
valid_indices = []
|
|
445
|
+
for i, smi in enumerate(smiles_list):
|
|
446
|
+
mol = Chem.MolFromSmiles(smi)
|
|
447
|
+
if mol is not None:
|
|
448
|
+
fp = fp_gen.GetFingerprint(mol)
|
|
449
|
+
fps.append(fp)
|
|
450
|
+
valid_indices.append(i)
|
|
451
|
+
|
|
452
|
+
if len(fps) == 0:
|
|
453
|
+
raise ValueError("No valid molecules found for clustering")
|
|
454
|
+
|
|
455
|
+
# Compute distance matrix (upper triangle only for efficiency)
|
|
456
|
+
n = len(fps)
|
|
457
|
+
dists = []
|
|
458
|
+
for i in range(1, n):
|
|
459
|
+
sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
|
|
460
|
+
dists.extend([1 - s for s in sims])
|
|
461
|
+
|
|
462
|
+
# Butina clustering
|
|
463
|
+
clusters = Butina.ClusterData(dists, n, cutoff, isDistData=True)
|
|
464
|
+
|
|
465
|
+
# Map back to original indices
|
|
466
|
+
cluster_labels = np.zeros(len(smiles_list), dtype=int)
|
|
467
|
+
for cluster_idx, cluster in enumerate(clusters):
|
|
468
|
+
for mol_idx in cluster:
|
|
469
|
+
original_idx = valid_indices[mol_idx]
|
|
470
|
+
cluster_labels[original_idx] = cluster_idx
|
|
471
|
+
|
|
472
|
+
# Assign invalid molecules to their own clusters
|
|
473
|
+
next_cluster = len(clusters)
|
|
474
|
+
for i in range(len(smiles_list)):
|
|
475
|
+
if i not in valid_indices:
|
|
476
|
+
cluster_labels[i] = next_cluster
|
|
477
|
+
next_cluster += 1
|
|
478
|
+
|
|
479
|
+
n_clusters = len(set(cluster_labels))
|
|
480
|
+
print(f"Butina clustering: {n_clusters} clusters from {len(smiles_list)} molecules (cutoff={cutoff})")
|
|
481
|
+
return cluster_labels
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _find_smiles_column(columns: list[str]) -> str | None:
|
|
485
|
+
"""Find SMILES column (case-insensitive match for 'smiles').
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
columns: List of column names
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
The matching column name, or None if not found
|
|
492
|
+
"""
|
|
493
|
+
return next((c for c in columns if c.lower() == "smiles"), None)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def get_split_indices(
|
|
497
|
+
df: pd.DataFrame,
|
|
498
|
+
n_splits: int = 5,
|
|
499
|
+
strategy: str = "random",
|
|
500
|
+
smiles_column: str | None = None,
|
|
501
|
+
target_column: str | None = None,
|
|
502
|
+
test_size: float = 0.2,
|
|
503
|
+
random_state: int = 42,
|
|
504
|
+
butina_cutoff: float = 0.4,
|
|
505
|
+
) -> list[tuple[np.ndarray, np.ndarray]]:
|
|
506
|
+
"""Get train/validation split indices using various strategies.
|
|
507
|
+
|
|
508
|
+
This is a unified interface for generating splits that can be used across
|
|
509
|
+
all model templates (XGBoost, PyTorch, ChemProp).
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
df: DataFrame containing the data
|
|
513
|
+
n_splits: Number of CV folds (1 = single train/val split)
|
|
514
|
+
strategy: Split strategy - one of:
|
|
515
|
+
- "random": Standard random split (default sklearn behavior)
|
|
516
|
+
- "scaffold": Bemis-Murcko scaffold-based grouping
|
|
517
|
+
- "butina": Morgan fingerprint clustering (recommended for ADMET)
|
|
518
|
+
smiles_column: Column containing SMILES. If None, auto-detects 'smiles' (case-insensitive)
|
|
519
|
+
target_column: Column containing target values (for stratification, optional)
|
|
520
|
+
test_size: Fraction for validation set when n_splits=1 (default 0.2)
|
|
521
|
+
random_state: Random seed for reproducibility
|
|
522
|
+
butina_cutoff: Tanimoto distance cutoff for Butina clustering (default 0.4)
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
List of (train_indices, val_indices) tuples
|
|
526
|
+
|
|
527
|
+
Note:
|
|
528
|
+
If scaffold/butina strategy is requested but no SMILES column is found,
|
|
529
|
+
automatically falls back to random split with a warning message.
|
|
530
|
+
|
|
531
|
+
Example:
|
|
532
|
+
>>> folds = get_split_indices(df, n_splits=5, strategy="scaffold")
|
|
533
|
+
>>> for train_idx, val_idx in folds:
|
|
534
|
+
... X_train, X_val = df.iloc[train_idx], df.iloc[val_idx]
|
|
535
|
+
"""
|
|
536
|
+
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
|
|
537
|
+
|
|
538
|
+
n_samples = len(df)
|
|
539
|
+
|
|
540
|
+
# Random split (original behavior)
|
|
541
|
+
if strategy == "random":
|
|
542
|
+
if n_splits == 1:
|
|
543
|
+
indices = np.arange(n_samples)
|
|
544
|
+
train_idx, val_idx = train_test_split(indices, test_size=test_size, random_state=random_state)
|
|
545
|
+
return [(train_idx, val_idx)]
|
|
546
|
+
else:
|
|
547
|
+
if target_column and df[target_column].dtype in ["object", "category", "bool"]:
|
|
548
|
+
kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
|
|
549
|
+
return list(kfold.split(df, df[target_column]))
|
|
550
|
+
else:
|
|
551
|
+
kfold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
|
|
552
|
+
return list(kfold.split(df))
|
|
553
|
+
|
|
554
|
+
# Scaffold or Butina split requires SMILES - auto-detect if not provided
|
|
555
|
+
if smiles_column is None:
|
|
556
|
+
smiles_column = _find_smiles_column(df.columns.tolist())
|
|
557
|
+
|
|
558
|
+
# Fall back to random split if no SMILES column available
|
|
559
|
+
if smiles_column is None or smiles_column not in df.columns:
|
|
560
|
+
print(f"No 'smiles' column found for strategy='{strategy}', falling back to random split")
|
|
561
|
+
return get_split_indices(
|
|
562
|
+
df,
|
|
563
|
+
n_splits=n_splits,
|
|
564
|
+
strategy="random",
|
|
565
|
+
target_column=target_column,
|
|
566
|
+
test_size=test_size,
|
|
567
|
+
random_state=random_state,
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
smiles_list = df[smiles_column].tolist()
|
|
571
|
+
|
|
572
|
+
# Get group assignments
|
|
573
|
+
if strategy == "scaffold":
|
|
574
|
+
groups = get_scaffold_groups(smiles_list)
|
|
575
|
+
elif strategy == "butina":
|
|
576
|
+
groups = get_butina_clusters(smiles_list, cutoff=butina_cutoff)
|
|
577
|
+
else:
|
|
578
|
+
raise ValueError(f"Unknown strategy: {strategy}. Use 'random', 'scaffold', or 'butina'")
|
|
579
|
+
|
|
580
|
+
# Generate splits using GroupKFold or GroupShuffleSplit
|
|
581
|
+
if n_splits == 1:
|
|
582
|
+
# Single split: use GroupShuffleSplit
|
|
583
|
+
splitter = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
|
|
584
|
+
return list(splitter.split(df, groups=groups))
|
|
585
|
+
else:
|
|
586
|
+
# K-fold: use GroupKFold (ensures no group appears in both train and val)
|
|
587
|
+
# Note: GroupKFold doesn't shuffle, so we shuffle group order first
|
|
588
|
+
unique_groups = np.unique(groups)
|
|
589
|
+
rng = np.random.default_rng(random_state)
|
|
590
|
+
shuffled_group_map = {g: i for i, g in enumerate(rng.permutation(unique_groups))}
|
|
591
|
+
shuffled_groups = np.array([shuffled_group_map[g] for g in groups])
|
|
592
|
+
|
|
593
|
+
gkf = GroupKFold(n_splits=n_splits)
|
|
594
|
+
return list(gkf.split(df, groups=shuffled_groups))
|
|
@@ -53,19 +53,25 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
53
53
|
"use_batch_norm": True,
|
|
54
54
|
# Loss function for regression (L1Loss=MAE, MSELoss=MSE, HuberLoss, SmoothL1Loss)
|
|
55
55
|
"loss": "L1Loss",
|
|
56
|
+
# Split strategy: "random", "scaffold", or "butina"
|
|
57
|
+
# - random: Standard random split
|
|
58
|
+
# - scaffold: Bemis-Murcko scaffold-based grouping (requires 'smiles' column in data)
|
|
59
|
+
# - butina: Morgan fingerprint clustering (requires 'smiles' column, recommended for ADMET)
|
|
60
|
+
"split_strategy": "butina",
|
|
61
|
+
"butina_cutoff": 0.4, # Tanimoto distance cutoff for Butina clustering
|
|
56
62
|
# Random seed
|
|
57
63
|
"seed": 42,
|
|
58
64
|
}
|
|
59
65
|
|
|
60
66
|
# Template parameters (filled in by Workbench)
|
|
61
67
|
TEMPLATE_PARAMS = {
|
|
62
|
-
"model_type": "
|
|
63
|
-
"target": "
|
|
68
|
+
"model_type": "uq_regressor",
|
|
69
|
+
"target": "logd",
|
|
64
70
|
"features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
|
|
65
|
-
"id_column": "
|
|
66
|
-
"compressed_features": [],
|
|
67
|
-
"model_metrics_s3_path": "s3://
|
|
68
|
-
"hyperparameters": {},
|
|
71
|
+
"id_column": "molecule_name",
|
|
72
|
+
"compressed_features": ['fingerprint'],
|
|
73
|
+
"model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/logd-pytorch-split-butina/training",
|
|
74
|
+
"hyperparameters": {'split_strategy': 'butina'},
|
|
69
75
|
}
|
|
70
76
|
|
|
71
77
|
|
|
@@ -234,6 +240,7 @@ if __name__ == "__main__":
|
|
|
234
240
|
check_dataframe,
|
|
235
241
|
compute_classification_metrics,
|
|
236
242
|
compute_regression_metrics,
|
|
243
|
+
get_split_indices,
|
|
237
244
|
print_classification_metrics,
|
|
238
245
|
print_confusion_matrix,
|
|
239
246
|
print_regression_metrics,
|
|
@@ -337,22 +344,29 @@ if __name__ == "__main__":
|
|
|
337
344
|
# Get categorical cardinalities
|
|
338
345
|
categorical_cardinalities = [len(category_mappings.get(col, {})) for col in categorical_cols]
|
|
339
346
|
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
347
|
+
# Get split strategy parameters
|
|
348
|
+
split_strategy = hyperparameters.get("split_strategy", "random")
|
|
349
|
+
butina_cutoff = hyperparameters.get("butina_cutoff", 0.4)
|
|
350
|
+
|
|
351
|
+
# Check for pre-defined training column (overrides split strategy)
|
|
352
|
+
if n_folds == 1 and "training" in all_df.columns:
|
|
353
|
+
print("Using 'training' column for train/val split")
|
|
354
|
+
train_idx = np.where(all_df["training"])[0]
|
|
355
|
+
val_idx = np.where(~all_df["training"])[0]
|
|
348
356
|
folds = [(train_idx, val_idx)]
|
|
349
357
|
else:
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
358
|
+
# Use unified split interface (auto-detects 'smiles' column for scaffold/butina)
|
|
359
|
+
target_col = target if model_type == "classifier" else None
|
|
360
|
+
folds = get_split_indices(
|
|
361
|
+
all_df,
|
|
362
|
+
n_splits=n_folds,
|
|
363
|
+
strategy=split_strategy,
|
|
364
|
+
target_column=target_col,
|
|
365
|
+
test_size=0.2,
|
|
366
|
+
random_state=42,
|
|
367
|
+
butina_cutoff=butina_cutoff,
|
|
368
|
+
)
|
|
369
|
+
print(f"Split strategy: {split_strategy}")
|
|
356
370
|
|
|
357
371
|
print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
|
|
358
372
|
|
|
@@ -16,6 +16,7 @@ from sklearn.metrics import (
|
|
|
16
16
|
r2_score,
|
|
17
17
|
root_mean_squared_error,
|
|
18
18
|
)
|
|
19
|
+
from sklearn.model_selection import GroupKFold, GroupShuffleSplit
|
|
19
20
|
from scipy.stats import spearmanr
|
|
20
21
|
|
|
21
22
|
|
|
@@ -367,3 +368,227 @@ def print_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, label_names:
|
|
|
367
368
|
for j, col_name in enumerate(label_names):
|
|
368
369
|
value = conf_mtx[i, j]
|
|
369
370
|
print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
# =============================================================================
|
|
374
|
+
# Dataset Splitting Utilities for Molecular Data
|
|
375
|
+
# =============================================================================
|
|
376
|
+
def get_scaffold(smiles: str) -> str:
|
|
377
|
+
"""Extract Bemis-Murcko scaffold from a SMILES string.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
smiles: SMILES string of the molecule
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
SMILES string of the scaffold, or empty string if molecule is invalid
|
|
384
|
+
"""
|
|
385
|
+
from rdkit import Chem
|
|
386
|
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
|
387
|
+
|
|
388
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
389
|
+
if mol is None:
|
|
390
|
+
return ""
|
|
391
|
+
try:
|
|
392
|
+
scaffold = MurckoScaffold.GetScaffoldForMol(mol)
|
|
393
|
+
return Chem.MolToSmiles(scaffold)
|
|
394
|
+
except Exception:
|
|
395
|
+
return ""
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def get_scaffold_groups(smiles_list: list[str]) -> np.ndarray:
|
|
399
|
+
"""Assign each molecule to a scaffold group.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
smiles_list: List of SMILES strings
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
Array of group indices (same scaffold = same group)
|
|
406
|
+
"""
|
|
407
|
+
scaffold_to_group = {}
|
|
408
|
+
groups = []
|
|
409
|
+
|
|
410
|
+
for smi in smiles_list:
|
|
411
|
+
scaffold = get_scaffold(smi)
|
|
412
|
+
if scaffold not in scaffold_to_group:
|
|
413
|
+
scaffold_to_group[scaffold] = len(scaffold_to_group)
|
|
414
|
+
groups.append(scaffold_to_group[scaffold])
|
|
415
|
+
|
|
416
|
+
n_scaffolds = len(scaffold_to_group)
|
|
417
|
+
print(f"Found {n_scaffolds} unique scaffolds from {len(smiles_list)} molecules")
|
|
418
|
+
return np.array(groups)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def get_butina_clusters(smiles_list: list[str], cutoff: float = 0.4) -> np.ndarray:
|
|
422
|
+
"""Cluster molecules using Butina algorithm on Morgan fingerprints.
|
|
423
|
+
|
|
424
|
+
Uses RDKit's Butina clustering with Tanimoto distance on Morgan fingerprints.
|
|
425
|
+
This is Pat Walters' recommended approach for creating diverse train/test splits.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
smiles_list: List of SMILES strings
|
|
429
|
+
cutoff: Tanimoto distance cutoff for clustering (default 0.4)
|
|
430
|
+
Lower values = more clusters = more similar molecules per cluster
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
Array of cluster indices
|
|
434
|
+
"""
|
|
435
|
+
from rdkit import Chem, DataStructs
|
|
436
|
+
from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
|
|
437
|
+
from rdkit.ML.Cluster import Butina
|
|
438
|
+
|
|
439
|
+
# Create Morgan fingerprint generator
|
|
440
|
+
fp_gen = GetMorganGenerator(radius=2, fpSize=2048)
|
|
441
|
+
|
|
442
|
+
# Generate Morgan fingerprints
|
|
443
|
+
fps = []
|
|
444
|
+
valid_indices = []
|
|
445
|
+
for i, smi in enumerate(smiles_list):
|
|
446
|
+
mol = Chem.MolFromSmiles(smi)
|
|
447
|
+
if mol is not None:
|
|
448
|
+
fp = fp_gen.GetFingerprint(mol)
|
|
449
|
+
fps.append(fp)
|
|
450
|
+
valid_indices.append(i)
|
|
451
|
+
|
|
452
|
+
if len(fps) == 0:
|
|
453
|
+
raise ValueError("No valid molecules found for clustering")
|
|
454
|
+
|
|
455
|
+
# Compute distance matrix (upper triangle only for efficiency)
|
|
456
|
+
n = len(fps)
|
|
457
|
+
dists = []
|
|
458
|
+
for i in range(1, n):
|
|
459
|
+
sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
|
|
460
|
+
dists.extend([1 - s for s in sims])
|
|
461
|
+
|
|
462
|
+
# Butina clustering
|
|
463
|
+
clusters = Butina.ClusterData(dists, n, cutoff, isDistData=True)
|
|
464
|
+
|
|
465
|
+
# Map back to original indices
|
|
466
|
+
cluster_labels = np.zeros(len(smiles_list), dtype=int)
|
|
467
|
+
for cluster_idx, cluster in enumerate(clusters):
|
|
468
|
+
for mol_idx in cluster:
|
|
469
|
+
original_idx = valid_indices[mol_idx]
|
|
470
|
+
cluster_labels[original_idx] = cluster_idx
|
|
471
|
+
|
|
472
|
+
# Assign invalid molecules to their own clusters
|
|
473
|
+
next_cluster = len(clusters)
|
|
474
|
+
for i in range(len(smiles_list)):
|
|
475
|
+
if i not in valid_indices:
|
|
476
|
+
cluster_labels[i] = next_cluster
|
|
477
|
+
next_cluster += 1
|
|
478
|
+
|
|
479
|
+
n_clusters = len(set(cluster_labels))
|
|
480
|
+
print(f"Butina clustering: {n_clusters} clusters from {len(smiles_list)} molecules (cutoff={cutoff})")
|
|
481
|
+
return cluster_labels
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _find_smiles_column(columns: list[str]) -> str | None:
|
|
485
|
+
"""Find SMILES column (case-insensitive match for 'smiles').
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
columns: List of column names
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
The matching column name, or None if not found
|
|
492
|
+
"""
|
|
493
|
+
return next((c for c in columns if c.lower() == "smiles"), None)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def get_split_indices(
|
|
497
|
+
df: pd.DataFrame,
|
|
498
|
+
n_splits: int = 5,
|
|
499
|
+
strategy: str = "random",
|
|
500
|
+
smiles_column: str | None = None,
|
|
501
|
+
target_column: str | None = None,
|
|
502
|
+
test_size: float = 0.2,
|
|
503
|
+
random_state: int = 42,
|
|
504
|
+
butina_cutoff: float = 0.4,
|
|
505
|
+
) -> list[tuple[np.ndarray, np.ndarray]]:
|
|
506
|
+
"""Get train/validation split indices using various strategies.
|
|
507
|
+
|
|
508
|
+
This is a unified interface for generating splits that can be used across
|
|
509
|
+
all model templates (XGBoost, PyTorch, ChemProp).
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
df: DataFrame containing the data
|
|
513
|
+
n_splits: Number of CV folds (1 = single train/val split)
|
|
514
|
+
strategy: Split strategy - one of:
|
|
515
|
+
- "random": Standard random split (default sklearn behavior)
|
|
516
|
+
- "scaffold": Bemis-Murcko scaffold-based grouping
|
|
517
|
+
- "butina": Morgan fingerprint clustering (recommended for ADMET)
|
|
518
|
+
smiles_column: Column containing SMILES. If None, auto-detects 'smiles' (case-insensitive)
|
|
519
|
+
target_column: Column containing target values (for stratification, optional)
|
|
520
|
+
test_size: Fraction for validation set when n_splits=1 (default 0.2)
|
|
521
|
+
random_state: Random seed for reproducibility
|
|
522
|
+
butina_cutoff: Tanimoto distance cutoff for Butina clustering (default 0.4)
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
List of (train_indices, val_indices) tuples
|
|
526
|
+
|
|
527
|
+
Note:
|
|
528
|
+
If scaffold/butina strategy is requested but no SMILES column is found,
|
|
529
|
+
automatically falls back to random split with a warning message.
|
|
530
|
+
|
|
531
|
+
Example:
|
|
532
|
+
>>> folds = get_split_indices(df, n_splits=5, strategy="scaffold")
|
|
533
|
+
>>> for train_idx, val_idx in folds:
|
|
534
|
+
... X_train, X_val = df.iloc[train_idx], df.iloc[val_idx]
|
|
535
|
+
"""
|
|
536
|
+
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
|
|
537
|
+
|
|
538
|
+
n_samples = len(df)
|
|
539
|
+
|
|
540
|
+
# Random split (original behavior)
|
|
541
|
+
if strategy == "random":
|
|
542
|
+
if n_splits == 1:
|
|
543
|
+
indices = np.arange(n_samples)
|
|
544
|
+
train_idx, val_idx = train_test_split(indices, test_size=test_size, random_state=random_state)
|
|
545
|
+
return [(train_idx, val_idx)]
|
|
546
|
+
else:
|
|
547
|
+
if target_column and df[target_column].dtype in ["object", "category", "bool"]:
|
|
548
|
+
kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
|
|
549
|
+
return list(kfold.split(df, df[target_column]))
|
|
550
|
+
else:
|
|
551
|
+
kfold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
|
|
552
|
+
return list(kfold.split(df))
|
|
553
|
+
|
|
554
|
+
# Scaffold or Butina split requires SMILES - auto-detect if not provided
|
|
555
|
+
if smiles_column is None:
|
|
556
|
+
smiles_column = _find_smiles_column(df.columns.tolist())
|
|
557
|
+
|
|
558
|
+
# Fall back to random split if no SMILES column available
|
|
559
|
+
if smiles_column is None or smiles_column not in df.columns:
|
|
560
|
+
print(f"No 'smiles' column found for strategy='{strategy}', falling back to random split")
|
|
561
|
+
return get_split_indices(
|
|
562
|
+
df,
|
|
563
|
+
n_splits=n_splits,
|
|
564
|
+
strategy="random",
|
|
565
|
+
target_column=target_column,
|
|
566
|
+
test_size=test_size,
|
|
567
|
+
random_state=random_state,
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
smiles_list = df[smiles_column].tolist()
|
|
571
|
+
|
|
572
|
+
# Get group assignments
|
|
573
|
+
if strategy == "scaffold":
|
|
574
|
+
groups = get_scaffold_groups(smiles_list)
|
|
575
|
+
elif strategy == "butina":
|
|
576
|
+
groups = get_butina_clusters(smiles_list, cutoff=butina_cutoff)
|
|
577
|
+
else:
|
|
578
|
+
raise ValueError(f"Unknown strategy: {strategy}. Use 'random', 'scaffold', or 'butina'")
|
|
579
|
+
|
|
580
|
+
# Generate splits using GroupKFold or GroupShuffleSplit
|
|
581
|
+
if n_splits == 1:
|
|
582
|
+
# Single split: use GroupShuffleSplit
|
|
583
|
+
splitter = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
|
|
584
|
+
return list(splitter.split(df, groups=groups))
|
|
585
|
+
else:
|
|
586
|
+
# K-fold: use GroupKFold (ensures no group appears in both train and val)
|
|
587
|
+
# Note: GroupKFold doesn't shuffle, so we shuffle group order first
|
|
588
|
+
unique_groups = np.unique(groups)
|
|
589
|
+
rng = np.random.default_rng(random_state)
|
|
590
|
+
shuffled_group_map = {g: i for i, g in enumerate(rng.permutation(unique_groups))}
|
|
591
|
+
shuffled_groups = np.array([shuffled_group_map[g] for g in groups])
|
|
592
|
+
|
|
593
|
+
gkf = GroupKFold(n_splits=n_splits)
|
|
594
|
+
return list(gkf.split(df, groups=shuffled_groups))
|