workbench 0.8.205__py3-none-any.whl → 0.8.213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/api/endpoint.py +3 -6
- workbench/api/feature_set.py +1 -1
- workbench/api/model.py +5 -11
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/endpoint_core.py +63 -153
- workbench/core/artifacts/model_core.py +21 -19
- workbench/core/transforms/features_to_model/features_to_model.py +2 -2
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +1 -1
- workbench/model_script_utils/model_script_utils.py +335 -0
- workbench/model_script_utils/pytorch_utils.py +395 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +289 -666
- workbench/model_scripts/chemprop/generated_model_script.py +292 -669
- workbench/model_scripts/chemprop/model_script_utils.py +335 -0
- workbench/model_scripts/chemprop/requirements.txt +2 -10
- workbench/model_scripts/pytorch_model/generated_model_script.py +355 -612
- workbench/model_scripts/pytorch_model/model_script_utils.py +335 -0
- workbench/model_scripts/pytorch_model/pytorch.template +350 -607
- workbench/model_scripts/pytorch_model/pytorch_utils.py +395 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/script_generation.py +2 -5
- workbench/model_scripts/uq_models/generated_model_script.py +65 -422
- workbench/model_scripts/xgb_model/generated_model_script.py +349 -412
- workbench/model_scripts/xgb_model/model_script_utils.py +335 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +344 -407
- workbench/scripts/training_test.py +85 -0
- workbench/utils/chemprop_utils.py +18 -656
- workbench/utils/metrics_utils.py +172 -0
- workbench/utils/model_utils.py +104 -47
- workbench/utils/pytorch_utils.py +32 -472
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +49 -356
- workbench/web_interface/components/plugins/model_details.py +30 -68
- {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/METADATA +5 -5
- {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/RECORD +42 -31
- {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/entry_points.txt +1 -0
- workbench/model_scripts/uq_models/mapie.template +0 -605
- workbench/model_scripts/uq_models/requirements.txt +0 -1
- {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/WHEEL +0 -0
- {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/top_level.txt +0 -0
|
@@ -1,160 +1,103 @@
|
|
|
1
1
|
# ChemProp Model Template for Workbench
|
|
2
|
-
# Uses ChemProp 2.x Message Passing Neural Networks for molecular property prediction
|
|
3
2
|
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
7
|
-
#
|
|
8
|
-
#
|
|
9
|
-
# - Regression uses output_transform (UnscaleTransform) for target scaling
|
|
10
|
-
#
|
|
11
|
-
# 2. Data Handling (create_molecule_datapoints function)
|
|
12
|
-
# - MoleculeDatapoint creation with x_d (extra descriptors)
|
|
13
|
-
# - RDKit validation of SMILES
|
|
14
|
-
#
|
|
15
|
-
# 3. Scaling (training section)
|
|
16
|
-
# - Extra descriptors: normalize_inputs("X_d") + X_d_transform in model
|
|
17
|
-
# - Targets (regression): normalize_targets() + UnscaleTransform in FFN
|
|
18
|
-
# - At inference: pass RAW features, transforms handle scaling automatically
|
|
19
|
-
#
|
|
20
|
-
# 4. Training Loop (search for "pl.Trainer")
|
|
21
|
-
# - PyTorch Lightning Trainer with ChemProp MPNN
|
|
22
|
-
#
|
|
23
|
-
# AWS/SageMaker boilerplate (can skip):
|
|
24
|
-
# - input_fn, output_fn, model_fn: SageMaker serving interface
|
|
25
|
-
# - argparse, file loading, S3 writes
|
|
26
|
-
# =============================
|
|
3
|
+
# This template handles molecular property prediction using ChemProp 2.x MPNN with:
|
|
4
|
+
# - K-fold cross-validation ensemble training (or single train/val split)
|
|
5
|
+
# - Multi-task regression support
|
|
6
|
+
# - Hybrid mode (SMILES + extra molecular descriptors)
|
|
7
|
+
# - Classification (single-target only)
|
|
27
8
|
|
|
28
|
-
import glob
|
|
29
|
-
import os
|
|
30
9
|
import argparse
|
|
10
|
+
import glob
|
|
31
11
|
import json
|
|
32
|
-
|
|
12
|
+
import os
|
|
33
13
|
|
|
34
14
|
import awswrangler as wr
|
|
15
|
+
import joblib
|
|
35
16
|
import numpy as np
|
|
36
17
|
import pandas as pd
|
|
37
18
|
import torch
|
|
38
19
|
from lightning import pytorch as pl
|
|
39
|
-
from sklearn.model_selection import
|
|
20
|
+
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
|
|
40
21
|
from sklearn.preprocessing import LabelEncoder
|
|
41
|
-
from sklearn.metrics import (
|
|
42
|
-
mean_absolute_error,
|
|
43
|
-
median_absolute_error,
|
|
44
|
-
r2_score,
|
|
45
|
-
root_mean_squared_error,
|
|
46
|
-
precision_recall_fscore_support,
|
|
47
|
-
confusion_matrix,
|
|
48
|
-
)
|
|
49
|
-
from scipy.stats import spearmanr
|
|
50
|
-
import joblib
|
|
51
22
|
|
|
52
|
-
#
|
|
23
|
+
# Enable Tensor Core optimization for GPUs that support it
|
|
24
|
+
torch.set_float32_matmul_precision("medium")
|
|
25
|
+
|
|
53
26
|
from chemprop import data, models, nn
|
|
54
27
|
|
|
55
|
-
|
|
28
|
+
from model_script_utils import (
|
|
29
|
+
check_dataframe,
|
|
30
|
+
compute_classification_metrics,
|
|
31
|
+
compute_regression_metrics,
|
|
32
|
+
expand_proba_column,
|
|
33
|
+
input_fn,
|
|
34
|
+
output_fn,
|
|
35
|
+
print_classification_metrics,
|
|
36
|
+
print_confusion_matrix,
|
|
37
|
+
print_regression_metrics,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# =============================================================================
|
|
41
|
+
# Default Hyperparameters
|
|
42
|
+
# =============================================================================
|
|
43
|
+
DEFAULT_HYPERPARAMETERS = {
|
|
44
|
+
# Training
|
|
45
|
+
"n_folds": 5,
|
|
46
|
+
"max_epochs": 400,
|
|
47
|
+
"patience": 40,
|
|
48
|
+
"batch_size": 16,
|
|
49
|
+
# Message Passing
|
|
50
|
+
"hidden_dim": 700,
|
|
51
|
+
"depth": 6,
|
|
52
|
+
"dropout": 0.15,
|
|
53
|
+
# FFN
|
|
54
|
+
"ffn_hidden_dim": 2000,
|
|
55
|
+
"ffn_num_layers": 2,
|
|
56
|
+
# Random seed
|
|
57
|
+
"seed": 42,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
# Template parameters (filled in by Workbench)
|
|
56
61
|
TEMPLATE_PARAMS = {
|
|
57
62
|
"model_type": "uq_regressor",
|
|
58
|
-
"targets": ['udm_asy_res_efflux_ratio'],
|
|
59
|
-
"feature_list": ['smiles', 'smr_vsa4', 'tpsa', '
|
|
63
|
+
"targets": ['udm_asy_res_efflux_ratio'],
|
|
64
|
+
"feature_list": ['smiles', 'smr_vsa4', 'tpsa', 'numhdonors', 'nhohcount', 'nbase', 'vsa_estate3', 'fr_guanido', 'mollogp', 'peoe_vsa8', 'peoe_vsa1', 'fr_imine', 'vsa_estate2', 'estate_vsa10', 'asphericity', 'xc_3dv', 'smr_vsa3', 'charge_centroid_distance', 'c3sp3', 'nitrogen_span', 'estate_vsa2', 'minpartialcharge', 'hba_hbd_ratio', 'slogp_vsa1', 'axp_7d', 'nocount', 'vsa_estate4', 'vsa_estate6', 'estate_vsa4', 'xc_4dv', 'xc_4d', 'num_s_centers', 'vsa_estate9', 'chi2v', 'axp_5d', 'mi', 'mse', 'bcut2d_mrhi', 'smr_vsa6', 'hallkieralpha', 'balabanj', 'amphiphilic_moment', 'type_ii_pattern_count', 'minabsestateindex', 'bcut2d_mwlow', 'axp_0dv', 'slogp_vsa5', 'axp_2d', 'axp_1dv', 'xch_5d', 'peoe_vsa10'],
|
|
60
65
|
"id_column": "udm_mol_bat_id",
|
|
61
|
-
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-chemprop-
|
|
62
|
-
"hyperparameters": {
|
|
66
|
+
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-chemprop-hybrid/training",
|
|
67
|
+
"hyperparameters": {},
|
|
63
68
|
}
|
|
64
69
|
|
|
65
70
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
|
|
70
|
-
print(msg)
|
|
71
|
-
raise ValueError(msg)
|
|
72
|
-
|
|
73
|
-
|
|
71
|
+
# =============================================================================
|
|
72
|
+
# Helper Functions
|
|
73
|
+
# =============================================================================
|
|
74
74
|
def find_smiles_column(columns: list[str]) -> str:
|
|
75
|
-
"""Find
|
|
76
|
-
|
|
77
|
-
if
|
|
78
|
-
raise ValueError(
|
|
79
|
-
|
|
80
|
-
)
|
|
81
|
-
return smiles_column
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
|
|
85
|
-
"""Expands a column containing a list of probabilities into separate columns.
|
|
86
|
-
|
|
87
|
-
Handles None values for rows where predictions couldn't be made.
|
|
88
|
-
"""
|
|
89
|
-
proba_column = "pred_proba"
|
|
90
|
-
if proba_column not in df.columns:
|
|
91
|
-
raise ValueError('DataFrame does not contain a "pred_proba" column')
|
|
92
|
-
|
|
93
|
-
proba_splits = [f"{label}_proba" for label in class_labels]
|
|
94
|
-
n_classes = len(class_labels)
|
|
95
|
-
|
|
96
|
-
# Handle None values by replacing with list of NaNs
|
|
97
|
-
proba_values = []
|
|
98
|
-
for val in df[proba_column]:
|
|
99
|
-
if val is None:
|
|
100
|
-
proba_values.append([np.nan] * n_classes)
|
|
101
|
-
else:
|
|
102
|
-
proba_values.append(val)
|
|
103
|
-
|
|
104
|
-
proba_df = pd.DataFrame(proba_values, columns=proba_splits)
|
|
105
|
-
|
|
106
|
-
df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
|
|
107
|
-
df = df.reset_index(drop=True)
|
|
108
|
-
df = pd.concat([df, proba_df], axis=1)
|
|
109
|
-
return df
|
|
75
|
+
"""Find SMILES column (case-insensitive match for 'smiles')."""
|
|
76
|
+
smiles_col = next((c for c in columns if c.lower() == "smiles"), None)
|
|
77
|
+
if smiles_col is None:
|
|
78
|
+
raise ValueError("Column list must contain a 'smiles' column (case-insensitive)")
|
|
79
|
+
return smiles_col
|
|
110
80
|
|
|
111
81
|
|
|
112
82
|
def create_molecule_datapoints(
|
|
113
83
|
smiles_list: list[str],
|
|
114
|
-
targets:
|
|
84
|
+
targets: np.ndarray | None = None,
|
|
115
85
|
extra_descriptors: np.ndarray | None = None,
|
|
116
86
|
) -> tuple[list[data.MoleculeDatapoint], list[int]]:
|
|
117
|
-
"""Create ChemProp MoleculeDatapoints from SMILES strings.
|
|
118
|
-
|
|
119
|
-
Args:
|
|
120
|
-
smiles_list: List of SMILES strings
|
|
121
|
-
targets: Optional target values as 2D array (n_samples, n_targets). NaN allowed for missing targets.
|
|
122
|
-
extra_descriptors: Optional array of extra features (n_samples, n_features)
|
|
123
|
-
|
|
124
|
-
Returns:
|
|
125
|
-
Tuple of (list of MoleculeDatapoint objects, list of valid indices)
|
|
126
|
-
"""
|
|
87
|
+
"""Create ChemProp MoleculeDatapoints from SMILES strings."""
|
|
127
88
|
from rdkit import Chem
|
|
128
89
|
|
|
129
|
-
datapoints = []
|
|
130
|
-
|
|
131
|
-
invalid_count = 0
|
|
132
|
-
|
|
133
|
-
# Convert targets to 2D array if provided
|
|
134
|
-
if targets is not None:
|
|
135
|
-
targets = np.atleast_2d(np.array(targets))
|
|
136
|
-
if targets.shape[0] == 1 and len(smiles_list) > 1:
|
|
137
|
-
targets = targets.T # Shape was (1, n_samples), transpose to (n_samples, 1)
|
|
90
|
+
datapoints, valid_indices = [], []
|
|
91
|
+
targets = np.atleast_2d(np.array(targets)).T if targets is not None and np.array(targets).ndim == 1 else targets
|
|
138
92
|
|
|
139
93
|
for i, smi in enumerate(smiles_list):
|
|
140
|
-
|
|
141
|
-
mol = Chem.MolFromSmiles(smi)
|
|
142
|
-
if mol is None:
|
|
143
|
-
invalid_count += 1
|
|
94
|
+
if Chem.MolFromSmiles(smi) is None:
|
|
144
95
|
continue
|
|
145
|
-
|
|
146
|
-
# Build datapoint with optional target(s) and extra descriptors
|
|
147
|
-
# For multi-task, y is a list of values (can include NaN for missing targets)
|
|
148
96
|
y = targets[i].tolist() if targets is not None else None
|
|
149
97
|
x_d = extra_descriptors[i] if extra_descriptors is not None else None
|
|
150
|
-
|
|
151
|
-
dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
|
|
152
|
-
datapoints.append(dp)
|
|
98
|
+
datapoints.append(data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d))
|
|
153
99
|
valid_indices.append(i)
|
|
154
100
|
|
|
155
|
-
if invalid_count > 0:
|
|
156
|
-
print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
|
|
157
|
-
|
|
158
101
|
return datapoints, valid_indices
|
|
159
102
|
|
|
160
103
|
|
|
@@ -168,525 +111,306 @@ def build_mpnn_model(
|
|
|
168
111
|
output_transform: nn.UnscaleTransform | None = None,
|
|
169
112
|
task_weights: np.ndarray | None = None,
|
|
170
113
|
) -> models.MPNN:
|
|
171
|
-
"""Build an MPNN model with
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
n_targets: Number of target columns (for multi-task regression)
|
|
178
|
-
n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
|
|
179
|
-
x_d_transform: Optional transform for extra descriptors (scaling)
|
|
180
|
-
output_transform: Optional transform for regression output (unscaling targets)
|
|
181
|
-
task_weights: Optional array of weights for each task (multi-task learning)
|
|
182
|
-
|
|
183
|
-
Returns:
|
|
184
|
-
Configured MPNN model
|
|
185
|
-
"""
|
|
186
|
-
# Model hyperparameters with defaults
|
|
187
|
-
hidden_dim = hyperparameters.get("hidden_dim", 700)
|
|
188
|
-
depth = hyperparameters.get("depth", 6)
|
|
189
|
-
dropout = hyperparameters.get("dropout", 0.15)
|
|
190
|
-
ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 2000)
|
|
191
|
-
ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
|
|
192
|
-
|
|
193
|
-
# Message passing component
|
|
194
|
-
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
114
|
+
"""Build an MPNN model with specified hyperparameters."""
|
|
115
|
+
hidden_dim = hyperparameters["hidden_dim"]
|
|
116
|
+
depth = hyperparameters["depth"]
|
|
117
|
+
dropout = hyperparameters["dropout"]
|
|
118
|
+
ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
|
|
119
|
+
ffn_num_layers = hyperparameters["ffn_num_layers"]
|
|
195
120
|
|
|
196
|
-
|
|
121
|
+
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
197
122
|
agg = nn.NormAggregation()
|
|
198
|
-
|
|
199
|
-
# FFN input_dim = message passing output + extra descriptors
|
|
200
123
|
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
201
124
|
|
|
202
|
-
# Build FFN based on task type
|
|
203
125
|
if task == "classification" and num_classes is not None:
|
|
204
|
-
# Multi-class classification
|
|
205
126
|
ffn = nn.MulticlassClassificationFFN(
|
|
206
|
-
n_classes=num_classes,
|
|
207
|
-
|
|
208
|
-
hidden_dim=ffn_hidden_dim,
|
|
209
|
-
n_layers=ffn_num_layers,
|
|
210
|
-
dropout=dropout,
|
|
127
|
+
n_classes=num_classes, input_dim=ffn_input_dim,
|
|
128
|
+
hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
|
|
211
129
|
)
|
|
212
130
|
else:
|
|
213
|
-
|
|
214
|
-
# n_tasks controls the number of output heads for multi-task learning
|
|
215
|
-
# task_weights goes here (in RegressionFFN) to weight loss per task
|
|
216
|
-
weights_tensor = None
|
|
217
|
-
if task_weights is not None:
|
|
218
|
-
weights_tensor = torch.tensor(task_weights, dtype=torch.float32)
|
|
219
|
-
|
|
131
|
+
weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
|
|
220
132
|
ffn = nn.RegressionFFN(
|
|
221
|
-
input_dim=ffn_input_dim,
|
|
222
|
-
|
|
223
|
-
n_layers=ffn_num_layers,
|
|
224
|
-
dropout=dropout,
|
|
225
|
-
n_tasks=n_targets,
|
|
226
|
-
output_transform=output_transform,
|
|
227
|
-
task_weights=weights_tensor,
|
|
133
|
+
input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
|
|
134
|
+
dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
|
|
228
135
|
)
|
|
229
136
|
|
|
230
|
-
|
|
231
|
-
mpnn = models.MPNN(
|
|
232
|
-
message_passing=mp,
|
|
233
|
-
agg=agg,
|
|
234
|
-
predictor=ffn,
|
|
235
|
-
batch_norm=True,
|
|
236
|
-
metrics=None,
|
|
237
|
-
X_d_transform=x_d_transform,
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
return mpnn
|
|
137
|
+
return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
|
|
241
138
|
|
|
242
139
|
|
|
140
|
+
# =============================================================================
|
|
141
|
+
# Model Loading (for SageMaker inference)
|
|
142
|
+
# =============================================================================
|
|
243
143
|
def model_fn(model_dir: str) -> dict:
|
|
244
|
-
"""Load
|
|
245
|
-
|
|
246
|
-
Args:
|
|
247
|
-
model_dir: Directory containing the saved models
|
|
248
|
-
|
|
249
|
-
Returns:
|
|
250
|
-
Dictionary with ensemble models and metadata
|
|
251
|
-
"""
|
|
252
|
-
# Load ensemble metadata (required)
|
|
253
|
-
ensemble_metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
|
|
254
|
-
ensemble_metadata = joblib.load(ensemble_metadata_path)
|
|
255
|
-
n_ensemble = ensemble_metadata["n_ensemble"]
|
|
256
|
-
target_columns = ensemble_metadata["target_columns"]
|
|
257
|
-
|
|
258
|
-
# Load all ensemble models
|
|
144
|
+
"""Load ChemProp MPNN ensemble from the specified directory."""
|
|
145
|
+
metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
|
|
259
146
|
ensemble_models = []
|
|
260
|
-
for
|
|
261
|
-
|
|
262
|
-
model = models.MPNN.load_from_file(model_path)
|
|
147
|
+
for i in range(metadata["n_ensemble"]):
|
|
148
|
+
model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
|
|
263
149
|
model.eval()
|
|
264
150
|
ensemble_models.append(model)
|
|
265
151
|
|
|
266
|
-
print(f"Loaded {len(ensemble_models)}
|
|
267
|
-
|
|
268
|
-
return {
|
|
269
|
-
"ensemble_models": ensemble_models,
|
|
270
|
-
"n_ensemble": n_ensemble,
|
|
271
|
-
"target_columns": target_columns,
|
|
272
|
-
}
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
def input_fn(input_data, content_type: str) -> pd.DataFrame:
|
|
276
|
-
"""Parse input data and return a DataFrame."""
|
|
277
|
-
if not input_data:
|
|
278
|
-
raise ValueError("Empty input data is not supported!")
|
|
279
|
-
|
|
280
|
-
if isinstance(input_data, bytes):
|
|
281
|
-
input_data = input_data.decode("utf-8")
|
|
282
|
-
|
|
283
|
-
if "text/csv" in content_type:
|
|
284
|
-
return pd.read_csv(StringIO(input_data))
|
|
285
|
-
elif "application/json" in content_type:
|
|
286
|
-
return pd.DataFrame(json.loads(input_data))
|
|
287
|
-
else:
|
|
288
|
-
raise ValueError(f"{content_type} not supported!")
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
292
|
-
"""Supports both CSV and JSON output formats."""
|
|
293
|
-
if "text/csv" in accept_type:
|
|
294
|
-
csv_output = output_df.fillna("N/A").to_csv(index=False)
|
|
295
|
-
return csv_output, "text/csv"
|
|
296
|
-
elif "application/json" in accept_type:
|
|
297
|
-
return output_df.to_json(orient="records"), "application/json"
|
|
298
|
-
else:
|
|
299
|
-
raise RuntimeError(
|
|
300
|
-
f"{accept_type} accept type is not supported by this script."
|
|
301
|
-
)
|
|
152
|
+
print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
|
|
153
|
+
return {"ensemble_models": ensemble_models, "n_ensemble": metadata["n_ensemble"], "target_columns": metadata["target_columns"]}
|
|
302
154
|
|
|
303
155
|
|
|
156
|
+
# =============================================================================
|
|
157
|
+
# Inference (for SageMaker inference)
|
|
158
|
+
# =============================================================================
|
|
304
159
|
def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
305
|
-
"""Make predictions with
|
|
306
|
-
|
|
307
|
-
Args:
|
|
308
|
-
df: Input DataFrame containing SMILES column (and extra features if hybrid mode)
|
|
309
|
-
model_dict: Dictionary containing ensemble models and metadata
|
|
310
|
-
|
|
311
|
-
Returns:
|
|
312
|
-
DataFrame with predictions added (and prediction_std for ensembles)
|
|
313
|
-
"""
|
|
160
|
+
"""Make predictions with ChemProp MPNN ensemble."""
|
|
314
161
|
model_type = TEMPLATE_PARAMS["model_type"]
|
|
315
162
|
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
316
163
|
|
|
317
|
-
# Extract ensemble models and metadata
|
|
318
164
|
ensemble_models = model_dict["ensemble_models"]
|
|
319
|
-
n_ensemble = model_dict["n_ensemble"]
|
|
320
165
|
target_columns = model_dict["target_columns"]
|
|
321
166
|
|
|
322
|
-
# Load
|
|
167
|
+
# Load artifacts
|
|
323
168
|
label_encoder = None
|
|
324
|
-
|
|
325
|
-
if os.path.exists(
|
|
326
|
-
label_encoder = joblib.load(
|
|
169
|
+
encoder_path = os.path.join(model_dir, "label_encoder.joblib")
|
|
170
|
+
if os.path.exists(encoder_path):
|
|
171
|
+
label_encoder = joblib.load(encoder_path)
|
|
327
172
|
|
|
328
|
-
# Load feature metadata if present (hybrid mode)
|
|
329
|
-
# Contains column names, NaN fill values, and scaler for feature scaling
|
|
330
173
|
feature_metadata = None
|
|
331
|
-
|
|
332
|
-
if os.path.exists(
|
|
333
|
-
feature_metadata = joblib.load(
|
|
334
|
-
print(
|
|
335
|
-
f"Hybrid mode: using {len(feature_metadata['extra_feature_cols'])} extra features"
|
|
336
|
-
)
|
|
174
|
+
feature_path = os.path.join(model_dir, "feature_metadata.joblib")
|
|
175
|
+
if os.path.exists(feature_path):
|
|
176
|
+
feature_metadata = joblib.load(feature_path)
|
|
177
|
+
print(f"Hybrid mode: {len(feature_metadata['extra_feature_cols'])} extra features")
|
|
337
178
|
|
|
338
|
-
# Find SMILES column
|
|
179
|
+
# Find SMILES column and validate
|
|
339
180
|
smiles_column = find_smiles_column(df.columns.tolist())
|
|
340
|
-
|
|
341
181
|
smiles_list = df[smiles_column].tolist()
|
|
342
182
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
valid_smiles = []
|
|
346
|
-
valid_indices = []
|
|
347
|
-
for i, smi in enumerate(smiles_list):
|
|
348
|
-
if smi and isinstance(smi, str) and len(smi.strip()) > 0:
|
|
349
|
-
valid_mask.append(True)
|
|
350
|
-
valid_smiles.append(smi.strip())
|
|
351
|
-
valid_indices.append(i)
|
|
352
|
-
else:
|
|
353
|
-
valid_mask.append(False)
|
|
354
|
-
|
|
355
|
-
valid_mask = np.array(valid_mask)
|
|
183
|
+
valid_mask = np.array([bool(s and isinstance(s, str) and s.strip()) for s in smiles_list])
|
|
184
|
+
valid_smiles = [s.strip() for i, s in enumerate(smiles_list) if valid_mask[i]]
|
|
356
185
|
print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
|
|
357
186
|
|
|
358
|
-
# Initialize
|
|
187
|
+
# Initialize output columns
|
|
359
188
|
if model_type == "classifier":
|
|
360
189
|
df["prediction"] = pd.Series([None] * len(df), dtype=object)
|
|
361
190
|
else:
|
|
362
|
-
# Regression: create prediction column for each target
|
|
363
191
|
for tc in target_columns:
|
|
364
192
|
df[f"{tc}_pred"] = np.nan
|
|
365
193
|
df[f"{tc}_pred_std"] = np.nan
|
|
366
194
|
|
|
367
195
|
if sum(valid_mask) == 0:
|
|
368
|
-
print("Warning: No valid SMILES to predict on")
|
|
369
196
|
return df
|
|
370
197
|
|
|
371
|
-
# Prepare extra features
|
|
372
|
-
# NOTE: We pass RAW (unscaled) features here - the model's X_d_transform handles scaling
|
|
198
|
+
# Prepare extra features (raw, unscaled - model handles scaling)
|
|
373
199
|
extra_features = None
|
|
374
200
|
if feature_metadata is not None:
|
|
375
|
-
|
|
201
|
+
extra_cols = feature_metadata["extra_feature_cols"]
|
|
376
202
|
col_means = np.array(feature_metadata["col_means"])
|
|
203
|
+
valid_indices = np.where(valid_mask)[0]
|
|
377
204
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
if missing_cols:
|
|
381
|
-
print(
|
|
382
|
-
f"Warning: Missing extra feature columns: {missing_cols}. Using mean values."
|
|
383
|
-
)
|
|
384
|
-
|
|
385
|
-
# Extract features for valid SMILES rows (raw, unscaled)
|
|
386
|
-
extra_features = np.zeros(
|
|
387
|
-
(len(valid_indices), len(extra_feature_cols)), dtype=np.float32
|
|
388
|
-
)
|
|
389
|
-
for j, col in enumerate(extra_feature_cols):
|
|
205
|
+
extra_features = np.zeros((len(valid_indices), len(extra_cols)), dtype=np.float32)
|
|
206
|
+
for j, col in enumerate(extra_cols):
|
|
390
207
|
if col in df.columns:
|
|
391
208
|
values = df.iloc[valid_indices][col].values.astype(np.float32)
|
|
392
|
-
|
|
393
|
-
nan_mask = np.isnan(values)
|
|
394
|
-
values[nan_mask] = col_means[j]
|
|
209
|
+
values[np.isnan(values)] = col_means[j]
|
|
395
210
|
extra_features[:, j] = values
|
|
396
211
|
else:
|
|
397
|
-
# Column missing, use training mean
|
|
398
212
|
extra_features[:, j] = col_means[j]
|
|
399
213
|
|
|
400
|
-
# Create datapoints
|
|
401
|
-
datapoints,
|
|
402
|
-
valid_smiles, extra_descriptors=extra_features
|
|
403
|
-
)
|
|
404
|
-
|
|
214
|
+
# Create datapoints and predict
|
|
215
|
+
datapoints, rdkit_valid = create_molecule_datapoints(valid_smiles, extra_descriptors=extra_features)
|
|
405
216
|
if len(datapoints) == 0:
|
|
406
|
-
print("Warning: No valid SMILES after RDKit validation")
|
|
407
217
|
return df
|
|
408
218
|
|
|
409
219
|
dataset = data.MoleculeDataset(datapoints)
|
|
410
220
|
dataloader = data.build_dataloader(dataset, shuffle=False)
|
|
221
|
+
trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
411
222
|
|
|
412
|
-
#
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
logger=False,
|
|
416
|
-
enable_progress_bar=False,
|
|
417
|
-
)
|
|
418
|
-
|
|
419
|
-
# Collect predictions from all ensemble members
|
|
420
|
-
all_ensemble_preds = []
|
|
421
|
-
for ens_idx, ens_model in enumerate(ensemble_models):
|
|
223
|
+
# Ensemble predictions
|
|
224
|
+
all_preds = []
|
|
225
|
+
for model in ensemble_models:
|
|
422
226
|
with torch.inference_mode():
|
|
423
|
-
predictions = trainer.predict(
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
ensemble_preds = np.stack(all_ensemble_preds, axis=0)
|
|
432
|
-
preds = np.mean(ensemble_preds, axis=0)
|
|
433
|
-
preds_std = np.std(ensemble_preds, axis=0) # Will be 0s for n_ensemble=1
|
|
434
|
-
|
|
435
|
-
# Ensure 2D: (n_samples, n_targets)
|
|
227
|
+
predictions = trainer.predict(model, dataloader)
|
|
228
|
+
preds = np.concatenate([p.numpy() for p in predictions], axis=0)
|
|
229
|
+
if preds.ndim == 3 and preds.shape[1] == 1:
|
|
230
|
+
preds = preds.squeeze(axis=1)
|
|
231
|
+
all_preds.append(preds)
|
|
232
|
+
|
|
233
|
+
preds = np.mean(np.stack(all_preds), axis=0)
|
|
234
|
+
preds_std = np.std(np.stack(all_preds), axis=0)
|
|
436
235
|
if preds.ndim == 1:
|
|
437
|
-
preds = preds.reshape(-1, 1)
|
|
438
|
-
preds_std = preds_std.reshape(-1, 1)
|
|
236
|
+
preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
|
|
439
237
|
|
|
440
|
-
print(f"Inference
|
|
238
|
+
print(f"Inference complete: {preds.shape[0]} predictions")
|
|
441
239
|
|
|
442
|
-
# Map predictions back to
|
|
443
|
-
|
|
444
|
-
valid_positions = np.where(valid_mask)[0][rdkit_valid_indices]
|
|
240
|
+
# Map predictions back to valid positions
|
|
241
|
+
valid_positions = np.where(valid_mask)[0][rdkit_valid]
|
|
445
242
|
valid_mask = np.zeros(len(df), dtype=bool)
|
|
446
243
|
valid_mask[valid_positions] = True
|
|
447
244
|
|
|
448
245
|
if model_type == "classifier" and label_encoder is not None:
|
|
449
|
-
|
|
450
|
-
if preds.ndim == 2 and preds.shape[1] > 1:
|
|
451
|
-
# Multi-class: preds are probabilities (averaged across ensemble)
|
|
246
|
+
if preds.shape[1] > 1:
|
|
452
247
|
class_preds = np.argmax(preds, axis=1)
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
proba_series = pd.Series([None] * len(df), index=df.index, dtype=object)
|
|
458
|
-
proba_series.loc[valid_mask] = [p.tolist() for p in preds]
|
|
459
|
-
df["pred_proba"] = proba_series
|
|
248
|
+
df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform(class_preds)
|
|
249
|
+
proba = pd.Series([None] * len(df), dtype=object)
|
|
250
|
+
proba.loc[valid_mask] = [p.tolist() for p in preds]
|
|
251
|
+
df["pred_proba"] = proba
|
|
460
252
|
df = expand_proba_column(df, label_encoder.classes_)
|
|
461
253
|
else:
|
|
462
|
-
|
|
463
|
-
class_preds = (preds.flatten() > 0.5).astype(int)
|
|
464
|
-
decoded_preds = label_encoder.inverse_transform(class_preds)
|
|
465
|
-
df.loc[valid_mask, "prediction"] = decoded_preds
|
|
254
|
+
df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform((preds.flatten() > 0.5).astype(int))
|
|
466
255
|
else:
|
|
467
|
-
# Regression: store predictions for each target
|
|
468
256
|
for t_idx, tc in enumerate(target_columns):
|
|
469
257
|
df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
|
|
470
258
|
df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
first_target = target_columns[0]
|
|
474
|
-
df["prediction"] = df[f"{first_target}_pred"]
|
|
475
|
-
df["prediction_std"] = df[f"{first_target}_pred_std"]
|
|
259
|
+
df["prediction"] = df[f"{target_columns[0]}_pred"]
|
|
260
|
+
df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
|
|
476
261
|
|
|
477
262
|
return df
|
|
478
263
|
|
|
479
264
|
|
|
265
|
+
# =============================================================================
|
|
266
|
+
# Training
|
|
267
|
+
# =============================================================================
|
|
480
268
|
if __name__ == "__main__":
|
|
481
|
-
|
|
269
|
+
# -------------------------------------------------------------------------
|
|
270
|
+
# Setup: Parse arguments and load data
|
|
271
|
+
# -------------------------------------------------------------------------
|
|
272
|
+
parser = argparse.ArgumentParser()
|
|
273
|
+
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
274
|
+
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
275
|
+
parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
|
|
276
|
+
args = parser.parse_args()
|
|
482
277
|
|
|
483
|
-
#
|
|
484
|
-
target_columns = TEMPLATE_PARAMS["targets"]
|
|
278
|
+
# Extract template parameters
|
|
279
|
+
target_columns = TEMPLATE_PARAMS["targets"]
|
|
485
280
|
model_type = TEMPLATE_PARAMS["model_type"]
|
|
486
281
|
feature_list = TEMPLATE_PARAMS["feature_list"]
|
|
487
282
|
id_column = TEMPLATE_PARAMS["id_column"]
|
|
488
283
|
model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
|
|
489
|
-
hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
|
|
284
|
+
hyperparameters = {**DEFAULT_HYPERPARAMETERS, **(TEMPLATE_PARAMS["hyperparameters"] or {})}
|
|
490
285
|
|
|
491
|
-
|
|
492
|
-
if not target_columns or not isinstance(target_columns, list) or len(target_columns) == 0:
|
|
286
|
+
if not target_columns or not isinstance(target_columns, list):
|
|
493
287
|
raise ValueError("'targets' must be a non-empty list of target column names")
|
|
494
288
|
n_targets = len(target_columns)
|
|
495
|
-
print(f"Target columns ({n_targets}): {target_columns}")
|
|
496
289
|
|
|
497
|
-
# Get the SMILES column name from feature_list (user defines this, so we use their exact name)
|
|
498
290
|
smiles_column = find_smiles_column(feature_list)
|
|
499
291
|
extra_feature_cols = [f for f in feature_list if f != smiles_column]
|
|
500
292
|
use_extra_features = len(extra_feature_cols) > 0
|
|
501
|
-
print(f"Feature List: {feature_list}")
|
|
502
|
-
print(f"SMILES Column: {smiles_column}")
|
|
503
|
-
print(
|
|
504
|
-
f"Extra Features (hybrid mode): {extra_feature_cols if use_extra_features else 'None (SMILES only)'}"
|
|
505
|
-
)
|
|
506
293
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
)
|
|
512
|
-
parser.add_argument(
|
|
513
|
-
"--train",
|
|
514
|
-
type=str,
|
|
515
|
-
default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"),
|
|
516
|
-
)
|
|
517
|
-
parser.add_argument(
|
|
518
|
-
"--output-data-dir",
|
|
519
|
-
type=str,
|
|
520
|
-
default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"),
|
|
521
|
-
)
|
|
522
|
-
args = parser.parse_args()
|
|
294
|
+
print(f"Target columns ({n_targets}): {target_columns}")
|
|
295
|
+
print(f"SMILES column: {smiles_column}")
|
|
296
|
+
print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
|
|
297
|
+
print(f"Hyperparameters: {hyperparameters}")
|
|
523
298
|
|
|
524
|
-
#
|
|
525
|
-
training_files = [
|
|
526
|
-
os.path.join(args.train, f)
|
|
527
|
-
for f in os.listdir(args.train)
|
|
528
|
-
if f.endswith(".csv")
|
|
529
|
-
]
|
|
299
|
+
# Load training data
|
|
300
|
+
training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
|
|
530
301
|
print(f"Training Files: {training_files}")
|
|
531
|
-
|
|
532
302
|
all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
|
|
533
|
-
print(f"All Data Shape: {all_df.shape}")
|
|
534
|
-
|
|
535
303
|
check_dataframe(all_df, "training_df")
|
|
536
304
|
|
|
537
|
-
#
|
|
305
|
+
# Clean data
|
|
538
306
|
initial_count = len(all_df)
|
|
539
307
|
all_df = all_df.dropna(subset=[smiles_column])
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
print(f"Dropped {dropped} rows with missing SMILES or all target values")
|
|
546
|
-
|
|
547
|
-
print(f"Target columns: {target_columns}")
|
|
548
|
-
print(f"Data Shape after cleaning: {all_df.shape}")
|
|
308
|
+
all_df = all_df[all_df[target_columns].notna().any(axis=1)]
|
|
309
|
+
if len(all_df) < initial_count:
|
|
310
|
+
print(f"Dropped {initial_count - len(all_df)} rows with missing SMILES/targets")
|
|
311
|
+
|
|
312
|
+
print(f"Data shape: {all_df.shape}")
|
|
549
313
|
for tc in target_columns:
|
|
550
|
-
|
|
551
|
-
print(f" {tc}: {n_valid} samples with values")
|
|
314
|
+
print(f" {tc}: {all_df[tc].notna().sum()} samples")
|
|
552
315
|
|
|
553
|
-
#
|
|
316
|
+
# -------------------------------------------------------------------------
|
|
317
|
+
# Classification setup
|
|
318
|
+
# -------------------------------------------------------------------------
|
|
554
319
|
label_encoder = None
|
|
320
|
+
num_classes = None
|
|
555
321
|
if model_type == "classifier":
|
|
556
322
|
if n_targets > 1:
|
|
557
|
-
raise ValueError("Multi-task classification
|
|
323
|
+
raise ValueError("Multi-task classification not supported")
|
|
558
324
|
label_encoder = LabelEncoder()
|
|
559
325
|
all_df[target_columns[0]] = label_encoder.fit_transform(all_df[target_columns[0]])
|
|
560
326
|
num_classes = len(label_encoder.classes_)
|
|
561
|
-
print(
|
|
562
|
-
f"Classification task with {num_classes} classes: {label_encoder.classes_}"
|
|
563
|
-
)
|
|
564
|
-
else:
|
|
565
|
-
num_classes = None
|
|
327
|
+
print(f"Classification: {num_classes} classes: {label_encoder.classes_}")
|
|
566
328
|
|
|
567
|
-
#
|
|
568
|
-
|
|
329
|
+
# -------------------------------------------------------------------------
|
|
330
|
+
# Prepare features
|
|
331
|
+
# -------------------------------------------------------------------------
|
|
569
332
|
task = "classification" if model_type == "classifier" else "regression"
|
|
570
333
|
n_extra = len(extra_feature_cols) if use_extra_features else 0
|
|
571
|
-
max_epochs = hyperparameters.get("max_epochs", 400)
|
|
572
|
-
patience = hyperparameters.get("patience", 40)
|
|
573
|
-
n_folds = hyperparameters.get("n_folds", 5) # Number of CV folds (default: 5)
|
|
574
|
-
batch_size = hyperparameters.get("batch_size", 16)
|
|
575
334
|
|
|
576
|
-
|
|
577
|
-
if use_extra_features:
|
|
578
|
-
missing_cols = [col for col in extra_feature_cols if col not in all_df.columns]
|
|
579
|
-
if missing_cols:
|
|
580
|
-
raise ValueError(f"Missing extra feature columns in training data: {missing_cols}")
|
|
581
|
-
|
|
582
|
-
# =========================================================================
|
|
583
|
-
# UNIFIED TRAINING: Works for n_folds=1 (single model) or n_folds>1 (K-fold CV)
|
|
584
|
-
# =========================================================================
|
|
585
|
-
print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold cross-validation ensemble'}...")
|
|
586
|
-
|
|
587
|
-
# Prepare extra features and validate SMILES upfront
|
|
588
|
-
all_extra_features = None
|
|
589
|
-
col_means = None
|
|
335
|
+
all_extra_features, col_means = None, None
|
|
590
336
|
if use_extra_features:
|
|
591
337
|
all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
|
|
592
338
|
col_means = np.nanmean(all_extra_features, axis=0)
|
|
593
339
|
for i in range(all_extra_features.shape[1]):
|
|
594
340
|
all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
|
|
595
341
|
|
|
596
|
-
# Prepare target array: always 2D (n_samples, n_targets)
|
|
597
342
|
all_targets = all_df[target_columns].values.astype(np.float32)
|
|
598
343
|
|
|
599
|
-
# Filter invalid SMILES
|
|
600
|
-
_, valid_indices = create_molecule_datapoints(
|
|
601
|
-
all_df[smiles_column].tolist(), all_targets, all_extra_features
|
|
602
|
-
)
|
|
344
|
+
# Filter invalid SMILES
|
|
345
|
+
_, valid_indices = create_molecule_datapoints(all_df[smiles_column].tolist(), all_targets, all_extra_features)
|
|
603
346
|
all_df = all_df.iloc[valid_indices].reset_index(drop=True)
|
|
604
347
|
all_targets = all_targets[valid_indices]
|
|
605
348
|
if all_extra_features is not None:
|
|
606
349
|
all_extra_features = all_extra_features[valid_indices]
|
|
607
350
|
print(f"Data after SMILES validation: {all_df.shape}")
|
|
608
351
|
|
|
609
|
-
#
|
|
610
|
-
# Weight = inverse of sample count (normalized so min weight = 1.0)
|
|
611
|
-
# This gives higher weight to targets with fewer samples
|
|
352
|
+
# Task weights for multi-task (inverse sample count)
|
|
612
353
|
task_weights = None
|
|
613
354
|
if n_targets > 1 and model_type != "classifier":
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
355
|
+
counts = np.array([np.sum(~np.isnan(all_targets[:, t])) for t in range(n_targets)])
|
|
356
|
+
task_weights = (1.0 / counts) / (1.0 / counts).min()
|
|
357
|
+
print(f"Task weights: {dict(zip(target_columns, task_weights.round(3)))}")
|
|
358
|
+
|
|
359
|
+
# -------------------------------------------------------------------------
|
|
360
|
+
# Cross-validation setup
|
|
361
|
+
# -------------------------------------------------------------------------
|
|
362
|
+
n_folds = hyperparameters["n_folds"]
|
|
363
|
+
batch_size = hyperparameters["batch_size"]
|
|
622
364
|
|
|
623
|
-
# Create fold splits
|
|
624
365
|
if n_folds == 1:
|
|
625
|
-
# Single fold: use train/val split from "training" column or random split
|
|
626
366
|
if "training" in all_df.columns:
|
|
627
|
-
print("
|
|
367
|
+
print("Using 'training' column for train/val split")
|
|
628
368
|
train_idx = np.where(all_df["training"])[0]
|
|
629
369
|
val_idx = np.where(~all_df["training"])[0]
|
|
630
370
|
else:
|
|
631
|
-
print("WARNING: No training column
|
|
632
|
-
|
|
633
|
-
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
|
|
371
|
+
print("WARNING: No 'training' column, using random 80/20 split")
|
|
372
|
+
train_idx, val_idx = train_test_split(np.arange(len(all_df)), test_size=0.2, random_state=42)
|
|
634
373
|
folds = [(train_idx, val_idx)]
|
|
635
374
|
else:
|
|
636
|
-
# K-Fold CV
|
|
637
375
|
if model_type == "classifier":
|
|
638
376
|
kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
639
|
-
|
|
377
|
+
folds = list(kfold.split(all_df, all_df[target_columns[0]]))
|
|
640
378
|
else:
|
|
641
379
|
kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
642
|
-
|
|
643
|
-
folds = list(kfold.split(all_df, split_target))
|
|
380
|
+
folds = list(kfold.split(all_df))
|
|
644
381
|
|
|
645
|
-
|
|
382
|
+
print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
|
|
383
|
+
|
|
384
|
+
# -------------------------------------------------------------------------
|
|
385
|
+
# Training loop
|
|
386
|
+
# -------------------------------------------------------------------------
|
|
646
387
|
oof_predictions = np.full((len(all_df), n_targets), np.nan, dtype=np.float64)
|
|
647
|
-
if model_type == "classifier" and num_classes
|
|
648
|
-
oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
|
|
649
|
-
else:
|
|
650
|
-
oof_proba = None
|
|
388
|
+
oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64) if model_type == "classifier" and num_classes else None
|
|
651
389
|
|
|
652
390
|
ensemble_models = []
|
|
653
|
-
|
|
654
391
|
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
655
392
|
print(f"\n{'='*50}")
|
|
656
|
-
print(f"
|
|
393
|
+
print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
|
|
657
394
|
print(f"{'='*50}")
|
|
658
395
|
|
|
659
|
-
# Split data
|
|
660
|
-
df_train = all_df.iloc[train_idx].reset_index(drop=True)
|
|
661
|
-
|
|
662
|
-
train_targets = all_targets[train_idx]
|
|
663
|
-
val_targets = all_targets[val_idx]
|
|
664
|
-
|
|
396
|
+
# Split data
|
|
397
|
+
df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
|
|
398
|
+
train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
|
|
665
399
|
train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
|
|
666
400
|
val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
|
|
667
|
-
|
|
668
|
-
print(f"Fold {fold_idx + 1} - Train: {len(df_train)}, Val: {len(df_val)}")
|
|
669
|
-
|
|
670
|
-
# Create ChemProp datasets for this fold
|
|
671
|
-
train_datapoints, _ = create_molecule_datapoints(
|
|
672
|
-
df_train[smiles_column].tolist(), train_targets, train_extra
|
|
673
|
-
)
|
|
674
|
-
val_datapoints, _ = create_molecule_datapoints(
|
|
675
|
-
df_val[smiles_column].tolist(), val_targets, val_extra
|
|
676
|
-
)
|
|
677
|
-
|
|
678
|
-
train_dataset = data.MoleculeDataset(train_datapoints)
|
|
679
|
-
val_dataset = data.MoleculeDataset(val_datapoints)
|
|
680
|
-
|
|
681
|
-
# Save raw val features for prediction
|
|
682
401
|
val_extra_raw = val_extra.copy() if val_extra is not None else None
|
|
683
402
|
|
|
684
|
-
#
|
|
403
|
+
# Create datasets
|
|
404
|
+
train_dps, _ = create_molecule_datapoints(df_train[smiles_column].tolist(), train_targets, train_extra)
|
|
405
|
+
val_dps, _ = create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra)
|
|
406
|
+
train_dataset, val_dataset = data.MoleculeDataset(train_dps), data.MoleculeDataset(val_dps)
|
|
407
|
+
|
|
408
|
+
# Scale features/targets
|
|
685
409
|
x_d_transform = None
|
|
686
410
|
if use_extra_features:
|
|
687
|
-
|
|
688
|
-
val_dataset.normalize_inputs("X_d",
|
|
689
|
-
x_d_transform = nn.ScaleTransform.from_standard_scaler(
|
|
411
|
+
scaler = train_dataset.normalize_inputs("X_d")
|
|
412
|
+
val_dataset.normalize_inputs("X_d", scaler)
|
|
413
|
+
x_d_transform = nn.ScaleTransform.from_standard_scaler(scaler)
|
|
690
414
|
|
|
691
415
|
output_transform = None
|
|
692
416
|
if model_type in ["regressor", "uq_regressor"]:
|
|
@@ -697,29 +421,24 @@ if __name__ == "__main__":
|
|
|
697
421
|
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
698
422
|
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
699
423
|
|
|
700
|
-
# Build and train model
|
|
701
|
-
pl.seed_everything(
|
|
424
|
+
# Build and train model
|
|
425
|
+
pl.seed_everything(hyperparameters["seed"] + fold_idx)
|
|
702
426
|
mpnn = build_mpnn_model(
|
|
703
427
|
hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
|
|
704
|
-
n_extra_descriptors=n_extra, x_d_transform=x_d_transform,
|
|
705
|
-
task_weights=task_weights,
|
|
428
|
+
n_extra_descriptors=n_extra, x_d_transform=x_d_transform,
|
|
429
|
+
output_transform=output_transform, task_weights=task_weights,
|
|
706
430
|
)
|
|
707
431
|
|
|
708
|
-
callbacks = [
|
|
709
|
-
pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
|
|
710
|
-
pl.callbacks.ModelCheckpoint(
|
|
711
|
-
dirpath=args.model_dir, filename=f"best_model_{fold_idx}",
|
|
712
|
-
monitor="val_loss", mode="min", save_top_k=1,
|
|
713
|
-
),
|
|
714
|
-
]
|
|
715
|
-
|
|
716
432
|
trainer = pl.Trainer(
|
|
717
|
-
accelerator="auto", max_epochs=max_epochs,
|
|
718
|
-
|
|
433
|
+
accelerator="auto", max_epochs=hyperparameters["max_epochs"], logger=False, enable_progress_bar=True,
|
|
434
|
+
callbacks=[
|
|
435
|
+
pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min"),
|
|
436
|
+
pl.callbacks.ModelCheckpoint(dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1),
|
|
437
|
+
],
|
|
719
438
|
)
|
|
720
|
-
|
|
721
439
|
trainer.fit(mpnn, train_loader, val_loader)
|
|
722
440
|
|
|
441
|
+
# Load best checkpoint
|
|
723
442
|
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
|
|
724
443
|
checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
|
|
725
444
|
mpnn.load_state_dict(checkpoint["state_dict"])
|
|
@@ -727,224 +446,128 @@ if __name__ == "__main__":
|
|
|
727
446
|
mpnn.eval()
|
|
728
447
|
ensemble_models.append(mpnn)
|
|
729
448
|
|
|
730
|
-
#
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
)
|
|
734
|
-
val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
|
|
735
|
-
val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
|
|
449
|
+
# Out-of-fold predictions (using raw features)
|
|
450
|
+
val_dps_raw, _ = create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
|
|
451
|
+
val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
|
|
736
452
|
|
|
737
453
|
with torch.inference_mode():
|
|
738
|
-
|
|
739
|
-
fold_preds = np.concatenate([p.numpy() for p in fold_predictions], axis=0)
|
|
454
|
+
fold_preds = np.concatenate([p.numpy() for p in trainer.predict(mpnn, val_loader_pred)], axis=0)
|
|
740
455
|
if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
|
|
741
456
|
fold_preds = fold_preds.squeeze(axis=1)
|
|
742
457
|
|
|
743
|
-
# Store out-of-fold predictions
|
|
744
458
|
if model_type == "classifier" and fold_preds.ndim == 2:
|
|
745
|
-
# Store class index in first column for classification
|
|
746
459
|
oof_predictions[val_idx, 0] = np.argmax(fold_preds, axis=1)
|
|
747
460
|
if oof_proba is not None:
|
|
748
461
|
oof_proba[val_idx] = fold_preds
|
|
749
462
|
else:
|
|
750
|
-
# Regression: fold_preds shape is (n_val, n_targets) or (n_val,)
|
|
751
463
|
if fold_preds.ndim == 1:
|
|
752
464
|
fold_preds = fold_preds.reshape(-1, 1)
|
|
753
465
|
oof_predictions[val_idx] = fold_preds
|
|
754
466
|
|
|
755
|
-
print(f"Fold {fold_idx + 1} complete!")
|
|
756
|
-
|
|
757
467
|
print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
|
|
758
468
|
|
|
759
|
-
#
|
|
760
|
-
#
|
|
469
|
+
# -------------------------------------------------------------------------
|
|
470
|
+
# Prepare validation results
|
|
471
|
+
# -------------------------------------------------------------------------
|
|
761
472
|
if n_folds == 1:
|
|
762
|
-
# oof_predictions is always 2D now: check if any column has a value
|
|
763
473
|
val_mask = ~np.isnan(oof_predictions).all(axis=1)
|
|
764
|
-
preds = oof_predictions[val_mask]
|
|
765
474
|
df_val = all_df[val_mask].copy()
|
|
475
|
+
preds = oof_predictions[val_mask]
|
|
766
476
|
y_validate = all_targets[val_mask]
|
|
767
477
|
if oof_proba is not None:
|
|
768
478
|
oof_proba = oof_proba[val_mask]
|
|
769
479
|
val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
|
|
770
480
|
else:
|
|
771
|
-
preds = oof_predictions
|
|
772
481
|
df_val = all_df.copy()
|
|
482
|
+
preds = oof_predictions
|
|
773
483
|
y_validate = all_targets
|
|
774
484
|
val_extra_features = all_extra_features
|
|
775
485
|
|
|
776
|
-
#
|
|
777
|
-
#
|
|
778
|
-
|
|
779
|
-
if model_type in ["regressor", "uq_regressor"] and len(ensemble_models) > 0:
|
|
780
|
-
print("Computing prediction_std from ensemble predictions on validation data...")
|
|
781
|
-
val_datapoints_for_std, _ = create_molecule_datapoints(
|
|
782
|
-
df_val[smiles_column].tolist(),
|
|
783
|
-
y_validate,
|
|
784
|
-
val_extra_features
|
|
785
|
-
)
|
|
786
|
-
val_dataset_for_std = data.MoleculeDataset(val_datapoints_for_std)
|
|
787
|
-
val_loader_for_std = data.build_dataloader(val_dataset_for_std, batch_size=batch_size, shuffle=False)
|
|
788
|
-
|
|
789
|
-
all_ensemble_preds_for_std = []
|
|
790
|
-
trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
791
|
-
for ens_model in ensemble_models:
|
|
792
|
-
with torch.inference_mode():
|
|
793
|
-
ens_preds = trainer_pred.predict(ens_model, val_loader_for_std)
|
|
794
|
-
ens_preds = np.concatenate([p.numpy() for p in ens_preds], axis=0)
|
|
795
|
-
if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
|
|
796
|
-
ens_preds = ens_preds.squeeze(axis=1)
|
|
797
|
-
all_ensemble_preds_for_std.append(ens_preds)
|
|
798
|
-
|
|
799
|
-
# Stack ensemble predictions: shape (n_ensemble, n_samples, n_targets)
|
|
800
|
-
ensemble_preds_stacked = np.stack(all_ensemble_preds_for_std, axis=0)
|
|
801
|
-
preds_std = np.std(ensemble_preds_stacked, axis=0)
|
|
802
|
-
# Ensure 2D
|
|
803
|
-
if preds_std.ndim == 1:
|
|
804
|
-
preds_std = preds_std.reshape(-1, 1)
|
|
805
|
-
print(f"Ensemble prediction_std - mean per target: {np.nanmean(preds_std, axis=0)}")
|
|
806
|
-
|
|
486
|
+
# -------------------------------------------------------------------------
|
|
487
|
+
# Compute metrics and prepare output
|
|
488
|
+
# -------------------------------------------------------------------------
|
|
807
489
|
if model_type == "classifier":
|
|
808
|
-
# Classification metrics - preds contains class indices in first column from OOF predictions
|
|
809
490
|
class_preds = preds[:, 0].astype(int)
|
|
810
|
-
has_proba = oof_proba is not None
|
|
811
|
-
|
|
812
|
-
print(f"class_preds shape: {class_preds.shape}")
|
|
813
|
-
|
|
814
|
-
# Decode labels for metrics (classification is single-target only)
|
|
815
491
|
target_name = target_columns[0]
|
|
816
|
-
|
|
492
|
+
y_true_decoded = label_encoder.inverse_transform(y_validate[:, 0].astype(int))
|
|
817
493
|
preds_decoded = label_encoder.inverse_transform(class_preds)
|
|
818
494
|
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
y_validate_decoded, preds_decoded, average=None, labels=label_names
|
|
823
|
-
)
|
|
824
|
-
|
|
825
|
-
score_df = pd.DataFrame(
|
|
826
|
-
{
|
|
827
|
-
target_name: label_names,
|
|
828
|
-
"precision": scores[0],
|
|
829
|
-
"recall": scores[1],
|
|
830
|
-
"f1": scores[2],
|
|
831
|
-
"support": scores[3],
|
|
832
|
-
}
|
|
833
|
-
)
|
|
834
|
-
|
|
835
|
-
# Output metrics per class
|
|
836
|
-
metrics = ["precision", "recall", "f1", "support"]
|
|
837
|
-
for t in label_names:
|
|
838
|
-
for m in metrics:
|
|
839
|
-
value = score_df.loc[score_df[target_name] == t, m].iloc[0]
|
|
840
|
-
print(f"Metrics:{t}:{m} {value}")
|
|
495
|
+
score_df = compute_classification_metrics(y_true_decoded, preds_decoded, label_encoder.classes_, target_name)
|
|
496
|
+
print_classification_metrics(score_df, target_name, label_encoder.classes_)
|
|
497
|
+
print_confusion_matrix(y_true_decoded, preds_decoded, label_encoder.classes_)
|
|
841
498
|
|
|
842
|
-
#
|
|
843
|
-
|
|
844
|
-
y_validate_decoded, preds_decoded, labels=label_names
|
|
845
|
-
)
|
|
846
|
-
for i, row_name in enumerate(label_names):
|
|
847
|
-
for j, col_name in enumerate(label_names):
|
|
848
|
-
value = conf_mtx[i, j]
|
|
849
|
-
print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
|
|
850
|
-
|
|
851
|
-
# Save validation predictions
|
|
852
|
-
df_val = df_val.copy()
|
|
499
|
+
# Decode target column back to string labels (was encoded for training)
|
|
500
|
+
df_val[target_name] = y_true_decoded
|
|
853
501
|
df_val["prediction"] = preds_decoded
|
|
854
|
-
if
|
|
502
|
+
if oof_proba is not None:
|
|
855
503
|
df_val["pred_proba"] = [p.tolist() for p in oof_proba]
|
|
856
|
-
df_val = expand_proba_column(df_val,
|
|
857
|
-
|
|
504
|
+
df_val = expand_proba_column(df_val, label_encoder.classes_)
|
|
858
505
|
else:
|
|
859
|
-
#
|
|
860
|
-
|
|
506
|
+
# Compute ensemble std
|
|
507
|
+
preds_std = None
|
|
508
|
+
if len(ensemble_models) > 1:
|
|
509
|
+
print("Computing prediction_std from ensemble...")
|
|
510
|
+
val_dps, _ = create_molecule_datapoints(df_val[smiles_column].tolist(), y_validate, val_extra_features)
|
|
511
|
+
val_loader = data.build_dataloader(data.MoleculeDataset(val_dps), batch_size=batch_size, shuffle=False)
|
|
512
|
+
trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
513
|
+
|
|
514
|
+
all_ens_preds = []
|
|
515
|
+
for m in ensemble_models:
|
|
516
|
+
with torch.inference_mode():
|
|
517
|
+
ens_preds = np.concatenate([p.numpy() for p in trainer_pred.predict(m, val_loader)], axis=0)
|
|
518
|
+
if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
|
|
519
|
+
ens_preds = ens_preds.squeeze(axis=1)
|
|
520
|
+
all_ens_preds.append(ens_preds)
|
|
521
|
+
preds_std = np.std(np.stack(all_ens_preds), axis=0)
|
|
522
|
+
if preds_std.ndim == 1:
|
|
523
|
+
preds_std = preds_std.reshape(-1, 1)
|
|
524
|
+
|
|
861
525
|
print("\n--- Per-target metrics ---")
|
|
862
526
|
for t_idx, t_name in enumerate(target_columns):
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
if len(y_true) > 0:
|
|
869
|
-
rmse = root_mean_squared_error(y_true, y_pred)
|
|
870
|
-
mae = mean_absolute_error(y_true, y_pred)
|
|
871
|
-
medae = median_absolute_error(y_true, y_pred)
|
|
872
|
-
r2 = r2_score(y_true, y_pred)
|
|
873
|
-
spearman_corr = spearmanr(y_true, y_pred).correlation
|
|
874
|
-
support = len(y_true)
|
|
875
|
-
# Print metrics in format expected by SageMaker metric definitions
|
|
876
|
-
print(f"rmse: {rmse:.3f}")
|
|
877
|
-
print(f"mae: {mae:.3f}")
|
|
878
|
-
print(f"medae: {medae:.3f}")
|
|
879
|
-
print(f"r2: {r2:.3f}")
|
|
880
|
-
print(f"spearmanr: {spearman_corr:.3f}")
|
|
881
|
-
print(f"support: {support}")
|
|
882
|
-
|
|
883
|
-
# Store predictions in dataframe
|
|
527
|
+
valid_mask = ~np.isnan(y_validate[:, t_idx])
|
|
528
|
+
if valid_mask.sum() > 0:
|
|
529
|
+
metrics = compute_regression_metrics(y_validate[valid_mask, t_idx], preds[valid_mask, t_idx])
|
|
530
|
+
print_regression_metrics(metrics)
|
|
531
|
+
|
|
884
532
|
df_val[f"{t_name}_pred"] = preds[:, t_idx]
|
|
885
|
-
if preds_std is not None
|
|
886
|
-
df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx]
|
|
887
|
-
else:
|
|
888
|
-
df_val[f"{t_name}_pred_std"] = 0.0
|
|
533
|
+
df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx] if preds_std is not None else 0.0
|
|
889
534
|
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
df_val["prediction"] = df_val[f"{first_target}_pred"]
|
|
893
|
-
df_val["prediction_std"] = df_val[f"{first_target}_pred_std"]
|
|
535
|
+
df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
|
|
536
|
+
df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
|
|
894
537
|
|
|
538
|
+
# -------------------------------------------------------------------------
|
|
895
539
|
# Save validation predictions to S3
|
|
896
|
-
#
|
|
897
|
-
output_columns = []
|
|
898
|
-
if id_column in df_val.columns:
|
|
899
|
-
output_columns.append(id_column)
|
|
900
|
-
# Include all target columns and their predictions
|
|
540
|
+
# -------------------------------------------------------------------------
|
|
541
|
+
output_columns = [id_column] if id_column in df_val.columns else []
|
|
901
542
|
output_columns += target_columns
|
|
902
|
-
output_columns += [f"{t}_pred" for t in target_columns]
|
|
903
|
-
output_columns += [f"{t}_pred_std" for t in target_columns]
|
|
543
|
+
output_columns += [f"{t}_pred" for t in target_columns] + [f"{t}_pred_std" for t in target_columns]
|
|
904
544
|
output_columns += ["prediction", "prediction_std"]
|
|
905
|
-
|
|
906
|
-
output_columns += [col for col in df_val.columns if col.endswith("_proba")]
|
|
907
|
-
# Filter to only columns that exist
|
|
545
|
+
output_columns += [c for c in df_val.columns if c.endswith("_proba")]
|
|
908
546
|
output_columns = [c for c in output_columns if c in df_val.columns]
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
"n_ensemble": n_ensemble,
|
|
930
|
-
"n_folds": n_folds,
|
|
931
|
-
"target_columns": target_columns,
|
|
932
|
-
}
|
|
933
|
-
joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
934
|
-
print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds}, targets={target_columns})")
|
|
935
|
-
|
|
936
|
-
# Save label encoder if classification
|
|
937
|
-
if label_encoder is not None:
|
|
547
|
+
|
|
548
|
+
wr.s3.to_csv(df_val[output_columns], f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
|
|
549
|
+
|
|
550
|
+
# -------------------------------------------------------------------------
|
|
551
|
+
# Save model artifacts
|
|
552
|
+
# -------------------------------------------------------------------------
|
|
553
|
+
for idx, m in enumerate(ensemble_models):
|
|
554
|
+
models.save_model(os.path.join(args.model_dir, f"chemprop_model_{idx}.pt"), m)
|
|
555
|
+
print(f"Saved {len(ensemble_models)} model(s)")
|
|
556
|
+
|
|
557
|
+
# Clean up checkpoints
|
|
558
|
+
for ckpt in glob.glob(os.path.join(args.model_dir, "best_*.ckpt")):
|
|
559
|
+
os.remove(ckpt)
|
|
560
|
+
|
|
561
|
+
joblib.dump({"n_ensemble": len(ensemble_models), "n_folds": n_folds, "target_columns": target_columns}, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
562
|
+
|
|
563
|
+
with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
|
|
564
|
+
json.dump(hyperparameters, f, indent=2)
|
|
565
|
+
|
|
566
|
+
if label_encoder:
|
|
938
567
|
joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
|
|
939
568
|
|
|
940
|
-
# Save extra feature metadata for inference (hybrid mode)
|
|
941
|
-
# Note: We don't need to save the scaler - X_d_transform is embedded in the model
|
|
942
569
|
if use_extra_features:
|
|
943
|
-
|
|
944
|
-
"extra_feature_cols": extra_feature_cols,
|
|
945
|
-
"col_means": col_means.tolist(), # Unscaled means for NaN imputation
|
|
946
|
-
}
|
|
947
|
-
joblib.dump(
|
|
948
|
-
feature_metadata, os.path.join(args.model_dir, "feature_metadata.joblib")
|
|
949
|
-
)
|
|
570
|
+
joblib.dump({"extra_feature_cols": extra_feature_cols, "col_means": col_means.tolist()}, os.path.join(args.model_dir, "feature_metadata.joblib"))
|
|
950
571
|
print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")
|
|
572
|
+
|
|
573
|
+
print(f"\nModel training complete! Artifacts saved to {args.model_dir}")
|