workbench 0.8.205__py3-none-any.whl → 0.8.212__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. workbench/algorithms/models/noise_model.py +388 -0
  2. workbench/api/endpoint.py +3 -6
  3. workbench/api/feature_set.py +1 -1
  4. workbench/api/model.py +5 -11
  5. workbench/cached/cached_model.py +4 -4
  6. workbench/core/artifacts/endpoint_core.py +57 -145
  7. workbench/core/artifacts/model_core.py +21 -19
  8. workbench/core/transforms/features_to_model/features_to_model.py +2 -2
  9. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +1 -1
  10. workbench/model_script_utils/model_script_utils.py +335 -0
  11. workbench/model_script_utils/pytorch_utils.py +395 -0
  12. workbench/model_script_utils/uq_harness.py +278 -0
  13. workbench/model_scripts/chemprop/chemprop.template +289 -666
  14. workbench/model_scripts/chemprop/generated_model_script.py +292 -669
  15. workbench/model_scripts/chemprop/model_script_utils.py +335 -0
  16. workbench/model_scripts/chemprop/requirements.txt +2 -10
  17. workbench/model_scripts/pytorch_model/generated_model_script.py +355 -612
  18. workbench/model_scripts/pytorch_model/model_script_utils.py +335 -0
  19. workbench/model_scripts/pytorch_model/pytorch.template +350 -607
  20. workbench/model_scripts/pytorch_model/pytorch_utils.py +395 -0
  21. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  22. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  23. workbench/model_scripts/script_generation.py +2 -5
  24. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  25. workbench/model_scripts/xgb_model/generated_model_script.py +349 -412
  26. workbench/model_scripts/xgb_model/model_script_utils.py +335 -0
  27. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  28. workbench/model_scripts/xgb_model/xgb_model.template +344 -407
  29. workbench/scripts/training_test.py +85 -0
  30. workbench/utils/chemprop_utils.py +18 -656
  31. workbench/utils/metrics_utils.py +172 -0
  32. workbench/utils/model_utils.py +104 -47
  33. workbench/utils/pytorch_utils.py +32 -472
  34. workbench/utils/xgboost_local_crossfold.py +267 -0
  35. workbench/utils/xgboost_model_utils.py +49 -356
  36. workbench/web_interface/components/plugins/model_details.py +30 -68
  37. {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/METADATA +5 -5
  38. {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/RECORD +42 -31
  39. {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/entry_points.txt +1 -0
  40. workbench/model_scripts/uq_models/mapie.template +0 -605
  41. workbench/model_scripts/uq_models/requirements.txt +0 -1
  42. {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/WHEEL +0 -0
  43. {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/licenses/LICENSE +0 -0
  44. {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/top_level.txt +0 -0
@@ -1,160 +1,103 @@
1
1
  # ChemProp Model Template for Workbench
2
- # Uses ChemProp 2.x Message Passing Neural Networks for molecular property prediction
3
2
  #
4
- # === CHEMPROP REVIEW NOTES ===
5
- # This script runs on AWS SageMaker. Key areas for ChemProp review:
6
- #
7
- # 1. Model Architecture (build_mpnn_model function)
8
- # - BondMessagePassing, NormAggregation, FFN configuration
9
- # - Regression uses output_transform (UnscaleTransform) for target scaling
10
- #
11
- # 2. Data Handling (create_molecule_datapoints function)
12
- # - MoleculeDatapoint creation with x_d (extra descriptors)
13
- # - RDKit validation of SMILES
14
- #
15
- # 3. Scaling (training section)
16
- # - Extra descriptors: normalize_inputs("X_d") + X_d_transform in model
17
- # - Targets (regression): normalize_targets() + UnscaleTransform in FFN
18
- # - At inference: pass RAW features, transforms handle scaling automatically
19
- #
20
- # 4. Training Loop (search for "pl.Trainer")
21
- # - PyTorch Lightning Trainer with ChemProp MPNN
22
- #
23
- # AWS/SageMaker boilerplate (can skip):
24
- # - input_fn, output_fn, model_fn: SageMaker serving interface
25
- # - argparse, file loading, S3 writes
26
- # =============================
3
+ # This template handles molecular property prediction using ChemProp 2.x MPNN with:
4
+ # - K-fold cross-validation ensemble training (or single train/val split)
5
+ # - Multi-task regression support
6
+ # - Hybrid mode (SMILES + extra molecular descriptors)
7
+ # - Classification (single-target only)
27
8
 
28
- import glob
29
- import os
30
9
  import argparse
10
+ import glob
31
11
  import json
32
- from io import StringIO
12
+ import os
33
13
 
34
14
  import awswrangler as wr
15
+ import joblib
35
16
  import numpy as np
36
17
  import pandas as pd
37
18
  import torch
38
19
  from lightning import pytorch as pl
39
- from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
20
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
40
21
  from sklearn.preprocessing import LabelEncoder
41
- from sklearn.metrics import (
42
- mean_absolute_error,
43
- median_absolute_error,
44
- r2_score,
45
- root_mean_squared_error,
46
- precision_recall_fscore_support,
47
- confusion_matrix,
48
- )
49
- from scipy.stats import spearmanr
50
- import joblib
51
22
 
52
- # ChemProp imports
23
+ # Enable Tensor Core optimization for GPUs that support it
24
+ torch.set_float32_matmul_precision("medium")
25
+
53
26
  from chemprop import data, models, nn
54
27
 
55
- # Template Parameters
28
+ from model_script_utils import (
29
+ check_dataframe,
30
+ compute_classification_metrics,
31
+ compute_regression_metrics,
32
+ expand_proba_column,
33
+ input_fn,
34
+ output_fn,
35
+ print_classification_metrics,
36
+ print_confusion_matrix,
37
+ print_regression_metrics,
38
+ )
39
+
40
+ # =============================================================================
41
+ # Default Hyperparameters
42
+ # =============================================================================
43
+ DEFAULT_HYPERPARAMETERS = {
44
+ # Training
45
+ "n_folds": 5,
46
+ "max_epochs": 400,
47
+ "patience": 40,
48
+ "batch_size": 16,
49
+ # Message Passing
50
+ "hidden_dim": 700,
51
+ "depth": 6,
52
+ "dropout": 0.15,
53
+ # FFN
54
+ "ffn_hidden_dim": 2000,
55
+ "ffn_num_layers": 2,
56
+ # Random seed
57
+ "seed": 42,
58
+ }
59
+
60
+ # Template parameters (filled in by Workbench)
56
61
  TEMPLATE_PARAMS = {
57
62
  "model_type": "uq_regressor",
58
- "targets": ['udm_asy_res_efflux_ratio'], # List of target columns (single or multi-task)
59
- "feature_list": ['smiles', 'smr_vsa4', 'tpsa', 'nhohcount', 'mollogp', 'peoe_vsa1', 'smr_vsa3', 'nitrogen_span', 'numhdonors', 'minpartialcharge', 'vsa_estate3', 'vsa_estate6', 'tertiary_amine_count', 'hba_hbd_ratio', 'peoe_vsa8', 'estate_vsa4', 'xc_4dv', 'vsa_estate2', 'molmr', 'xp_2dv', 'mi', 'molecular_axis_length', 'vsa_estate4', 'xp_6dv', 'qed', 'estate_vsa8', 'chi1v', 'asphericity', 'axp_1d', 'bcut2d_logphi', 'kappa3', 'axp_7d', 'num_s_centers', 'amphiphilic_moment', 'molecular_asymmetry', 'charge_centroid_distance', 'estate_vsa3', 'vsa_estate8', 'aromatic_interaction_score', 'molecular_volume_3d', 'axp_7dv', 'peoe_vsa3', 'smr_vsa6', 'bcut2d_mrhi', 'radius_of_gyration', 'xpc_4dv', 'minabsestateindex', 'axp_0dv', 'chi4n', 'balabanj', 'bcut2d_mwlow'],
63
+ "targets": ['udm_asy_res_efflux_ratio'],
64
+ "feature_list": ['smiles', 'smr_vsa4', 'tpsa', 'numhdonors', 'nhohcount', 'nbase', 'vsa_estate3', 'fr_guanido', 'mollogp', 'peoe_vsa8', 'peoe_vsa1', 'fr_imine', 'vsa_estate2', 'estate_vsa10', 'asphericity', 'xc_3dv', 'smr_vsa3', 'charge_centroid_distance', 'c3sp3', 'nitrogen_span', 'estate_vsa2', 'minpartialcharge', 'hba_hbd_ratio', 'slogp_vsa1', 'axp_7d', 'nocount', 'vsa_estate4', 'vsa_estate6', 'estate_vsa4', 'xc_4dv', 'xc_4d', 'num_s_centers', 'vsa_estate9', 'chi2v', 'axp_5d', 'mi', 'mse', 'bcut2d_mrhi', 'smr_vsa6', 'hallkieralpha', 'balabanj', 'amphiphilic_moment', 'type_ii_pattern_count', 'minabsestateindex', 'bcut2d_mwlow', 'axp_0dv', 'slogp_vsa5', 'axp_2d', 'axp_1dv', 'xch_5d', 'peoe_vsa10'],
60
65
  "id_column": "udm_mol_bat_id",
61
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-chemprop-reg-hybrid/training",
62
- "hyperparameters": {'n_folds': 5, 'hidden_dim': 700, 'depth': 6, 'dropout': 0.15, 'ffn_hidden_dim': 2000, 'ffn_num_layers': 2},
66
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-chemprop-hybrid/training",
67
+ "hyperparameters": {},
63
68
  }
64
69
 
65
70
 
66
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
67
- """Check if the provided dataframe is empty and raise an exception if it is."""
68
- if df.empty:
69
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
70
- print(msg)
71
- raise ValueError(msg)
72
-
73
-
71
+ # =============================================================================
72
+ # Helper Functions
73
+ # =============================================================================
74
74
  def find_smiles_column(columns: list[str]) -> str:
75
- """Find the SMILES column name from a list (case-insensitive match for 'smiles')."""
76
- smiles_column = next((col for col in columns if col.lower() == "smiles"), None)
77
- if smiles_column is None:
78
- raise ValueError(
79
- "Column list must contain a 'smiles' column (case-insensitive)"
80
- )
81
- return smiles_column
82
-
83
-
84
- def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
85
- """Expands a column containing a list of probabilities into separate columns.
86
-
87
- Handles None values for rows where predictions couldn't be made.
88
- """
89
- proba_column = "pred_proba"
90
- if proba_column not in df.columns:
91
- raise ValueError('DataFrame does not contain a "pred_proba" column')
92
-
93
- proba_splits = [f"{label}_proba" for label in class_labels]
94
- n_classes = len(class_labels)
95
-
96
- # Handle None values by replacing with list of NaNs
97
- proba_values = []
98
- for val in df[proba_column]:
99
- if val is None:
100
- proba_values.append([np.nan] * n_classes)
101
- else:
102
- proba_values.append(val)
103
-
104
- proba_df = pd.DataFrame(proba_values, columns=proba_splits)
105
-
106
- df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
107
- df = df.reset_index(drop=True)
108
- df = pd.concat([df, proba_df], axis=1)
109
- return df
75
+ """Find SMILES column (case-insensitive match for 'smiles')."""
76
+ smiles_col = next((c for c in columns if c.lower() == "smiles"), None)
77
+ if smiles_col is None:
78
+ raise ValueError("Column list must contain a 'smiles' column (case-insensitive)")
79
+ return smiles_col
110
80
 
111
81
 
112
82
  def create_molecule_datapoints(
113
83
  smiles_list: list[str],
114
- targets: list[float] | np.ndarray | None = None,
84
+ targets: np.ndarray | None = None,
115
85
  extra_descriptors: np.ndarray | None = None,
116
86
  ) -> tuple[list[data.MoleculeDatapoint], list[int]]:
117
- """Create ChemProp MoleculeDatapoints from SMILES strings.
118
-
119
- Args:
120
- smiles_list: List of SMILES strings
121
- targets: Optional target values as 2D array (n_samples, n_targets). NaN allowed for missing targets.
122
- extra_descriptors: Optional array of extra features (n_samples, n_features)
123
-
124
- Returns:
125
- Tuple of (list of MoleculeDatapoint objects, list of valid indices)
126
- """
87
+ """Create ChemProp MoleculeDatapoints from SMILES strings."""
127
88
  from rdkit import Chem
128
89
 
129
- datapoints = []
130
- valid_indices = []
131
- invalid_count = 0
132
-
133
- # Convert targets to 2D array if provided
134
- if targets is not None:
135
- targets = np.atleast_2d(np.array(targets))
136
- if targets.shape[0] == 1 and len(smiles_list) > 1:
137
- targets = targets.T # Shape was (1, n_samples), transpose to (n_samples, 1)
90
+ datapoints, valid_indices = [], []
91
+ targets = np.atleast_2d(np.array(targets)).T if targets is not None and np.array(targets).ndim == 1 else targets
138
92
 
139
93
  for i, smi in enumerate(smiles_list):
140
- # Validate SMILES with RDKit first
141
- mol = Chem.MolFromSmiles(smi)
142
- if mol is None:
143
- invalid_count += 1
94
+ if Chem.MolFromSmiles(smi) is None:
144
95
  continue
145
-
146
- # Build datapoint with optional target(s) and extra descriptors
147
- # For multi-task, y is a list of values (can include NaN for missing targets)
148
96
  y = targets[i].tolist() if targets is not None else None
149
97
  x_d = extra_descriptors[i] if extra_descriptors is not None else None
150
-
151
- dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
152
- datapoints.append(dp)
98
+ datapoints.append(data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d))
153
99
  valid_indices.append(i)
154
100
 
155
- if invalid_count > 0:
156
- print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
157
-
158
101
  return datapoints, valid_indices
159
102
 
160
103
 
@@ -168,525 +111,306 @@ def build_mpnn_model(
168
111
  output_transform: nn.UnscaleTransform | None = None,
169
112
  task_weights: np.ndarray | None = None,
170
113
  ) -> models.MPNN:
171
- """Build an MPNN model with the specified hyperparameters.
172
-
173
- Args:
174
- hyperparameters: Dictionary of model hyperparameters
175
- task: Either "regression" or "classification"
176
- num_classes: Number of classes for classification tasks
177
- n_targets: Number of target columns (for multi-task regression)
178
- n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
179
- x_d_transform: Optional transform for extra descriptors (scaling)
180
- output_transform: Optional transform for regression output (unscaling targets)
181
- task_weights: Optional array of weights for each task (multi-task learning)
182
-
183
- Returns:
184
- Configured MPNN model
185
- """
186
- # Model hyperparameters with defaults
187
- hidden_dim = hyperparameters.get("hidden_dim", 700)
188
- depth = hyperparameters.get("depth", 6)
189
- dropout = hyperparameters.get("dropout", 0.15)
190
- ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 2000)
191
- ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
192
-
193
- # Message passing component
194
- mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
114
+ """Build an MPNN model with specified hyperparameters."""
115
+ hidden_dim = hyperparameters["hidden_dim"]
116
+ depth = hyperparameters["depth"]
117
+ dropout = hyperparameters["dropout"]
118
+ ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
119
+ ffn_num_layers = hyperparameters["ffn_num_layers"]
195
120
 
196
- # Aggregation - NormAggregation normalizes output, recommended when using extra descriptors
121
+ mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
197
122
  agg = nn.NormAggregation()
198
-
199
- # FFN input_dim = message passing output + extra descriptors
200
123
  ffn_input_dim = hidden_dim + n_extra_descriptors
201
124
 
202
- # Build FFN based on task type
203
125
  if task == "classification" and num_classes is not None:
204
- # Multi-class classification
205
126
  ffn = nn.MulticlassClassificationFFN(
206
- n_classes=num_classes,
207
- input_dim=ffn_input_dim,
208
- hidden_dim=ffn_hidden_dim,
209
- n_layers=ffn_num_layers,
210
- dropout=dropout,
127
+ n_classes=num_classes, input_dim=ffn_input_dim,
128
+ hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
211
129
  )
212
130
  else:
213
- # Regression with optional output transform to unscale predictions
214
- # n_tasks controls the number of output heads for multi-task learning
215
- # task_weights goes here (in RegressionFFN) to weight loss per task
216
- weights_tensor = None
217
- if task_weights is not None:
218
- weights_tensor = torch.tensor(task_weights, dtype=torch.float32)
219
-
131
+ weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
220
132
  ffn = nn.RegressionFFN(
221
- input_dim=ffn_input_dim,
222
- hidden_dim=ffn_hidden_dim,
223
- n_layers=ffn_num_layers,
224
- dropout=dropout,
225
- n_tasks=n_targets,
226
- output_transform=output_transform,
227
- task_weights=weights_tensor,
133
+ input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
134
+ dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
228
135
  )
229
136
 
230
- # Create the MPNN model
231
- mpnn = models.MPNN(
232
- message_passing=mp,
233
- agg=agg,
234
- predictor=ffn,
235
- batch_norm=True,
236
- metrics=None,
237
- X_d_transform=x_d_transform,
238
- )
239
-
240
- return mpnn
137
+ return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
241
138
 
242
139
 
140
+ # =============================================================================
141
+ # Model Loading (for SageMaker inference)
142
+ # =============================================================================
243
143
  def model_fn(model_dir: str) -> dict:
244
- """Load the ChemProp MPNN ensemble models from the specified directory.
245
-
246
- Args:
247
- model_dir: Directory containing the saved models
248
-
249
- Returns:
250
- Dictionary with ensemble models and metadata
251
- """
252
- # Load ensemble metadata (required)
253
- ensemble_metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
254
- ensemble_metadata = joblib.load(ensemble_metadata_path)
255
- n_ensemble = ensemble_metadata["n_ensemble"]
256
- target_columns = ensemble_metadata["target_columns"]
257
-
258
- # Load all ensemble models
144
+ """Load ChemProp MPNN ensemble from the specified directory."""
145
+ metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
259
146
  ensemble_models = []
260
- for ens_idx in range(n_ensemble):
261
- model_path = os.path.join(model_dir, f"chemprop_model_{ens_idx}.pt")
262
- model = models.MPNN.load_from_file(model_path)
147
+ for i in range(metadata["n_ensemble"]):
148
+ model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
263
149
  model.eval()
264
150
  ensemble_models.append(model)
265
151
 
266
- print(f"Loaded {len(ensemble_models)} ensemble model(s), n_targets={len(target_columns)}")
267
-
268
- return {
269
- "ensemble_models": ensemble_models,
270
- "n_ensemble": n_ensemble,
271
- "target_columns": target_columns,
272
- }
273
-
274
-
275
- def input_fn(input_data, content_type: str) -> pd.DataFrame:
276
- """Parse input data and return a DataFrame."""
277
- if not input_data:
278
- raise ValueError("Empty input data is not supported!")
279
-
280
- if isinstance(input_data, bytes):
281
- input_data = input_data.decode("utf-8")
282
-
283
- if "text/csv" in content_type:
284
- return pd.read_csv(StringIO(input_data))
285
- elif "application/json" in content_type:
286
- return pd.DataFrame(json.loads(input_data))
287
- else:
288
- raise ValueError(f"{content_type} not supported!")
289
-
290
-
291
- def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
292
- """Supports both CSV and JSON output formats."""
293
- if "text/csv" in accept_type:
294
- csv_output = output_df.fillna("N/A").to_csv(index=False)
295
- return csv_output, "text/csv"
296
- elif "application/json" in accept_type:
297
- return output_df.to_json(orient="records"), "application/json"
298
- else:
299
- raise RuntimeError(
300
- f"{accept_type} accept type is not supported by this script."
301
- )
152
+ print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
153
+ return {"ensemble_models": ensemble_models, "n_ensemble": metadata["n_ensemble"], "target_columns": metadata["target_columns"]}
302
154
 
303
155
 
156
+ # =============================================================================
157
+ # Inference (for SageMaker inference)
158
+ # =============================================================================
304
159
  def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
305
- """Make predictions with the ChemProp MPNN ensemble.
306
-
307
- Args:
308
- df: Input DataFrame containing SMILES column (and extra features if hybrid mode)
309
- model_dict: Dictionary containing ensemble models and metadata
310
-
311
- Returns:
312
- DataFrame with predictions added (and prediction_std for ensembles)
313
- """
160
+ """Make predictions with ChemProp MPNN ensemble."""
314
161
  model_type = TEMPLATE_PARAMS["model_type"]
315
162
  model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
316
163
 
317
- # Extract ensemble models and metadata
318
164
  ensemble_models = model_dict["ensemble_models"]
319
- n_ensemble = model_dict["n_ensemble"]
320
165
  target_columns = model_dict["target_columns"]
321
166
 
322
- # Load label encoder if present (classification)
167
+ # Load artifacts
323
168
  label_encoder = None
324
- label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
325
- if os.path.exists(label_encoder_path):
326
- label_encoder = joblib.load(label_encoder_path)
169
+ encoder_path = os.path.join(model_dir, "label_encoder.joblib")
170
+ if os.path.exists(encoder_path):
171
+ label_encoder = joblib.load(encoder_path)
327
172
 
328
- # Load feature metadata if present (hybrid mode)
329
- # Contains column names, NaN fill values, and scaler for feature scaling
330
173
  feature_metadata = None
331
- feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
332
- if os.path.exists(feature_metadata_path):
333
- feature_metadata = joblib.load(feature_metadata_path)
334
- print(
335
- f"Hybrid mode: using {len(feature_metadata['extra_feature_cols'])} extra features"
336
- )
174
+ feature_path = os.path.join(model_dir, "feature_metadata.joblib")
175
+ if os.path.exists(feature_path):
176
+ feature_metadata = joblib.load(feature_path)
177
+ print(f"Hybrid mode: {len(feature_metadata['extra_feature_cols'])} extra features")
337
178
 
338
- # Find SMILES column in input DataFrame
179
+ # Find SMILES column and validate
339
180
  smiles_column = find_smiles_column(df.columns.tolist())
340
-
341
181
  smiles_list = df[smiles_column].tolist()
342
182
 
343
- # Track invalid SMILES
344
- valid_mask = []
345
- valid_smiles = []
346
- valid_indices = []
347
- for i, smi in enumerate(smiles_list):
348
- if smi and isinstance(smi, str) and len(smi.strip()) > 0:
349
- valid_mask.append(True)
350
- valid_smiles.append(smi.strip())
351
- valid_indices.append(i)
352
- else:
353
- valid_mask.append(False)
354
-
355
- valid_mask = np.array(valid_mask)
183
+ valid_mask = np.array([bool(s and isinstance(s, str) and s.strip()) for s in smiles_list])
184
+ valid_smiles = [s.strip() for i, s in enumerate(smiles_list) if valid_mask[i]]
356
185
  print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
357
186
 
358
- # Initialize prediction columns (use object dtype for classifiers to avoid FutureWarning)
187
+ # Initialize output columns
359
188
  if model_type == "classifier":
360
189
  df["prediction"] = pd.Series([None] * len(df), dtype=object)
361
190
  else:
362
- # Regression: create prediction column for each target
363
191
  for tc in target_columns:
364
192
  df[f"{tc}_pred"] = np.nan
365
193
  df[f"{tc}_pred_std"] = np.nan
366
194
 
367
195
  if sum(valid_mask) == 0:
368
- print("Warning: No valid SMILES to predict on")
369
196
  return df
370
197
 
371
- # Prepare extra features if in hybrid mode
372
- # NOTE: We pass RAW (unscaled) features here - the model's X_d_transform handles scaling
198
+ # Prepare extra features (raw, unscaled - model handles scaling)
373
199
  extra_features = None
374
200
  if feature_metadata is not None:
375
- extra_feature_cols = feature_metadata["extra_feature_cols"]
201
+ extra_cols = feature_metadata["extra_feature_cols"]
376
202
  col_means = np.array(feature_metadata["col_means"])
203
+ valid_indices = np.where(valid_mask)[0]
377
204
 
378
- # Check columns exist
379
- missing_cols = [col for col in extra_feature_cols if col not in df.columns]
380
- if missing_cols:
381
- print(
382
- f"Warning: Missing extra feature columns: {missing_cols}. Using mean values."
383
- )
384
-
385
- # Extract features for valid SMILES rows (raw, unscaled)
386
- extra_features = np.zeros(
387
- (len(valid_indices), len(extra_feature_cols)), dtype=np.float32
388
- )
389
- for j, col in enumerate(extra_feature_cols):
205
+ extra_features = np.zeros((len(valid_indices), len(extra_cols)), dtype=np.float32)
206
+ for j, col in enumerate(extra_cols):
390
207
  if col in df.columns:
391
208
  values = df.iloc[valid_indices][col].values.astype(np.float32)
392
- # Fill NaN with training column means (unscaled means)
393
- nan_mask = np.isnan(values)
394
- values[nan_mask] = col_means[j]
209
+ values[np.isnan(values)] = col_means[j]
395
210
  extra_features[:, j] = values
396
211
  else:
397
- # Column missing, use training mean
398
212
  extra_features[:, j] = col_means[j]
399
213
 
400
- # Create datapoints for prediction (filter out invalid SMILES)
401
- datapoints, rdkit_valid_indices = create_molecule_datapoints(
402
- valid_smiles, extra_descriptors=extra_features
403
- )
404
-
214
+ # Create datapoints and predict
215
+ datapoints, rdkit_valid = create_molecule_datapoints(valid_smiles, extra_descriptors=extra_features)
405
216
  if len(datapoints) == 0:
406
- print("Warning: No valid SMILES after RDKit validation")
407
217
  return df
408
218
 
409
219
  dataset = data.MoleculeDataset(datapoints)
410
220
  dataloader = data.build_dataloader(dataset, shuffle=False)
221
+ trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
411
222
 
412
- # Make predictions with ensemble
413
- trainer = pl.Trainer(
414
- accelerator="auto",
415
- logger=False,
416
- enable_progress_bar=False,
417
- )
418
-
419
- # Collect predictions from all ensemble members
420
- all_ensemble_preds = []
421
- for ens_idx, ens_model in enumerate(ensemble_models):
223
+ # Ensemble predictions
224
+ all_preds = []
225
+ for model in ensemble_models:
422
226
  with torch.inference_mode():
423
- predictions = trainer.predict(ens_model, dataloader)
424
- ens_preds = np.concatenate([p.numpy() for p in predictions], axis=0)
425
- # Squeeze middle dim if present
426
- if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
427
- ens_preds = ens_preds.squeeze(axis=1)
428
- all_ensemble_preds.append(ens_preds)
429
-
430
- # Stack and compute mean/std (std is 0 for single model)
431
- ensemble_preds = np.stack(all_ensemble_preds, axis=0)
432
- preds = np.mean(ensemble_preds, axis=0)
433
- preds_std = np.std(ensemble_preds, axis=0) # Will be 0s for n_ensemble=1
434
-
435
- # Ensure 2D: (n_samples, n_targets)
227
+ predictions = trainer.predict(model, dataloader)
228
+ preds = np.concatenate([p.numpy() for p in predictions], axis=0)
229
+ if preds.ndim == 3 and preds.shape[1] == 1:
230
+ preds = preds.squeeze(axis=1)
231
+ all_preds.append(preds)
232
+
233
+ preds = np.mean(np.stack(all_preds), axis=0)
234
+ preds_std = np.std(np.stack(all_preds), axis=0)
436
235
  if preds.ndim == 1:
437
- preds = preds.reshape(-1, 1)
438
- preds_std = preds_std.reshape(-1, 1)
236
+ preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
439
237
 
440
- print(f"Inference: Ensemble predictions shape: {preds.shape}")
238
+ print(f"Inference complete: {preds.shape[0]} predictions")
441
239
 
442
- # Map predictions back to valid_mask positions (accounting for RDKit-invalid SMILES)
443
- # rdkit_valid_indices tells us which of the valid_smiles were actually valid
444
- valid_positions = np.where(valid_mask)[0][rdkit_valid_indices]
240
+ # Map predictions back to valid positions
241
+ valid_positions = np.where(valid_mask)[0][rdkit_valid]
445
242
  valid_mask = np.zeros(len(df), dtype=bool)
446
243
  valid_mask[valid_positions] = True
447
244
 
448
245
  if model_type == "classifier" and label_encoder is not None:
449
- # For classification, get class predictions and probabilities
450
- if preds.ndim == 2 and preds.shape[1] > 1:
451
- # Multi-class: preds are probabilities (averaged across ensemble)
246
+ if preds.shape[1] > 1:
452
247
  class_preds = np.argmax(preds, axis=1)
453
- decoded_preds = label_encoder.inverse_transform(class_preds)
454
- df.loc[valid_mask, "prediction"] = decoded_preds
455
-
456
- # Add probability columns
457
- proba_series = pd.Series([None] * len(df), index=df.index, dtype=object)
458
- proba_series.loc[valid_mask] = [p.tolist() for p in preds]
459
- df["pred_proba"] = proba_series
248
+ df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform(class_preds)
249
+ proba = pd.Series([None] * len(df), dtype=object)
250
+ proba.loc[valid_mask] = [p.tolist() for p in preds]
251
+ df["pred_proba"] = proba
460
252
  df = expand_proba_column(df, label_encoder.classes_)
461
253
  else:
462
- # Binary or single output
463
- class_preds = (preds.flatten() > 0.5).astype(int)
464
- decoded_preds = label_encoder.inverse_transform(class_preds)
465
- df.loc[valid_mask, "prediction"] = decoded_preds
254
+ df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform((preds.flatten() > 0.5).astype(int))
466
255
  else:
467
- # Regression: store predictions for each target
468
256
  for t_idx, tc in enumerate(target_columns):
469
257
  df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
470
258
  df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
471
-
472
- # Add prediction/prediction_std aliases for first target
473
- first_target = target_columns[0]
474
- df["prediction"] = df[f"{first_target}_pred"]
475
- df["prediction_std"] = df[f"{first_target}_pred_std"]
259
+ df["prediction"] = df[f"{target_columns[0]}_pred"]
260
+ df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
476
261
 
477
262
  return df
478
263
 
479
264
 
265
+ # =============================================================================
266
+ # Training
267
+ # =============================================================================
480
268
  if __name__ == "__main__":
481
- """Training script for ChemProp MPNN model"""
269
+ # -------------------------------------------------------------------------
270
+ # Setup: Parse arguments and load data
271
+ # -------------------------------------------------------------------------
272
+ parser = argparse.ArgumentParser()
273
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
274
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
275
+ parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
276
+ args = parser.parse_args()
482
277
 
483
- # Template Parameters
484
- target_columns = TEMPLATE_PARAMS["targets"] # List of target columns
278
+ # Extract template parameters
279
+ target_columns = TEMPLATE_PARAMS["targets"]
485
280
  model_type = TEMPLATE_PARAMS["model_type"]
486
281
  feature_list = TEMPLATE_PARAMS["feature_list"]
487
282
  id_column = TEMPLATE_PARAMS["id_column"]
488
283
  model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
489
- hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
284
+ hyperparameters = {**DEFAULT_HYPERPARAMETERS, **(TEMPLATE_PARAMS["hyperparameters"] or {})}
490
285
 
491
- # Validate target_columns
492
- if not target_columns or not isinstance(target_columns, list) or len(target_columns) == 0:
286
+ if not target_columns or not isinstance(target_columns, list):
493
287
  raise ValueError("'targets' must be a non-empty list of target column names")
494
288
  n_targets = len(target_columns)
495
- print(f"Target columns ({n_targets}): {target_columns}")
496
289
 
497
- # Get the SMILES column name from feature_list (user defines this, so we use their exact name)
498
290
  smiles_column = find_smiles_column(feature_list)
499
291
  extra_feature_cols = [f for f in feature_list if f != smiles_column]
500
292
  use_extra_features = len(extra_feature_cols) > 0
501
- print(f"Feature List: {feature_list}")
502
- print(f"SMILES Column: {smiles_column}")
503
- print(
504
- f"Extra Features (hybrid mode): {extra_feature_cols if use_extra_features else 'None (SMILES only)'}"
505
- )
506
293
 
507
- # Script arguments for input/output directories
508
- parser = argparse.ArgumentParser()
509
- parser.add_argument(
510
- "--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
511
- )
512
- parser.add_argument(
513
- "--train",
514
- type=str,
515
- default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"),
516
- )
517
- parser.add_argument(
518
- "--output-data-dir",
519
- type=str,
520
- default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"),
521
- )
522
- args = parser.parse_args()
294
+ print(f"Target columns ({n_targets}): {target_columns}")
295
+ print(f"SMILES column: {smiles_column}")
296
+ print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
297
+ print(f"Hyperparameters: {hyperparameters}")
523
298
 
524
- # Read the training data
525
- training_files = [
526
- os.path.join(args.train, f)
527
- for f in os.listdir(args.train)
528
- if f.endswith(".csv")
529
- ]
299
+ # Load training data
300
+ training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
530
301
  print(f"Training Files: {training_files}")
531
-
532
302
  all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
533
- print(f"All Data Shape: {all_df.shape}")
534
-
535
303
  check_dataframe(all_df, "training_df")
536
304
 
537
- # Drop rows with missing SMILES or all target values
305
+ # Clean data
538
306
  initial_count = len(all_df)
539
307
  all_df = all_df.dropna(subset=[smiles_column])
540
- # Keep rows that have at least one non-null target (works for single and multi-task)
541
- has_any_target = all_df[target_columns].notna().any(axis=1)
542
- all_df = all_df[has_any_target]
543
- dropped = initial_count - len(all_df)
544
- if dropped > 0:
545
- print(f"Dropped {dropped} rows with missing SMILES or all target values")
546
-
547
- print(f"Target columns: {target_columns}")
548
- print(f"Data Shape after cleaning: {all_df.shape}")
308
+ all_df = all_df[all_df[target_columns].notna().any(axis=1)]
309
+ if len(all_df) < initial_count:
310
+ print(f"Dropped {initial_count - len(all_df)} rows with missing SMILES/targets")
311
+
312
+ print(f"Data shape: {all_df.shape}")
549
313
  for tc in target_columns:
550
- n_valid = all_df[tc].notna().sum()
551
- print(f" {tc}: {n_valid} samples with values")
314
+ print(f" {tc}: {all_df[tc].notna().sum()} samples")
552
315
 
553
- # Set up label encoder for classification (single-target only)
316
+ # -------------------------------------------------------------------------
317
+ # Classification setup
318
+ # -------------------------------------------------------------------------
554
319
  label_encoder = None
320
+ num_classes = None
555
321
  if model_type == "classifier":
556
322
  if n_targets > 1:
557
- raise ValueError("Multi-task classification is not supported. Use regression for multi-task.")
323
+ raise ValueError("Multi-task classification not supported")
558
324
  label_encoder = LabelEncoder()
559
325
  all_df[target_columns[0]] = label_encoder.fit_transform(all_df[target_columns[0]])
560
326
  num_classes = len(label_encoder.classes_)
561
- print(
562
- f"Classification task with {num_classes} classes: {label_encoder.classes_}"
563
- )
564
- else:
565
- num_classes = None
327
+ print(f"Classification: {num_classes} classes: {label_encoder.classes_}")
566
328
 
567
- # Model and training configuration
568
- print(f"Hyperparameters: {hyperparameters}")
329
+ # -------------------------------------------------------------------------
330
+ # Prepare features
331
+ # -------------------------------------------------------------------------
569
332
  task = "classification" if model_type == "classifier" else "regression"
570
333
  n_extra = len(extra_feature_cols) if use_extra_features else 0
571
- max_epochs = hyperparameters.get("max_epochs", 400)
572
- patience = hyperparameters.get("patience", 40)
573
- n_folds = hyperparameters.get("n_folds", 5) # Number of CV folds (default: 5)
574
- batch_size = hyperparameters.get("batch_size", 16)
575
334
 
576
- # Check extra feature columns exist
577
- if use_extra_features:
578
- missing_cols = [col for col in extra_feature_cols if col not in all_df.columns]
579
- if missing_cols:
580
- raise ValueError(f"Missing extra feature columns in training data: {missing_cols}")
581
-
582
- # =========================================================================
583
- # UNIFIED TRAINING: Works for n_folds=1 (single model) or n_folds>1 (K-fold CV)
584
- # =========================================================================
585
- print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold cross-validation ensemble'}...")
586
-
587
- # Prepare extra features and validate SMILES upfront
588
- all_extra_features = None
589
- col_means = None
335
+ all_extra_features, col_means = None, None
590
336
  if use_extra_features:
591
337
  all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
592
338
  col_means = np.nanmean(all_extra_features, axis=0)
593
339
  for i in range(all_extra_features.shape[1]):
594
340
  all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
595
341
 
596
- # Prepare target array: always 2D (n_samples, n_targets)
597
342
  all_targets = all_df[target_columns].values.astype(np.float32)
598
343
 
599
- # Filter invalid SMILES from the full dataset
600
- _, valid_indices = create_molecule_datapoints(
601
- all_df[smiles_column].tolist(), all_targets, all_extra_features
602
- )
344
+ # Filter invalid SMILES
345
+ _, valid_indices = create_molecule_datapoints(all_df[smiles_column].tolist(), all_targets, all_extra_features)
603
346
  all_df = all_df.iloc[valid_indices].reset_index(drop=True)
604
347
  all_targets = all_targets[valid_indices]
605
348
  if all_extra_features is not None:
606
349
  all_extra_features = all_extra_features[valid_indices]
607
350
  print(f"Data after SMILES validation: {all_df.shape}")
608
351
 
609
- # Compute dynamic task weights for multi-task regression
610
- # Weight = inverse of sample count (normalized so min weight = 1.0)
611
- # This gives higher weight to targets with fewer samples
352
+ # Task weights for multi-task (inverse sample count)
612
353
  task_weights = None
613
354
  if n_targets > 1 and model_type != "classifier":
614
- sample_counts = np.array([np.sum(~np.isnan(all_targets[:, t])) for t in range(n_targets)])
615
- # Inverse weighting: fewer samples = higher weight
616
- inverse_counts = 1.0 / sample_counts
617
- # Normalize so minimum weight is 1.0
618
- task_weights = inverse_counts / inverse_counts.min()
619
- print(f"Task weights (inverse sample count):")
620
- for t_idx, t_name in enumerate(target_columns):
621
- print(f" {t_name}: {task_weights[t_idx]:.3f} (n={sample_counts[t_idx]})")
355
+ counts = np.array([np.sum(~np.isnan(all_targets[:, t])) for t in range(n_targets)])
356
+ task_weights = (1.0 / counts) / (1.0 / counts).min()
357
+ print(f"Task weights: {dict(zip(target_columns, task_weights.round(3)))}")
358
+
359
+ # -------------------------------------------------------------------------
360
+ # Cross-validation setup
361
+ # -------------------------------------------------------------------------
362
+ n_folds = hyperparameters["n_folds"]
363
+ batch_size = hyperparameters["batch_size"]
622
364
 
623
- # Create fold splits
624
365
  if n_folds == 1:
625
- # Single fold: use train/val split from "training" column or random split
626
366
  if "training" in all_df.columns:
627
- print("Found training column, splitting data based on training column")
367
+ print("Using 'training' column for train/val split")
628
368
  train_idx = np.where(all_df["training"])[0]
629
369
  val_idx = np.where(~all_df["training"])[0]
630
370
  else:
631
- print("WARNING: No training column found, splitting data with random 80/20 split")
632
- indices = np.arange(len(all_df))
633
- train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
371
+ print("WARNING: No 'training' column, using random 80/20 split")
372
+ train_idx, val_idx = train_test_split(np.arange(len(all_df)), test_size=0.2, random_state=42)
634
373
  folds = [(train_idx, val_idx)]
635
374
  else:
636
- # K-Fold CV
637
375
  if model_type == "classifier":
638
376
  kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
639
- split_target = all_df[target_columns[0]]
377
+ folds = list(kfold.split(all_df, all_df[target_columns[0]]))
640
378
  else:
641
379
  kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
642
- split_target = None
643
- folds = list(kfold.split(all_df, split_target))
380
+ folds = list(kfold.split(all_df))
644
381
 
645
- # Initialize storage for out-of-fold predictions: always 2D (n_samples, n_targets)
382
+ print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
383
+
384
+ # -------------------------------------------------------------------------
385
+ # Training loop
386
+ # -------------------------------------------------------------------------
646
387
  oof_predictions = np.full((len(all_df), n_targets), np.nan, dtype=np.float64)
647
- if model_type == "classifier" and num_classes and num_classes > 1:
648
- oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
649
- else:
650
- oof_proba = None
388
+ oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64) if model_type == "classifier" and num_classes else None
651
389
 
652
390
  ensemble_models = []
653
-
654
391
  for fold_idx, (train_idx, val_idx) in enumerate(folds):
655
392
  print(f"\n{'='*50}")
656
- print(f"Training Fold {fold_idx + 1}/{len(folds)}")
393
+ print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
657
394
  print(f"{'='*50}")
658
395
 
659
- # Split data for this fold
660
- df_train = all_df.iloc[train_idx].reset_index(drop=True)
661
- df_val = all_df.iloc[val_idx].reset_index(drop=True)
662
- train_targets = all_targets[train_idx]
663
- val_targets = all_targets[val_idx]
664
-
396
+ # Split data
397
+ df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
398
+ train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
665
399
  train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
666
400
  val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
667
-
668
- print(f"Fold {fold_idx + 1} - Train: {len(df_train)}, Val: {len(df_val)}")
669
-
670
- # Create ChemProp datasets for this fold
671
- train_datapoints, _ = create_molecule_datapoints(
672
- df_train[smiles_column].tolist(), train_targets, train_extra
673
- )
674
- val_datapoints, _ = create_molecule_datapoints(
675
- df_val[smiles_column].tolist(), val_targets, val_extra
676
- )
677
-
678
- train_dataset = data.MoleculeDataset(train_datapoints)
679
- val_dataset = data.MoleculeDataset(val_datapoints)
680
-
681
- # Save raw val features for prediction
682
401
  val_extra_raw = val_extra.copy() if val_extra is not None else None
683
402
 
684
- # Scale features and targets for this fold
403
+ # Create datasets
404
+ train_dps, _ = create_molecule_datapoints(df_train[smiles_column].tolist(), train_targets, train_extra)
405
+ val_dps, _ = create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra)
406
+ train_dataset, val_dataset = data.MoleculeDataset(train_dps), data.MoleculeDataset(val_dps)
407
+
408
+ # Scale features/targets
685
409
  x_d_transform = None
686
410
  if use_extra_features:
687
- feature_scaler = train_dataset.normalize_inputs("X_d")
688
- val_dataset.normalize_inputs("X_d", feature_scaler)
689
- x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
411
+ scaler = train_dataset.normalize_inputs("X_d")
412
+ val_dataset.normalize_inputs("X_d", scaler)
413
+ x_d_transform = nn.ScaleTransform.from_standard_scaler(scaler)
690
414
 
691
415
  output_transform = None
692
416
  if model_type in ["regressor", "uq_regressor"]:
@@ -697,29 +421,24 @@ if __name__ == "__main__":
697
421
  train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
698
422
  val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
699
423
 
700
- # Build and train model for this fold
701
- pl.seed_everything(42 + fold_idx)
424
+ # Build and train model
425
+ pl.seed_everything(hyperparameters["seed"] + fold_idx)
702
426
  mpnn = build_mpnn_model(
703
427
  hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
704
- n_extra_descriptors=n_extra, x_d_transform=x_d_transform, output_transform=output_transform,
705
- task_weights=task_weights,
428
+ n_extra_descriptors=n_extra, x_d_transform=x_d_transform,
429
+ output_transform=output_transform, task_weights=task_weights,
706
430
  )
707
431
 
708
- callbacks = [
709
- pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
710
- pl.callbacks.ModelCheckpoint(
711
- dirpath=args.model_dir, filename=f"best_model_{fold_idx}",
712
- monitor="val_loss", mode="min", save_top_k=1,
713
- ),
714
- ]
715
-
716
432
  trainer = pl.Trainer(
717
- accelerator="auto", max_epochs=max_epochs, callbacks=callbacks,
718
- logger=False, enable_progress_bar=True,
433
+ accelerator="auto", max_epochs=hyperparameters["max_epochs"], logger=False, enable_progress_bar=True,
434
+ callbacks=[
435
+ pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min"),
436
+ pl.callbacks.ModelCheckpoint(dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1),
437
+ ],
719
438
  )
720
-
721
439
  trainer.fit(mpnn, train_loader, val_loader)
722
440
 
441
+ # Load best checkpoint
723
442
  if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
724
443
  checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
725
444
  mpnn.load_state_dict(checkpoint["state_dict"])
@@ -727,224 +446,128 @@ if __name__ == "__main__":
727
446
  mpnn.eval()
728
447
  ensemble_models.append(mpnn)
729
448
 
730
- # Make out-of-fold predictions using raw features
731
- val_datapoints_raw, _ = create_molecule_datapoints(
732
- df_val[smiles_column].tolist(), val_targets, val_extra_raw
733
- )
734
- val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
735
- val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
449
+ # Out-of-fold predictions (using raw features)
450
+ val_dps_raw, _ = create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
451
+ val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
736
452
 
737
453
  with torch.inference_mode():
738
- fold_predictions = trainer.predict(mpnn, val_loader_pred)
739
- fold_preds = np.concatenate([p.numpy() for p in fold_predictions], axis=0)
454
+ fold_preds = np.concatenate([p.numpy() for p in trainer.predict(mpnn, val_loader_pred)], axis=0)
740
455
  if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
741
456
  fold_preds = fold_preds.squeeze(axis=1)
742
457
 
743
- # Store out-of-fold predictions
744
458
  if model_type == "classifier" and fold_preds.ndim == 2:
745
- # Store class index in first column for classification
746
459
  oof_predictions[val_idx, 0] = np.argmax(fold_preds, axis=1)
747
460
  if oof_proba is not None:
748
461
  oof_proba[val_idx] = fold_preds
749
462
  else:
750
- # Regression: fold_preds shape is (n_val, n_targets) or (n_val,)
751
463
  if fold_preds.ndim == 1:
752
464
  fold_preds = fold_preds.reshape(-1, 1)
753
465
  oof_predictions[val_idx] = fold_preds
754
466
 
755
- print(f"Fold {fold_idx + 1} complete!")
756
-
757
467
  print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
758
468
 
759
- # Use out-of-fold predictions for metrics
760
- # For n_folds=1, we only have predictions for val_idx, so filter to those rows
469
+ # -------------------------------------------------------------------------
470
+ # Prepare validation results
471
+ # -------------------------------------------------------------------------
761
472
  if n_folds == 1:
762
- # oof_predictions is always 2D now: check if any column has a value
763
473
  val_mask = ~np.isnan(oof_predictions).all(axis=1)
764
- preds = oof_predictions[val_mask]
765
474
  df_val = all_df[val_mask].copy()
475
+ preds = oof_predictions[val_mask]
766
476
  y_validate = all_targets[val_mask]
767
477
  if oof_proba is not None:
768
478
  oof_proba = oof_proba[val_mask]
769
479
  val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
770
480
  else:
771
- preds = oof_predictions
772
481
  df_val = all_df.copy()
482
+ preds = oof_predictions
773
483
  y_validate = all_targets
774
484
  val_extra_features = all_extra_features
775
485
 
776
- # Compute prediction_std by running all ensemble models on validation data
777
- # For n_folds=1, std will be 0 (only one model). For n_folds>1, std shows ensemble disagreement.
778
- preds_std = None
779
- if model_type in ["regressor", "uq_regressor"] and len(ensemble_models) > 0:
780
- print("Computing prediction_std from ensemble predictions on validation data...")
781
- val_datapoints_for_std, _ = create_molecule_datapoints(
782
- df_val[smiles_column].tolist(),
783
- y_validate,
784
- val_extra_features
785
- )
786
- val_dataset_for_std = data.MoleculeDataset(val_datapoints_for_std)
787
- val_loader_for_std = data.build_dataloader(val_dataset_for_std, batch_size=batch_size, shuffle=False)
788
-
789
- all_ensemble_preds_for_std = []
790
- trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
791
- for ens_model in ensemble_models:
792
- with torch.inference_mode():
793
- ens_preds = trainer_pred.predict(ens_model, val_loader_for_std)
794
- ens_preds = np.concatenate([p.numpy() for p in ens_preds], axis=0)
795
- if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
796
- ens_preds = ens_preds.squeeze(axis=1)
797
- all_ensemble_preds_for_std.append(ens_preds)
798
-
799
- # Stack ensemble predictions: shape (n_ensemble, n_samples, n_targets)
800
- ensemble_preds_stacked = np.stack(all_ensemble_preds_for_std, axis=0)
801
- preds_std = np.std(ensemble_preds_stacked, axis=0)
802
- # Ensure 2D
803
- if preds_std.ndim == 1:
804
- preds_std = preds_std.reshape(-1, 1)
805
- print(f"Ensemble prediction_std - mean per target: {np.nanmean(preds_std, axis=0)}")
806
-
486
+ # -------------------------------------------------------------------------
487
+ # Compute metrics and prepare output
488
+ # -------------------------------------------------------------------------
807
489
  if model_type == "classifier":
808
- # Classification metrics - preds contains class indices in first column from OOF predictions
809
490
  class_preds = preds[:, 0].astype(int)
810
- has_proba = oof_proba is not None
811
-
812
- print(f"class_preds shape: {class_preds.shape}")
813
-
814
- # Decode labels for metrics (classification is single-target only)
815
491
  target_name = target_columns[0]
816
- y_validate_decoded = label_encoder.inverse_transform(y_validate[:, 0].astype(int))
492
+ y_true_decoded = label_encoder.inverse_transform(y_validate[:, 0].astype(int))
817
493
  preds_decoded = label_encoder.inverse_transform(class_preds)
818
494
 
819
- # Calculate metrics
820
- label_names = label_encoder.classes_
821
- scores = precision_recall_fscore_support(
822
- y_validate_decoded, preds_decoded, average=None, labels=label_names
823
- )
824
-
825
- score_df = pd.DataFrame(
826
- {
827
- target_name: label_names,
828
- "precision": scores[0],
829
- "recall": scores[1],
830
- "f1": scores[2],
831
- "support": scores[3],
832
- }
833
- )
834
-
835
- # Output metrics per class
836
- metrics = ["precision", "recall", "f1", "support"]
837
- for t in label_names:
838
- for m in metrics:
839
- value = score_df.loc[score_df[target_name] == t, m].iloc[0]
840
- print(f"Metrics:{t}:{m} {value}")
495
+ score_df = compute_classification_metrics(y_true_decoded, preds_decoded, label_encoder.classes_, target_name)
496
+ print_classification_metrics(score_df, target_name, label_encoder.classes_)
497
+ print_confusion_matrix(y_true_decoded, preds_decoded, label_encoder.classes_)
841
498
 
842
- # Confusion matrix
843
- conf_mtx = confusion_matrix(
844
- y_validate_decoded, preds_decoded, labels=label_names
845
- )
846
- for i, row_name in enumerate(label_names):
847
- for j, col_name in enumerate(label_names):
848
- value = conf_mtx[i, j]
849
- print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
850
-
851
- # Save validation predictions
852
- df_val = df_val.copy()
499
+ # Decode target column back to string labels (was encoded for training)
500
+ df_val[target_name] = y_true_decoded
853
501
  df_val["prediction"] = preds_decoded
854
- if has_proba and oof_proba is not None:
502
+ if oof_proba is not None:
855
503
  df_val["pred_proba"] = [p.tolist() for p in oof_proba]
856
- df_val = expand_proba_column(df_val, label_names)
857
-
504
+ df_val = expand_proba_column(df_val, label_encoder.classes_)
858
505
  else:
859
- # Regression metrics: compute per target (works for single or multi-task)
860
- df_val = df_val.copy()
506
+ # Compute ensemble std
507
+ preds_std = None
508
+ if len(ensemble_models) > 1:
509
+ print("Computing prediction_std from ensemble...")
510
+ val_dps, _ = create_molecule_datapoints(df_val[smiles_column].tolist(), y_validate, val_extra_features)
511
+ val_loader = data.build_dataloader(data.MoleculeDataset(val_dps), batch_size=batch_size, shuffle=False)
512
+ trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
513
+
514
+ all_ens_preds = []
515
+ for m in ensemble_models:
516
+ with torch.inference_mode():
517
+ ens_preds = np.concatenate([p.numpy() for p in trainer_pred.predict(m, val_loader)], axis=0)
518
+ if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
519
+ ens_preds = ens_preds.squeeze(axis=1)
520
+ all_ens_preds.append(ens_preds)
521
+ preds_std = np.std(np.stack(all_ens_preds), axis=0)
522
+ if preds_std.ndim == 1:
523
+ preds_std = preds_std.reshape(-1, 1)
524
+
861
525
  print("\n--- Per-target metrics ---")
862
526
  for t_idx, t_name in enumerate(target_columns):
863
- # Get valid (non-NaN) indices for this target
864
- target_valid_mask = ~np.isnan(y_validate[:, t_idx])
865
- y_true = y_validate[target_valid_mask, t_idx]
866
- y_pred = preds[target_valid_mask, t_idx]
867
-
868
- if len(y_true) > 0:
869
- rmse = root_mean_squared_error(y_true, y_pred)
870
- mae = mean_absolute_error(y_true, y_pred)
871
- medae = median_absolute_error(y_true, y_pred)
872
- r2 = r2_score(y_true, y_pred)
873
- spearman_corr = spearmanr(y_true, y_pred).correlation
874
- support = len(y_true)
875
- # Print metrics in format expected by SageMaker metric definitions
876
- print(f"rmse: {rmse:.3f}")
877
- print(f"mae: {mae:.3f}")
878
- print(f"medae: {medae:.3f}")
879
- print(f"r2: {r2:.3f}")
880
- print(f"spearmanr: {spearman_corr:.3f}")
881
- print(f"support: {support}")
882
-
883
- # Store predictions in dataframe
527
+ valid_mask = ~np.isnan(y_validate[:, t_idx])
528
+ if valid_mask.sum() > 0:
529
+ metrics = compute_regression_metrics(y_validate[valid_mask, t_idx], preds[valid_mask, t_idx])
530
+ print_regression_metrics(metrics)
531
+
884
532
  df_val[f"{t_name}_pred"] = preds[:, t_idx]
885
- if preds_std is not None:
886
- df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx]
887
- else:
888
- df_val[f"{t_name}_pred_std"] = 0.0
533
+ df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx] if preds_std is not None else 0.0
889
534
 
