workbench 0.8.202__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (84) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  3. workbench/algorithms/dataframe/fingerprint_proximity.py +421 -85
  4. workbench/algorithms/dataframe/projection_2d.py +44 -21
  5. workbench/algorithms/dataframe/proximity.py +78 -150
  6. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  7. workbench/algorithms/models/cleanlab_model.py +382 -0
  8. workbench/algorithms/models/noise_model.py +388 -0
  9. workbench/algorithms/sql/outliers.py +3 -3
  10. workbench/api/__init__.py +3 -0
  11. workbench/api/df_store.py +17 -108
  12. workbench/api/endpoint.py +13 -11
  13. workbench/api/feature_set.py +111 -8
  14. workbench/api/meta_model.py +289 -0
  15. workbench/api/model.py +45 -12
  16. workbench/api/parameter_store.py +3 -52
  17. workbench/cached/cached_model.py +4 -4
  18. workbench/core/artifacts/artifact.py +5 -5
  19. workbench/core/artifacts/df_store_core.py +114 -0
  20. workbench/core/artifacts/endpoint_core.py +228 -237
  21. workbench/core/artifacts/feature_set_core.py +185 -230
  22. workbench/core/artifacts/model_core.py +34 -26
  23. workbench/core/artifacts/parameter_store_core.py +98 -0
  24. workbench/core/pipelines/pipeline_executor.py +1 -1
  25. workbench/core/transforms/features_to_model/features_to_model.py +22 -10
  26. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +41 -10
  27. workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
  28. workbench/model_script_utils/model_script_utils.py +339 -0
  29. workbench/model_script_utils/pytorch_utils.py +405 -0
  30. workbench/model_script_utils/uq_harness.py +278 -0
  31. workbench/model_scripts/chemprop/chemprop.template +428 -631
  32. workbench/model_scripts/chemprop/generated_model_script.py +432 -635
  33. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  34. workbench/model_scripts/chemprop/requirements.txt +2 -10
  35. workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
  36. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  37. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  38. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  39. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  40. workbench/model_scripts/meta_model/meta_model.template +209 -0
  41. workbench/model_scripts/pytorch_model/generated_model_script.py +374 -613
  42. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  43. workbench/model_scripts/pytorch_model/pytorch.template +370 -609
  44. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  45. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  46. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  47. workbench/model_scripts/script_generation.py +6 -5
  48. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  49. workbench/model_scripts/xgb_model/generated_model_script.py +372 -395
  50. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  51. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  52. workbench/model_scripts/xgb_model/xgb_model.template +366 -396
  53. workbench/repl/workbench_shell.py +0 -5
  54. workbench/resources/open_source_api.key +1 -1
  55. workbench/scripts/endpoint_test.py +2 -2
  56. workbench/scripts/meta_model_sim.py +35 -0
  57. workbench/scripts/training_test.py +85 -0
  58. workbench/utils/chem_utils/fingerprints.py +87 -46
  59. workbench/utils/chem_utils/projections.py +16 -6
  60. workbench/utils/chemprop_utils.py +36 -655
  61. workbench/utils/meta_model_simulator.py +499 -0
  62. workbench/utils/metrics_utils.py +256 -0
  63. workbench/utils/model_utils.py +192 -54
  64. workbench/utils/pytorch_utils.py +33 -472
  65. workbench/utils/shap_utils.py +1 -55
  66. workbench/utils/xgboost_local_crossfold.py +267 -0
  67. workbench/utils/xgboost_model_utils.py +49 -356
  68. workbench/web_interface/components/model_plot.py +7 -1
  69. workbench/web_interface/components/plugins/model_details.py +30 -68
  70. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  71. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/METADATA +6 -5
  72. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/RECORD +76 -60
  73. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/entry_points.txt +2 -0
  74. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  75. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -296
  76. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  77. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  78. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  79. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  80. workbench/model_scripts/uq_models/mapie.template +0 -605
  81. workbench/model_scripts/uq_models/requirements.txt +0 -1
  82. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  83. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +0 -0
  84. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -1,630 +1,479 @@
1
1
  # ChemProp Model Template for Workbench
2
- # Uses ChemProp 2.x Message Passing Neural Networks for molecular property prediction
3
2
  #
4
- # === CHEMPROP REVIEW NOTES ===
5
- # This script runs on AWS SageMaker. Key areas for ChemProp review:
3
+ # This template handles molecular property prediction using ChemProp 2.x MPNN with:
4
+ # - K-fold cross-validation ensemble training (or single train/val split)
5
+ # - Multi-task regression support
6
+ # - Hybrid mode (SMILES + extra molecular descriptors)
7
+ # - Classification (single-target only)
6
8
  #
7
- # 1. Model Architecture (build_mpnn_model function)
8
- # - BondMessagePassing, NormAggregation, FFN configuration
9
- # - Regression uses output_transform (UnscaleTransform) for target scaling
10
- #
11
- # 2. Data Handling (create_molecule_datapoints function)
12
- # - MoleculeDatapoint creation with x_d (extra descriptors)
13
- # - RDKit validation of SMILES
14
- #
15
- # 3. Scaling (training section)
16
- # - Extra descriptors: normalize_inputs("X_d") + X_d_transform in model
17
- # - Targets (regression): normalize_targets() + UnscaleTransform in FFN
18
- # - At inference: pass RAW features, transforms handle scaling automatically
19
- #
20
- # 4. Training Loop (search for "pl.Trainer")
21
- # - PyTorch Lightning Trainer with ChemProp MPNN
22
- #
23
- # AWS/SageMaker boilerplate (can skip):
24
- # - input_fn, output_fn, model_fn: SageMaker serving interface
25
- # - argparse, file loading, S3 writes
26
- # =============================
9
+ # NOTE: Imports are structured to minimize serverless endpoint startup time.
10
+ # Heavy imports (lightning, sklearn, awswrangler) are deferred to training time.
27
11
 
28
- import os
29
- import argparse
30
12
  import json
31
- from io import StringIO
13
+ import os
32
14
 
33
- import awswrangler as wr
15
+ import joblib
34
16
  import numpy as np
35
17
  import pandas as pd
36
18
  import torch
