workbench 0.8.197__py3-none-any.whl → 0.8.201__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. workbench/algorithms/dataframe/proximity.py +19 -12
  2. workbench/api/__init__.py +2 -1
  3. workbench/api/feature_set.py +7 -4
  4. workbench/api/model.py +1 -1
  5. workbench/core/artifacts/__init__.py +11 -2
  6. workbench/core/artifacts/endpoint_core.py +84 -46
  7. workbench/core/artifacts/feature_set_core.py +69 -1
  8. workbench/core/artifacts/model_core.py +37 -7
  9. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  10. workbench/core/transforms/features_to_model/features_to_model.py +23 -20
  11. workbench/core/views/view.py +2 -2
  12. workbench/model_scripts/chemprop/chemprop.template +931 -0
  13. workbench/model_scripts/chemprop/generated_model_script.py +931 -0
  14. workbench/model_scripts/chemprop/requirements.txt +11 -0
  15. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  16. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  17. workbench/model_scripts/custom_models/proximity/proximity.py +19 -12
  18. workbench/model_scripts/custom_models/uq_models/proximity.py +19 -12
  19. workbench/model_scripts/pytorch_model/generated_model_script.py +130 -88
  20. workbench/model_scripts/pytorch_model/pytorch.template +128 -86
  21. workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
  22. workbench/model_scripts/script_generation.py +10 -7
  23. workbench/model_scripts/uq_models/generated_model_script.py +25 -18
  24. workbench/model_scripts/uq_models/mapie.template +23 -16
  25. workbench/model_scripts/xgb_model/generated_model_script.py +6 -6
  26. workbench/model_scripts/xgb_model/xgb_model.template +2 -2
  27. workbench/repl/workbench_shell.py +14 -5
  28. workbench/scripts/endpoint_test.py +162 -0
  29. workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
  30. workbench/utils/chemprop_utils.py +724 -0
  31. workbench/utils/pytorch_utils.py +497 -0
  32. workbench/utils/xgboost_model_utils.py +12 -5
  33. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/METADATA +2 -2
  34. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/RECORD +38 -30
  35. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/entry_points.txt +2 -1
  36. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/WHEEL +0 -0
  37. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/licenses/LICENSE +0 -0
  38. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,931 @@
