workbench 0.8.198__py3-none-any.whl → 0.8.201__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. workbench/algorithms/dataframe/proximity.py +11 -4
  2. workbench/api/__init__.py +2 -1
  3. workbench/api/feature_set.py +7 -4
  4. workbench/api/model.py +1 -1
  5. workbench/core/artifacts/__init__.py +11 -2
  6. workbench/core/artifacts/endpoint_core.py +84 -46
  7. workbench/core/artifacts/feature_set_core.py +69 -1
  8. workbench/core/artifacts/model_core.py +37 -7
  9. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  10. workbench/core/transforms/features_to_model/features_to_model.py +23 -20
  11. workbench/core/views/view.py +2 -2
  12. workbench/model_scripts/chemprop/chemprop.template +931 -0
  13. workbench/model_scripts/chemprop/generated_model_script.py +931 -0
  14. workbench/model_scripts/chemprop/requirements.txt +11 -0
  15. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  16. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  17. workbench/model_scripts/custom_models/proximity/proximity.py +11 -4
  18. workbench/model_scripts/custom_models/uq_models/proximity.py +11 -4
  19. workbench/model_scripts/pytorch_model/generated_model_script.py +130 -88
  20. workbench/model_scripts/pytorch_model/pytorch.template +128 -86
  21. workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
  22. workbench/model_scripts/script_generation.py +10 -7
  23. workbench/model_scripts/uq_models/generated_model_script.py +25 -18
  24. workbench/model_scripts/uq_models/mapie.template +23 -16
  25. workbench/model_scripts/xgb_model/generated_model_script.py +6 -6
  26. workbench/model_scripts/xgb_model/xgb_model.template +2 -2
  27. workbench/repl/workbench_shell.py +14 -5
  28. workbench/scripts/endpoint_test.py +162 -0
  29. workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
  30. workbench/utils/chemprop_utils.py +724 -0
  31. workbench/utils/pytorch_utils.py +497 -0
  32. workbench/utils/xgboost_model_utils.py +10 -5
  33. {workbench-0.8.198.dist-info → workbench-0.8.201.dist-info}/METADATA +2 -2
  34. {workbench-0.8.198.dist-info → workbench-0.8.201.dist-info}/RECORD +38 -32
  35. {workbench-0.8.198.dist-info → workbench-0.8.201.dist-info}/entry_points.txt +2 -1
  36. workbench/model_scripts/__pycache__/script_generation.cpython-312.pyc +0 -0
  37. workbench/model_scripts/__pycache__/script_generation.cpython-313.pyc +0 -0
  38. {workbench-0.8.198.dist-info → workbench-0.8.201.dist-info}/WHEEL +0 -0
  39. {workbench-0.8.198.dist-info → workbench-0.8.201.dist-info}/licenses/LICENSE +0 -0
  40. {workbench-0.8.198.dist-info → workbench-0.8.201.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,724 @@
1
+ """ChemProp utilities for Workbench models."""
2
+
3
+ # flake8: noqa: E402
4
+ import logging
5
+ import os
6
+ import tempfile
7
+ from typing import Any, Tuple
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from scipy.stats import spearmanr
12
+ from sklearn.metrics import (
13
+ mean_absolute_error,
14
+ mean_squared_error,
15
+ median_absolute_error,
16
+ precision_recall_fscore_support,
17
+ r2_score,
18
+ roc_auc_score,
19
+ )
20
+ from sklearn.model_selection import KFold, StratifiedKFold
21
+ from sklearn.preprocessing import LabelEncoder
22
+
23
+ from workbench.utils.model_utils import safe_extract_tarfile
24
+ from workbench.utils.pandas_utils import expand_proba_column
25
+
26
+ log = logging.getLogger("workbench")
27
+
28
+
29
+ def download_and_extract_model(s3_uri: str, model_dir: str) -> None:
30
+ """Download model artifact from S3 and extract it.
31
+
32
+ Args:
33
+ s3_uri: S3 URI to the model artifact (model.tar.gz)
34
+ model_dir: Directory to extract model artifacts to
35
+ """
36
+ import awswrangler as wr
37
+
38
+ log.info(f"Downloading model from {s3_uri}...")
39
+
40
+ # Download to temp file
41
+ local_tar_path = os.path.join(model_dir, "model.tar.gz")
42
+ wr.s3.download(path=s3_uri, local_file=local_tar_path)
43
+
44
+ # Extract using safe extraction
45
+ log.info(f"Extracting to {model_dir}...")
46
+ safe_extract_tarfile(local_tar_path, model_dir)
47
+
48
+ # Cleanup tar file
49
+ os.unlink(local_tar_path)
50
+
51
+
52
+ def load_chemprop_model_artifacts(model_dir: str) -> Tuple[Any, dict]:
53
+ """Load ChemProp MPNN model and artifacts from an extracted model directory.
54
+
55
+ Args:
56
+ model_dir: Directory containing extracted model artifacts
57
+
58
+ Returns:
59
+ Tuple of (MPNN model, artifacts_dict).
60
+ artifacts_dict contains 'label_encoder' and 'feature_metadata' if present.
61
+ """
62
+ import joblib
63
+ from chemprop import models
64
+
65
+ model_path = os.path.join(model_dir, "chemprop_model.pt")
66
+ if not os.path.exists(model_path):
67
+ raise FileNotFoundError(f"No chemprop_model.pt found in {model_dir}")
68
+
69
+ model = models.MPNN.load_from_file(model_path)
70
+ model.eval()
71
+
72
+ # Load additional artifacts
73
+ artifacts = {}
74
+
75
+ label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
76
+ if os.path.exists(label_encoder_path):
77
+ artifacts["label_encoder"] = joblib.load(label_encoder_path)
78
+
79
+ feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
80
+ if os.path.exists(feature_metadata_path):
81
+ artifacts["feature_metadata"] = joblib.load(feature_metadata_path)
82
+
83
+ return model, artifacts
84
+
85
+
86
+ def _find_smiles_column(columns: list) -> str:
87
+ """Find the SMILES column name from a list (case-insensitive match for 'smiles')."""
88
+ smiles_column = next((col for col in columns if col.lower() == "smiles"), None)
89
+ if smiles_column is None:
90
+ raise ValueError("Column list must contain a 'smiles' column (case-insensitive)")
91
+ return smiles_column
92
+
93
+
94
+ def _create_molecule_datapoints(
95
+ smiles_list: list,
96
+ targets: list = None,
97
+ extra_descriptors: np.ndarray = None,
98
+ ) -> Tuple[list, list]:
99
+ """Create ChemProp MoleculeDatapoints from SMILES strings.
100
+
101
+ Args:
102
+ smiles_list: List of SMILES strings
103
+ targets: Optional list of target values (for training)
104
+ extra_descriptors: Optional array of extra features (n_samples, n_features)
105
+
106
+ Returns:
107
+ Tuple of (list of MoleculeDatapoint objects, list of valid indices)
108
+ """
109
+ from chemprop import data
110
+ from rdkit import Chem
111
+
112
+ datapoints = []
113
+ valid_indices = []
114
+ invalid_count = 0
115
+
116
+ for i, smi in enumerate(smiles_list):
117
+ # Validate SMILES with RDKit first
118
+ mol = Chem.MolFromSmiles(smi)
119
+ if mol is None:
120
+ invalid_count += 1
121
+ continue
122
+
123
+ # Build datapoint with optional target and extra descriptors
124
+ y = [targets[i]] if targets is not None else None
125
+ x_d = extra_descriptors[i] if extra_descriptors is not None else None
126
+
127
+ dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
128
+ datapoints.append(dp)
129
+ valid_indices.append(i)
130
+
131
+ if invalid_count > 0:
132
+ print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
133
+
134
+ return datapoints, valid_indices
135
+
136
+
137
+ def _build_mpnn_model(
138
+ hyperparameters: dict,
139
+ task: str = "regression",
140
+ num_classes: int = None,
141
+ n_extra_descriptors: int = 0,
142
+ x_d_transform: Any = None,
143
+ output_transform: Any = None,
144
+ ) -> Any:
145
+ """Build an MPNN model with the specified hyperparameters.
146
+
147
+ Args:
148
+ hyperparameters: Dictionary of model hyperparameters
149
+ task: Either "regression" or "classification"
150
+ num_classes: Number of classes for classification tasks
151
+ n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
152
+ x_d_transform: Optional transform for extra descriptors (scaling)
153
+ output_transform: Optional transform for regression output (unscaling targets)
154
+
155
+ Returns:
156
+ Configured MPNN model
157
+ """
158
+ from chemprop import models, nn
159
+
160
+ # Model hyperparameters with defaults
161
+ hidden_dim = hyperparameters.get("hidden_dim", 300)
162
+ depth = hyperparameters.get("depth", 3)
163
+ dropout = hyperparameters.get("dropout", 0.0)
164
+ ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 300)
165
+ ffn_num_layers = hyperparameters.get("ffn_num_layers", 1)
166
+
167
+ # Message passing component
168
+ mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
169
+
170
+ # Aggregation - NormAggregation normalizes output, recommended when using extra descriptors
171
+ agg = nn.NormAggregation()
172
+
173
+ # FFN input_dim = message passing output + extra descriptors
174
+ ffn_input_dim = hidden_dim + n_extra_descriptors
175
+
176
+ # Build FFN based on task type
177
+ if task == "classification" and num_classes is not None:
178
+ # Multi-class classification
179
+ ffn = nn.MulticlassClassificationFFN(
180
+ n_classes=num_classes,
181
+ input_dim=ffn_input_dim,
182
+ hidden_dim=ffn_hidden_dim,
183
+ n_layers=ffn_num_layers,
184
+ dropout=dropout,
185
+ )
186
+ else:
187
+ # Regression with optional output transform to unscale predictions
188
+ ffn = nn.RegressionFFN(
189
+ input_dim=ffn_input_dim,
190
+ hidden_dim=ffn_hidden_dim,
191
+ n_layers=ffn_num_layers,
192
+ dropout=dropout,
193
+ output_transform=output_transform,
194
+ )
195
+
196
+ # Create the MPNN model
197
+ mpnn = models.MPNN(
198
+ message_passing=mp,
199
+ agg=agg,
200
+ predictor=ffn,
201
+ batch_norm=True,
202
+ metrics=None,
203
+ X_d_transform=x_d_transform,
204
+ )
205
+
206
+ return mpnn
207
+
208
+
209
+ def _extract_model_hyperparameters(loaded_model: Any) -> dict:
210
+ """Extract hyperparameters from a loaded ChemProp MPNN model.
211
+
212
+ Extracts architecture parameters from the model's components to replicate
213
+ the exact same model configuration during cross-validation.
214
+
215
+ Args:
216
+ loaded_model: Loaded MPNN model instance
217
+
218
+ Returns:
219
+ Dictionary of hyperparameters matching the training template
220
+ """
221
+ hyperparameters = {}
222
+
223
+ # Extract from message passing layer (BondMessagePassing)
224
+ mp = loaded_model.message_passing
225
+ hyperparameters["hidden_dim"] = getattr(mp, "d_h", 300)
226
+ hyperparameters["depth"] = getattr(mp, "depth", 3)
227
+
228
+ # Dropout is stored as a nn.Dropout module, get the p value
229
+ if hasattr(mp, "dropout"):
230
+ dropout_module = mp.dropout
231
+ hyperparameters["dropout"] = getattr(dropout_module, "p", 0.0)
232
+ else:
233
+ hyperparameters["dropout"] = 0.0
234
+
235
+ # Extract from predictor (FFN - either RegressionFFN or MulticlassClassificationFFN)
236
+ ffn = loaded_model.predictor
237
+
238
+ # FFN hidden_dim - try multiple attribute names
239
+ if hasattr(ffn, "hidden_dim"):
240
+ hyperparameters["ffn_hidden_dim"] = ffn.hidden_dim
241
+ elif hasattr(ffn, "d_h"):
242
+ hyperparameters["ffn_hidden_dim"] = ffn.d_h
243
+ else:
244
+ hyperparameters["ffn_hidden_dim"] = 300
245
+
246
+ # FFN num_layers - try multiple attribute names
247
+ if hasattr(ffn, "n_layers"):
248
+ hyperparameters["ffn_num_layers"] = ffn.n_layers
249
+ elif hasattr(ffn, "num_layers"):
250
+ hyperparameters["ffn_num_layers"] = ffn.num_layers
251
+ else:
252
+ hyperparameters["ffn_num_layers"] = 1
253
+
254
+ # Training hyperparameters (use defaults matching the template)
255
+ hyperparameters["max_epochs"] = 50
256
+ hyperparameters["patience"] = 10
257
+
258
+ return hyperparameters
259
+
260
+
261
+ def _get_n_extra_descriptors(loaded_model: Any) -> int:
262
+ """Get the number of extra descriptors from the loaded model.
263
+
264
+ The model's X_d_transform contains the scaler which knows the feature dimension.
265
+
266
+ Args:
267
+ loaded_model: Loaded MPNN model instance
268
+
269
+ Returns:
270
+ Number of extra descriptors (0 if none)
271
+ """
272
+ x_d_transform = loaded_model.X_d_transform
273
+ if x_d_transform is None:
274
+ return 0
275
+
276
+ # ScaleTransform wraps a StandardScaler, check its mean_ attribute
277
+ if hasattr(x_d_transform, "mean"):
278
+ # x_d_transform.mean is a tensor
279
+ return len(x_d_transform.mean)
280
+ elif hasattr(x_d_transform, "scaler") and hasattr(x_d_transform.scaler, "mean_"):
281
+ return len(x_d_transform.scaler.mean_)
282
+
283
+ return 0
284
+
285
+
286
+ def cross_fold_inference(
287
+ workbench_model: Any,
288
+ nfolds: int = 5,
289
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
290
+ """Performs K-fold cross-validation for ChemProp MPNN models.
291
+
292
+ Replicates the training setup from the original model to ensure
293
+ cross-validation results are comparable to the deployed model.
294
+
295
+ Args:
296
+ workbench_model: Workbench model object
297
+ nfolds: Number of folds for cross-validation (default is 5)
298
+
299
+ Returns:
300
+ Tuple of:
301
+ - DataFrame with per-class metrics (and 'all' row for overall metrics)
302
+ - DataFrame with columns: id, target, prediction, and *_proba columns (for classifiers)
303
+ """
304
+ import shutil
305
+
306
+ import joblib
307
+ import torch
308
+ from chemprop import data, nn
309
+ from lightning import pytorch as pl
310
+
311
+ from workbench.api import FeatureSet
312
+
313
+ # Create a temporary model directory
314
+ model_dir = tempfile.mkdtemp(prefix="chemprop_cv_")
315
+ log.info(f"Using model directory: {model_dir}")
316
+
317
+ try:
318
+ # Download and extract model artifacts to get config and artifacts
319
+ model_artifact_uri = workbench_model.model_data_url()
320
+ download_and_extract_model(model_artifact_uri, model_dir)
321
+
322
+ # Load model and artifacts
323
+ loaded_model, artifacts = load_chemprop_model_artifacts(model_dir)
324
+ feature_metadata = artifacts.get("feature_metadata", {})
325
+
326
+ # Determine if classifier from predictor type
327
+ from chemprop.nn import MulticlassClassificationFFN
328
+
329
+ is_classifier = isinstance(loaded_model.predictor, MulticlassClassificationFFN)
330
+
331
+ # Use saved label encoder if available, otherwise create fresh one
332
+ if is_classifier:
333
+ label_encoder = artifacts.get("label_encoder")
334
+ if label_encoder is None:
335
+ log.warning("No saved label encoder found, creating fresh one")
336
+ label_encoder = LabelEncoder()
337
+ else:
338
+ label_encoder = None
339
+
340
+ # Prepare data
341
+ fs = FeatureSet(workbench_model.get_input())
342
+ df = workbench_model.training_view().pull_dataframe()
343
+
344
+ # Get columns
345
+ id_col = fs.id_column
346
+ target_col = workbench_model.target()
347
+ feature_cols = workbench_model.features()
348
+ print(f"Target column: {target_col}")
349
+ print(f"Feature columns: {len(feature_cols)} features")
350
+
351
+ # Find SMILES column
352
+ smiles_column = _find_smiles_column(feature_cols)
353
+
354
+ # Determine extra feature columns:
355
+ # 1. First try feature_metadata (saved during training)
356
+ # 2. Fall back to inferring from feature_cols (exclude SMILES column)
357
+ # 3. Verify against model's X_d_transform dimension
358
+ if feature_metadata and "extra_feature_cols" in feature_metadata:
359
+ extra_feature_cols = feature_metadata["extra_feature_cols"]
360
+ else:
361
+ # Infer from feature list - everything except SMILES is an extra feature
362
+ extra_feature_cols = [f for f in feature_cols if f.lower() != "smiles"]
363
+
364
+ # Verify against model's actual extra descriptor dimension
365
+ n_extra_from_model = _get_n_extra_descriptors(loaded_model)
366
+ if n_extra_from_model > 0 and len(extra_feature_cols) != n_extra_from_model:
367
+ log.warning(
368
+ f"Inferred {len(extra_feature_cols)} extra features but model expects "
369
+ f"{n_extra_from_model}. Using inferred columns."
370
+ )
371
+
372
+ use_extra_features = len(extra_feature_cols) > 0
373
+
374
+ print(f"SMILES column: {smiles_column}")
375
+ print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
376
+
377
+ # Drop rows with missing SMILES or target values
378
+ initial_count = len(df)
379
+ df = df.dropna(subset=[smiles_column, target_col])
380
+ dropped = initial_count - len(df)
381
+ if dropped > 0:
382
+ print(f"Dropped {dropped} rows with missing SMILES or target values")
383
+
384
+ # Extract hyperparameters from loaded model
385
+ hyperparameters = _extract_model_hyperparameters(loaded_model)
386
+ print(f"Extracted hyperparameters: {hyperparameters}")
387
+
388
+ # Get number of classes for classifier
389
+ num_classes = None
390
+ if is_classifier:
391
+ # Try to get from loaded model's FFN first (most reliable)
392
+ ffn = loaded_model.predictor
393
+ if hasattr(ffn, "n_classes"):
394
+ num_classes = ffn.n_classes
395
+ elif label_encoder is not None and hasattr(label_encoder, "classes_"):
396
+ num_classes = len(label_encoder.classes_)
397
+ else:
398
+ # Fit label encoder to get classes
399
+ if label_encoder is None:
400
+ label_encoder = LabelEncoder()
401
+ label_encoder.fit(df[target_col])
402
+ num_classes = len(label_encoder.classes_)
403
+ print(f"Classification task with {num_classes} classes")
404
+
405
+ X = df[[smiles_column] + extra_feature_cols]
406
+ y = df[target_col]
407
+ ids = df[id_col]
408
+
409
+ # Encode target if classifier
410
+ if label_encoder is not None:
411
+ if not hasattr(label_encoder, "classes_"):
412
+ label_encoder.fit(y)
413
+ y_encoded = label_encoder.transform(y)
414
+ y_for_cv = pd.Series(y_encoded, index=y.index, name=target_col)
415
+ else:
416
+ y_for_cv = y
417
+
418
+ # Prepare KFold
419
+ kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
420
+
421
+ # Initialize results collection
422
+ fold_metrics = []
423
+ predictions_df = pd.DataFrame({id_col: ids, target_col: y})
424
+ if is_classifier:
425
+ predictions_df["pred_proba"] = [None] * len(predictions_df)
426
+
427
+ # Perform cross-validation
428
+ for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y_for_cv), 1):
429
+ print(f"\n{'='*50}")
430
+ print(f"Fold {fold_idx}/{nfolds}")
431
+ print(f"{'='*50}")
432
+
433
+ # Split data
434
+ df_train = df.iloc[train_idx].copy()
435
+ df_val = df.iloc[val_idx].copy()
436
+
437
+ # Encode target for this fold
438
+ if is_classifier:
439
+ df_train[target_col] = label_encoder.transform(df_train[target_col])
440
+ df_val[target_col] = label_encoder.transform(df_val[target_col])
441
+
442
+ # Prepare extra features if using hybrid mode
443
+ train_extra_features = None
444
+ val_extra_features = None
445
+ col_means = None
446
+
447
+ if use_extra_features:
448
+ train_extra_features = df_train[extra_feature_cols].values.astype(np.float32)
449
+ val_extra_features = df_val[extra_feature_cols].values.astype(np.float32)
450
+
451
+ # Fill NaN with column means from training data
452
+ col_means = np.nanmean(train_extra_features, axis=0)
453
+ for i in range(train_extra_features.shape[1]):
454
+ train_nan_mask = np.isnan(train_extra_features[:, i])
455
+ val_nan_mask = np.isnan(val_extra_features[:, i])
456
+ train_extra_features[train_nan_mask, i] = col_means[i]
457
+ val_extra_features[val_nan_mask, i] = col_means[i]
458
+
459
+ # Create ChemProp datasets
460
+ train_datapoints, train_valid_idx = _create_molecule_datapoints(
461
+ df_train[smiles_column].tolist(),
462
+ df_train[target_col].tolist(),
463
+ train_extra_features,
464
+ )
465
+ val_datapoints, val_valid_idx = _create_molecule_datapoints(
466
+ df_val[smiles_column].tolist(),
467
+ df_val[target_col].tolist(),
468
+ val_extra_features,
469
+ )
470
+
471
+ # Update dataframes to only include valid molecules
472
+ df_train_valid = df_train.iloc[train_valid_idx].reset_index(drop=True)
473
+ df_val_valid = df_val.iloc[val_valid_idx].reset_index(drop=True)
474
+
475
+ train_dataset = data.MoleculeDataset(train_datapoints)
476
+ val_dataset = data.MoleculeDataset(val_datapoints)
477
+
478
+ # Save raw validation features before scaling
479
+ val_extra_raw = val_extra_features[val_valid_idx] if val_extra_features is not None else None
480
+
481
+ # Scale extra descriptors
482
+ feature_scaler = None
483
+ x_d_transform = None
484
+ if use_extra_features:
485
+ feature_scaler = train_dataset.normalize_inputs("X_d")
486
+ val_dataset.normalize_inputs("X_d", feature_scaler)
487
+ x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
488
+
489
+ # Scale targets for regression
490
+ target_scaler = None
491
+ output_transform = None
492
+ if not is_classifier:
493
+ target_scaler = train_dataset.normalize_targets()
494
+ val_dataset.normalize_targets(target_scaler)
495
+ output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
496
+
497
+ # Get batch size
498
+ batch_size = min(64, max(16, len(df_train_valid) // 16))
499
+
500
+ train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
501
+ val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
502
+
503
+ # Build the model
504
+ n_extra = len(extra_feature_cols) if use_extra_features else 0
505
+ mpnn = _build_mpnn_model(
506
+ hyperparameters,
507
+ task="classification" if is_classifier else "regression",
508
+ num_classes=num_classes,
509
+ n_extra_descriptors=n_extra,
510
+ x_d_transform=x_d_transform,
511
+ output_transform=output_transform,
512
+ )
513
+
514
+ # Training configuration
515
+ max_epochs = hyperparameters.get("max_epochs", 50)
516
+ patience = hyperparameters.get("patience", 10)
517
+
518
+ # Set up trainer
519
+ checkpoint_dir = os.path.join(model_dir, f"fold_{fold_idx}")
520
+ os.makedirs(checkpoint_dir, exist_ok=True)
521
+
522
+ callbacks = [
523
+ pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
524
+ pl.callbacks.ModelCheckpoint(
525
+ dirpath=checkpoint_dir,
526
+ filename="best_model",
527
+ monitor="val_loss",
528
+ mode="min",
529
+ save_top_k=1,
530
+ ),
531
+ ]
532
+
533
+ trainer = pl.Trainer(
534
+ accelerator="auto",
535
+ max_epochs=max_epochs,
536
+ callbacks=callbacks,
537
+ logger=False,
538
+ enable_progress_bar=True,
539
+ )
540
+
541
+ # Train the model
542
+ trainer.fit(mpnn, train_loader, val_loader)
543
+
544
+ # Load the best checkpoint
545
+ if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
546
+ best_ckpt_path = trainer.checkpoint_callback.best_model_path
547
+ checkpoint = torch.load(best_ckpt_path, weights_only=False)
548
+ mpnn.load_state_dict(checkpoint["state_dict"])
549
+
550
+ mpnn.eval()
551
+
552
+ # Make predictions using raw features
553
+ val_datapoints_raw, _ = _create_molecule_datapoints(
554
+ df_val_valid[smiles_column].tolist(),
555
+ df_val_valid[target_col].tolist(),
556
+ val_extra_raw,
557
+ )
558
+ val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
559
+ val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
560
+
561
+ with torch.inference_mode():
562
+ val_predictions = trainer.predict(mpnn, val_loader_pred)
563
+
564
+ preds = np.concatenate([p.numpy() for p in val_predictions], axis=0)
565
+
566
+ # ChemProp may return (n_samples, 1, n_classes) for multiclass - squeeze middle dim
567
+ if preds.ndim == 3 and preds.shape[1] == 1:
568
+ preds = preds.squeeze(axis=1)
569
+
570
+ # Map predictions back to original indices
571
+ original_val_indices = df.iloc[val_idx].index[val_valid_idx]
572
+
573
+ if is_classifier:
574
+ # Get class predictions
575
+ if preds.ndim == 2 and preds.shape[1] > 1:
576
+ class_preds = np.argmax(preds, axis=1)
577
+ else:
578
+ class_preds = (preds.flatten() > 0.5).astype(int)
579
+
580
+ preds_decoded = label_encoder.inverse_transform(class_preds)
581
+ predictions_df.loc[original_val_indices, "prediction"] = preds_decoded
582
+
583
+ # Store probabilities
584
+ if preds.ndim == 2 and preds.shape[1] > 1:
585
+ for i, idx in enumerate(original_val_indices):
586
+ predictions_df.at[idx, "pred_proba"] = preds[i].tolist()
587
+ else:
588
+ predictions_df.loc[original_val_indices, "prediction"] = preds.flatten()
589
+
590
+ # Calculate fold metrics
591
+ y_val = df_val_valid[target_col].values
592
+
593
+ if is_classifier:
594
+ y_val_orig = label_encoder.inverse_transform(y_val.astype(int))
595
+ preds_orig = preds_decoded
596
+
597
+ prec, rec, f1, _ = precision_recall_fscore_support(
598
+ y_val_orig, preds_orig, average="weighted", zero_division=0
599
+ )
600
+
601
+ prec_per_class, rec_per_class, f1_per_class, _ = precision_recall_fscore_support(
602
+ y_val_orig, preds_orig, average=None, zero_division=0, labels=label_encoder.classes_
603
+ )
604
+
605
+ # ROC AUC
606
+ if preds.ndim == 2 and preds.shape[1] > 1:
607
+ roc_auc_overall = roc_auc_score(y_val, preds, multi_class="ovr", average="macro")
608
+ roc_auc_per_class = roc_auc_score(y_val, preds, multi_class="ovr", average=None)
609
+ else:
610
+ roc_auc_overall = roc_auc_score(y_val, preds.flatten())
611
+ roc_auc_per_class = [roc_auc_overall]
612
+
613
+ fold_metrics.append(
614
+ {
615
+ "fold": fold_idx,
616
+ "precision": prec,
617
+ "recall": rec,
618
+ "f1": f1,
619
+ "roc_auc": roc_auc_overall,
620
+ "precision_per_class": prec_per_class,
621
+ "recall_per_class": rec_per_class,
622
+ "f1_per_class": f1_per_class,
623
+ "roc_auc_per_class": roc_auc_per_class,
624
+ }
625
+ )
626
+
627
+ print(f"Fold {fold_idx} - F1: {f1:.4f}, ROC-AUC: {roc_auc_overall:.4f}")
628
+ else:
629
+ spearman_corr, _ = spearmanr(y_val, preds.flatten())
630
+ rmse = np.sqrt(mean_squared_error(y_val, preds.flatten()))
631
+
632
+ fold_metrics.append(
633
+ {
634
+ "fold": fold_idx,
635
+ "rmse": rmse,
636
+ "mae": mean_absolute_error(y_val, preds.flatten()),
637
+ "medae": median_absolute_error(y_val, preds.flatten()),
638
+ "r2": r2_score(y_val, preds.flatten()),
639
+ "spearmanr": spearman_corr,
640
+ }
641
+ )
642
+
643
+ print(f"Fold {fold_idx} - RMSE: {rmse:.4f}, R2: {fold_metrics[-1]['r2']:.4f}")
644
+
645
+ # Calculate summary metrics
646
+ fold_df = pd.DataFrame(fold_metrics)
647
+
648
+ if is_classifier:
649
+ if "pred_proba" in predictions_df.columns:
650
+ predictions_df = expand_proba_column(predictions_df, label_encoder.classes_)
651
+
652
+ metric_rows = []
653
+ for idx, class_name in enumerate(label_encoder.classes_):
654
+ prec_scores = np.array([fold["precision_per_class"][idx] for fold in fold_metrics])
655
+ rec_scores = np.array([fold["recall_per_class"][idx] for fold in fold_metrics])
656
+ f1_scores = np.array([fold["f1_per_class"][idx] for fold in fold_metrics])
657
+ roc_auc_scores = np.array([fold["roc_auc_per_class"][idx] for fold in fold_metrics])
658
+
659
+ y_orig = label_encoder.inverse_transform(y_for_cv)
660
+ support = int((y_orig == class_name).sum())
661
+
662
+ metric_rows.append(
663
+ {
664
+ "class": class_name,
665
+ "precision": prec_scores.mean(),
666
+ "recall": rec_scores.mean(),
667
+ "f1": f1_scores.mean(),
668
+ "roc_auc": roc_auc_scores.mean(),
669
+ "support": support,
670
+ }
671
+ )
672
+
673
+ metric_rows.append(
674
+ {
675
+ "class": "all",
676
+ "precision": fold_df["precision"].mean(),
677
+ "recall": fold_df["recall"].mean(),
678
+ "f1": fold_df["f1"].mean(),
679
+ "roc_auc": fold_df["roc_auc"].mean(),
680
+ "support": len(y_for_cv),
681
+ }
682
+ )
683
+
684
+ metrics_df = pd.DataFrame(metric_rows)
685
+ else:
686
+ metrics_df = pd.DataFrame(
687
+ [
688
+ {
689
+ "rmse": fold_df["rmse"].mean(),
690
+ "mae": fold_df["mae"].mean(),
691
+ "medae": fold_df["medae"].mean(),
692
+ "r2": fold_df["r2"].mean(),
693
+ "spearmanr": fold_df["spearmanr"].mean(),
694
+ "support": len(y_for_cv),
695
+ }
696
+ ]
697
+ )
698
+
699
+ print(f"\n{'='*50}")
700
+ print("Cross-Validation Summary")
701
+ print(f"{'='*50}")
702
+ print(metrics_df.to_string(index=False))
703
+
704
+ return metrics_df, predictions_df
705
+
706
+ finally:
707
+ log.info(f"Cleaning up model directory: {model_dir}")
708
+ shutil.rmtree(model_dir, ignore_errors=True)
709
+
710
+
711
+ if __name__ == "__main__":
712
+
713
+ # Tests for the ChemProp utilities
714
+ from workbench.api import Endpoint, Model
715
+
716
+ # Initialize Workbench model
717
+ model_name = "aqsol-chemprop-reg"
718
+ print(f"Loading Workbench model: {model_name}")
719
+ model = Model(model_name)
720
+ print(f"Model Framework: {model.model_framework}")
721
+
722
+ # Perform cross-fold inference
723
+ end = Endpoint(model.endpoints()[0])
724
+ end.cross_fold_inference()