37
- from lightning import pytorch as pl
38
- from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
39
- from sklearn.preprocessing import LabelEncoder
40
- from sklearn.metrics import (
41
- mean_absolute_error,
42
- median_absolute_error,
43
- r2_score,
44
- root_mean_squared_error,
45
- precision_recall_fscore_support,
46
- confusion_matrix,
19
+
20
+ from chemprop import data, models
21
+
22
+ from model_script_utils import (
23
+ expand_proba_column,
24
+ input_fn,
25
+ output_fn,
47
26
  )
48
- from scipy.stats import spearmanr
49
- import joblib
50
27
 
51
- # ChemProp imports
52
- from chemprop import data, models, nn
28
+ # =============================================================================
29
+ # Default Hyperparameters
30
+ # =============================================================================
31
+ DEFAULT_HYPERPARAMETERS = {
32
+ # Training
33
+ "n_folds": 5,
34
+ "max_epochs": 400,
35
+ "patience": 50,
36
+ "batch_size": 32,
37
+ # Message Passing
38
+ "hidden_dim": 700,
39
+ "depth": 6,
40
+ "dropout": 0.1, # Lower dropout - ensemble provides regularization
41
+ # FFN
42
+ "ffn_hidden_dim": 2000,
43
+ "ffn_num_layers": 2,
44
+ # Loss function for regression (mae, mse)
45
+ "criterion": "mae",
46
+ # Random seed
47
+ "seed": 42,
48
+ }
53
49
 
54
- # Template Parameters
50
+ # Template parameters (filled in by Workbench)
55
51
  TEMPLATE_PARAMS = {
56
52
  "model_type": "uq_regressor",
57
- "target": "udm_asy_res_efflux_ratio",
58
- "feature_list": ['smiles', 'smr_vsa4', 'tpsa', 'nhohcount', 'peoe_vsa1', 'mollogp', 'numhdonors', 'tertiary_amine_count', 'smr_vsa3', 'nitrogen_span', 'vsa_estate2', 'hba_hbd_ratio', 'minpartialcharge', 'estate_vsa4', 'asphericity', 'charge_centroid_distance', 'peoe_vsa8', 'mi', 'estate_vsa8', 'vsa_estate6', 'vsa_estate3', 'molecular_volume_3d', 'kappa3', 'smr_vsa5', 'sv', 'xp_6dv', 'xc_4dv', 'si', 'molecular_axis_length', 'axp_5d', 'estate_vsa3', 'estate_vsa10', 'axp_7dv', 'slogp_vsa1', 'molecular_asymmetry', 'molmr', 'qed', 'xp_3d', 'axp_0dv', 'fpdensitymorgan1', 'minabsestateindex', 'numatomstereocenters', 'fpdensitymorgan2', 'slogp_vsa2', 'xch_5dv', 'num_s_centers', 'aromatic_interaction_score', 'axp_2dv', 'chi1v', 'hallkieralpha', 'vsa_estate8'],
59
- "id_column": "udm_mol_bat_id",
60
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-chemprop-reg-hybrid/training",
61
- "hyperparameters": {'n_folds': 5, 'hidden_dim': 300, 'depth': 4, 'dropout': 0.1, 'ffn_hidden_dim': 300, 'ffn_num_layers': 2},
53
+ "targets": ['logd'],
54
+ "feature_list": ['smiles', 'mollogp', 'fr_halogen', 'nbase', 'peoe_vsa6', 'bcut2d_mrlow', 'peoe_vsa7', 'peoe_vsa9', 'vsa_estate1', 'peoe_vsa1', 'numhdonors', 'vsa_estate5', 'smr_vsa3', 'slogp_vsa1', 'vsa_estate7', 'bcut2d_mwhi', 'axp_2dv', 'axp_3dv', 'mi', 'smr_vsa9', 'vsa_estate3', 'estate_vsa9', 'bcut2d_mwlow', 'tpsa', 'vsa_estate10', 'xch_5dv', 'slogp_vsa2', 'nhohcount', 'bcut2d_logplow', 'hallkieralpha', 'c2sp2', 'bcut2d_chglo', 'smr_vsa4', 'maxabspartialcharge', 'estate_vsa6', 'qed', 'slogp_vsa6', 'vsa_estate2', 'bcut2d_logphi', 'vsa_estate8', 'xch_7dv', 'fpdensitymorgan3', 'xpc_6d', 'smr_vsa10', 'axp_0d', 'fr_nh1', 'axp_4dv', 'peoe_vsa2', 'estate_vsa8', 'peoe_vsa5', 'vsa_estate6'],
55
+ "id_column": "molecule_name",
56
+ "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/logd-reg-chemprop-hybrid/training",
57
+ "hyperparameters": {},
62
58
  }
63
59
 
64
60
 
65
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
66
- """Check if the provided dataframe is empty and raise an exception if it is."""
67
- if df.empty:
68
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
69
- print(msg)
70
- raise ValueError(msg)
71
-
61
+ # =============================================================================
62
+ # Helper Functions
63
+ # =============================================================================
64
+ def _compute_std_confidence(df: pd.DataFrame, median_std: float, std_col: str = "prediction_std") -> pd.DataFrame:
65
+ """Compute confidence score from ensemble prediction_std.
72
66
 
73
- def find_smiles_column(columns: list[str]) -> str:
74
- """Find the SMILES column name from a list (case-insensitive match for 'smiles')."""
75
- smiles_column = next((col for col in columns if col.lower() == "smiles"), None)
76
- if smiles_column is None:
77
- raise ValueError(
78
- "Column list must contain a 'smiles' column (case-insensitive)"
79
- )
80
- return smiles_column
67
+ Uses exponential decay: confidence = exp(-std / median_std)
68
+ - Low std (ensemble agreement) -> high confidence
69
+ - High std (ensemble disagreement) -> low confidence
81
70
 
71
+ Args:
72
+ df: DataFrame with prediction_std column
73
+ median_std: Median std from training validation set (normalization factor)
74
+ std_col: Name of the std column to use
82
75
 
83
- def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
84
- """Expands a column containing a list of probabilities into separate columns.
85
-
86
- Handles None values for rows where predictions couldn't be made.
76
+ Returns:
77
+ DataFrame with added 'confidence' column (0.0 to 1.0)
87
78
  """
88
- proba_column = "pred_proba"
89
- if proba_column not in df.columns:
90
- raise ValueError('DataFrame does not contain a "pred_proba" column')
91
-
92
- proba_splits = [f"{label}_proba" for label in class_labels]
93
- n_classes = len(class_labels)
94
-
95
- # Handle None values by replacing with list of NaNs
96
- proba_values = []
97
- for val in df[proba_column]:
98
- if val is None:
99
- proba_values.append([np.nan] * n_classes)
100
- else:
101
- proba_values.append(val)
79
+ df["confidence"] = np.exp(-df[std_col] / median_std)
80
+ return df
102
81
 
103
- proba_df = pd.DataFrame(proba_values, columns=proba_splits)
104
82
 
105
- df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
106
- df = df.reset_index(drop=True)
107
- df = pd.concat([df, proba_df], axis=1)
108
- return df
83
+ def _find_smiles_column(columns: list[str]) -> str:
84
+ """Find SMILES column (case-insensitive match for 'smiles')."""
85
+ smiles_col = next((c for c in columns if c.lower() == "smiles"), None)
86
+ if smiles_col is None:
87
+ raise ValueError("Column list must contain a 'smiles' column (case-insensitive)")
88
+ return smiles_col
109
89
 
110
90
 
111
- def create_molecule_datapoints(
91
+ def _create_molecule_datapoints(
112
92
  smiles_list: list[str],
113
- targets: list[float] | None = None,
93
+ targets: np.ndarray | None = None,
114
94
  extra_descriptors: np.ndarray | None = None,
115
95
  ) -> tuple[list[data.MoleculeDatapoint], list[int]]:
116
- """Create ChemProp MoleculeDatapoints from SMILES strings.
117
-
118
- Args:
119
- smiles_list: List of SMILES strings
120
- targets: Optional list of target values (for training)
121
- extra_descriptors: Optional array of extra features (n_samples, n_features)
122
-
123
- Returns:
124
- Tuple of (list of MoleculeDatapoint objects, list of valid indices)
125
- """
96
+ """Create ChemProp MoleculeDatapoints from SMILES strings."""
126
97
  from rdkit import Chem
127
98
 
128
- datapoints = []
129
- valid_indices = []
130
- invalid_count = 0
99
+ datapoints, valid_indices = [], []
100
+ targets = np.atleast_2d(np.array(targets)).T if targets is not None and np.array(targets).ndim == 1 else targets
131
101
 
132
102
  for i, smi in enumerate(smiles_list):
133
- # Validate SMILES with RDKit first
134
- mol = Chem.MolFromSmiles(smi)
135
- if mol is None:
136
- invalid_count += 1
103
+ if Chem.MolFromSmiles(smi) is None:
137
104
  continue
138
-
139
- # Build datapoint with optional target and extra descriptors
140
- y = [targets[i]] if targets is not None else None
105
+ y = targets[i].tolist() if targets is not None else None
141
106
  x_d = extra_descriptors[i] if extra_descriptors is not None else None
142
-
143
- dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
144
- datapoints.append(dp)
107
+ datapoints.append(data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d))
145
108
  valid_indices.append(i)
146
109
 
147
- if invalid_count > 0:
148
- print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
149
-
150
110
  return datapoints, valid_indices
151
111
 
152
112
 
153
- def build_mpnn_model(
154
- hyperparameters: dict,
155
- task: str = "regression",
156
- num_classes: int | None = None,
157
- n_extra_descriptors: int = 0,
158
- x_d_transform: nn.ScaleTransform | None = None,
159
- output_transform: nn.UnscaleTransform | None = None,
160
- ) -> models.MPNN:
161
- """Build an MPNN model with the specified hyperparameters.
162
-
163
- Args:
164
- hyperparameters: Dictionary of model hyperparameters
165
- task: Either "regression" or "classification"
166
- num_classes: Number of classes for classification tasks
167
- n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
168
- x_d_transform: Optional transform for extra descriptors (scaling)
169
- output_transform: Optional transform for regression output (unscaling targets)
170
-
171
- Returns:
172
- Configured MPNN model
173
- """
174
- # Model hyperparameters with defaults
175
- hidden_dim = hyperparameters.get("hidden_dim", 300)
176
- depth = hyperparameters.get("depth", 4)
177
- dropout = hyperparameters.get("dropout", 0.1)
178
- ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 300)
179
- ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
180
-
181
- # Message passing component
182
- mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
183
-
184
- # Aggregation - NormAggregation normalizes output, recommended when using extra descriptors
185
- agg = nn.NormAggregation()
186
-
187
- # FFN input_dim = message passing output + extra descriptors
188
- ffn_input_dim = hidden_dim + n_extra_descriptors
189
-
190
- # Build FFN based on task type
191
- if task == "classification" and num_classes is not None:
192
- # Multi-class classification
193
- ffn = nn.MulticlassClassificationFFN(
194
- n_classes=num_classes,
195
- input_dim=ffn_input_dim,
196
- hidden_dim=ffn_hidden_dim,
197
- n_layers=ffn_num_layers,
198
- dropout=dropout,
199
- )
200
- else:
201
- # Regression with optional output transform to unscale predictions
202
- ffn = nn.RegressionFFN(
203
- input_dim=ffn_input_dim,
204
- hidden_dim=ffn_hidden_dim,
205
- n_layers=ffn_num_layers,
206
- dropout=dropout,
207
- output_transform=output_transform,
208
- )
209
-
210
- # Create the MPNN model
211
- mpnn = models.MPNN(
212
- message_passing=mp,
213
- agg=agg,
214
- predictor=ffn,
215
- batch_norm=True,
216
- metrics=None,
217
- X_d_transform=x_d_transform,
218
- )
219
-
220
- return mpnn
221
-
222
-
113
+ # =============================================================================
114
+ # Model Loading (for SageMaker inference)
115
+ # =============================================================================
223
116
  def model_fn(model_dir: str) -> dict:
224
- """Load the ChemProp MPNN ensemble models from the specified directory.
225
-
226
- Args:
227
- model_dir: Directory containing the saved models
228
-
229
- Returns:
230
- Dictionary with ensemble models and metadata
231
- """
232
- # Load ensemble metadata
233
- ensemble_metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
234
- if os.path.exists(ensemble_metadata_path):
235
- ensemble_metadata = joblib.load(ensemble_metadata_path)
236
- n_ensemble = ensemble_metadata["n_ensemble"]
237
- else:
238
- # Backwards compatibility: single model without ensemble metadata
239
- n_ensemble = 1
117
+ """Load ChemProp MPNN ensemble from the specified directory."""
118
+ from lightning import pytorch as pl
240
119
 
241
- # Load all ensemble models
120
+ metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
242
121
  ensemble_models = []
243
- for ens_idx in range(n_ensemble):
244
- model_path = os.path.join(model_dir, f"chemprop_model_{ens_idx}.pt")
245
- if not os.path.exists(model_path):
246
- # Backwards compatibility: try old single model path
247
- model_path = os.path.join(model_dir, "chemprop_model.pt")
248
- model = models.MPNN.load_from_file(model_path)
122
+ for i in range(metadata["n_ensemble"]):
123
+ model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
249
124
  model.eval()
250
125
  ensemble_models.append(model)
251
126
 
252
- print(f"Loaded {len(ensemble_models)} ensemble model(s)")
127
+ # Pre-initialize trainer once during model loading (expensive operation)
128
+ trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
253
129
 
130
+ print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
254
131
  return {
255
132
  "ensemble_models": ensemble_models,
256
- "n_ensemble": n_ensemble,
133
+ "n_ensemble": metadata["n_ensemble"],
134
+ "target_columns": metadata["target_columns"],
135
+ "median_std": metadata["median_std"],
136
+ "trainer": trainer,
257
137
  }
258
138
 
259
139
 
260
- def input_fn(input_data, content_type: str) -> pd.DataFrame:
261
- """Parse input data and return a DataFrame."""
262
- if not input_data:
263
- raise ValueError("Empty input data is not supported!")
264
-
265
- if isinstance(input_data, bytes):
266
- input_data = input_data.decode("utf-8")
267
-
268
- if "text/csv" in content_type:
269
- return pd.read_csv(StringIO(input_data))
270
- elif "application/json" in content_type:
271
- return pd.DataFrame(json.loads(input_data))
272
- else:
273
- raise ValueError(f"{content_type} not supported!")
274
-
275
-
276
- def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
277
- """Supports both CSV and JSON output formats."""
278
- if "text/csv" in accept_type:
279
- csv_output = output_df.fillna("N/A").to_csv(index=False)
280
- return csv_output, "text/csv"
281
- elif "application/json" in accept_type:
282
- return output_df.to_json(orient="records"), "application/json"
283
- else:
284
- raise RuntimeError(
285
- f"{accept_type} accept type is not supported by this script."
286
- )
287
-
288
-
140
+ # =============================================================================
141
+ # Inference (for SageMaker inference)
142
+ # =============================================================================
289
143
  def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
290
- """Make predictions with the ChemProp MPNN ensemble.
291
-
292
- Args:
293
- df: Input DataFrame containing SMILES column (and extra features if hybrid mode)
294
- model_dict: Dictionary containing ensemble models and metadata
295
-
296
- Returns:
297
- DataFrame with predictions added (and prediction_std for ensembles)
298
- """
144
+ """Make predictions with ChemProp MPNN ensemble."""
299
145
  model_type = TEMPLATE_PARAMS["model_type"]
300
146
  model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
301
147
 
302
- # Extract ensemble models
303
148
  ensemble_models = model_dict["ensemble_models"]
304
- n_ensemble = model_dict["n_ensemble"]
149
+ target_columns = model_dict["target_columns"]
150
+ trainer = model_dict["trainer"] # Use pre-initialized trainer
305
151
 
306
- # Load label encoder if present (classification)
152
+ # Load artifacts
307
153
  label_encoder = None
308
- label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
309
- if os.path.exists(label_encoder_path):
310
- label_encoder = joblib.load(label_encoder_path)
154
+ encoder_path = os.path.join(model_dir, "label_encoder.joblib")
155
+ if os.path.exists(encoder_path):
156
+ label_encoder = joblib.load(encoder_path)
311
157
 
312
- # Load feature metadata if present (hybrid mode)
313
- # Contains column names, NaN fill values, and scaler for feature scaling
314
158
  feature_metadata = None
315
- feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
316
- if os.path.exists(feature_metadata_path):
317
- feature_metadata = joblib.load(feature_metadata_path)
318
- print(
319
- f"Hybrid mode: using {len(feature_metadata['extra_feature_cols'])} extra features"
320
- )
321
-
322
- # Find SMILES column in input DataFrame
323
- smiles_column = find_smiles_column(df.columns.tolist())
159
+ feature_path = os.path.join(model_dir, "feature_metadata.joblib")
160
+ if os.path.exists(feature_path):
161
+ feature_metadata = joblib.load(feature_path)
162
+ print(f"Hybrid mode: {len(feature_metadata['extra_feature_cols'])} extra features")
324
163
 
164
+ # Find SMILES column and validate
165
+ smiles_column = _find_smiles_column(df.columns.tolist())
325
166
  smiles_list = df[smiles_column].tolist()
326
167
 
327
- # Track invalid SMILES
328
- valid_mask = []
329
- valid_smiles = []
330
- valid_indices = []
331
- for i, smi in enumerate(smiles_list):
332
- if smi and isinstance(smi, str) and len(smi.strip()) > 0:
333
- valid_mask.append(True)
334
- valid_smiles.append(smi.strip())
335
- valid_indices.append(i)
336
- else:
337
- valid_mask.append(False)
338
-
339
- valid_mask = np.array(valid_mask)
168
+ valid_mask = np.array([bool(s and isinstance(s, str) and s.strip()) for s in smiles_list])
169
+ valid_smiles = [s.strip() for i, s in enumerate(smiles_list) if valid_mask[i]]
340
170
  print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
341
171
 
342
- # Initialize prediction column (use object dtype for classifiers to avoid FutureWarning)
172
+ # Initialize output columns
343
173
  if model_type == "classifier":
344
174
  df["prediction"] = pd.Series([None] * len(df), dtype=object)
345
175
  else:
346
- # Regression (includes uq_regressor)
347
- df["prediction"] = np.nan
348
- df["prediction_std"] = np.nan
176
+ for tc in target_columns:
177
+ df[f"{tc}_pred"] = np.nan
178
+ df[f"{tc}_pred_std"] = np.nan
349
179
 
350
180
  if sum(valid_mask) == 0:
351
- print("Warning: No valid SMILES to predict on")
352
181
  return df
353
182
 
354
- # Prepare extra features if in hybrid mode
355
- # NOTE: We pass RAW (unscaled) features here - the model's X_d_transform handles scaling
183
+ # Prepare extra features (raw, unscaled - model handles scaling)
356
184
  extra_features = None
357
185
  if feature_metadata is not None:
358
- extra_feature_cols = feature_metadata["extra_feature_cols"]
186
+ extra_cols = feature_metadata["extra_feature_cols"]
359
187
  col_means = np.array(feature_metadata["col_means"])
188
+ valid_indices = np.where(valid_mask)[0]
360
189
 
361
- # Check columns exist
362
- missing_cols = [col for col in extra_feature_cols if col not in df.columns]
363
- if missing_cols:
364
- print(
365
- f"Warning: Missing extra feature columns: {missing_cols}. Using mean values."
366
- )
367
-
368
- # Extract features for valid SMILES rows (raw, unscaled)
369
- extra_features = np.zeros(
370
- (len(valid_indices), len(extra_feature_cols)), dtype=np.float32
371
- )
372
- for j, col in enumerate(extra_feature_cols):
190
+ extra_features = np.zeros((len(valid_indices), len(extra_cols)), dtype=np.float32)
191
+ for j, col in enumerate(extra_cols):
373
192
  if col in df.columns:
374
193
  values = df.iloc[valid_indices][col].values.astype(np.float32)
375
- # Fill NaN with training column means (unscaled means)
376
- nan_mask = np.isnan(values)
377
- values[nan_mask] = col_means[j]
194
+ values[np.isnan(values)] = col_means[j]
378
195
  extra_features[:, j] = values
379
196
  else:
380
- # Column missing, use training mean
381
197
  extra_features[:, j] = col_means[j]
382
198
 
383
- # Create datapoints for prediction (filter out invalid SMILES)
384
- datapoints, rdkit_valid_indices = create_molecule_datapoints(
385
- valid_smiles, extra_descriptors=extra_features
386
- )
387
-
199
+ # Create datapoints and predict
200
+ datapoints, rdkit_valid = _create_molecule_datapoints(valid_smiles, extra_descriptors=extra_features)
388
201
  if len(datapoints) == 0:
389
- print("Warning: No valid SMILES after RDKit validation")
390
202
  return df
391
203
 
392
204
  dataset = data.MoleculeDataset(datapoints)
393
205
  dataloader = data.build_dataloader(dataset, shuffle=False)
394
206
 
395
- # Make predictions with ensemble
396
- trainer = pl.Trainer(
397
- accelerator="auto",
398
- logger=False,
399
- enable_progress_bar=False,
400
- )
401
-
402
- # Collect predictions from all ensemble members
403
- all_ensemble_preds = []
404
- for ens_idx, ens_model in enumerate(ensemble_models):
207
+ # Ensemble predictions
208
+ all_preds = []
209
+ for model in ensemble_models:
405
210
  with torch.inference_mode():
406
- predictions = trainer.predict(ens_model, dataloader)
407
- ens_preds = np.concatenate([p.numpy() for p in predictions], axis=0)
408
- # Squeeze middle dim if present
409
- if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
410
- ens_preds = ens_preds.squeeze(axis=1)
411
- all_ensemble_preds.append(ens_preds)
412
-
413
- # Stack and compute mean/std (std is 0 for single model)
414
- ensemble_preds = np.stack(all_ensemble_preds, axis=0)
415
- preds = np.mean(ensemble_preds, axis=0)
416
- preds_std = np.std(ensemble_preds, axis=0) # Will be 0s for n_ensemble=1
417
-
418
- print(f"Inference: Ensemble predictions shape: {preds.shape}")
419
-
420
- # Map predictions back to valid_mask positions (accounting for RDKit-invalid SMILES)
421
- # rdkit_valid_indices tells us which of the valid_smiles were actually valid
422
- valid_positions = np.where(valid_mask)[0][rdkit_valid_indices]
211
+ predictions = trainer.predict(model, dataloader)
212
+ preds = np.concatenate([p.numpy() for p in predictions], axis=0)
213
+ if preds.ndim == 3 and preds.shape[1] == 1:
214
+ preds = preds.squeeze(axis=1)
215
+ all_preds.append(preds)
216
+
217
+ preds = np.mean(np.stack(all_preds), axis=0)
218
+ preds_std = np.std(np.stack(all_preds), axis=0)
219
+ if preds.ndim == 1:
220
+ preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
221
+
222
+ print(f"Inference complete: {preds.shape[0]} predictions")
223
+
224
+ # Map predictions back to valid positions
225
+ valid_positions = np.where(valid_mask)[0][rdkit_valid]
423
226
  valid_mask = np.zeros(len(df), dtype=bool)
424
227
  valid_mask[valid_positions] = True
425
228
 
426
229
  if model_type == "classifier" and label_encoder is not None:
427
- # For classification, get class predictions and probabilities
428
- if preds.ndim == 2 and preds.shape[1] > 1:
429
- # Multi-class: preds are probabilities (averaged across ensemble)
230
+ if preds.shape[1] > 1:
430
231
  class_preds = np.argmax(preds, axis=1)
431
- decoded_preds = label_encoder.inverse_transform(class_preds)
432
- df.loc[valid_mask, "prediction"] = decoded_preds
433
-
434
- # Add probability columns
435
- proba_series = pd.Series([None] * len(df), index=df.index, dtype=object)
436
- proba_series.loc[valid_mask] = [p.tolist() for p in preds]
437
- df["pred_proba"] = proba_series
232
+ df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform(class_preds)
233
+ proba = pd.Series([None] * len(df), dtype=object)
234
+ proba.loc[valid_mask] = [p.tolist() for p in preds]
235
+ df["pred_proba"] = proba
438
236
  df = expand_proba_column(df, label_encoder.classes_)
439
237
  else:
440
- # Binary or single output
441
- class_preds = (preds.flatten() > 0.5).astype(int)
442
- decoded_preds = label_encoder.inverse_transform(class_preds)
443
- df.loc[valid_mask, "prediction"] = decoded_preds
238
+ df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform((preds.flatten() > 0.5).astype(int))
444
239
  else:
445
- # Regression: direct predictions
446
- df.loc[valid_mask, "prediction"] = preds.flatten()
447
- df.loc[valid_mask, "prediction_std"] = preds_std.flatten()
240
+ for t_idx, tc in enumerate(target_columns):
241
+ df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
242
+ df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
243
+ df["prediction"] = df[f"{target_columns[0]}_pred"]
244
+ df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
245
+
246
+ # Compute confidence from ensemble std
247
+ df = _compute_std_confidence(df, model_dict["median_std"])
448
248
 
449
249
  return df
450
250
 
451
251
 
252
+ # =============================================================================
253
+ # Training
254
+ # =============================================================================
452
255
  if __name__ == "__main__":
453
- """Training script for ChemProp MPNN model"""
256
+ # -------------------------------------------------------------------------
257
+ # Training-only imports (deferred to reduce serverless startup time)
258
+ # -------------------------------------------------------------------------
259
+ import argparse
260
+ import glob
261
+
262
+ import awswrangler as wr
263
+ from lightning import pytorch as pl
264
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
265
+ from sklearn.preprocessing import LabelEncoder
266
+
267
+ # Enable Tensor Core optimization for GPUs that support it
268
+ torch.set_float32_matmul_precision("medium")
269
+
270
+ from chemprop import nn
271
+
272
+ from model_script_utils import (
273
+ check_dataframe,
274
+ compute_classification_metrics,
275
+ compute_regression_metrics,
276
+ print_classification_metrics,
277
+ print_confusion_matrix,
278
+ print_regression_metrics,
279
+ )
280
+
281
+ # -------------------------------------------------------------------------
282
+ # Training-only helper function
283
+ # -------------------------------------------------------------------------
284
+ def build_mpnn_model(
285
+ hyperparameters: dict,
286
+ task: str = "regression",
287
+ num_classes: int | None = None,
288
+ n_targets: int = 1,
289
+ n_extra_descriptors: int = 0,
290
+ x_d_transform: nn.ScaleTransform | None = None,
291
+ output_transform: nn.UnscaleTransform | None = None,
292
+ task_weights: np.ndarray | None = None,
293
+ ) -> models.MPNN:
294
+ """Build an MPNN model with specified hyperparameters."""
295
+ hidden_dim = hyperparameters["hidden_dim"]
296
+ depth = hyperparameters["depth"]
297
+ dropout = hyperparameters["dropout"]
298
+ ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
299
+ ffn_num_layers = hyperparameters["ffn_num_layers"]
300
+
301
+ mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
302
+ agg = nn.NormAggregation()
303
+ ffn_input_dim = hidden_dim + n_extra_descriptors
304
+
305
+ if task == "classification" and num_classes is not None:
306
+ ffn = nn.MulticlassClassificationFFN(
307
+ n_classes=num_classes, input_dim=ffn_input_dim,
308
+ hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
309
+ )
310
+ else:
311
+ # Map criterion name to ChemProp metric class (must have .clone() method)
312
+ from chemprop.nn.metrics import MAE, MSE
313
+
314
+ criterion_map = {
315
+ "mae": MAE,
316
+ "mse": MSE,
317
+ }
318
+ criterion_name = hyperparameters.get("criterion", "mae")
319
+ if criterion_name not in criterion_map:
320
+ raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
321
+ criterion = criterion_map[criterion_name]()
322
+
323
+ weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
324
+ ffn = nn.RegressionFFN(
325
+ input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
326
+ dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
327
+ criterion=criterion,
328
+ )
329
+
330
+ return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
454
331
 
455
- # Template Parameters
456
- target = TEMPLATE_PARAMS["target"]
332
+ # -------------------------------------------------------------------------
333
+ # Setup: Parse arguments and load data
334
+ # -------------------------------------------------------------------------
335
+ parser = argparse.ArgumentParser()
336
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
337
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
338
+ parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
339
+ args = parser.parse_args()
340
+
341
+ # Extract template parameters
342
+ target_columns = TEMPLATE_PARAMS["targets"]
457
343
  model_type = TEMPLATE_PARAMS["model_type"]
458
344
  feature_list = TEMPLATE_PARAMS["feature_list"]
459
345
  id_column = TEMPLATE_PARAMS["id_column"]
460
346
  model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
461
- hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
347
+ hyperparameters = {**DEFAULT_HYPERPARAMETERS, **(TEMPLATE_PARAMS["hyperparameters"] or {})}
462
348
 
463
- # Get the SMILES column name from feature_list (user defines this, so we use their exact name)
464
- smiles_column = find_smiles_column(feature_list)
349
+ if not target_columns or not isinstance(target_columns, list):
350
+ raise ValueError("'targets' must be a non-empty list of target column names")
351
+ n_targets = len(target_columns)
352
+
353
+ smiles_column = _find_smiles_column(feature_list)
465
354
  extra_feature_cols = [f for f in feature_list if f != smiles_column]
466
355
  use_extra_features = len(extra_feature_cols) > 0
467
- print(f"Feature List: {feature_list}")
468
- print(f"SMILES Column: {smiles_column}")
469
- print(
470
- f"Extra Features (hybrid mode): {extra_feature_cols if use_extra_features else 'None (SMILES only)'}"
471
- )
472
356
 
473
- # Script arguments for input/output directories
474
- parser = argparse.ArgumentParser()
475
- parser.add_argument(
476
- "--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
477
- )
478
- parser.add_argument(
479
- "--train",
480
- type=str,
481
- default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"),
482
- )
483
- parser.add_argument(
484
- "--output-data-dir",
485
- type=str,
486
- default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"),
487
- )
488
- args = parser.parse_args()
357
+ print(f"Target columns ({n_targets}): {target_columns}")
358
+ print(f"SMILES column: {smiles_column}")
359
+ print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
360
+ print(f"Hyperparameters: {hyperparameters}")
489
361
 
490
- # Read the training data
491
- training_files = [
492
- os.path.join(args.train, f)
493
- for f in os.listdir(args.train)
494
- if f.endswith(".csv")
495
- ]
362
+ # Load training data
363
+ training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
496
364
  print(f"Training Files: {training_files}")
497
-
498
365
  all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
499
- print(f"All Data Shape: {all_df.shape}")
500
-
501
366
  check_dataframe(all_df, "training_df")
502
367
 
503
- # Drop rows with missing SMILES or target values
368
+ # Clean data
504
369
  initial_count = len(all_df)
505
- all_df = all_df.dropna(subset=[smiles_column, target])
506
- dropped = initial_count - len(all_df)
507
- if dropped > 0:
508
- print(f"Dropped {dropped} rows with missing SMILES or target values")
509
-
510
- print(f"Target: {target}")
511
- print(f"Data Shape after cleaning: {all_df.shape}")
512
-
513
- # Set up label encoder for classification
370
+ all_df = all_df.dropna(subset=[smiles_column])
371
+ all_df = all_df[all_df[target_columns].notna().any(axis=1)]
372
+ if len(all_df) < initial_count:
373
+ print(f"Dropped {initial_count - len(all_df)} rows with missing SMILES/targets")
374
+
375
+ print(f"Data shape: {all_df.shape}")
376
+ for tc in target_columns:
377
+ print(f" {tc}: {all_df[tc].notna().sum()} samples")
378
+
379
+ # -------------------------------------------------------------------------
380
+ # Classification setup
381
+ # -------------------------------------------------------------------------
514
382
  label_encoder = None
383
+ num_classes = None
515
384
  if model_type == "classifier":
385
+ if n_targets > 1:
386
+ raise ValueError("Multi-task classification not supported")
516
387
  label_encoder = LabelEncoder()
517
- all_df[target] = label_encoder.fit_transform(all_df[target])
388
+ all_df[target_columns[0]] = label_encoder.fit_transform(all_df[target_columns[0]])
518
389
  num_classes = len(label_encoder.classes_)
519
- print(
520
- f"Classification task with {num_classes} classes: {label_encoder.classes_}"
521
- )
522
- else:
523
- num_classes = None
390
+ print(f"Classification: {num_classes} classes: {label_encoder.classes_}")
524
391
 
525
- # Model and training configuration
526
- print(f"Hyperparameters: {hyperparameters}")
392
+ # -------------------------------------------------------------------------
393
+ # Prepare features
394
+ # -------------------------------------------------------------------------
527
395
  task = "classification" if model_type == "classifier" else "regression"
528
396
  n_extra = len(extra_feature_cols) if use_extra_features else 0
529
- max_epochs = hyperparameters.get("max_epochs", 200)
530
- patience = hyperparameters.get("patience", 20)
531
- n_folds = hyperparameters.get("n_folds", 5) # Number of CV folds (default: 5)
532
- batch_size = hyperparameters.get("batch_size", min(64, max(16, len(all_df) // 16)))
533
397
 
534
- # Check extra feature columns exist
535
- if use_extra_features:
536
- missing_cols = [col for col in extra_feature_cols if col not in all_df.columns]
537
- if missing_cols:
538
- raise ValueError(f"Missing extra feature columns in training data: {missing_cols}")
539
-
540
- # =========================================================================
541
- # UNIFIED TRAINING: Works for n_folds=1 (single model) or n_folds>1 (K-fold CV)
542
- # =========================================================================
543
- print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold cross-validation ensemble'}...")
544
-
545
- # Prepare extra features and validate SMILES upfront
546
- all_extra_features = None
547
- col_means = None
398
+ all_extra_features, col_means = None, None
548
399
  if use_extra_features:
549
400
  all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
550
401
  col_means = np.nanmean(all_extra_features, axis=0)
551
402
  for i in range(all_extra_features.shape[1]):
552
403
  all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
553
404
 
554
- # Filter invalid SMILES from the full dataset
555
- _, valid_indices = create_molecule_datapoints(
556
- all_df[smiles_column].tolist(), all_df[target].tolist(), all_extra_features
557
- )
405
+ all_targets = all_df[target_columns].values.astype(np.float32)
406
+
407
+ # Filter invalid SMILES
408
+ _, valid_indices = _create_molecule_datapoints(all_df[smiles_column].tolist(), all_targets, all_extra_features)
558
409
  all_df = all_df.iloc[valid_indices].reset_index(drop=True)
410
+ all_targets = all_targets[valid_indices]
559
411
  if all_extra_features is not None:
560
412
  all_extra_features = all_extra_features[valid_indices]
561
413
  print(f"Data after SMILES validation: {all_df.shape}")
562
414
 
563
- # Create fold splits
415
+ # Task weights for multi-task (inverse sample count)
416
+ task_weights = None
417
+ if n_targets > 1 and model_type != "classifier":
418
+ counts = np.array([np.sum(~np.isnan(all_targets[:, t])) for t in range(n_targets)])
419
+ task_weights = (1.0 / counts) / (1.0 / counts).min()
420
+ print(f"Task weights: {dict(zip(target_columns, task_weights.round(3)))}")
421
+
422
+ # -------------------------------------------------------------------------
423
+ # Cross-validation setup
424
+ # -------------------------------------------------------------------------
425
+ n_folds = hyperparameters["n_folds"]
426
+ batch_size = hyperparameters["batch_size"]
427
+
564
428
  if n_folds == 1:
565
- # Single fold: use train/val split from "training" column or random split
566
429
  if "training" in all_df.columns:
567
- print("Found training column, splitting data based on training column")
430
+ print("Using 'training' column for train/val split")
568
431
  train_idx = np.where(all_df["training"])[0]
569
432
  val_idx = np.where(~all_df["training"])[0]
570
433
  else:
571
- print("WARNING: No training column found, splitting data with random 80/20 split")
572
- indices = np.arange(len(all_df))
573
- train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
434
+ print("WARNING: No 'training' column, using random 80/20 split")
435
+ train_idx, val_idx = train_test_split(np.arange(len(all_df)), test_size=0.2, random_state=42)
574
436
  folds = [(train_idx, val_idx)]
575
437
  else:
576
- # K-Fold CV
577
438
  if model_type == "classifier":
578
439
  kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
579
- split_target = all_df[target]
440
+ folds = list(kfold.split(all_df, all_df[target_columns[0]]))
580
441
  else:
581
442
  kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
582
- split_target = None
583
- folds = list(kfold.split(all_df, split_target))
443
+ folds = list(kfold.split(all_df))
584
444
 
585
- # Initialize storage for out-of-fold predictions
586
- oof_predictions = np.full(len(all_df), np.nan, dtype=np.float64)
587
- if model_type == "classifier" and num_classes and num_classes > 1:
588
- oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
589
- else:
590
- oof_proba = None
445
+ print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
591
446
 
592
- ensemble_models = []
447
+ # -------------------------------------------------------------------------
448
+ # Training loop
449
+ # -------------------------------------------------------------------------
450
+ oof_predictions = np.full((len(all_df), n_targets), np.nan, dtype=np.float64)
451
+ oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64) if model_type == "classifier" and num_classes else None
593
452
 
453
+ ensemble_models = []
594
454
  for fold_idx, (train_idx, val_idx) in enumerate(folds):
595
455
  print(f"\n{'='*50}")
596
- print(f"Training Fold {fold_idx + 1}/{len(folds)}")
456
+ print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
597
457
  print(f"{'='*50}")
598
458
 
599
- # Split data for this fold
600
- df_train = all_df.iloc[train_idx].reset_index(drop=True)
601
- df_val = all_df.iloc[val_idx].reset_index(drop=True)
602
-
459
+ # Split data
460
+ df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
461
+ train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
603
462
  train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
604
463
  val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
605
-
606
- print(f"Fold {fold_idx + 1} - Train: {len(df_train)}, Val: {len(df_val)}")
607
-
608
- # Create ChemProp datasets for this fold
609
- train_datapoints, _ = create_molecule_datapoints(
610
- df_train[smiles_column].tolist(), df_train[target].tolist(), train_extra
611
- )
612
- val_datapoints, _ = create_molecule_datapoints(
613
- df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra
614
- )
615
-
616
- train_dataset = data.MoleculeDataset(train_datapoints)
617
- val_dataset = data.MoleculeDataset(val_datapoints)
618
-
619
- # Save raw val features for prediction
620
464
  val_extra_raw = val_extra.copy() if val_extra is not None else None
621
465
 
622
- # Scale features and targets for this fold
466
+ # Create datasets
467
+ train_dps, _ = _create_molecule_datapoints(df_train[smiles_column].tolist(), train_targets, train_extra)
468
+ val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra)
469
+ train_dataset, val_dataset = data.MoleculeDataset(train_dps), data.MoleculeDataset(val_dps)
470
+
471
+ # Scale features/targets
623
472
  x_d_transform = None
624
473
  if use_extra_features:
625
- feature_scaler = train_dataset.normalize_inputs("X_d")
626
- val_dataset.normalize_inputs("X_d", feature_scaler)
627
- x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
474
+ scaler = train_dataset.normalize_inputs("X_d")
475
+ val_dataset.normalize_inputs("X_d", scaler)
476
+ x_d_transform = nn.ScaleTransform.from_standard_scaler(scaler)
628
477
 
629
478
  output_transform = None
630
479
  if model_type in ["regressor", "uq_regressor"]:
@@ -632,31 +481,27 @@ if __name__ == "__main__":
632
481
  val_dataset.normalize_targets(target_scaler)
633
482
  output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
634
483
 
635
- train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
636
- val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
484
+ train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
485
+ val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
637
486
 
638
- # Build and train model for this fold
639
- pl.seed_everything(42 + fold_idx)
487
+ # Build and train model
488
+ pl.seed_everything(hyperparameters["seed"] + fold_idx)
640
489
  mpnn = build_mpnn_model(
641
- hyperparameters, task=task, num_classes=num_classes,
642
- n_extra_descriptors=n_extra, x_d_transform=x_d_transform, output_transform=output_transform,
490
+ hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
491
+ n_extra_descriptors=n_extra, x_d_transform=x_d_transform,
492
+ output_transform=output_transform, task_weights=task_weights,
643
493
  )
644
494
 
645
- callbacks = [
646
- pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
647
- pl.callbacks.ModelCheckpoint(
648
- dirpath=args.model_dir, filename=f"best_model_{fold_idx}",
649
- monitor="val_loss", mode="min", save_top_k=1,
650
- ),
651
- ]
652
-
653
495
  trainer = pl.Trainer(
654
- accelerator="auto", max_epochs=max_epochs, callbacks=callbacks,
655
- logger=False, enable_progress_bar=True,
496
+ accelerator="auto", max_epochs=hyperparameters["max_epochs"], logger=False, enable_progress_bar=True,
497
+ callbacks=[
498
+ pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min"),
499
+ pl.callbacks.ModelCheckpoint(dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1),
500
+ ],
656
501
  )
657
-
658
502
  trainer.fit(mpnn, train_loader, val_loader)
659
503
 
504
+ # Load best checkpoint
660
505
  if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
661
506
  checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
662
507
  mpnn.load_state_dict(checkpoint["state_dict"])
@@ -664,189 +509,141 @@ if __name__ == "__main__":
664
509
  mpnn.eval()
665
510
  ensemble_models.append(mpnn)
666
511
 
667
- # Make out-of-fold predictions using raw features
668
- val_datapoints_raw, _ = create_molecule_datapoints(
669
- df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra_raw
670
- )
671
- val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
672
- val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
512
+ # Out-of-fold predictions (using raw features)
513
+ val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
514
+ val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
673
515
 
674
516
  with torch.inference_mode():
675
- fold_predictions = trainer.predict(mpnn, val_loader_pred)
676
- fold_preds = np.concatenate([p.numpy() for p in fold_predictions], axis=0)
517
+ fold_preds = np.concatenate([p.numpy() for p in trainer.predict(mpnn, val_loader_pred)], axis=0)
677
518
  if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
678
519
  fold_preds = fold_preds.squeeze(axis=1)
679
520
 
680
- # Store out-of-fold predictions
681
521
  if model_type == "classifier" and fold_preds.ndim == 2:
682
- oof_predictions[val_idx] = np.argmax(fold_preds, axis=1)
522
+ oof_predictions[val_idx, 0] = np.argmax(fold_preds, axis=1)
683
523
  if oof_proba is not None:
684
524
  oof_proba[val_idx] = fold_preds
685
525
  else:
686
- oof_predictions[val_idx] = fold_preds.flatten()
687
-
688
- print(f"Fold {fold_idx + 1} complete!")
526
+ if fold_preds.ndim == 1:
527
+ fold_preds = fold_preds.reshape(-1, 1)
528
+ oof_predictions[val_idx] = fold_preds
689
529
 
690
530
  print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
691
531
 
692
- # Use out-of-fold predictions for metrics
693
- # For n_folds=1, we only have predictions for val_idx, so filter to those rows
532
+ # -------------------------------------------------------------------------
533
+ # Prepare validation results
534
+ # -------------------------------------------------------------------------
694
535
  if n_folds == 1:
695
- val_mask = ~np.isnan(oof_predictions)
696
- preds = oof_predictions[val_mask]
536
+ val_mask = ~np.isnan(oof_predictions).all(axis=1)
697
537
  df_val = all_df[val_mask].copy()
698
- y_validate = df_val[target].values
538
+ preds = oof_predictions[val_mask]
539
+ y_validate = all_targets[val_mask]
699
540
  if oof_proba is not None:
700
541
  oof_proba = oof_proba[val_mask]
701
542
  val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
702
543
  else:
703
- preds = oof_predictions
704
544
  df_val = all_df.copy()
705
- y_validate = all_df[target].values
545
+ preds = oof_predictions
546
+ y_validate = all_targets
706
547
  val_extra_features = all_extra_features
707
548
 
708
- # Compute prediction_std by running all ensemble models on validation data
709
- # For n_folds=1, std will be 0 (only one model). For n_folds>1, std shows ensemble disagreement.
710
- preds_std = None
711
- if model_type in ["regressor", "uq_regressor"] and len(ensemble_models) > 0:
712
- print("Computing prediction_std from ensemble predictions on validation data...")
713
- val_datapoints_for_std, _ = create_molecule_datapoints(
714
- df_val[smiles_column].tolist(),
715
- df_val[target].tolist(),
716
- val_extra_features
717
- )
718
- val_dataset_for_std = data.MoleculeDataset(val_datapoints_for_std)
719
- val_loader_for_std = data.build_dataloader(val_dataset_for_std, batch_size=batch_size, shuffle=False)
720
-
721
- all_ensemble_preds_for_std = []
722
- trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
723
- for ens_model in ensemble_models:
724
- with torch.inference_mode():
725
- ens_preds = trainer_pred.predict(ens_model, val_loader_for_std)
726
- ens_preds = np.concatenate([p.numpy() for p in ens_preds], axis=0)
727
- if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
728
- ens_preds = ens_preds.squeeze(axis=1)
729
- all_ensemble_preds_for_std.append(ens_preds.flatten())
730
-
731
- ensemble_preds_stacked = np.stack(all_ensemble_preds_for_std, axis=0)
732
- preds_std = np.std(ensemble_preds_stacked, axis=0)
733
- print(f"Ensemble prediction_std - mean: {np.mean(preds_std):.4f}, max: {np.max(preds_std):.4f}")
734
-
549
+ # -------------------------------------------------------------------------
550
+ # Compute metrics and prepare output
551
+ # -------------------------------------------------------------------------
552
+ median_std = None # Only set for regression models with ensemble
735
553
  if model_type == "classifier":
736
- # Classification metrics - preds contains class indices from OOF predictions
737
- class_preds = preds.astype(int)
738
- has_proba = oof_proba is not None
739
-
740
- print(f"class_preds shape: {class_preds.shape}")
741
-
742
- # Decode labels for metrics
743
- y_validate_decoded = label_encoder.inverse_transform(y_validate.astype(int))
554
+ class_preds = preds[:, 0].astype(int)
555
+ target_name = target_columns[0]
556
+ y_true_decoded = label_encoder.inverse_transform(y_validate[:, 0].astype(int))
744
557
  preds_decoded = label_encoder.inverse_transform(class_preds)
745
558
 
746
- # Calculate metrics
747
- label_names = label_encoder.classes_
748
- scores = precision_recall_fscore_support(
749
- y_validate_decoded, preds_decoded, average=None, labels=label_names
750
- )
751
-
752
- score_df = pd.DataFrame(
753
- {
754
- target: label_names,
755
- "precision": scores[0],
756
- "recall": scores[1],
757
- "f1": scores[2],
758
- "support": scores[3],
759
- }
760
- )
761
-
762
- # Output metrics per class
763
- metrics = ["precision", "recall", "f1", "support"]
764
- for t in label_names:
765
- for m in metrics:
766
- value = score_df.loc[score_df[target] == t, m].iloc[0]
767
- print(f"Metrics:{t}:{m} {value}")
559
+ score_df = compute_classification_metrics(y_true_decoded, preds_decoded, label_encoder.classes_, target_name)
560
+ print_classification_metrics(score_df, target_name, label_encoder.classes_)
561
+ print_confusion_matrix(y_true_decoded, preds_decoded, label_encoder.classes_)
768
562
 
769
- # Confusion matrix
770
- conf_mtx = confusion_matrix(
771
- y_validate_decoded, preds_decoded, labels=label_names
772
- )
773
- for i, row_name in enumerate(label_names):
774
- for j, col_name in enumerate(label_names):
775
- value = conf_mtx[i, j]
776
- print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
777
-
778
- # Save validation predictions
779
- df_val = df_val.copy()
563
+ # Decode target column back to string labels (was encoded for training)
564
+ df_val[target_name] = y_true_decoded
780
565
  df_val["prediction"] = preds_decoded
781
- if has_proba and oof_proba is not None:
566
+ if oof_proba is not None:
782
567
  df_val["pred_proba"] = [p.tolist() for p in oof_proba]
783
- df_val = expand_proba_column(df_val, label_names)
784
-
568
+ df_val = expand_proba_column(df_val, label_encoder.classes_)
785
569
  else:
786
- # Regression metrics
787
- preds_flat = preds.flatten()
788
- rmse = root_mean_squared_error(y_validate, preds_flat)
789
- mae = mean_absolute_error(y_validate, preds_flat)
790
- medae = median_absolute_error(y_validate, preds_flat)
791
- r2 = r2_score(y_validate, preds_flat)
792
- spearman_corr = spearmanr(y_validate, preds_flat).correlation
793
- support = len(df_val)
794
- print(f"rmse: {rmse:.3f}")
795
- print(f"mae: {mae:.3f}")
796
- print(f"medae: {medae:.3f}")
797
- print(f"r2: {r2:.3f}")
798
- print(f"spearmanr: {spearman_corr:.3f}")
799
- print(f"support: {support}")
800
-
801
- df_val = df_val.copy()
802
- df_val["prediction"] = preds_flat
803
-
804
- # Add prediction_std (always present for regressors, 0 for single model)
805
- if preds_std is not None:
806
- df_val["prediction_std"] = preds_std.flatten()
807
- else:
808
- df_val["prediction_std"] = 0.0
809
- print(f"Ensemble std - mean: {df_val['prediction_std'].mean():.4f}, max: {df_val['prediction_std'].max():.4f}")
810
-
570
+ # Compute ensemble std
571
+ preds_std = None
572
+ if len(ensemble_models) > 1:
573
+ print("Computing prediction_std from ensemble...")
574
+ val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), y_validate, val_extra_features)
575
+ val_loader = data.build_dataloader(data.MoleculeDataset(val_dps), batch_size=batch_size, shuffle=False)
576
+ trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
577
+
578
+ all_ens_preds = []
579
+ for m in ensemble_models:
580
+ with torch.inference_mode():
581
+ ens_preds = np.concatenate([p.numpy() for p in trainer_pred.predict(m, val_loader)], axis=0)
582
+ if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
583
+ ens_preds = ens_preds.squeeze(axis=1)
584
+ all_ens_preds.append(ens_preds)
585
+ preds_std = np.std(np.stack(all_ens_preds), axis=0)
586
+ if preds_std.ndim == 1:
587
+ preds_std = preds_std.reshape(-1, 1)
588
+
589
+ print("\n--- Per-target metrics ---")
590
+ for t_idx, t_name in enumerate(target_columns):
591
+ valid_mask = ~np.isnan(y_validate[:, t_idx])
592
+ if valid_mask.sum() > 0:
593
+ metrics = compute_regression_metrics(y_validate[valid_mask, t_idx], preds[valid_mask, t_idx])
594
+ print_regression_metrics(metrics)
595
+
596
+ df_val[f"{t_name}_pred"] = preds[:, t_idx]
597
+ df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx] if preds_std is not None else 0.0
598
+
599
+ df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
600
+ df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
601
+
602
+ # Compute confidence from ensemble std
603
+ median_std = float(np.median(preds_std[:, 0]))
604
+ print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
605
+ df_val = _compute_std_confidence(df_val, median_std)
606
+ print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
607
+
608
+ # -------------------------------------------------------------------------
811
609
  # Save validation predictions to S3
812
- # Include id_column if it exists in df_val
813
- output_columns = []
814
- if id_column in df_val.columns:
815
- output_columns.append(id_column)
816
- output_columns += [target, "prediction"]
817
- if "prediction_std" in df_val.columns:
818
- output_columns.append("prediction_std")
819
- output_columns += [col for col in df_val.columns if col.endswith("_proba")]
820
- wr.s3.to_csv(
821
- df_val[output_columns],
822
- path=f"{model_metrics_s3_path}/validation_predictions.csv",
823
- index=False,
824
- )
825
-
826
- # Save ensemble models (n_folds models if CV, 1 model otherwise)
827
- for model_idx, ens_model in enumerate(ensemble_models):
828
- model_path = os.path.join(args.model_dir, f"chemprop_model_{model_idx}.pt")
829
- models.save_model(model_path, ens_model)
830
- print(f"Saved model {model_idx + 1} to {model_path}")
831
-
832
- # Save ensemble metadata (n_ensemble = number of models for inference)
833
- n_ensemble = len(ensemble_models)
834
- ensemble_metadata = {"n_ensemble": n_ensemble, "n_folds": n_folds}
610
+ # -------------------------------------------------------------------------
611
+ output_columns = [id_column] if id_column in df_val.columns else []
612
+ output_columns += target_columns
613
+ output_columns += [f"{t}_pred" for t in target_columns] + [f"{t}_pred_std" for t in target_columns]
614
+ output_columns += ["prediction", "prediction_std", "confidence"]
615
+ output_columns += [c for c in df_val.columns if c.endswith("_proba")]
616
+ output_columns = [c for c in output_columns if c in df_val.columns]
617
+
618
+ wr.s3.to_csv(df_val[output_columns], f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
619
+
620
+ # -------------------------------------------------------------------------
621
+ # Save model artifacts
622
+ # -------------------------------------------------------------------------
623
+ for idx, m in enumerate(ensemble_models):
624
+ models.save_model(os.path.join(args.model_dir, f"chemprop_model_{idx}.pt"), m)
625
+ print(f"Saved {len(ensemble_models)} model(s)")
626
+
627
+ # Clean up checkpoints
628
+ for ckpt in glob.glob(os.path.join(args.model_dir, "best_*.ckpt")):
629
+ os.remove(ckpt)
630
+
631
+ ensemble_metadata = {
632
+ "n_ensemble": len(ensemble_models),
633
+ "n_folds": n_folds,
634
+ "target_columns": target_columns,
635
+ "median_std": median_std, # For confidence calculation during inference
636
+ }
835
637
  joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
836
- print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds})")
837
638
 
838
- # Save label encoder if classification
839
- if label_encoder is not None:
639
+ with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
640
+ json.dump(hyperparameters, f, indent=2)
641
+
642
+ if label_encoder:
840
643
  joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
841
644
 
842
- # Save extra feature metadata for inference (hybrid mode)
843
- # Note: We don't need to save the scaler - X_d_transform is embedded in the model
844
645
  if use_extra_features:
845
- feature_metadata = {
846
- "extra_feature_cols": extra_feature_cols,
847
- "col_means": col_means.tolist(), # Unscaled means for NaN imputation
848
- }
849
- joblib.dump(
850
- feature_metadata, os.path.join(args.model_dir, "feature_metadata.joblib")
851
- )
646
+ joblib.dump({"extra_feature_cols": extra_feature_cols, "col_means": col_means.tolist()}, os.path.join(args.model_dir, "feature_metadata.joblib"))
852
647
  print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")
648
+
649
+ print(f"\nModel training complete! Artifacts saved to {args.model_dir}")