workbench 0.8.198__py3-none-any.whl → 0.8.203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. workbench/algorithms/dataframe/proximity.py +11 -4
  2. workbench/api/__init__.py +2 -1
  3. workbench/api/df_store.py +17 -108
  4. workbench/api/feature_set.py +48 -11
  5. workbench/api/model.py +1 -1
  6. workbench/api/parameter_store.py +3 -52
  7. workbench/core/artifacts/__init__.py +11 -2
  8. workbench/core/artifacts/artifact.py +5 -5
  9. workbench/core/artifacts/df_store_core.py +114 -0
  10. workbench/core/artifacts/endpoint_core.py +261 -78
  11. workbench/core/artifacts/feature_set_core.py +69 -1
  12. workbench/core/artifacts/model_core.py +48 -14
  13. workbench/core/artifacts/parameter_store_core.py +98 -0
  14. workbench/core/transforms/features_to_model/features_to_model.py +50 -33
  15. workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
  16. workbench/core/views/view.py +2 -2
  17. workbench/model_scripts/chemprop/chemprop.template +933 -0
  18. workbench/model_scripts/chemprop/generated_model_script.py +933 -0
  19. workbench/model_scripts/chemprop/requirements.txt +11 -0
  20. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  21. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  22. workbench/model_scripts/custom_models/proximity/proximity.py +11 -4
  23. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +11 -5
  24. workbench/model_scripts/custom_models/uq_models/meta_uq.template +11 -5
  25. workbench/model_scripts/custom_models/uq_models/ngboost.template +11 -5
  26. workbench/model_scripts/custom_models/uq_models/proximity.py +11 -4
  27. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +11 -5
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +365 -173
  29. workbench/model_scripts/pytorch_model/pytorch.template +362 -170
  30. workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
  31. workbench/model_scripts/script_generation.py +10 -7
  32. workbench/model_scripts/uq_models/generated_model_script.py +43 -27
  33. workbench/model_scripts/uq_models/mapie.template +40 -24
  34. workbench/model_scripts/xgb_model/generated_model_script.py +36 -7
  35. workbench/model_scripts/xgb_model/xgb_model.template +36 -7
  36. workbench/repl/workbench_shell.py +14 -5
  37. workbench/resources/open_source_api.key +1 -1
  38. workbench/scripts/endpoint_test.py +162 -0
  39. workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
  40. workbench/utils/chemprop_utils.py +761 -0
  41. workbench/utils/pytorch_utils.py +527 -0
  42. workbench/utils/xgboost_model_utils.py +10 -5
  43. workbench/web_interface/components/model_plot.py +7 -1
  44. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/METADATA +3 -3
  45. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/RECORD +49 -43
  46. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/entry_points.txt +2 -1
  47. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  48. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  49. workbench/model_scripts/__pycache__/script_generation.cpython-312.pyc +0 -0
  50. workbench/model_scripts/__pycache__/script_generation.cpython-313.pyc +0 -0
  51. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/WHEEL +0 -0
  52. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/licenses/LICENSE +0 -0
  53. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,933 @@