1
+ # ChemProp Model Template for Workbench
2
+ # Uses ChemProp 2.x Message Passing Neural Networks for molecular property prediction
3
+ #
4
+ # === CHEMPROP REVIEW NOTES ===
5
+ # This script runs on AWS SageMaker. Key areas for ChemProp review:
6
+ #
7
+ # 1. Model Architecture (build_mpnn_model function)
8
+ # - BondMessagePassing, NormAggregation, FFN configuration
9
+ # - Regression uses output_transform (UnscaleTransform) for target scaling
10
+ #
11
+ # 2. Data Handling (create_molecule_datapoints function)
12
+ # - MoleculeDatapoint creation with x_d (extra descriptors)
13
+ # - RDKit validation of SMILES
14
+ #
15
+ # 3. Scaling (training section)
16
+ # - Extra descriptors: normalize_inputs("X_d") + X_d_transform in model
17
+ # - Targets (regression): normalize_targets() + UnscaleTransform in FFN
18
+ # - At inference: pass RAW features, transforms handle scaling automatically
19
+ #
20
+ # 4. Training Loop (search for "pl.Trainer")
21
+ # - PyTorch Lightning Trainer with ChemProp MPNN
22
+ #
23
+ # AWS/SageMaker boilerplate (can skip):
24
+ # - input_fn, output_fn, model_fn: SageMaker serving interface
25
+ # - argparse, file loading, S3 writes
26
+ # =============================
27
+
28
+ import os
29
+ import argparse
30
+ import json
31
+ from io import StringIO
32
+
33
+ import awswrangler as wr
34
+ import numpy as np
35
+ import pandas as pd
36
+ import torch
37
+ from lightning import pytorch as pl
38
+ from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
39
+ from sklearn.preprocessing import LabelEncoder
40
+ from sklearn.metrics import (
41
+ mean_absolute_error,
42
+ r2_score,
43
+ root_mean_squared_error,
44
+ precision_recall_fscore_support,
45
+ confusion_matrix,
46
+ )
47
+ import joblib
48
+
49
+ # ChemProp imports
50
+ from chemprop import data, models, nn
51
+
52
+ # Template Parameters
53
+ TEMPLATE_PARAMS = {
54
+ "model_type": "{{model_type}}",
55
+ "target": "{{target_column}}",
56
+ "feature_list": "{{feature_list}}",
57
+ "model_metrics_s3_path": "{{model_metrics_s3_path}}",
58
+ "train_all_data": "{{train_all_data}}",
59
+ "hyperparameters": "{{hyperparameters}}",
60
+ }
61
+
62
+
63
+ def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
64
+ """Check if the provided dataframe is empty and raise an exception if it is."""
65
+ if df.empty:
66
+ msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
67
+ print(msg)
68
+ raise ValueError(msg)
69
+
70
+
71
+ def find_smiles_column(columns: list[str]) -> str:
72
+ """Find the SMILES column name from a list (case-insensitive match for 'smiles')."""
73
+ smiles_column = next((col for col in columns if col.lower() == "smiles"), None)
74
+ if smiles_column is None:
75
+ raise ValueError(
76
+ "Column list must contain a 'smiles' column (case-insensitive)"
77
+ )
78
+ return smiles_column
79
+
80
+
81
+ def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
82
+ """Expands a column containing a list of probabilities into separate columns.
83
+
84
+ Handles None values for rows where predictions couldn't be made.
85
+ """
86
+ proba_column = "pred_proba"
87
+ if proba_column not in df.columns:
88
+ raise ValueError('DataFrame does not contain a "pred_proba" column')
89
+
90
+ proba_splits = [f"{label}_proba" for label in class_labels]
91
+ n_classes = len(class_labels)
92
+
93
+ # Handle None values by replacing with list of NaNs
94
+ proba_values = []
95
+ for val in df[proba_column]:
96
+ if val is None:
97
+ proba_values.append([np.nan] * n_classes)
98
+ else:
99
+ proba_values.append(val)
100
+
101
+ proba_df = pd.DataFrame(proba_values, columns=proba_splits)
102
+
103
+ df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
104
+ df = df.reset_index(drop=True)
105
+ df = pd.concat([df, proba_df], axis=1)
106
+ return df
107
+
108
+
109
+ def create_molecule_datapoints(
110
+ smiles_list: list[str],
111
+ targets: list[float] | None = None,
112
+ extra_descriptors: np.ndarray | None = None,
113
+ ) -> tuple[list[data.MoleculeDatapoint], list[int]]:
114
+ """Create ChemProp MoleculeDatapoints from SMILES strings.
115
+
116
+ Args:
117
+ smiles_list: List of SMILES strings
118
+ targets: Optional list of target values (for training)
119
+ extra_descriptors: Optional array of extra features (n_samples, n_features)
120
+
121
+ Returns:
122
+ Tuple of (list of MoleculeDatapoint objects, list of valid indices)
123
+ """
124
+ from rdkit import Chem
125
+
126
+ datapoints = []
127
+ valid_indices = []
128
+ invalid_count = 0
129
+
130
+ for i, smi in enumerate(smiles_list):
131
+ # Validate SMILES with RDKit first
132
+ mol = Chem.MolFromSmiles(smi)
133
+ if mol is None:
134
+ invalid_count += 1
135
+ continue
136
+
137
+ # Build datapoint with optional target and extra descriptors
138
+ y = [targets[i]] if targets is not None else None
139
+ x_d = extra_descriptors[i] if extra_descriptors is not None else None
140
+
141
+ dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
142
+ datapoints.append(dp)
143
+ valid_indices.append(i)
144
+
145
+ if invalid_count > 0:
146
+ print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
147
+
148
+ return datapoints, valid_indices
149
+
150
+
151
+ def build_mpnn_model(
152
+ hyperparameters: dict,
153
+ task: str = "regression",
154
+ num_classes: int | None = None,
155
+ n_extra_descriptors: int = 0,
156
+ x_d_transform: nn.ScaleTransform | None = None,
157
+ output_transform: nn.UnscaleTransform | None = None,
158
+ ) -> models.MPNN:
159
+ """Build an MPNN model with the specified hyperparameters.
160
+
161
+ Args:
162
+ hyperparameters: Dictionary of model hyperparameters
163
+ task: Either "regression" or "classification"
164
+ num_classes: Number of classes for classification tasks
165
+ n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
166
+ x_d_transform: Optional transform for extra descriptors (scaling)
167
+ output_transform: Optional transform for regression output (unscaling targets)
168
+
169
+ Returns:
170
+ Configured MPNN model
171
+ """
172
+ # Model hyperparameters with defaults (based on OpenADMET baseline with slight improvements)
173
+ hidden_dim = hyperparameters.get("hidden_dim", 300)
174
+ depth = hyperparameters.get("depth", 4)
175
+ dropout = hyperparameters.get("dropout", 0.10)
176
+ ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 300)
177
+ ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
178
+
179
+ # Message passing component
180
+ mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
181
+
182
+ # Aggregation - NormAggregation normalizes output, recommended when using extra descriptors
183
+ agg = nn.NormAggregation()
184
+
185
+ # FFN input_dim = message passing output + extra descriptors
186
+ ffn_input_dim = hidden_dim + n_extra_descriptors
187
+
188
+ # Build FFN based on task type
189
+ if task == "classification" and num_classes is not None:
190
+ # Multi-class classification
191
+ ffn = nn.MulticlassClassificationFFN(
192
+ n_classes=num_classes,
193
+ input_dim=ffn_input_dim,
194
+ hidden_dim=ffn_hidden_dim,
195
+ n_layers=ffn_num_layers,
196
+ dropout=dropout,
197
+ )
198
+ else:
199
+ # Regression with optional output transform to unscale predictions
200
+ ffn = nn.RegressionFFN(
201
+ input_dim=ffn_input_dim,
202
+ hidden_dim=ffn_hidden_dim,
203
+ n_layers=ffn_num_layers,
204
+ dropout=dropout,
205
+ output_transform=output_transform,
206
+ )
207
+
208
+ # Create the MPNN model
209
+ mpnn = models.MPNN(
210
+ message_passing=mp,
211
+ agg=agg,
212
+ predictor=ffn,
213
+ batch_norm=True,
214
+ metrics=None,
215
+ X_d_transform=x_d_transform,
216
+ )
217
+
218
+ return mpnn
219
+
220
+
221
+ def model_fn(model_dir: str) -> dict:
222
+ """Load the ChemProp MPNN ensemble models from the specified directory.
223
+
224
+ Args:
225
+ model_dir: Directory containing the saved models
226
+
227
+ Returns:
228
+ Dictionary with ensemble models and metadata
229
+ """
230
+ # Load ensemble metadata
231
+ ensemble_metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
232
+ if os.path.exists(ensemble_metadata_path):
233
+ ensemble_metadata = joblib.load(ensemble_metadata_path)
234
+ n_ensemble = ensemble_metadata["n_ensemble"]
235
+ else:
236
+ # Backwards compatibility: single model without ensemble metadata
237
+ n_ensemble = 1
238
+
239
+ # Load all ensemble models
240
+ ensemble_models = []
241
+ for ens_idx in range(n_ensemble):
242
+ model_path = os.path.join(model_dir, f"chemprop_model_{ens_idx}.pt")
243
+ if not os.path.exists(model_path):
244
+ # Backwards compatibility: try old single model path
245
+ model_path = os.path.join(model_dir, "chemprop_model.pt")
246
+ model = models.MPNN.load_from_file(model_path)
247
+ model.eval()
248
+ ensemble_models.append(model)
249
+
250
+ print(f"Loaded {len(ensemble_models)} ensemble model(s)")
251
+
252
+ return {
253
+ "ensemble_models": ensemble_models,
254
+ "n_ensemble": n_ensemble,
255
+ }
256
+
257
+
258
+ def input_fn(input_data, content_type: str) -> pd.DataFrame:
259
+ """Parse input data and return a DataFrame."""
260
+ if not input_data:
261
+ raise ValueError("Empty input data is not supported!")
262
+
263
+ if isinstance(input_data, bytes):
264
+ input_data = input_data.decode("utf-8")
265
+
266
+ if "text/csv" in content_type:
267
+ return pd.read_csv(StringIO(input_data))
268
+ elif "application/json" in content_type:
269
+ return pd.DataFrame(json.loads(input_data))
270
+ else:
271
+ raise ValueError(f"{content_type} not supported!")
272
+
273
+
274
+ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
275
+ """Supports both CSV and JSON output formats."""
276
+ if "text/csv" in accept_type:
277
+ csv_output = output_df.fillna("N/A").to_csv(index=False)
278
+ return csv_output, "text/csv"
279
+ elif "application/json" in accept_type:
280
+ return output_df.to_json(orient="records"), "application/json"
281
+ else:
282
+ raise RuntimeError(
283
+ f"{accept_type} accept type is not supported by this script."
284
+ )
285
+
286
+
287
+ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
288
+ """Make predictions with the ChemProp MPNN ensemble.
289
+
290
+ Args:
291
+ df: Input DataFrame containing SMILES column (and extra features if hybrid mode)
292
+ model_dict: Dictionary containing ensemble models and metadata
293
+
294
+ Returns:
295
+ DataFrame with predictions added (and prediction_std for ensembles)
296
+ """
297
+ model_type = TEMPLATE_PARAMS["model_type"]
298
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
299
+
300
+ # Extract ensemble models
301
+ ensemble_models = model_dict["ensemble_models"]
302
+ n_ensemble = model_dict["n_ensemble"]
303
+
304
+ # Load label encoder if present (classification)
305
+ label_encoder = None
306
+ label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
307
+ if os.path.exists(label_encoder_path):
308
+ label_encoder = joblib.load(label_encoder_path)
309
+
310
+ # Load feature metadata if present (hybrid mode)
311
+ # Contains column names, NaN fill values, and scaler for feature scaling
312
+ feature_metadata = None
313
+ feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
314
+ if os.path.exists(feature_metadata_path):
315
+ feature_metadata = joblib.load(feature_metadata_path)
316
+ print(
317
+ f"Hybrid mode: using {len(feature_metadata['extra_feature_cols'])} extra features"
318
+ )
319
+
320
+ # Find SMILES column in input DataFrame
321
+ smiles_column = find_smiles_column(df.columns.tolist())
322
+
323
+ smiles_list = df[smiles_column].tolist()
324
+
325
+ # Track invalid SMILES
326
+ valid_mask = []
327
+ valid_smiles = []
328
+ valid_indices = []
329
+ for i, smi in enumerate(smiles_list):
330
+ if smi and isinstance(smi, str) and len(smi.strip()) > 0:
331
+ valid_mask.append(True)
332
+ valid_smiles.append(smi.strip())
333
+ valid_indices.append(i)
334
+ else:
335
+ valid_mask.append(False)
336
+
337
+ valid_mask = np.array(valid_mask)
338
+ print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
339
+
340
+ # Initialize prediction column (use object dtype for classifiers to avoid FutureWarning)
341
+ if model_type == "classifier":
342
+ df["prediction"] = pd.Series([None] * len(df), dtype=object)
343
+ else:
344
+ df["prediction"] = np.nan
345
+ if n_ensemble > 1:
346
+ df["prediction_std"] = np.nan
347
+
348
+ if sum(valid_mask) == 0:
349
+ print("Warning: No valid SMILES to predict on")
350
+ return df
351
+
352
+ # Prepare extra features if in hybrid mode
353
+ # NOTE: We pass RAW (unscaled) features here - the model's X_d_transform handles scaling
354
+ extra_features = None
355
+ if feature_metadata is not None:
356
+ extra_feature_cols = feature_metadata["extra_feature_cols"]
357
+ col_means = np.array(feature_metadata["col_means"])
358
+
359
+ # Check columns exist
360
+ missing_cols = [col for col in extra_feature_cols if col not in df.columns]
361
+ if missing_cols:
362
+ print(
363
+ f"Warning: Missing extra feature columns: {missing_cols}. Using mean values."
364
+ )
365
+
366
+ # Extract features for valid SMILES rows (raw, unscaled)
367
+ extra_features = np.zeros(
368
+ (len(valid_indices), len(extra_feature_cols)), dtype=np.float32
369
+ )
370
+ for j, col in enumerate(extra_feature_cols):
371
+ if col in df.columns:
372
+ values = df.iloc[valid_indices][col].values.astype(np.float32)
373
+ # Fill NaN with training column means (unscaled means)
374
+ nan_mask = np.isnan(values)
375
+ values[nan_mask] = col_means[j]
376
+ extra_features[:, j] = values
377
+ else:
378
+ # Column missing, use training mean
379
+ extra_features[:, j] = col_means[j]
380
+
381
+ # Create datapoints for prediction (filter out invalid SMILES)
382
+ datapoints, rdkit_valid_indices = create_molecule_datapoints(
383
+ valid_smiles, extra_descriptors=extra_features
384
+ )
385
+
386
+ if len(datapoints) == 0:
387
+ print("Warning: No valid SMILES after RDKit validation")
388
+ return df
389
+
390
+ dataset = data.MoleculeDataset(datapoints)
391
+ dataloader = data.build_dataloader(dataset, shuffle=False)
392
+
393
+ # Make predictions with ensemble
394
+ trainer = pl.Trainer(
395
+ accelerator="auto",
396
+ logger=False,
397
+ enable_progress_bar=False,
398
+ )
399
+
400
+ # Collect predictions from all ensemble members
401
+ all_ensemble_preds = []
402
+ for ens_idx, ens_model in enumerate(ensemble_models):
403
+ with torch.inference_mode():
404
+ predictions = trainer.predict(ens_model, dataloader)
405
+ ens_preds = np.concatenate([p.numpy() for p in predictions], axis=0)
406
+ # Squeeze middle dim if present
407
+ if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
408
+ ens_preds = ens_preds.squeeze(axis=1)
409
+ all_ensemble_preds.append(ens_preds)
410
+
411
+ # Stack and compute mean/std
412
+ ensemble_preds = np.stack(all_ensemble_preds, axis=0)
413
+ preds = np.mean(ensemble_preds, axis=0)
414
+ preds_std = np.std(ensemble_preds, axis=0) if n_ensemble > 1 else None
415
+
416
+ print(f"Inference: Ensemble predictions shape: {preds.shape}")
417
+
418
+ # Map predictions back to valid_mask positions (accounting for RDKit-invalid SMILES)
419
+ # rdkit_valid_indices tells us which of the valid_smiles were actually valid
420
+ valid_positions = np.where(valid_mask)[0][rdkit_valid_indices]
421
+ valid_mask = np.zeros(len(df), dtype=bool)
422
+ valid_mask[valid_positions] = True
423
+
424
+ if model_type == "classifier" and label_encoder is not None:
425
+ # For classification, get class predictions and probabilities
426
+ if preds.ndim == 2 and preds.shape[1] > 1:
427
+ # Multi-class: preds are probabilities (averaged across ensemble)
428
+ class_preds = np.argmax(preds, axis=1)
429
+ decoded_preds = label_encoder.inverse_transform(class_preds)
430
+ df.loc[valid_mask, "prediction"] = decoded_preds
431
+
432
+ # Add probability columns
433
+ proba_series = pd.Series([None] * len(df), index=df.index, dtype=object)
434
+ proba_series.loc[valid_mask] = [p.tolist() for p in preds]
435
+ df["pred_proba"] = proba_series
436
+ df = expand_proba_column(df, label_encoder.classes_)
437
+ else:
438
+ # Binary or single output
439
+ class_preds = (preds.flatten() > 0.5).astype(int)
440
+ decoded_preds = label_encoder.inverse_transform(class_preds)
441
+ df.loc[valid_mask, "prediction"] = decoded_preds
442
+ else:
443
+ # Regression: direct predictions
444
+ df.loc[valid_mask, "prediction"] = preds.flatten()
445
+
446
+ # Add prediction_std for ensemble models
447
+ if preds_std is not None:
448
+ df.loc[valid_mask, "prediction_std"] = preds_std.flatten()
449
+
450
+ return df
451
+
452
+
453
+ if __name__ == "__main__":
454
+ """Training script for ChemProp MPNN model"""
455
+
456
+ # Template Parameters
457
+ target = TEMPLATE_PARAMS["target"]
458
+ model_type = TEMPLATE_PARAMS["model_type"]
459
+ feature_list = TEMPLATE_PARAMS["feature_list"]
460
+ model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
461
+ train_all_data = TEMPLATE_PARAMS["train_all_data"]
462
+ hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
463
+ validation_split = 0.2
464
+
465
+ # Get the SMILES column name from feature_list (user defines this, so we use their exact name)
466
+ smiles_column = find_smiles_column(feature_list)
467
+ extra_feature_cols = [f for f in feature_list if f != smiles_column]
468
+ use_extra_features = len(extra_feature_cols) > 0
469
+ print(f"Feature List: {feature_list}")
470
+ print(f"SMILES Column: {smiles_column}")
471
+ print(
472
+ f"Extra Features (hybrid mode): {extra_feature_cols if use_extra_features else 'None (SMILES only)'}"
473
+ )
474
+
475
+ # Script arguments for input/output directories
476
+ parser = argparse.ArgumentParser()
477
+ parser.add_argument(
478
+ "--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
479
+ )
480
+ parser.add_argument(
481
+ "--train",
482
+ type=str,
483
+ default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"),
484
+ )
485
+ parser.add_argument(
486
+ "--output-data-dir",
487
+ type=str,
488
+ default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"),
489
+ )
490
+ args = parser.parse_args()
491
+
492
+ # Read the training data
493
+ training_files = [
494
+ os.path.join(args.train, f)
495
+ for f in os.listdir(args.train)
496
+ if f.endswith(".csv")
497
+ ]
498
+ print(f"Training Files: {training_files}")
499
+
500
+ all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
501
+ print(f"All Data Shape: {all_df.shape}")
502
+
503
+ check_dataframe(all_df, "training_df")
504
+
505
+ # Drop rows with missing SMILES or target values
506
+ initial_count = len(all_df)
507
+ all_df = all_df.dropna(subset=[smiles_column, target])
508
+ dropped = initial_count - len(all_df)
509
+ if dropped > 0:
510
+ print(f"Dropped {dropped} rows with missing SMILES or target values")
511
+
512
+ print(f"Target: {target}")
513
+ print(f"Data Shape after cleaning: {all_df.shape}")
514
+
515
+ # Set up label encoder for classification
516
+ label_encoder = None
517
+ if model_type == "classifier":
518
+ label_encoder = LabelEncoder()
519
+ all_df[target] = label_encoder.fit_transform(all_df[target])
520
+ num_classes = len(label_encoder.classes_)
521
+ print(
522
+ f"Classification task with {num_classes} classes: {label_encoder.classes_}"
523
+ )
524
+ else:
525
+ num_classes = None
526
+
527
+ # Model and training configuration
528
+ print(f"Hyperparameters: {hyperparameters}")
529
+ task = "classification" if model_type == "classifier" else "regression"
530
+ n_extra = len(extra_feature_cols) if use_extra_features else 0
531
+ max_epochs = hyperparameters.get("max_epochs", 50)
532
+ patience = hyperparameters.get("patience", 10)
533
+ n_folds = hyperparameters.get("n_folds", 1) # Number of CV folds (default: 1 = no CV)
534
+ batch_size = hyperparameters.get("batch_size", min(64, max(16, len(all_df) // 16)))
535
+
536
+ # Check extra feature columns exist
537
+ if use_extra_features:
538
+ missing_cols = [col for col in extra_feature_cols if col not in all_df.columns]
539
+ if missing_cols:
540
+ raise ValueError(f"Missing extra feature columns in training data: {missing_cols}")
541
+
542
+ # =========================================================================
543
+ # SINGLE MODEL TRAINING (n_folds=1) - uses train/val split
544
+ # =========================================================================
545
+ if n_folds == 1:
546
+ print("Training single model (no cross-validation)...")
547
+
548
+ # Split data
549
+ if train_all_data:
550
+ print("Training on ALL of the data")
551
+ df_train = all_df.copy()
552
+ df_val = all_df.copy()
553
+ elif "training" in all_df.columns:
554
+ print("Found training column, splitting data based on training column")
555
+ df_train = all_df[all_df["training"]].copy()
556
+ df_val = all_df[~all_df["training"]].copy()
557
+ else:
558
+ print("WARNING: No training column found, splitting data with random state=42")
559
+ df_train, df_val = train_test_split(
560
+ all_df, test_size=validation_split, random_state=42
561
+ )
562
+
563
+ print(f"TRAIN: {df_train.shape}")
564
+ print(f"VALIDATION: {df_val.shape}")
565
+
566
+ # Extract and prepare extra features
567
+ train_extra_features = None
568
+ val_extra_features = None
569
+ col_means = None
570
+
571
+ if use_extra_features:
572
+ train_extra_features = df_train[extra_feature_cols].values.astype(np.float32)
573
+ val_extra_features = df_val[extra_feature_cols].values.astype(np.float32)
574
+ col_means = np.nanmean(train_extra_features, axis=0)
575
+ for i in range(train_extra_features.shape[1]):
576
+ train_extra_features[np.isnan(train_extra_features[:, i]), i] = col_means[i]
577
+ val_extra_features[np.isnan(val_extra_features[:, i]), i] = col_means[i]
578
+
579
+ # Create ChemProp datasets
580
+ train_datapoints, train_valid_idx = create_molecule_datapoints(
581
+ df_train[smiles_column].tolist(), df_train[target].tolist(), train_extra_features
582
+ )
583
+ val_datapoints, val_valid_idx = create_molecule_datapoints(
584
+ df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra_features
585
+ )
586
+
587
+ df_train = df_train.iloc[train_valid_idx].reset_index(drop=True)
588
+ df_val = df_val.iloc[val_valid_idx].reset_index(drop=True)
589
+
590
+ train_dataset = data.MoleculeDataset(train_datapoints)
591
+ val_dataset = data.MoleculeDataset(val_datapoints)
592
+
593
+ # Save raw validation features for predictions later
594
+ val_extra_raw = val_extra_features[val_valid_idx] if val_extra_features is not None else None
595
+
596
+ # Scale features and targets
597
+ x_d_transform = None
598
+ if use_extra_features:
599
+ feature_scaler = train_dataset.normalize_inputs("X_d")
600
+ val_dataset.normalize_inputs("X_d", feature_scaler)
601
+ x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
602
+
603
+ output_transform = None
604
+ if model_type == "regressor":
605
+ target_scaler = train_dataset.normalize_targets()
606
+ val_dataset.normalize_targets(target_scaler)
607
+ output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
608
+
609
+ train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
610
+ val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
611
+
612
+ # Build and train single model
613
+ pl.seed_everything(42)
614
+ mpnn = build_mpnn_model(
615
+ hyperparameters, task=task, num_classes=num_classes,
616
+ n_extra_descriptors=n_extra, x_d_transform=x_d_transform, output_transform=output_transform,
617
+ )
618
+
619
+ callbacks = [
620
+ pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
621
+ pl.callbacks.ModelCheckpoint(
622
+ dirpath=args.model_dir, filename="best_model_0",
623
+ monitor="val_loss", mode="min", save_top_k=1,
624
+ ),
625
+ ]
626
+
627
+ trainer = pl.Trainer(
628
+ accelerator="auto", max_epochs=max_epochs, callbacks=callbacks,
629
+ logger=False, enable_progress_bar=True,
630
+ )
631
+
632
+ trainer.fit(mpnn, train_loader, val_loader)
633
+
634
+ if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
635
+ checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
636
+ mpnn.load_state_dict(checkpoint["state_dict"])
637
+
638
+ mpnn.eval()
639
+ ensemble_models = [mpnn]
640
+
641
+ # Make predictions on validation set
642
+ val_datapoints_raw, _ = create_molecule_datapoints(
643
+ df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra_raw
644
+ )
645
+ val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
646
+ val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
647
+
648
+ with torch.inference_mode():
649
+ val_predictions = trainer.predict(mpnn, val_loader_pred)
650
+ preds = np.concatenate([p.numpy() for p in val_predictions], axis=0)
651
+ if preds.ndim == 3 and preds.shape[1] == 1:
652
+ preds = preds.squeeze(axis=1)
653
+
654
+ preds_std = None
655
+ y_validate = df_val[target].values
656
+
657
+ # =========================================================================
658
+ # K-FOLD CROSS-VALIDATION (n_folds > 1) - trains n_folds models
659
+ # =========================================================================
660
+ else:
661
+ print(f"Training {n_folds}-fold cross-validation ensemble...")
662
+
663
+ # Validate all SMILES upfront and filter invalid ones
664
+ all_extra_features = None
665
+ if use_extra_features:
666
+ all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
667
+ col_means = np.nanmean(all_extra_features, axis=0)
668
+ for i in range(all_extra_features.shape[1]):
669
+ all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
670
+ else:
671
+ col_means = None
672
+
673
+ # Filter invalid SMILES from the full dataset
674
+ _, valid_indices = create_molecule_datapoints(
675
+ all_df[smiles_column].tolist(), all_df[target].tolist(), all_extra_features
676
+ )
677
+ all_df = all_df.iloc[valid_indices].reset_index(drop=True)
678
+ if all_extra_features is not None:
679
+ all_extra_features = all_extra_features[valid_indices]
680
+ print(f"Data after SMILES validation: {all_df.shape}")
681
+
682
+ # Set up K-Fold
683
+ if model_type == "classifier":
684
+ kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
685
+ split_target = all_df[target]
686
+ else:
687
+ kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
688
+ split_target = None
689
+
690
+ # Initialize storage for out-of-fold predictions
691
+ oof_predictions = np.full(len(all_df), np.nan, dtype=np.float64)
692
+ if model_type == "classifier" and num_classes and num_classes > 1:
693
+ oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
694
+ else:
695
+ oof_proba = None
696
+
697
+ ensemble_models = []
698
+
699
+ for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(all_df, split_target)):
700
+ print(f"\n{'='*50}")
701
+ print(f"Training Fold {fold_idx + 1}/{n_folds}")
702
+ print(f"{'='*50}")
703
+
704
+ # Split data for this fold
705
+ df_train = all_df.iloc[train_idx].reset_index(drop=True)
706
+ df_val = all_df.iloc[val_idx].reset_index(drop=True)
707
+
708
+ train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
709
+ val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
710
+
711
+ print(f"Fold {fold_idx + 1} - Train: {len(df_train)}, Val: {len(df_val)}")
712
+
713
+ # Create ChemProp datasets for this fold
714
+ train_datapoints, _ = create_molecule_datapoints(
715
+ df_train[smiles_column].tolist(), df_train[target].tolist(), train_extra
716
+ )
717
+ val_datapoints, _ = create_molecule_datapoints(
718
+ df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra
719
+ )
720
+
721
+ train_dataset = data.MoleculeDataset(train_datapoints)
722
+ val_dataset = data.MoleculeDataset(val_datapoints)
723
+
724
+ # Save raw val features for prediction
725
+ val_extra_raw = val_extra.copy() if val_extra is not None else None
726
+
727
+ # Scale features and targets for this fold
728
+ x_d_transform = None
729
+ if use_extra_features:
730
+ feature_scaler = train_dataset.normalize_inputs("X_d")
731
+ val_dataset.normalize_inputs("X_d", feature_scaler)
732
+ x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
733
+
734
+ output_transform = None
735
+ if model_type == "regressor":
736
+ target_scaler = train_dataset.normalize_targets()
737
+ val_dataset.normalize_targets(target_scaler)
738
+ output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
739
+
740
+ train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
741
+ val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
742
+
743
+ # Build and train model for this fold
744
+ pl.seed_everything(42 + fold_idx)
745
+ mpnn = build_mpnn_model(
746
+ hyperparameters, task=task, num_classes=num_classes,
747
+ n_extra_descriptors=n_extra, x_d_transform=x_d_transform, output_transform=output_transform,
748
+ )
749
+
750
+ callbacks = [
751
+ pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
752
+ pl.callbacks.ModelCheckpoint(
753
+ dirpath=args.model_dir, filename=f"best_model_{fold_idx}",
754
+ monitor="val_loss", mode="min", save_top_k=1,
755
+ ),
756
+ ]
757
+
758
+ trainer = pl.Trainer(
759
+ accelerator="auto", max_epochs=max_epochs, callbacks=callbacks,
760
+ logger=False, enable_progress_bar=True,
761
+ )
762
+
763
+ trainer.fit(mpnn, train_loader, val_loader)
764
+
765
+ if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
766
+ checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
767
+ mpnn.load_state_dict(checkpoint["state_dict"])
768
+
769
+ mpnn.eval()
770
+ ensemble_models.append(mpnn)
771
+
772
+ # Make out-of-fold predictions using raw features
773
+ val_datapoints_raw, _ = create_molecule_datapoints(
774
+ df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra_raw
775
+ )
776
+ val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
777
+ val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
778
+
779
+ with torch.inference_mode():
780
+ fold_predictions = trainer.predict(mpnn, val_loader_pred)
781
+ fold_preds = np.concatenate([p.numpy() for p in fold_predictions], axis=0)
782
+ if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
783
+ fold_preds = fold_preds.squeeze(axis=1)
784
+
785
+ # Store out-of-fold predictions
786
+ if model_type == "classifier" and fold_preds.ndim == 2:
787
+ oof_predictions[val_idx] = np.argmax(fold_preds, axis=1)
788
+ if oof_proba is not None:
789
+ oof_proba[val_idx] = fold_preds
790
+ else:
791
+ oof_predictions[val_idx] = fold_preds.flatten()
792
+
793
+ print(f"Fold {fold_idx + 1} complete!")
794
+
795
+ print(f"\nCross-validation complete! Trained {len(ensemble_models)} models.")
796
+
797
+ # Use out-of-fold predictions for metrics
798
+ preds = oof_predictions
799
+ preds_std = None # Will compute from ensemble at inference time
800
+ y_validate = all_df[target].values
801
+ df_val = all_df # For saving predictions
802
+
803
+ if model_type == "classifier":
804
+ # Classification metrics - handle multi-class output
805
+ # For CV mode, preds already contains class indices; for single model, preds are probabilities
806
+ if preds.ndim == 2 and preds.shape[1] > 1:
807
+ # Multi-class probabilities: (n_samples, n_classes), take argmax
808
+ class_preds = np.argmax(preds, axis=1)
809
+ has_proba = True
810
+ elif preds.ndim == 1:
811
+ # Either class indices (CV mode) or binary probabilities
812
+ if n_folds > 1:
813
+ # CV mode: preds are already class indices
814
+ class_preds = preds.astype(int)
815
+ has_proba = False
816
+ else:
817
+ # Single model: preds are probabilities
818
+ class_preds = (preds > 0.5).astype(int)
819
+ has_proba = False
820
+ else:
821
+ # Squeeze extra dimensions if needed
822
+ preds = preds.squeeze()
823
+ if preds.ndim == 2:
824
+ class_preds = np.argmax(preds, axis=1)
825
+ has_proba = True
826
+ else:
827
+ class_preds = (preds > 0.5).astype(int)
828
+ has_proba = False
829
+
830
+ print(f"class_preds shape: {class_preds.shape}")
831
+
832
+ # Decode labels for metrics
833
+ y_validate_decoded = label_encoder.inverse_transform(y_validate.astype(int))
834
+ preds_decoded = label_encoder.inverse_transform(class_preds)
835
+
836
+ # Calculate metrics
837
+ label_names = label_encoder.classes_
838
+ scores = precision_recall_fscore_support(
839
+ y_validate_decoded, preds_decoded, average=None, labels=label_names
840
+ )
841
+
842
+ score_df = pd.DataFrame(
843
+ {
844
+ target: label_names,
845
+ "precision": scores[0],
846
+ "recall": scores[1],
847
+ "f1": scores[2],
848
+ "support": scores[3],
849
+ }
850
+ )
851
+
852
+ # Output metrics per class
853
+ metrics = ["precision", "recall", "f1", "support"]
854
+ for t in label_names:
855
+ for m in metrics:
856
+ value = score_df.loc[score_df[target] == t, m].iloc[0]
857
+ print(f"Metrics:{t}:{m} {value}")
858
+
859
+ # Confusion matrix
860
+ conf_mtx = confusion_matrix(
861
+ y_validate_decoded, preds_decoded, labels=label_names
862
+ )
863
+ for i, row_name in enumerate(label_names):
864
+ for j, col_name in enumerate(label_names):
865
+ value = conf_mtx[i, j]
866
+ print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
867
+
868
+ # Save validation predictions
869
+ df_val = df_val.copy()
870
+ df_val["prediction"] = preds_decoded
871
+ if has_proba and preds.ndim == 2 and preds.shape[1] > 1:
872
+ df_val["pred_proba"] = [p.tolist() for p in preds]
873
+ df_val = expand_proba_column(df_val, label_names)
874
+
875
+ else:
876
+ # Regression metrics
877
+ preds_flat = preds.flatten()
878
+ rmse = root_mean_squared_error(y_validate, preds_flat)
879
+ mae = mean_absolute_error(y_validate, preds_flat)
880
+ r2 = r2_score(y_validate, preds_flat)
881
+ print(f"RMSE: {rmse:.3f}")
882
+ print(f"MAE: {mae:.3f}")
883
+ print(f"R2: {r2:.3f}")
884
+ print(f"NumRows: {len(df_val)}")
885
+
886
+ df_val = df_val.copy()
887
+ df_val["prediction"] = preds_flat
888
+
889
+ # Add prediction_std for ensemble models
890
+ if preds_std is not None:
891
+ df_val["prediction_std"] = preds_std.flatten()
892
+ print(f"Ensemble std - mean: {df_val['prediction_std'].mean():.4f}, max: {df_val['prediction_std'].max():.4f}")
893
+
894
+ # Save validation predictions to S3
895
+ output_columns = [target, "prediction"]
896
+ if "prediction_std" in df_val.columns:
897
+ output_columns.append("prediction_std")
898
+ output_columns += [col for col in df_val.columns if col.endswith("_proba")]
899
+ wr.s3.to_csv(
900
+ df_val[output_columns],
901
+ path=f"{model_metrics_s3_path}/validation_predictions.csv",
902
+ index=False,
903
+ )
904
+
905
+ # Save ensemble models (n_folds models if CV, 1 model otherwise)
906
+ for model_idx, ens_model in enumerate(ensemble_models):
907
+ model_path = os.path.join(args.model_dir, f"chemprop_model_{model_idx}.pt")
908
+ models.save_model(model_path, ens_model)
909
+ print(f"Saved model {model_idx + 1} to {model_path}")
910
+
911
+ # Save ensemble metadata (n_ensemble = number of models for inference)
912
+ n_ensemble = len(ensemble_models)
913
+ ensemble_metadata = {"n_ensemble": n_ensemble, "n_folds": n_folds}
914
+ joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
915
+ print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds})")
916
+
917
+ # Save label encoder if classification
918
+ if label_encoder is not None:
919
+ joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
920
+
921
+ # Save extra feature metadata for inference (hybrid mode)
922
+ # Note: We don't need to save the scaler - X_d_transform is embedded in the model
923
+ if use_extra_features:
924
+ feature_metadata = {
925
+ "extra_feature_cols": extra_feature_cols,
926
+ "col_means": col_means.tolist(), # Unscaled means for NaN imputation
927
+ }
928
+ joblib.dump(
929
+ feature_metadata, os.path.join(args.model_dir, "feature_metadata.joblib")
930
+ )
931
+ print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")