workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
- workbench/algorithms/dataframe/proximity.py +261 -235
- workbench/algorithms/graph/light/proximity_graph.py +10 -8
- workbench/api/__init__.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +11 -0
- workbench/api/feature_set.py +11 -8
- workbench/api/meta.py +5 -2
- workbench/api/model.py +16 -15
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +256 -118
- workbench/core/artifacts/feature_set_core.py +265 -16
- workbench/core/artifacts/model_core.py +107 -60
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +42 -32
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/chemprop/chemprop.template +852 -0
- workbench/model_scripts/chemprop/generated_model_script.py +852 -0
- workbench/model_scripts/chemprop/requirements.txt +11 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
- workbench/model_scripts/pytorch_model/pytorch.template +370 -187
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +17 -9
- workbench/model_scripts/uq_models/generated_model_script.py +605 -0
- workbench/model_scripts/uq_models/mapie.template +605 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
- workbench/model_scripts/xgb_model/xgb_model.template +44 -46
- workbench/repl/workbench_shell.py +28 -14
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +760 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +95 -34
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +526 -0
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +371 -156
- workbench/web_interface/components/model_plot.py +7 -1
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +9 -7
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,852 @@
|
|
|
1
|
+
# ChemProp Model Template for Workbench
|
|
2
|
+
# Uses ChemProp 2.x Message Passing Neural Networks for molecular property prediction
|
|
3
|
+
#
|
|
4
|
+
# === CHEMPROP REVIEW NOTES ===
|
|
5
|
+
# This script runs on AWS SageMaker. Key areas for ChemProp review:
|
|
6
|
+
#
|
|
7
|
+
# 1. Model Architecture (build_mpnn_model function)
|
|
8
|
+
# - BondMessagePassing, NormAggregation, FFN configuration
|
|
9
|
+
# - Regression uses output_transform (UnscaleTransform) for target scaling
|
|
10
|
+
#
|
|
11
|
+
# 2. Data Handling (create_molecule_datapoints function)
|
|
12
|
+
# - MoleculeDatapoint creation with x_d (extra descriptors)
|
|
13
|
+
# - RDKit validation of SMILES
|
|
14
|
+
#
|
|
15
|
+
# 3. Scaling (training section)
|
|
16
|
+
# - Extra descriptors: normalize_inputs("X_d") + X_d_transform in model
|
|
17
|
+
# - Targets (regression): normalize_targets() + UnscaleTransform in FFN
|
|
18
|
+
# - At inference: pass RAW features, transforms handle scaling automatically
|
|
19
|
+
#
|
|
20
|
+
# 4. Training Loop (search for "pl.Trainer")
|
|
21
|
+
# - PyTorch Lightning Trainer with ChemProp MPNN
|
|
22
|
+
#
|
|
23
|
+
# AWS/SageMaker boilerplate (can skip):
|
|
24
|
+
# - input_fn, output_fn, model_fn: SageMaker serving interface
|
|
25
|
+
# - argparse, file loading, S3 writes
|
|
26
|
+
# =============================
|
|
27
|
+
|
|
28
|
+
import os
|
|
29
|
+
import argparse
|
|
30
|
+
import json
|
|
31
|
+
from io import StringIO
|
|
32
|
+
|
|
33
|
+
import awswrangler as wr
|
|
34
|
+
import numpy as np
|
|
35
|
+
import pandas as pd
|
|
36
|
+
import torch
|
|
37
|
+
from lightning import pytorch as pl
|
|
38
|
+
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
|
|
39
|
+
from sklearn.preprocessing import LabelEncoder
|
|
40
|
+
from sklearn.metrics import (
|
|
41
|
+
mean_absolute_error,
|
|
42
|
+
median_absolute_error,
|
|
43
|
+
r2_score,
|
|
44
|
+
root_mean_squared_error,
|
|
45
|
+
precision_recall_fscore_support,
|
|
46
|
+
confusion_matrix,
|
|
47
|
+
)
|
|
48
|
+
from scipy.stats import spearmanr
|
|
49
|
+
import joblib
|
|
50
|
+
|
|
51
|
+
# ChemProp imports
|
|
52
|
+
from chemprop import data, models, nn
|
|
53
|
+
|
|
54
|
+
# Template Parameters
|
|
55
|
+
TEMPLATE_PARAMS = {
|
|
56
|
+
"model_type": "{{model_type}}",
|
|
57
|
+
"target": "{{target_column}}",
|
|
58
|
+
"feature_list": "{{feature_list}}",
|
|
59
|
+
"id_column": "{{id_column}}",
|
|
60
|
+
"model_metrics_s3_path": "{{model_metrics_s3_path}}",
|
|
61
|
+
"hyperparameters": "{{hyperparameters}}",
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
|
|
66
|
+
"""Check if the provided dataframe is empty and raise an exception if it is."""
|
|
67
|
+
if df.empty:
|
|
68
|
+
msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
|
|
69
|
+
print(msg)
|
|
70
|
+
raise ValueError(msg)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def find_smiles_column(columns: list[str]) -> str:
|
|
74
|
+
"""Find the SMILES column name from a list (case-insensitive match for 'smiles')."""
|
|
75
|
+
smiles_column = next((col for col in columns if col.lower() == "smiles"), None)
|
|
76
|
+
if smiles_column is None:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Column list must contain a 'smiles' column (case-insensitive)"
|
|
79
|
+
)
|
|
80
|
+
return smiles_column
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
|
|
84
|
+
"""Expands a column containing a list of probabilities into separate columns.
|
|
85
|
+
|
|
86
|
+
Handles None values for rows where predictions couldn't be made.
|
|
87
|
+
"""
|
|
88
|
+
proba_column = "pred_proba"
|
|
89
|
+
if proba_column not in df.columns:
|
|
90
|
+
raise ValueError('DataFrame does not contain a "pred_proba" column')
|
|
91
|
+
|
|
92
|
+
proba_splits = [f"{label}_proba" for label in class_labels]
|
|
93
|
+
n_classes = len(class_labels)
|
|
94
|
+
|
|
95
|
+
# Handle None values by replacing with list of NaNs
|
|
96
|
+
proba_values = []
|
|
97
|
+
for val in df[proba_column]:
|
|
98
|
+
if val is None:
|
|
99
|
+
proba_values.append([np.nan] * n_classes)
|
|
100
|
+
else:
|
|
101
|
+
proba_values.append(val)
|
|
102
|
+
|
|
103
|
+
proba_df = pd.DataFrame(proba_values, columns=proba_splits)
|
|
104
|
+
|
|
105
|
+
df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
|
|
106
|
+
df = df.reset_index(drop=True)
|
|
107
|
+
df = pd.concat([df, proba_df], axis=1)
|
|
108
|
+
return df
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def create_molecule_datapoints(
|
|
112
|
+
smiles_list: list[str],
|
|
113
|
+
targets: list[float] | None = None,
|
|
114
|
+
extra_descriptors: np.ndarray | None = None,
|
|
115
|
+
) -> tuple[list[data.MoleculeDatapoint], list[int]]:
|
|
116
|
+
"""Create ChemProp MoleculeDatapoints from SMILES strings.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
smiles_list: List of SMILES strings
|
|
120
|
+
targets: Optional list of target values (for training)
|
|
121
|
+
extra_descriptors: Optional array of extra features (n_samples, n_features)
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Tuple of (list of MoleculeDatapoint objects, list of valid indices)
|
|
125
|
+
"""
|
|
126
|
+
from rdkit import Chem
|
|
127
|
+
|
|
128
|
+
datapoints = []
|
|
129
|
+
valid_indices = []
|
|
130
|
+
invalid_count = 0
|
|
131
|
+
|
|
132
|
+
for i, smi in enumerate(smiles_list):
|
|
133
|
+
# Validate SMILES with RDKit first
|
|
134
|
+
mol = Chem.MolFromSmiles(smi)
|
|
135
|
+
if mol is None:
|
|
136
|
+
invalid_count += 1
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
# Build datapoint with optional target and extra descriptors
|
|
140
|
+
y = [targets[i]] if targets is not None else None
|
|
141
|
+
x_d = extra_descriptors[i] if extra_descriptors is not None else None
|
|
142
|
+
|
|
143
|
+
dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
|
|
144
|
+
datapoints.append(dp)
|
|
145
|
+
valid_indices.append(i)
|
|
146
|
+
|
|
147
|
+
if invalid_count > 0:
|
|
148
|
+
print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
|
|
149
|
+
|
|
150
|
+
return datapoints, valid_indices
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def build_mpnn_model(
|
|
154
|
+
hyperparameters: dict,
|
|
155
|
+
task: str = "regression",
|
|
156
|
+
num_classes: int | None = None,
|
|
157
|
+
n_extra_descriptors: int = 0,
|
|
158
|
+
x_d_transform: nn.ScaleTransform | None = None,
|
|
159
|
+
output_transform: nn.UnscaleTransform | None = None,
|
|
160
|
+
) -> models.MPNN:
|
|
161
|
+
"""Build an MPNN model with the specified hyperparameters.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
hyperparameters: Dictionary of model hyperparameters
|
|
165
|
+
task: Either "regression" or "classification"
|
|
166
|
+
num_classes: Number of classes for classification tasks
|
|
167
|
+
n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
|
|
168
|
+
x_d_transform: Optional transform for extra descriptors (scaling)
|
|
169
|
+
output_transform: Optional transform for regression output (unscaling targets)
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Configured MPNN model
|
|
173
|
+
"""
|
|
174
|
+
# Model hyperparameters with defaults
|
|
175
|
+
hidden_dim = hyperparameters.get("hidden_dim", 300)
|
|
176
|
+
depth = hyperparameters.get("depth", 4)
|
|
177
|
+
dropout = hyperparameters.get("dropout", 0.1)
|
|
178
|
+
ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 300)
|
|
179
|
+
ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
|
|
180
|
+
|
|
181
|
+
# Message passing component
|
|
182
|
+
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
183
|
+
|
|
184
|
+
# Aggregation - NormAggregation normalizes output, recommended when using extra descriptors
|
|
185
|
+
agg = nn.NormAggregation()
|
|
186
|
+
|
|
187
|
+
# FFN input_dim = message passing output + extra descriptors
|
|
188
|
+
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
189
|
+
|
|
190
|
+
# Build FFN based on task type
|
|
191
|
+
if task == "classification" and num_classes is not None:
|
|
192
|
+
# Multi-class classification
|
|
193
|
+
ffn = nn.MulticlassClassificationFFN(
|
|
194
|
+
n_classes=num_classes,
|
|
195
|
+
input_dim=ffn_input_dim,
|
|
196
|
+
hidden_dim=ffn_hidden_dim,
|
|
197
|
+
n_layers=ffn_num_layers,
|
|
198
|
+
dropout=dropout,
|
|
199
|
+
)
|
|
200
|
+
else:
|
|
201
|
+
# Regression with optional output transform to unscale predictions
|
|
202
|
+
ffn = nn.RegressionFFN(
|
|
203
|
+
input_dim=ffn_input_dim,
|
|
204
|
+
hidden_dim=ffn_hidden_dim,
|
|
205
|
+
n_layers=ffn_num_layers,
|
|
206
|
+
dropout=dropout,
|
|
207
|
+
output_transform=output_transform,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Create the MPNN model
|
|
211
|
+
mpnn = models.MPNN(
|
|
212
|
+
message_passing=mp,
|
|
213
|
+
agg=agg,
|
|
214
|
+
predictor=ffn,
|
|
215
|
+
batch_norm=True,
|
|
216
|
+
metrics=None,
|
|
217
|
+
X_d_transform=x_d_transform,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
return mpnn
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def model_fn(model_dir: str) -> dict:
|
|
224
|
+
"""Load the ChemProp MPNN ensemble models from the specified directory.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
model_dir: Directory containing the saved models
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Dictionary with ensemble models and metadata
|
|
231
|
+
"""
|
|
232
|
+
# Load ensemble metadata
|
|
233
|
+
ensemble_metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
|
|
234
|
+
if os.path.exists(ensemble_metadata_path):
|
|
235
|
+
ensemble_metadata = joblib.load(ensemble_metadata_path)
|
|
236
|
+
n_ensemble = ensemble_metadata["n_ensemble"]
|
|
237
|
+
else:
|
|
238
|
+
# Backwards compatibility: single model without ensemble metadata
|
|
239
|
+
n_ensemble = 1
|
|
240
|
+
|
|
241
|
+
# Load all ensemble models
|
|
242
|
+
ensemble_models = []
|
|
243
|
+
for ens_idx in range(n_ensemble):
|
|
244
|
+
model_path = os.path.join(model_dir, f"chemprop_model_{ens_idx}.pt")
|
|
245
|
+
if not os.path.exists(model_path):
|
|
246
|
+
# Backwards compatibility: try old single model path
|
|
247
|
+
model_path = os.path.join(model_dir, "chemprop_model.pt")
|
|
248
|
+
model = models.MPNN.load_from_file(model_path)
|
|
249
|
+
model.eval()
|
|
250
|
+
ensemble_models.append(model)
|
|
251
|
+
|
|
252
|
+
print(f"Loaded {len(ensemble_models)} ensemble model(s)")
|
|
253
|
+
|
|
254
|
+
return {
|
|
255
|
+
"ensemble_models": ensemble_models,
|
|
256
|
+
"n_ensemble": n_ensemble,
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def input_fn(input_data, content_type: str) -> pd.DataFrame:
|
|
261
|
+
"""Parse input data and return a DataFrame."""
|
|
262
|
+
if not input_data:
|
|
263
|
+
raise ValueError("Empty input data is not supported!")
|
|
264
|
+
|
|
265
|
+
if isinstance(input_data, bytes):
|
|
266
|
+
input_data = input_data.decode("utf-8")
|
|
267
|
+
|
|
268
|
+
if "text/csv" in content_type:
|
|
269
|
+
return pd.read_csv(StringIO(input_data))
|
|
270
|
+
elif "application/json" in content_type:
|
|
271
|
+
return pd.DataFrame(json.loads(input_data))
|
|
272
|
+
else:
|
|
273
|
+
raise ValueError(f"{content_type} not supported!")
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
277
|
+
"""Supports both CSV and JSON output formats."""
|
|
278
|
+
if "text/csv" in accept_type:
|
|
279
|
+
csv_output = output_df.fillna("N/A").to_csv(index=False)
|
|
280
|
+
return csv_output, "text/csv"
|
|
281
|
+
elif "application/json" in accept_type:
|
|
282
|
+
return output_df.to_json(orient="records"), "application/json"
|
|
283
|
+
else:
|
|
284
|
+
raise RuntimeError(
|
|
285
|
+
f"{accept_type} accept type is not supported by this script."
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
290
|
+
"""Make predictions with the ChemProp MPNN ensemble.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
df: Input DataFrame containing SMILES column (and extra features if hybrid mode)
|
|
294
|
+
model_dict: Dictionary containing ensemble models and metadata
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
DataFrame with predictions added (and prediction_std for ensembles)
|
|
298
|
+
"""
|
|
299
|
+
model_type = TEMPLATE_PARAMS["model_type"]
|
|
300
|
+
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
301
|
+
|
|
302
|
+
# Extract ensemble models
|
|
303
|
+
ensemble_models = model_dict["ensemble_models"]
|
|
304
|
+
n_ensemble = model_dict["n_ensemble"]
|
|
305
|
+
|
|
306
|
+
# Load label encoder if present (classification)
|
|
307
|
+
label_encoder = None
|
|
308
|
+
label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
|
|
309
|
+
if os.path.exists(label_encoder_path):
|
|
310
|
+
label_encoder = joblib.load(label_encoder_path)
|
|
311
|
+
|
|
312
|
+
# Load feature metadata if present (hybrid mode)
|
|
313
|
+
# Contains column names, NaN fill values, and scaler for feature scaling
|
|
314
|
+
feature_metadata = None
|
|
315
|
+
feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
|
|
316
|
+
if os.path.exists(feature_metadata_path):
|
|
317
|
+
feature_metadata = joblib.load(feature_metadata_path)
|
|
318
|
+
print(
|
|
319
|
+
f"Hybrid mode: using {len(feature_metadata['extra_feature_cols'])} extra features"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Find SMILES column in input DataFrame
|
|
323
|
+
smiles_column = find_smiles_column(df.columns.tolist())
|
|
324
|
+
|
|
325
|
+
smiles_list = df[smiles_column].tolist()
|
|
326
|
+
|
|
327
|
+
# Track invalid SMILES
|
|
328
|
+
valid_mask = []
|
|
329
|
+
valid_smiles = []
|
|
330
|
+
valid_indices = []
|
|
331
|
+
for i, smi in enumerate(smiles_list):
|
|
332
|
+
if smi and isinstance(smi, str) and len(smi.strip()) > 0:
|
|
333
|
+
valid_mask.append(True)
|
|
334
|
+
valid_smiles.append(smi.strip())
|
|
335
|
+
valid_indices.append(i)
|
|
336
|
+
else:
|
|
337
|
+
valid_mask.append(False)
|
|
338
|
+
|
|
339
|
+
valid_mask = np.array(valid_mask)
|
|
340
|
+
print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
|
|
341
|
+
|
|
342
|
+
# Initialize prediction column (use object dtype for classifiers to avoid FutureWarning)
|
|
343
|
+
if model_type == "classifier":
|
|
344
|
+
df["prediction"] = pd.Series([None] * len(df), dtype=object)
|
|
345
|
+
else:
|
|
346
|
+
# Regression (includes uq_regressor)
|
|
347
|
+
df["prediction"] = np.nan
|
|
348
|
+
df["prediction_std"] = np.nan
|
|
349
|
+
|
|
350
|
+
if sum(valid_mask) == 0:
|
|
351
|
+
print("Warning: No valid SMILES to predict on")
|
|
352
|
+
return df
|
|
353
|
+
|
|
354
|
+
# Prepare extra features if in hybrid mode
|
|
355
|
+
# NOTE: We pass RAW (unscaled) features here - the model's X_d_transform handles scaling
|
|
356
|
+
extra_features = None
|
|
357
|
+
if feature_metadata is not None:
|
|
358
|
+
extra_feature_cols = feature_metadata["extra_feature_cols"]
|
|
359
|
+
col_means = np.array(feature_metadata["col_means"])
|
|
360
|
+
|
|
361
|
+
# Check columns exist
|
|
362
|
+
missing_cols = [col for col in extra_feature_cols if col not in df.columns]
|
|
363
|
+
if missing_cols:
|
|
364
|
+
print(
|
|
365
|
+
f"Warning: Missing extra feature columns: {missing_cols}. Using mean values."
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# Extract features for valid SMILES rows (raw, unscaled)
|
|
369
|
+
extra_features = np.zeros(
|
|
370
|
+
(len(valid_indices), len(extra_feature_cols)), dtype=np.float32
|
|
371
|
+
)
|
|
372
|
+
for j, col in enumerate(extra_feature_cols):
|
|
373
|
+
if col in df.columns:
|
|
374
|
+
values = df.iloc[valid_indices][col].values.astype(np.float32)
|
|
375
|
+
# Fill NaN with training column means (unscaled means)
|
|
376
|
+
nan_mask = np.isnan(values)
|
|
377
|
+
values[nan_mask] = col_means[j]
|
|
378
|
+
extra_features[:, j] = values
|
|
379
|
+
else:
|
|
380
|
+
# Column missing, use training mean
|
|
381
|
+
extra_features[:, j] = col_means[j]
|
|
382
|
+
|
|
383
|
+
# Create datapoints for prediction (filter out invalid SMILES)
|
|
384
|
+
datapoints, rdkit_valid_indices = create_molecule_datapoints(
|
|
385
|
+
valid_smiles, extra_descriptors=extra_features
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
if len(datapoints) == 0:
|
|
389
|
+
print("Warning: No valid SMILES after RDKit validation")
|
|
390
|
+
return df
|
|
391
|
+
|
|
392
|
+
dataset = data.MoleculeDataset(datapoints)
|
|
393
|
+
dataloader = data.build_dataloader(dataset, shuffle=False)
|
|
394
|
+
|
|
395
|
+
# Make predictions with ensemble
|
|
396
|
+
trainer = pl.Trainer(
|
|
397
|
+
accelerator="auto",
|
|
398
|
+
logger=False,
|
|
399
|
+
enable_progress_bar=False,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Collect predictions from all ensemble members
|
|
403
|
+
all_ensemble_preds = []
|
|
404
|
+
for ens_idx, ens_model in enumerate(ensemble_models):
|
|
405
|
+
with torch.inference_mode():
|
|
406
|
+
predictions = trainer.predict(ens_model, dataloader)
|
|
407
|
+
ens_preds = np.concatenate([p.numpy() for p in predictions], axis=0)
|
|
408
|
+
# Squeeze middle dim if present
|
|
409
|
+
if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
|
|
410
|
+
ens_preds = ens_preds.squeeze(axis=1)
|
|
411
|
+
all_ensemble_preds.append(ens_preds)
|
|
412
|
+
|
|
413
|
+
# Stack and compute mean/std (std is 0 for single model)
|
|
414
|
+
ensemble_preds = np.stack(all_ensemble_preds, axis=0)
|
|
415
|
+
preds = np.mean(ensemble_preds, axis=0)
|
|
416
|
+
preds_std = np.std(ensemble_preds, axis=0) # Will be 0s for n_ensemble=1
|
|
417
|
+
|
|
418
|
+
print(f"Inference: Ensemble predictions shape: {preds.shape}")
|
|
419
|
+
|
|
420
|
+
# Map predictions back to valid_mask positions (accounting for RDKit-invalid SMILES)
|
|
421
|
+
# rdkit_valid_indices tells us which of the valid_smiles were actually valid
|
|
422
|
+
valid_positions = np.where(valid_mask)[0][rdkit_valid_indices]
|
|
423
|
+
valid_mask = np.zeros(len(df), dtype=bool)
|
|
424
|
+
valid_mask[valid_positions] = True
|
|
425
|
+
|
|
426
|
+
if model_type == "classifier" and label_encoder is not None:
|
|
427
|
+
# For classification, get class predictions and probabilities
|
|
428
|
+
if preds.ndim == 2 and preds.shape[1] > 1:
|
|
429
|
+
# Multi-class: preds are probabilities (averaged across ensemble)
|
|
430
|
+
class_preds = np.argmax(preds, axis=1)
|
|
431
|
+
decoded_preds = label_encoder.inverse_transform(class_preds)
|
|
432
|
+
df.loc[valid_mask, "prediction"] = decoded_preds
|
|
433
|
+
|
|
434
|
+
# Add probability columns
|
|
435
|
+
proba_series = pd.Series([None] * len(df), index=df.index, dtype=object)
|
|
436
|
+
proba_series.loc[valid_mask] = [p.tolist() for p in preds]
|
|
437
|
+
df["pred_proba"] = proba_series
|
|
438
|
+
df = expand_proba_column(df, label_encoder.classes_)
|
|
439
|
+
else:
|
|
440
|
+
# Binary or single output
|
|
441
|
+
class_preds = (preds.flatten() > 0.5).astype(int)
|
|
442
|
+
decoded_preds = label_encoder.inverse_transform(class_preds)
|
|
443
|
+
df.loc[valid_mask, "prediction"] = decoded_preds
|
|
444
|
+
else:
|
|
445
|
+
# Regression: direct predictions
|
|
446
|
+
df.loc[valid_mask, "prediction"] = preds.flatten()
|
|
447
|
+
df.loc[valid_mask, "prediction_std"] = preds_std.flatten()
|
|
448
|
+
|
|
449
|
+
return df
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
if __name__ == "__main__":
|
|
453
|
+
"""Training script for ChemProp MPNN model"""
|
|
454
|
+
|
|
455
|
+
# Template Parameters
|
|
456
|
+
target = TEMPLATE_PARAMS["target"]
|
|
457
|
+
model_type = TEMPLATE_PARAMS["model_type"]
|
|
458
|
+
feature_list = TEMPLATE_PARAMS["feature_list"]
|
|
459
|
+
id_column = TEMPLATE_PARAMS["id_column"]
|
|
460
|
+
model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
|
|
461
|
+
hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
|
|
462
|
+
|
|
463
|
+
# Get the SMILES column name from feature_list (user defines this, so we use their exact name)
|
|
464
|
+
smiles_column = find_smiles_column(feature_list)
|
|
465
|
+
extra_feature_cols = [f for f in feature_list if f != smiles_column]
|
|
466
|
+
use_extra_features = len(extra_feature_cols) > 0
|
|
467
|
+
print(f"Feature List: {feature_list}")
|
|
468
|
+
print(f"SMILES Column: {smiles_column}")
|
|
469
|
+
print(
|
|
470
|
+
f"Extra Features (hybrid mode): {extra_feature_cols if use_extra_features else 'None (SMILES only)'}"
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# Script arguments for input/output directories
|
|
474
|
+
parser = argparse.ArgumentParser()
|
|
475
|
+
parser.add_argument(
|
|
476
|
+
"--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
477
|
+
)
|
|
478
|
+
parser.add_argument(
|
|
479
|
+
"--train",
|
|
480
|
+
type=str,
|
|
481
|
+
default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"),
|
|
482
|
+
)
|
|
483
|
+
parser.add_argument(
|
|
484
|
+
"--output-data-dir",
|
|
485
|
+
type=str,
|
|
486
|
+
default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"),
|
|
487
|
+
)
|
|
488
|
+
args = parser.parse_args()
|
|
489
|
+
|
|
490
|
+
# Read the training data
|
|
491
|
+
training_files = [
|
|
492
|
+
os.path.join(args.train, f)
|
|
493
|
+
for f in os.listdir(args.train)
|
|
494
|
+
if f.endswith(".csv")
|
|
495
|
+
]
|
|
496
|
+
print(f"Training Files: {training_files}")
|
|
497
|
+
|
|
498
|
+
all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
|
|
499
|
+
print(f"All Data Shape: {all_df.shape}")
|
|
500
|
+
|
|
501
|
+
check_dataframe(all_df, "training_df")
|
|
502
|
+
|
|
503
|
+
# Drop rows with missing SMILES or target values
|
|
504
|
+
initial_count = len(all_df)
|
|
505
|
+
all_df = all_df.dropna(subset=[smiles_column, target])
|
|
506
|
+
dropped = initial_count - len(all_df)
|
|
507
|
+
if dropped > 0:
|
|
508
|
+
print(f"Dropped {dropped} rows with missing SMILES or target values")
|
|
509
|
+
|
|
510
|
+
print(f"Target: {target}")
|
|
511
|
+
print(f"Data Shape after cleaning: {all_df.shape}")
|
|
512
|
+
|
|
513
|
+
# Set up label encoder for classification
|
|
514
|
+
label_encoder = None
|
|
515
|
+
if model_type == "classifier":
|
|
516
|
+
label_encoder = LabelEncoder()
|
|
517
|
+
all_df[target] = label_encoder.fit_transform(all_df[target])
|
|
518
|
+
num_classes = len(label_encoder.classes_)
|
|
519
|
+
print(
|
|
520
|
+
f"Classification task with {num_classes} classes: {label_encoder.classes_}"
|
|
521
|
+
)
|
|
522
|
+
else:
|
|
523
|
+
num_classes = None
|
|
524
|
+
|
|
525
|
+
# Model and training configuration
|
|
526
|
+
print(f"Hyperparameters: {hyperparameters}")
|
|
527
|
+
task = "classification" if model_type == "classifier" else "regression"
|
|
528
|
+
n_extra = len(extra_feature_cols) if use_extra_features else 0
|
|
529
|
+
max_epochs = hyperparameters.get("max_epochs", 200)
|
|
530
|
+
patience = hyperparameters.get("patience", 20)
|
|
531
|
+
n_folds = hyperparameters.get("n_folds", 5) # Number of CV folds (default: 5)
|
|
532
|
+
batch_size = hyperparameters.get("batch_size", min(64, max(16, len(all_df) // 16)))
|
|
533
|
+
|
|
534
|
+
# Check extra feature columns exist
|
|
535
|
+
if use_extra_features:
|
|
536
|
+
missing_cols = [col for col in extra_feature_cols if col not in all_df.columns]
|
|
537
|
+
if missing_cols:
|
|
538
|
+
raise ValueError(f"Missing extra feature columns in training data: {missing_cols}")
|
|
539
|
+
|
|
540
|
+
# =========================================================================
|
|
541
|
+
# UNIFIED TRAINING: Works for n_folds=1 (single model) or n_folds>1 (K-fold CV)
|
|
542
|
+
# =========================================================================
|
|
543
|
+
print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold cross-validation ensemble'}...")
|
|
544
|
+
|
|
545
|
+
# Prepare extra features and validate SMILES upfront
|
|
546
|
+
all_extra_features = None
|
|
547
|
+
col_means = None
|
|
548
|
+
if use_extra_features:
|
|
549
|
+
all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
|
|
550
|
+
col_means = np.nanmean(all_extra_features, axis=0)
|
|
551
|
+
for i in range(all_extra_features.shape[1]):
|
|
552
|
+
all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
|
|
553
|
+
|
|
554
|
+
# Filter invalid SMILES from the full dataset
|
|
555
|
+
_, valid_indices = create_molecule_datapoints(
|
|
556
|
+
all_df[smiles_column].tolist(), all_df[target].tolist(), all_extra_features
|
|
557
|
+
)
|
|
558
|
+
all_df = all_df.iloc[valid_indices].reset_index(drop=True)
|
|
559
|
+
if all_extra_features is not None:
|
|
560
|
+
all_extra_features = all_extra_features[valid_indices]
|
|
561
|
+
print(f"Data after SMILES validation: {all_df.shape}")
|
|
562
|
+
|
|
563
|
+
# Create fold splits
|
|
564
|
+
if n_folds == 1:
|
|
565
|
+
# Single fold: use train/val split from "training" column or random split
|
|
566
|
+
if "training" in all_df.columns:
|
|
567
|
+
print("Found training column, splitting data based on training column")
|
|
568
|
+
train_idx = np.where(all_df["training"])[0]
|
|
569
|
+
val_idx = np.where(~all_df["training"])[0]
|
|
570
|
+
else:
|
|
571
|
+
print("WARNING: No training column found, splitting data with random 80/20 split")
|
|
572
|
+
indices = np.arange(len(all_df))
|
|
573
|
+
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
|
|
574
|
+
folds = [(train_idx, val_idx)]
|
|
575
|
+
else:
|
|
576
|
+
# K-Fold CV
|
|
577
|
+
if model_type == "classifier":
|
|
578
|
+
kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
579
|
+
split_target = all_df[target]
|
|
580
|
+
else:
|
|
581
|
+
kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
582
|
+
split_target = None
|
|
583
|
+
folds = list(kfold.split(all_df, split_target))
|
|
584
|
+
|
|
585
|
+
# Initialize storage for out-of-fold predictions
|
|
586
|
+
oof_predictions = np.full(len(all_df), np.nan, dtype=np.float64)
|
|
587
|
+
if model_type == "classifier" and num_classes and num_classes > 1:
|
|
588
|
+
oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
|
|
589
|
+
else:
|
|
590
|
+
oof_proba = None
|
|
591
|
+
|
|
592
|
+
ensemble_models = []
|
|
593
|
+
|
|
594
|
+
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
595
|
+
print(f"\n{'='*50}")
|
|
596
|
+
print(f"Training Fold {fold_idx + 1}/{len(folds)}")
|
|
597
|
+
print(f"{'='*50}")
|
|
598
|
+
|
|
599
|
+
# Split data for this fold
|
|
600
|
+
df_train = all_df.iloc[train_idx].reset_index(drop=True)
|
|
601
|
+
df_val = all_df.iloc[val_idx].reset_index(drop=True)
|
|
602
|
+
|
|
603
|
+
train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
|
|
604
|
+
val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
|
|
605
|
+
|
|
606
|
+
print(f"Fold {fold_idx + 1} - Train: {len(df_train)}, Val: {len(df_val)}")
|
|
607
|
+
|
|
608
|
+
# Create ChemProp datasets for this fold
|
|
609
|
+
train_datapoints, _ = create_molecule_datapoints(
|
|
610
|
+
df_train[smiles_column].tolist(), df_train[target].tolist(), train_extra
|
|
611
|
+
)
|
|
612
|
+
val_datapoints, _ = create_molecule_datapoints(
|
|
613
|
+
df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
train_dataset = data.MoleculeDataset(train_datapoints)
|
|
617
|
+
val_dataset = data.MoleculeDataset(val_datapoints)
|
|
618
|
+
|
|
619
|
+
# Save raw val features for prediction
|
|
620
|
+
val_extra_raw = val_extra.copy() if val_extra is not None else None
|
|
621
|
+
|
|
622
|
+
# Scale features and targets for this fold
|
|
623
|
+
x_d_transform = None
|
|
624
|
+
if use_extra_features:
|
|
625
|
+
feature_scaler = train_dataset.normalize_inputs("X_d")
|
|
626
|
+
val_dataset.normalize_inputs("X_d", feature_scaler)
|
|
627
|
+
x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
|
|
628
|
+
|
|
629
|
+
output_transform = None
|
|
630
|
+
if model_type in ["regressor", "uq_regressor"]:
|
|
631
|
+
target_scaler = train_dataset.normalize_targets()
|
|
632
|
+
val_dataset.normalize_targets(target_scaler)
|
|
633
|
+
output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
|
|
634
|
+
|
|
635
|
+
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
636
|
+
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
637
|
+
|
|
638
|
+
# Build and train model for this fold
|
|
639
|
+
pl.seed_everything(42 + fold_idx)
|
|
640
|
+
mpnn = build_mpnn_model(
|
|
641
|
+
hyperparameters, task=task, num_classes=num_classes,
|
|
642
|
+
n_extra_descriptors=n_extra, x_d_transform=x_d_transform, output_transform=output_transform,
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
callbacks = [
|
|
646
|
+
pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
|
|
647
|
+
pl.callbacks.ModelCheckpoint(
|
|
648
|
+
dirpath=args.model_dir, filename=f"best_model_{fold_idx}",
|
|
649
|
+
monitor="val_loss", mode="min", save_top_k=1,
|
|
650
|
+
),
|
|
651
|
+
]
|
|
652
|
+
|
|
653
|
+
trainer = pl.Trainer(
|
|
654
|
+
accelerator="auto", max_epochs=max_epochs, callbacks=callbacks,
|
|
655
|
+
logger=False, enable_progress_bar=True,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
659
|
+
|
|
660
|
+
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
|
|
661
|
+
checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
|
|
662
|
+
mpnn.load_state_dict(checkpoint["state_dict"])
|
|
663
|
+
|
|
664
|
+
mpnn.eval()
|
|
665
|
+
ensemble_models.append(mpnn)
|
|
666
|
+
|
|
667
|
+
# Make out-of-fold predictions using raw features
|
|
668
|
+
val_datapoints_raw, _ = create_molecule_datapoints(
|
|
669
|
+
df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra_raw
|
|
670
|
+
)
|
|
671
|
+
val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
|
|
672
|
+
val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
|
|
673
|
+
|
|
674
|
+
with torch.inference_mode():
|
|
675
|
+
fold_predictions = trainer.predict(mpnn, val_loader_pred)
|
|
676
|
+
fold_preds = np.concatenate([p.numpy() for p in fold_predictions], axis=0)
|
|
677
|
+
if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
|
|
678
|
+
fold_preds = fold_preds.squeeze(axis=1)
|
|
679
|
+
|
|
680
|
+
# Store out-of-fold predictions
|
|
681
|
+
if model_type == "classifier" and fold_preds.ndim == 2:
|
|
682
|
+
oof_predictions[val_idx] = np.argmax(fold_preds, axis=1)
|
|
683
|
+
if oof_proba is not None:
|
|
684
|
+
oof_proba[val_idx] = fold_preds
|
|
685
|
+
else:
|
|
686
|
+
oof_predictions[val_idx] = fold_preds.flatten()
|
|
687
|
+
|
|
688
|
+
print(f"Fold {fold_idx + 1} complete!")
|
|
689
|
+
|
|
690
|
+
print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
|
|
691
|
+
|
|
692
|
+
# Use out-of-fold predictions for metrics
|
|
693
|
+
# For n_folds=1, we only have predictions for val_idx, so filter to those rows
|
|
694
|
+
if n_folds == 1:
|
|
695
|
+
val_mask = ~np.isnan(oof_predictions)
|
|
696
|
+
preds = oof_predictions[val_mask]
|
|
697
|
+
df_val = all_df[val_mask].copy()
|
|
698
|
+
y_validate = df_val[target].values
|
|
699
|
+
if oof_proba is not None:
|
|
700
|
+
oof_proba = oof_proba[val_mask]
|
|
701
|
+
val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
|
|
702
|
+
else:
|
|
703
|
+
preds = oof_predictions
|
|
704
|
+
df_val = all_df.copy()
|
|
705
|
+
y_validate = all_df[target].values
|
|
706
|
+
val_extra_features = all_extra_features
|
|
707
|
+
|
|
708
|
+
# Compute prediction_std by running all ensemble models on validation data
|
|
709
|
+
# For n_folds=1, std will be 0 (only one model). For n_folds>1, std shows ensemble disagreement.
|
|
710
|
+
preds_std = None
|
|
711
|
+
if model_type in ["regressor", "uq_regressor"] and len(ensemble_models) > 0:
|
|
712
|
+
print("Computing prediction_std from ensemble predictions on validation data...")
|
|
713
|
+
val_datapoints_for_std, _ = create_molecule_datapoints(
|
|
714
|
+
df_val[smiles_column].tolist(),
|
|
715
|
+
df_val[target].tolist(),
|
|
716
|
+
val_extra_features
|
|
717
|
+
)
|
|
718
|
+
val_dataset_for_std = data.MoleculeDataset(val_datapoints_for_std)
|
|
719
|
+
val_loader_for_std = data.build_dataloader(val_dataset_for_std, batch_size=batch_size, shuffle=False)
|
|
720
|
+
|
|
721
|
+
all_ensemble_preds_for_std = []
|
|
722
|
+
trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
723
|
+
for ens_model in ensemble_models:
|
|
724
|
+
with torch.inference_mode():
|
|
725
|
+
ens_preds = trainer_pred.predict(ens_model, val_loader_for_std)
|
|
726
|
+
ens_preds = np.concatenate([p.numpy() for p in ens_preds], axis=0)
|
|
727
|
+
if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
|
|
728
|
+
ens_preds = ens_preds.squeeze(axis=1)
|
|
729
|
+
all_ensemble_preds_for_std.append(ens_preds.flatten())
|
|
730
|
+
|
|
731
|
+
ensemble_preds_stacked = np.stack(all_ensemble_preds_for_std, axis=0)
|
|
732
|
+
preds_std = np.std(ensemble_preds_stacked, axis=0)
|
|
733
|
+
print(f"Ensemble prediction_std - mean: {np.mean(preds_std):.4f}, max: {np.max(preds_std):.4f}")
|
|
734
|
+
|
|
735
|
+
if model_type == "classifier":
|
|
736
|
+
# Classification metrics - preds contains class indices from OOF predictions
|
|
737
|
+
class_preds = preds.astype(int)
|
|
738
|
+
has_proba = oof_proba is not None
|
|
739
|
+
|
|
740
|
+
print(f"class_preds shape: {class_preds.shape}")
|
|
741
|
+
|
|
742
|
+
# Decode labels for metrics
|
|
743
|
+
y_validate_decoded = label_encoder.inverse_transform(y_validate.astype(int))
|
|
744
|
+
preds_decoded = label_encoder.inverse_transform(class_preds)
|
|
745
|
+
|
|
746
|
+
# Calculate metrics
|
|
747
|
+
label_names = label_encoder.classes_
|
|
748
|
+
scores = precision_recall_fscore_support(
|
|
749
|
+
y_validate_decoded, preds_decoded, average=None, labels=label_names
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
score_df = pd.DataFrame(
|
|
753
|
+
{
|
|
754
|
+
target: label_names,
|
|
755
|
+
"precision": scores[0],
|
|
756
|
+
"recall": scores[1],
|
|
757
|
+
"f1": scores[2],
|
|
758
|
+
"support": scores[3],
|
|
759
|
+
}
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
# Output metrics per class
|
|
763
|
+
metrics = ["precision", "recall", "f1", "support"]
|
|
764
|
+
for t in label_names:
|
|
765
|
+
for m in metrics:
|
|
766
|
+
value = score_df.loc[score_df[target] == t, m].iloc[0]
|
|
767
|
+
print(f"Metrics:{t}:{m} {value}")
|
|
768
|
+
|
|
769
|
+
# Confusion matrix
|
|
770
|
+
conf_mtx = confusion_matrix(
|
|
771
|
+
y_validate_decoded, preds_decoded, labels=label_names
|
|
772
|
+
)
|
|
773
|
+
for i, row_name in enumerate(label_names):
|
|
774
|
+
for j, col_name in enumerate(label_names):
|
|
775
|
+
value = conf_mtx[i, j]
|
|
776
|
+
print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
|
|
777
|
+
|
|
778
|
+
# Save validation predictions
|
|
779
|
+
df_val = df_val.copy()
|
|
780
|
+
df_val["prediction"] = preds_decoded
|
|
781
|
+
if has_proba and oof_proba is not None:
|
|
782
|
+
df_val["pred_proba"] = [p.tolist() for p in oof_proba]
|
|
783
|
+
df_val = expand_proba_column(df_val, label_names)
|
|
784
|
+
|
|
785
|
+
else:
|
|
786
|
+
# Regression metrics
|
|
787
|
+
preds_flat = preds.flatten()
|
|
788
|
+
rmse = root_mean_squared_error(y_validate, preds_flat)
|
|
789
|
+
mae = mean_absolute_error(y_validate, preds_flat)
|
|
790
|
+
medae = median_absolute_error(y_validate, preds_flat)
|
|
791
|
+
r2 = r2_score(y_validate, preds_flat)
|
|
792
|
+
spearman_corr = spearmanr(y_validate, preds_flat).correlation
|
|
793
|
+
support = len(df_val)
|
|
794
|
+
print(f"rmse: {rmse:.3f}")
|
|
795
|
+
print(f"mae: {mae:.3f}")
|
|
796
|
+
print(f"medae: {medae:.3f}")
|
|
797
|
+
print(f"r2: {r2:.3f}")
|
|
798
|
+
print(f"spearmanr: {spearman_corr:.3f}")
|
|
799
|
+
print(f"support: {support}")
|
|
800
|
+
|
|
801
|
+
df_val = df_val.copy()
|
|
802
|
+
df_val["prediction"] = preds_flat
|
|
803
|
+
|
|
804
|
+
# Add prediction_std (always present for regressors, 0 for single model)
|
|
805
|
+
if preds_std is not None:
|
|
806
|
+
df_val["prediction_std"] = preds_std.flatten()
|
|
807
|
+
else:
|
|
808
|
+
df_val["prediction_std"] = 0.0
|
|
809
|
+
print(f"Ensemble std - mean: {df_val['prediction_std'].mean():.4f}, max: {df_val['prediction_std'].max():.4f}")
|
|
810
|
+
|
|
811
|
+
# Save validation predictions to S3
|
|
812
|
+
# Include id_column if it exists in df_val
|
|
813
|
+
output_columns = []
|
|
814
|
+
if id_column in df_val.columns:
|
|
815
|
+
output_columns.append(id_column)
|
|
816
|
+
output_columns += [target, "prediction"]
|
|
817
|
+
if "prediction_std" in df_val.columns:
|
|
818
|
+
output_columns.append("prediction_std")
|
|
819
|
+
output_columns += [col for col in df_val.columns if col.endswith("_proba")]
|
|
820
|
+
wr.s3.to_csv(
|
|
821
|
+
df_val[output_columns],
|
|
822
|
+
path=f"{model_metrics_s3_path}/validation_predictions.csv",
|
|
823
|
+
index=False,
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
# Save ensemble models (n_folds models if CV, 1 model otherwise)
|
|
827
|
+
for model_idx, ens_model in enumerate(ensemble_models):
|
|
828
|
+
model_path = os.path.join(args.model_dir, f"chemprop_model_{model_idx}.pt")
|
|
829
|
+
models.save_model(model_path, ens_model)
|
|
830
|
+
print(f"Saved model {model_idx + 1} to {model_path}")
|
|
831
|
+
|
|
832
|
+
# Save ensemble metadata (n_ensemble = number of models for inference)
|
|
833
|
+
n_ensemble = len(ensemble_models)
|
|
834
|
+
ensemble_metadata = {"n_ensemble": n_ensemble, "n_folds": n_folds}
|
|
835
|
+
joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
836
|
+
print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds})")
|
|
837
|
+
|
|
838
|
+
# Save label encoder if classification
|
|
839
|
+
if label_encoder is not None:
|
|
840
|
+
joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
|
|
841
|
+
|
|
842
|
+
# Save extra feature metadata for inference (hybrid mode)
|
|
843
|
+
# Note: We don't need to save the scaler - X_d_transform is embedded in the model
|
|
844
|
+
if use_extra_features:
|
|
845
|
+
feature_metadata = {
|
|
846
|
+
"extra_feature_cols": extra_feature_cols,
|
|
847
|
+
"col_means": col_means.tolist(), # Unscaled means for NaN imputation
|
|
848
|
+
}
|
|
849
|
+
joblib.dump(
|
|
850
|
+
feature_metadata, os.path.join(args.model_dir, "feature_metadata.joblib")
|
|
851
|
+
)
|
|
852
|
+
print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")
|