1
+ # ChemProp Model Template for Workbench
2
+ # Uses ChemProp 2.x Message Passing Neural Networks for molecular property prediction
3
+ #
4
+ # === CHEMPROP REVIEW NOTES ===
5
+ # This script runs on AWS SageMaker. Key areas for ChemProp review:
6
+ #
7
+ # 1. Model Architecture (build_mpnn_model function)
8
+ # - BondMessagePassing, NormAggregation, FFN configuration
9
+ # - Regression uses output_transform (UnscaleTransform) for target scaling
10
+ #
11
+ # 2. Data Handling (create_molecule_datapoints function)
12
+ # - MoleculeDatapoint creation with x_d (extra descriptors)
13
+ # - RDKit validation of SMILES
14
+ #
15
+ # 3. Scaling (training section)
16
+ # - Extra descriptors: normalize_inputs("X_d") + X_d_transform in model
17
+ # - Targets (regression): normalize_targets() + UnscaleTransform in FFN
18
+ # - At inference: pass RAW features, transforms handle scaling automatically
19
+ #
20
+ # 4. Training Loop (search for "pl.Trainer")
21
+ # - PyTorch Lightning Trainer with ChemProp MPNN
22
+ #
23
+ # AWS/SageMaker boilerplate (can skip):
24
+ # - input_fn, output_fn, model_fn: SageMaker serving interface
25
+ # - argparse, file loading, S3 writes
26
+ # =============================
27
+
28
+ import os
29
+ import argparse
30
+ import json
31
+ from io import StringIO
32
+
33
+ import awswrangler as wr
34
+ import numpy as np
35
+ import pandas as pd
36
+ import torch
37
+ from lightning import pytorch as pl
38
+ from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
39
+ from sklearn.preprocessing import LabelEncoder
40
+ from sklearn.metrics import (
41
+ mean_absolute_error,
42
+ median_absolute_error,
43
+ r2_score,
44
+ root_mean_squared_error,
45
+ precision_recall_fscore_support,
46
+ confusion_matrix,
47
+ )
48
+ from scipy.stats import spearmanr
49
+ import joblib
50
+
51
+ # ChemProp imports
52
+ from chemprop import data, models, nn
53
+
54
+ # Template Parameters
55
+ TEMPLATE_PARAMS = {
56
+ "model_type": "{{model_type}}",
57
+ "targets": "{{target_column}}", # List of target columns (single or multi-task)
58
+ "feature_list": "{{feature_list}}",
59
+ "id_column": "{{id_column}}",
60
+ "model_metrics_s3_path": "{{model_metrics_s3_path}}",
61
+ "hyperparameters": "{{hyperparameters}}",
62
+ }
63
+
64
+
65
+ def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
66
+ """Check if the provided dataframe is empty and raise an exception if it is."""
67
+ if df.empty:
68
+ msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
69
+ print(msg)
70
+ raise ValueError(msg)
71
+
72
+
73
+ def find_smiles_column(columns: list[str]) -> str:
74
+ """Find the SMILES column name from a list (case-insensitive match for 'smiles')."""
75
+ smiles_column = next((col for col in columns if col.lower() == "smiles"), None)
76
+ if smiles_column is None:
77
+ raise ValueError(
78
+ "Column list must contain a 'smiles' column (case-insensitive)"
79
+ )
80
+ return smiles_column
81
+
82
+
83
+ def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
84
+ """Expands a column containing a list of probabilities into separate columns.
85
+
86
+ Handles None values for rows where predictions couldn't be made.
87
+ """
88
+ proba_column = "pred_proba"
89
+ if proba_column not in df.columns:
90
+ raise ValueError('DataFrame does not contain a "pred_proba" column')
91
+
92
+ proba_splits = [f"{label}_proba" for label in class_labels]
93
+ n_classes = len(class_labels)
94
+
95
+ # Handle None values by replacing with list of NaNs
96
+ proba_values = []
97
+ for val in df[proba_column]:
98
+ if val is None:
99
+ proba_values.append([np.nan] * n_classes)
100
+ else:
101
+ proba_values.append(val)
102
+
103
+ proba_df = pd.DataFrame(proba_values, columns=proba_splits)
104
+
105
+ df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
106
+ df = df.reset_index(drop=True)
107
+ df = pd.concat([df, proba_df], axis=1)
108
+ return df
109
+
110
+
111
+ def create_molecule_datapoints(
112
+ smiles_list: list[str],
113
+ targets: list[float] | np.ndarray | None = None,
114
+ extra_descriptors: np.ndarray | None = None,
115
+ ) -> tuple[list[data.MoleculeDatapoint], list[int]]:
116
+ """Create ChemProp MoleculeDatapoints from SMILES strings.
117
+
118
+ Args:
119
+ smiles_list: List of SMILES strings
120
+ targets: Optional target values as 2D array (n_samples, n_targets). NaN allowed for missing targets.
121
+ extra_descriptors: Optional array of extra features (n_samples, n_features)
122
+
123
+ Returns:
124
+ Tuple of (list of MoleculeDatapoint objects, list of valid indices)
125
+ """
126
+ from rdkit import Chem
127
+
128
+ datapoints = []
129
+ valid_indices = []
130
+ invalid_count = 0
131
+
132
+ # Convert targets to 2D array if provided
133
+ if targets is not None:
134
+ targets = np.atleast_2d(np.array(targets))
135
+ if targets.shape[0] == 1 and len(smiles_list) > 1:
136
+ targets = targets.T # Shape was (1, n_samples), transpose to (n_samples, 1)
137
+
138
+ for i, smi in enumerate(smiles_list):
139
+ # Validate SMILES with RDKit first
140
+ mol = Chem.MolFromSmiles(smi)
141
+ if mol is None:
142
+ invalid_count += 1
143
+ continue
144
+
145
+ # Build datapoint with optional target(s) and extra descriptors
146
+ # For multi-task, y is a list of values (can include NaN for missing targets)
147
+ y = targets[i].tolist() if targets is not None else None
148
+ x_d = extra_descriptors[i] if extra_descriptors is not None else None
149
+
150
+ dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
151
+ datapoints.append(dp)
152
+ valid_indices.append(i)
153
+
154
+ if invalid_count > 0:
155
+ print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
156
+
157
+ return datapoints, valid_indices
158
+
159
+
160
+ def build_mpnn_model(
161
+ hyperparameters: dict,
162
+ task: str = "regression",
163
+ num_classes: int | None = None,
164
+ n_targets: int = 1,
165
+ n_extra_descriptors: int = 0,
166
+ x_d_transform: nn.ScaleTransform | None = None,
167
+ output_transform: nn.UnscaleTransform | None = None,
168
+ task_weights: np.ndarray | None = None,
169
+ ) -> models.MPNN:
170
+ """Build an MPNN model with the specified hyperparameters.
171
+
172
+ Args:
173
+ hyperparameters: Dictionary of model hyperparameters
174
+ task: Either "regression" or "classification"
175
+ num_classes: Number of classes for classification tasks
176
+ n_targets: Number of target columns (for multi-task regression)
177
+ n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
178
+ x_d_transform: Optional transform for extra descriptors (scaling)
179
+ output_transform: Optional transform for regression output (unscaling targets)
180
+ task_weights: Optional array of weights for each task (multi-task learning)
181
+
182
+ Returns:
183
+ Configured MPNN model
184
+ """
185
+ # Model hyperparameters with defaults
186
+ hidden_dim = hyperparameters.get("hidden_dim", 700)
187
+ depth = hyperparameters.get("depth", 6)
188
+ dropout = hyperparameters.get("dropout", 0.25)
189
+ ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 2000)
190
+ ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
191
+
192
+ # Message passing component
193
+ mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
194
+
195
+ # Aggregation - NormAggregation normalizes output, recommended when using extra descriptors
196
+ agg = nn.NormAggregation()
197
+
198
+ # FFN input_dim = message passing output + extra descriptors
199
+ ffn_input_dim = hidden_dim + n_extra_descriptors
200
+
201
+ # Build FFN based on task type
202
+ if task == "classification" and num_classes is not None:
203
+ # Multi-class classification
204
+ ffn = nn.MulticlassClassificationFFN(
205
+ n_classes=num_classes,
206
+ input_dim=ffn_input_dim,
207
+ hidden_dim=ffn_hidden_dim,
208
+ n_layers=ffn_num_layers,
209
+ dropout=dropout,
210
+ )
211
+ else:
212
+ # Regression with optional output transform to unscale predictions
213
+ # n_tasks controls the number of output heads for multi-task learning
214
+ # task_weights goes here (in RegressionFFN) to weight loss per task
215
+ weights_tensor = None
216
+ if task_weights is not None:
217
+ weights_tensor = torch.tensor(task_weights, dtype=torch.float32)
218
+
219
+ ffn = nn.RegressionFFN(
220
+ input_dim=ffn_input_dim,
221
+ hidden_dim=ffn_hidden_dim,
222
+ n_layers=ffn_num_layers,
223
+ dropout=dropout,
224
+ n_tasks=n_targets,
225
+ output_transform=output_transform,
226
+ task_weights=weights_tensor,
227
+ )
228
+
229
+ # Create the MPNN model
230
+ mpnn = models.MPNN(
231
+ message_passing=mp,
232
+ agg=agg,
233
+ predictor=ffn,
234
+ batch_norm=True,
235
+ metrics=None,
236
+ X_d_transform=x_d_transform,
237
+ )
238
+
239
+ return mpnn
240
+
241
+
242
+ def model_fn(model_dir: str) -> dict:
243
+ """Load the ChemProp MPNN ensemble models from the specified directory.
244
+
245
+ Args:
246
+ model_dir: Directory containing the saved models
247
+
248
+ Returns:
249
+ Dictionary with ensemble models and metadata
250
+ """
251
+ # Load ensemble metadata (required)
252
+ ensemble_metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
253
+ ensemble_metadata = joblib.load(ensemble_metadata_path)
254
+ n_ensemble = ensemble_metadata["n_ensemble"]
255
+ target_columns = ensemble_metadata["target_columns"]
256
+
257
+ # Load all ensemble models
258
+ ensemble_models = []
259
+ for ens_idx in range(n_ensemble):
260
+ model_path = os.path.join(model_dir, f"chemprop_model_{ens_idx}.pt")
261
+ model = models.MPNN.load_from_file(model_path)
262
+ model.eval()
263
+ ensemble_models.append(model)
264
+
265
+ print(f"Loaded {len(ensemble_models)} ensemble model(s), n_targets={len(target_columns)}")
266
+
267
+ return {
268
+ "ensemble_models": ensemble_models,
269
+ "n_ensemble": n_ensemble,
270
+ "target_columns": target_columns,
271
+ }
272
+
273
+
274
+ def input_fn(input_data, content_type: str) -> pd.DataFrame:
275
+ """Parse input data and return a DataFrame."""
276
+ if not input_data:
277
+ raise ValueError("Empty input data is not supported!")
278
+
279
+ if isinstance(input_data, bytes):
280
+ input_data = input_data.decode("utf-8")
281
+
282
+ if "text/csv" in content_type:
283
+ return pd.read_csv(StringIO(input_data))
284
+ elif "application/json" in content_type:
285
+ return pd.DataFrame(json.loads(input_data))
286
+ else:
287
+ raise ValueError(f"{content_type} not supported!")
288
+
289
+
290
+ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
291
+ """Supports both CSV and JSON output formats."""
292
+ if "text/csv" in accept_type:
293
+ csv_output = output_df.fillna("N/A").to_csv(index=False)
294
+ return csv_output, "text/csv"
295
+ elif "application/json" in accept_type:
296
+ return output_df.to_json(orient="records"), "application/json"
297
+ else:
298
+ raise RuntimeError(
299
+ f"{accept_type} accept type is not supported by this script."
300
+ )
301
+
302
+
303
+ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
304
+ """Make predictions with the ChemProp MPNN ensemble.
305
+
306
+ Args:
307
+ df: Input DataFrame containing SMILES column (and extra features if hybrid mode)
308
+ model_dict: Dictionary containing ensemble models and metadata
309
+
310
+ Returns:
311
+ DataFrame with predictions added (and prediction_std for ensembles)
312
+ """
313
+ model_type = TEMPLATE_PARAMS["model_type"]
314
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
315
+
316
+ # Extract ensemble models and metadata
317
+ ensemble_models = model_dict["ensemble_models"]
318
+ n_ensemble = model_dict["n_ensemble"]
319
+ target_columns = model_dict["target_columns"]
320
+
321
+ # Load label encoder if present (classification)
322
+ label_encoder = None
323
+ label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
324
+ if os.path.exists(label_encoder_path):
325
+ label_encoder = joblib.load(label_encoder_path)
326
+
327
+ # Load feature metadata if present (hybrid mode)
328
+ # Contains column names, NaN fill values, and scaler for feature scaling
329
+ feature_metadata = None
330
+ feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
331
+ if os.path.exists(feature_metadata_path):
332
+ feature_metadata = joblib.load(feature_metadata_path)
333
+ print(
334
+ f"Hybrid mode: using {len(feature_metadata['extra_feature_cols'])} extra features"
335
+ )
336
+
337
+ # Find SMILES column in input DataFrame
338
+ smiles_column = find_smiles_column(df.columns.tolist())
339
+
340
+ smiles_list = df[smiles_column].tolist()
341
+
342
+ # Track invalid SMILES
343
+ valid_mask = []
344
+ valid_smiles = []
345
+ valid_indices = []
346
+ for i, smi in enumerate(smiles_list):
347
+ if smi and isinstance(smi, str) and len(smi.strip()) > 0:
348
+ valid_mask.append(True)
349
+ valid_smiles.append(smi.strip())
350
+ valid_indices.append(i)
351
+ else:
352
+ valid_mask.append(False)
353
+
354
+ valid_mask = np.array(valid_mask)
355
+ print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
356
+
357
+ # Initialize prediction columns (use object dtype for classifiers to avoid FutureWarning)
358
+ if model_type == "classifier":
359
+ df["prediction"] = pd.Series([None] * len(df), dtype=object)
360
+ else:
361
+ # Regression: create prediction column for each target
362
+ for tc in target_columns:
363
+ df[f"{tc}_pred"] = np.nan
364
+ df[f"{tc}_pred_std"] = np.nan
365
+
366
+ if sum(valid_mask) == 0:
367
+ print("Warning: No valid SMILES to predict on")
368
+ return df
369
+
370
+ # Prepare extra features if in hybrid mode
371
+ # NOTE: We pass RAW (unscaled) features here - the model's X_d_transform handles scaling
372
+ extra_features = None
373
+ if feature_metadata is not None:
374
+ extra_feature_cols = feature_metadata["extra_feature_cols"]
375
+ col_means = np.array(feature_metadata["col_means"])
376
+
377
+ # Check columns exist
378
+ missing_cols = [col for col in extra_feature_cols if col not in df.columns]
379
+ if missing_cols:
380
+ print(
381
+ f"Warning: Missing extra feature columns: {missing_cols}. Using mean values."
382
+ )
383
+
384
+ # Extract features for valid SMILES rows (raw, unscaled)
385
+ extra_features = np.zeros(
386
+ (len(valid_indices), len(extra_feature_cols)), dtype=np.float32
387
+ )
388
+ for j, col in enumerate(extra_feature_cols):
389
+ if col in df.columns:
390
+ values = df.iloc[valid_indices][col].values.astype(np.float32)
391
+ # Fill NaN with training column means (unscaled means)
392
+ nan_mask = np.isnan(values)
393
+ values[nan_mask] = col_means[j]
394
+ extra_features[:, j] = values
395
+ else:
396
+ # Column missing, use training mean
397
+ extra_features[:, j] = col_means[j]
398
+
399
+ # Create datapoints for prediction (filter out invalid SMILES)
400
+ datapoints, rdkit_valid_indices = create_molecule_datapoints(
401
+ valid_smiles, extra_descriptors=extra_features
402
+ )
403
+
404
+ if len(datapoints) == 0:
405
+ print("Warning: No valid SMILES after RDKit validation")
406
+ return df
407
+
408
+ dataset = data.MoleculeDataset(datapoints)
409
+ dataloader = data.build_dataloader(dataset, shuffle=False)
410
+
411
+ # Make predictions with ensemble
412
+ trainer = pl.Trainer(
413
+ accelerator="auto",
414
+ logger=False,
415
+ enable_progress_bar=False,
416
+ )
417
+
418
+ # Collect predictions from all ensemble members
419
+ all_ensemble_preds = []
420
+ for ens_idx, ens_model in enumerate(ensemble_models):
421
+ with torch.inference_mode():
422
+ predictions = trainer.predict(ens_model, dataloader)
423
+ ens_preds = np.concatenate([p.numpy() for p in predictions], axis=0)
424
+ # Squeeze middle dim if present
425
+ if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
426
+ ens_preds = ens_preds.squeeze(axis=1)
427
+ all_ensemble_preds.append(ens_preds)
428
+
429
+ # Stack and compute mean/std (std is 0 for single model)
430
+ ensemble_preds = np.stack(all_ensemble_preds, axis=0)
431
+ preds = np.mean(ensemble_preds, axis=0)
432
+ preds_std = np.std(ensemble_preds, axis=0) # Will be 0s for n_ensemble=1
433
+
434
+ # Ensure 2D: (n_samples, n_targets)
435
+ if preds.ndim == 1:
436
+ preds = preds.reshape(-1, 1)
437
+ preds_std = preds_std.reshape(-1, 1)
438
+
439
+ print(f"Inference: Ensemble predictions shape: {preds.shape}")
440
+
441
+ # Map predictions back to valid_mask positions (accounting for RDKit-invalid SMILES)
442
+ # rdkit_valid_indices tells us which of the valid_smiles were actually valid
443
+ valid_positions = np.where(valid_mask)[0][rdkit_valid_indices]
444
+ valid_mask = np.zeros(len(df), dtype=bool)
445
+ valid_mask[valid_positions] = True
446
+
447
+ if model_type == "classifier" and label_encoder is not None:
448
+ # For classification, get class predictions and probabilities
449
+ if preds.ndim == 2 and preds.shape[1] > 1:
450
+ # Multi-class: preds are probabilities (averaged across ensemble)
451
+ class_preds = np.argmax(preds, axis=1)
452
+ decoded_preds = label_encoder.inverse_transform(class_preds)
453
+ df.loc[valid_mask, "prediction"] = decoded_preds
454
+
455
+ # Add probability columns
456
+ proba_series = pd.Series([None] * len(df), index=df.index, dtype=object)
457
+ proba_series.loc[valid_mask] = [p.tolist() for p in preds]
458
+ df["pred_proba"] = proba_series
459
+ df = expand_proba_column(df, label_encoder.classes_)
460
+ else:
461
+ # Binary or single output
462
+ class_preds = (preds.flatten() > 0.5).astype(int)
463
+ decoded_preds = label_encoder.inverse_transform(class_preds)
464
+ df.loc[valid_mask, "prediction"] = decoded_preds
465
+ else:
466
+ # Regression: store predictions for each target
467
+ for t_idx, tc in enumerate(target_columns):
468
+ df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
469
+ df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
470
+
471
+ return df
472
+
473
+
474
+ if __name__ == "__main__":
475
+ """Training script for ChemProp MPNN model"""
476
+
477
+ # Template Parameters
478
+ target_columns = TEMPLATE_PARAMS["targets"] # List of target columns
479
+ model_type = TEMPLATE_PARAMS["model_type"]
480
+ feature_list = TEMPLATE_PARAMS["feature_list"]
481
+ id_column = TEMPLATE_PARAMS["id_column"]
482
+ model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
483
+ hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
484
+
485
+ # Validate target_columns
486
+ if not target_columns or not isinstance(target_columns, list) or len(target_columns) == 0:
487
+ raise ValueError("'targets' must be a non-empty list of target column names")
488
+ n_targets = len(target_columns)
489
+ print(f"Target columns ({n_targets}): {target_columns}")
490
+
491
+ # Get the SMILES column name from feature_list (user defines this, so we use their exact name)
492
+ smiles_column = find_smiles_column(feature_list)
493
+ extra_feature_cols = [f for f in feature_list if f != smiles_column]
494
+ use_extra_features = len(extra_feature_cols) > 0
495
+ print(f"Feature List: {feature_list}")
496
+ print(f"SMILES Column: {smiles_column}")
497
+ print(
498
+ f"Extra Features (hybrid mode): {extra_feature_cols if use_extra_features else 'None (SMILES only)'}"
499
+ )
500
+
501
+ # Script arguments for input/output directories
502
+ parser = argparse.ArgumentParser()
503
+ parser.add_argument(
504
+ "--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
505
+ )
506
+ parser.add_argument(
507
+ "--train",
508
+ type=str,
509
+ default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"),
510
+ )
511
+ parser.add_argument(
512
+ "--output-data-dir",
513
+ type=str,
514
+ default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"),
515
+ )
516
+ args = parser.parse_args()
517
+
518
+ # Read the training data
519
+ training_files = [
520
+ os.path.join(args.train, f)
521
+ for f in os.listdir(args.train)
522
+ if f.endswith(".csv")
523
+ ]
524
+ print(f"Training Files: {training_files}")
525
+
526
+ all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
527
+ print(f"All Data Shape: {all_df.shape}")
528
+
529
+ check_dataframe(all_df, "training_df")
530
+
531
+ # Drop rows with missing SMILES or all target values
532
+ initial_count = len(all_df)
533
+ all_df = all_df.dropna(subset=[smiles_column])
534
+ # Keep rows that have at least one non-null target (works for single and multi-task)
535
+ has_any_target = all_df[target_columns].notna().any(axis=1)
536
+ all_df = all_df[has_any_target]
537
+ dropped = initial_count - len(all_df)
538
+ if dropped > 0:
539
+ print(f"Dropped {dropped} rows with missing SMILES or all target values")
540
+
541
+ print(f"Target columns: {target_columns}")
542
+ print(f"Data Shape after cleaning: {all_df.shape}")
543
+ for tc in target_columns:
544
+ n_valid = all_df[tc].notna().sum()
545
+ print(f" {tc}: {n_valid} samples with values")
546
+
547
+ # Set up label encoder for classification (single-target only)
548
+ label_encoder = None
549
+ if model_type == "classifier":
550
+ if n_targets > 1:
551
+ raise ValueError("Multi-task classification is not supported. Use regression for multi-task.")
552
+ label_encoder = LabelEncoder()
553
+ all_df[target_columns[0]] = label_encoder.fit_transform(all_df[target_columns[0]])
554
+ num_classes = len(label_encoder.classes_)
555
+ print(
556
+ f"Classification task with {num_classes} classes: {label_encoder.classes_}"
557
+ )
558
+ else:
559
+ num_classes = None
560
+
561
+ # Model and training configuration
562
+ print(f"Hyperparameters: {hyperparameters}")
563
+ task = "classification" if model_type == "classifier" else "regression"
564
+ n_extra = len(extra_feature_cols) if use_extra_features else 0
565
+ max_epochs = hyperparameters.get("max_epochs", 400)
566
+ patience = hyperparameters.get("patience", 40)
567
+ n_folds = hyperparameters.get("n_folds", 5) # Number of CV folds (default: 5)
568
+ batch_size = hyperparameters.get("batch_size", 16)
569
+
570
+ # Check extra feature columns exist
571
+ if use_extra_features:
572
+ missing_cols = [col for col in extra_feature_cols if col not in all_df.columns]
573
+ if missing_cols:
574
+ raise ValueError(f"Missing extra feature columns in training data: {missing_cols}")
575
+
576
+ # =========================================================================
577
+ # UNIFIED TRAINING: Works for n_folds=1 (single model) or n_folds>1 (K-fold CV)
578
+ # =========================================================================
579
+ print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold cross-validation ensemble'}...")
580
+
581
+ # Prepare extra features and validate SMILES upfront
582
+ all_extra_features = None
583
+ col_means = None
584
+ if use_extra_features:
585
+ all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
586
+ col_means = np.nanmean(all_extra_features, axis=0)
587
+ for i in range(all_extra_features.shape[1]):
588
+ all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
589
+
590
+ # Prepare target array: always 2D (n_samples, n_targets)
591
+ all_targets = all_df[target_columns].values.astype(np.float32)
592
+
593
+ # Filter invalid SMILES from the full dataset
594
+ _, valid_indices = create_molecule_datapoints(
595
+ all_df[smiles_column].tolist(), all_targets, all_extra_features
596
+ )
597
+ all_df = all_df.iloc[valid_indices].reset_index(drop=True)
598
+ all_targets = all_targets[valid_indices]
599
+ if all_extra_features is not None:
600
+ all_extra_features = all_extra_features[valid_indices]
601
+ print(f"Data after SMILES validation: {all_df.shape}")
602
+
603
+ # Compute dynamic task weights for multi-task regression
604
+ # Weight = inverse of sample count (normalized so min weight = 1.0)
605
+ # This gives higher weight to targets with fewer samples
606
+ task_weights = None
607
+ if n_targets > 1 and model_type != "classifier":
608
+ sample_counts = np.array([np.sum(~np.isnan(all_targets[:, t])) for t in range(n_targets)])
609
+ # Inverse weighting: fewer samples = higher weight
610
+ inverse_counts = 1.0 / sample_counts
611
+ # Normalize so minimum weight is 1.0
612
+ task_weights = inverse_counts / inverse_counts.min()
613
+ print(f"Task weights (inverse sample count):")
614
+ for t_idx, t_name in enumerate(target_columns):
615
+ print(f" {t_name}: {task_weights[t_idx]:.3f} (n={sample_counts[t_idx]})")
616
+
617
+ # Create fold splits
618
+ if n_folds == 1:
619
+ # Single fold: use train/val split from "training" column or random split
620
+ if "training" in all_df.columns:
621
+ print("Found training column, splitting data based on training column")
622
+ train_idx = np.where(all_df["training"])[0]
623
+ val_idx = np.where(~all_df["training"])[0]
624
+ else:
625
+ print("WARNING: No training column found, splitting data with random 80/20 split")
626
+ indices = np.arange(len(all_df))
627
+ train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
628
+ folds = [(train_idx, val_idx)]
629
+ else:
630
+ # K-Fold CV
631
+ if model_type == "classifier":
632
+ kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
633
+ split_target = all_df[target_columns[0]]
634
+ else:
635
+ kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
636
+ split_target = None
637
+ folds = list(kfold.split(all_df, split_target))
638
+
639
+ # Initialize storage for out-of-fold predictions: always 2D (n_samples, n_targets)
640
+ oof_predictions = np.full((len(all_df), n_targets), np.nan, dtype=np.float64)
641
+ if model_type == "classifier" and num_classes and num_classes > 1:
642
+ oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
643
+ else:
644
+ oof_proba = None
645
+
646
+ ensemble_models = []
647
+
648
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
649
+ print(f"\n{'='*50}")
650
+ print(f"Training Fold {fold_idx + 1}/{len(folds)}")
651
+ print(f"{'='*50}")
652
+
653
+ # Split data for this fold
654
+ df_train = all_df.iloc[train_idx].reset_index(drop=True)
655
+ df_val = all_df.iloc[val_idx].reset_index(drop=True)
656
+ train_targets = all_targets[train_idx]
657
+ val_targets = all_targets[val_idx]
658
+
659
+ train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
660
+ val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
661
+
662
+ print(f"Fold {fold_idx + 1} - Train: {len(df_train)}, Val: {len(df_val)}")
663
+
664
+ # Create ChemProp datasets for this fold
665
+ train_datapoints, _ = create_molecule_datapoints(
666
+ df_train[smiles_column].tolist(), train_targets, train_extra
667
+ )
668
+ val_datapoints, _ = create_molecule_datapoints(
669
+ df_val[smiles_column].tolist(), val_targets, val_extra
670
+ )
671
+
672
+ train_dataset = data.MoleculeDataset(train_datapoints)
673
+ val_dataset = data.MoleculeDataset(val_datapoints)
674
+
675
+ # Save raw val features for prediction
676
+ val_extra_raw = val_extra.copy() if val_extra is not None else None
677
+
678
+ # Scale features and targets for this fold
679
+ x_d_transform = None
680
+ if use_extra_features:
681
+ feature_scaler = train_dataset.normalize_inputs("X_d")
682
+ val_dataset.normalize_inputs("X_d", feature_scaler)
683
+ x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
684
+
685
+ output_transform = None
686
+ if model_type in ["regressor", "uq_regressor"]:
687
+ target_scaler = train_dataset.normalize_targets()
688
+ val_dataset.normalize_targets(target_scaler)
689
+ output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
690
+
691
+ train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
692
+ val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
693
+
694
+ # Build and train model for this fold
695
+ pl.seed_everything(42 + fold_idx)
696
+ mpnn = build_mpnn_model(
697
+ hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
698
+ n_extra_descriptors=n_extra, x_d_transform=x_d_transform, output_transform=output_transform,
699
+ task_weights=task_weights,
700
+ )
701
+
702
+ callbacks = [
703
+ pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
704
+ pl.callbacks.ModelCheckpoint(
705
+ dirpath=args.model_dir, filename=f"best_model_{fold_idx}",
706
+ monitor="val_loss", mode="min", save_top_k=1,
707
+ ),
708
+ ]
709
+
710
+ trainer = pl.Trainer(
711
+ accelerator="auto", max_epochs=max_epochs, callbacks=callbacks,
712
+ logger=False, enable_progress_bar=True,
713
+ )
714
+
715
+ trainer.fit(mpnn, train_loader, val_loader)
716
+
717
+ if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
718
+ checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
719
+ mpnn.load_state_dict(checkpoint["state_dict"])
720
+
721
+ mpnn.eval()
722
+ ensemble_models.append(mpnn)
723
+
724
+ # Make out-of-fold predictions using raw features
725
+ val_datapoints_raw, _ = create_molecule_datapoints(
726
+ df_val[smiles_column].tolist(), val_targets, val_extra_raw
727
+ )
728
+ val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
729
+ val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
730
+
731
+ with torch.inference_mode():
732
+ fold_predictions = trainer.predict(mpnn, val_loader_pred)
733
+ fold_preds = np.concatenate([p.numpy() for p in fold_predictions], axis=0)
734
+ if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
735
+ fold_preds = fold_preds.squeeze(axis=1)
736
+
737
+ # Store out-of-fold predictions
738
+ if model_type == "classifier" and fold_preds.ndim == 2:
739
+ # Store class index in first column for classification
740
+ oof_predictions[val_idx, 0] = np.argmax(fold_preds, axis=1)
741
+ if oof_proba is not None:
742
+ oof_proba[val_idx] = fold_preds
743
+ else:
744
+ # Regression: fold_preds shape is (n_val, n_targets) or (n_val,)
745
+ if fold_preds.ndim == 1:
746
+ fold_preds = fold_preds.reshape(-1, 1)
747
+ oof_predictions[val_idx] = fold_preds
748
+
749
+ print(f"Fold {fold_idx + 1} complete!")
750
+
751
+ print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
752
+
753
+ # Use out-of-fold predictions for metrics
754
+ # For n_folds=1, we only have predictions for val_idx, so filter to those rows
755
+ if n_folds == 1:
756
+ # oof_predictions is always 2D now: check if any column has a value
757
+ val_mask = ~np.isnan(oof_predictions).all(axis=1)
758
+ preds = oof_predictions[val_mask]
759
+ df_val = all_df[val_mask].copy()
760
+ y_validate = all_targets[val_mask]
761
+ if oof_proba is not None:
762
+ oof_proba = oof_proba[val_mask]
763
+ val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
764
+ else:
765
+ preds = oof_predictions
766
+ df_val = all_df.copy()
767
+ y_validate = all_targets
768
+ val_extra_features = all_extra_features
769
+
770
+ # Compute prediction_std by running all ensemble models on validation data
771
+ # For n_folds=1, std will be 0 (only one model). For n_folds>1, std shows ensemble disagreement.
772
+ preds_std = None
773
+ if model_type in ["regressor", "uq_regressor"] and len(ensemble_models) > 0:
774
+ print("Computing prediction_std from ensemble predictions on validation data...")
775
+ val_datapoints_for_std, _ = create_molecule_datapoints(
776
+ df_val[smiles_column].tolist(),
777
+ y_validate,
778
+ val_extra_features
779
+ )
780
+ val_dataset_for_std = data.MoleculeDataset(val_datapoints_for_std)
781
+ val_loader_for_std = data.build_dataloader(val_dataset_for_std, batch_size=batch_size, shuffle=False)
782
+
783
+ all_ensemble_preds_for_std = []
784
+ trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
785
+ for ens_model in ensemble_models:
786
+ with torch.inference_mode():
787
+ ens_preds = trainer_pred.predict(ens_model, val_loader_for_std)
788
+ ens_preds = np.concatenate([p.numpy() for p in ens_preds], axis=0)
789
+ if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
790
+ ens_preds = ens_preds.squeeze(axis=1)
791
+ all_ensemble_preds_for_std.append(ens_preds)
792
+
793
+ # Stack ensemble predictions: shape (n_ensemble, n_samples, n_targets)
794
+ ensemble_preds_stacked = np.stack(all_ensemble_preds_for_std, axis=0)
795
+ preds_std = np.std(ensemble_preds_stacked, axis=0)
796
+ # Ensure 2D
797
+ if preds_std.ndim == 1:
798
+ preds_std = preds_std.reshape(-1, 1)
799
+ print(f"Ensemble prediction_std - mean per target: {np.nanmean(preds_std, axis=0)}")
800
+
801
+ if model_type == "classifier":
802
+ # Classification metrics - preds contains class indices in first column from OOF predictions
803
+ class_preds = preds[:, 0].astype(int)
804
+ has_proba = oof_proba is not None
805
+
806
+ print(f"class_preds shape: {class_preds.shape}")
807
+
808
+ # Decode labels for metrics (classification is single-target only)
809
+ target_name = target_columns[0]
810
+ y_validate_decoded = label_encoder.inverse_transform(y_validate[:, 0].astype(int))
811
+ preds_decoded = label_encoder.inverse_transform(class_preds)
812
+
813
+ # Calculate metrics
814
+ label_names = label_encoder.classes_
815
+ scores = precision_recall_fscore_support(
816
+ y_validate_decoded, preds_decoded, average=None, labels=label_names
817
+ )
818
+
819
+ score_df = pd.DataFrame(
820
+ {
821
+ target_name: label_names,
822
+ "precision": scores[0],
823
+ "recall": scores[1],
824
+ "f1": scores[2],
825
+ "support": scores[3],
826
+ }
827
+ )
828
+
829
+ # Output metrics per class
830
+ metrics = ["precision", "recall", "f1", "support"]
831
+ for t in label_names:
832
+ for m in metrics:
833
+ value = score_df.loc[score_df[target_name] == t, m].iloc[0]
834
+ print(f"Metrics:{t}:{m} {value}")
835
+
836
+ # Confusion matrix
837
+ conf_mtx = confusion_matrix(
838
+ y_validate_decoded, preds_decoded, labels=label_names
839
+ )
840
+ for i, row_name in enumerate(label_names):
841
+ for j, col_name in enumerate(label_names):
842
+ value = conf_mtx[i, j]
843
+ print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
844
+
845
+ # Save validation predictions
846
+ df_val = df_val.copy()
847
+ df_val["prediction"] = preds_decoded
848
+ if has_proba and oof_proba is not None:
849
+ df_val["pred_proba"] = [p.tolist() for p in oof_proba]
850
+ df_val = expand_proba_column(df_val, label_names)
851
+
852
+ else:
853
+ # Regression metrics: compute per target (works for single or multi-task)
854
+ df_val = df_val.copy()
855
+ print("\n--- Per-target metrics ---")
856
+ for t_idx, t_name in enumerate(target_columns):
857
+ # Get valid (non-NaN) indices for this target
858
+ target_valid_mask = ~np.isnan(y_validate[:, t_idx])
859
+ y_true = y_validate[target_valid_mask, t_idx]
860
+ y_pred = preds[target_valid_mask, t_idx]
861
+
862
+ if len(y_true) > 0:
863
+ rmse = root_mean_squared_error(y_true, y_pred)
864
+ mae = mean_absolute_error(y_true, y_pred)
865
+ medae = median_absolute_error(y_true, y_pred)
866
+ r2 = r2_score(y_true, y_pred)
867
+ spearman_corr = spearmanr(y_true, y_pred).correlation
868
+ support = len(y_true)
869
+ # Print metrics in format expected by SageMaker metric definitions
870
+ print(f"rmse: {rmse:.3f}")
871
+ print(f"mae: {mae:.3f}")
872
+ print(f"medae: {medae:.3f}")
873
+ print(f"r2: {r2:.3f}")
874
+ print(f"spearmanr: {spearman_corr:.3f}")
875
+ print(f"support: {support}")
876
+
877
+ # Store predictions in dataframe
878
+ df_val[f"{t_name}_pred"] = preds[:, t_idx]
879
+ if preds_std is not None:
880
+ df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx]
881
+ else:
882
+ df_val[f"{t_name}_pred_std"] = 0.0
883
+
884
+ # Save validation predictions to S3
885
+ # Include id_column if it exists in df_val
886
+ output_columns = []
887
+ if id_column in df_val.columns:
888
+ output_columns.append(id_column)
889
+ # Include all target columns and their predictions
890
+ output_columns += target_columns
891
+ output_columns += [f"{t}_pred" for t in target_columns]
892
+ output_columns += [f"{t}_pred_std" for t in target_columns]
893
+ # Add proba columns for classifiers
894
+ output_columns += [col for col in df_val.columns if col.endswith("_proba")]
895
+ # Filter to only columns that exist
896
+ output_columns = [c for c in output_columns if c in df_val.columns]
897
+ wr.s3.to_csv(
898
+ df_val[output_columns],
899
+ path=f"{model_metrics_s3_path}/validation_predictions.csv",
900
+ index=False,
901
+ )
902
+
903
+ # Save ensemble models (n_folds models if CV, 1 model otherwise)
904
+ for model_idx, ens_model in enumerate(ensemble_models):
905
+ model_path = os.path.join(args.model_dir, f"chemprop_model_{model_idx}.pt")
906
+ models.save_model(model_path, ens_model)
907
+ print(f"Saved model {model_idx + 1} to {model_path}")
908
+
909
+ # Save ensemble metadata (n_ensemble = number of models for inference)
910
+ n_ensemble = len(ensemble_models)
911
+ ensemble_metadata = {
912
+ "n_ensemble": n_ensemble,
913
+ "n_folds": n_folds,
914
+ "target_columns": target_columns,
915
+ }
916
+ joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
917
+ print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds}, targets={target_columns})")
918
+
919
+ # Save label encoder if classification
920
+ if label_encoder is not None:
921
+ joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
922
+
923
+ # Save extra feature metadata for inference (hybrid mode)
924
+ # Note: We don't need to save the scaler - X_d_transform is embedded in the model
925
+ if use_extra_features:
926
+ feature_metadata = {
927
+ "extra_feature_cols": extra_feature_cols,
928
+ "col_means": col_means.tolist(), # Unscaled means for NaN imputation
929
+ }
930
+ joblib.dump(
931
+ feature_metadata, os.path.join(args.model_dir, "feature_metadata.joblib")
932
+ )
933
+ print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")