workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,760 @@
1
+ """ChemProp utilities for Workbench models."""
2
+
3
+ # flake8: noqa: E402
4
+ import logging
5
+ import os
6
+ import tempfile
7
+ from typing import Any, Tuple
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from scipy.stats import spearmanr
12
+ from sklearn.metrics import (
13
+ mean_absolute_error,
14
+ mean_squared_error,
15
+ median_absolute_error,
16
+ precision_recall_fscore_support,
17
+ r2_score,
18
+ roc_auc_score,
19
+ )
20
+ from sklearn.model_selection import KFold, StratifiedKFold
21
+ from sklearn.preprocessing import LabelEncoder
22
+
23
+ from workbench.utils.model_utils import safe_extract_tarfile
24
+ from workbench.utils.pandas_utils import expand_proba_column
25
+ from workbench.utils.aws_utils import pull_s3_data
26
+
27
+ log = logging.getLogger("workbench")
28
+
29
+
30
+ def download_and_extract_model(s3_uri: str, model_dir: str) -> None:
31
+ """Download model artifact from S3 and extract it.
32
+
33
+ Args:
34
+ s3_uri: S3 URI to the model artifact (model.tar.gz)
35
+ model_dir: Directory to extract model artifacts to
36
+ """
37
+ import awswrangler as wr
38
+
39
+ log.info(f"Downloading model from {s3_uri}...")
40
+
41
+ # Download to temp file
42
+ local_tar_path = os.path.join(model_dir, "model.tar.gz")
43
+ wr.s3.download(path=s3_uri, local_file=local_tar_path)
44
+
45
+ # Extract using safe extraction
46
+ log.info(f"Extracting to {model_dir}...")
47
+ safe_extract_tarfile(local_tar_path, model_dir)
48
+
49
+ # Cleanup tar file
50
+ os.unlink(local_tar_path)
51
+
52
+
53
+ def load_chemprop_model_artifacts(model_dir: str) -> Tuple[Any, dict]:
54
+ """Load ChemProp MPNN model and artifacts from an extracted model directory.
55
+
56
+ Args:
57
+ model_dir: Directory containing extracted model artifacts
58
+
59
+ Returns:
60
+ Tuple of (MPNN model, artifacts_dict).
61
+ artifacts_dict contains 'label_encoder' and 'feature_metadata' if present.
62
+ """
63
+ import joblib
64
+ from chemprop import models
65
+
66
+ model_path = os.path.join(model_dir, "chemprop_model.pt")
67
+ if not os.path.exists(model_path):
68
+ raise FileNotFoundError(f"No chemprop_model.pt found in {model_dir}")
69
+
70
+ model = models.MPNN.load_from_file(model_path)
71
+ model.eval()
72
+
73
+ # Load additional artifacts
74
+ artifacts = {}
75
+
76
+ label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
77
+ if os.path.exists(label_encoder_path):
78
+ artifacts["label_encoder"] = joblib.load(label_encoder_path)
79
+
80
+ feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
81
+ if os.path.exists(feature_metadata_path):
82
+ artifacts["feature_metadata"] = joblib.load(feature_metadata_path)
83
+
84
+ return model, artifacts
85
+
86
+
87
+ def _find_smiles_column(columns: list) -> str:
88
+ """Find the SMILES column name from a list (case-insensitive match for 'smiles')."""
89
+ smiles_column = next((col for col in columns if col.lower() == "smiles"), None)
90
+ if smiles_column is None:
91
+ raise ValueError("Column list must contain a 'smiles' column (case-insensitive)")
92
+ return smiles_column
93
+
94
+
95
+ def _create_molecule_datapoints(
96
+ smiles_list: list,
97
+ targets: list = None,
98
+ extra_descriptors: np.ndarray = None,
99
+ ) -> Tuple[list, list]:
100
+ """Create ChemProp MoleculeDatapoints from SMILES strings.
101
+
102
+ Args:
103
+ smiles_list: List of SMILES strings
104
+ targets: Optional list of target values (for training)
105
+ extra_descriptors: Optional array of extra features (n_samples, n_features)
106
+
107
+ Returns:
108
+ Tuple of (list of MoleculeDatapoint objects, list of valid indices)
109
+ """
110
+ from chemprop import data
111
+ from rdkit import Chem
112
+
113
+ datapoints = []
114
+ valid_indices = []
115
+ invalid_count = 0
116
+
117
+ for i, smi in enumerate(smiles_list):
118
+ # Validate SMILES with RDKit first
119
+ mol = Chem.MolFromSmiles(smi)
120
+ if mol is None:
121
+ invalid_count += 1
122
+ continue
123
+
124
+ # Build datapoint with optional target and extra descriptors
125
+ y = [targets[i]] if targets is not None else None
126
+ x_d = extra_descriptors[i] if extra_descriptors is not None else None
127
+
128
+ dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
129
+ datapoints.append(dp)
130
+ valid_indices.append(i)
131
+
132
+ if invalid_count > 0:
133
+ print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
134
+
135
+ return datapoints, valid_indices
136
+
137
+
138
+ def _build_mpnn_model(
139
+ hyperparameters: dict,
140
+ task: str = "regression",
141
+ num_classes: int = None,
142
+ n_extra_descriptors: int = 0,
143
+ x_d_transform: Any = None,
144
+ output_transform: Any = None,
145
+ ) -> Any:
146
+ """Build an MPNN model with the specified hyperparameters.
147
+
148
+ Args:
149
+ hyperparameters: Dictionary of model hyperparameters
150
+ task: Either "regression" or "classification"
151
+ num_classes: Number of classes for classification tasks
152
+ n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
153
+ x_d_transform: Optional transform for extra descriptors (scaling)
154
+ output_transform: Optional transform for regression output (unscaling targets)
155
+
156
+ Returns:
157
+ Configured MPNN model
158
+ """
159
+ from chemprop import models, nn
160
+
161
+ # Model hyperparameters with defaults
162
+ hidden_dim = hyperparameters.get("hidden_dim", 300)
163
+ depth = hyperparameters.get("depth", 3)
164
+ dropout = hyperparameters.get("dropout", 0.1)
165
+ ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 300)
166
+ ffn_num_layers = hyperparameters.get("ffn_num_layers", 1)
167
+
168
+ # Message passing component
169
+ mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
170
+
171
+ # Aggregation - NormAggregation normalizes output, recommended when using extra descriptors
172
+ agg = nn.NormAggregation()
173
+
174
+ # FFN input_dim = message passing output + extra descriptors
175
+ ffn_input_dim = hidden_dim + n_extra_descriptors
176
+
177
+ # Build FFN based on task type
178
+ if task == "classification" and num_classes is not None:
179
+ # Multi-class classification
180
+ ffn = nn.MulticlassClassificationFFN(
181
+ n_classes=num_classes,
182
+ input_dim=ffn_input_dim,
183
+ hidden_dim=ffn_hidden_dim,
184
+ n_layers=ffn_num_layers,
185
+ dropout=dropout,
186
+ )
187
+ else:
188
+ # Regression with optional output transform to unscale predictions
189
+ ffn = nn.RegressionFFN(
190
+ input_dim=ffn_input_dim,
191
+ hidden_dim=ffn_hidden_dim,
192
+ n_layers=ffn_num_layers,
193
+ dropout=dropout,
194
+ output_transform=output_transform,
195
+ )
196
+
197
+ # Create the MPNN model
198
+ mpnn = models.MPNN(
199
+ message_passing=mp,
200
+ agg=agg,
201
+ predictor=ffn,
202
+ batch_norm=True,
203
+ metrics=None,
204
+ X_d_transform=x_d_transform,
205
+ )
206
+
207
+ return mpnn
208
+
209
+
210
+ def _extract_model_hyperparameters(loaded_model: Any) -> dict:
211
+ """Extract hyperparameters from a loaded ChemProp MPNN model.
212
+
213
+ Extracts architecture parameters from the model's components to replicate
214
+ the exact same model configuration during cross-validation.
215
+
216
+ Args:
217
+ loaded_model: Loaded MPNN model instance
218
+
219
+ Returns:
220
+ Dictionary of hyperparameters matching the training template
221
+ """
222
+ hyperparameters = {}
223
+
224
+ # Extract from message passing layer (BondMessagePassing)
225
+ mp = loaded_model.message_passing
226
+ hyperparameters["hidden_dim"] = getattr(mp, "d_h", 300)
227
+ hyperparameters["depth"] = getattr(mp, "depth", 3)
228
+
229
+ # Dropout is stored as a nn.Dropout module, get the p value
230
+ if hasattr(mp, "dropout"):
231
+ dropout_module = mp.dropout
232
+ hyperparameters["dropout"] = getattr(dropout_module, "p", 0.0)
233
+ else:
234
+ hyperparameters["dropout"] = 0.0
235
+
236
+ # Extract from predictor (FFN - either RegressionFFN or MulticlassClassificationFFN)
237
+ ffn = loaded_model.predictor
238
+
239
+ # FFN hidden_dim - try multiple attribute names
240
+ if hasattr(ffn, "hidden_dim"):
241
+ hyperparameters["ffn_hidden_dim"] = ffn.hidden_dim
242
+ elif hasattr(ffn, "d_h"):
243
+ hyperparameters["ffn_hidden_dim"] = ffn.d_h
244
+ else:
245
+ hyperparameters["ffn_hidden_dim"] = 300
246
+
247
+ # FFN num_layers - try multiple attribute names
248
+ if hasattr(ffn, "n_layers"):
249
+ hyperparameters["ffn_num_layers"] = ffn.n_layers
250
+ elif hasattr(ffn, "num_layers"):
251
+ hyperparameters["ffn_num_layers"] = ffn.num_layers
252
+ else:
253
+ hyperparameters["ffn_num_layers"] = 1
254
+
255
+ # Training hyperparameters (use defaults matching the template)
256
+ hyperparameters["max_epochs"] = 50
257
+ hyperparameters["patience"] = 10
258
+
259
+ return hyperparameters
260
+
261
+
262
+ def _get_n_extra_descriptors(loaded_model: Any) -> int:
263
+ """Get the number of extra descriptors from the loaded model.
264
+
265
+ The model's X_d_transform contains the scaler which knows the feature dimension.
266
+
267
+ Args:
268
+ loaded_model: Loaded MPNN model instance
269
+
270
+ Returns:
271
+ Number of extra descriptors (0 if none)
272
+ """
273
+ x_d_transform = loaded_model.X_d_transform
274
+ if x_d_transform is None:
275
+ return 0
276
+
277
+ # ScaleTransform wraps a StandardScaler, check its mean_ attribute
278
+ if hasattr(x_d_transform, "mean"):
279
+ # x_d_transform.mean is a tensor
280
+ return len(x_d_transform.mean)
281
+ elif hasattr(x_d_transform, "scaler") and hasattr(x_d_transform.scaler, "mean_"):
282
+ return len(x_d_transform.scaler.mean_)
283
+
284
+ return 0
285
+
286
+
287
+ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
288
+ """Pull cross-validation results from AWS training artifacts.
289
+
290
+ This retrieves the validation predictions and training metrics that were
291
+ saved during model training.
292
+
293
+ Args:
294
+ workbench_model: Workbench model object
295
+
296
+ Returns:
297
+ Tuple of:
298
+ - DataFrame with training metrics
299
+ - DataFrame with validation predictions
300
+ """
301
+ # Get the validation predictions from S3
302
+ s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
303
+ predictions_df = pull_s3_data(s3_path)
304
+
305
+ if predictions_df is None:
306
+ raise ValueError(f"No validation predictions found at {s3_path}")
307
+
308
+ log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
309
+
310
+ # Get training metrics from model metadata
311
+ training_metrics = workbench_model.workbench_meta().get("workbench_training_metrics")
312
+
313
+ if training_metrics is None:
314
+ raise ValueError(f"No training metrics found in model metadata for {workbench_model.model_name}")
315
+
316
+ metrics_df = pd.DataFrame.from_dict(training_metrics)
317
+ log.info(f"Metrics summary:\n{metrics_df.to_string(index=False)}")
318
+
319
+ return metrics_df, predictions_df
320
+
321
+
322
+ def cross_fold_inference(
323
+ workbench_model: Any,
324
+ nfolds: int = 5,
325
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
326
+ """Performs K-fold cross-validation for ChemProp MPNN models.
327
+
328
+ Replicates the training setup from the original model to ensure
329
+ cross-validation results are comparable to the deployed model.
330
+
331
+ Args:
332
+ workbench_model: Workbench model object
333
+ nfolds: Number of folds for cross-validation (default is 5)
334
+
335
+ Returns:
336
+ Tuple of:
337
+ - DataFrame with per-class metrics (and 'all' row for overall metrics)
338
+ - DataFrame with columns: id, target, prediction, and *_proba columns (for classifiers)
339
+ """
340
+ import shutil
341
+
342
+ import joblib
343
+ import torch
344
+ from chemprop import data, nn
345
+ from lightning import pytorch as pl
346
+
347
+ from workbench.api import FeatureSet
348
+
349
+ # Create a temporary model directory
350
+ model_dir = tempfile.mkdtemp(prefix="chemprop_cv_")
351
+ log.info(f"Using model directory: {model_dir}")
352
+
353
+ try:
354
+ # Download and extract model artifacts to get config and artifacts
355
+ model_artifact_uri = workbench_model.model_data_url()
356
+ download_and_extract_model(model_artifact_uri, model_dir)
357
+
358
+ # Load model and artifacts
359
+ loaded_model, artifacts = load_chemprop_model_artifacts(model_dir)
360
+ feature_metadata = artifacts.get("feature_metadata", {})
361
+
362
+ # Determine if classifier from predictor type
363
+ from chemprop.nn import MulticlassClassificationFFN
364
+
365
+ is_classifier = isinstance(loaded_model.predictor, MulticlassClassificationFFN)
366
+
367
+ # Use saved label encoder if available, otherwise create fresh one
368
+ if is_classifier:
369
+ label_encoder = artifacts.get("label_encoder")
370
+ if label_encoder is None:
371
+ log.warning("No saved label encoder found, creating fresh one")
372
+ label_encoder = LabelEncoder()
373
+ else:
374
+ label_encoder = None
375
+
376
+ # Prepare data
377
+ fs = FeatureSet(workbench_model.get_input())
378
+ df = workbench_model.training_view().pull_dataframe()
379
+
380
+ # Get columns
381
+ id_col = fs.id_column
382
+ target_col = workbench_model.target()
383
+ feature_cols = workbench_model.features()
384
+ print(f"Target column: {target_col}")
385
+ print(f"Feature columns: {len(feature_cols)} features")
386
+
387
+ # Find SMILES column
388
+ smiles_column = _find_smiles_column(feature_cols)
389
+
390
+ # Determine extra feature columns:
391
+ # 1. First try feature_metadata (saved during training)
392
+ # 2. Fall back to inferring from feature_cols (exclude SMILES column)
393
+ # 3. Verify against model's X_d_transform dimension
394
+ if feature_metadata and "extra_feature_cols" in feature_metadata:
395
+ extra_feature_cols = feature_metadata["extra_feature_cols"]
396
+ else:
397
+ # Infer from feature list - everything except SMILES is an extra feature
398
+ extra_feature_cols = [f for f in feature_cols if f.lower() != "smiles"]
399
+
400
+ # Verify against model's actual extra descriptor dimension
401
+ n_extra_from_model = _get_n_extra_descriptors(loaded_model)
402
+ if n_extra_from_model > 0 and len(extra_feature_cols) != n_extra_from_model:
403
+ log.warning(
404
+ f"Inferred {len(extra_feature_cols)} extra features but model expects "
405
+ f"{n_extra_from_model}. Using inferred columns."
406
+ )
407
+
408
+ use_extra_features = len(extra_feature_cols) > 0
409
+
410
+ print(f"SMILES column: {smiles_column}")
411
+ print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
412
+
413
+ # Drop rows with missing SMILES or target values
414
+ initial_count = len(df)
415
+ df = df.dropna(subset=[smiles_column, target_col])
416
+ dropped = initial_count - len(df)
417
+ if dropped > 0:
418
+ print(f"Dropped {dropped} rows with missing SMILES or target values")
419
+
420
+ # Extract hyperparameters from loaded model
421
+ hyperparameters = _extract_model_hyperparameters(loaded_model)
422
+ print(f"Extracted hyperparameters: {hyperparameters}")
423
+
424
+ # Get number of classes for classifier
425
+ num_classes = None
426
+ if is_classifier:
427
+ # Try to get from loaded model's FFN first (most reliable)
428
+ ffn = loaded_model.predictor
429
+ if hasattr(ffn, "n_classes"):
430
+ num_classes = ffn.n_classes
431
+ elif label_encoder is not None and hasattr(label_encoder, "classes_"):
432
+ num_classes = len(label_encoder.classes_)
433
+ else:
434
+ # Fit label encoder to get classes
435
+ if label_encoder is None:
436
+ label_encoder = LabelEncoder()
437
+ label_encoder.fit(df[target_col])
438
+ num_classes = len(label_encoder.classes_)
439
+ print(f"Classification task with {num_classes} classes")
440
+
441
+ X = df[[smiles_column] + extra_feature_cols]
442
+ y = df[target_col]
443
+ ids = df[id_col]
444
+
445
+ # Encode target if classifier
446
+ if label_encoder is not None:
447
+ if not hasattr(label_encoder, "classes_"):
448
+ label_encoder.fit(y)
449
+ y_encoded = label_encoder.transform(y)
450
+ y_for_cv = pd.Series(y_encoded, index=y.index, name=target_col)
451
+ else:
452
+ y_for_cv = y
453
+
454
+ # Prepare KFold
455
+ kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
456
+
457
+ # Initialize results collection
458
+ fold_metrics = []
459
+ predictions_df = pd.DataFrame({id_col: ids, target_col: y})
460
+ if is_classifier:
461
+ predictions_df["pred_proba"] = [None] * len(predictions_df)
462
+
463
+ # Perform cross-validation
464
+ for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y_for_cv), 1):
465
+ print(f"\n{'='*50}")
466
+ print(f"Fold {fold_idx}/{nfolds}")
467
+ print(f"{'='*50}")
468
+
469
+ # Split data
470
+ df_train = df.iloc[train_idx].copy()
471
+ df_val = df.iloc[val_idx].copy()
472
+
473
+ # Encode target for this fold
474
+ if is_classifier:
475
+ df_train[target_col] = label_encoder.transform(df_train[target_col])
476
+ df_val[target_col] = label_encoder.transform(df_val[target_col])
477
+
478
+ # Prepare extra features if using hybrid mode
479
+ train_extra_features = None
480
+ val_extra_features = None
481
+ col_means = None
482
+
483
+ if use_extra_features:
484
+ train_extra_features = df_train[extra_feature_cols].values.astype(np.float32)
485
+ val_extra_features = df_val[extra_feature_cols].values.astype(np.float32)
486
+
487
+ # Fill NaN with column means from training data
488
+ col_means = np.nanmean(train_extra_features, axis=0)
489
+ for i in range(train_extra_features.shape[1]):
490
+ train_nan_mask = np.isnan(train_extra_features[:, i])
491
+ val_nan_mask = np.isnan(val_extra_features[:, i])
492
+ train_extra_features[train_nan_mask, i] = col_means[i]
493
+ val_extra_features[val_nan_mask, i] = col_means[i]
494
+
495
+ # Create ChemProp datasets
496
+ train_datapoints, train_valid_idx = _create_molecule_datapoints(
497
+ df_train[smiles_column].tolist(),
498
+ df_train[target_col].tolist(),
499
+ train_extra_features,
500
+ )
501
+ val_datapoints, val_valid_idx = _create_molecule_datapoints(
502
+ df_val[smiles_column].tolist(),
503
+ df_val[target_col].tolist(),
504
+ val_extra_features,
505
+ )
506
+
507
+ # Update dataframes to only include valid molecules
508
+ df_train_valid = df_train.iloc[train_valid_idx].reset_index(drop=True)
509
+ df_val_valid = df_val.iloc[val_valid_idx].reset_index(drop=True)
510
+
511
+ train_dataset = data.MoleculeDataset(train_datapoints)
512
+ val_dataset = data.MoleculeDataset(val_datapoints)
513
+
514
+ # Save raw validation features before scaling
515
+ val_extra_raw = val_extra_features[val_valid_idx] if val_extra_features is not None else None
516
+
517
+ # Scale extra descriptors
518
+ feature_scaler = None
519
+ x_d_transform = None
520
+ if use_extra_features:
521
+ feature_scaler = train_dataset.normalize_inputs("X_d")
522
+ val_dataset.normalize_inputs("X_d", feature_scaler)
523
+ x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
524
+
525
+ # Scale targets for regression
526
+ target_scaler = None
527
+ output_transform = None
528
+ if not is_classifier:
529
+ target_scaler = train_dataset.normalize_targets()
530
+ val_dataset.normalize_targets(target_scaler)
531
+ output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
532
+
533
+ # Get batch size
534
+ batch_size = min(64, max(16, len(df_train_valid) // 16))
535
+
536
+ train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
537
+ val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
538
+
539
+ # Build the model
540
+ n_extra = len(extra_feature_cols) if use_extra_features else 0
541
+ mpnn = _build_mpnn_model(
542
+ hyperparameters,
543
+ task="classification" if is_classifier else "regression",
544
+ num_classes=num_classes,
545
+ n_extra_descriptors=n_extra,
546
+ x_d_transform=x_d_transform,
547
+ output_transform=output_transform,
548
+ )
549
+
550
+ # Training configuration
551
+ max_epochs = hyperparameters.get("max_epochs", 50)
552
+ patience = hyperparameters.get("patience", 10)
553
+
554
+ # Set up trainer
555
+ checkpoint_dir = os.path.join(model_dir, f"fold_{fold_idx}")
556
+ os.makedirs(checkpoint_dir, exist_ok=True)
557
+
558
+ callbacks = [
559
+ pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
560
+ pl.callbacks.ModelCheckpoint(
561
+ dirpath=checkpoint_dir,
562
+ filename="best_model",
563
+ monitor="val_loss",
564
+ mode="min",
565
+ save_top_k=1,
566
+ ),
567
+ ]
568
+
569
+ trainer = pl.Trainer(
570
+ accelerator="auto",
571
+ max_epochs=max_epochs,
572
+ callbacks=callbacks,
573
+ logger=False,
574
+ enable_progress_bar=True,
575
+ )
576
+
577
+ # Train the model
578
+ trainer.fit(mpnn, train_loader, val_loader)
579
+
580
+ # Load the best checkpoint
581
+ if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
582
+ best_ckpt_path = trainer.checkpoint_callback.best_model_path
583
+ checkpoint = torch.load(best_ckpt_path, weights_only=False)
584
+ mpnn.load_state_dict(checkpoint["state_dict"])
585
+
586
+ mpnn.eval()
587
+
588
+ # Make predictions using raw features
589
+ val_datapoints_raw, _ = _create_molecule_datapoints(
590
+ df_val_valid[smiles_column].tolist(),
591
+ df_val_valid[target_col].tolist(),
592
+ val_extra_raw,
593
+ )
594
+ val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
595
+ val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
596
+
597
+ with torch.inference_mode():
598
+ val_predictions = trainer.predict(mpnn, val_loader_pred)
599
+
600
+ preds = np.concatenate([p.numpy() for p in val_predictions], axis=0)
601
+
602
+ # ChemProp may return (n_samples, 1, n_classes) for multiclass - squeeze middle dim
603
+ if preds.ndim == 3 and preds.shape[1] == 1:
604
+ preds = preds.squeeze(axis=1)
605
+
606
+ # Map predictions back to original indices
607
+ original_val_indices = df.iloc[val_idx].index[val_valid_idx]
608
+
609
+ if is_classifier:
610
+ # Get class predictions
611
+ if preds.ndim == 2 and preds.shape[1] > 1:
612
+ class_preds = np.argmax(preds, axis=1)
613
+ else:
614
+ class_preds = (preds.flatten() > 0.5).astype(int)
615
+
616
+ preds_decoded = label_encoder.inverse_transform(class_preds)
617
+ predictions_df.loc[original_val_indices, "prediction"] = preds_decoded
618
+
619
+ # Store probabilities
620
+ if preds.ndim == 2 and preds.shape[1] > 1:
621
+ for i, idx in enumerate(original_val_indices):
622
+ predictions_df.at[idx, "pred_proba"] = preds[i].tolist()
623
+ else:
624
+ predictions_df.loc[original_val_indices, "prediction"] = preds.flatten()
625
+
626
+ # Calculate fold metrics
627
+ y_val = df_val_valid[target_col].values
628
+
629
+ if is_classifier:
630
+ y_val_orig = label_encoder.inverse_transform(y_val.astype(int))
631
+ preds_orig = preds_decoded
632
+
633
+ prec, rec, f1, _ = precision_recall_fscore_support(
634
+ y_val_orig, preds_orig, average="weighted", zero_division=0
635
+ )
636
+
637
+ prec_per_class, rec_per_class, f1_per_class, _ = precision_recall_fscore_support(
638
+ y_val_orig, preds_orig, average=None, zero_division=0, labels=label_encoder.classes_
639
+ )
640
+
641
+ # ROC AUC
642
+ if preds.ndim == 2 and preds.shape[1] > 1:
643
+ roc_auc_overall = roc_auc_score(y_val, preds, multi_class="ovr", average="macro")
644
+ roc_auc_per_class = roc_auc_score(y_val, preds, multi_class="ovr", average=None)
645
+ else:
646
+ roc_auc_overall = roc_auc_score(y_val, preds.flatten())
647
+ roc_auc_per_class = [roc_auc_overall]
648
+
649
+ fold_metrics.append(
650
+ {
651
+ "fold": fold_idx,
652
+ "precision": prec,
653
+ "recall": rec,
654
+ "f1": f1,
655
+ "roc_auc": roc_auc_overall,
656
+ "precision_per_class": prec_per_class,
657
+ "recall_per_class": rec_per_class,
658
+ "f1_per_class": f1_per_class,
659
+ "roc_auc_per_class": roc_auc_per_class,
660
+ }
661
+ )
662
+
663
+ print(f"Fold {fold_idx} - F1: {f1:.4f}, ROC-AUC: {roc_auc_overall:.4f}")
664
+ else:
665
+ spearman_corr, _ = spearmanr(y_val, preds.flatten())
666
+ rmse = np.sqrt(mean_squared_error(y_val, preds.flatten()))
667
+
668
+ fold_metrics.append(
669
+ {
670
+ "fold": fold_idx,
671
+ "rmse": rmse,
672
+ "mae": mean_absolute_error(y_val, preds.flatten()),
673
+ "medae": median_absolute_error(y_val, preds.flatten()),
674
+ "r2": r2_score(y_val, preds.flatten()),
675
+ "spearmanr": spearman_corr,
676
+ }
677
+ )
678
+
679
+ print(f"Fold {fold_idx} - RMSE: {rmse:.4f}, R2: {fold_metrics[-1]['r2']:.4f}")
680
+
681
+ # Calculate summary metrics
682
+ fold_df = pd.DataFrame(fold_metrics)
683
+
684
+ if is_classifier:
685
+ if "pred_proba" in predictions_df.columns:
686
+ predictions_df = expand_proba_column(predictions_df, label_encoder.classes_)
687
+
688
+ metric_rows = []
689
+ for idx, class_name in enumerate(label_encoder.classes_):
690
+ prec_scores = np.array([fold["precision_per_class"][idx] for fold in fold_metrics])
691
+ rec_scores = np.array([fold["recall_per_class"][idx] for fold in fold_metrics])
692
+ f1_scores = np.array([fold["f1_per_class"][idx] for fold in fold_metrics])
693
+ roc_auc_scores = np.array([fold["roc_auc_per_class"][idx] for fold in fold_metrics])
694
+
695
+ y_orig = label_encoder.inverse_transform(y_for_cv)
696
+ support = int((y_orig == class_name).sum())
697
+
698
+ metric_rows.append(
699
+ {
700
+ "class": class_name,
701
+ "precision": prec_scores.mean(),
702
+ "recall": rec_scores.mean(),
703
+ "f1": f1_scores.mean(),
704
+ "roc_auc": roc_auc_scores.mean(),
705
+ "support": support,
706
+ }
707
+ )
708
+
709
+ metric_rows.append(
710
+ {
711
+ "class": "all",
712
+ "precision": fold_df["precision"].mean(),
713
+ "recall": fold_df["recall"].mean(),
714
+ "f1": fold_df["f1"].mean(),
715
+ "roc_auc": fold_df["roc_auc"].mean(),
716
+ "support": len(y_for_cv),
717
+ }
718
+ )
719
+
720
+ metrics_df = pd.DataFrame(metric_rows)
721
+ else:
722
+ metrics_df = pd.DataFrame(
723
+ [
724
+ {
725
+ "rmse": fold_df["rmse"].mean(),
726
+ "mae": fold_df["mae"].mean(),
727
+ "medae": fold_df["medae"].mean(),
728
+ "r2": fold_df["r2"].mean(),
729
+ "spearmanr": fold_df["spearmanr"].mean(),
730
+ "support": len(y_for_cv),
731
+ }
732
+ ]
733
+ )
734
+
735
+ print(f"\n{'='*50}")
736
+ print("Cross-Validation Summary")
737
+ print(f"{'='*50}")
738
+ print(metrics_df.to_string(index=False))
739
+
740
+ return metrics_df, predictions_df
741
+
742
+ finally:
743
+ log.info(f"Cleaning up model directory: {model_dir}")
744
+ shutil.rmtree(model_dir, ignore_errors=True)
745
+
746
+
747
+ if __name__ == "__main__":
748
+
749
+ # Tests for the ChemProp utilities
750
+ from workbench.api import Endpoint, Model
751
+
752
+ # Initialize Workbench model
753
+ model_name = "aqsol-chemprop-reg"
754
+ print(f"Loading Workbench model: {model_name}")
755
+ model = Model(model_name)
756
+ print(f"Model Framework: {model.model_framework}")
757
+
758
+ # Perform cross-fold inference
759
+ end = Endpoint(model.endpoints()[0])
760
+ end.cross_fold_inference()