workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,852 @@
1
+ # ChemProp Model Template for Workbench
2
+ # Uses ChemProp 2.x Message Passing Neural Networks for molecular property prediction
3
+ #
4
+ # === CHEMPROP REVIEW NOTES ===
5
+ # This script runs on AWS SageMaker. Key areas for ChemProp review:
6
+ #
7
+ # 1. Model Architecture (build_mpnn_model function)
8
+ # - BondMessagePassing, NormAggregation, FFN configuration
9
+ # - Regression uses output_transform (UnscaleTransform) for target scaling
10
+ #
11
+ # 2. Data Handling (create_molecule_datapoints function)
12
+ # - MoleculeDatapoint creation with x_d (extra descriptors)
13
+ # - RDKit validation of SMILES
14
+ #
15
+ # 3. Scaling (training section)
16
+ # - Extra descriptors: normalize_inputs("X_d") + X_d_transform in model
17
+ # - Targets (regression): normalize_targets() + UnscaleTransform in FFN
18
+ # - At inference: pass RAW features, transforms handle scaling automatically
19
+ #
20
+ # 4. Training Loop (search for "pl.Trainer")
21
+ # - PyTorch Lightning Trainer with ChemProp MPNN
22
+ #
23
+ # AWS/SageMaker boilerplate (can skip):
24
+ # - input_fn, output_fn, model_fn: SageMaker serving interface
25
+ # - argparse, file loading, S3 writes
26
+ # =============================
27
+
28
+ import os
29
+ import argparse
30
+ import json
31
+ from io import StringIO
32
+
33
+ import awswrangler as wr
34
+ import numpy as np
35
+ import pandas as pd
36
+ import torch
37
+ from lightning import pytorch as pl
38
+ from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
39
+ from sklearn.preprocessing import LabelEncoder
40
+ from sklearn.metrics import (
41
+ mean_absolute_error,
42
+ median_absolute_error,
43
+ r2_score,
44
+ root_mean_squared_error,
45
+ precision_recall_fscore_support,
46
+ confusion_matrix,
47
+ )
48
+ from scipy.stats import spearmanr
49
+ import joblib
50
+
51
+ # ChemProp imports
52
+ from chemprop import data, models, nn
53
+
54
+ # Template Parameters
55
+ TEMPLATE_PARAMS = {
56
+ "model_type": "uq_regressor",
57
+ "target": "udm_asy_res_efflux_ratio",
58
+ "feature_list": ['smiles', 'smr_vsa4', 'tpsa', 'nhohcount', 'peoe_vsa1', 'mollogp', 'numhdonors', 'tertiary_amine_count', 'smr_vsa3', 'nitrogen_span', 'vsa_estate2', 'hba_hbd_ratio', 'minpartialcharge', 'estate_vsa4', 'asphericity', 'charge_centroid_distance', 'peoe_vsa8', 'mi', 'estate_vsa8', 'vsa_estate6', 'vsa_estate3', 'molecular_volume_3d', 'kappa3', 'smr_vsa5', 'sv', 'xp_6dv', 'xc_4dv', 'si', 'molecular_axis_length', 'axp_5d', 'estate_vsa3', 'estate_vsa10', 'axp_7dv', 'slogp_vsa1', 'molecular_asymmetry', 'molmr', 'qed', 'xp_3d', 'axp_0dv', 'fpdensitymorgan1', 'minabsestateindex', 'numatomstereocenters', 'fpdensitymorgan2', 'slogp_vsa2', 'xch_5dv', 'num_s_centers', 'aromatic_interaction_score', 'axp_2dv', 'chi1v', 'hallkieralpha', 'vsa_estate8'],
59
+ "id_column": "udm_mol_bat_id",
60
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-chemprop-reg-hybrid/training",
61
+ "hyperparameters": {'n_folds': 5, 'hidden_dim': 300, 'depth': 4, 'dropout': 0.1, 'ffn_hidden_dim': 300, 'ffn_num_layers': 2},
62
+ }
63
+
64
+
65
+ def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
66
+ """Check if the provided dataframe is empty and raise an exception if it is."""
67
+ if df.empty:
68
+ msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
69
+ print(msg)
70
+ raise ValueError(msg)
71
+
72
+
73
+ def find_smiles_column(columns: list[str]) -> str:
74
+ """Find the SMILES column name from a list (case-insensitive match for 'smiles')."""
75
+ smiles_column = next((col for col in columns if col.lower() == "smiles"), None)
76
+ if smiles_column is None:
77
+ raise ValueError(
78
+ "Column list must contain a 'smiles' column (case-insensitive)"
79
+ )
80
+ return smiles_column
81
+
82
+
83
+ def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
84
+ """Expands a column containing a list of probabilities into separate columns.
85
+
86
+ Handles None values for rows where predictions couldn't be made.
87
+ """
88
+ proba_column = "pred_proba"
89
+ if proba_column not in df.columns:
90
+ raise ValueError('DataFrame does not contain a "pred_proba" column')
91
+
92
+ proba_splits = [f"{label}_proba" for label in class_labels]
93
+ n_classes = len(class_labels)
94
+
95
+ # Handle None values by replacing with list of NaNs
96
+ proba_values = []
97
+ for val in df[proba_column]:
98
+ if val is None:
99
+ proba_values.append([np.nan] * n_classes)
100
+ else:
101
+ proba_values.append(val)
102
+
103
+ proba_df = pd.DataFrame(proba_values, columns=proba_splits)
104
+
105
+ df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
106
+ df = df.reset_index(drop=True)
107
+ df = pd.concat([df, proba_df], axis=1)
108
+ return df
109
+
110
+
111
+ def create_molecule_datapoints(
112
+ smiles_list: list[str],
113
+ targets: list[float] | None = None,
114
+ extra_descriptors: np.ndarray | None = None,
115
+ ) -> tuple[list[data.MoleculeDatapoint], list[int]]:
116
+ """Create ChemProp MoleculeDatapoints from SMILES strings.
117
+
118
+ Args:
119
+ smiles_list: List of SMILES strings
120
+ targets: Optional list of target values (for training)
121
+ extra_descriptors: Optional array of extra features (n_samples, n_features)
122
+
123
+ Returns:
124
+ Tuple of (list of MoleculeDatapoint objects, list of valid indices)
125
+ """
126
+ from rdkit import Chem
127
+
128
+ datapoints = []
129
+ valid_indices = []
130
+ invalid_count = 0
131
+
132
+ for i, smi in enumerate(smiles_list):
133
+ # Validate SMILES with RDKit first
134
+ mol = Chem.MolFromSmiles(smi)
135
+ if mol is None:
136
+ invalid_count += 1
137
+ continue
138
+
139
+ # Build datapoint with optional target and extra descriptors
140
+ y = [targets[i]] if targets is not None else None
141
+ x_d = extra_descriptors[i] if extra_descriptors is not None else None
142
+
143
+ dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
144
+ datapoints.append(dp)
145
+ valid_indices.append(i)
146
+
147
+ if invalid_count > 0:
148
+ print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
149
+
150
+ return datapoints, valid_indices
151
+
152
+
153
+ def build_mpnn_model(
154
+ hyperparameters: dict,
155
+ task: str = "regression",
156
+ num_classes: int | None = None,
157
+ n_extra_descriptors: int = 0,
158
+ x_d_transform: nn.ScaleTransform | None = None,
159
+ output_transform: nn.UnscaleTransform | None = None,
160
+ ) -> models.MPNN:
161
+ """Build an MPNN model with the specified hyperparameters.
162
+
163
+ Args:
164
+ hyperparameters: Dictionary of model hyperparameters
165
+ task: Either "regression" or "classification"
166
+ num_classes: Number of classes for classification tasks
167
+ n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
168
+ x_d_transform: Optional transform for extra descriptors (scaling)
169
+ output_transform: Optional transform for regression output (unscaling targets)
170
+
171
+ Returns:
172
+ Configured MPNN model
173
+ """
174
+ # Model hyperparameters with defaults
175
+ hidden_dim = hyperparameters.get("hidden_dim", 300)
176
+ depth = hyperparameters.get("depth", 4)
177
+ dropout = hyperparameters.get("dropout", 0.1)
178
+ ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 300)
179
+ ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
180
+
181
+ # Message passing component
182
+ mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
183
+
184
+ # Aggregation - NormAggregation normalizes output, recommended when using extra descriptors
185
+ agg = nn.NormAggregation()
186
+
187
+ # FFN input_dim = message passing output + extra descriptors
188
+ ffn_input_dim = hidden_dim + n_extra_descriptors
189
+
190
+ # Build FFN based on task type
191
+ if task == "classification" and num_classes is not None:
192
+ # Multi-class classification
193
+ ffn = nn.MulticlassClassificationFFN(
194
+ n_classes=num_classes,
195
+ input_dim=ffn_input_dim,
196
+ hidden_dim=ffn_hidden_dim,
197
+ n_layers=ffn_num_layers,
198
+ dropout=dropout,
199
+ )
200
+ else:
201
+ # Regression with optional output transform to unscale predictions
202
+ ffn = nn.RegressionFFN(
203
+ input_dim=ffn_input_dim,
204
+ hidden_dim=ffn_hidden_dim,
205
+ n_layers=ffn_num_layers,
206
+ dropout=dropout,
207
+ output_transform=output_transform,
208
+ )
209
+
210
+ # Create the MPNN model
211
+ mpnn = models.MPNN(
212
+ message_passing=mp,
213
+ agg=agg,
214
+ predictor=ffn,
215
+ batch_norm=True,
216
+ metrics=None,
217
+ X_d_transform=x_d_transform,
218
+ )
219
+
220
+ return mpnn
221
+
222
+
223
+ def model_fn(model_dir: str) -> dict:
224
+ """Load the ChemProp MPNN ensemble models from the specified directory.
225
+
226
+ Args:
227
+ model_dir: Directory containing the saved models
228
+
229
+ Returns:
230
+ Dictionary with ensemble models and metadata
231
+ """
232
+ # Load ensemble metadata
233
+ ensemble_metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
234
+ if os.path.exists(ensemble_metadata_path):
235
+ ensemble_metadata = joblib.load(ensemble_metadata_path)
236
+ n_ensemble = ensemble_metadata["n_ensemble"]
237
+ else:
238
+ # Backwards compatibility: single model without ensemble metadata
239
+ n_ensemble = 1
240
+
241
+ # Load all ensemble models
242
+ ensemble_models = []
243
+ for ens_idx in range(n_ensemble):
244
+ model_path = os.path.join(model_dir, f"chemprop_model_{ens_idx}.pt")
245
+ if not os.path.exists(model_path):
246
+ # Backwards compatibility: try old single model path
247
+ model_path = os.path.join(model_dir, "chemprop_model.pt")
248
+ model = models.MPNN.load_from_file(model_path)
249
+ model.eval()
250
+ ensemble_models.append(model)
251
+
252
+ print(f"Loaded {len(ensemble_models)} ensemble model(s)")
253
+
254
+ return {
255
+ "ensemble_models": ensemble_models,
256
+ "n_ensemble": n_ensemble,
257
+ }
258
+
259
+
260
+ def input_fn(input_data, content_type: str) -> pd.DataFrame:
261
+ """Parse input data and return a DataFrame."""
262
+ if not input_data:
263
+ raise ValueError("Empty input data is not supported!")
264
+
265
+ if isinstance(input_data, bytes):
266
+ input_data = input_data.decode("utf-8")
267
+
268
+ if "text/csv" in content_type:
269
+ return pd.read_csv(StringIO(input_data))
270
+ elif "application/json" in content_type:
271
+ return pd.DataFrame(json.loads(input_data))
272
+ else:
273
+ raise ValueError(f"{content_type} not supported!")
274
+
275
+
276
+ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
277
+ """Supports both CSV and JSON output formats."""
278
+ if "text/csv" in accept_type:
279
+ csv_output = output_df.fillna("N/A").to_csv(index=False)
280
+ return csv_output, "text/csv"
281
+ elif "application/json" in accept_type:
282
+ return output_df.to_json(orient="records"), "application/json"
283
+ else:
284
+ raise RuntimeError(
285
+ f"{accept_type} accept type is not supported by this script."
286
+ )
287
+
288
+
289
+ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
290
+ """Make predictions with the ChemProp MPNN ensemble.
291
+
292
+ Args:
293
+ df: Input DataFrame containing SMILES column (and extra features if hybrid mode)
294
+ model_dict: Dictionary containing ensemble models and metadata
295
+
296
+ Returns:
297
+ DataFrame with predictions added (and prediction_std for ensembles)
298
+ """
299
+ model_type = TEMPLATE_PARAMS["model_type"]
300
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
301
+
302
+ # Extract ensemble models
303
+ ensemble_models = model_dict["ensemble_models"]
304
+ n_ensemble = model_dict["n_ensemble"]
305
+
306
+ # Load label encoder if present (classification)
307
+ label_encoder = None
308
+ label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
309
+ if os.path.exists(label_encoder_path):
310
+ label_encoder = joblib.load(label_encoder_path)
311
+
312
+ # Load feature metadata if present (hybrid mode)
313
+ # Contains column names, NaN fill values, and scaler for feature scaling
314
+ feature_metadata = None
315
+ feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
316
+ if os.path.exists(feature_metadata_path):
317
+ feature_metadata = joblib.load(feature_metadata_path)
318
+ print(
319
+ f"Hybrid mode: using {len(feature_metadata['extra_feature_cols'])} extra features"
320
+ )
321
+
322
+ # Find SMILES column in input DataFrame
323
+ smiles_column = find_smiles_column(df.columns.tolist())
324
+
325
+ smiles_list = df[smiles_column].tolist()
326
+
327
+ # Track invalid SMILES
328
+ valid_mask = []
329
+ valid_smiles = []
330
+ valid_indices = []
331
+ for i, smi in enumerate(smiles_list):
332
+ if smi and isinstance(smi, str) and len(smi.strip()) > 0:
333
+ valid_mask.append(True)
334
+ valid_smiles.append(smi.strip())
335
+ valid_indices.append(i)
336
+ else:
337
+ valid_mask.append(False)
338
+
339
+ valid_mask = np.array(valid_mask)
340
+ print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
341
+
342
+ # Initialize prediction column (use object dtype for classifiers to avoid FutureWarning)
343
+ if model_type == "classifier":
344
+ df["prediction"] = pd.Series([None] * len(df), dtype=object)
345
+ else:
346
+ # Regression (includes uq_regressor)
347
+ df["prediction"] = np.nan
348
+ df["prediction_std"] = np.nan
349
+
350
+ if sum(valid_mask) == 0:
351
+ print("Warning: No valid SMILES to predict on")
352
+ return df
353
+
354
+ # Prepare extra features if in hybrid mode
355
+ # NOTE: We pass RAW (unscaled) features here - the model's X_d_transform handles scaling
356
+ extra_features = None
357
+ if feature_metadata is not None:
358
+ extra_feature_cols = feature_metadata["extra_feature_cols"]
359
+ col_means = np.array(feature_metadata["col_means"])
360
+
361
+ # Check columns exist
362
+ missing_cols = [col for col in extra_feature_cols if col not in df.columns]
363
+ if missing_cols:
364
+ print(
365
+ f"Warning: Missing extra feature columns: {missing_cols}. Using mean values."
366
+ )
367
+
368
+ # Extract features for valid SMILES rows (raw, unscaled)
369
+ extra_features = np.zeros(
370
+ (len(valid_indices), len(extra_feature_cols)), dtype=np.float32
371
+ )
372
+ for j, col in enumerate(extra_feature_cols):
373
+ if col in df.columns:
374
+ values = df.iloc[valid_indices][col].values.astype(np.float32)
375
+ # Fill NaN with training column means (unscaled means)
376
+ nan_mask = np.isnan(values)
377
+ values[nan_mask] = col_means[j]
378
+ extra_features[:, j] = values
379
+ else:
380
+ # Column missing, use training mean
381
+ extra_features[:, j] = col_means[j]
382
+
383
+ # Create datapoints for prediction (filter out invalid SMILES)
384
+ datapoints, rdkit_valid_indices = create_molecule_datapoints(
385
+ valid_smiles, extra_descriptors=extra_features
386
+ )
387
+
388
+ if len(datapoints) == 0:
389
+ print("Warning: No valid SMILES after RDKit validation")
390
+ return df
391
+
392
+ dataset = data.MoleculeDataset(datapoints)
393
+ dataloader = data.build_dataloader(dataset, shuffle=False)
394
+
395
+ # Make predictions with ensemble
396
+ trainer = pl.Trainer(
397
+ accelerator="auto",
398
+ logger=False,
399
+ enable_progress_bar=False,
400
+ )
401
+
402
+ # Collect predictions from all ensemble members
403
+ all_ensemble_preds = []
404
+ for ens_idx, ens_model in enumerate(ensemble_models):
405
+ with torch.inference_mode():
406
+ predictions = trainer.predict(ens_model, dataloader)
407
+ ens_preds = np.concatenate([p.numpy() for p in predictions], axis=0)
408
+ # Squeeze middle dim if present
409
+ if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
410
+ ens_preds = ens_preds.squeeze(axis=1)
411
+ all_ensemble_preds.append(ens_preds)
412
+
413
+ # Stack and compute mean/std (std is 0 for single model)
414
+ ensemble_preds = np.stack(all_ensemble_preds, axis=0)
415
+ preds = np.mean(ensemble_preds, axis=0)
416
+ preds_std = np.std(ensemble_preds, axis=0) # Will be 0s for n_ensemble=1
417
+
418
+ print(f"Inference: Ensemble predictions shape: {preds.shape}")
419
+
420
+ # Map predictions back to valid_mask positions (accounting for RDKit-invalid SMILES)
421
+ # rdkit_valid_indices tells us which of the valid_smiles were actually valid
422
+ valid_positions = np.where(valid_mask)[0][rdkit_valid_indices]
423
+ valid_mask = np.zeros(len(df), dtype=bool)
424
+ valid_mask[valid_positions] = True
425
+
426
+ if model_type == "classifier" and label_encoder is not None:
427
+ # For classification, get class predictions and probabilities
428
+ if preds.ndim == 2 and preds.shape[1] > 1:
429
+ # Multi-class: preds are probabilities (averaged across ensemble)
430
+ class_preds = np.argmax(preds, axis=1)
431
+ decoded_preds = label_encoder.inverse_transform(class_preds)
432
+ df.loc[valid_mask, "prediction"] = decoded_preds
433
+
434
+ # Add probability columns
435
+ proba_series = pd.Series([None] * len(df), index=df.index, dtype=object)
436
+ proba_series.loc[valid_mask] = [p.tolist() for p in preds]
437
+ df["pred_proba"] = proba_series
438
+ df = expand_proba_column(df, label_encoder.classes_)
439
+ else:
440
+ # Binary or single output
441
+ class_preds = (preds.flatten() > 0.5).astype(int)
442
+ decoded_preds = label_encoder.inverse_transform(class_preds)
443
+ df.loc[valid_mask, "prediction"] = decoded_preds
444
+ else:
445
+ # Regression: direct predictions
446
+ df.loc[valid_mask, "prediction"] = preds.flatten()
447
+ df.loc[valid_mask, "prediction_std"] = preds_std.flatten()
448
+
449
+ return df
450
+
451
+
452
+ if __name__ == "__main__":
453
+ """Training script for ChemProp MPNN model"""
454
+
455
+ # Template Parameters
456
+ target = TEMPLATE_PARAMS["target"]
457
+ model_type = TEMPLATE_PARAMS["model_type"]
458
+ feature_list = TEMPLATE_PARAMS["feature_list"]
459
+ id_column = TEMPLATE_PARAMS["id_column"]
460
+ model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
461
+ hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
462
+
463
+ # Get the SMILES column name from feature_list (user defines this, so we use their exact name)
464
+ smiles_column = find_smiles_column(feature_list)
465
+ extra_feature_cols = [f for f in feature_list if f != smiles_column]
466
+ use_extra_features = len(extra_feature_cols) > 0
467
+ print(f"Feature List: {feature_list}")
468
+ print(f"SMILES Column: {smiles_column}")
469
+ print(
470
+ f"Extra Features (hybrid mode): {extra_feature_cols if use_extra_features else 'None (SMILES only)'}"
471
+ )
472
+
473
+ # Script arguments for input/output directories
474
+ parser = argparse.ArgumentParser()
475
+ parser.add_argument(
476
+ "--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
477
+ )
478
+ parser.add_argument(
479
+ "--train",
480
+ type=str,
481
+ default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"),
482
+ )
483
+ parser.add_argument(
484
+ "--output-data-dir",
485
+ type=str,
486
+ default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"),
487
+ )
488
+ args = parser.parse_args()
489
+
490
+ # Read the training data
491
+ training_files = [
492
+ os.path.join(args.train, f)
493
+ for f in os.listdir(args.train)
494
+ if f.endswith(".csv")
495
+ ]
496
+ print(f"Training Files: {training_files}")
497
+
498
+ all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
499
+ print(f"All Data Shape: {all_df.shape}")
500
+
501
+ check_dataframe(all_df, "training_df")
502
+
503
+ # Drop rows with missing SMILES or target values
504
+ initial_count = len(all_df)
505
+ all_df = all_df.dropna(subset=[smiles_column, target])
506
+ dropped = initial_count - len(all_df)
507
+ if dropped > 0:
508
+ print(f"Dropped {dropped} rows with missing SMILES or target values")
509
+
510
+ print(f"Target: {target}")
511
+ print(f"Data Shape after cleaning: {all_df.shape}")
512
+
513
+ # Set up label encoder for classification
514
+ label_encoder = None
515
+ if model_type == "classifier":
516
+ label_encoder = LabelEncoder()
517
+ all_df[target] = label_encoder.fit_transform(all_df[target])
518
+ num_classes = len(label_encoder.classes_)
519
+ print(
520
+ f"Classification task with {num_classes} classes: {label_encoder.classes_}"
521
+ )
522
+ else:
523
+ num_classes = None
524
+
525
+ # Model and training configuration
526
+ print(f"Hyperparameters: {hyperparameters}")
527
+ task = "classification" if model_type == "classifier" else "regression"
528
+ n_extra = len(extra_feature_cols) if use_extra_features else 0
529
+ max_epochs = hyperparameters.get("max_epochs", 200)
530
+ patience = hyperparameters.get("patience", 20)
531
+ n_folds = hyperparameters.get("n_folds", 5) # Number of CV folds (default: 5)
532
+ batch_size = hyperparameters.get("batch_size", min(64, max(16, len(all_df) // 16)))
533
+
534
+ # Check extra feature columns exist
535
+ if use_extra_features:
536
+ missing_cols = [col for col in extra_feature_cols if col not in all_df.columns]
537
+ if missing_cols:
538
+ raise ValueError(f"Missing extra feature columns in training data: {missing_cols}")
539
+
540
+ # =========================================================================
541
+ # UNIFIED TRAINING: Works for n_folds=1 (single model) or n_folds>1 (K-fold CV)
542
+ # =========================================================================
543
+ print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold cross-validation ensemble'}...")
544
+
545
+ # Prepare extra features and validate SMILES upfront
546
+ all_extra_features = None
547
+ col_means = None
548
+ if use_extra_features:
549
+ all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
550
+ col_means = np.nanmean(all_extra_features, axis=0)
551
+ for i in range(all_extra_features.shape[1]):
552
+ all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
553
+
554
+ # Filter invalid SMILES from the full dataset
555
+ _, valid_indices = create_molecule_datapoints(
556
+ all_df[smiles_column].tolist(), all_df[target].tolist(), all_extra_features
557
+ )
558
+ all_df = all_df.iloc[valid_indices].reset_index(drop=True)
559
+ if all_extra_features is not None:
560
+ all_extra_features = all_extra_features[valid_indices]
561
+ print(f"Data after SMILES validation: {all_df.shape}")
562
+
563
+ # Create fold splits
564
+ if n_folds == 1:
565
+ # Single fold: use train/val split from "training" column or random split
566
+ if "training" in all_df.columns:
567
+ print("Found training column, splitting data based on training column")
568
+ train_idx = np.where(all_df["training"])[0]
569
+ val_idx = np.where(~all_df["training"])[0]
570
+ else:
571
+ print("WARNING: No training column found, splitting data with random 80/20 split")
572
+ indices = np.arange(len(all_df))
573
+ train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
574
+ folds = [(train_idx, val_idx)]
575
+ else:
576
+ # K-Fold CV
577
+ if model_type == "classifier":
578
+ kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
579
+ split_target = all_df[target]
580
+ else:
581
+ kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
582
+ split_target = None
583
+ folds = list(kfold.split(all_df, split_target))
584
+
585
+ # Initialize storage for out-of-fold predictions
586
+ oof_predictions = np.full(len(all_df), np.nan, dtype=np.float64)
587
+ if model_type == "classifier" and num_classes and num_classes > 1:
588
+ oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
589
+ else:
590
+ oof_proba = None
591
+
592
+ ensemble_models = []
593
+
594
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
595
+ print(f"\n{'='*50}")
596
+ print(f"Training Fold {fold_idx + 1}/{len(folds)}")
597
+ print(f"{'='*50}")
598
+
599
+ # Split data for this fold
600
+ df_train = all_df.iloc[train_idx].reset_index(drop=True)
601
+ df_val = all_df.iloc[val_idx].reset_index(drop=True)
602
+
603
+ train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
604
+ val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
605
+
606
+ print(f"Fold {fold_idx + 1} - Train: {len(df_train)}, Val: {len(df_val)}")
607
+
608
+ # Create ChemProp datasets for this fold
609
+ train_datapoints, _ = create_molecule_datapoints(
610
+ df_train[smiles_column].tolist(), df_train[target].tolist(), train_extra
611
+ )
612
+ val_datapoints, _ = create_molecule_datapoints(
613
+ df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra
614
+ )
615
+
616
+ train_dataset = data.MoleculeDataset(train_datapoints)
617
+ val_dataset = data.MoleculeDataset(val_datapoints)
618
+
619
+ # Save raw val features for prediction
620
+ val_extra_raw = val_extra.copy() if val_extra is not None else None
621
+
622
+ # Scale features and targets for this fold
623
+ x_d_transform = None
624
+ if use_extra_features:
625
+ feature_scaler = train_dataset.normalize_inputs("X_d")
626
+ val_dataset.normalize_inputs("X_d", feature_scaler)
627
+ x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
628
+
629
+ output_transform = None
630
+ if model_type in ["regressor", "uq_regressor"]:
631
+ target_scaler = train_dataset.normalize_targets()
632
+ val_dataset.normalize_targets(target_scaler)
633
+ output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
634
+
635
+ train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
636
+ val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
637
+
638
+ # Build and train model for this fold
639
+ pl.seed_everything(42 + fold_idx)
640
+ mpnn = build_mpnn_model(
641
+ hyperparameters, task=task, num_classes=num_classes,
642
+ n_extra_descriptors=n_extra, x_d_transform=x_d_transform, output_transform=output_transform,
643
+ )
644
+
645
+ callbacks = [
646
+ pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
647
+ pl.callbacks.ModelCheckpoint(
648
+ dirpath=args.model_dir, filename=f"best_model_{fold_idx}",
649
+ monitor="val_loss", mode="min", save_top_k=1,
650
+ ),
651
+ ]
652
+
653
+ trainer = pl.Trainer(
654
+ accelerator="auto", max_epochs=max_epochs, callbacks=callbacks,
655
+ logger=False, enable_progress_bar=True,
656
+ )
657
+
658
+ trainer.fit(mpnn, train_loader, val_loader)
659
+
660
+ if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
661
+ checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
662
+ mpnn.load_state_dict(checkpoint["state_dict"])
663
+
664
+ mpnn.eval()
665
+ ensemble_models.append(mpnn)
666
+
667
+ # Make out-of-fold predictions using raw features
668
+ val_datapoints_raw, _ = create_molecule_datapoints(
669
+ df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra_raw
670
+ )
671
+ val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
672
+ val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
673
+
674
+ with torch.inference_mode():
675
+ fold_predictions = trainer.predict(mpnn, val_loader_pred)
676
+ fold_preds = np.concatenate([p.numpy() for p in fold_predictions], axis=0)
677
+ if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
678
+ fold_preds = fold_preds.squeeze(axis=1)
679
+
680
+ # Store out-of-fold predictions
681
+ if model_type == "classifier" and fold_preds.ndim == 2:
682
+ oof_predictions[val_idx] = np.argmax(fold_preds, axis=1)
683
+ if oof_proba is not None:
684
+ oof_proba[val_idx] = fold_preds
685
+ else:
686
+ oof_predictions[val_idx] = fold_preds.flatten()
687
+
688
+ print(f"Fold {fold_idx + 1} complete!")
689
+
690
+ print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
691
+
692
+ # Use out-of-fold predictions for metrics
693
+ # For n_folds=1, we only have predictions for val_idx, so filter to those rows
694
+ if n_folds == 1:
695
+ val_mask = ~np.isnan(oof_predictions)
696
+ preds = oof_predictions[val_mask]
697
+ df_val = all_df[val_mask].copy()
698
+ y_validate = df_val[target].values
699
+ if oof_proba is not None:
700
+ oof_proba = oof_proba[val_mask]
701
+ val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
702
+ else:
703
+ preds = oof_predictions
704
+ df_val = all_df.copy()
705
+ y_validate = all_df[target].values
706
+ val_extra_features = all_extra_features
707
+
708
+ # Compute prediction_std by running all ensemble models on validation data
709
+ # For n_folds=1, std will be 0 (only one model). For n_folds>1, std shows ensemble disagreement.
710
+ preds_std = None
711
+ if model_type in ["regressor", "uq_regressor"] and len(ensemble_models) > 0:
712
+ print("Computing prediction_std from ensemble predictions on validation data...")
713
+ val_datapoints_for_std, _ = create_molecule_datapoints(
714
+ df_val[smiles_column].tolist(),
715
+ df_val[target].tolist(),
716
+ val_extra_features
717
+ )
718
+ val_dataset_for_std = data.MoleculeDataset(val_datapoints_for_std)
719
+ val_loader_for_std = data.build_dataloader(val_dataset_for_std, batch_size=batch_size, shuffle=False)
720
+
721
+ all_ensemble_preds_for_std = []
722
+ trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
723
+ for ens_model in ensemble_models:
724
+ with torch.inference_mode():
725
+ ens_preds = trainer_pred.predict(ens_model, val_loader_for_std)
726
+ ens_preds = np.concatenate([p.numpy() for p in ens_preds], axis=0)
727
+ if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
728
+ ens_preds = ens_preds.squeeze(axis=1)
729
+ all_ensemble_preds_for_std.append(ens_preds.flatten())
730
+
731
+ ensemble_preds_stacked = np.stack(all_ensemble_preds_for_std, axis=0)
732
+ preds_std = np.std(ensemble_preds_stacked, axis=0)
733
+ print(f"Ensemble prediction_std - mean: {np.mean(preds_std):.4f}, max: {np.max(preds_std):.4f}")
734
+
735
+ if model_type == "classifier":
736
+ # Classification metrics - preds contains class indices from OOF predictions
737
+ class_preds = preds.astype(int)
738
+ has_proba = oof_proba is not None
739
+
740
+ print(f"class_preds shape: {class_preds.shape}")
741
+
742
+ # Decode labels for metrics
743
+ y_validate_decoded = label_encoder.inverse_transform(y_validate.astype(int))
744
+ preds_decoded = label_encoder.inverse_transform(class_preds)
745
+
746
+ # Calculate metrics
747
+ label_names = label_encoder.classes_
748
+ scores = precision_recall_fscore_support(
749
+ y_validate_decoded, preds_decoded, average=None, labels=label_names
750
+ )
751
+
752
+ score_df = pd.DataFrame(
753
+ {
754
+ target: label_names,
755
+ "precision": scores[0],
756
+ "recall": scores[1],
757
+ "f1": scores[2],
758
+ "support": scores[3],
759
+ }
760
+ )
761
+
762
+ # Output metrics per class
763
+ metrics = ["precision", "recall", "f1", "support"]
764
+ for t in label_names:
765
+ for m in metrics:
766
+ value = score_df.loc[score_df[target] == t, m].iloc[0]
767
+ print(f"Metrics:{t}:{m} {value}")
768
+
769
+ # Confusion matrix
770
+ conf_mtx = confusion_matrix(
771
+ y_validate_decoded, preds_decoded, labels=label_names
772
+ )
773
+ for i, row_name in enumerate(label_names):
774
+ for j, col_name in enumerate(label_names):
775
+ value = conf_mtx[i, j]
776
+ print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
777
+
778
+ # Save validation predictions
779
+ df_val = df_val.copy()
780
+ df_val["prediction"] = preds_decoded
781
+ if has_proba and oof_proba is not None:
782
+ df_val["pred_proba"] = [p.tolist() for p in oof_proba]
783
+ df_val = expand_proba_column(df_val, label_names)
784
+
785
+ else:
786
+ # Regression metrics
787
+ preds_flat = preds.flatten()
788
+ rmse = root_mean_squared_error(y_validate, preds_flat)
789
+ mae = mean_absolute_error(y_validate, preds_flat)
790
+ medae = median_absolute_error(y_validate, preds_flat)
791
+ r2 = r2_score(y_validate, preds_flat)
792
+ spearman_corr = spearmanr(y_validate, preds_flat).correlation
793
+ support = len(df_val)
794
+ print(f"rmse: {rmse:.3f}")
795
+ print(f"mae: {mae:.3f}")
796
+ print(f"medae: {medae:.3f}")
797
+ print(f"r2: {r2:.3f}")
798
+ print(f"spearmanr: {spearman_corr:.3f}")
799
+ print(f"support: {support}")
800
+
801
+ df_val = df_val.copy()
802
+ df_val["prediction"] = preds_flat
803
+
804
+ # Add prediction_std (always present for regressors, 0 for single model)
805
+ if preds_std is not None:
806
+ df_val["prediction_std"] = preds_std.flatten()
807
+ else:
808
+ df_val["prediction_std"] = 0.0
809
+ print(f"Ensemble std - mean: {df_val['prediction_std'].mean():.4f}, max: {df_val['prediction_std'].max():.4f}")
810
+
811
+ # Save validation predictions to S3
812
+ # Include id_column if it exists in df_val
813
+ output_columns = []
814
+ if id_column in df_val.columns:
815
+ output_columns.append(id_column)
816
+ output_columns += [target, "prediction"]
817
+ if "prediction_std" in df_val.columns:
818
+ output_columns.append("prediction_std")
819
+ output_columns += [col for col in df_val.columns if col.endswith("_proba")]
820
+ wr.s3.to_csv(
821
+ df_val[output_columns],
822
+ path=f"{model_metrics_s3_path}/validation_predictions.csv",
823
+ index=False,
824
+ )
825
+
826
+ # Save ensemble models (n_folds models if CV, 1 model otherwise)
827
+ for model_idx, ens_model in enumerate(ensemble_models):
828
+ model_path = os.path.join(args.model_dir, f"chemprop_model_{model_idx}.pt")
829
+ models.save_model(model_path, ens_model)
830
+ print(f"Saved model {model_idx + 1} to {model_path}")
831
+
832
+ # Save ensemble metadata (n_ensemble = number of models for inference)
833
+ n_ensemble = len(ensemble_models)
834
+ ensemble_metadata = {"n_ensemble": n_ensemble, "n_folds": n_folds}
835
+ joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
836
+ print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds})")
837
+
838
+ # Save label encoder if classification
839
+ if label_encoder is not None:
840
+ joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
841
+
842
+ # Save extra feature metadata for inference (hybrid mode)
843
+ # Note: We don't need to save the scaler - X_d_transform is embedded in the model
844
+ if use_extra_features:
845
+ feature_metadata = {
846
+ "extra_feature_cols": extra_feature_cols,
847
+ "col_means": col_means.tolist(), # Unscaled means for NaN imputation
848
+ }
849
+ joblib.dump(
850
+ feature_metadata, os.path.join(args.model_dir, "feature_metadata.joblib")
851
+ )
852
+ print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")