890
- # Add prediction/prediction_std aliases for first target
891
- first_target = target_columns[0]
892
- df_val["prediction"] = df_val[f"{first_target}_pred"]
893
- df_val["prediction_std"] = df_val[f"{first_target}_pred_std"]
535
+ df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
536
+ df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
894
537
 
538
+ # -------------------------------------------------------------------------
895
539
  # Save validation predictions to S3
896
- # Include id_column if it exists in df_val
897
- output_columns = []
898
- if id_column in df_val.columns:
899
- output_columns.append(id_column)
900
- # Include all target columns and their predictions
540
+ # -------------------------------------------------------------------------
541
+ output_columns = [id_column] if id_column in df_val.columns else []
901
542
  output_columns += target_columns
902
- output_columns += [f"{t}_pred" for t in target_columns]
903
- output_columns += [f"{t}_pred_std" for t in target_columns]
543
+ output_columns += [f"{t}_pred" for t in target_columns] + [f"{t}_pred_std" for t in target_columns]
904
544
  output_columns += ["prediction", "prediction_std"]
905
- # Add proba columns for classifiers
906
- output_columns += [col for col in df_val.columns if col.endswith("_proba")]
907
- # Filter to only columns that exist
545
+ output_columns += [c for c in df_val.columns if c.endswith("_proba")]
908
546
  output_columns = [c for c in output_columns if c in df_val.columns]
