workbench 0.8.217__py3-none-any.whl → 0.8.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
- workbench/algorithms/dataframe/projection_2d.py +8 -2
- workbench/algorithms/dataframe/proximity.py +3 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/feature_set.py +0 -1
- workbench/core/artifacts/endpoint_core.py +2 -2
- workbench/core/artifacts/feature_set_core.py +185 -230
- workbench/core/transforms/features_to_model/features_to_model.py +2 -8
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
- workbench/model_script_utils/model_script_utils.py +15 -11
- workbench/model_scripts/chemprop/chemprop.template +195 -70
- workbench/model_scripts/chemprop/generated_model_script.py +198 -73
- workbench/model_scripts/chemprop/model_script_utils.py +15 -11
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +80 -43
- workbench/model_scripts/pytorch_model/generated_model_script.py +2 -2
- workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
- workbench/model_scripts/xgb_model/generated_model_script.py +7 -7
- workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +71 -2
- workbench/themes/light/custom.css +7 -1
- workbench/themes/midnight_blue/custom.css +34 -0
- workbench/utils/chem_utils/fingerprints.py +80 -43
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/meta_model_simulator.py +41 -13
- workbench/utils/model_utils.py +0 -1
- workbench/utils/plot_utils.py +146 -28
- workbench/utils/shap_utils.py +1 -55
- workbench/utils/theme_manager.py +95 -30
- workbench/web_interface/components/plugins/scatter_plot.py +152 -66
- workbench/web_interface/components/settings_menu.py +184 -0
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/METADATA +4 -13
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/RECORD +38 -37
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/entry_points.txt +1 -0
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/WHEEL +0 -0
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/top_level.txt +0 -0
|
@@ -34,7 +34,7 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
34
34
|
"max_epochs": 400,
|
|
35
35
|
"patience": 50,
|
|
36
36
|
"batch_size": 32,
|
|
37
|
-
# Message Passing
|
|
37
|
+
# Message Passing (ignored when using foundation model)
|
|
38
38
|
"hidden_dim": 700,
|
|
39
39
|
"depth": 6,
|
|
40
40
|
"dropout": 0.1, # Lower dropout - ensemble provides regularization
|
|
@@ -45,16 +45,24 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
45
45
|
"criterion": "mae",
|
|
46
46
|
# Random seed
|
|
47
47
|
"seed": 42,
|
|
48
|
+
# Foundation model support
|
|
49
|
+
# - "CheMeleon": Load CheMeleon pretrained weights (auto-downloads on first use)
|
|
50
|
+
# - Path to .pt file: Load custom pretrained Chemprop model
|
|
51
|
+
# - None: Train from scratch (default)
|
|
52
|
+
"from_foundation": None,
|
|
53
|
+
# Freeze MPNN for N epochs, then unfreeze (0 = no freezing, train all params from start)
|
|
54
|
+
# Recommended: 5-20 epochs when using foundation models to stabilize FFN before fine-tuning MPNN
|
|
55
|
+
"freeze_mpnn_epochs": 0,
|
|
48
56
|
}
|
|
49
57
|
|
|
50
58
|
# Template parameters (filled in by Workbench)
|
|
51
59
|
TEMPLATE_PARAMS = {
|
|
52
60
|
"model_type": "uq_regressor",
|
|
53
|
-
"targets": ['
|
|
61
|
+
"targets": ['udm_asy_res_free_percent'],
|
|
54
62
|
"feature_list": ['smiles'],
|
|
55
63
|
"id_column": "udm_mol_bat_id",
|
|
56
|
-
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/
|
|
57
|
-
"hyperparameters": {},
|
|
64
|
+
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/ppb-human-free-reg-chemprop-foundation-1-dt/training",
|
|
65
|
+
"hyperparameters": {'from_foundation': 'CheMeleon', 'freeze_mpnn_epochs': 10, 'n_folds': 5, 'max_epochs': 100, 'patience': 20, 'ffn_hidden_dim': 512, 'dropout': 0.15},
|
|
58
66
|
}
|
|
59
67
|
|
|
60
68
|
|
|
@@ -114,26 +122,27 @@ def _create_molecule_datapoints(
|
|
|
114
122
|
# Model Loading (for SageMaker inference)
|
|
115
123
|
# =============================================================================
|
|
116
124
|
def model_fn(model_dir: str) -> dict:
|
|
117
|
-
"""Load ChemProp MPNN ensemble from the specified directory.
|
|
118
|
-
from lightning import pytorch as pl
|
|
125
|
+
"""Load ChemProp MPNN ensemble from the specified directory.
|
|
119
126
|
|
|
127
|
+
Optimized for serverless cold starts - uses direct PyTorch inference
|
|
128
|
+
instead of Lightning Trainer to minimize startup time.
|
|
129
|
+
"""
|
|
120
130
|
metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
|
|
131
|
+
|
|
132
|
+
# Load all ensemble models (keep on CPU for serverless compatibility)
|
|
133
|
+
# ChemProp handles device placement internally
|
|
121
134
|
ensemble_models = []
|
|
122
135
|
for i in range(metadata["n_ensemble"]):
|
|
123
136
|
model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
|
|
124
137
|
model.eval()
|
|
125
138
|
ensemble_models.append(model)
|
|
126
139
|
|
|
127
|
-
# Pre-initialize trainer once during model loading (expensive operation)
|
|
128
|
-
trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
129
|
-
|
|
130
140
|
print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
|
|
131
141
|
return {
|
|
132
142
|
"ensemble_models": ensemble_models,
|
|
133
143
|
"n_ensemble": metadata["n_ensemble"],
|
|
134
144
|
"target_columns": metadata["target_columns"],
|
|
135
145
|
"median_std": metadata["median_std"],
|
|
136
|
-
"trainer": trainer,
|
|
137
146
|
}
|
|
138
147
|
|
|
139
148
|
|
|
@@ -141,13 +150,15 @@ def model_fn(model_dir: str) -> dict:
|
|
|
141
150
|
# Inference (for SageMaker inference)
|
|
142
151
|
# =============================================================================
|
|
143
152
|
def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
144
|
-
"""Make predictions with ChemProp MPNN ensemble.
|
|
153
|
+
"""Make predictions with ChemProp MPNN ensemble.
|
|
154
|
+
|
|
155
|
+
Uses direct PyTorch inference (no Lightning Trainer) for fast serverless inference.
|
|
156
|
+
"""
|
|
145
157
|
model_type = TEMPLATE_PARAMS["model_type"]
|
|
146
158
|
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
147
159
|
|
|
148
160
|
ensemble_models = model_dict["ensemble_models"]
|
|
149
161
|
target_columns = model_dict["target_columns"]
|
|
150
|
-
trainer = model_dict["trainer"] # Use pre-initialized trainer
|
|
151
162
|
|
|
152
163
|
# Load artifacts
|
|
153
164
|
label_encoder = None
|
|
@@ -202,18 +213,34 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
202
213
|
return df
|
|
203
214
|
|
|
204
215
|
dataset = data.MoleculeDataset(datapoints)
|
|
205
|
-
dataloader = data.build_dataloader(dataset, shuffle=False)
|
|
216
|
+
dataloader = data.build_dataloader(dataset, shuffle=False, batch_size=64)
|
|
206
217
|
|
|
207
|
-
# Ensemble predictions
|
|
218
|
+
# Ensemble predictions using direct PyTorch inference (no Lightning Trainer)
|
|
208
219
|
all_preds = []
|
|
209
220
|
for model in ensemble_models:
|
|
221
|
+
model_preds = []
|
|
222
|
+
model.eval()
|
|
210
223
|
with torch.inference_mode():
|
|
211
|
-
|
|
212
|
-
|
|
224
|
+
for batch in dataloader:
|
|
225
|
+
# TrainingBatch contains (bmg, V_d, X_d, targets, weights, lt_mask, gt_mask)
|
|
226
|
+
# For inference we only need bmg, V_d, X_d
|
|
227
|
+
bmg, V_d, X_d, *_ = batch
|
|
228
|
+
output = model(bmg, V_d, X_d)
|
|
229
|
+
model_preds.append(output.detach().cpu().numpy())
|
|
230
|
+
|
|
231
|
+
if len(model_preds) == 0:
|
|
232
|
+
print(f"Warning: No predictions generated. Dataset size: {len(datapoints)}")
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
preds = np.concatenate(model_preds, axis=0)
|
|
213
236
|
if preds.ndim == 3 and preds.shape[1] == 1:
|
|
214
237
|
preds = preds.squeeze(axis=1)
|
|
215
238
|
all_preds.append(preds)
|
|
216
239
|
|
|
240
|
+
if len(all_preds) == 0:
|
|
241
|
+
print("Error: No ensemble predictions generated")
|
|
242
|
+
return df
|
|
243
|
+
|
|
217
244
|
preds = np.mean(np.stack(all_preds), axis=0)
|
|
218
245
|
preds_std = np.std(np.stack(all_preds), axis=0)
|
|
219
246
|
if preds.ndim == 1:
|
|
@@ -243,8 +270,11 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
243
270
|
df["prediction"] = df[f"{target_columns[0]}_pred"]
|
|
244
271
|
df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
|
|
245
272
|
|
|
246
|
-
# Compute confidence from ensemble std
|
|
247
|
-
|
|
273
|
+
# Compute confidence from ensemble std (or NaN if single model)
|
|
274
|
+
if model_dict["median_std"] is not None:
|
|
275
|
+
df = _compute_std_confidence(df, model_dict["median_std"])
|
|
276
|
+
else:
|
|
277
|
+
df["confidence"] = np.nan
|
|
248
278
|
|
|
249
279
|
return df
|
|
250
280
|
|
|
@@ -279,54 +309,107 @@ if __name__ == "__main__":
|
|
|
279
309
|
)
|
|
280
310
|
|
|
281
311
|
# -------------------------------------------------------------------------
|
|
282
|
-
# Training-only helper
|
|
312
|
+
# Training-only helper functions
|
|
283
313
|
# -------------------------------------------------------------------------
|
|
284
|
-
def
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
314
|
+
def _load_foundation_weights(from_foundation: str) -> tuple[nn.BondMessagePassing, nn.Aggregation]:
|
|
315
|
+
"""Load pretrained MPNN weights from foundation model.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
from_foundation: "CheMeleon" or path to .pt file
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
Tuple of (message_passing, aggregation) modules
|
|
322
|
+
"""
|
|
323
|
+
import urllib.request
|
|
324
|
+
from pathlib import Path
|
|
325
|
+
|
|
326
|
+
print(f"Loading foundation model: {from_foundation}")
|
|
327
|
+
|
|
328
|
+
if from_foundation.lower() == "chemeleon":
|
|
329
|
+
# Download from Zenodo if not cached
|
|
330
|
+
cache_dir = Path.home() / ".chemprop" / "foundation"
|
|
331
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
332
|
+
chemeleon_path = cache_dir / "chemeleon_mp.pt"
|
|
333
|
+
|
|
334
|
+
if not chemeleon_path.exists():
|
|
335
|
+
print(" Downloading CheMeleon weights from Zenodo...")
|
|
336
|
+
urllib.request.urlretrieve(
|
|
337
|
+
"https://zenodo.org/records/15460715/files/chemeleon_mp.pt", chemeleon_path
|
|
338
|
+
)
|
|
339
|
+
print(f" Downloaded to {chemeleon_path}")
|
|
340
|
+
|
|
341
|
+
ckpt = torch.load(chemeleon_path, weights_only=True)
|
|
342
|
+
mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
|
|
343
|
+
mp.load_state_dict(ckpt["state_dict"])
|
|
344
|
+
print(f" Loaded CheMeleon MPNN (hidden_dim={mp.output_dim})")
|
|
345
|
+
return mp, nn.MeanAggregation()
|
|
346
|
+
|
|
347
|
+
if not os.path.exists(from_foundation):
|
|
348
|
+
raise ValueError(f"Foundation model not found: {from_foundation}. Use 'CheMeleon' or a valid .pt path.")
|
|
349
|
+
|
|
350
|
+
ckpt = torch.load(from_foundation, weights_only=False)
|
|
351
|
+
if "hyper_parameters" in ckpt and "state_dict" in ckpt:
|
|
352
|
+
# CheMeleon-style checkpoint
|
|
353
|
+
mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
|
|
354
|
+
mp.load_state_dict(ckpt["state_dict"])
|
|
355
|
+
print(f" Loaded custom foundation weights (hidden_dim={mp.output_dim})")
|
|
356
|
+
return mp, nn.MeanAggregation()
|
|
357
|
+
|
|
358
|
+
# Full MPNN model file
|
|
359
|
+
pretrained = models.MPNN.load_from_file(from_foundation)
|
|
360
|
+
print(f" Loaded custom MPNN (hidden_dim={pretrained.message_passing.output_dim})")
|
|
361
|
+
return pretrained.message_passing, pretrained.agg
|
|
362
|
+
|
|
363
|
+
def _build_ffn(
|
|
364
|
+
task: str, input_dim: int, hyperparameters: dict,
|
|
365
|
+
num_classes: int | None, n_targets: int,
|
|
366
|
+
output_transform: nn.UnscaleTransform | None, task_weights: np.ndarray | None,
|
|
367
|
+
) -> nn.Predictor:
|
|
368
|
+
"""Build task-specific FFN head."""
|
|
297
369
|
dropout = hyperparameters["dropout"]
|
|
298
370
|
ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
|
|
299
371
|
ffn_num_layers = hyperparameters["ffn_num_layers"]
|
|
300
372
|
|
|
301
|
-
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
302
|
-
agg = nn.NormAggregation()
|
|
303
|
-
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
304
|
-
|
|
305
373
|
if task == "classification" and num_classes is not None:
|
|
306
|
-
|
|
307
|
-
n_classes=num_classes, input_dim=
|
|
374
|
+
return nn.MulticlassClassificationFFN(
|
|
375
|
+
n_classes=num_classes, input_dim=input_dim,
|
|
308
376
|
hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
|
|
309
377
|
)
|
|
378
|
+
|
|
379
|
+
from chemprop.nn.metrics import MAE, MSE
|
|
380
|
+
criterion_map = {"mae": MAE, "mse": MSE}
|
|
381
|
+
criterion_name = hyperparameters.get("criterion", "mae")
|
|
382
|
+
if criterion_name not in criterion_map:
|
|
383
|
+
raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
|
|
384
|
+
|
|
385
|
+
weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
|
|
386
|
+
return nn.RegressionFFN(
|
|
387
|
+
input_dim=input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
|
|
388
|
+
dropout=dropout, n_tasks=n_targets, output_transform=output_transform,
|
|
389
|
+
task_weights=weights_tensor, criterion=criterion_map[criterion_name](),
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
def build_mpnn_model(
|
|
393
|
+
hyperparameters: dict, task: str = "regression", num_classes: int | None = None,
|
|
394
|
+
n_targets: int = 1, n_extra_descriptors: int = 0,
|
|
395
|
+
x_d_transform: nn.ScaleTransform | None = None,
|
|
396
|
+
output_transform: nn.UnscaleTransform | None = None, task_weights: np.ndarray | None = None,
|
|
397
|
+
) -> models.MPNN:
|
|
398
|
+
"""Build MPNN model, optionally loading pretrained weights."""
|
|
399
|
+
from_foundation = hyperparameters.get("from_foundation")
|
|
400
|
+
|
|
401
|
+
if from_foundation:
|
|
402
|
+
mp, agg = _load_foundation_weights(from_foundation)
|
|
403
|
+
ffn_input_dim = mp.output_dim + n_extra_descriptors
|
|
310
404
|
else:
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
criterion_map = {
|
|
315
|
-
"mae": MAE,
|
|
316
|
-
"mse": MSE,
|
|
317
|
-
}
|
|
318
|
-
criterion_name = hyperparameters.get("criterion", "mae")
|
|
319
|
-
if criterion_name not in criterion_map:
|
|
320
|
-
raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
|
|
321
|
-
criterion = criterion_map[criterion_name]()
|
|
322
|
-
|
|
323
|
-
weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
|
|
324
|
-
ffn = nn.RegressionFFN(
|
|
325
|
-
input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
|
|
326
|
-
dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
|
|
327
|
-
criterion=criterion,
|
|
405
|
+
mp = nn.BondMessagePassing(
|
|
406
|
+
d_h=hyperparameters["hidden_dim"], depth=hyperparameters["depth"],
|
|
407
|
+
dropout=hyperparameters["dropout"],
|
|
328
408
|
)
|
|
409
|
+
agg = nn.NormAggregation()
|
|
410
|
+
ffn_input_dim = hyperparameters["hidden_dim"] + n_extra_descriptors
|
|
329
411
|
|
|
412
|
+
ffn = _build_ffn(task, ffn_input_dim, hyperparameters, num_classes, n_targets, output_transform, task_weights)
|
|
330
413
|
return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
|
|
331
414
|
|
|
332
415
|
# -------------------------------------------------------------------------
|
|
@@ -359,6 +442,14 @@ if __name__ == "__main__":
|
|
|
359
442
|
print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
|
|
360
443
|
print(f"Hyperparameters: {hyperparameters}")
|
|
361
444
|
|
|
445
|
+
# Log foundation model configuration
|
|
446
|
+
if hyperparameters.get("from_foundation"):
|
|
447
|
+
freeze_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
|
|
448
|
+
freeze_msg = f"MPNN frozen for {freeze_epochs} epochs" if freeze_epochs > 0 else "no freezing"
|
|
449
|
+
print(f"Foundation model: {hyperparameters['from_foundation']} ({freeze_msg})")
|
|
450
|
+
else:
|
|
451
|
+
print("Foundation model: None (training from scratch)")
|
|
452
|
+
|
|
362
453
|
# Load training data
|
|
363
454
|
training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
|
|
364
455
|
print(f"Training Files: {training_files}")
|
|
@@ -456,7 +547,7 @@ if __name__ == "__main__":
|
|
|
456
547
|
print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
|
|
457
548
|
print(f"{'='*50}")
|
|
458
549
|
|
|
459
|
-
# Split data
|
|
550
|
+
# Split data (val_extra_raw preserves unscaled features for OOF predictions)
|
|
460
551
|
df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
|
|
461
552
|
train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
|
|
462
553
|
train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
|
|
@@ -481,10 +572,10 @@ if __name__ == "__main__":
|
|
|
481
572
|
val_dataset.normalize_targets(target_scaler)
|
|
482
573
|
output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
|
|
483
574
|
|
|
484
|
-
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
485
|
-
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
575
|
+
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
|
|
576
|
+
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
|
|
486
577
|
|
|
487
|
-
# Build
|
|
578
|
+
# Build model
|
|
488
579
|
pl.seed_everything(hyperparameters["seed"] + fold_idx)
|
|
489
580
|
mpnn = build_mpnn_model(
|
|
490
581
|
hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
|
|
@@ -492,14 +583,39 @@ if __name__ == "__main__":
|
|
|
492
583
|
output_transform=output_transform, task_weights=task_weights,
|
|
493
584
|
)
|
|
494
585
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
586
|
+
# Train model (with optional two-phase foundation training)
|
|
587
|
+
freeze_mpnn_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
|
|
588
|
+
use_two_phase = hyperparameters.get("from_foundation") and freeze_mpnn_epochs > 0
|
|
589
|
+
|
|
590
|
+
def _set_mpnn_frozen(frozen: bool):
|
|
591
|
+
for param in mpnn.message_passing.parameters():
|
|
592
|
+
param.requires_grad = not frozen
|
|
593
|
+
for param in mpnn.agg.parameters():
|
|
594
|
+
param.requires_grad = not frozen
|
|
595
|
+
|
|
596
|
+
def _make_trainer(max_epochs: int, save_checkpoint: bool = False):
|
|
597
|
+
callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min")]
|
|
598
|
+
if save_checkpoint:
|
|
599
|
+
callbacks.append(pl.callbacks.ModelCheckpoint(
|
|
600
|
+
dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1
|
|
601
|
+
))
|
|
602
|
+
return pl.Trainer(accelerator="auto", max_epochs=max_epochs, logger=False, enable_progress_bar=True, callbacks=callbacks)
|
|
603
|
+
|
|
604
|
+
if use_two_phase:
|
|
605
|
+
# Phase 1: Freeze MPNN, train FFN only
|
|
606
|
+
print(f"Phase 1: Training with frozen MPNN for {freeze_mpnn_epochs} epochs...")
|
|
607
|
+
_set_mpnn_frozen(True)
|
|
608
|
+
_make_trainer(freeze_mpnn_epochs).fit(mpnn, train_loader, val_loader)
|
|
609
|
+
|
|
610
|
+
# Phase 2: Unfreeze and fine-tune all
|
|
611
|
+
print("Phase 2: Unfreezing MPNN, continuing training...")
|
|
612
|
+
_set_mpnn_frozen(False)
|
|
613
|
+
remaining_epochs = max(1, hyperparameters["max_epochs"] - freeze_mpnn_epochs)
|
|
614
|
+
trainer = _make_trainer(remaining_epochs, save_checkpoint=True)
|
|
615
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
616
|
+
else:
|
|
617
|
+
trainer = _make_trainer(hyperparameters["max_epochs"], save_checkpoint=True)
|
|
618
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
503
619
|
|
|
504
620
|
# Load best checkpoint
|
|
505
621
|
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
|
|
@@ -509,7 +625,7 @@ if __name__ == "__main__":
|
|
|
509
625
|
mpnn.eval()
|
|
510
626
|
ensemble_models.append(mpnn)
|
|
511
627
|
|
|
512
|
-
# Out-of-fold predictions (using
|
|
628
|
+
# Out-of-fold predictions (using unscaled features - model's x_d_transform handles scaling)
|
|
513
629
|
val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
|
|
514
630
|
val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
|
|
515
631
|
|
|
@@ -599,11 +715,17 @@ if __name__ == "__main__":
|
|
|
599
715
|
df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
|
|
600
716
|
df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
|
|
601
717
|
|
|
602
|
-
# Compute confidence from ensemble std
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
718
|
+
# Compute confidence from ensemble std (or NaN for single model)
|
|
719
|
+
if preds_std is not None:
|
|
720
|
+
median_std = float(np.median(preds_std[:, 0]))
|
|
721
|
+
print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
|
|
722
|
+
df_val = _compute_std_confidence(df_val, median_std)
|
|
723
|
+
print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
|
|
724
|
+
else:
|
|
725
|
+
# Single model - no ensemble std available, confidence is undefined
|
|
726
|
+
median_std = None
|
|
727
|
+
df_val["confidence"] = np.nan
|
|
728
|
+
print("\nSingle model (n_folds=1): No ensemble std, confidence set to NaN")
|
|
607
729
|
|
|
608
730
|
# -------------------------------------------------------------------------
|
|
609
731
|
# Save validation predictions to S3
|
|
@@ -633,6 +755,9 @@ if __name__ == "__main__":
|
|
|
633
755
|
"n_folds": n_folds,
|
|
634
756
|
"target_columns": target_columns,
|
|
635
757
|
"median_std": median_std, # For confidence calculation during inference
|
|
758
|
+
# Foundation model provenance (for tracking/reproducibility)
|
|
759
|
+
"from_foundation": hyperparameters.get("from_foundation", None),
|
|
760
|
+
"freeze_mpnn_epochs": hyperparameters.get("freeze_mpnn_epochs", 0),
|
|
636
761
|
}
|
|
637
762
|
joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
638
763
|
|
|
@@ -148,12 +148,16 @@ def convert_categorical_types(
|
|
|
148
148
|
def decompress_features(
|
|
149
149
|
df: pd.DataFrame, features: list[str], compressed_features: list[str]
|
|
150
150
|
) -> tuple[pd.DataFrame, list[str]]:
|
|
151
|
-
"""Decompress
|
|
151
|
+
"""Decompress compressed features (bitstrings or count vectors) into individual columns.
|
|
152
|
+
|
|
153
|
+
Supports two formats (auto-detected):
|
|
154
|
+
- Bitstrings: "10110010..." → individual uint8 columns (0 or 1)
|
|
155
|
+
- Count vectors: "0,3,0,1,5,..." → individual uint8 columns (0-255)
|
|
152
156
|
|
|
153
157
|
Args:
|
|
154
158
|
df: The features DataFrame
|
|
155
159
|
features: Full list of feature names
|
|
156
|
-
compressed_features: List of feature names to decompress
|
|
160
|
+
compressed_features: List of feature names to decompress
|
|
157
161
|
|
|
158
162
|
Returns:
|
|
159
163
|
Tuple of (DataFrame with decompressed features, updated feature list)
|
|
@@ -178,18 +182,18 @@ def decompress_features(
|
|
|
178
182
|
# Remove the feature from the list to avoid duplication
|
|
179
183
|
decompressed_features.remove(feature)
|
|
180
184
|
|
|
181
|
-
#
|
|
182
|
-
|
|
183
|
-
|
|
185
|
+
# Auto-detect format and parse: comma-separated counts or bitstring
|
|
186
|
+
sample = str(df[feature].dropna().iloc[0]) if not df[feature].dropna().empty else ""
|
|
187
|
+
parse_fn = (lambda s: list(map(int, s.split(",")))) if "," in sample else list
|
|
188
|
+
feature_matrix = np.array([parse_fn(s) for s in df[feature]], dtype=np.uint8)
|
|
184
189
|
|
|
185
|
-
# Create
|
|
186
|
-
|
|
187
|
-
|
|
190
|
+
# Create new columns with prefix from feature name
|
|
191
|
+
prefix = feature[:3]
|
|
192
|
+
new_col_names = [f"{prefix}_{i}" for i in range(feature_matrix.shape[1])]
|
|
193
|
+
new_df = pd.DataFrame(feature_matrix, columns=new_col_names, index=df.index)
|
|
188
194
|
|
|
189
|
-
#
|
|
195
|
+
# Update features list and dataframe
|
|
190
196
|
decompressed_features.extend(new_col_names)
|
|
191
|
-
|
|
192
|
-
# Drop original column and concatenate new ones
|
|
193
197
|
df = df.drop(columns=[feature])
|
|
194
198
|
df = pd.concat([df, new_df], axis=1)
|
|
195
199
|
|
|
@@ -1,11 +1,19 @@
|
|
|
1
|
-
"""Molecular fingerprint computation utilities
|
|
1
|
+
"""Molecular fingerprint computation utilities for ADMET modeling.
|
|
2
|
+
|
|
3
|
+
This module provides Morgan count fingerprints, the standard for ADMET prediction.
|
|
4
|
+
Count fingerprints outperform binary fingerprints for molecular property prediction.
|
|
5
|
+
|
|
6
|
+
References:
|
|
7
|
+
- Count vs Binary: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
|
|
8
|
+
- ECFP/Morgan: https://pubs.acs.org/doi/10.1021/ci100050t
|
|
9
|
+
"""
|
|
2
10
|
|
|
3
11
|
import logging
|
|
4
|
-
import pandas as pd
|
|
5
12
|
|
|
6
|
-
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
7
15
|
from rdkit import Chem, RDLogger
|
|
8
|
-
from rdkit.Chem import
|
|
16
|
+
from rdkit.Chem import AllChem
|
|
9
17
|
from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
10
18
|
|
|
11
19
|
# Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
|
|
@@ -16,20 +24,25 @@ RDLogger.DisableLog("rdApp.warning")
|
|
|
16
24
|
log = logging.getLogger("workbench")
|
|
17
25
|
|
|
18
26
|
|
|
19
|
-
def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048
|
|
20
|
-
"""Compute
|
|
27
|
+
def compute_morgan_fingerprints(df: pd.DataFrame, radius: int = 2, n_bits: int = 2048) -> pd.DataFrame:
|
|
28
|
+
"""Compute Morgan count fingerprints for ADMET modeling.
|
|
29
|
+
|
|
30
|
+
Generates true count fingerprints where each bit position contains the
|
|
31
|
+
number of times that substructure appears in the molecule (clamped to 0-255).
|
|
32
|
+
This is the recommended approach for ADMET prediction per 2025 research.
|
|
21
33
|
|
|
22
34
|
Args:
|
|
23
|
-
df
|
|
24
|
-
radius
|
|
25
|
-
n_bits
|
|
26
|
-
counts (bool): Count simulation for the fingerprint.
|
|
35
|
+
df: Input DataFrame containing SMILES strings.
|
|
36
|
+
radius: Radius for the Morgan fingerprint (default 2 = ECFP4 equivalent).
|
|
37
|
+
n_bits: Number of bits for the fingerprint (default 2048).
|
|
27
38
|
|
|
28
39
|
Returns:
|
|
29
|
-
pd.DataFrame:
|
|
40
|
+
pd.DataFrame: Input DataFrame with 'fingerprint' column added.
|
|
41
|
+
Values are comma-separated uint8 counts.
|
|
30
42
|
|
|
31
43
|
Note:
|
|
32
|
-
|
|
44
|
+
Count fingerprints outperform binary for ADMET prediction.
|
|
45
|
+
See: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
|
|
33
46
|
"""
|
|
34
47
|
delete_mol_column = False
|
|
35
48
|
|
|
@@ -43,7 +56,7 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
|
|
|
43
56
|
log.warning("Detected serialized molecules in 'molecule' column. Removing...")
|
|
44
57
|
del df["molecule"]
|
|
45
58
|
|
|
46
|
-
# Convert SMILES to RDKit molecule objects
|
|
59
|
+
# Convert SMILES to RDKit molecule objects
|
|
47
60
|
if "molecule" not in df.columns:
|
|
48
61
|
log.info("Converting SMILES to RDKit Molecules...")
|
|
49
62
|
delete_mol_column = True
|
|
@@ -59,15 +72,24 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
|
|
|
59
72
|
lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
|
|
60
73
|
)
|
|
61
74
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
75
|
+
def mol_to_count_string(mol):
|
|
76
|
+
"""Convert molecule to comma-separated count fingerprint string."""
|
|
77
|
+
if mol is None:
|
|
78
|
+
return pd.NA
|
|
66
79
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
80
|
+
# Get hashed Morgan fingerprint with counts
|
|
81
|
+
fp = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=n_bits)
|
|
82
|
+
|
|
83
|
+
# Initialize array and populate with counts (clamped to uint8 range)
|
|
84
|
+
counts = np.zeros(n_bits, dtype=np.uint8)
|
|
85
|
+
for idx, count in fp.GetNonzeroElements().items():
|
|
86
|
+
counts[idx] = min(count, 255)
|
|
87
|
+
|
|
88
|
+
# Return as comma-separated string
|
|
89
|
+
return ",".join(map(str, counts))
|
|
90
|
+
|
|
91
|
+
# Compute Morgan count fingerprints
|
|
92
|
+
fingerprints = largest_frags.apply(mol_to_count_string)
|
|
71
93
|
|
|
72
94
|
# Add the fingerprints to the DataFrame
|
|
73
95
|
df["fingerprint"] = fingerprints
|
|
@@ -75,59 +97,62 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
|
|
|
75
97
|
# Drop the intermediate 'molecule' column if it was added
|
|
76
98
|
if delete_mol_column:
|
|
77
99
|
del df["molecule"]
|
|
100
|
+
|
|
78
101
|
return df
|
|
79
102
|
|
|
80
103
|
|
|
81
104
|
if __name__ == "__main__":
|
|
82
|
-
print("Running
|
|
83
|
-
print("Note: This requires molecular_screening module to be available")
|
|
105
|
+
print("Running Morgan count fingerprint tests...")
|
|
84
106
|
|
|
85
107
|
# Test molecules
|
|
86
108
|
test_molecules = {
|
|
87
109
|
"aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
|
|
88
110
|
"caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
|
|
89
111
|
"glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
|
|
90
|
-
"sodium_acetate": "CC(=O)[O-].[Na+]", # Salt
|
|
112
|
+
"sodium_acetate": "CC(=O)[O-].[Na+]", # Salt (largest fragment used)
|
|
91
113
|
"benzene": "c1ccccc1",
|
|
92
114
|
"butene_e": "C/C=C/C", # E-butene
|
|
93
115
|
"butene_z": "C/C=C\\C", # Z-butene
|
|
94
116
|
}
|
|
95
117
|
|
|
96
|
-
# Test 1: Morgan Fingerprints
|
|
97
|
-
print("\n1. Testing Morgan fingerprint generation...")
|
|
118
|
+
# Test 1: Morgan Count Fingerprints (default parameters)
|
|
119
|
+
print("\n1. Testing Morgan fingerprint generation (radius=2, n_bits=2048)...")
|
|
98
120
|
|
|
99
121
|
test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
|
|
100
|
-
|
|
101
|
-
fp_df = compute_morgan_fingerprints(test_df.copy(), radius=2, n_bits=512, counts=False)
|
|
122
|
+
fp_df = compute_morgan_fingerprints(test_df.copy())
|
|
102
123
|
|
|
103
124
|
print(" Fingerprint generation results:")
|
|
104
125
|
for _, row in fp_df.iterrows():
|
|
105
126
|
fp = row.get("fingerprint", "N/A")
|
|
106
|
-
|
|
107
|
-
|
|
127
|
+
if pd.notna(fp):
|
|
128
|
+
counts = [int(x) for x in fp.split(",")]
|
|
129
|
+
non_zero = sum(1 for c in counts if c > 0)
|
|
130
|
+
max_count = max(counts)
|
|
131
|
+
print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero, max={max_count}")
|
|
132
|
+
else:
|
|
133
|
+
print(f" {row['name']:15} → N/A")
|
|
108
134
|
|
|
109
|
-
# Test 2: Different
|
|
110
|
-
print("\n2. Testing different
|
|
135
|
+
# Test 2: Different parameters
|
|
136
|
+
print("\n2. Testing with different parameters (radius=3, n_bits=1024)...")
|
|
111
137
|
|
|
112
|
-
|
|
113
|
-
fp_counts_df = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=256, counts=True)
|
|
138
|
+
fp_df_custom = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=1024)
|
|
114
139
|
|
|
115
|
-
|
|
116
|
-
for _, row in fp_counts_df.iterrows():
|
|
140
|
+
for _, row in fp_df_custom.iterrows():
|
|
117
141
|
fp = row.get("fingerprint", "N/A")
|
|
118
|
-
|
|
119
|
-
|
|
142
|
+
if pd.notna(fp):
|
|
143
|
+
counts = [int(x) for x in fp.split(",")]
|
|
144
|
+
non_zero = sum(1 for c in counts if c > 0)
|
|
145
|
+
print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero")
|
|
146
|
+
else:
|
|
147
|
+
print(f" {row['name']:15} → N/A")
|
|
120
148
|
|
|
121
149
|
# Test 3: Edge cases
|
|
122
150
|
print("\n3. Testing edge cases...")
|
|
123
151
|
|
|
124
152
|
# Invalid SMILES
|
|
125
153
|
invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} valid molecules")
|
|
129
|
-
except Exception as e:
|
|
130
|
-
print(f" ✓ Invalid SMILES properly raised error: {type(e).__name__}")
|
|
154
|
+
fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
|
|
155
|
+
print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} rows returned")
|
|
131
156
|
|
|
132
157
|
# Test with pre-existing molecule column
|
|
133
158
|
mol_df = test_df.copy()
|
|
@@ -135,4 +160,16 @@ if __name__ == "__main__":
|
|
|
135
160
|
fp_with_mol = compute_morgan_fingerprints(mol_df)
|
|
136
161
|
print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
|
|
137
162
|
|
|
163
|
+
# Test 4: Verify count values are reasonable
|
|
164
|
+
print("\n4. Verifying count distribution...")
|
|
165
|
+
all_counts = []
|
|
166
|
+
for _, row in fp_df.iterrows():
|
|
167
|
+
fp = row.get("fingerprint", "N/A")
|
|
168
|
+
if pd.notna(fp):
|
|
169
|
+
counts = [int(x) for x in fp.split(",")]
|
|
170
|
+
all_counts.extend([c for c in counts if c > 0])
|
|
171
|
+
|
|
172
|
+
if all_counts:
|
|
173
|
+
print(f" Non-zero counts: min={min(all_counts)}, max={max(all_counts)}, mean={np.mean(all_counts):.2f}")
|
|
174
|
+
|
|
138
175
|
print("\n✅ All fingerprint tests completed!")
|