workbench 0.8.202__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +421 -85
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +78 -150
- workbench/algorithms/graph/light/proximity_graph.py +5 -5
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +3 -0
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +13 -11
- workbench/api/feature_set.py +111 -8
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +45 -12
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/artifact.py +5 -5
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +228 -237
- workbench/core/artifacts/feature_set_core.py +185 -230
- workbench/core/artifacts/model_core.py +34 -26
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +22 -10
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +41 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +428 -631
- workbench/model_scripts/chemprop/generated_model_script.py +432 -635
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +2 -10
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +374 -613
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +370 -609
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/script_generation.py +6 -5
- workbench/model_scripts/uq_models/generated_model_script.py +65 -422
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -395
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +366 -396
- workbench/repl/workbench_shell.py +0 -5
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +2 -2
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/training_test.py +85 -0
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chemprop_utils.py +36 -655
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +192 -54
- workbench/utils/pytorch_utils.py +33 -472
- workbench/utils/shap_utils.py +1 -55
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +49 -356
- workbench/web_interface/components/model_plot.py +7 -1
- workbench/web_interface/components/plugins/model_details.py +30 -68
- workbench/web_interface/components/plugins/scatter_plot.py +4 -8
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/METADATA +6 -5
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/RECORD +76 -60
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/entry_points.txt +2 -0
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -296
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
- workbench/model_scripts/uq_models/mapie.template +0 -605
- workbench/model_scripts/uq_models/requirements.txt +0 -1
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -1,630 +1,479 @@
|
|
|
1
1
|
# ChemProp Model Template for Workbench
|
|
2
|
-
# Uses ChemProp 2.x Message Passing Neural Networks for molecular property prediction
|
|
3
2
|
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
3
|
+
# This template handles molecular property prediction using ChemProp 2.x MPNN with:
|
|
4
|
+
# - K-fold cross-validation ensemble training (or single train/val split)
|
|
5
|
+
# - Multi-task regression support
|
|
6
|
+
# - Hybrid mode (SMILES + extra molecular descriptors)
|
|
7
|
+
# - Classification (single-target only)
|
|
6
8
|
#
|
|
7
|
-
#
|
|
8
|
-
#
|
|
9
|
-
# - Regression uses output_transform (UnscaleTransform) for target scaling
|
|
10
|
-
#
|
|
11
|
-
# 2. Data Handling (create_molecule_datapoints function)
|
|
12
|
-
# - MoleculeDatapoint creation with x_d (extra descriptors)
|
|
13
|
-
# - RDKit validation of SMILES
|
|
14
|
-
#
|
|
15
|
-
# 3. Scaling (training section)
|
|
16
|
-
# - Extra descriptors: normalize_inputs("X_d") + X_d_transform in model
|
|
17
|
-
# - Targets (regression): normalize_targets() + UnscaleTransform in FFN
|
|
18
|
-
# - At inference: pass RAW features, transforms handle scaling automatically
|
|
19
|
-
#
|
|
20
|
-
# 4. Training Loop (search for "pl.Trainer")
|
|
21
|
-
# - PyTorch Lightning Trainer with ChemProp MPNN
|
|
22
|
-
#
|
|
23
|
-
# AWS/SageMaker boilerplate (can skip):
|
|
24
|
-
# - input_fn, output_fn, model_fn: SageMaker serving interface
|
|
25
|
-
# - argparse, file loading, S3 writes
|
|
26
|
-
# =============================
|
|
9
|
+
# NOTE: Imports are structured to minimize serverless endpoint startup time.
|
|
10
|
+
# Heavy imports (lightning, sklearn, awswrangler) are deferred to training time.
|
|
27
11
|
|
|
28
|
-
import os
|
|
29
|
-
import argparse
|
|
30
12
|
import json
|
|
31
|
-
|
|
13
|
+
import os
|
|
32
14
|
|
|
33
|
-
import
|
|
15
|
+
import joblib
|
|
34
16
|
import numpy as np
|
|
35
17
|
import pandas as pd
|
|
36
18
|
import torch
|
|
37
|
-
|
|
38
|
-
from
|
|
39
|
-
|
|
40
|
-
from
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
root_mean_squared_error,
|
|
45
|
-
precision_recall_fscore_support,
|
|
46
|
-
confusion_matrix,
|
|
19
|
+
|
|
20
|
+
from chemprop import data, models
|
|
21
|
+
|
|
22
|
+
from model_script_utils import (
|
|
23
|
+
expand_proba_column,
|
|
24
|
+
input_fn,
|
|
25
|
+
output_fn,
|
|
47
26
|
)
|
|
48
|
-
from scipy.stats import spearmanr
|
|
49
|
-
import joblib
|
|
50
27
|
|
|
51
|
-
#
|
|
52
|
-
|
|
28
|
+
# =============================================================================
|
|
29
|
+
# Default Hyperparameters
|
|
30
|
+
# =============================================================================
|
|
31
|
+
DEFAULT_HYPERPARAMETERS = {
|
|
32
|
+
# Training
|
|
33
|
+
"n_folds": 5,
|
|
34
|
+
"max_epochs": 400,
|
|
35
|
+
"patience": 50,
|
|
36
|
+
"batch_size": 32,
|
|
37
|
+
# Message Passing
|
|
38
|
+
"hidden_dim": 700,
|
|
39
|
+
"depth": 6,
|
|
40
|
+
"dropout": 0.1, # Lower dropout - ensemble provides regularization
|
|
41
|
+
# FFN
|
|
42
|
+
"ffn_hidden_dim": 2000,
|
|
43
|
+
"ffn_num_layers": 2,
|
|
44
|
+
# Loss function for regression (mae, mse)
|
|
45
|
+
"criterion": "mae",
|
|
46
|
+
# Random seed
|
|
47
|
+
"seed": 42,
|
|
48
|
+
}
|
|
53
49
|
|
|
54
|
-
# Template
|
|
50
|
+
# Template parameters (filled in by Workbench)
|
|
55
51
|
TEMPLATE_PARAMS = {
|
|
56
52
|
"model_type": "uq_regressor",
|
|
57
|
-
"
|
|
58
|
-
"feature_list": ['smiles', '
|
|
59
|
-
"id_column": "
|
|
60
|
-
"model_metrics_s3_path": "s3://
|
|
61
|
-
"hyperparameters": {
|
|
53
|
+
"targets": ['logd'],
|
|
54
|
+
"feature_list": ['smiles', 'mollogp', 'fr_halogen', 'nbase', 'peoe_vsa6', 'bcut2d_mrlow', 'peoe_vsa7', 'peoe_vsa9', 'vsa_estate1', 'peoe_vsa1', 'numhdonors', 'vsa_estate5', 'smr_vsa3', 'slogp_vsa1', 'vsa_estate7', 'bcut2d_mwhi', 'axp_2dv', 'axp_3dv', 'mi', 'smr_vsa9', 'vsa_estate3', 'estate_vsa9', 'bcut2d_mwlow', 'tpsa', 'vsa_estate10', 'xch_5dv', 'slogp_vsa2', 'nhohcount', 'bcut2d_logplow', 'hallkieralpha', 'c2sp2', 'bcut2d_chglo', 'smr_vsa4', 'maxabspartialcharge', 'estate_vsa6', 'qed', 'slogp_vsa6', 'vsa_estate2', 'bcut2d_logphi', 'vsa_estate8', 'xch_7dv', 'fpdensitymorgan3', 'xpc_6d', 'smr_vsa10', 'axp_0d', 'fr_nh1', 'axp_4dv', 'peoe_vsa2', 'estate_vsa8', 'peoe_vsa5', 'vsa_estate6'],
|
|
55
|
+
"id_column": "molecule_name",
|
|
56
|
+
"model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/logd-reg-chemprop-hybrid/training",
|
|
57
|
+
"hyperparameters": {},
|
|
62
58
|
}
|
|
63
59
|
|
|
64
60
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
raise ValueError(msg)
|
|
71
|
-
|
|
61
|
+
# =============================================================================
|
|
62
|
+
# Helper Functions
|
|
63
|
+
# =============================================================================
|
|
64
|
+
def _compute_std_confidence(df: pd.DataFrame, median_std: float, std_col: str = "prediction_std") -> pd.DataFrame:
|
|
65
|
+
"""Compute confidence score from ensemble prediction_std.
|
|
72
66
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
if smiles_column is None:
|
|
77
|
-
raise ValueError(
|
|
78
|
-
"Column list must contain a 'smiles' column (case-insensitive)"
|
|
79
|
-
)
|
|
80
|
-
return smiles_column
|
|
67
|
+
Uses exponential decay: confidence = exp(-std / median_std)
|
|
68
|
+
- Low std (ensemble agreement) -> high confidence
|
|
69
|
+
- High std (ensemble disagreement) -> low confidence
|
|
81
70
|
|
|
71
|
+
Args:
|
|
72
|
+
df: DataFrame with prediction_std column
|
|
73
|
+
median_std: Median std from training validation set (normalization factor)
|
|
74
|
+
std_col: Name of the std column to use
|
|
82
75
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
Handles None values for rows where predictions couldn't be made.
|
|
76
|
+
Returns:
|
|
77
|
+
DataFrame with added 'confidence' column (0.0 to 1.0)
|
|
87
78
|
"""
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
raise ValueError('DataFrame does not contain a "pred_proba" column')
|
|
91
|
-
|
|
92
|
-
proba_splits = [f"{label}_proba" for label in class_labels]
|
|
93
|
-
n_classes = len(class_labels)
|
|
94
|
-
|
|
95
|
-
# Handle None values by replacing with list of NaNs
|
|
96
|
-
proba_values = []
|
|
97
|
-
for val in df[proba_column]:
|
|
98
|
-
if val is None:
|
|
99
|
-
proba_values.append([np.nan] * n_classes)
|
|
100
|
-
else:
|
|
101
|
-
proba_values.append(val)
|
|
79
|
+
df["confidence"] = np.exp(-df[std_col] / median_std)
|
|
80
|
+
return df
|
|
102
81
|
|
|
103
|
-
proba_df = pd.DataFrame(proba_values, columns=proba_splits)
|
|
104
82
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
83
|
+
def _find_smiles_column(columns: list[str]) -> str:
|
|
84
|
+
"""Find SMILES column (case-insensitive match for 'smiles')."""
|
|
85
|
+
smiles_col = next((c for c in columns if c.lower() == "smiles"), None)
|
|
86
|
+
if smiles_col is None:
|
|
87
|
+
raise ValueError("Column list must contain a 'smiles' column (case-insensitive)")
|
|
88
|
+
return smiles_col
|
|
109
89
|
|
|
110
90
|
|
|
111
|
-
def
|
|
91
|
+
def _create_molecule_datapoints(
|
|
112
92
|
smiles_list: list[str],
|
|
113
|
-
targets:
|
|
93
|
+
targets: np.ndarray | None = None,
|
|
114
94
|
extra_descriptors: np.ndarray | None = None,
|
|
115
95
|
) -> tuple[list[data.MoleculeDatapoint], list[int]]:
|
|
116
|
-
"""Create ChemProp MoleculeDatapoints from SMILES strings.
|
|
117
|
-
|
|
118
|
-
Args:
|
|
119
|
-
smiles_list: List of SMILES strings
|
|
120
|
-
targets: Optional list of target values (for training)
|
|
121
|
-
extra_descriptors: Optional array of extra features (n_samples, n_features)
|
|
122
|
-
|
|
123
|
-
Returns:
|
|
124
|
-
Tuple of (list of MoleculeDatapoint objects, list of valid indices)
|
|
125
|
-
"""
|
|
96
|
+
"""Create ChemProp MoleculeDatapoints from SMILES strings."""
|
|
126
97
|
from rdkit import Chem
|
|
127
98
|
|
|
128
|
-
datapoints = []
|
|
129
|
-
|
|
130
|
-
invalid_count = 0
|
|
99
|
+
datapoints, valid_indices = [], []
|
|
100
|
+
targets = np.atleast_2d(np.array(targets)).T if targets is not None and np.array(targets).ndim == 1 else targets
|
|
131
101
|
|
|
132
102
|
for i, smi in enumerate(smiles_list):
|
|
133
|
-
|
|
134
|
-
mol = Chem.MolFromSmiles(smi)
|
|
135
|
-
if mol is None:
|
|
136
|
-
invalid_count += 1
|
|
103
|
+
if Chem.MolFromSmiles(smi) is None:
|
|
137
104
|
continue
|
|
138
|
-
|
|
139
|
-
# Build datapoint with optional target and extra descriptors
|
|
140
|
-
y = [targets[i]] if targets is not None else None
|
|
105
|
+
y = targets[i].tolist() if targets is not None else None
|
|
141
106
|
x_d = extra_descriptors[i] if extra_descriptors is not None else None
|
|
142
|
-
|
|
143
|
-
dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
|
|
144
|
-
datapoints.append(dp)
|
|
107
|
+
datapoints.append(data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d))
|
|
145
108
|
valid_indices.append(i)
|
|
146
109
|
|
|
147
|
-
if invalid_count > 0:
|
|
148
|
-
print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
|
|
149
|
-
|
|
150
110
|
return datapoints, valid_indices
|
|
151
111
|
|
|
152
112
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
num_classes: int | None = None,
|
|
157
|
-
n_extra_descriptors: int = 0,
|
|
158
|
-
x_d_transform: nn.ScaleTransform | None = None,
|
|
159
|
-
output_transform: nn.UnscaleTransform | None = None,
|
|
160
|
-
) -> models.MPNN:
|
|
161
|
-
"""Build an MPNN model with the specified hyperparameters.
|
|
162
|
-
|
|
163
|
-
Args:
|
|
164
|
-
hyperparameters: Dictionary of model hyperparameters
|
|
165
|
-
task: Either "regression" or "classification"
|
|
166
|
-
num_classes: Number of classes for classification tasks
|
|
167
|
-
n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
|
|
168
|
-
x_d_transform: Optional transform for extra descriptors (scaling)
|
|
169
|
-
output_transform: Optional transform for regression output (unscaling targets)
|
|
170
|
-
|
|
171
|
-
Returns:
|
|
172
|
-
Configured MPNN model
|
|
173
|
-
"""
|
|
174
|
-
# Model hyperparameters with defaults
|
|
175
|
-
hidden_dim = hyperparameters.get("hidden_dim", 300)
|
|
176
|
-
depth = hyperparameters.get("depth", 4)
|
|
177
|
-
dropout = hyperparameters.get("dropout", 0.1)
|
|
178
|
-
ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 300)
|
|
179
|
-
ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
|
|
180
|
-
|
|
181
|
-
# Message passing component
|
|
182
|
-
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
183
|
-
|
|
184
|
-
# Aggregation - NormAggregation normalizes output, recommended when using extra descriptors
|
|
185
|
-
agg = nn.NormAggregation()
|
|
186
|
-
|
|
187
|
-
# FFN input_dim = message passing output + extra descriptors
|
|
188
|
-
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
189
|
-
|
|
190
|
-
# Build FFN based on task type
|
|
191
|
-
if task == "classification" and num_classes is not None:
|
|
192
|
-
# Multi-class classification
|
|
193
|
-
ffn = nn.MulticlassClassificationFFN(
|
|
194
|
-
n_classes=num_classes,
|
|
195
|
-
input_dim=ffn_input_dim,
|
|
196
|
-
hidden_dim=ffn_hidden_dim,
|
|
197
|
-
n_layers=ffn_num_layers,
|
|
198
|
-
dropout=dropout,
|
|
199
|
-
)
|
|
200
|
-
else:
|
|
201
|
-
# Regression with optional output transform to unscale predictions
|
|
202
|
-
ffn = nn.RegressionFFN(
|
|
203
|
-
input_dim=ffn_input_dim,
|
|
204
|
-
hidden_dim=ffn_hidden_dim,
|
|
205
|
-
n_layers=ffn_num_layers,
|
|
206
|
-
dropout=dropout,
|
|
207
|
-
output_transform=output_transform,
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
# Create the MPNN model
|
|
211
|
-
mpnn = models.MPNN(
|
|
212
|
-
message_passing=mp,
|
|
213
|
-
agg=agg,
|
|
214
|
-
predictor=ffn,
|
|
215
|
-
batch_norm=True,
|
|
216
|
-
metrics=None,
|
|
217
|
-
X_d_transform=x_d_transform,
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
return mpnn
|
|
221
|
-
|
|
222
|
-
|
|
113
|
+
# =============================================================================
|
|
114
|
+
# Model Loading (for SageMaker inference)
|
|
115
|
+
# =============================================================================
|
|
223
116
|
def model_fn(model_dir: str) -> dict:
|
|
224
|
-
"""Load
|
|
225
|
-
|
|
226
|
-
Args:
|
|
227
|
-
model_dir: Directory containing the saved models
|
|
228
|
-
|
|
229
|
-
Returns:
|
|
230
|
-
Dictionary with ensemble models and metadata
|
|
231
|
-
"""
|
|
232
|
-
# Load ensemble metadata
|
|
233
|
-
ensemble_metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
|
|
234
|
-
if os.path.exists(ensemble_metadata_path):
|
|
235
|
-
ensemble_metadata = joblib.load(ensemble_metadata_path)
|
|
236
|
-
n_ensemble = ensemble_metadata["n_ensemble"]
|
|
237
|
-
else:
|
|
238
|
-
# Backwards compatibility: single model without ensemble metadata
|
|
239
|
-
n_ensemble = 1
|
|
117
|
+
"""Load ChemProp MPNN ensemble from the specified directory."""
|
|
118
|
+
from lightning import pytorch as pl
|
|
240
119
|
|
|
241
|
-
|
|
120
|
+
metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
|
|
242
121
|
ensemble_models = []
|
|
243
|
-
for
|
|
244
|
-
|
|
245
|
-
if not os.path.exists(model_path):
|
|
246
|
-
# Backwards compatibility: try old single model path
|
|
247
|
-
model_path = os.path.join(model_dir, "chemprop_model.pt")
|
|
248
|
-
model = models.MPNN.load_from_file(model_path)
|
|
122
|
+
for i in range(metadata["n_ensemble"]):
|
|
123
|
+
model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
|
|
249
124
|
model.eval()
|
|
250
125
|
ensemble_models.append(model)
|
|
251
126
|
|
|
252
|
-
|
|
127
|
+
# Pre-initialize trainer once during model loading (expensive operation)
|
|
128
|
+
trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
253
129
|
|
|
130
|
+
print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
|
|
254
131
|
return {
|
|
255
132
|
"ensemble_models": ensemble_models,
|
|
256
|
-
"n_ensemble": n_ensemble,
|
|
133
|
+
"n_ensemble": metadata["n_ensemble"],
|
|
134
|
+
"target_columns": metadata["target_columns"],
|
|
135
|
+
"median_std": metadata["median_std"],
|
|
136
|
+
"trainer": trainer,
|
|
257
137
|
}
|
|
258
138
|
|
|
259
139
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
raise ValueError("Empty input data is not supported!")
|
|
264
|
-
|
|
265
|
-
if isinstance(input_data, bytes):
|
|
266
|
-
input_data = input_data.decode("utf-8")
|
|
267
|
-
|
|
268
|
-
if "text/csv" in content_type:
|
|
269
|
-
return pd.read_csv(StringIO(input_data))
|
|
270
|
-
elif "application/json" in content_type:
|
|
271
|
-
return pd.DataFrame(json.loads(input_data))
|
|
272
|
-
else:
|
|
273
|
-
raise ValueError(f"{content_type} not supported!")
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
277
|
-
"""Supports both CSV and JSON output formats."""
|
|
278
|
-
if "text/csv" in accept_type:
|
|
279
|
-
csv_output = output_df.fillna("N/A").to_csv(index=False)
|
|
280
|
-
return csv_output, "text/csv"
|
|
281
|
-
elif "application/json" in accept_type:
|
|
282
|
-
return output_df.to_json(orient="records"), "application/json"
|
|
283
|
-
else:
|
|
284
|
-
raise RuntimeError(
|
|
285
|
-
f"{accept_type} accept type is not supported by this script."
|
|
286
|
-
)
|
|
287
|
-
|
|
288
|
-
|
|
140
|
+
# =============================================================================
|
|
141
|
+
# Inference (for SageMaker inference)
|
|
142
|
+
# =============================================================================
|
|
289
143
|
def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
290
|
-
"""Make predictions with
|
|
291
|
-
|
|
292
|
-
Args:
|
|
293
|
-
df: Input DataFrame containing SMILES column (and extra features if hybrid mode)
|
|
294
|
-
model_dict: Dictionary containing ensemble models and metadata
|
|
295
|
-
|
|
296
|
-
Returns:
|
|
297
|
-
DataFrame with predictions added (and prediction_std for ensembles)
|
|
298
|
-
"""
|
|
144
|
+
"""Make predictions with ChemProp MPNN ensemble."""
|
|
299
145
|
model_type = TEMPLATE_PARAMS["model_type"]
|
|
300
146
|
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
301
147
|
|
|
302
|
-
# Extract ensemble models
|
|
303
148
|
ensemble_models = model_dict["ensemble_models"]
|
|
304
|
-
|
|
149
|
+
target_columns = model_dict["target_columns"]
|
|
150
|
+
trainer = model_dict["trainer"] # Use pre-initialized trainer
|
|
305
151
|
|
|
306
|
-
# Load
|
|
152
|
+
# Load artifacts
|
|
307
153
|
label_encoder = None
|
|
308
|
-
|
|
309
|
-
if os.path.exists(
|
|
310
|
-
label_encoder = joblib.load(
|
|
154
|
+
encoder_path = os.path.join(model_dir, "label_encoder.joblib")
|
|
155
|
+
if os.path.exists(encoder_path):
|
|
156
|
+
label_encoder = joblib.load(encoder_path)
|
|
311
157
|
|
|
312
|
-
# Load feature metadata if present (hybrid mode)
|
|
313
|
-
# Contains column names, NaN fill values, and scaler for feature scaling
|
|
314
158
|
feature_metadata = None
|
|
315
|
-
|
|
316
|
-
if os.path.exists(
|
|
317
|
-
feature_metadata = joblib.load(
|
|
318
|
-
print(
|
|
319
|
-
f"Hybrid mode: using {len(feature_metadata['extra_feature_cols'])} extra features"
|
|
320
|
-
)
|
|
321
|
-
|
|
322
|
-
# Find SMILES column in input DataFrame
|
|
323
|
-
smiles_column = find_smiles_column(df.columns.tolist())
|
|
159
|
+
feature_path = os.path.join(model_dir, "feature_metadata.joblib")
|
|
160
|
+
if os.path.exists(feature_path):
|
|
161
|
+
feature_metadata = joblib.load(feature_path)
|
|
162
|
+
print(f"Hybrid mode: {len(feature_metadata['extra_feature_cols'])} extra features")
|
|
324
163
|
|
|
164
|
+
# Find SMILES column and validate
|
|
165
|
+
smiles_column = _find_smiles_column(df.columns.tolist())
|
|
325
166
|
smiles_list = df[smiles_column].tolist()
|
|
326
167
|
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
valid_smiles = []
|
|
330
|
-
valid_indices = []
|
|
331
|
-
for i, smi in enumerate(smiles_list):
|
|
332
|
-
if smi and isinstance(smi, str) and len(smi.strip()) > 0:
|
|
333
|
-
valid_mask.append(True)
|
|
334
|
-
valid_smiles.append(smi.strip())
|
|
335
|
-
valid_indices.append(i)
|
|
336
|
-
else:
|
|
337
|
-
valid_mask.append(False)
|
|
338
|
-
|
|
339
|
-
valid_mask = np.array(valid_mask)
|
|
168
|
+
valid_mask = np.array([bool(s and isinstance(s, str) and s.strip()) for s in smiles_list])
|
|
169
|
+
valid_smiles = [s.strip() for i, s in enumerate(smiles_list) if valid_mask[i]]
|
|
340
170
|
print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
|
|
341
171
|
|
|
342
|
-
# Initialize
|
|
172
|
+
# Initialize output columns
|
|
343
173
|
if model_type == "classifier":
|
|
344
174
|
df["prediction"] = pd.Series([None] * len(df), dtype=object)
|
|
345
175
|
else:
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
176
|
+
for tc in target_columns:
|
|
177
|
+
df[f"{tc}_pred"] = np.nan
|
|
178
|
+
df[f"{tc}_pred_std"] = np.nan
|
|
349
179
|
|
|
350
180
|
if sum(valid_mask) == 0:
|
|
351
|
-
print("Warning: No valid SMILES to predict on")
|
|
352
181
|
return df
|
|
353
182
|
|
|
354
|
-
# Prepare extra features
|
|
355
|
-
# NOTE: We pass RAW (unscaled) features here - the model's X_d_transform handles scaling
|
|
183
|
+
# Prepare extra features (raw, unscaled - model handles scaling)
|
|
356
184
|
extra_features = None
|
|
357
185
|
if feature_metadata is not None:
|
|
358
|
-
|
|
186
|
+
extra_cols = feature_metadata["extra_feature_cols"]
|
|
359
187
|
col_means = np.array(feature_metadata["col_means"])
|
|
188
|
+
valid_indices = np.where(valid_mask)[0]
|
|
360
189
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
if missing_cols:
|
|
364
|
-
print(
|
|
365
|
-
f"Warning: Missing extra feature columns: {missing_cols}. Using mean values."
|
|
366
|
-
)
|
|
367
|
-
|
|
368
|
-
# Extract features for valid SMILES rows (raw, unscaled)
|
|
369
|
-
extra_features = np.zeros(
|
|
370
|
-
(len(valid_indices), len(extra_feature_cols)), dtype=np.float32
|
|
371
|
-
)
|
|
372
|
-
for j, col in enumerate(extra_feature_cols):
|
|
190
|
+
extra_features = np.zeros((len(valid_indices), len(extra_cols)), dtype=np.float32)
|
|
191
|
+
for j, col in enumerate(extra_cols):
|
|
373
192
|
if col in df.columns:
|
|
374
193
|
values = df.iloc[valid_indices][col].values.astype(np.float32)
|
|
375
|
-
|
|
376
|
-
nan_mask = np.isnan(values)
|
|
377
|
-
values[nan_mask] = col_means[j]
|
|
194
|
+
values[np.isnan(values)] = col_means[j]
|
|
378
195
|
extra_features[:, j] = values
|
|
379
196
|
else:
|
|
380
|
-
# Column missing, use training mean
|
|
381
197
|
extra_features[:, j] = col_means[j]
|
|
382
198
|
|
|
383
|
-
# Create datapoints
|
|
384
|
-
datapoints,
|
|
385
|
-
valid_smiles, extra_descriptors=extra_features
|
|
386
|
-
)
|
|
387
|
-
|
|
199
|
+
# Create datapoints and predict
|
|
200
|
+
datapoints, rdkit_valid = _create_molecule_datapoints(valid_smiles, extra_descriptors=extra_features)
|
|
388
201
|
if len(datapoints) == 0:
|
|
389
|
-
print("Warning: No valid SMILES after RDKit validation")
|
|
390
202
|
return df
|
|
391
203
|
|
|
392
204
|
dataset = data.MoleculeDataset(datapoints)
|
|
393
205
|
dataloader = data.build_dataloader(dataset, shuffle=False)
|
|
394
206
|
|
|
395
|
-
#
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
logger=False,
|
|
399
|
-
enable_progress_bar=False,
|
|
400
|
-
)
|
|
401
|
-
|
|
402
|
-
# Collect predictions from all ensemble members
|
|
403
|
-
all_ensemble_preds = []
|
|
404
|
-
for ens_idx, ens_model in enumerate(ensemble_models):
|
|
207
|
+
# Ensemble predictions
|
|
208
|
+
all_preds = []
|
|
209
|
+
for model in ensemble_models:
|
|
405
210
|
with torch.inference_mode():
|
|
406
|
-
predictions = trainer.predict(
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
# rdkit_valid_indices tells us which of the valid_smiles were actually valid
|
|
422
|
-
valid_positions = np.where(valid_mask)[0][rdkit_valid_indices]
|
|
211
|
+
predictions = trainer.predict(model, dataloader)
|
|
212
|
+
preds = np.concatenate([p.numpy() for p in predictions], axis=0)
|
|
213
|
+
if preds.ndim == 3 and preds.shape[1] == 1:
|
|
214
|
+
preds = preds.squeeze(axis=1)
|
|
215
|
+
all_preds.append(preds)
|
|
216
|
+
|
|
217
|
+
preds = np.mean(np.stack(all_preds), axis=0)
|
|
218
|
+
preds_std = np.std(np.stack(all_preds), axis=0)
|
|
219
|
+
if preds.ndim == 1:
|
|
220
|
+
preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
|
|
221
|
+
|
|
222
|
+
print(f"Inference complete: {preds.shape[0]} predictions")
|
|
223
|
+
|
|
224
|
+
# Map predictions back to valid positions
|
|
225
|
+
valid_positions = np.where(valid_mask)[0][rdkit_valid]
|
|
423
226
|
valid_mask = np.zeros(len(df), dtype=bool)
|
|
424
227
|
valid_mask[valid_positions] = True
|
|
425
228
|
|
|
426
229
|
if model_type == "classifier" and label_encoder is not None:
|
|
427
|
-
|
|
428
|
-
if preds.ndim == 2 and preds.shape[1] > 1:
|
|
429
|
-
# Multi-class: preds are probabilities (averaged across ensemble)
|
|
230
|
+
if preds.shape[1] > 1:
|
|
430
231
|
class_preds = np.argmax(preds, axis=1)
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
proba_series = pd.Series([None] * len(df), index=df.index, dtype=object)
|
|
436
|
-
proba_series.loc[valid_mask] = [p.tolist() for p in preds]
|
|
437
|
-
df["pred_proba"] = proba_series
|
|
232
|
+
df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform(class_preds)
|
|
233
|
+
proba = pd.Series([None] * len(df), dtype=object)
|
|
234
|
+
proba.loc[valid_mask] = [p.tolist() for p in preds]
|
|
235
|
+
df["pred_proba"] = proba
|
|
438
236
|
df = expand_proba_column(df, label_encoder.classes_)
|
|
439
237
|
else:
|
|
440
|
-
|
|
441
|
-
class_preds = (preds.flatten() > 0.5).astype(int)
|
|
442
|
-
decoded_preds = label_encoder.inverse_transform(class_preds)
|
|
443
|
-
df.loc[valid_mask, "prediction"] = decoded_preds
|
|
238
|
+
df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform((preds.flatten() > 0.5).astype(int))
|
|
444
239
|
else:
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
240
|
+
for t_idx, tc in enumerate(target_columns):
|
|
241
|
+
df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
|
|
242
|
+
df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
|
|
243
|
+
df["prediction"] = df[f"{target_columns[0]}_pred"]
|
|
244
|
+
df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
|
|
245
|
+
|
|
246
|
+
# Compute confidence from ensemble std
|
|
247
|
+
df = _compute_std_confidence(df, model_dict["median_std"])
|
|
448
248
|
|
|
449
249
|
return df
|
|
450
250
|
|
|
451
251
|
|
|
252
|
+
# =============================================================================
|
|
253
|
+
# Training
|
|
254
|
+
# =============================================================================
|
|
452
255
|
if __name__ == "__main__":
|
|
453
|
-
|
|
256
|
+
# -------------------------------------------------------------------------
|
|
257
|
+
# Training-only imports (deferred to reduce serverless startup time)
|
|
258
|
+
# -------------------------------------------------------------------------
|
|
259
|
+
import argparse
|
|
260
|
+
import glob
|
|
261
|
+
|
|
262
|
+
import awswrangler as wr
|
|
263
|
+
from lightning import pytorch as pl
|
|
264
|
+
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
|
|
265
|
+
from sklearn.preprocessing import LabelEncoder
|
|
266
|
+
|
|
267
|
+
# Enable Tensor Core optimization for GPUs that support it
|
|
268
|
+
torch.set_float32_matmul_precision("medium")
|
|
269
|
+
|
|
270
|
+
from chemprop import nn
|
|
271
|
+
|
|
272
|
+
from model_script_utils import (
|
|
273
|
+
check_dataframe,
|
|
274
|
+
compute_classification_metrics,
|
|
275
|
+
compute_regression_metrics,
|
|
276
|
+
print_classification_metrics,
|
|
277
|
+
print_confusion_matrix,
|
|
278
|
+
print_regression_metrics,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# -------------------------------------------------------------------------
|
|
282
|
+
# Training-only helper function
|
|
283
|
+
# -------------------------------------------------------------------------
|
|
284
|
+
def build_mpnn_model(
|
|
285
|
+
hyperparameters: dict,
|
|
286
|
+
task: str = "regression",
|
|
287
|
+
num_classes: int | None = None,
|
|
288
|
+
n_targets: int = 1,
|
|
289
|
+
n_extra_descriptors: int = 0,
|
|
290
|
+
x_d_transform: nn.ScaleTransform | None = None,
|
|
291
|
+
output_transform: nn.UnscaleTransform | None = None,
|
|
292
|
+
task_weights: np.ndarray | None = None,
|
|
293
|
+
) -> models.MPNN:
|
|
294
|
+
"""Build an MPNN model with specified hyperparameters."""
|
|
295
|
+
hidden_dim = hyperparameters["hidden_dim"]
|
|
296
|
+
depth = hyperparameters["depth"]
|
|
297
|
+
dropout = hyperparameters["dropout"]
|
|
298
|
+
ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
|
|
299
|
+
ffn_num_layers = hyperparameters["ffn_num_layers"]
|
|
300
|
+
|
|
301
|
+
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
302
|
+
agg = nn.NormAggregation()
|
|
303
|
+
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
304
|
+
|
|
305
|
+
if task == "classification" and num_classes is not None:
|
|
306
|
+
ffn = nn.MulticlassClassificationFFN(
|
|
307
|
+
n_classes=num_classes, input_dim=ffn_input_dim,
|
|
308
|
+
hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
|
|
309
|
+
)
|
|
310
|
+
else:
|
|
311
|
+
# Map criterion name to ChemProp metric class (must have .clone() method)
|
|
312
|
+
from chemprop.nn.metrics import MAE, MSE
|
|
313
|
+
|
|
314
|
+
criterion_map = {
|
|
315
|
+
"mae": MAE,
|
|
316
|
+
"mse": MSE,
|
|
317
|
+
}
|
|
318
|
+
criterion_name = hyperparameters.get("criterion", "mae")
|
|
319
|
+
if criterion_name not in criterion_map:
|
|
320
|
+
raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
|
|
321
|
+
criterion = criterion_map[criterion_name]()
|
|
322
|
+
|
|
323
|
+
weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
|
|
324
|
+
ffn = nn.RegressionFFN(
|
|
325
|
+
input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
|
|
326
|
+
dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
|
|
327
|
+
criterion=criterion,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
|
|
454
331
|
|
|
455
|
-
#
|
|
456
|
-
|
|
332
|
+
# -------------------------------------------------------------------------
|
|
333
|
+
# Setup: Parse arguments and load data
|
|
334
|
+
# -------------------------------------------------------------------------
|
|
335
|
+
parser = argparse.ArgumentParser()
|
|
336
|
+
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
337
|
+
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
338
|
+
parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
|
|
339
|
+
args = parser.parse_args()
|
|
340
|
+
|
|
341
|
+
# Extract template parameters
|
|
342
|
+
target_columns = TEMPLATE_PARAMS["targets"]
|
|
457
343
|
model_type = TEMPLATE_PARAMS["model_type"]
|
|
458
344
|
feature_list = TEMPLATE_PARAMS["feature_list"]
|
|
459
345
|
id_column = TEMPLATE_PARAMS["id_column"]
|
|
460
346
|
model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
|
|
461
|
-
hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
|
|
347
|
+
hyperparameters = {**DEFAULT_HYPERPARAMETERS, **(TEMPLATE_PARAMS["hyperparameters"] or {})}
|
|
462
348
|
|
|
463
|
-
|
|
464
|
-
|
|
349
|
+
if not target_columns or not isinstance(target_columns, list):
|
|
350
|
+
raise ValueError("'targets' must be a non-empty list of target column names")
|
|
351
|
+
n_targets = len(target_columns)
|
|
352
|
+
|
|
353
|
+
smiles_column = _find_smiles_column(feature_list)
|
|
465
354
|
extra_feature_cols = [f for f in feature_list if f != smiles_column]
|
|
466
355
|
use_extra_features = len(extra_feature_cols) > 0
|
|
467
|
-
print(f"Feature List: {feature_list}")
|
|
468
|
-
print(f"SMILES Column: {smiles_column}")
|
|
469
|
-
print(
|
|
470
|
-
f"Extra Features (hybrid mode): {extra_feature_cols if use_extra_features else 'None (SMILES only)'}"
|
|
471
|
-
)
|
|
472
356
|
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
)
|
|
478
|
-
parser.add_argument(
|
|
479
|
-
"--train",
|
|
480
|
-
type=str,
|
|
481
|
-
default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"),
|
|
482
|
-
)
|
|
483
|
-
parser.add_argument(
|
|
484
|
-
"--output-data-dir",
|
|
485
|
-
type=str,
|
|
486
|
-
default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"),
|
|
487
|
-
)
|
|
488
|
-
args = parser.parse_args()
|
|
357
|
+
print(f"Target columns ({n_targets}): {target_columns}")
|
|
358
|
+
print(f"SMILES column: {smiles_column}")
|
|
359
|
+
print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
|
|
360
|
+
print(f"Hyperparameters: {hyperparameters}")
|
|
489
361
|
|
|
490
|
-
#
|
|
491
|
-
training_files = [
|
|
492
|
-
os.path.join(args.train, f)
|
|
493
|
-
for f in os.listdir(args.train)
|
|
494
|
-
if f.endswith(".csv")
|
|
495
|
-
]
|
|
362
|
+
# Load training data
|
|
363
|
+
training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
|
|
496
364
|
print(f"Training Files: {training_files}")
|
|
497
|
-
|
|
498
365
|
all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
|
|
499
|
-
print(f"All Data Shape: {all_df.shape}")
|
|
500
|
-
|
|
501
366
|
check_dataframe(all_df, "training_df")
|
|
502
367
|
|
|
503
|
-
#
|
|
368
|
+
# Clean data
|
|
504
369
|
initial_count = len(all_df)
|
|
505
|
-
all_df = all_df.dropna(subset=[smiles_column
|
|
506
|
-
|
|
507
|
-
if
|
|
508
|
-
print(f"Dropped {
|
|
509
|
-
|
|
510
|
-
print(f"
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
370
|
+
all_df = all_df.dropna(subset=[smiles_column])
|
|
371
|
+
all_df = all_df[all_df[target_columns].notna().any(axis=1)]
|
|
372
|
+
if len(all_df) < initial_count:
|
|
373
|
+
print(f"Dropped {initial_count - len(all_df)} rows with missing SMILES/targets")
|
|
374
|
+
|
|
375
|
+
print(f"Data shape: {all_df.shape}")
|
|
376
|
+
for tc in target_columns:
|
|
377
|
+
print(f" {tc}: {all_df[tc].notna().sum()} samples")
|
|
378
|
+
|
|
379
|
+
# -------------------------------------------------------------------------
|
|
380
|
+
# Classification setup
|
|
381
|
+
# -------------------------------------------------------------------------
|
|
514
382
|
label_encoder = None
|
|
383
|
+
num_classes = None
|
|
515
384
|
if model_type == "classifier":
|
|
385
|
+
if n_targets > 1:
|
|
386
|
+
raise ValueError("Multi-task classification not supported")
|
|
516
387
|
label_encoder = LabelEncoder()
|
|
517
|
-
all_df[
|
|
388
|
+
all_df[target_columns[0]] = label_encoder.fit_transform(all_df[target_columns[0]])
|
|
518
389
|
num_classes = len(label_encoder.classes_)
|
|
519
|
-
print(
|
|
520
|
-
f"Classification task with {num_classes} classes: {label_encoder.classes_}"
|
|
521
|
-
)
|
|
522
|
-
else:
|
|
523
|
-
num_classes = None
|
|
390
|
+
print(f"Classification: {num_classes} classes: {label_encoder.classes_}")
|
|
524
391
|
|
|
525
|
-
#
|
|
526
|
-
|
|
392
|
+
# -------------------------------------------------------------------------
|
|
393
|
+
# Prepare features
|
|
394
|
+
# -------------------------------------------------------------------------
|
|
527
395
|
task = "classification" if model_type == "classifier" else "regression"
|
|
528
396
|
n_extra = len(extra_feature_cols) if use_extra_features else 0
|
|
529
|
-
max_epochs = hyperparameters.get("max_epochs", 200)
|
|
530
|
-
patience = hyperparameters.get("patience", 20)
|
|
531
|
-
n_folds = hyperparameters.get("n_folds", 5) # Number of CV folds (default: 5)
|
|
532
|
-
batch_size = hyperparameters.get("batch_size", min(64, max(16, len(all_df) // 16)))
|
|
533
397
|
|
|
534
|
-
|
|
535
|
-
if use_extra_features:
|
|
536
|
-
missing_cols = [col for col in extra_feature_cols if col not in all_df.columns]
|
|
537
|
-
if missing_cols:
|
|
538
|
-
raise ValueError(f"Missing extra feature columns in training data: {missing_cols}")
|
|
539
|
-
|
|
540
|
-
# =========================================================================
|
|
541
|
-
# UNIFIED TRAINING: Works for n_folds=1 (single model) or n_folds>1 (K-fold CV)
|
|
542
|
-
# =========================================================================
|
|
543
|
-
print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold cross-validation ensemble'}...")
|
|
544
|
-
|
|
545
|
-
# Prepare extra features and validate SMILES upfront
|
|
546
|
-
all_extra_features = None
|
|
547
|
-
col_means = None
|
|
398
|
+
all_extra_features, col_means = None, None
|
|
548
399
|
if use_extra_features:
|
|
549
400
|
all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
|
|
550
401
|
col_means = np.nanmean(all_extra_features, axis=0)
|
|
551
402
|
for i in range(all_extra_features.shape[1]):
|
|
552
403
|
all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
|
|
553
404
|
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
)
|
|
405
|
+
all_targets = all_df[target_columns].values.astype(np.float32)
|
|
406
|
+
|
|
407
|
+
# Filter invalid SMILES
|
|
408
|
+
_, valid_indices = _create_molecule_datapoints(all_df[smiles_column].tolist(), all_targets, all_extra_features)
|
|
558
409
|
all_df = all_df.iloc[valid_indices].reset_index(drop=True)
|
|
410
|
+
all_targets = all_targets[valid_indices]
|
|
559
411
|
if all_extra_features is not None:
|
|
560
412
|
all_extra_features = all_extra_features[valid_indices]
|
|
561
413
|
print(f"Data after SMILES validation: {all_df.shape}")
|
|
562
414
|
|
|
563
|
-
#
|
|
415
|
+
# Task weights for multi-task (inverse sample count)
|
|
416
|
+
task_weights = None
|
|
417
|
+
if n_targets > 1 and model_type != "classifier":
|
|
418
|
+
counts = np.array([np.sum(~np.isnan(all_targets[:, t])) for t in range(n_targets)])
|
|
419
|
+
task_weights = (1.0 / counts) / (1.0 / counts).min()
|
|
420
|
+
print(f"Task weights: {dict(zip(target_columns, task_weights.round(3)))}")
|
|
421
|
+
|
|
422
|
+
# -------------------------------------------------------------------------
|
|
423
|
+
# Cross-validation setup
|
|
424
|
+
# -------------------------------------------------------------------------
|
|
425
|
+
n_folds = hyperparameters["n_folds"]
|
|
426
|
+
batch_size = hyperparameters["batch_size"]
|
|
427
|
+
|
|
564
428
|
if n_folds == 1:
|
|
565
|
-
# Single fold: use train/val split from "training" column or random split
|
|
566
429
|
if "training" in all_df.columns:
|
|
567
|
-
print("
|
|
430
|
+
print("Using 'training' column for train/val split")
|
|
568
431
|
train_idx = np.where(all_df["training"])[0]
|
|
569
432
|
val_idx = np.where(~all_df["training"])[0]
|
|
570
433
|
else:
|
|
571
|
-
print("WARNING: No training column
|
|
572
|
-
|
|
573
|
-
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
|
|
434
|
+
print("WARNING: No 'training' column, using random 80/20 split")
|
|
435
|
+
train_idx, val_idx = train_test_split(np.arange(len(all_df)), test_size=0.2, random_state=42)
|
|
574
436
|
folds = [(train_idx, val_idx)]
|
|
575
437
|
else:
|
|
576
|
-
# K-Fold CV
|
|
577
438
|
if model_type == "classifier":
|
|
578
439
|
kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
579
|
-
|
|
440
|
+
folds = list(kfold.split(all_df, all_df[target_columns[0]]))
|
|
580
441
|
else:
|
|
581
442
|
kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
582
|
-
|
|
583
|
-
folds = list(kfold.split(all_df, split_target))
|
|
443
|
+
folds = list(kfold.split(all_df))
|
|
584
444
|
|
|
585
|
-
|
|
586
|
-
oof_predictions = np.full(len(all_df), np.nan, dtype=np.float64)
|
|
587
|
-
if model_type == "classifier" and num_classes and num_classes > 1:
|
|
588
|
-
oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
|
|
589
|
-
else:
|
|
590
|
-
oof_proba = None
|
|
445
|
+
print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
|
|
591
446
|
|
|
592
|
-
|
|
447
|
+
# -------------------------------------------------------------------------
|
|
448
|
+
# Training loop
|
|
449
|
+
# -------------------------------------------------------------------------
|
|
450
|
+
oof_predictions = np.full((len(all_df), n_targets), np.nan, dtype=np.float64)
|
|
451
|
+
oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64) if model_type == "classifier" and num_classes else None
|
|
593
452
|
|
|
453
|
+
ensemble_models = []
|
|
594
454
|
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
595
455
|
print(f"\n{'='*50}")
|
|
596
|
-
print(f"
|
|
456
|
+
print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
|
|
597
457
|
print(f"{'='*50}")
|
|
598
458
|
|
|
599
|
-
# Split data
|
|
600
|
-
df_train = all_df.iloc[train_idx].reset_index(drop=True)
|
|
601
|
-
|
|
602
|
-
|
|
459
|
+
# Split data
|
|
460
|
+
df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
|
|
461
|
+
train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
|
|
603
462
|
train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
|
|
604
463
|
val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
|
|
605
|
-
|
|
606
|
-
print(f"Fold {fold_idx + 1} - Train: {len(df_train)}, Val: {len(df_val)}")
|
|
607
|
-
|
|
608
|
-
# Create ChemProp datasets for this fold
|
|
609
|
-
train_datapoints, _ = create_molecule_datapoints(
|
|
610
|
-
df_train[smiles_column].tolist(), df_train[target].tolist(), train_extra
|
|
611
|
-
)
|
|
612
|
-
val_datapoints, _ = create_molecule_datapoints(
|
|
613
|
-
df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra
|
|
614
|
-
)
|
|
615
|
-
|
|
616
|
-
train_dataset = data.MoleculeDataset(train_datapoints)
|
|
617
|
-
val_dataset = data.MoleculeDataset(val_datapoints)
|
|
618
|
-
|
|
619
|
-
# Save raw val features for prediction
|
|
620
464
|
val_extra_raw = val_extra.copy() if val_extra is not None else None
|
|
621
465
|
|
|
622
|
-
#
|
|
466
|
+
# Create datasets
|
|
467
|
+
train_dps, _ = _create_molecule_datapoints(df_train[smiles_column].tolist(), train_targets, train_extra)
|
|
468
|
+
val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra)
|
|
469
|
+
train_dataset, val_dataset = data.MoleculeDataset(train_dps), data.MoleculeDataset(val_dps)
|
|
470
|
+
|
|
471
|
+
# Scale features/targets
|
|
623
472
|
x_d_transform = None
|
|
624
473
|
if use_extra_features:
|
|
625
|
-
|
|
626
|
-
val_dataset.normalize_inputs("X_d",
|
|
627
|
-
x_d_transform = nn.ScaleTransform.from_standard_scaler(
|
|
474
|
+
scaler = train_dataset.normalize_inputs("X_d")
|
|
475
|
+
val_dataset.normalize_inputs("X_d", scaler)
|
|
476
|
+
x_d_transform = nn.ScaleTransform.from_standard_scaler(scaler)
|
|
628
477
|
|
|
629
478
|
output_transform = None
|
|
630
479
|
if model_type in ["regressor", "uq_regressor"]:
|
|
@@ -632,31 +481,27 @@ if __name__ == "__main__":
|
|
|
632
481
|
val_dataset.normalize_targets(target_scaler)
|
|
633
482
|
output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
|
|
634
483
|
|
|
635
|
-
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
636
|
-
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
484
|
+
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
|
|
485
|
+
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
|
|
637
486
|
|
|
638
|
-
# Build and train model
|
|
639
|
-
pl.seed_everything(
|
|
487
|
+
# Build and train model
|
|
488
|
+
pl.seed_everything(hyperparameters["seed"] + fold_idx)
|
|
640
489
|
mpnn = build_mpnn_model(
|
|
641
|
-
hyperparameters, task=task, num_classes=num_classes,
|
|
642
|
-
n_extra_descriptors=n_extra, x_d_transform=x_d_transform,
|
|
490
|
+
hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
|
|
491
|
+
n_extra_descriptors=n_extra, x_d_transform=x_d_transform,
|
|
492
|
+
output_transform=output_transform, task_weights=task_weights,
|
|
643
493
|
)
|
|
644
494
|
|
|
645
|
-
callbacks = [
|
|
646
|
-
pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
|
|
647
|
-
pl.callbacks.ModelCheckpoint(
|
|
648
|
-
dirpath=args.model_dir, filename=f"best_model_{fold_idx}",
|
|
649
|
-
monitor="val_loss", mode="min", save_top_k=1,
|
|
650
|
-
),
|
|
651
|
-
]
|
|
652
|
-
|
|
653
495
|
trainer = pl.Trainer(
|
|
654
|
-
accelerator="auto", max_epochs=max_epochs,
|
|
655
|
-
|
|
496
|
+
accelerator="auto", max_epochs=hyperparameters["max_epochs"], logger=False, enable_progress_bar=True,
|
|
497
|
+
callbacks=[
|
|
498
|
+
pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min"),
|
|
499
|
+
pl.callbacks.ModelCheckpoint(dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1),
|
|
500
|
+
],
|
|
656
501
|
)
|
|
657
|
-
|
|
658
502
|
trainer.fit(mpnn, train_loader, val_loader)
|
|
659
503
|
|
|
504
|
+
# Load best checkpoint
|
|
660
505
|
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
|
|
661
506
|
checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
|
|
662
507
|
mpnn.load_state_dict(checkpoint["state_dict"])
|
|
@@ -664,189 +509,141 @@ if __name__ == "__main__":
|
|
|
664
509
|
mpnn.eval()
|
|
665
510
|
ensemble_models.append(mpnn)
|
|
666
511
|
|
|
667
|
-
#
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
)
|
|
671
|
-
val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
|
|
672
|
-
val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
|
|
512
|
+
# Out-of-fold predictions (using raw features)
|
|
513
|
+
val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
|
|
514
|
+
val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
|
|
673
515
|
|
|
674
516
|
with torch.inference_mode():
|
|
675
|
-
|
|
676
|
-
fold_preds = np.concatenate([p.numpy() for p in fold_predictions], axis=0)
|
|
517
|
+
fold_preds = np.concatenate([p.numpy() for p in trainer.predict(mpnn, val_loader_pred)], axis=0)
|
|
677
518
|
if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
|
|
678
519
|
fold_preds = fold_preds.squeeze(axis=1)
|
|
679
520
|
|
|
680
|
-
# Store out-of-fold predictions
|
|
681
521
|
if model_type == "classifier" and fold_preds.ndim == 2:
|
|
682
|
-
oof_predictions[val_idx] = np.argmax(fold_preds, axis=1)
|
|
522
|
+
oof_predictions[val_idx, 0] = np.argmax(fold_preds, axis=1)
|
|
683
523
|
if oof_proba is not None:
|
|
684
524
|
oof_proba[val_idx] = fold_preds
|
|
685
525
|
else:
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
526
|
+
if fold_preds.ndim == 1:
|
|
527
|
+
fold_preds = fold_preds.reshape(-1, 1)
|
|
528
|
+
oof_predictions[val_idx] = fold_preds
|
|
689
529
|
|
|
690
530
|
print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
|
|
691
531
|
|
|
692
|
-
#
|
|
693
|
-
#
|
|
532
|
+
# -------------------------------------------------------------------------
|
|
533
|
+
# Prepare validation results
|
|
534
|
+
# -------------------------------------------------------------------------
|
|
694
535
|
if n_folds == 1:
|
|
695
|
-
val_mask = ~np.isnan(oof_predictions)
|
|
696
|
-
preds = oof_predictions[val_mask]
|
|
536
|
+
val_mask = ~np.isnan(oof_predictions).all(axis=1)
|
|
697
537
|
df_val = all_df[val_mask].copy()
|
|
698
|
-
|
|
538
|
+
preds = oof_predictions[val_mask]
|
|
539
|
+
y_validate = all_targets[val_mask]
|
|
699
540
|
if oof_proba is not None:
|
|
700
541
|
oof_proba = oof_proba[val_mask]
|
|
701
542
|
val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
|
|
702
543
|
else:
|
|
703
|
-
preds = oof_predictions
|
|
704
544
|
df_val = all_df.copy()
|
|
705
|
-
|
|
545
|
+
preds = oof_predictions
|
|
546
|
+
y_validate = all_targets
|
|
706
547
|
val_extra_features = all_extra_features
|
|
707
548
|
|
|
708
|
-
#
|
|
709
|
-
#
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
print("Computing prediction_std from ensemble predictions on validation data...")
|
|
713
|
-
val_datapoints_for_std, _ = create_molecule_datapoints(
|
|
714
|
-
df_val[smiles_column].tolist(),
|
|
715
|
-
df_val[target].tolist(),
|
|
716
|
-
val_extra_features
|
|
717
|
-
)
|
|
718
|
-
val_dataset_for_std = data.MoleculeDataset(val_datapoints_for_std)
|
|
719
|
-
val_loader_for_std = data.build_dataloader(val_dataset_for_std, batch_size=batch_size, shuffle=False)
|
|
720
|
-
|
|
721
|
-
all_ensemble_preds_for_std = []
|
|
722
|
-
trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
723
|
-
for ens_model in ensemble_models:
|
|
724
|
-
with torch.inference_mode():
|
|
725
|
-
ens_preds = trainer_pred.predict(ens_model, val_loader_for_std)
|
|
726
|
-
ens_preds = np.concatenate([p.numpy() for p in ens_preds], axis=0)
|
|
727
|
-
if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
|
|
728
|
-
ens_preds = ens_preds.squeeze(axis=1)
|
|
729
|
-
all_ensemble_preds_for_std.append(ens_preds.flatten())
|
|
730
|
-
|
|
731
|
-
ensemble_preds_stacked = np.stack(all_ensemble_preds_for_std, axis=0)
|
|
732
|
-
preds_std = np.std(ensemble_preds_stacked, axis=0)
|
|
733
|
-
print(f"Ensemble prediction_std - mean: {np.mean(preds_std):.4f}, max: {np.max(preds_std):.4f}")
|
|
734
|
-
|
|
549
|
+
# -------------------------------------------------------------------------
|
|
550
|
+
# Compute metrics and prepare output
|
|
551
|
+
# -------------------------------------------------------------------------
|
|
552
|
+
median_std = None # Only set for regression models with ensemble
|
|
735
553
|
if model_type == "classifier":
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
print(f"class_preds shape: {class_preds.shape}")
|
|
741
|
-
|
|
742
|
-
# Decode labels for metrics
|
|
743
|
-
y_validate_decoded = label_encoder.inverse_transform(y_validate.astype(int))
|
|
554
|
+
class_preds = preds[:, 0].astype(int)
|
|
555
|
+
target_name = target_columns[0]
|
|
556
|
+
y_true_decoded = label_encoder.inverse_transform(y_validate[:, 0].astype(int))
|
|
744
557
|
preds_decoded = label_encoder.inverse_transform(class_preds)
|
|
745
558
|
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
y_validate_decoded, preds_decoded, average=None, labels=label_names
|
|
750
|
-
)
|
|
751
|
-
|
|
752
|
-
score_df = pd.DataFrame(
|
|
753
|
-
{
|
|
754
|
-
target: label_names,
|
|
755
|
-
"precision": scores[0],
|
|
756
|
-
"recall": scores[1],
|
|
757
|
-
"f1": scores[2],
|
|
758
|
-
"support": scores[3],
|
|
759
|
-
}
|
|
760
|
-
)
|
|
761
|
-
|
|
762
|
-
# Output metrics per class
|
|
763
|
-
metrics = ["precision", "recall", "f1", "support"]
|
|
764
|
-
for t in label_names:
|
|
765
|
-
for m in metrics:
|
|
766
|
-
value = score_df.loc[score_df[target] == t, m].iloc[0]
|
|
767
|
-
print(f"Metrics:{t}:{m} {value}")
|
|
559
|
+
score_df = compute_classification_metrics(y_true_decoded, preds_decoded, label_encoder.classes_, target_name)
|
|
560
|
+
print_classification_metrics(score_df, target_name, label_encoder.classes_)
|
|
561
|
+
print_confusion_matrix(y_true_decoded, preds_decoded, label_encoder.classes_)
|
|
768
562
|
|
|
769
|
-
#
|
|
770
|
-
|
|
771
|
-
y_validate_decoded, preds_decoded, labels=label_names
|
|
772
|
-
)
|
|
773
|
-
for i, row_name in enumerate(label_names):
|
|
774
|
-
for j, col_name in enumerate(label_names):
|
|
775
|
-
value = conf_mtx[i, j]
|
|
776
|
-
print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
|
|
777
|
-
|
|
778
|
-
# Save validation predictions
|
|
779
|
-
df_val = df_val.copy()
|
|
563
|
+
# Decode target column back to string labels (was encoded for training)
|
|
564
|
+
df_val[target_name] = y_true_decoded
|
|
780
565
|
df_val["prediction"] = preds_decoded
|
|
781
|
-
if
|
|
566
|
+
if oof_proba is not None:
|
|
782
567
|
df_val["pred_proba"] = [p.tolist() for p in oof_proba]
|
|
783
|
-
df_val = expand_proba_column(df_val,
|
|
784
|
-
|
|
568
|
+
df_val = expand_proba_column(df_val, label_encoder.classes_)
|
|
785
569
|
else:
|
|
786
|
-
#
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
570
|
+
# Compute ensemble std
|
|
571
|
+
preds_std = None
|
|
572
|
+
if len(ensemble_models) > 1:
|
|
573
|
+
print("Computing prediction_std from ensemble...")
|
|
574
|
+
val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), y_validate, val_extra_features)
|
|
575
|
+
val_loader = data.build_dataloader(data.MoleculeDataset(val_dps), batch_size=batch_size, shuffle=False)
|
|
576
|
+
trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
577
|
+
|
|
578
|
+
all_ens_preds = []
|
|
579
|
+
for m in ensemble_models:
|
|
580
|
+
with torch.inference_mode():
|
|
581
|
+
ens_preds = np.concatenate([p.numpy() for p in trainer_pred.predict(m, val_loader)], axis=0)
|
|
582
|
+
if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
|
|
583
|
+
ens_preds = ens_preds.squeeze(axis=1)
|
|
584
|
+
all_ens_preds.append(ens_preds)
|
|
585
|
+
preds_std = np.std(np.stack(all_ens_preds), axis=0)
|
|
586
|
+
if preds_std.ndim == 1:
|
|
587
|
+
preds_std = preds_std.reshape(-1, 1)
|
|
588
|
+
|
|
589
|
+
print("\n--- Per-target metrics ---")
|
|
590
|
+
for t_idx, t_name in enumerate(target_columns):
|
|
591
|
+
valid_mask = ~np.isnan(y_validate[:, t_idx])
|
|
592
|
+
if valid_mask.sum() > 0:
|
|
593
|
+
metrics = compute_regression_metrics(y_validate[valid_mask, t_idx], preds[valid_mask, t_idx])
|
|
594
|
+
print_regression_metrics(metrics)
|
|
595
|
+
|
|
596
|
+
df_val[f"{t_name}_pred"] = preds[:, t_idx]
|
|
597
|
+
df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx] if preds_std is not None else 0.0
|
|
598
|
+
|
|
599
|
+
df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
|
|
600
|
+
df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
|
|
601
|
+
|
|
602
|
+
# Compute confidence from ensemble std
|
|
603
|
+
median_std = float(np.median(preds_std[:, 0]))
|
|
604
|
+
print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
|
|
605
|
+
df_val = _compute_std_confidence(df_val, median_std)
|
|
606
|
+
print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
|
|
607
|
+
|
|
608
|
+
# -------------------------------------------------------------------------
|
|
811
609
|
# Save validation predictions to S3
|
|
812
|
-
#
|
|
813
|
-
output_columns = []
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
output_columns += [
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
wr.s3.to_csv(
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
610
|
+
# -------------------------------------------------------------------------
|
|
611
|
+
output_columns = [id_column] if id_column in df_val.columns else []
|
|
612
|
+
output_columns += target_columns
|
|
613
|
+
output_columns += [f"{t}_pred" for t in target_columns] + [f"{t}_pred_std" for t in target_columns]
|
|
614
|
+
output_columns += ["prediction", "prediction_std", "confidence"]
|
|
615
|
+
output_columns += [c for c in df_val.columns if c.endswith("_proba")]
|
|
616
|
+
output_columns = [c for c in output_columns if c in df_val.columns]
|
|
617
|
+
|
|
618
|
+
wr.s3.to_csv(df_val[output_columns], f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
|
|
619
|
+
|
|
620
|
+
# -------------------------------------------------------------------------
|
|
621
|
+
# Save model artifacts
|
|
622
|
+
# -------------------------------------------------------------------------
|
|
623
|
+
for idx, m in enumerate(ensemble_models):
|
|
624
|
+
models.save_model(os.path.join(args.model_dir, f"chemprop_model_{idx}.pt"), m)
|
|
625
|
+
print(f"Saved {len(ensemble_models)} model(s)")
|
|
626
|
+
|
|
627
|
+
# Clean up checkpoints
|
|
628
|
+
for ckpt in glob.glob(os.path.join(args.model_dir, "best_*.ckpt")):
|
|
629
|
+
os.remove(ckpt)
|
|
630
|
+
|
|
631
|
+
ensemble_metadata = {
|
|
632
|
+
"n_ensemble": len(ensemble_models),
|
|
633
|
+
"n_folds": n_folds,
|
|
634
|
+
"target_columns": target_columns,
|
|
635
|
+
"median_std": median_std, # For confidence calculation during inference
|
|
636
|
+
}
|
|
835
637
|
joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
836
|
-
print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds})")
|
|
837
638
|
|
|
838
|
-
|
|
839
|
-
|
|
639
|
+
with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
|
|
640
|
+
json.dump(hyperparameters, f, indent=2)
|
|
641
|
+
|
|
642
|
+
if label_encoder:
|
|
840
643
|
joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
|
|
841
644
|
|
|
842
|
-
# Save extra feature metadata for inference (hybrid mode)
|
|
843
|
-
# Note: We don't need to save the scaler - X_d_transform is embedded in the model
|
|
844
645
|
if use_extra_features:
|
|
845
|
-
|
|
846
|
-
"extra_feature_cols": extra_feature_cols,
|
|
847
|
-
"col_means": col_means.tolist(), # Unscaled means for NaN imputation
|
|
848
|
-
}
|
|
849
|
-
joblib.dump(
|
|
850
|
-
feature_metadata, os.path.join(args.model_dir, "feature_metadata.joblib")
|
|
851
|
-
)
|
|
646
|
+
joblib.dump({"extra_feature_cols": extra_feature_cols, "col_means": col_means.tolist()}, os.path.join(args.model_dir, "feature_metadata.joblib"))
|
|
852
647
|
print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")
|
|
648
|
+
|
|
649
|
+
print(f"\nModel training complete! Artifacts saved to {args.model_dir}")
|