workbench 0.8.198__py3-none-any.whl → 0.8.203__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/proximity.py +11 -4
- workbench/api/__init__.py +2 -1
- workbench/api/df_store.py +17 -108
- workbench/api/feature_set.py +48 -11
- workbench/api/model.py +1 -1
- workbench/api/parameter_store.py +3 -52
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +5 -5
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +261 -78
- workbench/core/artifacts/feature_set_core.py +69 -1
- workbench/core/artifacts/model_core.py +48 -14
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/transforms/features_to_model/features_to_model.py +50 -33
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
- workbench/core/views/view.py +2 -2
- workbench/model_scripts/chemprop/chemprop.template +933 -0
- workbench/model_scripts/chemprop/generated_model_script.py +933 -0
- workbench/model_scripts/chemprop/requirements.txt +11 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/proximity.py +11 -4
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +11 -5
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +11 -5
- workbench/model_scripts/custom_models/uq_models/ngboost.template +11 -5
- workbench/model_scripts/custom_models/uq_models/proximity.py +11 -4
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +11 -5
- workbench/model_scripts/pytorch_model/generated_model_script.py +365 -173
- workbench/model_scripts/pytorch_model/pytorch.template +362 -170
- workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
- workbench/model_scripts/script_generation.py +10 -7
- workbench/model_scripts/uq_models/generated_model_script.py +43 -27
- workbench/model_scripts/uq_models/mapie.template +40 -24
- workbench/model_scripts/xgb_model/generated_model_script.py +36 -7
- workbench/model_scripts/xgb_model/xgb_model.template +36 -7
- workbench/repl/workbench_shell.py +14 -5
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
- workbench/utils/chemprop_utils.py +761 -0
- workbench/utils/pytorch_utils.py +527 -0
- workbench/utils/xgboost_model_utils.py +10 -5
- workbench/web_interface/components/model_plot.py +7 -1
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/METADATA +3 -3
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/RECORD +49 -43
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/entry_points.txt +2 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/__pycache__/script_generation.cpython-312.pyc +0 -0
- workbench/model_scripts/__pycache__/script_generation.cpython-313.pyc +0 -0
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/WHEEL +0 -0
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,933 @@
|
|
|
1
|
+
# ChemProp Model Template for Workbench
|
|
2
|
+
# Uses ChemProp 2.x Message Passing Neural Networks for molecular property prediction
|
|
3
|
+
#
|
|
4
|
+
# === CHEMPROP REVIEW NOTES ===
|
|
5
|
+
# This script runs on AWS SageMaker. Key areas for ChemProp review:
|
|
6
|
+
#
|
|
7
|
+
# 1. Model Architecture (build_mpnn_model function)
|
|
8
|
+
# - BondMessagePassing, NormAggregation, FFN configuration
|
|
9
|
+
# - Regression uses output_transform (UnscaleTransform) for target scaling
|
|
10
|
+
#
|
|
11
|
+
# 2. Data Handling (create_molecule_datapoints function)
|
|
12
|
+
# - MoleculeDatapoint creation with x_d (extra descriptors)
|
|
13
|
+
# - RDKit validation of SMILES
|
|
14
|
+
#
|
|
15
|
+
# 3. Scaling (training section)
|
|
16
|
+
# - Extra descriptors: normalize_inputs("X_d") + X_d_transform in model
|
|
17
|
+
# - Targets (regression): normalize_targets() + UnscaleTransform in FFN
|
|
18
|
+
# - At inference: pass RAW features, transforms handle scaling automatically
|
|
19
|
+
#
|
|
20
|
+
# 4. Training Loop (search for "pl.Trainer")
|
|
21
|
+
# - PyTorch Lightning Trainer with ChemProp MPNN
|
|
22
|
+
#
|
|
23
|
+
# AWS/SageMaker boilerplate (can skip):
|
|
24
|
+
# - input_fn, output_fn, model_fn: SageMaker serving interface
|
|
25
|
+
# - argparse, file loading, S3 writes
|
|
26
|
+
# =============================
|
|
27
|
+
|
|
28
|
+
import os
|
|
29
|
+
import argparse
|
|
30
|
+
import json
|
|
31
|
+
from io import StringIO
|
|
32
|
+
|
|
33
|
+
import awswrangler as wr
|
|
34
|
+
import numpy as np
|
|
35
|
+
import pandas as pd
|
|
36
|
+
import torch
|
|
37
|
+
from lightning import pytorch as pl
|
|
38
|
+
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
|
|
39
|
+
from sklearn.preprocessing import LabelEncoder
|
|
40
|
+
from sklearn.metrics import (
|
|
41
|
+
mean_absolute_error,
|
|
42
|
+
median_absolute_error,
|
|
43
|
+
r2_score,
|
|
44
|
+
root_mean_squared_error,
|
|
45
|
+
precision_recall_fscore_support,
|
|
46
|
+
confusion_matrix,
|
|
47
|
+
)
|
|
48
|
+
from scipy.stats import spearmanr
|
|
49
|
+
import joblib
|
|
50
|
+
|
|
51
|
+
# ChemProp imports
|
|
52
|
+
from chemprop import data, models, nn
|
|
53
|
+
|
|
54
|
+
# Template Parameters
|
|
55
|
+
TEMPLATE_PARAMS = {
|
|
56
|
+
"model_type": "{{model_type}}",
|
|
57
|
+
"targets": "{{target_column}}", # List of target columns (single or multi-task)
|
|
58
|
+
"feature_list": "{{feature_list}}",
|
|
59
|
+
"id_column": "{{id_column}}",
|
|
60
|
+
"model_metrics_s3_path": "{{model_metrics_s3_path}}",
|
|
61
|
+
"hyperparameters": "{{hyperparameters}}",
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
|
|
66
|
+
"""Check if the provided dataframe is empty and raise an exception if it is."""
|
|
67
|
+
if df.empty:
|
|
68
|
+
msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
|
|
69
|
+
print(msg)
|
|
70
|
+
raise ValueError(msg)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def find_smiles_column(columns: list[str]) -> str:
|
|
74
|
+
"""Find the SMILES column name from a list (case-insensitive match for 'smiles')."""
|
|
75
|
+
smiles_column = next((col for col in columns if col.lower() == "smiles"), None)
|
|
76
|
+
if smiles_column is None:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Column list must contain a 'smiles' column (case-insensitive)"
|
|
79
|
+
)
|
|
80
|
+
return smiles_column
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
|
|
84
|
+
"""Expands a column containing a list of probabilities into separate columns.
|
|
85
|
+
|
|
86
|
+
Handles None values for rows where predictions couldn't be made.
|
|
87
|
+
"""
|
|
88
|
+
proba_column = "pred_proba"
|
|
89
|
+
if proba_column not in df.columns:
|
|
90
|
+
raise ValueError('DataFrame does not contain a "pred_proba" column')
|
|
91
|
+
|
|
92
|
+
proba_splits = [f"{label}_proba" for label in class_labels]
|
|
93
|
+
n_classes = len(class_labels)
|
|
94
|
+
|
|
95
|
+
# Handle None values by replacing with list of NaNs
|
|
96
|
+
proba_values = []
|
|
97
|
+
for val in df[proba_column]:
|
|
98
|
+
if val is None:
|
|
99
|
+
proba_values.append([np.nan] * n_classes)
|
|
100
|
+
else:
|
|
101
|
+
proba_values.append(val)
|
|
102
|
+
|
|
103
|
+
proba_df = pd.DataFrame(proba_values, columns=proba_splits)
|
|
104
|
+
|
|
105
|
+
df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
|
|
106
|
+
df = df.reset_index(drop=True)
|
|
107
|
+
df = pd.concat([df, proba_df], axis=1)
|
|
108
|
+
return df
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def create_molecule_datapoints(
|
|
112
|
+
smiles_list: list[str],
|
|
113
|
+
targets: list[float] | np.ndarray | None = None,
|
|
114
|
+
extra_descriptors: np.ndarray | None = None,
|
|
115
|
+
) -> tuple[list[data.MoleculeDatapoint], list[int]]:
|
|
116
|
+
"""Create ChemProp MoleculeDatapoints from SMILES strings.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
smiles_list: List of SMILES strings
|
|
120
|
+
targets: Optional target values as 2D array (n_samples, n_targets). NaN allowed for missing targets.
|
|
121
|
+
extra_descriptors: Optional array of extra features (n_samples, n_features)
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Tuple of (list of MoleculeDatapoint objects, list of valid indices)
|
|
125
|
+
"""
|
|
126
|
+
from rdkit import Chem
|
|
127
|
+
|
|
128
|
+
datapoints = []
|
|
129
|
+
valid_indices = []
|
|
130
|
+
invalid_count = 0
|
|
131
|
+
|
|
132
|
+
# Convert targets to 2D array if provided
|
|
133
|
+
if targets is not None:
|
|
134
|
+
targets = np.atleast_2d(np.array(targets))
|
|
135
|
+
if targets.shape[0] == 1 and len(smiles_list) > 1:
|
|
136
|
+
targets = targets.T # Shape was (1, n_samples), transpose to (n_samples, 1)
|
|
137
|
+
|
|
138
|
+
for i, smi in enumerate(smiles_list):
|
|
139
|
+
# Validate SMILES with RDKit first
|
|
140
|
+
mol = Chem.MolFromSmiles(smi)
|
|
141
|
+
if mol is None:
|
|
142
|
+
invalid_count += 1
|
|
143
|
+
continue
|
|
144
|
+
|
|
145
|
+
# Build datapoint with optional target(s) and extra descriptors
|
|
146
|
+
# For multi-task, y is a list of values (can include NaN for missing targets)
|
|
147
|
+
y = targets[i].tolist() if targets is not None else None
|
|
148
|
+
x_d = extra_descriptors[i] if extra_descriptors is not None else None
|
|
149
|
+
|
|
150
|
+
dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
|
|
151
|
+
datapoints.append(dp)
|
|
152
|
+
valid_indices.append(i)
|
|
153
|
+
|
|
154
|
+
if invalid_count > 0:
|
|
155
|
+
print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
|
|
156
|
+
|
|
157
|
+
return datapoints, valid_indices
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def build_mpnn_model(
|
|
161
|
+
hyperparameters: dict,
|
|
162
|
+
task: str = "regression",
|
|
163
|
+
num_classes: int | None = None,
|
|
164
|
+
n_targets: int = 1,
|
|
165
|
+
n_extra_descriptors: int = 0,
|
|
166
|
+
x_d_transform: nn.ScaleTransform | None = None,
|
|
167
|
+
output_transform: nn.UnscaleTransform | None = None,
|
|
168
|
+
task_weights: np.ndarray | None = None,
|
|
169
|
+
) -> models.MPNN:
|
|
170
|
+
"""Build an MPNN model with the specified hyperparameters.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
hyperparameters: Dictionary of model hyperparameters
|
|
174
|
+
task: Either "regression" or "classification"
|
|
175
|
+
num_classes: Number of classes for classification tasks
|
|
176
|
+
n_targets: Number of target columns (for multi-task regression)
|
|
177
|
+
n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
|
|
178
|
+
x_d_transform: Optional transform for extra descriptors (scaling)
|
|
179
|
+
output_transform: Optional transform for regression output (unscaling targets)
|
|
180
|
+
task_weights: Optional array of weights for each task (multi-task learning)
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Configured MPNN model
|
|
184
|
+
"""
|
|
185
|
+
# Model hyperparameters with defaults
|
|
186
|
+
hidden_dim = hyperparameters.get("hidden_dim", 700)
|
|
187
|
+
depth = hyperparameters.get("depth", 6)
|
|
188
|
+
dropout = hyperparameters.get("dropout", 0.25)
|
|
189
|
+
ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 2000)
|
|
190
|
+
ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
|
|
191
|
+
|
|
192
|
+
# Message passing component
|
|
193
|
+
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
194
|
+
|
|
195
|
+
# Aggregation - NormAggregation normalizes output, recommended when using extra descriptors
|
|
196
|
+
agg = nn.NormAggregation()
|
|
197
|
+
|
|
198
|
+
# FFN input_dim = message passing output + extra descriptors
|
|
199
|
+
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
200
|
+
|
|
201
|
+
# Build FFN based on task type
|
|
202
|
+
if task == "classification" and num_classes is not None:
|
|
203
|
+
# Multi-class classification
|
|
204
|
+
ffn = nn.MulticlassClassificationFFN(
|
|
205
|
+
n_classes=num_classes,
|
|
206
|
+
input_dim=ffn_input_dim,
|
|
207
|
+
hidden_dim=ffn_hidden_dim,
|
|
208
|
+
n_layers=ffn_num_layers,
|
|
209
|
+
dropout=dropout,
|
|
210
|
+
)
|
|
211
|
+
else:
|
|
212
|
+
# Regression with optional output transform to unscale predictions
|
|
213
|
+
# n_tasks controls the number of output heads for multi-task learning
|
|
214
|
+
# task_weights goes here (in RegressionFFN) to weight loss per task
|
|
215
|
+
weights_tensor = None
|
|
216
|
+
if task_weights is not None:
|
|
217
|
+
weights_tensor = torch.tensor(task_weights, dtype=torch.float32)
|
|
218
|
+
|
|
219
|
+
ffn = nn.RegressionFFN(
|
|
220
|
+
input_dim=ffn_input_dim,
|
|
221
|
+
hidden_dim=ffn_hidden_dim,
|
|
222
|
+
n_layers=ffn_num_layers,
|
|
223
|
+
dropout=dropout,
|
|
224
|
+
n_tasks=n_targets,
|
|
225
|
+
output_transform=output_transform,
|
|
226
|
+
task_weights=weights_tensor,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Create the MPNN model
|
|
230
|
+
mpnn = models.MPNN(
|
|
231
|
+
message_passing=mp,
|
|
232
|
+
agg=agg,
|
|
233
|
+
predictor=ffn,
|
|
234
|
+
batch_norm=True,
|
|
235
|
+
metrics=None,
|
|
236
|
+
X_d_transform=x_d_transform,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
return mpnn
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def model_fn(model_dir: str) -> dict:
|
|
243
|
+
"""Load the ChemProp MPNN ensemble models from the specified directory.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
model_dir: Directory containing the saved models
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Dictionary with ensemble models and metadata
|
|
250
|
+
"""
|
|
251
|
+
# Load ensemble metadata (required)
|
|
252
|
+
ensemble_metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
|
|
253
|
+
ensemble_metadata = joblib.load(ensemble_metadata_path)
|
|
254
|
+
n_ensemble = ensemble_metadata["n_ensemble"]
|
|
255
|
+
target_columns = ensemble_metadata["target_columns"]
|
|
256
|
+
|
|
257
|
+
# Load all ensemble models
|
|
258
|
+
ensemble_models = []
|
|
259
|
+
for ens_idx in range(n_ensemble):
|
|
260
|
+
model_path = os.path.join(model_dir, f"chemprop_model_{ens_idx}.pt")
|
|
261
|
+
model = models.MPNN.load_from_file(model_path)
|
|
262
|
+
model.eval()
|
|
263
|
+
ensemble_models.append(model)
|
|
264
|
+
|
|
265
|
+
print(f"Loaded {len(ensemble_models)} ensemble model(s), n_targets={len(target_columns)}")
|
|
266
|
+
|
|
267
|
+
return {
|
|
268
|
+
"ensemble_models": ensemble_models,
|
|
269
|
+
"n_ensemble": n_ensemble,
|
|
270
|
+
"target_columns": target_columns,
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def input_fn(input_data, content_type: str) -> pd.DataFrame:
|
|
275
|
+
"""Parse input data and return a DataFrame."""
|
|
276
|
+
if not input_data:
|
|
277
|
+
raise ValueError("Empty input data is not supported!")
|
|
278
|
+
|
|
279
|
+
if isinstance(input_data, bytes):
|
|
280
|
+
input_data = input_data.decode("utf-8")
|
|
281
|
+
|
|
282
|
+
if "text/csv" in content_type:
|
|
283
|
+
return pd.read_csv(StringIO(input_data))
|
|
284
|
+
elif "application/json" in content_type:
|
|
285
|
+
return pd.DataFrame(json.loads(input_data))
|
|
286
|
+
else:
|
|
287
|
+
raise ValueError(f"{content_type} not supported!")
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
291
|
+
"""Supports both CSV and JSON output formats."""
|
|
292
|
+
if "text/csv" in accept_type:
|
|
293
|
+
csv_output = output_df.fillna("N/A").to_csv(index=False)
|
|
294
|
+
return csv_output, "text/csv"
|
|
295
|
+
elif "application/json" in accept_type:
|
|
296
|
+
return output_df.to_json(orient="records"), "application/json"
|
|
297
|
+
else:
|
|
298
|
+
raise RuntimeError(
|
|
299
|
+
f"{accept_type} accept type is not supported by this script."
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
304
|
+
"""Make predictions with the ChemProp MPNN ensemble.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
df: Input DataFrame containing SMILES column (and extra features if hybrid mode)
|
|
308
|
+
model_dict: Dictionary containing ensemble models and metadata
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
DataFrame with predictions added (and prediction_std for ensembles)
|
|
312
|
+
"""
|
|
313
|
+
model_type = TEMPLATE_PARAMS["model_type"]
|
|
314
|
+
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
315
|
+
|
|
316
|
+
# Extract ensemble models and metadata
|
|
317
|
+
ensemble_models = model_dict["ensemble_models"]
|
|
318
|
+
n_ensemble = model_dict["n_ensemble"]
|
|
319
|
+
target_columns = model_dict["target_columns"]
|
|
320
|
+
|
|
321
|
+
# Load label encoder if present (classification)
|
|
322
|
+
label_encoder = None
|
|
323
|
+
label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
|
|
324
|
+
if os.path.exists(label_encoder_path):
|
|
325
|
+
label_encoder = joblib.load(label_encoder_path)
|
|
326
|
+
|
|
327
|
+
# Load feature metadata if present (hybrid mode)
|
|
328
|
+
# Contains column names, NaN fill values, and scaler for feature scaling
|
|
329
|
+
feature_metadata = None
|
|
330
|
+
feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
|
|
331
|
+
if os.path.exists(feature_metadata_path):
|
|
332
|
+
feature_metadata = joblib.load(feature_metadata_path)
|
|
333
|
+
print(
|
|
334
|
+
f"Hybrid mode: using {len(feature_metadata['extra_feature_cols'])} extra features"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Find SMILES column in input DataFrame
|
|
338
|
+
smiles_column = find_smiles_column(df.columns.tolist())
|
|
339
|
+
|
|
340
|
+
smiles_list = df[smiles_column].tolist()
|
|
341
|
+
|
|
342
|
+
# Track invalid SMILES
|
|
343
|
+
valid_mask = []
|
|
344
|
+
valid_smiles = []
|
|
345
|
+
valid_indices = []
|
|
346
|
+
for i, smi in enumerate(smiles_list):
|
|
347
|
+
if smi and isinstance(smi, str) and len(smi.strip()) > 0:
|
|
348
|
+
valid_mask.append(True)
|
|
349
|
+
valid_smiles.append(smi.strip())
|
|
350
|
+
valid_indices.append(i)
|
|
351
|
+
else:
|
|
352
|
+
valid_mask.append(False)
|
|
353
|
+
|
|
354
|
+
valid_mask = np.array(valid_mask)
|
|
355
|
+
print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
|
|
356
|
+
|
|
357
|
+
# Initialize prediction columns (use object dtype for classifiers to avoid FutureWarning)
|
|
358
|
+
if model_type == "classifier":
|
|
359
|
+
df["prediction"] = pd.Series([None] * len(df), dtype=object)
|
|
360
|
+
else:
|
|
361
|
+
# Regression: create prediction column for each target
|
|
362
|
+
for tc in target_columns:
|
|
363
|
+
df[f"{tc}_pred"] = np.nan
|
|
364
|
+
df[f"{tc}_pred_std"] = np.nan
|
|
365
|
+
|
|
366
|
+
if sum(valid_mask) == 0:
|
|
367
|
+
print("Warning: No valid SMILES to predict on")
|
|
368
|
+
return df
|
|
369
|
+
|
|
370
|
+
# Prepare extra features if in hybrid mode
|
|
371
|
+
# NOTE: We pass RAW (unscaled) features here - the model's X_d_transform handles scaling
|
|
372
|
+
extra_features = None
|
|
373
|
+
if feature_metadata is not None:
|
|
374
|
+
extra_feature_cols = feature_metadata["extra_feature_cols"]
|
|
375
|
+
col_means = np.array(feature_metadata["col_means"])
|
|
376
|
+
|
|
377
|
+
# Check columns exist
|
|
378
|
+
missing_cols = [col for col in extra_feature_cols if col not in df.columns]
|
|
379
|
+
if missing_cols:
|
|
380
|
+
print(
|
|
381
|
+
f"Warning: Missing extra feature columns: {missing_cols}. Using mean values."
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# Extract features for valid SMILES rows (raw, unscaled)
|
|
385
|
+
extra_features = np.zeros(
|
|
386
|
+
(len(valid_indices), len(extra_feature_cols)), dtype=np.float32
|
|
387
|
+
)
|
|
388
|
+
for j, col in enumerate(extra_feature_cols):
|
|
389
|
+
if col in df.columns:
|
|
390
|
+
values = df.iloc[valid_indices][col].values.astype(np.float32)
|
|
391
|
+
# Fill NaN with training column means (unscaled means)
|
|
392
|
+
nan_mask = np.isnan(values)
|
|
393
|
+
values[nan_mask] = col_means[j]
|
|
394
|
+
extra_features[:, j] = values
|
|
395
|
+
else:
|
|
396
|
+
# Column missing, use training mean
|
|
397
|
+
extra_features[:, j] = col_means[j]
|
|
398
|
+
|
|
399
|
+
# Create datapoints for prediction (filter out invalid SMILES)
|
|
400
|
+
datapoints, rdkit_valid_indices = create_molecule_datapoints(
|
|
401
|
+
valid_smiles, extra_descriptors=extra_features
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
if len(datapoints) == 0:
|
|
405
|
+
print("Warning: No valid SMILES after RDKit validation")
|
|
406
|
+
return df
|
|
407
|
+
|
|
408
|
+
dataset = data.MoleculeDataset(datapoints)
|
|
409
|
+
dataloader = data.build_dataloader(dataset, shuffle=False)
|
|
410
|
+
|
|
411
|
+
# Make predictions with ensemble
|
|
412
|
+
trainer = pl.Trainer(
|
|
413
|
+
accelerator="auto",
|
|
414
|
+
logger=False,
|
|
415
|
+
enable_progress_bar=False,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# Collect predictions from all ensemble members
|
|
419
|
+
all_ensemble_preds = []
|
|
420
|
+
for ens_idx, ens_model in enumerate(ensemble_models):
|
|
421
|
+
with torch.inference_mode():
|
|
422
|
+
predictions = trainer.predict(ens_model, dataloader)
|
|
423
|
+
ens_preds = np.concatenate([p.numpy() for p in predictions], axis=0)
|
|
424
|
+
# Squeeze middle dim if present
|
|
425
|
+
if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
|
|
426
|
+
ens_preds = ens_preds.squeeze(axis=1)
|
|
427
|
+
all_ensemble_preds.append(ens_preds)
|
|
428
|
+
|
|
429
|
+
# Stack and compute mean/std (std is 0 for single model)
|
|
430
|
+
ensemble_preds = np.stack(all_ensemble_preds, axis=0)
|
|
431
|
+
preds = np.mean(ensemble_preds, axis=0)
|
|
432
|
+
preds_std = np.std(ensemble_preds, axis=0) # Will be 0s for n_ensemble=1
|
|
433
|
+
|
|
434
|
+
# Ensure 2D: (n_samples, n_targets)
|
|
435
|
+
if preds.ndim == 1:
|
|
436
|
+
preds = preds.reshape(-1, 1)
|
|
437
|
+
preds_std = preds_std.reshape(-1, 1)
|
|
438
|
+
|
|
439
|
+
print(f"Inference: Ensemble predictions shape: {preds.shape}")
|
|
440
|
+
|
|
441
|
+
# Map predictions back to valid_mask positions (accounting for RDKit-invalid SMILES)
|
|
442
|
+
# rdkit_valid_indices tells us which of the valid_smiles were actually valid
|
|
443
|
+
valid_positions = np.where(valid_mask)[0][rdkit_valid_indices]
|
|
444
|
+
valid_mask = np.zeros(len(df), dtype=bool)
|
|
445
|
+
valid_mask[valid_positions] = True
|
|
446
|
+
|
|
447
|
+
if model_type == "classifier" and label_encoder is not None:
|
|
448
|
+
# For classification, get class predictions and probabilities
|
|
449
|
+
if preds.ndim == 2 and preds.shape[1] > 1:
|
|
450
|
+
# Multi-class: preds are probabilities (averaged across ensemble)
|
|
451
|
+
class_preds = np.argmax(preds, axis=1)
|
|
452
|
+
decoded_preds = label_encoder.inverse_transform(class_preds)
|
|
453
|
+
df.loc[valid_mask, "prediction"] = decoded_preds
|
|
454
|
+
|
|
455
|
+
# Add probability columns
|
|
456
|
+
proba_series = pd.Series([None] * len(df), index=df.index, dtype=object)
|
|
457
|
+
proba_series.loc[valid_mask] = [p.tolist() for p in preds]
|
|
458
|
+
df["pred_proba"] = proba_series
|
|
459
|
+
df = expand_proba_column(df, label_encoder.classes_)
|
|
460
|
+
else:
|
|
461
|
+
# Binary or single output
|
|
462
|
+
class_preds = (preds.flatten() > 0.5).astype(int)
|
|
463
|
+
decoded_preds = label_encoder.inverse_transform(class_preds)
|
|
464
|
+
df.loc[valid_mask, "prediction"] = decoded_preds
|
|
465
|
+
else:
|
|
466
|
+
# Regression: store predictions for each target
|
|
467
|
+
for t_idx, tc in enumerate(target_columns):
|
|
468
|
+
df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
|
|
469
|
+
df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
|
|
470
|
+
|
|
471
|
+
return df
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
if __name__ == "__main__":
|
|
475
|
+
"""Training script for ChemProp MPNN model"""
|
|
476
|
+
|
|
477
|
+
# Template Parameters
|
|
478
|
+
target_columns = TEMPLATE_PARAMS["targets"] # List of target columns
|
|
479
|
+
model_type = TEMPLATE_PARAMS["model_type"]
|
|
480
|
+
feature_list = TEMPLATE_PARAMS["feature_list"]
|
|
481
|
+
id_column = TEMPLATE_PARAMS["id_column"]
|
|
482
|
+
model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
|
|
483
|
+
hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
|
|
484
|
+
|
|
485
|
+
# Validate target_columns
|
|
486
|
+
if not target_columns or not isinstance(target_columns, list) or len(target_columns) == 0:
|
|
487
|
+
raise ValueError("'targets' must be a non-empty list of target column names")
|
|
488
|
+
n_targets = len(target_columns)
|
|
489
|
+
print(f"Target columns ({n_targets}): {target_columns}")
|
|
490
|
+
|
|
491
|
+
# Get the SMILES column name from feature_list (user defines this, so we use their exact name)
|
|
492
|
+
smiles_column = find_smiles_column(feature_list)
|
|
493
|
+
extra_feature_cols = [f for f in feature_list if f != smiles_column]
|
|
494
|
+
use_extra_features = len(extra_feature_cols) > 0
|
|
495
|
+
print(f"Feature List: {feature_list}")
|
|
496
|
+
print(f"SMILES Column: {smiles_column}")
|
|
497
|
+
print(
|
|
498
|
+
f"Extra Features (hybrid mode): {extra_feature_cols if use_extra_features else 'None (SMILES only)'}"
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
# Script arguments for input/output directories
|
|
502
|
+
parser = argparse.ArgumentParser()
|
|
503
|
+
parser.add_argument(
|
|
504
|
+
"--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
505
|
+
)
|
|
506
|
+
parser.add_argument(
|
|
507
|
+
"--train",
|
|
508
|
+
type=str,
|
|
509
|
+
default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"),
|
|
510
|
+
)
|
|
511
|
+
parser.add_argument(
|
|
512
|
+
"--output-data-dir",
|
|
513
|
+
type=str,
|
|
514
|
+
default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"),
|
|
515
|
+
)
|
|
516
|
+
args = parser.parse_args()
|
|
517
|
+
|
|
518
|
+
# Read the training data
|
|
519
|
+
training_files = [
|
|
520
|
+
os.path.join(args.train, f)
|
|
521
|
+
for f in os.listdir(args.train)
|
|
522
|
+
if f.endswith(".csv")
|
|
523
|
+
]
|
|
524
|
+
print(f"Training Files: {training_files}")
|
|
525
|
+
|
|
526
|
+
all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
|
|
527
|
+
print(f"All Data Shape: {all_df.shape}")
|
|
528
|
+
|
|
529
|
+
check_dataframe(all_df, "training_df")
|
|
530
|
+
|
|
531
|
+
# Drop rows with missing SMILES or all target values
|
|
532
|
+
initial_count = len(all_df)
|
|
533
|
+
all_df = all_df.dropna(subset=[smiles_column])
|
|
534
|
+
# Keep rows that have at least one non-null target (works for single and multi-task)
|
|
535
|
+
has_any_target = all_df[target_columns].notna().any(axis=1)
|
|
536
|
+
all_df = all_df[has_any_target]
|
|
537
|
+
dropped = initial_count - len(all_df)
|
|
538
|
+
if dropped > 0:
|
|
539
|
+
print(f"Dropped {dropped} rows with missing SMILES or all target values")
|
|
540
|
+
|
|
541
|
+
print(f"Target columns: {target_columns}")
|
|
542
|
+
print(f"Data Shape after cleaning: {all_df.shape}")
|
|
543
|
+
for tc in target_columns:
|
|
544
|
+
n_valid = all_df[tc].notna().sum()
|
|
545
|
+
print(f" {tc}: {n_valid} samples with values")
|
|
546
|
+
|
|
547
|
+
# Set up label encoder for classification (single-target only)
|
|
548
|
+
label_encoder = None
|
|
549
|
+
if model_type == "classifier":
|
|
550
|
+
if n_targets > 1:
|
|
551
|
+
raise ValueError("Multi-task classification is not supported. Use regression for multi-task.")
|
|
552
|
+
label_encoder = LabelEncoder()
|
|
553
|
+
all_df[target_columns[0]] = label_encoder.fit_transform(all_df[target_columns[0]])
|
|
554
|
+
num_classes = len(label_encoder.classes_)
|
|
555
|
+
print(
|
|
556
|
+
f"Classification task with {num_classes} classes: {label_encoder.classes_}"
|
|
557
|
+
)
|
|
558
|
+
else:
|
|
559
|
+
num_classes = None
|
|
560
|
+
|
|
561
|
+
# Model and training configuration
|
|
562
|
+
print(f"Hyperparameters: {hyperparameters}")
|
|
563
|
+
task = "classification" if model_type == "classifier" else "regression"
|
|
564
|
+
n_extra = len(extra_feature_cols) if use_extra_features else 0
|
|
565
|
+
max_epochs = hyperparameters.get("max_epochs", 400)
|
|
566
|
+
patience = hyperparameters.get("patience", 40)
|
|
567
|
+
n_folds = hyperparameters.get("n_folds", 5) # Number of CV folds (default: 5)
|
|
568
|
+
batch_size = hyperparameters.get("batch_size", 16)
|
|
569
|
+
|
|
570
|
+
# Check extra feature columns exist
|
|
571
|
+
if use_extra_features:
|
|
572
|
+
missing_cols = [col for col in extra_feature_cols if col not in all_df.columns]
|
|
573
|
+
if missing_cols:
|
|
574
|
+
raise ValueError(f"Missing extra feature columns in training data: {missing_cols}")
|
|
575
|
+
|
|
576
|
+
# =========================================================================
|
|
577
|
+
# UNIFIED TRAINING: Works for n_folds=1 (single model) or n_folds>1 (K-fold CV)
|
|
578
|
+
# =========================================================================
|
|
579
|
+
print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold cross-validation ensemble'}...")
|
|
580
|
+
|
|
581
|
+
# Prepare extra features and validate SMILES upfront
|
|
582
|
+
all_extra_features = None
|
|
583
|
+
col_means = None
|
|
584
|
+
if use_extra_features:
|
|
585
|
+
all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
|
|
586
|
+
col_means = np.nanmean(all_extra_features, axis=0)
|
|
587
|
+
for i in range(all_extra_features.shape[1]):
|
|
588
|
+
all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
|
|
589
|
+
|
|
590
|
+
# Prepare target array: always 2D (n_samples, n_targets)
|
|
591
|
+
all_targets = all_df[target_columns].values.astype(np.float32)
|
|
592
|
+
|
|
593
|
+
# Filter invalid SMILES from the full dataset
|
|
594
|
+
_, valid_indices = create_molecule_datapoints(
|
|
595
|
+
all_df[smiles_column].tolist(), all_targets, all_extra_features
|
|
596
|
+
)
|
|
597
|
+
all_df = all_df.iloc[valid_indices].reset_index(drop=True)
|
|
598
|
+
all_targets = all_targets[valid_indices]
|
|
599
|
+
if all_extra_features is not None:
|
|
600
|
+
all_extra_features = all_extra_features[valid_indices]
|
|
601
|
+
print(f"Data after SMILES validation: {all_df.shape}")
|
|
602
|
+
|
|
603
|
+
# Compute dynamic task weights for multi-task regression
|
|
604
|
+
# Weight = inverse of sample count (normalized so min weight = 1.0)
|
|
605
|
+
# This gives higher weight to targets with fewer samples
|
|
606
|
+
task_weights = None
|
|
607
|
+
if n_targets > 1 and model_type != "classifier":
|
|
608
|
+
sample_counts = np.array([np.sum(~np.isnan(all_targets[:, t])) for t in range(n_targets)])
|
|
609
|
+
# Inverse weighting: fewer samples = higher weight
|
|
610
|
+
inverse_counts = 1.0 / sample_counts
|
|
611
|
+
# Normalize so minimum weight is 1.0
|
|
612
|
+
task_weights = inverse_counts / inverse_counts.min()
|
|
613
|
+
print(f"Task weights (inverse sample count):")
|
|
614
|
+
for t_idx, t_name in enumerate(target_columns):
|
|
615
|
+
print(f" {t_name}: {task_weights[t_idx]:.3f} (n={sample_counts[t_idx]})")
|
|
616
|
+
|
|
617
|
+
# Create fold splits
|
|
618
|
+
if n_folds == 1:
|
|
619
|
+
# Single fold: use train/val split from "training" column or random split
|
|
620
|
+
if "training" in all_df.columns:
|
|
621
|
+
print("Found training column, splitting data based on training column")
|
|
622
|
+
train_idx = np.where(all_df["training"])[0]
|
|
623
|
+
val_idx = np.where(~all_df["training"])[0]
|
|
624
|
+
else:
|
|
625
|
+
print("WARNING: No training column found, splitting data with random 80/20 split")
|
|
626
|
+
indices = np.arange(len(all_df))
|
|
627
|
+
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
|
|
628
|
+
folds = [(train_idx, val_idx)]
|
|
629
|
+
else:
|
|
630
|
+
# K-Fold CV
|
|
631
|
+
if model_type == "classifier":
|
|
632
|
+
kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
633
|
+
split_target = all_df[target_columns[0]]
|
|
634
|
+
else:
|
|
635
|
+
kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
636
|
+
split_target = None
|
|
637
|
+
folds = list(kfold.split(all_df, split_target))
|
|
638
|
+
|
|
639
|
+
# Initialize storage for out-of-fold predictions: always 2D (n_samples, n_targets)
|
|
640
|
+
oof_predictions = np.full((len(all_df), n_targets), np.nan, dtype=np.float64)
|
|
641
|
+
if model_type == "classifier" and num_classes and num_classes > 1:
|
|
642
|
+
oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
|
|
643
|
+
else:
|
|
644
|
+
oof_proba = None
|
|
645
|
+
|
|
646
|
+
ensemble_models = []
|
|
647
|
+
|
|
648
|
+
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
649
|
+
print(f"\n{'='*50}")
|
|
650
|
+
print(f"Training Fold {fold_idx + 1}/{len(folds)}")
|
|
651
|
+
print(f"{'='*50}")
|
|
652
|
+
|
|
653
|
+
# Split data for this fold
|
|
654
|
+
df_train = all_df.iloc[train_idx].reset_index(drop=True)
|
|
655
|
+
df_val = all_df.iloc[val_idx].reset_index(drop=True)
|
|
656
|
+
train_targets = all_targets[train_idx]
|
|
657
|
+
val_targets = all_targets[val_idx]
|
|
658
|
+
|
|
659
|
+
train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
|
|
660
|
+
val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
|
|
661
|
+
|
|
662
|
+
print(f"Fold {fold_idx + 1} - Train: {len(df_train)}, Val: {len(df_val)}")
|
|
663
|
+
|
|
664
|
+
# Create ChemProp datasets for this fold
|
|
665
|
+
train_datapoints, _ = create_molecule_datapoints(
|
|
666
|
+
df_train[smiles_column].tolist(), train_targets, train_extra
|
|
667
|
+
)
|
|
668
|
+
val_datapoints, _ = create_molecule_datapoints(
|
|
669
|
+
df_val[smiles_column].tolist(), val_targets, val_extra
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
train_dataset = data.MoleculeDataset(train_datapoints)
|
|
673
|
+
val_dataset = data.MoleculeDataset(val_datapoints)
|
|
674
|
+
|
|
675
|
+
# Save raw val features for prediction
|
|
676
|
+
val_extra_raw = val_extra.copy() if val_extra is not None else None
|
|
677
|
+
|
|
678
|
+
# Scale features and targets for this fold
|
|
679
|
+
x_d_transform = None
|
|
680
|
+
if use_extra_features:
|
|
681
|
+
feature_scaler = train_dataset.normalize_inputs("X_d")
|
|
682
|
+
val_dataset.normalize_inputs("X_d", feature_scaler)
|
|
683
|
+
x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
|
|
684
|
+
|
|
685
|
+
output_transform = None
|
|
686
|
+
if model_type in ["regressor", "uq_regressor"]:
|
|
687
|
+
target_scaler = train_dataset.normalize_targets()
|
|
688
|
+
val_dataset.normalize_targets(target_scaler)
|
|
689
|
+
output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
|
|
690
|
+
|
|
691
|
+
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
692
|
+
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
693
|
+
|
|
694
|
+
# Build and train model for this fold
|
|
695
|
+
pl.seed_everything(42 + fold_idx)
|
|
696
|
+
mpnn = build_mpnn_model(
|
|
697
|
+
hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
|
|
698
|
+
n_extra_descriptors=n_extra, x_d_transform=x_d_transform, output_transform=output_transform,
|
|
699
|
+
task_weights=task_weights,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
callbacks = [
|
|
703
|
+
pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
|
|
704
|
+
pl.callbacks.ModelCheckpoint(
|
|
705
|
+
dirpath=args.model_dir, filename=f"best_model_{fold_idx}",
|
|
706
|
+
monitor="val_loss", mode="min", save_top_k=1,
|
|
707
|
+
),
|
|
708
|
+
]
|
|
709
|
+
|
|
710
|
+
trainer = pl.Trainer(
|
|
711
|
+
accelerator="auto", max_epochs=max_epochs, callbacks=callbacks,
|
|
712
|
+
logger=False, enable_progress_bar=True,
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
716
|
+
|
|
717
|
+
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
|
|
718
|
+
checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
|
|
719
|
+
mpnn.load_state_dict(checkpoint["state_dict"])
|
|
720
|
+
|
|
721
|
+
mpnn.eval()
|
|
722
|
+
ensemble_models.append(mpnn)
|
|
723
|
+
|
|
724
|
+
# Make out-of-fold predictions using raw features
|
|
725
|
+
val_datapoints_raw, _ = create_molecule_datapoints(
|
|
726
|
+
df_val[smiles_column].tolist(), val_targets, val_extra_raw
|
|
727
|
+
)
|
|
728
|
+
val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
|
|
729
|
+
val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
|
|
730
|
+
|
|
731
|
+
with torch.inference_mode():
|
|
732
|
+
fold_predictions = trainer.predict(mpnn, val_loader_pred)
|
|
733
|
+
fold_preds = np.concatenate([p.numpy() for p in fold_predictions], axis=0)
|
|
734
|
+
if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
|
|
735
|
+
fold_preds = fold_preds.squeeze(axis=1)
|
|
736
|
+
|
|
737
|
+
# Store out-of-fold predictions
|
|
738
|
+
if model_type == "classifier" and fold_preds.ndim == 2:
|
|
739
|
+
# Store class index in first column for classification
|
|
740
|
+
oof_predictions[val_idx, 0] = np.argmax(fold_preds, axis=1)
|
|
741
|
+
if oof_proba is not None:
|
|
742
|
+
oof_proba[val_idx] = fold_preds
|
|
743
|
+
else:
|
|
744
|
+
# Regression: fold_preds shape is (n_val, n_targets) or (n_val,)
|
|
745
|
+
if fold_preds.ndim == 1:
|
|
746
|
+
fold_preds = fold_preds.reshape(-1, 1)
|
|
747
|
+
oof_predictions[val_idx] = fold_preds
|
|
748
|
+
|
|
749
|
+
print(f"Fold {fold_idx + 1} complete!")
|
|
750
|
+
|
|
751
|
+
print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
|
|
752
|
+
|
|
753
|
+
# Use out-of-fold predictions for metrics
|
|
754
|
+
# For n_folds=1, we only have predictions for val_idx, so filter to those rows
|
|
755
|
+
if n_folds == 1:
|
|
756
|
+
# oof_predictions is always 2D now: check if any column has a value
|
|
757
|
+
val_mask = ~np.isnan(oof_predictions).all(axis=1)
|
|
758
|
+
preds = oof_predictions[val_mask]
|
|
759
|
+
df_val = all_df[val_mask].copy()
|
|
760
|
+
y_validate = all_targets[val_mask]
|
|
761
|
+
if oof_proba is not None:
|
|
762
|
+
oof_proba = oof_proba[val_mask]
|
|
763
|
+
val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
|
|
764
|
+
else:
|
|
765
|
+
preds = oof_predictions
|
|
766
|
+
df_val = all_df.copy()
|
|
767
|
+
y_validate = all_targets
|
|
768
|
+
val_extra_features = all_extra_features
|
|
769
|
+
|
|
770
|
+
# Compute prediction_std by running all ensemble models on validation data
|
|
771
|
+
# For n_folds=1, std will be 0 (only one model). For n_folds>1, std shows ensemble disagreement.
|
|
772
|
+
preds_std = None
|
|
773
|
+
if model_type in ["regressor", "uq_regressor"] and len(ensemble_models) > 0:
|
|
774
|
+
print("Computing prediction_std from ensemble predictions on validation data...")
|
|
775
|
+
val_datapoints_for_std, _ = create_molecule_datapoints(
|
|
776
|
+
df_val[smiles_column].tolist(),
|
|
777
|
+
y_validate,
|
|
778
|
+
val_extra_features
|
|
779
|
+
)
|
|
780
|
+
val_dataset_for_std = data.MoleculeDataset(val_datapoints_for_std)
|
|
781
|
+
val_loader_for_std = data.build_dataloader(val_dataset_for_std, batch_size=batch_size, shuffle=False)
|
|
782
|
+
|
|
783
|
+
all_ensemble_preds_for_std = []
|
|
784
|
+
trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
785
|
+
for ens_model in ensemble_models:
|
|
786
|
+
with torch.inference_mode():
|
|
787
|
+
ens_preds = trainer_pred.predict(ens_model, val_loader_for_std)
|
|
788
|
+
ens_preds = np.concatenate([p.numpy() for p in ens_preds], axis=0)
|
|
789
|
+
if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
|
|
790
|
+
ens_preds = ens_preds.squeeze(axis=1)
|
|
791
|
+
all_ensemble_preds_for_std.append(ens_preds)
|
|
792
|
+
|
|
793
|
+
# Stack ensemble predictions: shape (n_ensemble, n_samples, n_targets)
|
|
794
|
+
ensemble_preds_stacked = np.stack(all_ensemble_preds_for_std, axis=0)
|
|
795
|
+
preds_std = np.std(ensemble_preds_stacked, axis=0)
|
|
796
|
+
# Ensure 2D
|
|
797
|
+
if preds_std.ndim == 1:
|
|
798
|
+
preds_std = preds_std.reshape(-1, 1)
|
|
799
|
+
print(f"Ensemble prediction_std - mean per target: {np.nanmean(preds_std, axis=0)}")
|
|
800
|
+
|
|
801
|
+
if model_type == "classifier":
|
|
802
|
+
# Classification metrics - preds contains class indices in first column from OOF predictions
|
|
803
|
+
class_preds = preds[:, 0].astype(int)
|
|
804
|
+
has_proba = oof_proba is not None
|
|
805
|
+
|
|
806
|
+
print(f"class_preds shape: {class_preds.shape}")
|
|
807
|
+
|
|
808
|
+
# Decode labels for metrics (classification is single-target only)
|
|
809
|
+
target_name = target_columns[0]
|
|
810
|
+
y_validate_decoded = label_encoder.inverse_transform(y_validate[:, 0].astype(int))
|
|
811
|
+
preds_decoded = label_encoder.inverse_transform(class_preds)
|
|
812
|
+
|
|
813
|
+
# Calculate metrics
|
|
814
|
+
label_names = label_encoder.classes_
|
|
815
|
+
scores = precision_recall_fscore_support(
|
|
816
|
+
y_validate_decoded, preds_decoded, average=None, labels=label_names
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
score_df = pd.DataFrame(
|
|
820
|
+
{
|
|
821
|
+
target_name: label_names,
|
|
822
|
+
"precision": scores[0],
|
|
823
|
+
"recall": scores[1],
|
|
824
|
+
"f1": scores[2],
|
|
825
|
+
"support": scores[3],
|
|
826
|
+
}
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
# Output metrics per class
|
|
830
|
+
metrics = ["precision", "recall", "f1", "support"]
|
|
831
|
+
for t in label_names:
|
|
832
|
+
for m in metrics:
|
|
833
|
+
value = score_df.loc[score_df[target_name] == t, m].iloc[0]
|
|
834
|
+
print(f"Metrics:{t}:{m} {value}")
|
|
835
|
+
|
|
836
|
+
# Confusion matrix
|
|
837
|
+
conf_mtx = confusion_matrix(
|
|
838
|
+
y_validate_decoded, preds_decoded, labels=label_names
|
|
839
|
+
)
|
|
840
|
+
for i, row_name in enumerate(label_names):
|
|
841
|
+
for j, col_name in enumerate(label_names):
|
|
842
|
+
value = conf_mtx[i, j]
|
|
843
|
+
print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
|
|
844
|
+
|
|
845
|
+
# Save validation predictions
|
|
846
|
+
df_val = df_val.copy()
|
|
847
|
+
df_val["prediction"] = preds_decoded
|
|
848
|
+
if has_proba and oof_proba is not None:
|
|
849
|
+
df_val["pred_proba"] = [p.tolist() for p in oof_proba]
|
|
850
|
+
df_val = expand_proba_column(df_val, label_names)
|
|
851
|
+
|
|
852
|
+
else:
|
|
853
|
+
# Regression metrics: compute per target (works for single or multi-task)
|
|
854
|
+
df_val = df_val.copy()
|
|
855
|
+
print("\n--- Per-target metrics ---")
|
|
856
|
+
for t_idx, t_name in enumerate(target_columns):
|
|
857
|
+
# Get valid (non-NaN) indices for this target
|
|
858
|
+
target_valid_mask = ~np.isnan(y_validate[:, t_idx])
|
|
859
|
+
y_true = y_validate[target_valid_mask, t_idx]
|
|
860
|
+
y_pred = preds[target_valid_mask, t_idx]
|
|
861
|
+
|
|
862
|
+
if len(y_true) > 0:
|
|
863
|
+
rmse = root_mean_squared_error(y_true, y_pred)
|
|
864
|
+
mae = mean_absolute_error(y_true, y_pred)
|
|
865
|
+
medae = median_absolute_error(y_true, y_pred)
|
|
866
|
+
r2 = r2_score(y_true, y_pred)
|
|
867
|
+
spearman_corr = spearmanr(y_true, y_pred).correlation
|
|
868
|
+
support = len(y_true)
|
|
869
|
+
# Print metrics in format expected by SageMaker metric definitions
|
|
870
|
+
print(f"rmse: {rmse:.3f}")
|
|
871
|
+
print(f"mae: {mae:.3f}")
|
|
872
|
+
print(f"medae: {medae:.3f}")
|
|
873
|
+
print(f"r2: {r2:.3f}")
|
|
874
|
+
print(f"spearmanr: {spearman_corr:.3f}")
|
|
875
|
+
print(f"support: {support}")
|
|
876
|
+
|
|
877
|
+
# Store predictions in dataframe
|
|
878
|
+
df_val[f"{t_name}_pred"] = preds[:, t_idx]
|
|
879
|
+
if preds_std is not None:
|
|
880
|
+
df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx]
|
|
881
|
+
else:
|
|
882
|
+
df_val[f"{t_name}_pred_std"] = 0.0
|
|
883
|
+
|
|
884
|
+
# Save validation predictions to S3
|
|
885
|
+
# Include id_column if it exists in df_val
|
|
886
|
+
output_columns = []
|
|
887
|
+
if id_column in df_val.columns:
|
|
888
|
+
output_columns.append(id_column)
|
|
889
|
+
# Include all target columns and their predictions
|
|
890
|
+
output_columns += target_columns
|
|
891
|
+
output_columns += [f"{t}_pred" for t in target_columns]
|
|
892
|
+
output_columns += [f"{t}_pred_std" for t in target_columns]
|
|
893
|
+
# Add proba columns for classifiers
|
|
894
|
+
output_columns += [col for col in df_val.columns if col.endswith("_proba")]
|
|
895
|
+
# Filter to only columns that exist
|
|
896
|
+
output_columns = [c for c in output_columns if c in df_val.columns]
|
|
897
|
+
wr.s3.to_csv(
|
|
898
|
+
df_val[output_columns],
|
|
899
|
+
path=f"{model_metrics_s3_path}/validation_predictions.csv",
|
|
900
|
+
index=False,
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
# Save ensemble models (n_folds models if CV, 1 model otherwise)
|
|
904
|
+
for model_idx, ens_model in enumerate(ensemble_models):
|
|
905
|
+
model_path = os.path.join(args.model_dir, f"chemprop_model_{model_idx}.pt")
|
|
906
|
+
models.save_model(model_path, ens_model)
|
|
907
|
+
print(f"Saved model {model_idx + 1} to {model_path}")
|
|
908
|
+
|
|
909
|
+
# Save ensemble metadata (n_ensemble = number of models for inference)
|
|
910
|
+
n_ensemble = len(ensemble_models)
|
|
911
|
+
ensemble_metadata = {
|
|
912
|
+
"n_ensemble": n_ensemble,
|
|
913
|
+
"n_folds": n_folds,
|
|
914
|
+
"target_columns": target_columns,
|
|
915
|
+
}
|
|
916
|
+
joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
917
|
+
print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds}, targets={target_columns})")
|
|
918
|
+
|
|
919
|
+
# Save label encoder if classification
|
|
920
|
+
if label_encoder is not None:
|
|
921
|
+
joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
|
|
922
|
+
|
|
923
|
+
# Save extra feature metadata for inference (hybrid mode)
|
|
924
|
+
# Note: We don't need to save the scaler - X_d_transform is embedded in the model
|
|
925
|
+
if use_extra_features:
|
|
926
|
+
feature_metadata = {
|
|
927
|
+
"extra_feature_cols": extra_feature_cols,
|
|
928
|
+
"col_means": col_means.tolist(), # Unscaled means for NaN imputation
|
|
929
|
+
}
|
|
930
|
+
joblib.dump(
|
|
931
|
+
feature_metadata, os.path.join(args.model_dir, "feature_metadata.joblib")
|
|
932
|
+
)
|
|
933
|
+
print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")
|