workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +2 -0
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
- workbench/algorithms/dataframe/projection_2d.py +8 -2
- workbench/algorithms/dataframe/proximity.py +3 -0
- workbench/algorithms/dataframe/smart_aggregator.py +161 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/api/feature_set.py +0 -1
- workbench/api/meta.py +0 -1
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +37 -7
- workbench/core/artifacts/endpoint_core.py +12 -2
- workbench/core/artifacts/feature_set_core.py +238 -225
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/transforms/features_to_model/features_to_model.py +2 -8
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
- workbench/model_script_utils/model_script_utils.py +30 -0
- workbench/model_script_utils/uq_harness.py +0 -1
- workbench/model_scripts/chemprop/chemprop.template +196 -68
- workbench/model_scripts/chemprop/generated_model_script.py +197 -72
- workbench/model_scripts/chemprop/model_script_utils.py +30 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
- workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
- workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
- workbench/model_scripts/pytorch_model/pytorch.template +47 -29
- workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
- workbench/model_scripts/script_generation.py +0 -1
- workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
- workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
- workbench/model_scripts/xgb_model/uq_harness.py +0 -1
- workbench/scripts/ml_pipeline_sqs.py +71 -2
- workbench/themes/dark/custom.css +85 -8
- workbench/themes/dark/plotly.json +6 -6
- workbench/themes/light/custom.css +172 -64
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +82 -29
- workbench/themes/midnight_blue/plotly.json +1 -1
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +137 -27
- workbench/utils/clientside_callbacks.py +41 -0
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/model_utils.py +0 -1
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +52 -36
- workbench/utils/theme_manager.py +95 -30
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +2 -0
- workbench/web_interface/components/plugin_unit_test.py +0 -1
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +10 -6
- workbench/web_interface/components/plugins/scatter_plot.py +184 -85
- workbench/web_interface/components/settings_menu.py +185 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
|
@@ -20,6 +20,7 @@ import torch
|
|
|
20
20
|
from chemprop import data, models
|
|
21
21
|
|
|
22
22
|
from model_script_utils import (
|
|
23
|
+
cap_std_outliers,
|
|
23
24
|
expand_proba_column,
|
|
24
25
|
input_fn,
|
|
25
26
|
output_fn,
|
|
@@ -34,7 +35,7 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
34
35
|
"max_epochs": 400,
|
|
35
36
|
"patience": 50,
|
|
36
37
|
"batch_size": 32,
|
|
37
|
-
# Message Passing
|
|
38
|
+
# Message Passing (ignored when using foundation model)
|
|
38
39
|
"hidden_dim": 700,
|
|
39
40
|
"depth": 6,
|
|
40
41
|
"dropout": 0.1, # Lower dropout - ensemble provides regularization
|
|
@@ -45,6 +46,14 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
45
46
|
"criterion": "mae",
|
|
46
47
|
# Random seed
|
|
47
48
|
"seed": 42,
|
|
49
|
+
# Foundation model support
|
|
50
|
+
# - "CheMeleon": Load CheMeleon pretrained weights (auto-downloads on first use)
|
|
51
|
+
# - Path to .pt file: Load custom pretrained Chemprop model
|
|
52
|
+
# - None: Train from scratch (default)
|
|
53
|
+
"from_foundation": None,
|
|
54
|
+
# Freeze MPNN for N epochs, then unfreeze (0 = no freezing, train all params from start)
|
|
55
|
+
# Recommended: 5-20 epochs when using foundation models to stabilize FFN before fine-tuning MPNN
|
|
56
|
+
"freeze_mpnn_epochs": 0,
|
|
48
57
|
}
|
|
49
58
|
|
|
50
59
|
# Template parameters (filled in by Workbench)
|
|
@@ -114,26 +123,27 @@ def _create_molecule_datapoints(
|
|
|
114
123
|
# Model Loading (for SageMaker inference)
|
|
115
124
|
# =============================================================================
|
|
116
125
|
def model_fn(model_dir: str) -> dict:
|
|
117
|
-
"""Load ChemProp MPNN ensemble from the specified directory.
|
|
118
|
-
from lightning import pytorch as pl
|
|
126
|
+
"""Load ChemProp MPNN ensemble from the specified directory.
|
|
119
127
|
|
|
128
|
+
Optimized for serverless cold starts - uses direct PyTorch inference
|
|
129
|
+
instead of Lightning Trainer to minimize startup time.
|
|
130
|
+
"""
|
|
120
131
|
metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
|
|
132
|
+
|
|
133
|
+
# Load all ensemble models (keep on CPU for serverless compatibility)
|
|
134
|
+
# ChemProp handles device placement internally
|
|
121
135
|
ensemble_models = []
|
|
122
136
|
for i in range(metadata["n_ensemble"]):
|
|
123
137
|
model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
|
|
124
138
|
model.eval()
|
|
125
139
|
ensemble_models.append(model)
|
|
126
140
|
|
|
127
|
-
# Pre-initialize trainer once during model loading (expensive operation)
|
|
128
|
-
trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
129
|
-
|
|
130
141
|
print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
|
|
131
142
|
return {
|
|
132
143
|
"ensemble_models": ensemble_models,
|
|
133
144
|
"n_ensemble": metadata["n_ensemble"],
|
|
134
145
|
"target_columns": metadata["target_columns"],
|
|
135
146
|
"median_std": metadata["median_std"],
|
|
136
|
-
"trainer": trainer,
|
|
137
147
|
}
|
|
138
148
|
|
|
139
149
|
|
|
@@ -141,13 +151,15 @@ def model_fn(model_dir: str) -> dict:
|
|
|
141
151
|
# Inference (for SageMaker inference)
|
|
142
152
|
# =============================================================================
|
|
143
153
|
def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
144
|
-
"""Make predictions with ChemProp MPNN ensemble.
|
|
154
|
+
"""Make predictions with ChemProp MPNN ensemble.
|
|
155
|
+
|
|
156
|
+
Uses direct PyTorch inference (no Lightning Trainer) for fast serverless inference.
|
|
157
|
+
"""
|
|
145
158
|
model_type = TEMPLATE_PARAMS["model_type"]
|
|
146
159
|
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
147
160
|
|
|
148
161
|
ensemble_models = model_dict["ensemble_models"]
|
|
149
162
|
target_columns = model_dict["target_columns"]
|
|
150
|
-
trainer = model_dict["trainer"] # Use pre-initialized trainer
|
|
151
163
|
|
|
152
164
|
# Load artifacts
|
|
153
165
|
label_encoder = None
|
|
@@ -202,22 +214,39 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
202
214
|
return df
|
|
203
215
|
|
|
204
216
|
dataset = data.MoleculeDataset(datapoints)
|
|
205
|
-
dataloader = data.build_dataloader(dataset, shuffle=False)
|
|
217
|
+
dataloader = data.build_dataloader(dataset, shuffle=False, batch_size=64)
|
|
206
218
|
|
|
207
|
-
# Ensemble predictions
|
|
219
|
+
# Ensemble predictions using direct PyTorch inference (no Lightning Trainer)
|
|
208
220
|
all_preds = []
|
|
209
221
|
for model in ensemble_models:
|
|
222
|
+
model_preds = []
|
|
223
|
+
model.eval()
|
|
210
224
|
with torch.inference_mode():
|
|
211
|
-
|
|
212
|
-
|
|
225
|
+
for batch in dataloader:
|
|
226
|
+
# TrainingBatch contains (bmg, V_d, X_d, targets, weights, lt_mask, gt_mask)
|
|
227
|
+
# For inference we only need bmg, V_d, X_d
|
|
228
|
+
bmg, V_d, X_d, *_ = batch
|
|
229
|
+
output = model(bmg, V_d, X_d)
|
|
230
|
+
model_preds.append(output.detach().cpu().numpy())
|
|
231
|
+
|
|
232
|
+
if len(model_preds) == 0:
|
|
233
|
+
print(f"Warning: No predictions generated. Dataset size: {len(datapoints)}")
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
preds = np.concatenate(model_preds, axis=0)
|
|
213
237
|
if preds.ndim == 3 and preds.shape[1] == 1:
|
|
214
238
|
preds = preds.squeeze(axis=1)
|
|
215
239
|
all_preds.append(preds)
|
|
216
240
|
|
|
241
|
+
if len(all_preds) == 0:
|
|
242
|
+
print("Error: No ensemble predictions generated")
|
|
243
|
+
return df
|
|
244
|
+
|
|
217
245
|
preds = np.mean(np.stack(all_preds), axis=0)
|
|
218
246
|
preds_std = np.std(np.stack(all_preds), axis=0)
|
|
219
247
|
if preds.ndim == 1:
|
|
220
248
|
preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
|
|
249
|
+
preds_std = cap_std_outliers(preds_std)
|
|
221
250
|
|
|
222
251
|
print(f"Inference complete: {preds.shape[0]} predictions")
|
|
223
252
|
|
|
@@ -243,8 +272,11 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
243
272
|
df["prediction"] = df[f"{target_columns[0]}_pred"]
|
|
244
273
|
df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
|
|
245
274
|
|
|
246
|
-
# Compute confidence from ensemble std
|
|
247
|
-
|
|
275
|
+
# Compute confidence from ensemble std (or NaN if single model)
|
|
276
|
+
if model_dict["median_std"] is not None:
|
|
277
|
+
df = _compute_std_confidence(df, model_dict["median_std"])
|
|
278
|
+
else:
|
|
279
|
+
df["confidence"] = np.nan
|
|
248
280
|
|
|
249
281
|
return df
|
|
250
282
|
|
|
@@ -279,54 +311,107 @@ if __name__ == "__main__":
|
|
|
279
311
|
)
|
|
280
312
|
|
|
281
313
|
# -------------------------------------------------------------------------
|
|
282
|
-
# Training-only helper
|
|
314
|
+
# Training-only helper functions
|
|
283
315
|
# -------------------------------------------------------------------------
|
|
284
|
-
def
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
316
|
+
def _load_foundation_weights(from_foundation: str) -> tuple[nn.BondMessagePassing, nn.Aggregation]:
|
|
317
|
+
"""Load pretrained MPNN weights from foundation model.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
from_foundation: "CheMeleon" or path to .pt file
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
Tuple of (message_passing, aggregation) modules
|
|
324
|
+
"""
|
|
325
|
+
import urllib.request
|
|
326
|
+
from pathlib import Path
|
|
327
|
+
|
|
328
|
+
print(f"Loading foundation model: {from_foundation}")
|
|
329
|
+
|
|
330
|
+
if from_foundation.lower() == "chemeleon":
|
|
331
|
+
# Download from Zenodo if not cached
|
|
332
|
+
cache_dir = Path.home() / ".chemprop" / "foundation"
|
|
333
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
334
|
+
chemeleon_path = cache_dir / "chemeleon_mp.pt"
|
|
335
|
+
|
|
336
|
+
if not chemeleon_path.exists():
|
|
337
|
+
print(" Downloading CheMeleon weights from Zenodo...")
|
|
338
|
+
urllib.request.urlretrieve(
|
|
339
|
+
"https://zenodo.org/records/15460715/files/chemeleon_mp.pt", chemeleon_path
|
|
340
|
+
)
|
|
341
|
+
print(f" Downloaded to {chemeleon_path}")
|
|
342
|
+
|
|
343
|
+
ckpt = torch.load(chemeleon_path, weights_only=True)
|
|
344
|
+
mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
|
|
345
|
+
mp.load_state_dict(ckpt["state_dict"])
|
|
346
|
+
print(f" Loaded CheMeleon MPNN (hidden_dim={mp.output_dim})")
|
|
347
|
+
return mp, nn.MeanAggregation()
|
|
348
|
+
|
|
349
|
+
if not os.path.exists(from_foundation):
|
|
350
|
+
raise ValueError(f"Foundation model not found: {from_foundation}. Use 'CheMeleon' or a valid .pt path.")
|
|
351
|
+
|
|
352
|
+
ckpt = torch.load(from_foundation, weights_only=False)
|
|
353
|
+
if "hyper_parameters" in ckpt and "state_dict" in ckpt:
|
|
354
|
+
# CheMeleon-style checkpoint
|
|
355
|
+
mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
|
|
356
|
+
mp.load_state_dict(ckpt["state_dict"])
|
|
357
|
+
print(f" Loaded custom foundation weights (hidden_dim={mp.output_dim})")
|
|
358
|
+
return mp, nn.MeanAggregation()
|
|
359
|
+
|
|
360
|
+
# Full MPNN model file
|
|
361
|
+
pretrained = models.MPNN.load_from_file(from_foundation)
|
|
362
|
+
print(f" Loaded custom MPNN (hidden_dim={pretrained.message_passing.output_dim})")
|
|
363
|
+
return pretrained.message_passing, pretrained.agg
|
|
364
|
+
|
|
365
|
+
def _build_ffn(
|
|
366
|
+
task: str, input_dim: int, hyperparameters: dict,
|
|
367
|
+
num_classes: int | None, n_targets: int,
|
|
368
|
+
output_transform: nn.UnscaleTransform | None, task_weights: np.ndarray | None,
|
|
369
|
+
) -> nn.Predictor:
|
|
370
|
+
"""Build task-specific FFN head."""
|
|
297
371
|
dropout = hyperparameters["dropout"]
|
|
298
372
|
ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
|
|
299
373
|
ffn_num_layers = hyperparameters["ffn_num_layers"]
|
|
300
374
|
|
|
301
|
-
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
302
|
-
agg = nn.NormAggregation()
|
|
303
|
-
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
304
|
-
|
|
305
375
|
if task == "classification" and num_classes is not None:
|
|
306
|
-
|
|
307
|
-
n_classes=num_classes, input_dim=
|
|
376
|
+
return nn.MulticlassClassificationFFN(
|
|
377
|
+
n_classes=num_classes, input_dim=input_dim,
|
|
308
378
|
hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
|
|
309
379
|
)
|
|
380
|
+
|
|
381
|
+
from chemprop.nn.metrics import MAE, MSE
|
|
382
|
+
criterion_map = {"mae": MAE, "mse": MSE}
|
|
383
|
+
criterion_name = hyperparameters.get("criterion", "mae")
|
|
384
|
+
if criterion_name not in criterion_map:
|
|
385
|
+
raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
|
|
386
|
+
|
|
387
|
+
weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
|
|
388
|
+
return nn.RegressionFFN(
|
|
389
|
+
input_dim=input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
|
|
390
|
+
dropout=dropout, n_tasks=n_targets, output_transform=output_transform,
|
|
391
|
+
task_weights=weights_tensor, criterion=criterion_map[criterion_name](),
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
def build_mpnn_model(
|
|
395
|
+
hyperparameters: dict, task: str = "regression", num_classes: int | None = None,
|
|
396
|
+
n_targets: int = 1, n_extra_descriptors: int = 0,
|
|
397
|
+
x_d_transform: nn.ScaleTransform | None = None,
|
|
398
|
+
output_transform: nn.UnscaleTransform | None = None, task_weights: np.ndarray | None = None,
|
|
399
|
+
) -> models.MPNN:
|
|
400
|
+
"""Build MPNN model, optionally loading pretrained weights."""
|
|
401
|
+
from_foundation = hyperparameters.get("from_foundation")
|
|
402
|
+
|
|
403
|
+
if from_foundation:
|
|
404
|
+
mp, agg = _load_foundation_weights(from_foundation)
|
|
405
|
+
ffn_input_dim = mp.output_dim + n_extra_descriptors
|
|
310
406
|
else:
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
criterion_map = {
|
|
315
|
-
"mae": MAE,
|
|
316
|
-
"mse": MSE,
|
|
317
|
-
}
|
|
318
|
-
criterion_name = hyperparameters.get("criterion", "mae")
|
|
319
|
-
if criterion_name not in criterion_map:
|
|
320
|
-
raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
|
|
321
|
-
criterion = criterion_map[criterion_name]()
|
|
322
|
-
|
|
323
|
-
weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
|
|
324
|
-
ffn = nn.RegressionFFN(
|
|
325
|
-
input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
|
|
326
|
-
dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
|
|
327
|
-
criterion=criterion,
|
|
407
|
+
mp = nn.BondMessagePassing(
|
|
408
|
+
d_h=hyperparameters["hidden_dim"], depth=hyperparameters["depth"],
|
|
409
|
+
dropout=hyperparameters["dropout"],
|
|
328
410
|
)
|
|
411
|
+
agg = nn.NormAggregation()
|
|
412
|
+
ffn_input_dim = hyperparameters["hidden_dim"] + n_extra_descriptors
|
|
329
413
|
|
|
414
|
+
ffn = _build_ffn(task, ffn_input_dim, hyperparameters, num_classes, n_targets, output_transform, task_weights)
|
|
330
415
|
return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
|
|
331
416
|
|
|
332
417
|
# -------------------------------------------------------------------------
|
|
@@ -359,6 +444,14 @@ if __name__ == "__main__":
|
|
|
359
444
|
print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
|
|
360
445
|
print(f"Hyperparameters: {hyperparameters}")
|
|
361
446
|
|
|
447
|
+
# Log foundation model configuration
|
|
448
|
+
if hyperparameters.get("from_foundation"):
|
|
449
|
+
freeze_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
|
|
450
|
+
freeze_msg = f"MPNN frozen for {freeze_epochs} epochs" if freeze_epochs > 0 else "no freezing"
|
|
451
|
+
print(f"Foundation model: {hyperparameters['from_foundation']} ({freeze_msg})")
|
|
452
|
+
else:
|
|
453
|
+
print("Foundation model: None (training from scratch)")
|
|
454
|
+
|
|
362
455
|
# Load training data
|
|
363
456
|
training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
|
|
364
457
|
print(f"Training Files: {training_files}")
|
|
@@ -456,7 +549,7 @@ if __name__ == "__main__":
|
|
|
456
549
|
print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
|
|
457
550
|
print(f"{'='*50}")
|
|
458
551
|
|
|
459
|
-
# Split data
|
|
552
|
+
# Split data (val_extra_raw preserves unscaled features for OOF predictions)
|
|
460
553
|
df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
|
|
461
554
|
train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
|
|
462
555
|
train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
|
|
@@ -484,7 +577,7 @@ if __name__ == "__main__":
|
|
|
484
577
|
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
|
|
485
578
|
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
|
|
486
579
|
|
|
487
|
-
# Build
|
|
580
|
+
# Build model
|
|
488
581
|
pl.seed_everything(hyperparameters["seed"] + fold_idx)
|
|
489
582
|
mpnn = build_mpnn_model(
|
|
490
583
|
hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
|
|
@@ -492,14 +585,39 @@ if __name__ == "__main__":
|
|
|
492
585
|
output_transform=output_transform, task_weights=task_weights,
|
|
493
586
|
)
|
|
494
587
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
588
|
+
# Train model (with optional two-phase foundation training)
|
|
589
|
+
freeze_mpnn_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
|
|
590
|
+
use_two_phase = hyperparameters.get("from_foundation") and freeze_mpnn_epochs > 0
|
|
591
|
+
|
|
592
|
+
def _set_mpnn_frozen(frozen: bool):
|
|
593
|
+
for param in mpnn.message_passing.parameters():
|
|
594
|
+
param.requires_grad = not frozen
|
|
595
|
+
for param in mpnn.agg.parameters():
|
|
596
|
+
param.requires_grad = not frozen
|
|
597
|
+
|
|
598
|
+
def _make_trainer(max_epochs: int, save_checkpoint: bool = False):
|
|
599
|
+
callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min")]
|
|
600
|
+
if save_checkpoint:
|
|
601
|
+
callbacks.append(pl.callbacks.ModelCheckpoint(
|
|
602
|
+
dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1
|
|
603
|
+
))
|
|
604
|
+
return pl.Trainer(accelerator="auto", max_epochs=max_epochs, logger=False, enable_progress_bar=True, callbacks=callbacks)
|
|
605
|
+
|
|
606
|
+
if use_two_phase:
|
|
607
|
+
# Phase 1: Freeze MPNN, train FFN only
|
|
608
|
+
print(f"Phase 1: Training with frozen MPNN for {freeze_mpnn_epochs} epochs...")
|
|
609
|
+
_set_mpnn_frozen(True)
|
|
610
|
+
_make_trainer(freeze_mpnn_epochs).fit(mpnn, train_loader, val_loader)
|
|
611
|
+
|
|
612
|
+
# Phase 2: Unfreeze and fine-tune all
|
|
613
|
+
print("Phase 2: Unfreezing MPNN, continuing training...")
|
|
614
|
+
_set_mpnn_frozen(False)
|
|
615
|
+
remaining_epochs = max(1, hyperparameters["max_epochs"] - freeze_mpnn_epochs)
|
|
616
|
+
trainer = _make_trainer(remaining_epochs, save_checkpoint=True)
|
|
617
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
618
|
+
else:
|
|
619
|
+
trainer = _make_trainer(hyperparameters["max_epochs"], save_checkpoint=True)
|
|
620
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
503
621
|
|
|
504
622
|
# Load best checkpoint
|
|
505
623
|
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
|
|
@@ -509,7 +627,7 @@ if __name__ == "__main__":
|
|
|
509
627
|
mpnn.eval()
|
|
510
628
|
ensemble_models.append(mpnn)
|
|
511
629
|
|
|
512
|
-
# Out-of-fold predictions (using
|
|
630
|
+
# Out-of-fold predictions (using unscaled features - model's x_d_transform handles scaling)
|
|
513
631
|
val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
|
|
514
632
|
val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
|
|
515
633
|
|
|
@@ -585,6 +703,7 @@ if __name__ == "__main__":
|
|
|
585
703
|
preds_std = np.std(np.stack(all_ens_preds), axis=0)
|
|
586
704
|
if preds_std.ndim == 1:
|
|
587
705
|
preds_std = preds_std.reshape(-1, 1)
|
|
706
|
+
preds_std = cap_std_outliers(preds_std)
|
|
588
707
|
|
|
589
708
|
print("\n--- Per-target metrics ---")
|
|
590
709
|
for t_idx, t_name in enumerate(target_columns):
|
|
@@ -599,11 +718,17 @@ if __name__ == "__main__":
|
|
|
599
718
|
df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
|
|
600
719
|
df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
|
|
601
720
|
|
|
602
|
-
# Compute confidence from ensemble std
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
721
|
+
# Compute confidence from ensemble std (or NaN for single model)
|
|
722
|
+
if preds_std is not None:
|
|
723
|
+
median_std = float(np.median(preds_std[:, 0]))
|
|
724
|
+
print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
|
|
725
|
+
df_val = _compute_std_confidence(df_val, median_std)
|
|
726
|
+
print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
|
|
727
|
+
else:
|
|
728
|
+
# Single model - no ensemble std available, confidence is undefined
|
|
729
|
+
median_std = None
|
|
730
|
+
df_val["confidence"] = np.nan
|
|
731
|
+
print("\nSingle model (n_folds=1): No ensemble std, confidence set to NaN")
|
|
607
732
|
|
|
608
733
|
# -------------------------------------------------------------------------
|
|
609
734
|
# Save validation predictions to S3
|
|
@@ -633,6 +758,9 @@ if __name__ == "__main__":
|
|
|
633
758
|
"n_folds": n_folds,
|
|
634
759
|
"target_columns": target_columns,
|
|
635
760
|
"median_std": median_std, # For confidence calculation during inference
|
|
761
|
+
# Foundation model provenance (for tracking/reproducibility)
|
|
762
|
+
"from_foundation": hyperparameters.get("from_foundation", None),
|
|
763
|
+
"freeze_mpnn_epochs": hyperparameters.get("freeze_mpnn_epochs", 0),
|
|
636
764
|
}
|
|
637
765
|
joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
638
766
|
|