workbench 0.8.197__py3-none-any.whl → 0.8.201__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/proximity.py +19 -12
- workbench/api/__init__.py +2 -1
- workbench/api/feature_set.py +7 -4
- workbench/api/model.py +1 -1
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/endpoint_core.py +84 -46
- workbench/core/artifacts/feature_set_core.py +69 -1
- workbench/core/artifacts/model_core.py +37 -7
- workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
- workbench/core/transforms/features_to_model/features_to_model.py +23 -20
- workbench/core/views/view.py +2 -2
- workbench/model_scripts/chemprop/chemprop.template +931 -0
- workbench/model_scripts/chemprop/generated_model_script.py +931 -0
- workbench/model_scripts/chemprop/requirements.txt +11 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/proximity.py +19 -12
- workbench/model_scripts/custom_models/uq_models/proximity.py +19 -12
- workbench/model_scripts/pytorch_model/generated_model_script.py +130 -88
- workbench/model_scripts/pytorch_model/pytorch.template +128 -86
- workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
- workbench/model_scripts/script_generation.py +10 -7
- workbench/model_scripts/uq_models/generated_model_script.py +25 -18
- workbench/model_scripts/uq_models/mapie.template +23 -16
- workbench/model_scripts/xgb_model/generated_model_script.py +6 -6
- workbench/model_scripts/xgb_model/xgb_model.template +2 -2
- workbench/repl/workbench_shell.py +14 -5
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
- workbench/utils/chemprop_utils.py +724 -0
- workbench/utils/pytorch_utils.py +497 -0
- workbench/utils/xgboost_model_utils.py +12 -5
- {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/METADATA +2 -2
- {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/RECORD +38 -30
- {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/entry_points.txt +2 -1
- {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/WHEEL +0 -0
- {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,931 @@
|
|
|
1
|
+
# ChemProp Model Template for Workbench
|
|
2
|
+
# Uses ChemProp 2.x Message Passing Neural Networks for molecular property prediction
|
|
3
|
+
#
|
|
4
|
+
# === CHEMPROP REVIEW NOTES ===
|
|
5
|
+
# This script runs on AWS SageMaker. Key areas for ChemProp review:
|
|
6
|
+
#
|
|
7
|
+
# 1. Model Architecture (build_mpnn_model function)
|
|
8
|
+
# - BondMessagePassing, NormAggregation, FFN configuration
|
|
9
|
+
# - Regression uses output_transform (UnscaleTransform) for target scaling
|
|
10
|
+
#
|
|
11
|
+
# 2. Data Handling (create_molecule_datapoints function)
|
|
12
|
+
# - MoleculeDatapoint creation with x_d (extra descriptors)
|
|
13
|
+
# - RDKit validation of SMILES
|
|
14
|
+
#
|
|
15
|
+
# 3. Scaling (training section)
|
|
16
|
+
# - Extra descriptors: normalize_inputs("X_d") + X_d_transform in model
|
|
17
|
+
# - Targets (regression): normalize_targets() + UnscaleTransform in FFN
|
|
18
|
+
# - At inference: pass RAW features, transforms handle scaling automatically
|
|
19
|
+
#
|
|
20
|
+
# 4. Training Loop (search for "pl.Trainer")
|
|
21
|
+
# - PyTorch Lightning Trainer with ChemProp MPNN
|
|
22
|
+
#
|
|
23
|
+
# AWS/SageMaker boilerplate (can skip):
|
|
24
|
+
# - input_fn, output_fn, model_fn: SageMaker serving interface
|
|
25
|
+
# - argparse, file loading, S3 writes
|
|
26
|
+
# =============================
|
|
27
|
+
|
|
28
|
+
import os
|
|
29
|
+
import argparse
|
|
30
|
+
import json
|
|
31
|
+
from io import StringIO
|
|
32
|
+
|
|
33
|
+
import awswrangler as wr
|
|
34
|
+
import numpy as np
|
|
35
|
+
import pandas as pd
|
|
36
|
+
import torch
|
|
37
|
+
from lightning import pytorch as pl
|
|
38
|
+
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
|
|
39
|
+
from sklearn.preprocessing import LabelEncoder
|
|
40
|
+
from sklearn.metrics import (
|
|
41
|
+
mean_absolute_error,
|
|
42
|
+
r2_score,
|
|
43
|
+
root_mean_squared_error,
|
|
44
|
+
precision_recall_fscore_support,
|
|
45
|
+
confusion_matrix,
|
|
46
|
+
)
|
|
47
|
+
import joblib
|
|
48
|
+
|
|
49
|
+
# ChemProp imports
|
|
50
|
+
from chemprop import data, models, nn
|
|
51
|
+
|
|
52
|
+
# Template Parameters
|
|
53
|
+
TEMPLATE_PARAMS = {
|
|
54
|
+
"model_type": "{{model_type}}",
|
|
55
|
+
"target": "{{target_column}}",
|
|
56
|
+
"feature_list": "{{feature_list}}",
|
|
57
|
+
"model_metrics_s3_path": "{{model_metrics_s3_path}}",
|
|
58
|
+
"train_all_data": "{{train_all_data}}",
|
|
59
|
+
"hyperparameters": "{{hyperparameters}}",
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
|
|
64
|
+
"""Check if the provided dataframe is empty and raise an exception if it is."""
|
|
65
|
+
if df.empty:
|
|
66
|
+
msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
|
|
67
|
+
print(msg)
|
|
68
|
+
raise ValueError(msg)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def find_smiles_column(columns: list[str]) -> str:
|
|
72
|
+
"""Find the SMILES column name from a list (case-insensitive match for 'smiles')."""
|
|
73
|
+
smiles_column = next((col for col in columns if col.lower() == "smiles"), None)
|
|
74
|
+
if smiles_column is None:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
"Column list must contain a 'smiles' column (case-insensitive)"
|
|
77
|
+
)
|
|
78
|
+
return smiles_column
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
|
|
82
|
+
"""Expands a column containing a list of probabilities into separate columns.
|
|
83
|
+
|
|
84
|
+
Handles None values for rows where predictions couldn't be made.
|
|
85
|
+
"""
|
|
86
|
+
proba_column = "pred_proba"
|
|
87
|
+
if proba_column not in df.columns:
|
|
88
|
+
raise ValueError('DataFrame does not contain a "pred_proba" column')
|
|
89
|
+
|
|
90
|
+
proba_splits = [f"{label}_proba" for label in class_labels]
|
|
91
|
+
n_classes = len(class_labels)
|
|
92
|
+
|
|
93
|
+
# Handle None values by replacing with list of NaNs
|
|
94
|
+
proba_values = []
|
|
95
|
+
for val in df[proba_column]:
|
|
96
|
+
if val is None:
|
|
97
|
+
proba_values.append([np.nan] * n_classes)
|
|
98
|
+
else:
|
|
99
|
+
proba_values.append(val)
|
|
100
|
+
|
|
101
|
+
proba_df = pd.DataFrame(proba_values, columns=proba_splits)
|
|
102
|
+
|
|
103
|
+
df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
|
|
104
|
+
df = df.reset_index(drop=True)
|
|
105
|
+
df = pd.concat([df, proba_df], axis=1)
|
|
106
|
+
return df
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def create_molecule_datapoints(
|
|
110
|
+
smiles_list: list[str],
|
|
111
|
+
targets: list[float] | None = None,
|
|
112
|
+
extra_descriptors: np.ndarray | None = None,
|
|
113
|
+
) -> tuple[list[data.MoleculeDatapoint], list[int]]:
|
|
114
|
+
"""Create ChemProp MoleculeDatapoints from SMILES strings.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
smiles_list: List of SMILES strings
|
|
118
|
+
targets: Optional list of target values (for training)
|
|
119
|
+
extra_descriptors: Optional array of extra features (n_samples, n_features)
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Tuple of (list of MoleculeDatapoint objects, list of valid indices)
|
|
123
|
+
"""
|
|
124
|
+
from rdkit import Chem
|
|
125
|
+
|
|
126
|
+
datapoints = []
|
|
127
|
+
valid_indices = []
|
|
128
|
+
invalid_count = 0
|
|
129
|
+
|
|
130
|
+
for i, smi in enumerate(smiles_list):
|
|
131
|
+
# Validate SMILES with RDKit first
|
|
132
|
+
mol = Chem.MolFromSmiles(smi)
|
|
133
|
+
if mol is None:
|
|
134
|
+
invalid_count += 1
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
# Build datapoint with optional target and extra descriptors
|
|
138
|
+
y = [targets[i]] if targets is not None else None
|
|
139
|
+
x_d = extra_descriptors[i] if extra_descriptors is not None else None
|
|
140
|
+
|
|
141
|
+
dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
|
|
142
|
+
datapoints.append(dp)
|
|
143
|
+
valid_indices.append(i)
|
|
144
|
+
|
|
145
|
+
if invalid_count > 0:
|
|
146
|
+
print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
|
|
147
|
+
|
|
148
|
+
return datapoints, valid_indices
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def build_mpnn_model(
|
|
152
|
+
hyperparameters: dict,
|
|
153
|
+
task: str = "regression",
|
|
154
|
+
num_classes: int | None = None,
|
|
155
|
+
n_extra_descriptors: int = 0,
|
|
156
|
+
x_d_transform: nn.ScaleTransform | None = None,
|
|
157
|
+
output_transform: nn.UnscaleTransform | None = None,
|
|
158
|
+
) -> models.MPNN:
|
|
159
|
+
"""Build an MPNN model with the specified hyperparameters.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
hyperparameters: Dictionary of model hyperparameters
|
|
163
|
+
task: Either "regression" or "classification"
|
|
164
|
+
num_classes: Number of classes for classification tasks
|
|
165
|
+
n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
|
|
166
|
+
x_d_transform: Optional transform for extra descriptors (scaling)
|
|
167
|
+
output_transform: Optional transform for regression output (unscaling targets)
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Configured MPNN model
|
|
171
|
+
"""
|
|
172
|
+
# Model hyperparameters with defaults (based on OpenADMET baseline with slight improvements)
|
|
173
|
+
hidden_dim = hyperparameters.get("hidden_dim", 300)
|
|
174
|
+
depth = hyperparameters.get("depth", 4)
|
|
175
|
+
dropout = hyperparameters.get("dropout", 0.10)
|
|
176
|
+
ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 300)
|
|
177
|
+
ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
|
|
178
|
+
|
|
179
|
+
# Message passing component
|
|
180
|
+
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
181
|
+
|
|
182
|
+
# Aggregation - NormAggregation normalizes output, recommended when using extra descriptors
|
|
183
|
+
agg = nn.NormAggregation()
|
|
184
|
+
|
|
185
|
+
# FFN input_dim = message passing output + extra descriptors
|
|
186
|
+
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
187
|
+
|
|
188
|
+
# Build FFN based on task type
|
|
189
|
+
if task == "classification" and num_classes is not None:
|
|
190
|
+
# Multi-class classification
|
|
191
|
+
ffn = nn.MulticlassClassificationFFN(
|
|
192
|
+
n_classes=num_classes,
|
|
193
|
+
input_dim=ffn_input_dim,
|
|
194
|
+
hidden_dim=ffn_hidden_dim,
|
|
195
|
+
n_layers=ffn_num_layers,
|
|
196
|
+
dropout=dropout,
|
|
197
|
+
)
|
|
198
|
+
else:
|
|
199
|
+
# Regression with optional output transform to unscale predictions
|
|
200
|
+
ffn = nn.RegressionFFN(
|
|
201
|
+
input_dim=ffn_input_dim,
|
|
202
|
+
hidden_dim=ffn_hidden_dim,
|
|
203
|
+
n_layers=ffn_num_layers,
|
|
204
|
+
dropout=dropout,
|
|
205
|
+
output_transform=output_transform,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Create the MPNN model
|
|
209
|
+
mpnn = models.MPNN(
|
|
210
|
+
message_passing=mp,
|
|
211
|
+
agg=agg,
|
|
212
|
+
predictor=ffn,
|
|
213
|
+
batch_norm=True,
|
|
214
|
+
metrics=None,
|
|
215
|
+
X_d_transform=x_d_transform,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
return mpnn
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def model_fn(model_dir: str) -> dict:
|
|
222
|
+
"""Load the ChemProp MPNN ensemble models from the specified directory.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
model_dir: Directory containing the saved models
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Dictionary with ensemble models and metadata
|
|
229
|
+
"""
|
|
230
|
+
# Load ensemble metadata
|
|
231
|
+
ensemble_metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
|
|
232
|
+
if os.path.exists(ensemble_metadata_path):
|
|
233
|
+
ensemble_metadata = joblib.load(ensemble_metadata_path)
|
|
234
|
+
n_ensemble = ensemble_metadata["n_ensemble"]
|
|
235
|
+
else:
|
|
236
|
+
# Backwards compatibility: single model without ensemble metadata
|
|
237
|
+
n_ensemble = 1
|
|
238
|
+
|
|
239
|
+
# Load all ensemble models
|
|
240
|
+
ensemble_models = []
|
|
241
|
+
for ens_idx in range(n_ensemble):
|
|
242
|
+
model_path = os.path.join(model_dir, f"chemprop_model_{ens_idx}.pt")
|
|
243
|
+
if not os.path.exists(model_path):
|
|
244
|
+
# Backwards compatibility: try old single model path
|
|
245
|
+
model_path = os.path.join(model_dir, "chemprop_model.pt")
|
|
246
|
+
model = models.MPNN.load_from_file(model_path)
|
|
247
|
+
model.eval()
|
|
248
|
+
ensemble_models.append(model)
|
|
249
|
+
|
|
250
|
+
print(f"Loaded {len(ensemble_models)} ensemble model(s)")
|
|
251
|
+
|
|
252
|
+
return {
|
|
253
|
+
"ensemble_models": ensemble_models,
|
|
254
|
+
"n_ensemble": n_ensemble,
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def input_fn(input_data, content_type: str) -> pd.DataFrame:
|
|
259
|
+
"""Parse input data and return a DataFrame."""
|
|
260
|
+
if not input_data:
|
|
261
|
+
raise ValueError("Empty input data is not supported!")
|
|
262
|
+
|
|
263
|
+
if isinstance(input_data, bytes):
|
|
264
|
+
input_data = input_data.decode("utf-8")
|
|
265
|
+
|
|
266
|
+
if "text/csv" in content_type:
|
|
267
|
+
return pd.read_csv(StringIO(input_data))
|
|
268
|
+
elif "application/json" in content_type:
|
|
269
|
+
return pd.DataFrame(json.loads(input_data))
|
|
270
|
+
else:
|
|
271
|
+
raise ValueError(f"{content_type} not supported!")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
275
|
+
"""Supports both CSV and JSON output formats."""
|
|
276
|
+
if "text/csv" in accept_type:
|
|
277
|
+
csv_output = output_df.fillna("N/A").to_csv(index=False)
|
|
278
|
+
return csv_output, "text/csv"
|
|
279
|
+
elif "application/json" in accept_type:
|
|
280
|
+
return output_df.to_json(orient="records"), "application/json"
|
|
281
|
+
else:
|
|
282
|
+
raise RuntimeError(
|
|
283
|
+
f"{accept_type} accept type is not supported by this script."
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
288
|
+
"""Make predictions with the ChemProp MPNN ensemble.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
df: Input DataFrame containing SMILES column (and extra features if hybrid mode)
|
|
292
|
+
model_dict: Dictionary containing ensemble models and metadata
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
DataFrame with predictions added (and prediction_std for ensembles)
|
|
296
|
+
"""
|
|
297
|
+
model_type = TEMPLATE_PARAMS["model_type"]
|
|
298
|
+
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
299
|
+
|
|
300
|
+
# Extract ensemble models
|
|
301
|
+
ensemble_models = model_dict["ensemble_models"]
|
|
302
|
+
n_ensemble = model_dict["n_ensemble"]
|
|
303
|
+
|
|
304
|
+
# Load label encoder if present (classification)
|
|
305
|
+
label_encoder = None
|
|
306
|
+
label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
|
|
307
|
+
if os.path.exists(label_encoder_path):
|
|
308
|
+
label_encoder = joblib.load(label_encoder_path)
|
|
309
|
+
|
|
310
|
+
# Load feature metadata if present (hybrid mode)
|
|
311
|
+
# Contains column names, NaN fill values, and scaler for feature scaling
|
|
312
|
+
feature_metadata = None
|
|
313
|
+
feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
|
|
314
|
+
if os.path.exists(feature_metadata_path):
|
|
315
|
+
feature_metadata = joblib.load(feature_metadata_path)
|
|
316
|
+
print(
|
|
317
|
+
f"Hybrid mode: using {len(feature_metadata['extra_feature_cols'])} extra features"
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Find SMILES column in input DataFrame
|
|
321
|
+
smiles_column = find_smiles_column(df.columns.tolist())
|
|
322
|
+
|
|
323
|
+
smiles_list = df[smiles_column].tolist()
|
|
324
|
+
|
|
325
|
+
# Track invalid SMILES
|
|
326
|
+
valid_mask = []
|
|
327
|
+
valid_smiles = []
|
|
328
|
+
valid_indices = []
|
|
329
|
+
for i, smi in enumerate(smiles_list):
|
|
330
|
+
if smi and isinstance(smi, str) and len(smi.strip()) > 0:
|
|
331
|
+
valid_mask.append(True)
|
|
332
|
+
valid_smiles.append(smi.strip())
|
|
333
|
+
valid_indices.append(i)
|
|
334
|
+
else:
|
|
335
|
+
valid_mask.append(False)
|
|
336
|
+
|
|
337
|
+
valid_mask = np.array(valid_mask)
|
|
338
|
+
print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
|
|
339
|
+
|
|
340
|
+
# Initialize prediction column (use object dtype for classifiers to avoid FutureWarning)
|
|
341
|
+
if model_type == "classifier":
|
|
342
|
+
df["prediction"] = pd.Series([None] * len(df), dtype=object)
|
|
343
|
+
else:
|
|
344
|
+
df["prediction"] = np.nan
|
|
345
|
+
if n_ensemble > 1:
|
|
346
|
+
df["prediction_std"] = np.nan
|
|
347
|
+
|
|
348
|
+
if sum(valid_mask) == 0:
|
|
349
|
+
print("Warning: No valid SMILES to predict on")
|
|
350
|
+
return df
|
|
351
|
+
|
|
352
|
+
# Prepare extra features if in hybrid mode
|
|
353
|
+
# NOTE: We pass RAW (unscaled) features here - the model's X_d_transform handles scaling
|
|
354
|
+
extra_features = None
|
|
355
|
+
if feature_metadata is not None:
|
|
356
|
+
extra_feature_cols = feature_metadata["extra_feature_cols"]
|
|
357
|
+
col_means = np.array(feature_metadata["col_means"])
|
|
358
|
+
|
|
359
|
+
# Check columns exist
|
|
360
|
+
missing_cols = [col for col in extra_feature_cols if col not in df.columns]
|
|
361
|
+
if missing_cols:
|
|
362
|
+
print(
|
|
363
|
+
f"Warning: Missing extra feature columns: {missing_cols}. Using mean values."
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Extract features for valid SMILES rows (raw, unscaled)
|
|
367
|
+
extra_features = np.zeros(
|
|
368
|
+
(len(valid_indices), len(extra_feature_cols)), dtype=np.float32
|
|
369
|
+
)
|
|
370
|
+
for j, col in enumerate(extra_feature_cols):
|
|
371
|
+
if col in df.columns:
|
|
372
|
+
values = df.iloc[valid_indices][col].values.astype(np.float32)
|
|
373
|
+
# Fill NaN with training column means (unscaled means)
|
|
374
|
+
nan_mask = np.isnan(values)
|
|
375
|
+
values[nan_mask] = col_means[j]
|
|
376
|
+
extra_features[:, j] = values
|
|
377
|
+
else:
|
|
378
|
+
# Column missing, use training mean
|
|
379
|
+
extra_features[:, j] = col_means[j]
|
|
380
|
+
|
|
381
|
+
# Create datapoints for prediction (filter out invalid SMILES)
|
|
382
|
+
datapoints, rdkit_valid_indices = create_molecule_datapoints(
|
|
383
|
+
valid_smiles, extra_descriptors=extra_features
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
if len(datapoints) == 0:
|
|
387
|
+
print("Warning: No valid SMILES after RDKit validation")
|
|
388
|
+
return df
|
|
389
|
+
|
|
390
|
+
dataset = data.MoleculeDataset(datapoints)
|
|
391
|
+
dataloader = data.build_dataloader(dataset, shuffle=False)
|
|
392
|
+
|
|
393
|
+
# Make predictions with ensemble
|
|
394
|
+
trainer = pl.Trainer(
|
|
395
|
+
accelerator="auto",
|
|
396
|
+
logger=False,
|
|
397
|
+
enable_progress_bar=False,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Collect predictions from all ensemble members
|
|
401
|
+
all_ensemble_preds = []
|
|
402
|
+
for ens_idx, ens_model in enumerate(ensemble_models):
|
|
403
|
+
with torch.inference_mode():
|
|
404
|
+
predictions = trainer.predict(ens_model, dataloader)
|
|
405
|
+
ens_preds = np.concatenate([p.numpy() for p in predictions], axis=0)
|
|
406
|
+
# Squeeze middle dim if present
|
|
407
|
+
if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
|
|
408
|
+
ens_preds = ens_preds.squeeze(axis=1)
|
|
409
|
+
all_ensemble_preds.append(ens_preds)
|
|
410
|
+
|
|
411
|
+
# Stack and compute mean/std
|
|
412
|
+
ensemble_preds = np.stack(all_ensemble_preds, axis=0)
|
|
413
|
+
preds = np.mean(ensemble_preds, axis=0)
|
|
414
|
+
preds_std = np.std(ensemble_preds, axis=0) if n_ensemble > 1 else None
|
|
415
|
+
|
|
416
|
+
print(f"Inference: Ensemble predictions shape: {preds.shape}")
|
|
417
|
+
|
|
418
|
+
# Map predictions back to valid_mask positions (accounting for RDKit-invalid SMILES)
|
|
419
|
+
# rdkit_valid_indices tells us which of the valid_smiles were actually valid
|
|
420
|
+
valid_positions = np.where(valid_mask)[0][rdkit_valid_indices]
|
|
421
|
+
valid_mask = np.zeros(len(df), dtype=bool)
|
|
422
|
+
valid_mask[valid_positions] = True
|
|
423
|
+
|
|
424
|
+
if model_type == "classifier" and label_encoder is not None:
|
|
425
|
+
# For classification, get class predictions and probabilities
|
|
426
|
+
if preds.ndim == 2 and preds.shape[1] > 1:
|
|
427
|
+
# Multi-class: preds are probabilities (averaged across ensemble)
|
|
428
|
+
class_preds = np.argmax(preds, axis=1)
|
|
429
|
+
decoded_preds = label_encoder.inverse_transform(class_preds)
|
|
430
|
+
df.loc[valid_mask, "prediction"] = decoded_preds
|
|
431
|
+
|
|
432
|
+
# Add probability columns
|
|
433
|
+
proba_series = pd.Series([None] * len(df), index=df.index, dtype=object)
|
|
434
|
+
proba_series.loc[valid_mask] = [p.tolist() for p in preds]
|
|
435
|
+
df["pred_proba"] = proba_series
|
|
436
|
+
df = expand_proba_column(df, label_encoder.classes_)
|
|
437
|
+
else:
|
|
438
|
+
# Binary or single output
|
|
439
|
+
class_preds = (preds.flatten() > 0.5).astype(int)
|
|
440
|
+
decoded_preds = label_encoder.inverse_transform(class_preds)
|
|
441
|
+
df.loc[valid_mask, "prediction"] = decoded_preds
|
|
442
|
+
else:
|
|
443
|
+
# Regression: direct predictions
|
|
444
|
+
df.loc[valid_mask, "prediction"] = preds.flatten()
|
|
445
|
+
|
|
446
|
+
# Add prediction_std for ensemble models
|
|
447
|
+
if preds_std is not None:
|
|
448
|
+
df.loc[valid_mask, "prediction_std"] = preds_std.flatten()
|
|
449
|
+
|
|
450
|
+
return df
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
if __name__ == "__main__":
|
|
454
|
+
"""Training script for ChemProp MPNN model"""
|
|
455
|
+
|
|
456
|
+
# Template Parameters
|
|
457
|
+
target = TEMPLATE_PARAMS["target"]
|
|
458
|
+
model_type = TEMPLATE_PARAMS["model_type"]
|
|
459
|
+
feature_list = TEMPLATE_PARAMS["feature_list"]
|
|
460
|
+
model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
|
|
461
|
+
train_all_data = TEMPLATE_PARAMS["train_all_data"]
|
|
462
|
+
hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
|
|
463
|
+
validation_split = 0.2
|
|
464
|
+
|
|
465
|
+
# Get the SMILES column name from feature_list (user defines this, so we use their exact name)
|
|
466
|
+
smiles_column = find_smiles_column(feature_list)
|
|
467
|
+
extra_feature_cols = [f for f in feature_list if f != smiles_column]
|
|
468
|
+
use_extra_features = len(extra_feature_cols) > 0
|
|
469
|
+
print(f"Feature List: {feature_list}")
|
|
470
|
+
print(f"SMILES Column: {smiles_column}")
|
|
471
|
+
print(
|
|
472
|
+
f"Extra Features (hybrid mode): {extra_feature_cols if use_extra_features else 'None (SMILES only)'}"
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
# Script arguments for input/output directories
|
|
476
|
+
parser = argparse.ArgumentParser()
|
|
477
|
+
parser.add_argument(
|
|
478
|
+
"--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
479
|
+
)
|
|
480
|
+
parser.add_argument(
|
|
481
|
+
"--train",
|
|
482
|
+
type=str,
|
|
483
|
+
default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"),
|
|
484
|
+
)
|
|
485
|
+
parser.add_argument(
|
|
486
|
+
"--output-data-dir",
|
|
487
|
+
type=str,
|
|
488
|
+
default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"),
|
|
489
|
+
)
|
|
490
|
+
args = parser.parse_args()
|
|
491
|
+
|
|
492
|
+
# Read the training data
|
|
493
|
+
training_files = [
|
|
494
|
+
os.path.join(args.train, f)
|
|
495
|
+
for f in os.listdir(args.train)
|
|
496
|
+
if f.endswith(".csv")
|
|
497
|
+
]
|
|
498
|
+
print(f"Training Files: {training_files}")
|
|
499
|
+
|
|
500
|
+
all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
|
|
501
|
+
print(f"All Data Shape: {all_df.shape}")
|
|
502
|
+
|
|
503
|
+
check_dataframe(all_df, "training_df")
|
|
504
|
+
|
|
505
|
+
# Drop rows with missing SMILES or target values
|
|
506
|
+
initial_count = len(all_df)
|
|
507
|
+
all_df = all_df.dropna(subset=[smiles_column, target])
|
|
508
|
+
dropped = initial_count - len(all_df)
|
|
509
|
+
if dropped > 0:
|
|
510
|
+
print(f"Dropped {dropped} rows with missing SMILES or target values")
|
|
511
|
+
|
|
512
|
+
print(f"Target: {target}")
|
|
513
|
+
print(f"Data Shape after cleaning: {all_df.shape}")
|
|
514
|
+
|
|
515
|
+
# Set up label encoder for classification
|
|
516
|
+
label_encoder = None
|
|
517
|
+
if model_type == "classifier":
|
|
518
|
+
label_encoder = LabelEncoder()
|
|
519
|
+
all_df[target] = label_encoder.fit_transform(all_df[target])
|
|
520
|
+
num_classes = len(label_encoder.classes_)
|
|
521
|
+
print(
|
|
522
|
+
f"Classification task with {num_classes} classes: {label_encoder.classes_}"
|
|
523
|
+
)
|
|
524
|
+
else:
|
|
525
|
+
num_classes = None
|
|
526
|
+
|
|
527
|
+
# Model and training configuration
|
|
528
|
+
print(f"Hyperparameters: {hyperparameters}")
|
|
529
|
+
task = "classification" if model_type == "classifier" else "regression"
|
|
530
|
+
n_extra = len(extra_feature_cols) if use_extra_features else 0
|
|
531
|
+
max_epochs = hyperparameters.get("max_epochs", 50)
|
|
532
|
+
patience = hyperparameters.get("patience", 10)
|
|
533
|
+
n_folds = hyperparameters.get("n_folds", 1) # Number of CV folds (default: 1 = no CV)
|
|
534
|
+
batch_size = hyperparameters.get("batch_size", min(64, max(16, len(all_df) // 16)))
|
|
535
|
+
|
|
536
|
+
# Check extra feature columns exist
|
|
537
|
+
if use_extra_features:
|
|
538
|
+
missing_cols = [col for col in extra_feature_cols if col not in all_df.columns]
|
|
539
|
+
if missing_cols:
|
|
540
|
+
raise ValueError(f"Missing extra feature columns in training data: {missing_cols}")
|
|
541
|
+
|
|
542
|
+
# =========================================================================
|
|
543
|
+
# SINGLE MODEL TRAINING (n_folds=1) - uses train/val split
|
|
544
|
+
# =========================================================================
|
|
545
|
+
if n_folds == 1:
|
|
546
|
+
print("Training single model (no cross-validation)...")
|
|
547
|
+
|
|
548
|
+
# Split data
|
|
549
|
+
if train_all_data:
|
|
550
|
+
print("Training on ALL of the data")
|
|
551
|
+
df_train = all_df.copy()
|
|
552
|
+
df_val = all_df.copy()
|
|
553
|
+
elif "training" in all_df.columns:
|
|
554
|
+
print("Found training column, splitting data based on training column")
|
|
555
|
+
df_train = all_df[all_df["training"]].copy()
|
|
556
|
+
df_val = all_df[~all_df["training"]].copy()
|
|
557
|
+
else:
|
|
558
|
+
print("WARNING: No training column found, splitting data with random state=42")
|
|
559
|
+
df_train, df_val = train_test_split(
|
|
560
|
+
all_df, test_size=validation_split, random_state=42
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
print(f"TRAIN: {df_train.shape}")
|
|
564
|
+
print(f"VALIDATION: {df_val.shape}")
|
|
565
|
+
|
|
566
|
+
# Extract and prepare extra features
|
|
567
|
+
train_extra_features = None
|
|
568
|
+
val_extra_features = None
|
|
569
|
+
col_means = None
|
|
570
|
+
|
|
571
|
+
if use_extra_features:
|
|
572
|
+
train_extra_features = df_train[extra_feature_cols].values.astype(np.float32)
|
|
573
|
+
val_extra_features = df_val[extra_feature_cols].values.astype(np.float32)
|
|
574
|
+
col_means = np.nanmean(train_extra_features, axis=0)
|
|
575
|
+
for i in range(train_extra_features.shape[1]):
|
|
576
|
+
train_extra_features[np.isnan(train_extra_features[:, i]), i] = col_means[i]
|
|
577
|
+
val_extra_features[np.isnan(val_extra_features[:, i]), i] = col_means[i]
|
|
578
|
+
|
|
579
|
+
# Create ChemProp datasets
|
|
580
|
+
train_datapoints, train_valid_idx = create_molecule_datapoints(
|
|
581
|
+
df_train[smiles_column].tolist(), df_train[target].tolist(), train_extra_features
|
|
582
|
+
)
|
|
583
|
+
val_datapoints, val_valid_idx = create_molecule_datapoints(
|
|
584
|
+
df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra_features
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
df_train = df_train.iloc[train_valid_idx].reset_index(drop=True)
|
|
588
|
+
df_val = df_val.iloc[val_valid_idx].reset_index(drop=True)
|
|
589
|
+
|
|
590
|
+
train_dataset = data.MoleculeDataset(train_datapoints)
|
|
591
|
+
val_dataset = data.MoleculeDataset(val_datapoints)
|
|
592
|
+
|
|
593
|
+
# Save raw validation features for predictions later
|
|
594
|
+
val_extra_raw = val_extra_features[val_valid_idx] if val_extra_features is not None else None
|
|
595
|
+
|
|
596
|
+
# Scale features and targets
|
|
597
|
+
x_d_transform = None
|
|
598
|
+
if use_extra_features:
|
|
599
|
+
feature_scaler = train_dataset.normalize_inputs("X_d")
|
|
600
|
+
val_dataset.normalize_inputs("X_d", feature_scaler)
|
|
601
|
+
x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
|
|
602
|
+
|
|
603
|
+
output_transform = None
|
|
604
|
+
if model_type == "regressor":
|
|
605
|
+
target_scaler = train_dataset.normalize_targets()
|
|
606
|
+
val_dataset.normalize_targets(target_scaler)
|
|
607
|
+
output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
|
|
608
|
+
|
|
609
|
+
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
610
|
+
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
611
|
+
|
|
612
|
+
# Build and train single model
|
|
613
|
+
pl.seed_everything(42)
|
|
614
|
+
mpnn = build_mpnn_model(
|
|
615
|
+
hyperparameters, task=task, num_classes=num_classes,
|
|
616
|
+
n_extra_descriptors=n_extra, x_d_transform=x_d_transform, output_transform=output_transform,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
callbacks = [
|
|
620
|
+
pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
|
|
621
|
+
pl.callbacks.ModelCheckpoint(
|
|
622
|
+
dirpath=args.model_dir, filename="best_model_0",
|
|
623
|
+
monitor="val_loss", mode="min", save_top_k=1,
|
|
624
|
+
),
|
|
625
|
+
]
|
|
626
|
+
|
|
627
|
+
trainer = pl.Trainer(
|
|
628
|
+
accelerator="auto", max_epochs=max_epochs, callbacks=callbacks,
|
|
629
|
+
logger=False, enable_progress_bar=True,
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
633
|
+
|
|
634
|
+
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
|
|
635
|
+
checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
|
|
636
|
+
mpnn.load_state_dict(checkpoint["state_dict"])
|
|
637
|
+
|
|
638
|
+
mpnn.eval()
|
|
639
|
+
ensemble_models = [mpnn]
|
|
640
|
+
|
|
641
|
+
# Make predictions on validation set
|
|
642
|
+
val_datapoints_raw, _ = create_molecule_datapoints(
|
|
643
|
+
df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra_raw
|
|
644
|
+
)
|
|
645
|
+
val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
|
|
646
|
+
val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
|
|
647
|
+
|
|
648
|
+
with torch.inference_mode():
|
|
649
|
+
val_predictions = trainer.predict(mpnn, val_loader_pred)
|
|
650
|
+
preds = np.concatenate([p.numpy() for p in val_predictions], axis=0)
|
|
651
|
+
if preds.ndim == 3 and preds.shape[1] == 1:
|
|
652
|
+
preds = preds.squeeze(axis=1)
|
|
653
|
+
|
|
654
|
+
preds_std = None
|
|
655
|
+
y_validate = df_val[target].values
|
|
656
|
+
|
|
657
|
+
# =========================================================================
|
|
658
|
+
# K-FOLD CROSS-VALIDATION (n_folds > 1) - trains n_folds models
|
|
659
|
+
# =========================================================================
|
|
660
|
+
else:
|
|
661
|
+
print(f"Training {n_folds}-fold cross-validation ensemble...")
|
|
662
|
+
|
|
663
|
+
# Validate all SMILES upfront and filter invalid ones
|
|
664
|
+
all_extra_features = None
|
|
665
|
+
if use_extra_features:
|
|
666
|
+
all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
|
|
667
|
+
col_means = np.nanmean(all_extra_features, axis=0)
|
|
668
|
+
for i in range(all_extra_features.shape[1]):
|
|
669
|
+
all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
|
|
670
|
+
else:
|
|
671
|
+
col_means = None
|
|
672
|
+
|
|
673
|
+
# Filter invalid SMILES from the full dataset
|
|
674
|
+
_, valid_indices = create_molecule_datapoints(
|
|
675
|
+
all_df[smiles_column].tolist(), all_df[target].tolist(), all_extra_features
|
|
676
|
+
)
|
|
677
|
+
all_df = all_df.iloc[valid_indices].reset_index(drop=True)
|
|
678
|
+
if all_extra_features is not None:
|
|
679
|
+
all_extra_features = all_extra_features[valid_indices]
|
|
680
|
+
print(f"Data after SMILES validation: {all_df.shape}")
|
|
681
|
+
|
|
682
|
+
# Set up K-Fold
|
|
683
|
+
if model_type == "classifier":
|
|
684
|
+
kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
685
|
+
split_target = all_df[target]
|
|
686
|
+
else:
|
|
687
|
+
kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
688
|
+
split_target = None
|
|
689
|
+
|
|
690
|
+
# Initialize storage for out-of-fold predictions
|
|
691
|
+
oof_predictions = np.full(len(all_df), np.nan, dtype=np.float64)
|
|
692
|
+
if model_type == "classifier" and num_classes and num_classes > 1:
|
|
693
|
+
oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
|
|
694
|
+
else:
|
|
695
|
+
oof_proba = None
|
|
696
|
+
|
|
697
|
+
ensemble_models = []
|
|
698
|
+
|
|
699
|
+
for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(all_df, split_target)):
|
|
700
|
+
print(f"\n{'='*50}")
|
|
701
|
+
print(f"Training Fold {fold_idx + 1}/{n_folds}")
|
|
702
|
+
print(f"{'='*50}")
|
|
703
|
+
|
|
704
|
+
# Split data for this fold
|
|
705
|
+
df_train = all_df.iloc[train_idx].reset_index(drop=True)
|
|
706
|
+
df_val = all_df.iloc[val_idx].reset_index(drop=True)
|
|
707
|
+
|
|
708
|
+
train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
|
|
709
|
+
val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
|
|
710
|
+
|
|
711
|
+
print(f"Fold {fold_idx + 1} - Train: {len(df_train)}, Val: {len(df_val)}")
|
|
712
|
+
|
|
713
|
+
# Create ChemProp datasets for this fold
|
|
714
|
+
train_datapoints, _ = create_molecule_datapoints(
|
|
715
|
+
df_train[smiles_column].tolist(), df_train[target].tolist(), train_extra
|
|
716
|
+
)
|
|
717
|
+
val_datapoints, _ = create_molecule_datapoints(
|
|
718
|
+
df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
train_dataset = data.MoleculeDataset(train_datapoints)
|
|
722
|
+
val_dataset = data.MoleculeDataset(val_datapoints)
|
|
723
|
+
|
|
724
|
+
# Save raw val features for prediction
|
|
725
|
+
val_extra_raw = val_extra.copy() if val_extra is not None else None
|
|
726
|
+
|
|
727
|
+
# Scale features and targets for this fold
|
|
728
|
+
x_d_transform = None
|
|
729
|
+
if use_extra_features:
|
|
730
|
+
feature_scaler = train_dataset.normalize_inputs("X_d")
|
|
731
|
+
val_dataset.normalize_inputs("X_d", feature_scaler)
|
|
732
|
+
x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
|
|
733
|
+
|
|
734
|
+
output_transform = None
|
|
735
|
+
if model_type == "regressor":
|
|
736
|
+
target_scaler = train_dataset.normalize_targets()
|
|
737
|
+
val_dataset.normalize_targets(target_scaler)
|
|
738
|
+
output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
|
|
739
|
+
|
|
740
|
+
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
741
|
+
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
742
|
+
|
|
743
|
+
# Build and train model for this fold
|
|
744
|
+
pl.seed_everything(42 + fold_idx)
|
|
745
|
+
mpnn = build_mpnn_model(
|
|
746
|
+
hyperparameters, task=task, num_classes=num_classes,
|
|
747
|
+
n_extra_descriptors=n_extra, x_d_transform=x_d_transform, output_transform=output_transform,
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
callbacks = [
|
|
751
|
+
pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
|
|
752
|
+
pl.callbacks.ModelCheckpoint(
|
|
753
|
+
dirpath=args.model_dir, filename=f"best_model_{fold_idx}",
|
|
754
|
+
monitor="val_loss", mode="min", save_top_k=1,
|
|
755
|
+
),
|
|
756
|
+
]
|
|
757
|
+
|
|
758
|
+
trainer = pl.Trainer(
|
|
759
|
+
accelerator="auto", max_epochs=max_epochs, callbacks=callbacks,
|
|
760
|
+
logger=False, enable_progress_bar=True,
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
764
|
+
|
|
765
|
+
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
|
|
766
|
+
checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
|
|
767
|
+
mpnn.load_state_dict(checkpoint["state_dict"])
|
|
768
|
+
|
|
769
|
+
mpnn.eval()
|
|
770
|
+
ensemble_models.append(mpnn)
|
|
771
|
+
|
|
772
|
+
# Make out-of-fold predictions using raw features
|
|
773
|
+
val_datapoints_raw, _ = create_molecule_datapoints(
|
|
774
|
+
df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra_raw
|
|
775
|
+
)
|
|
776
|
+
val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
|
|
777
|
+
val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
|
|
778
|
+
|
|
779
|
+
with torch.inference_mode():
|
|
780
|
+
fold_predictions = trainer.predict(mpnn, val_loader_pred)
|
|
781
|
+
fold_preds = np.concatenate([p.numpy() for p in fold_predictions], axis=0)
|
|
782
|
+
if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
|
|
783
|
+
fold_preds = fold_preds.squeeze(axis=1)
|
|
784
|
+
|
|
785
|
+
# Store out-of-fold predictions
|
|
786
|
+
if model_type == "classifier" and fold_preds.ndim == 2:
|
|
787
|
+
oof_predictions[val_idx] = np.argmax(fold_preds, axis=1)
|
|
788
|
+
if oof_proba is not None:
|
|
789
|
+
oof_proba[val_idx] = fold_preds
|
|
790
|
+
else:
|
|
791
|
+
oof_predictions[val_idx] = fold_preds.flatten()
|
|
792
|
+
|
|
793
|
+
print(f"Fold {fold_idx + 1} complete!")
|
|
794
|
+
|
|
795
|
+
print(f"\nCross-validation complete! Trained {len(ensemble_models)} models.")
|
|
796
|
+
|
|
797
|
+
# Use out-of-fold predictions for metrics
|
|
798
|
+
preds = oof_predictions
|
|
799
|
+
preds_std = None # Will compute from ensemble at inference time
|
|
800
|
+
y_validate = all_df[target].values
|
|
801
|
+
df_val = all_df # For saving predictions
|
|
802
|
+
|
|
803
|
+
if model_type == "classifier":
|
|
804
|
+
# Classification metrics - handle multi-class output
|
|
805
|
+
# For CV mode, preds already contains class indices; for single model, preds are probabilities
|
|
806
|
+
if preds.ndim == 2 and preds.shape[1] > 1:
|
|
807
|
+
# Multi-class probabilities: (n_samples, n_classes), take argmax
|
|
808
|
+
class_preds = np.argmax(preds, axis=1)
|
|
809
|
+
has_proba = True
|
|
810
|
+
elif preds.ndim == 1:
|
|
811
|
+
# Either class indices (CV mode) or binary probabilities
|
|
812
|
+
if n_folds > 1:
|
|
813
|
+
# CV mode: preds are already class indices
|
|
814
|
+
class_preds = preds.astype(int)
|
|
815
|
+
has_proba = False
|
|
816
|
+
else:
|
|
817
|
+
# Single model: preds are probabilities
|
|
818
|
+
class_preds = (preds > 0.5).astype(int)
|
|
819
|
+
has_proba = False
|
|
820
|
+
else:
|
|
821
|
+
# Squeeze extra dimensions if needed
|
|
822
|
+
preds = preds.squeeze()
|
|
823
|
+
if preds.ndim == 2:
|
|
824
|
+
class_preds = np.argmax(preds, axis=1)
|
|
825
|
+
has_proba = True
|
|
826
|
+
else:
|
|
827
|
+
class_preds = (preds > 0.5).astype(int)
|
|
828
|
+
has_proba = False
|
|
829
|
+
|
|
830
|
+
print(f"class_preds shape: {class_preds.shape}")
|
|
831
|
+
|
|
832
|
+
# Decode labels for metrics
|
|
833
|
+
y_validate_decoded = label_encoder.inverse_transform(y_validate.astype(int))
|
|
834
|
+
preds_decoded = label_encoder.inverse_transform(class_preds)
|
|
835
|
+
|
|
836
|
+
# Calculate metrics
|
|
837
|
+
label_names = label_encoder.classes_
|
|
838
|
+
scores = precision_recall_fscore_support(
|
|
839
|
+
y_validate_decoded, preds_decoded, average=None, labels=label_names
|
|
840
|
+
)
|
|
841
|
+
|
|
842
|
+
score_df = pd.DataFrame(
|
|
843
|
+
{
|
|
844
|
+
target: label_names,
|
|
845
|
+
"precision": scores[0],
|
|
846
|
+
"recall": scores[1],
|
|
847
|
+
"f1": scores[2],
|
|
848
|
+
"support": scores[3],
|
|
849
|
+
}
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
# Output metrics per class
|
|
853
|
+
metrics = ["precision", "recall", "f1", "support"]
|
|
854
|
+
for t in label_names:
|
|
855
|
+
for m in metrics:
|
|
856
|
+
value = score_df.loc[score_df[target] == t, m].iloc[0]
|
|
857
|
+
print(f"Metrics:{t}:{m} {value}")
|
|
858
|
+
|
|
859
|
+
# Confusion matrix
|
|
860
|
+
conf_mtx = confusion_matrix(
|
|
861
|
+
y_validate_decoded, preds_decoded, labels=label_names
|
|
862
|
+
)
|
|
863
|
+
for i, row_name in enumerate(label_names):
|
|
864
|
+
for j, col_name in enumerate(label_names):
|
|
865
|
+
value = conf_mtx[i, j]
|
|
866
|
+
print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
|
|
867
|
+
|
|
868
|
+
# Save validation predictions
|
|
869
|
+
df_val = df_val.copy()
|
|
870
|
+
df_val["prediction"] = preds_decoded
|
|
871
|
+
if has_proba and preds.ndim == 2 and preds.shape[1] > 1:
|
|
872
|
+
df_val["pred_proba"] = [p.tolist() for p in preds]
|
|
873
|
+
df_val = expand_proba_column(df_val, label_names)
|
|
874
|
+
|
|
875
|
+
else:
|
|
876
|
+
# Regression metrics
|
|
877
|
+
preds_flat = preds.flatten()
|
|
878
|
+
rmse = root_mean_squared_error(y_validate, preds_flat)
|
|
879
|
+
mae = mean_absolute_error(y_validate, preds_flat)
|
|
880
|
+
r2 = r2_score(y_validate, preds_flat)
|
|
881
|
+
print(f"RMSE: {rmse:.3f}")
|
|
882
|
+
print(f"MAE: {mae:.3f}")
|
|
883
|
+
print(f"R2: {r2:.3f}")
|
|
884
|
+
print(f"NumRows: {len(df_val)}")
|
|
885
|
+
|
|
886
|
+
df_val = df_val.copy()
|
|
887
|
+
df_val["prediction"] = preds_flat
|
|
888
|
+
|
|
889
|
+
# Add prediction_std for ensemble models
|
|
890
|
+
if preds_std is not None:
|
|
891
|
+
df_val["prediction_std"] = preds_std.flatten()
|
|
892
|
+
print(f"Ensemble std - mean: {df_val['prediction_std'].mean():.4f}, max: {df_val['prediction_std'].max():.4f}")
|
|
893
|
+
|
|
894
|
+
# Save validation predictions to S3
|
|
895
|
+
output_columns = [target, "prediction"]
|
|
896
|
+
if "prediction_std" in df_val.columns:
|
|
897
|
+
output_columns.append("prediction_std")
|
|
898
|
+
output_columns += [col for col in df_val.columns if col.endswith("_proba")]
|
|
899
|
+
wr.s3.to_csv(
|
|
900
|
+
df_val[output_columns],
|
|
901
|
+
path=f"{model_metrics_s3_path}/validation_predictions.csv",
|
|
902
|
+
index=False,
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
# Save ensemble models (n_folds models if CV, 1 model otherwise)
|
|
906
|
+
for model_idx, ens_model in enumerate(ensemble_models):
|
|
907
|
+
model_path = os.path.join(args.model_dir, f"chemprop_model_{model_idx}.pt")
|
|
908
|
+
models.save_model(model_path, ens_model)
|
|
909
|
+
print(f"Saved model {model_idx + 1} to {model_path}")
|
|
910
|
+
|
|
911
|
+
# Save ensemble metadata (n_ensemble = number of models for inference)
|
|
912
|
+
n_ensemble = len(ensemble_models)
|
|
913
|
+
ensemble_metadata = {"n_ensemble": n_ensemble, "n_folds": n_folds}
|
|
914
|
+
joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
915
|
+
print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds})")
|
|
916
|
+
|
|
917
|
+
# Save label encoder if classification
|
|
918
|
+
if label_encoder is not None:
|
|
919
|
+
joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
|
|
920
|
+
|
|
921
|
+
# Save extra feature metadata for inference (hybrid mode)
|
|
922
|
+
# Note: We don't need to save the scaler - X_d_transform is embedded in the model
|
|
923
|
+
if use_extra_features:
|
|
924
|
+
feature_metadata = {
|
|
925
|
+
"extra_feature_cols": extra_feature_cols,
|
|
926
|
+
"col_means": col_means.tolist(), # Unscaled means for NaN imputation
|
|
927
|
+
}
|
|
928
|
+
joblib.dump(
|
|
929
|
+
feature_metadata, os.path.join(args.model_dir, "feature_metadata.joblib")
|
|
930
|
+
)
|
|
931
|
+
print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")
|