909
- wr.s3.to_csv(
910
- df_val[output_columns],
911
- path=f"{model_metrics_s3_path}/validation_predictions.csv",
912
- index=False,
913
- )
914
-
915
- # Save ensemble models (n_folds models if CV, 1 model otherwise)
916
- for model_idx, ens_model in enumerate(ensemble_models):
917
- model_path = os.path.join(args.model_dir, f"chemprop_model_{model_idx}.pt")
918
- models.save_model(model_path, ens_model)
919
- print(f"Saved model {model_idx + 1} to {model_path}")
920
-
921
- # Clean up checkpoint files (not needed for inference, reduces artifact size)
922
- for ckpt_file in glob.glob(os.path.join(args.model_dir, "best_model_*.ckpt")):
923
- os.remove(ckpt_file)
924
- print(f"Removed checkpoint: {ckpt_file}")
925
-
926
- # Save ensemble metadata (n_ensemble = number of models for inference)
927
- n_ensemble = len(ensemble_models)
928
- ensemble_metadata = {
929
- "n_ensemble": n_ensemble,
930
- "n_folds": n_folds,
931
- "target_columns": target_columns,
932
- }
933
- joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
934
- print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds}, targets={target_columns})")
935
-
936
- # Save label encoder if classification
937
- if label_encoder is not None:
547
+
548
+ wr.s3.to_csv(df_val[output_columns], f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
549
+
550
+ # -------------------------------------------------------------------------
551
+ # Save model artifacts
552
+ # -------------------------------------------------------------------------
553
+ for idx, m in enumerate(ensemble_models):
554
+ models.save_model(os.path.join(args.model_dir, f"chemprop_model_{idx}.pt"), m)
555
+ print(f"Saved {len(ensemble_models)} model(s)")
556
+
557
+ # Clean up checkpoints
558
+ for ckpt in glob.glob(os.path.join(args.model_dir, "best_*.ckpt")):
559
+ os.remove(ckpt)
560
+
561
+ joblib.dump({"n_ensemble": len(ensemble_models), "n_folds": n_folds, "target_columns": target_columns}, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
562
+
563
+ with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
564
+ json.dump(hyperparameters, f, indent=2)
565
+
566
+ if label_encoder:
938
567
  joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
939
568
 
940
- # Save extra feature metadata for inference (hybrid mode)
941
- # Note: We don't need to save the scaler - X_d_transform is embedded in the model
942
569
  if use_extra_features:
943
- feature_metadata = {
944
- "extra_feature_cols": extra_feature_cols,
945
- "col_means": col_means.tolist(), # Unscaled means for NaN imputation
946
- }
947
- joblib.dump(
948
- feature_metadata, os.path.join(args.model_dir, "feature_metadata.joblib")
949
- )
570
+ joblib.dump({"extra_feature_cols": extra_feature_cols, "col_means": col_means.tolist()}, os.path.join(args.model_dir, "feature_metadata.joblib"))
950
571
  print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")
572
+
573
+ print(f"\nModel training complete! Artifacts saved to {args.model_dir}")