workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +7 -7
- workbench/core/artifacts/data_capture_core.py +8 -1
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +323 -205
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +133 -101
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +18 -7
- workbench/utils/chem_utils/mol_standardize.py +80 -58
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +274 -87
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,774 @@
|
|
|
1
|
+
# ChemProp Model Template for Workbench
|
|
2
|
+
#
|
|
3
|
+
# This template handles molecular property prediction using ChemProp 2.x MPNN with:
|
|
4
|
+
# - K-fold cross-validation ensemble training (or single train/val split)
|
|
5
|
+
# - Multi-task regression support
|
|
6
|
+
# - Hybrid mode (SMILES + extra molecular descriptors)
|
|
7
|
+
# - Classification (single-target only)
|
|
8
|
+
#
|
|
9
|
+
# NOTE: Imports are structured to minimize serverless endpoint startup time.
|
|
10
|
+
# Heavy imports (lightning, sklearn, awswrangler) are deferred to training time.
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import os
|
|
14
|
+
|
|
15
|
+
import joblib
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from chemprop import data, models
|
|
21
|
+
|
|
22
|
+
from model_script_utils import (
|
|
23
|
+
expand_proba_column,
|
|
24
|
+
input_fn,
|
|
25
|
+
output_fn,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# =============================================================================
|
|
29
|
+
# Default Hyperparameters
|
|
30
|
+
# =============================================================================
|
|
31
|
+
DEFAULT_HYPERPARAMETERS = {
|
|
32
|
+
# Training
|
|
33
|
+
"n_folds": 5,
|
|
34
|
+
"max_epochs": 400,
|
|
35
|
+
"patience": 50,
|
|
36
|
+
"batch_size": 32,
|
|
37
|
+
# Message Passing (ignored when using foundation model)
|
|
38
|
+
"hidden_dim": 700,
|
|
39
|
+
"depth": 6,
|
|
40
|
+
"dropout": 0.1, # Lower dropout - ensemble provides regularization
|
|
41
|
+
# FFN
|
|
42
|
+
"ffn_hidden_dim": 2000,
|
|
43
|
+
"ffn_num_layers": 2,
|
|
44
|
+
# Loss function for regression (mae, mse)
|
|
45
|
+
"criterion": "mae",
|
|
46
|
+
# Random seed
|
|
47
|
+
"seed": 42,
|
|
48
|
+
# Foundation model support
|
|
49
|
+
# - "CheMeleon": Load CheMeleon pretrained weights (auto-downloads on first use)
|
|
50
|
+
# - Path to .pt file: Load custom pretrained Chemprop model
|
|
51
|
+
# - None: Train from scratch (default)
|
|
52
|
+
"from_foundation": None,
|
|
53
|
+
# Freeze MPNN for N epochs, then unfreeze (0 = no freezing, train all params from start)
|
|
54
|
+
# Recommended: 5-20 epochs when using foundation models to stabilize FFN before fine-tuning MPNN
|
|
55
|
+
"freeze_mpnn_epochs": 0,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
# Template parameters (filled in by Workbench)
|
|
59
|
+
TEMPLATE_PARAMS = {
|
|
60
|
+
"model_type": "{{model_type}}",
|
|
61
|
+
"targets": "{{target_column}}",
|
|
62
|
+
"feature_list": "{{feature_list}}",
|
|
63
|
+
"id_column": "{{id_column}}",
|
|
64
|
+
"model_metrics_s3_path": "{{model_metrics_s3_path}}",
|
|
65
|
+
"hyperparameters": "{{hyperparameters}}",
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# =============================================================================
|
|
70
|
+
# Helper Functions
|
|
71
|
+
# =============================================================================
|
|
72
|
+
def _compute_std_confidence(df: pd.DataFrame, median_std: float, std_col: str = "prediction_std") -> pd.DataFrame:
|
|
73
|
+
"""Compute confidence score from ensemble prediction_std.
|
|
74
|
+
|
|
75
|
+
Uses exponential decay: confidence = exp(-std / median_std)
|
|
76
|
+
- Low std (ensemble agreement) -> high confidence
|
|
77
|
+
- High std (ensemble disagreement) -> low confidence
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
df: DataFrame with prediction_std column
|
|
81
|
+
median_std: Median std from training validation set (normalization factor)
|
|
82
|
+
std_col: Name of the std column to use
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
DataFrame with added 'confidence' column (0.0 to 1.0)
|
|
86
|
+
"""
|
|
87
|
+
df["confidence"] = np.exp(-df[std_col] / median_std)
|
|
88
|
+
return df
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _find_smiles_column(columns: list[str]) -> str:
|
|
92
|
+
"""Find SMILES column (case-insensitive match for 'smiles')."""
|
|
93
|
+
smiles_col = next((c for c in columns if c.lower() == "smiles"), None)
|
|
94
|
+
if smiles_col is None:
|
|
95
|
+
raise ValueError("Column list must contain a 'smiles' column (case-insensitive)")
|
|
96
|
+
return smiles_col
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _create_molecule_datapoints(
|
|
100
|
+
smiles_list: list[str],
|
|
101
|
+
targets: np.ndarray | None = None,
|
|
102
|
+
extra_descriptors: np.ndarray | None = None,
|
|
103
|
+
) -> tuple[list[data.MoleculeDatapoint], list[int]]:
|
|
104
|
+
"""Create ChemProp MoleculeDatapoints from SMILES strings."""
|
|
105
|
+
from rdkit import Chem
|
|
106
|
+
|
|
107
|
+
datapoints, valid_indices = [], []
|
|
108
|
+
targets = np.atleast_2d(np.array(targets)).T if targets is not None and np.array(targets).ndim == 1 else targets
|
|
109
|
+
|
|
110
|
+
for i, smi in enumerate(smiles_list):
|
|
111
|
+
if Chem.MolFromSmiles(smi) is None:
|
|
112
|
+
continue
|
|
113
|
+
y = targets[i].tolist() if targets is not None else None
|
|
114
|
+
x_d = extra_descriptors[i] if extra_descriptors is not None else None
|
|
115
|
+
datapoints.append(data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d))
|
|
116
|
+
valid_indices.append(i)
|
|
117
|
+
|
|
118
|
+
return datapoints, valid_indices
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# =============================================================================
|
|
122
|
+
# Model Loading (for SageMaker inference)
|
|
123
|
+
# =============================================================================
|
|
124
|
+
def model_fn(model_dir: str) -> dict:
|
|
125
|
+
"""Load ChemProp MPNN ensemble from the specified directory.
|
|
126
|
+
|
|
127
|
+
Optimized for serverless cold starts - uses direct PyTorch inference
|
|
128
|
+
instead of Lightning Trainer to minimize startup time.
|
|
129
|
+
"""
|
|
130
|
+
metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
|
|
131
|
+
|
|
132
|
+
# Load all ensemble models (keep on CPU for serverless compatibility)
|
|
133
|
+
# ChemProp handles device placement internally
|
|
134
|
+
ensemble_models = []
|
|
135
|
+
for i in range(metadata["n_ensemble"]):
|
|
136
|
+
model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
|
|
137
|
+
model.eval()
|
|
138
|
+
ensemble_models.append(model)
|
|
139
|
+
|
|
140
|
+
print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
|
|
141
|
+
return {
|
|
142
|
+
"ensemble_models": ensemble_models,
|
|
143
|
+
"n_ensemble": metadata["n_ensemble"],
|
|
144
|
+
"target_columns": metadata["target_columns"],
|
|
145
|
+
"median_std": metadata["median_std"],
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
# =============================================================================
|
|
150
|
+
# Inference (for SageMaker inference)
|
|
151
|
+
# =============================================================================
|
|
152
|
+
def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
153
|
+
"""Make predictions with ChemProp MPNN ensemble.
|
|
154
|
+
|
|
155
|
+
Uses direct PyTorch inference (no Lightning Trainer) for fast serverless inference.
|
|
156
|
+
"""
|
|
157
|
+
model_type = TEMPLATE_PARAMS["model_type"]
|
|
158
|
+
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
159
|
+
|
|
160
|
+
ensemble_models = model_dict["ensemble_models"]
|
|
161
|
+
target_columns = model_dict["target_columns"]
|
|
162
|
+
|
|
163
|
+
# Load artifacts
|
|
164
|
+
label_encoder = None
|
|
165
|
+
encoder_path = os.path.join(model_dir, "label_encoder.joblib")
|
|
166
|
+
if os.path.exists(encoder_path):
|
|
167
|
+
label_encoder = joblib.load(encoder_path)
|
|
168
|
+
|
|
169
|
+
feature_metadata = None
|
|
170
|
+
feature_path = os.path.join(model_dir, "feature_metadata.joblib")
|
|
171
|
+
if os.path.exists(feature_path):
|
|
172
|
+
feature_metadata = joblib.load(feature_path)
|
|
173
|
+
print(f"Hybrid mode: {len(feature_metadata['extra_feature_cols'])} extra features")
|
|
174
|
+
|
|
175
|
+
# Find SMILES column and validate
|
|
176
|
+
smiles_column = _find_smiles_column(df.columns.tolist())
|
|
177
|
+
smiles_list = df[smiles_column].tolist()
|
|
178
|
+
|
|
179
|
+
valid_mask = np.array([bool(s and isinstance(s, str) and s.strip()) for s in smiles_list])
|
|
180
|
+
valid_smiles = [s.strip() for i, s in enumerate(smiles_list) if valid_mask[i]]
|
|
181
|
+
print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
|
|
182
|
+
|
|
183
|
+
# Initialize output columns
|
|
184
|
+
if model_type == "classifier":
|
|
185
|
+
df["prediction"] = pd.Series([None] * len(df), dtype=object)
|
|
186
|
+
else:
|
|
187
|
+
for tc in target_columns:
|
|
188
|
+
df[f"{tc}_pred"] = np.nan
|
|
189
|
+
df[f"{tc}_pred_std"] = np.nan
|
|
190
|
+
|
|
191
|
+
if sum(valid_mask) == 0:
|
|
192
|
+
return df
|
|
193
|
+
|
|
194
|
+
# Prepare extra features (raw, unscaled - model handles scaling)
|
|
195
|
+
extra_features = None
|
|
196
|
+
if feature_metadata is not None:
|
|
197
|
+
extra_cols = feature_metadata["extra_feature_cols"]
|
|
198
|
+
col_means = np.array(feature_metadata["col_means"])
|
|
199
|
+
valid_indices = np.where(valid_mask)[0]
|
|
200
|
+
|
|
201
|
+
extra_features = np.zeros((len(valid_indices), len(extra_cols)), dtype=np.float32)
|
|
202
|
+
for j, col in enumerate(extra_cols):
|
|
203
|
+
if col in df.columns:
|
|
204
|
+
values = df.iloc[valid_indices][col].values.astype(np.float32)
|
|
205
|
+
values[np.isnan(values)] = col_means[j]
|
|
206
|
+
extra_features[:, j] = values
|
|
207
|
+
else:
|
|
208
|
+
extra_features[:, j] = col_means[j]
|
|
209
|
+
|
|
210
|
+
# Create datapoints and predict
|
|
211
|
+
datapoints, rdkit_valid = _create_molecule_datapoints(valid_smiles, extra_descriptors=extra_features)
|
|
212
|
+
if len(datapoints) == 0:
|
|
213
|
+
return df
|
|
214
|
+
|
|
215
|
+
dataset = data.MoleculeDataset(datapoints)
|
|
216
|
+
dataloader = data.build_dataloader(dataset, shuffle=False, batch_size=64)
|
|
217
|
+
|
|
218
|
+
# Ensemble predictions using direct PyTorch inference (no Lightning Trainer)
|
|
219
|
+
all_preds = []
|
|
220
|
+
for model in ensemble_models:
|
|
221
|
+
model_preds = []
|
|
222
|
+
model.eval()
|
|
223
|
+
with torch.inference_mode():
|
|
224
|
+
for batch in dataloader:
|
|
225
|
+
# TrainingBatch contains (bmg, V_d, X_d, targets, weights, lt_mask, gt_mask)
|
|
226
|
+
# For inference we only need bmg, V_d, X_d
|
|
227
|
+
bmg, V_d, X_d, *_ = batch
|
|
228
|
+
output = model(bmg, V_d, X_d)
|
|
229
|
+
model_preds.append(output.detach().cpu().numpy())
|
|
230
|
+
|
|
231
|
+
if len(model_preds) == 0:
|
|
232
|
+
print(f"Warning: No predictions generated. Dataset size: {len(datapoints)}")
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
preds = np.concatenate(model_preds, axis=0)
|
|
236
|
+
if preds.ndim == 3 and preds.shape[1] == 1:
|
|
237
|
+
preds = preds.squeeze(axis=1)
|
|
238
|
+
all_preds.append(preds)
|
|
239
|
+
|
|
240
|
+
if len(all_preds) == 0:
|
|
241
|
+
print("Error: No ensemble predictions generated")
|
|
242
|
+
return df
|
|
243
|
+
|
|
244
|
+
preds = np.mean(np.stack(all_preds), axis=0)
|
|
245
|
+
preds_std = np.std(np.stack(all_preds), axis=0)
|
|
246
|
+
if preds.ndim == 1:
|
|
247
|
+
preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
|
|
248
|
+
|
|
249
|
+
print(f"Inference complete: {preds.shape[0]} predictions")
|
|
250
|
+
|
|
251
|
+
# Map predictions back to valid positions
|
|
252
|
+
valid_positions = np.where(valid_mask)[0][rdkit_valid]
|
|
253
|
+
valid_mask = np.zeros(len(df), dtype=bool)
|
|
254
|
+
valid_mask[valid_positions] = True
|
|
255
|
+
|
|
256
|
+
if model_type == "classifier" and label_encoder is not None:
|
|
257
|
+
if preds.shape[1] > 1:
|
|
258
|
+
class_preds = np.argmax(preds, axis=1)
|
|
259
|
+
df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform(class_preds)
|
|
260
|
+
proba = pd.Series([None] * len(df), dtype=object)
|
|
261
|
+
proba.loc[valid_mask] = [p.tolist() for p in preds]
|
|
262
|
+
df["pred_proba"] = proba
|
|
263
|
+
df = expand_proba_column(df, label_encoder.classes_)
|
|
264
|
+
else:
|
|
265
|
+
df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform((preds.flatten() > 0.5).astype(int))
|
|
266
|
+
else:
|
|
267
|
+
for t_idx, tc in enumerate(target_columns):
|
|
268
|
+
df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
|
|
269
|
+
df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
|
|
270
|
+
df["prediction"] = df[f"{target_columns[0]}_pred"]
|
|
271
|
+
df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
|
|
272
|
+
|
|
273
|
+
# Compute confidence from ensemble std (or NaN if single model)
|
|
274
|
+
if model_dict["median_std"] is not None:
|
|
275
|
+
df = _compute_std_confidence(df, model_dict["median_std"])
|
|
276
|
+
else:
|
|
277
|
+
df["confidence"] = np.nan
|
|
278
|
+
|
|
279
|
+
return df
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
# =============================================================================
|
|
283
|
+
# Training
|
|
284
|
+
# =============================================================================
|
|
285
|
+
if __name__ == "__main__":
|
|
286
|
+
# -------------------------------------------------------------------------
|
|
287
|
+
# Training-only imports (deferred to reduce serverless startup time)
|
|
288
|
+
# -------------------------------------------------------------------------
|
|
289
|
+
import argparse
|
|
290
|
+
import glob
|
|
291
|
+
|
|
292
|
+
import awswrangler as wr
|
|
293
|
+
from lightning import pytorch as pl
|
|
294
|
+
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
|
|
295
|
+
from sklearn.preprocessing import LabelEncoder
|
|
296
|
+
|
|
297
|
+
# Enable Tensor Core optimization for GPUs that support it
|
|
298
|
+
torch.set_float32_matmul_precision("medium")
|
|
299
|
+
|
|
300
|
+
from chemprop import nn
|
|
301
|
+
|
|
302
|
+
from model_script_utils import (
|
|
303
|
+
check_dataframe,
|
|
304
|
+
compute_classification_metrics,
|
|
305
|
+
compute_regression_metrics,
|
|
306
|
+
print_classification_metrics,
|
|
307
|
+
print_confusion_matrix,
|
|
308
|
+
print_regression_metrics,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# -------------------------------------------------------------------------
|
|
312
|
+
# Training-only helper functions
|
|
313
|
+
# -------------------------------------------------------------------------
|
|
314
|
+
def _load_foundation_weights(from_foundation: str) -> tuple[nn.BondMessagePassing, nn.Aggregation]:
|
|
315
|
+
"""Load pretrained MPNN weights from foundation model.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
from_foundation: "CheMeleon" or path to .pt file
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
Tuple of (message_passing, aggregation) modules
|
|
322
|
+
"""
|
|
323
|
+
import urllib.request
|
|
324
|
+
from pathlib import Path
|
|
325
|
+
|
|
326
|
+
print(f"Loading foundation model: {from_foundation}")
|
|
327
|
+
|
|
328
|
+
if from_foundation.lower() == "chemeleon":
|
|
329
|
+
# Download from Zenodo if not cached
|
|
330
|
+
cache_dir = Path.home() / ".chemprop" / "foundation"
|
|
331
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
332
|
+
chemeleon_path = cache_dir / "chemeleon_mp.pt"
|
|
333
|
+
|
|
334
|
+
if not chemeleon_path.exists():
|
|
335
|
+
print(" Downloading CheMeleon weights from Zenodo...")
|
|
336
|
+
urllib.request.urlretrieve(
|
|
337
|
+
"https://zenodo.org/records/15460715/files/chemeleon_mp.pt", chemeleon_path
|
|
338
|
+
)
|
|
339
|
+
print(f" Downloaded to {chemeleon_path}")
|
|
340
|
+
|
|
341
|
+
ckpt = torch.load(chemeleon_path, weights_only=True)
|
|
342
|
+
mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
|
|
343
|
+
mp.load_state_dict(ckpt["state_dict"])
|
|
344
|
+
print(f" Loaded CheMeleon MPNN (hidden_dim={mp.output_dim})")
|
|
345
|
+
return mp, nn.MeanAggregation()
|
|
346
|
+
|
|
347
|
+
if not os.path.exists(from_foundation):
|
|
348
|
+
raise ValueError(f"Foundation model not found: {from_foundation}. Use 'CheMeleon' or a valid .pt path.")
|
|
349
|
+
|
|
350
|
+
ckpt = torch.load(from_foundation, weights_only=False)
|
|
351
|
+
if "hyper_parameters" in ckpt and "state_dict" in ckpt:
|
|
352
|
+
# CheMeleon-style checkpoint
|
|
353
|
+
mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
|
|
354
|
+
mp.load_state_dict(ckpt["state_dict"])
|
|
355
|
+
print(f" Loaded custom foundation weights (hidden_dim={mp.output_dim})")
|
|
356
|
+
return mp, nn.MeanAggregation()
|
|
357
|
+
|
|
358
|
+
# Full MPNN model file
|
|
359
|
+
pretrained = models.MPNN.load_from_file(from_foundation)
|
|
360
|
+
print(f" Loaded custom MPNN (hidden_dim={pretrained.message_passing.output_dim})")
|
|
361
|
+
return pretrained.message_passing, pretrained.agg
|
|
362
|
+
|
|
363
|
+
def _build_ffn(
|
|
364
|
+
task: str, input_dim: int, hyperparameters: dict,
|
|
365
|
+
num_classes: int | None, n_targets: int,
|
|
366
|
+
output_transform: nn.UnscaleTransform | None, task_weights: np.ndarray | None,
|
|
367
|
+
) -> nn.Predictor:
|
|
368
|
+
"""Build task-specific FFN head."""
|
|
369
|
+
dropout = hyperparameters["dropout"]
|
|
370
|
+
ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
|
|
371
|
+
ffn_num_layers = hyperparameters["ffn_num_layers"]
|
|
372
|
+
|
|
373
|
+
if task == "classification" and num_classes is not None:
|
|
374
|
+
return nn.MulticlassClassificationFFN(
|
|
375
|
+
n_classes=num_classes, input_dim=input_dim,
|
|
376
|
+
hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
from chemprop.nn.metrics import MAE, MSE
|
|
380
|
+
criterion_map = {"mae": MAE, "mse": MSE}
|
|
381
|
+
criterion_name = hyperparameters.get("criterion", "mae")
|
|
382
|
+
if criterion_name not in criterion_map:
|
|
383
|
+
raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
|
|
384
|
+
|
|
385
|
+
weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
|
|
386
|
+
return nn.RegressionFFN(
|
|
387
|
+
input_dim=input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
|
|
388
|
+
dropout=dropout, n_tasks=n_targets, output_transform=output_transform,
|
|
389
|
+
task_weights=weights_tensor, criterion=criterion_map[criterion_name](),
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
def build_mpnn_model(
|
|
393
|
+
hyperparameters: dict, task: str = "regression", num_classes: int | None = None,
|
|
394
|
+
n_targets: int = 1, n_extra_descriptors: int = 0,
|
|
395
|
+
x_d_transform: nn.ScaleTransform | None = None,
|
|
396
|
+
output_transform: nn.UnscaleTransform | None = None, task_weights: np.ndarray | None = None,
|
|
397
|
+
) -> models.MPNN:
|
|
398
|
+
"""Build MPNN model, optionally loading pretrained weights."""
|
|
399
|
+
from_foundation = hyperparameters.get("from_foundation")
|
|
400
|
+
|
|
401
|
+
if from_foundation:
|
|
402
|
+
mp, agg = _load_foundation_weights(from_foundation)
|
|
403
|
+
ffn_input_dim = mp.output_dim + n_extra_descriptors
|
|
404
|
+
else:
|
|
405
|
+
mp = nn.BondMessagePassing(
|
|
406
|
+
d_h=hyperparameters["hidden_dim"], depth=hyperparameters["depth"],
|
|
407
|
+
dropout=hyperparameters["dropout"],
|
|
408
|
+
)
|
|
409
|
+
agg = nn.NormAggregation()
|
|
410
|
+
ffn_input_dim = hyperparameters["hidden_dim"] + n_extra_descriptors
|
|
411
|
+
|
|
412
|
+
ffn = _build_ffn(task, ffn_input_dim, hyperparameters, num_classes, n_targets, output_transform, task_weights)
|
|
413
|
+
return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
|
|
414
|
+
|
|
415
|
+
# -------------------------------------------------------------------------
|
|
416
|
+
# Setup: Parse arguments and load data
|
|
417
|
+
# -------------------------------------------------------------------------
|
|
418
|
+
parser = argparse.ArgumentParser()
|
|
419
|
+
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
420
|
+
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
421
|
+
parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
|
|
422
|
+
args = parser.parse_args()
|
|
423
|
+
|
|
424
|
+
# Extract template parameters
|
|
425
|
+
target_columns = TEMPLATE_PARAMS["targets"]
|
|
426
|
+
model_type = TEMPLATE_PARAMS["model_type"]
|
|
427
|
+
feature_list = TEMPLATE_PARAMS["feature_list"]
|
|
428
|
+
id_column = TEMPLATE_PARAMS["id_column"]
|
|
429
|
+
model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
|
|
430
|
+
hyperparameters = {**DEFAULT_HYPERPARAMETERS, **(TEMPLATE_PARAMS["hyperparameters"] or {})}
|
|
431
|
+
|
|
432
|
+
if not target_columns or not isinstance(target_columns, list):
|
|
433
|
+
raise ValueError("'targets' must be a non-empty list of target column names")
|
|
434
|
+
n_targets = len(target_columns)
|
|
435
|
+
|
|
436
|
+
smiles_column = _find_smiles_column(feature_list)
|
|
437
|
+
extra_feature_cols = [f for f in feature_list if f != smiles_column]
|
|
438
|
+
use_extra_features = len(extra_feature_cols) > 0
|
|
439
|
+
|
|
440
|
+
print(f"Target columns ({n_targets}): {target_columns}")
|
|
441
|
+
print(f"SMILES column: {smiles_column}")
|
|
442
|
+
print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
|
|
443
|
+
print(f"Hyperparameters: {hyperparameters}")
|
|
444
|
+
|
|
445
|
+
# Log foundation model configuration
|
|
446
|
+
if hyperparameters.get("from_foundation"):
|
|
447
|
+
freeze_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
|
|
448
|
+
freeze_msg = f"MPNN frozen for {freeze_epochs} epochs" if freeze_epochs > 0 else "no freezing"
|
|
449
|
+
print(f"Foundation model: {hyperparameters['from_foundation']} ({freeze_msg})")
|
|
450
|
+
else:
|
|
451
|
+
print("Foundation model: None (training from scratch)")
|
|
452
|
+
|
|
453
|
+
# Load training data
|
|
454
|
+
training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
|
|
455
|
+
print(f"Training Files: {training_files}")
|
|
456
|
+
all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
|
|
457
|
+
check_dataframe(all_df, "training_df")
|
|
458
|
+
|
|
459
|
+
# Clean data
|
|
460
|
+
initial_count = len(all_df)
|
|
461
|
+
all_df = all_df.dropna(subset=[smiles_column])
|
|
462
|
+
all_df = all_df[all_df[target_columns].notna().any(axis=1)]
|
|
463
|
+
if len(all_df) < initial_count:
|
|
464
|
+
print(f"Dropped {initial_count - len(all_df)} rows with missing SMILES/targets")
|
|
465
|
+
|
|
466
|
+
print(f"Data shape: {all_df.shape}")
|
|
467
|
+
for tc in target_columns:
|
|
468
|
+
print(f" {tc}: {all_df[tc].notna().sum()} samples")
|
|
469
|
+
|
|
470
|
+
# -------------------------------------------------------------------------
|
|
471
|
+
# Classification setup
|
|
472
|
+
# -------------------------------------------------------------------------
|
|
473
|
+
label_encoder = None
|
|
474
|
+
num_classes = None
|
|
475
|
+
if model_type == "classifier":
|
|
476
|
+
if n_targets > 1:
|
|
477
|
+
raise ValueError("Multi-task classification not supported")
|
|
478
|
+
label_encoder = LabelEncoder()
|
|
479
|
+
all_df[target_columns[0]] = label_encoder.fit_transform(all_df[target_columns[0]])
|
|
480
|
+
num_classes = len(label_encoder.classes_)
|
|
481
|
+
print(f"Classification: {num_classes} classes: {label_encoder.classes_}")
|
|
482
|
+
|
|
483
|
+
# -------------------------------------------------------------------------
|
|
484
|
+
# Prepare features
|
|
485
|
+
# -------------------------------------------------------------------------
|
|
486
|
+
task = "classification" if model_type == "classifier" else "regression"
|
|
487
|
+
n_extra = len(extra_feature_cols) if use_extra_features else 0
|
|
488
|
+
|
|
489
|
+
all_extra_features, col_means = None, None
|
|
490
|
+
if use_extra_features:
|
|
491
|
+
all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
|
|
492
|
+
col_means = np.nanmean(all_extra_features, axis=0)
|
|
493
|
+
for i in range(all_extra_features.shape[1]):
|
|
494
|
+
all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
|
|
495
|
+
|
|
496
|
+
all_targets = all_df[target_columns].values.astype(np.float32)
|
|
497
|
+
|
|
498
|
+
# Filter invalid SMILES
|
|
499
|
+
_, valid_indices = _create_molecule_datapoints(all_df[smiles_column].tolist(), all_targets, all_extra_features)
|
|
500
|
+
all_df = all_df.iloc[valid_indices].reset_index(drop=True)
|
|
501
|
+
all_targets = all_targets[valid_indices]
|
|
502
|
+
if all_extra_features is not None:
|
|
503
|
+
all_extra_features = all_extra_features[valid_indices]
|
|
504
|
+
print(f"Data after SMILES validation: {all_df.shape}")
|
|
505
|
+
|
|
506
|
+
# Task weights for multi-task (inverse sample count)
|
|
507
|
+
task_weights = None
|
|
508
|
+
if n_targets > 1 and model_type != "classifier":
|
|
509
|
+
counts = np.array([np.sum(~np.isnan(all_targets[:, t])) for t in range(n_targets)])
|
|
510
|
+
task_weights = (1.0 / counts) / (1.0 / counts).min()
|
|
511
|
+
print(f"Task weights: {dict(zip(target_columns, task_weights.round(3)))}")
|
|
512
|
+
|
|
513
|
+
# -------------------------------------------------------------------------
|
|
514
|
+
# Cross-validation setup
|
|
515
|
+
# -------------------------------------------------------------------------
|
|
516
|
+
n_folds = hyperparameters["n_folds"]
|
|
517
|
+
batch_size = hyperparameters["batch_size"]
|
|
518
|
+
|
|
519
|
+
if n_folds == 1:
|
|
520
|
+
if "training" in all_df.columns:
|
|
521
|
+
print("Using 'training' column for train/val split")
|
|
522
|
+
train_idx = np.where(all_df["training"])[0]
|
|
523
|
+
val_idx = np.where(~all_df["training"])[0]
|
|
524
|
+
else:
|
|
525
|
+
print("WARNING: No 'training' column, using random 80/20 split")
|
|
526
|
+
train_idx, val_idx = train_test_split(np.arange(len(all_df)), test_size=0.2, random_state=42)
|
|
527
|
+
folds = [(train_idx, val_idx)]
|
|
528
|
+
else:
|
|
529
|
+
if model_type == "classifier":
|
|
530
|
+
kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
531
|
+
folds = list(kfold.split(all_df, all_df[target_columns[0]]))
|
|
532
|
+
else:
|
|
533
|
+
kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
534
|
+
folds = list(kfold.split(all_df))
|
|
535
|
+
|
|
536
|
+
print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
|
|
537
|
+
|
|
538
|
+
# -------------------------------------------------------------------------
|
|
539
|
+
# Training loop
|
|
540
|
+
# -------------------------------------------------------------------------
|
|
541
|
+
oof_predictions = np.full((len(all_df), n_targets), np.nan, dtype=np.float64)
|
|
542
|
+
oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64) if model_type == "classifier" and num_classes else None
|
|
543
|
+
|
|
544
|
+
ensemble_models = []
|
|
545
|
+
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
546
|
+
print(f"\n{'='*50}")
|
|
547
|
+
print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
|
|
548
|
+
print(f"{'='*50}")
|
|
549
|
+
|
|
550
|
+
# Split data (val_extra_raw preserves unscaled features for OOF predictions)
|
|
551
|
+
df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
|
|
552
|
+
train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
|
|
553
|
+
train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
|
|
554
|
+
val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
|
|
555
|
+
val_extra_raw = val_extra.copy() if val_extra is not None else None
|
|
556
|
+
|
|
557
|
+
# Create datasets
|
|
558
|
+
train_dps, _ = _create_molecule_datapoints(df_train[smiles_column].tolist(), train_targets, train_extra)
|
|
559
|
+
val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra)
|
|
560
|
+
train_dataset, val_dataset = data.MoleculeDataset(train_dps), data.MoleculeDataset(val_dps)
|
|
561
|
+
|
|
562
|
+
# Scale features/targets
|
|
563
|
+
x_d_transform = None
|
|
564
|
+
if use_extra_features:
|
|
565
|
+
scaler = train_dataset.normalize_inputs("X_d")
|
|
566
|
+
val_dataset.normalize_inputs("X_d", scaler)
|
|
567
|
+
x_d_transform = nn.ScaleTransform.from_standard_scaler(scaler)
|
|
568
|
+
|
|
569
|
+
output_transform = None
|
|
570
|
+
if model_type in ["regressor", "uq_regressor"]:
|
|
571
|
+
target_scaler = train_dataset.normalize_targets()
|
|
572
|
+
val_dataset.normalize_targets(target_scaler)
|
|
573
|
+
output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
|
|
574
|
+
|
|
575
|
+
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
|
|
576
|
+
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
|
|
577
|
+
|
|
578
|
+
# Build model
|
|
579
|
+
pl.seed_everything(hyperparameters["seed"] + fold_idx)
|
|
580
|
+
mpnn = build_mpnn_model(
|
|
581
|
+
hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
|
|
582
|
+
n_extra_descriptors=n_extra, x_d_transform=x_d_transform,
|
|
583
|
+
output_transform=output_transform, task_weights=task_weights,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
# Train model (with optional two-phase foundation training)
|
|
587
|
+
freeze_mpnn_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
|
|
588
|
+
use_two_phase = hyperparameters.get("from_foundation") and freeze_mpnn_epochs > 0
|
|
589
|
+
|
|
590
|
+
def _set_mpnn_frozen(frozen: bool):
|
|
591
|
+
for param in mpnn.message_passing.parameters():
|
|
592
|
+
param.requires_grad = not frozen
|
|
593
|
+
for param in mpnn.agg.parameters():
|
|
594
|
+
param.requires_grad = not frozen
|
|
595
|
+
|
|
596
|
+
def _make_trainer(max_epochs: int, save_checkpoint: bool = False):
|
|
597
|
+
callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min")]
|
|
598
|
+
if save_checkpoint:
|
|
599
|
+
callbacks.append(pl.callbacks.ModelCheckpoint(
|
|
600
|
+
dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1
|
|
601
|
+
))
|
|
602
|
+
return pl.Trainer(accelerator="auto", max_epochs=max_epochs, logger=False, enable_progress_bar=True, callbacks=callbacks)
|
|
603
|
+
|
|
604
|
+
if use_two_phase:
|
|
605
|
+
# Phase 1: Freeze MPNN, train FFN only
|
|
606
|
+
print(f"Phase 1: Training with frozen MPNN for {freeze_mpnn_epochs} epochs...")
|
|
607
|
+
_set_mpnn_frozen(True)
|
|
608
|
+
_make_trainer(freeze_mpnn_epochs).fit(mpnn, train_loader, val_loader)
|
|
609
|
+
|
|
610
|
+
# Phase 2: Unfreeze and fine-tune all
|
|
611
|
+
print("Phase 2: Unfreezing MPNN, continuing training...")
|
|
612
|
+
_set_mpnn_frozen(False)
|
|
613
|
+
remaining_epochs = max(1, hyperparameters["max_epochs"] - freeze_mpnn_epochs)
|
|
614
|
+
trainer = _make_trainer(remaining_epochs, save_checkpoint=True)
|
|
615
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
616
|
+
else:
|
|
617
|
+
trainer = _make_trainer(hyperparameters["max_epochs"], save_checkpoint=True)
|
|
618
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
619
|
+
|
|
620
|
+
# Load best checkpoint
|
|
621
|
+
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
|
|
622
|
+
checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
|
|
623
|
+
mpnn.load_state_dict(checkpoint["state_dict"])
|
|
624
|
+
|
|
625
|
+
mpnn.eval()
|
|
626
|
+
ensemble_models.append(mpnn)
|
|
627
|
+
|
|
628
|
+
# Out-of-fold predictions (using unscaled features - model's x_d_transform handles scaling)
|
|
629
|
+
val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
|
|
630
|
+
val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
|
|
631
|
+
|
|
632
|
+
with torch.inference_mode():
|
|
633
|
+
fold_preds = np.concatenate([p.numpy() for p in trainer.predict(mpnn, val_loader_pred)], axis=0)
|
|
634
|
+
if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
|
|
635
|
+
fold_preds = fold_preds.squeeze(axis=1)
|
|
636
|
+
|
|
637
|
+
if model_type == "classifier" and fold_preds.ndim == 2:
|
|
638
|
+
oof_predictions[val_idx, 0] = np.argmax(fold_preds, axis=1)
|
|
639
|
+
if oof_proba is not None:
|
|
640
|
+
oof_proba[val_idx] = fold_preds
|
|
641
|
+
else:
|
|
642
|
+
if fold_preds.ndim == 1:
|
|
643
|
+
fold_preds = fold_preds.reshape(-1, 1)
|
|
644
|
+
oof_predictions[val_idx] = fold_preds
|
|
645
|
+
|
|
646
|
+
print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
|
|
647
|
+
|
|
648
|
+
# -------------------------------------------------------------------------
|
|
649
|
+
# Prepare validation results
|
|
650
|
+
# -------------------------------------------------------------------------
|
|
651
|
+
if n_folds == 1:
|
|
652
|
+
val_mask = ~np.isnan(oof_predictions).all(axis=1)
|
|
653
|
+
df_val = all_df[val_mask].copy()
|
|
654
|
+
preds = oof_predictions[val_mask]
|
|
655
|
+
y_validate = all_targets[val_mask]
|
|
656
|
+
if oof_proba is not None:
|
|
657
|
+
oof_proba = oof_proba[val_mask]
|
|
658
|
+
val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
|
|
659
|
+
else:
|
|
660
|
+
df_val = all_df.copy()
|
|
661
|
+
preds = oof_predictions
|
|
662
|
+
y_validate = all_targets
|
|
663
|
+
val_extra_features = all_extra_features
|
|
664
|
+
|
|
665
|
+
# -------------------------------------------------------------------------
|
|
666
|
+
# Compute metrics and prepare output
|
|
667
|
+
# -------------------------------------------------------------------------
|
|
668
|
+
median_std = None # Only set for regression models with ensemble
|
|
669
|
+
if model_type == "classifier":
|
|
670
|
+
class_preds = preds[:, 0].astype(int)
|
|
671
|
+
target_name = target_columns[0]
|
|
672
|
+
y_true_decoded = label_encoder.inverse_transform(y_validate[:, 0].astype(int))
|
|
673
|
+
preds_decoded = label_encoder.inverse_transform(class_preds)
|
|
674
|
+
|
|
675
|
+
score_df = compute_classification_metrics(y_true_decoded, preds_decoded, label_encoder.classes_, target_name)
|
|
676
|
+
print_classification_metrics(score_df, target_name, label_encoder.classes_)
|
|
677
|
+
print_confusion_matrix(y_true_decoded, preds_decoded, label_encoder.classes_)
|
|
678
|
+
|
|
679
|
+
# Decode target column back to string labels (was encoded for training)
|
|
680
|
+
df_val[target_name] = y_true_decoded
|
|
681
|
+
df_val["prediction"] = preds_decoded
|
|
682
|
+
if oof_proba is not None:
|
|
683
|
+
df_val["pred_proba"] = [p.tolist() for p in oof_proba]
|
|
684
|
+
df_val = expand_proba_column(df_val, label_encoder.classes_)
|
|
685
|
+
else:
|
|
686
|
+
# Compute ensemble std
|
|
687
|
+
preds_std = None
|
|
688
|
+
if len(ensemble_models) > 1:
|
|
689
|
+
print("Computing prediction_std from ensemble...")
|
|
690
|
+
val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), y_validate, val_extra_features)
|
|
691
|
+
val_loader = data.build_dataloader(data.MoleculeDataset(val_dps), batch_size=batch_size, shuffle=False)
|
|
692
|
+
trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
693
|
+
|
|
694
|
+
all_ens_preds = []
|
|
695
|
+
for m in ensemble_models:
|
|
696
|
+
with torch.inference_mode():
|
|
697
|
+
ens_preds = np.concatenate([p.numpy() for p in trainer_pred.predict(m, val_loader)], axis=0)
|
|
698
|
+
if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
|
|
699
|
+
ens_preds = ens_preds.squeeze(axis=1)
|
|
700
|
+
all_ens_preds.append(ens_preds)
|
|
701
|
+
preds_std = np.std(np.stack(all_ens_preds), axis=0)
|
|
702
|
+
if preds_std.ndim == 1:
|
|
703
|
+
preds_std = preds_std.reshape(-1, 1)
|
|
704
|
+
|
|
705
|
+
print("\n--- Per-target metrics ---")
|
|
706
|
+
for t_idx, t_name in enumerate(target_columns):
|
|
707
|
+
valid_mask = ~np.isnan(y_validate[:, t_idx])
|
|
708
|
+
if valid_mask.sum() > 0:
|
|
709
|
+
metrics = compute_regression_metrics(y_validate[valid_mask, t_idx], preds[valid_mask, t_idx])
|
|
710
|
+
print_regression_metrics(metrics)
|
|
711
|
+
|
|
712
|
+
df_val[f"{t_name}_pred"] = preds[:, t_idx]
|
|
713
|
+
df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx] if preds_std is not None else 0.0
|
|
714
|
+
|
|
715
|
+
df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
|
|
716
|
+
df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
|
|
717
|
+
|
|
718
|
+
# Compute confidence from ensemble std (or NaN for single model)
|
|
719
|
+
if preds_std is not None:
|
|
720
|
+
median_std = float(np.median(preds_std[:, 0]))
|
|
721
|
+
print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
|
|
722
|
+
df_val = _compute_std_confidence(df_val, median_std)
|
|
723
|
+
print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
|
|
724
|
+
else:
|
|
725
|
+
# Single model - no ensemble std available, confidence is undefined
|
|
726
|
+
median_std = None
|
|
727
|
+
df_val["confidence"] = np.nan
|
|
728
|
+
print("\nSingle model (n_folds=1): No ensemble std, confidence set to NaN")
|
|
729
|
+
|
|
730
|
+
# -------------------------------------------------------------------------
|
|
731
|
+
# Save validation predictions to S3
|
|
732
|
+
# -------------------------------------------------------------------------
|
|
733
|
+
output_columns = [id_column] if id_column in df_val.columns else []
|
|
734
|
+
output_columns += target_columns
|
|
735
|
+
output_columns += [f"{t}_pred" for t in target_columns] + [f"{t}_pred_std" for t in target_columns]
|
|
736
|
+
output_columns += ["prediction", "prediction_std", "confidence"]
|
|
737
|
+
output_columns += [c for c in df_val.columns if c.endswith("_proba")]
|
|
738
|
+
output_columns = [c for c in output_columns if c in df_val.columns]
|
|
739
|
+
|
|
740
|
+
wr.s3.to_csv(df_val[output_columns], f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
|
|
741
|
+
|
|
742
|
+
# -------------------------------------------------------------------------
|
|
743
|
+
# Save model artifacts
|
|
744
|
+
# -------------------------------------------------------------------------
|
|
745
|
+
for idx, m in enumerate(ensemble_models):
|
|
746
|
+
models.save_model(os.path.join(args.model_dir, f"chemprop_model_{idx}.pt"), m)
|
|
747
|
+
print(f"Saved {len(ensemble_models)} model(s)")
|
|
748
|
+
|
|
749
|
+
# Clean up checkpoints
|
|
750
|
+
for ckpt in glob.glob(os.path.join(args.model_dir, "best_*.ckpt")):
|
|
751
|
+
os.remove(ckpt)
|
|
752
|
+
|
|
753
|
+
ensemble_metadata = {
|
|
754
|
+
"n_ensemble": len(ensemble_models),
|
|
755
|
+
"n_folds": n_folds,
|
|
756
|
+
"target_columns": target_columns,
|
|
757
|
+
"median_std": median_std, # For confidence calculation during inference
|
|
758
|
+
# Foundation model provenance (for tracking/reproducibility)
|
|
759
|
+
"from_foundation": hyperparameters.get("from_foundation", None),
|
|
760
|
+
"freeze_mpnn_epochs": hyperparameters.get("freeze_mpnn_epochs", 0),
|
|
761
|
+
}
|
|
762
|
+
joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
763
|
+
|
|
764
|
+
with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
|
|
765
|
+
json.dump(hyperparameters, f, indent=2)
|
|
766
|
+
|
|
767
|
+
if label_encoder:
|
|
768
|
+
joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
|
|
769
|
+
|
|
770
|
+
if use_extra_features:
|
|
771
|
+
joblib.dump({"extra_feature_cols": extra_feature_cols, "col_means": col_means.tolist()}, os.path.join(args.model_dir, "feature_metadata.joblib"))
|
|
772
|
+
print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")
|
|
773
|
+
|
|
774
|
+
print(f"\nModel training complete! Artifacts saved to {args.model_dir}")
|