workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
- workbench/algorithms/dataframe/proximity.py +261 -235
- workbench/algorithms/graph/light/proximity_graph.py +10 -8
- workbench/api/__init__.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +11 -0
- workbench/api/feature_set.py +11 -8
- workbench/api/meta.py +5 -2
- workbench/api/model.py +16 -15
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +256 -118
- workbench/core/artifacts/feature_set_core.py +265 -16
- workbench/core/artifacts/model_core.py +107 -60
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +42 -32
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/chemprop/chemprop.template +852 -0
- workbench/model_scripts/chemprop/generated_model_script.py +852 -0
- workbench/model_scripts/chemprop/requirements.txt +11 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
- workbench/model_scripts/pytorch_model/pytorch.template +370 -187
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +17 -9
- workbench/model_scripts/uq_models/generated_model_script.py +605 -0
- workbench/model_scripts/uq_models/mapie.template +605 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
- workbench/model_scripts/xgb_model/xgb_model.template +44 -46
- workbench/repl/workbench_shell.py +28 -14
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +760 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +95 -34
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +526 -0
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +371 -156
- workbench/web_interface/components/model_plot.py +7 -1
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +9 -7
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,760 @@
|
|
|
1
|
+
"""ChemProp utilities for Workbench models."""
|
|
2
|
+
|
|
3
|
+
# flake8: noqa: E402
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import tempfile
|
|
7
|
+
from typing import Any, Tuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from scipy.stats import spearmanr
|
|
12
|
+
from sklearn.metrics import (
|
|
13
|
+
mean_absolute_error,
|
|
14
|
+
mean_squared_error,
|
|
15
|
+
median_absolute_error,
|
|
16
|
+
precision_recall_fscore_support,
|
|
17
|
+
r2_score,
|
|
18
|
+
roc_auc_score,
|
|
19
|
+
)
|
|
20
|
+
from sklearn.model_selection import KFold, StratifiedKFold
|
|
21
|
+
from sklearn.preprocessing import LabelEncoder
|
|
22
|
+
|
|
23
|
+
from workbench.utils.model_utils import safe_extract_tarfile
|
|
24
|
+
from workbench.utils.pandas_utils import expand_proba_column
|
|
25
|
+
from workbench.utils.aws_utils import pull_s3_data
|
|
26
|
+
|
|
27
|
+
log = logging.getLogger("workbench")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def download_and_extract_model(s3_uri: str, model_dir: str) -> None:
|
|
31
|
+
"""Download model artifact from S3 and extract it.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
s3_uri: S3 URI to the model artifact (model.tar.gz)
|
|
35
|
+
model_dir: Directory to extract model artifacts to
|
|
36
|
+
"""
|
|
37
|
+
import awswrangler as wr
|
|
38
|
+
|
|
39
|
+
log.info(f"Downloading model from {s3_uri}...")
|
|
40
|
+
|
|
41
|
+
# Download to temp file
|
|
42
|
+
local_tar_path = os.path.join(model_dir, "model.tar.gz")
|
|
43
|
+
wr.s3.download(path=s3_uri, local_file=local_tar_path)
|
|
44
|
+
|
|
45
|
+
# Extract using safe extraction
|
|
46
|
+
log.info(f"Extracting to {model_dir}...")
|
|
47
|
+
safe_extract_tarfile(local_tar_path, model_dir)
|
|
48
|
+
|
|
49
|
+
# Cleanup tar file
|
|
50
|
+
os.unlink(local_tar_path)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_chemprop_model_artifacts(model_dir: str) -> Tuple[Any, dict]:
|
|
54
|
+
"""Load ChemProp MPNN model and artifacts from an extracted model directory.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
model_dir: Directory containing extracted model artifacts
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Tuple of (MPNN model, artifacts_dict).
|
|
61
|
+
artifacts_dict contains 'label_encoder' and 'feature_metadata' if present.
|
|
62
|
+
"""
|
|
63
|
+
import joblib
|
|
64
|
+
from chemprop import models
|
|
65
|
+
|
|
66
|
+
model_path = os.path.join(model_dir, "chemprop_model.pt")
|
|
67
|
+
if not os.path.exists(model_path):
|
|
68
|
+
raise FileNotFoundError(f"No chemprop_model.pt found in {model_dir}")
|
|
69
|
+
|
|
70
|
+
model = models.MPNN.load_from_file(model_path)
|
|
71
|
+
model.eval()
|
|
72
|
+
|
|
73
|
+
# Load additional artifacts
|
|
74
|
+
artifacts = {}
|
|
75
|
+
|
|
76
|
+
label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
|
|
77
|
+
if os.path.exists(label_encoder_path):
|
|
78
|
+
artifacts["label_encoder"] = joblib.load(label_encoder_path)
|
|
79
|
+
|
|
80
|
+
feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
|
|
81
|
+
if os.path.exists(feature_metadata_path):
|
|
82
|
+
artifacts["feature_metadata"] = joblib.load(feature_metadata_path)
|
|
83
|
+
|
|
84
|
+
return model, artifacts
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _find_smiles_column(columns: list) -> str:
|
|
88
|
+
"""Find the SMILES column name from a list (case-insensitive match for 'smiles')."""
|
|
89
|
+
smiles_column = next((col for col in columns if col.lower() == "smiles"), None)
|
|
90
|
+
if smiles_column is None:
|
|
91
|
+
raise ValueError("Column list must contain a 'smiles' column (case-insensitive)")
|
|
92
|
+
return smiles_column
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _create_molecule_datapoints(
|
|
96
|
+
smiles_list: list,
|
|
97
|
+
targets: list = None,
|
|
98
|
+
extra_descriptors: np.ndarray = None,
|
|
99
|
+
) -> Tuple[list, list]:
|
|
100
|
+
"""Create ChemProp MoleculeDatapoints from SMILES strings.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
smiles_list: List of SMILES strings
|
|
104
|
+
targets: Optional list of target values (for training)
|
|
105
|
+
extra_descriptors: Optional array of extra features (n_samples, n_features)
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Tuple of (list of MoleculeDatapoint objects, list of valid indices)
|
|
109
|
+
"""
|
|
110
|
+
from chemprop import data
|
|
111
|
+
from rdkit import Chem
|
|
112
|
+
|
|
113
|
+
datapoints = []
|
|
114
|
+
valid_indices = []
|
|
115
|
+
invalid_count = 0
|
|
116
|
+
|
|
117
|
+
for i, smi in enumerate(smiles_list):
|
|
118
|
+
# Validate SMILES with RDKit first
|
|
119
|
+
mol = Chem.MolFromSmiles(smi)
|
|
120
|
+
if mol is None:
|
|
121
|
+
invalid_count += 1
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
# Build datapoint with optional target and extra descriptors
|
|
125
|
+
y = [targets[i]] if targets is not None else None
|
|
126
|
+
x_d = extra_descriptors[i] if extra_descriptors is not None else None
|
|
127
|
+
|
|
128
|
+
dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
|
|
129
|
+
datapoints.append(dp)
|
|
130
|
+
valid_indices.append(i)
|
|
131
|
+
|
|
132
|
+
if invalid_count > 0:
|
|
133
|
+
print(f"Warning: Skipped {invalid_count} invalid SMILES strings")
|
|
134
|
+
|
|
135
|
+
return datapoints, valid_indices
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _build_mpnn_model(
|
|
139
|
+
hyperparameters: dict,
|
|
140
|
+
task: str = "regression",
|
|
141
|
+
num_classes: int = None,
|
|
142
|
+
n_extra_descriptors: int = 0,
|
|
143
|
+
x_d_transform: Any = None,
|
|
144
|
+
output_transform: Any = None,
|
|
145
|
+
) -> Any:
|
|
146
|
+
"""Build an MPNN model with the specified hyperparameters.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
hyperparameters: Dictionary of model hyperparameters
|
|
150
|
+
task: Either "regression" or "classification"
|
|
151
|
+
num_classes: Number of classes for classification tasks
|
|
152
|
+
n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
|
|
153
|
+
x_d_transform: Optional transform for extra descriptors (scaling)
|
|
154
|
+
output_transform: Optional transform for regression output (unscaling targets)
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Configured MPNN model
|
|
158
|
+
"""
|
|
159
|
+
from chemprop import models, nn
|
|
160
|
+
|
|
161
|
+
# Model hyperparameters with defaults
|
|
162
|
+
hidden_dim = hyperparameters.get("hidden_dim", 300)
|
|
163
|
+
depth = hyperparameters.get("depth", 3)
|
|
164
|
+
dropout = hyperparameters.get("dropout", 0.1)
|
|
165
|
+
ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 300)
|
|
166
|
+
ffn_num_layers = hyperparameters.get("ffn_num_layers", 1)
|
|
167
|
+
|
|
168
|
+
# Message passing component
|
|
169
|
+
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
170
|
+
|
|
171
|
+
# Aggregation - NormAggregation normalizes output, recommended when using extra descriptors
|
|
172
|
+
agg = nn.NormAggregation()
|
|
173
|
+
|
|
174
|
+
# FFN input_dim = message passing output + extra descriptors
|
|
175
|
+
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
176
|
+
|
|
177
|
+
# Build FFN based on task type
|
|
178
|
+
if task == "classification" and num_classes is not None:
|
|
179
|
+
# Multi-class classification
|
|
180
|
+
ffn = nn.MulticlassClassificationFFN(
|
|
181
|
+
n_classes=num_classes,
|
|
182
|
+
input_dim=ffn_input_dim,
|
|
183
|
+
hidden_dim=ffn_hidden_dim,
|
|
184
|
+
n_layers=ffn_num_layers,
|
|
185
|
+
dropout=dropout,
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
# Regression with optional output transform to unscale predictions
|
|
189
|
+
ffn = nn.RegressionFFN(
|
|
190
|
+
input_dim=ffn_input_dim,
|
|
191
|
+
hidden_dim=ffn_hidden_dim,
|
|
192
|
+
n_layers=ffn_num_layers,
|
|
193
|
+
dropout=dropout,
|
|
194
|
+
output_transform=output_transform,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Create the MPNN model
|
|
198
|
+
mpnn = models.MPNN(
|
|
199
|
+
message_passing=mp,
|
|
200
|
+
agg=agg,
|
|
201
|
+
predictor=ffn,
|
|
202
|
+
batch_norm=True,
|
|
203
|
+
metrics=None,
|
|
204
|
+
X_d_transform=x_d_transform,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
return mpnn
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _extract_model_hyperparameters(loaded_model: Any) -> dict:
|
|
211
|
+
"""Extract hyperparameters from a loaded ChemProp MPNN model.
|
|
212
|
+
|
|
213
|
+
Extracts architecture parameters from the model's components to replicate
|
|
214
|
+
the exact same model configuration during cross-validation.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
loaded_model: Loaded MPNN model instance
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Dictionary of hyperparameters matching the training template
|
|
221
|
+
"""
|
|
222
|
+
hyperparameters = {}
|
|
223
|
+
|
|
224
|
+
# Extract from message passing layer (BondMessagePassing)
|
|
225
|
+
mp = loaded_model.message_passing
|
|
226
|
+
hyperparameters["hidden_dim"] = getattr(mp, "d_h", 300)
|
|
227
|
+
hyperparameters["depth"] = getattr(mp, "depth", 3)
|
|
228
|
+
|
|
229
|
+
# Dropout is stored as a nn.Dropout module, get the p value
|
|
230
|
+
if hasattr(mp, "dropout"):
|
|
231
|
+
dropout_module = mp.dropout
|
|
232
|
+
hyperparameters["dropout"] = getattr(dropout_module, "p", 0.0)
|
|
233
|
+
else:
|
|
234
|
+
hyperparameters["dropout"] = 0.0
|
|
235
|
+
|
|
236
|
+
# Extract from predictor (FFN - either RegressionFFN or MulticlassClassificationFFN)
|
|
237
|
+
ffn = loaded_model.predictor
|
|
238
|
+
|
|
239
|
+
# FFN hidden_dim - try multiple attribute names
|
|
240
|
+
if hasattr(ffn, "hidden_dim"):
|
|
241
|
+
hyperparameters["ffn_hidden_dim"] = ffn.hidden_dim
|
|
242
|
+
elif hasattr(ffn, "d_h"):
|
|
243
|
+
hyperparameters["ffn_hidden_dim"] = ffn.d_h
|
|
244
|
+
else:
|
|
245
|
+
hyperparameters["ffn_hidden_dim"] = 300
|
|
246
|
+
|
|
247
|
+
# FFN num_layers - try multiple attribute names
|
|
248
|
+
if hasattr(ffn, "n_layers"):
|
|
249
|
+
hyperparameters["ffn_num_layers"] = ffn.n_layers
|
|
250
|
+
elif hasattr(ffn, "num_layers"):
|
|
251
|
+
hyperparameters["ffn_num_layers"] = ffn.num_layers
|
|
252
|
+
else:
|
|
253
|
+
hyperparameters["ffn_num_layers"] = 1
|
|
254
|
+
|
|
255
|
+
# Training hyperparameters (use defaults matching the template)
|
|
256
|
+
hyperparameters["max_epochs"] = 50
|
|
257
|
+
hyperparameters["patience"] = 10
|
|
258
|
+
|
|
259
|
+
return hyperparameters
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _get_n_extra_descriptors(loaded_model: Any) -> int:
|
|
263
|
+
"""Get the number of extra descriptors from the loaded model.
|
|
264
|
+
|
|
265
|
+
The model's X_d_transform contains the scaler which knows the feature dimension.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
loaded_model: Loaded MPNN model instance
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
Number of extra descriptors (0 if none)
|
|
272
|
+
"""
|
|
273
|
+
x_d_transform = loaded_model.X_d_transform
|
|
274
|
+
if x_d_transform is None:
|
|
275
|
+
return 0
|
|
276
|
+
|
|
277
|
+
# ScaleTransform wraps a StandardScaler, check its mean_ attribute
|
|
278
|
+
if hasattr(x_d_transform, "mean"):
|
|
279
|
+
# x_d_transform.mean is a tensor
|
|
280
|
+
return len(x_d_transform.mean)
|
|
281
|
+
elif hasattr(x_d_transform, "scaler") and hasattr(x_d_transform.scaler, "mean_"):
|
|
282
|
+
return len(x_d_transform.scaler.mean_)
|
|
283
|
+
|
|
284
|
+
return 0
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
288
|
+
"""Pull cross-validation results from AWS training artifacts.
|
|
289
|
+
|
|
290
|
+
This retrieves the validation predictions and training metrics that were
|
|
291
|
+
saved during model training.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
workbench_model: Workbench model object
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
Tuple of:
|
|
298
|
+
- DataFrame with training metrics
|
|
299
|
+
- DataFrame with validation predictions
|
|
300
|
+
"""
|
|
301
|
+
# Get the validation predictions from S3
|
|
302
|
+
s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
|
|
303
|
+
predictions_df = pull_s3_data(s3_path)
|
|
304
|
+
|
|
305
|
+
if predictions_df is None:
|
|
306
|
+
raise ValueError(f"No validation predictions found at {s3_path}")
|
|
307
|
+
|
|
308
|
+
log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
|
|
309
|
+
|
|
310
|
+
# Get training metrics from model metadata
|
|
311
|
+
training_metrics = workbench_model.workbench_meta().get("workbench_training_metrics")
|
|
312
|
+
|
|
313
|
+
if training_metrics is None:
|
|
314
|
+
raise ValueError(f"No training metrics found in model metadata for {workbench_model.model_name}")
|
|
315
|
+
|
|
316
|
+
metrics_df = pd.DataFrame.from_dict(training_metrics)
|
|
317
|
+
log.info(f"Metrics summary:\n{metrics_df.to_string(index=False)}")
|
|
318
|
+
|
|
319
|
+
return metrics_df, predictions_df
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def cross_fold_inference(
|
|
323
|
+
workbench_model: Any,
|
|
324
|
+
nfolds: int = 5,
|
|
325
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
326
|
+
"""Performs K-fold cross-validation for ChemProp MPNN models.
|
|
327
|
+
|
|
328
|
+
Replicates the training setup from the original model to ensure
|
|
329
|
+
cross-validation results are comparable to the deployed model.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
workbench_model: Workbench model object
|
|
333
|
+
nfolds: Number of folds for cross-validation (default is 5)
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
Tuple of:
|
|
337
|
+
- DataFrame with per-class metrics (and 'all' row for overall metrics)
|
|
338
|
+
- DataFrame with columns: id, target, prediction, and *_proba columns (for classifiers)
|
|
339
|
+
"""
|
|
340
|
+
import shutil
|
|
341
|
+
|
|
342
|
+
import joblib
|
|
343
|
+
import torch
|
|
344
|
+
from chemprop import data, nn
|
|
345
|
+
from lightning import pytorch as pl
|
|
346
|
+
|
|
347
|
+
from workbench.api import FeatureSet
|
|
348
|
+
|
|
349
|
+
# Create a temporary model directory
|
|
350
|
+
model_dir = tempfile.mkdtemp(prefix="chemprop_cv_")
|
|
351
|
+
log.info(f"Using model directory: {model_dir}")
|
|
352
|
+
|
|
353
|
+
try:
|
|
354
|
+
# Download and extract model artifacts to get config and artifacts
|
|
355
|
+
model_artifact_uri = workbench_model.model_data_url()
|
|
356
|
+
download_and_extract_model(model_artifact_uri, model_dir)
|
|
357
|
+
|
|
358
|
+
# Load model and artifacts
|
|
359
|
+
loaded_model, artifacts = load_chemprop_model_artifacts(model_dir)
|
|
360
|
+
feature_metadata = artifacts.get("feature_metadata", {})
|
|
361
|
+
|
|
362
|
+
# Determine if classifier from predictor type
|
|
363
|
+
from chemprop.nn import MulticlassClassificationFFN
|
|
364
|
+
|
|
365
|
+
is_classifier = isinstance(loaded_model.predictor, MulticlassClassificationFFN)
|
|
366
|
+
|
|
367
|
+
# Use saved label encoder if available, otherwise create fresh one
|
|
368
|
+
if is_classifier:
|
|
369
|
+
label_encoder = artifacts.get("label_encoder")
|
|
370
|
+
if label_encoder is None:
|
|
371
|
+
log.warning("No saved label encoder found, creating fresh one")
|
|
372
|
+
label_encoder = LabelEncoder()
|
|
373
|
+
else:
|
|
374
|
+
label_encoder = None
|
|
375
|
+
|
|
376
|
+
# Prepare data
|
|
377
|
+
fs = FeatureSet(workbench_model.get_input())
|
|
378
|
+
df = workbench_model.training_view().pull_dataframe()
|
|
379
|
+
|
|
380
|
+
# Get columns
|
|
381
|
+
id_col = fs.id_column
|
|
382
|
+
target_col = workbench_model.target()
|
|
383
|
+
feature_cols = workbench_model.features()
|
|
384
|
+
print(f"Target column: {target_col}")
|
|
385
|
+
print(f"Feature columns: {len(feature_cols)} features")
|
|
386
|
+
|
|
387
|
+
# Find SMILES column
|
|
388
|
+
smiles_column = _find_smiles_column(feature_cols)
|
|
389
|
+
|
|
390
|
+
# Determine extra feature columns:
|
|
391
|
+
# 1. First try feature_metadata (saved during training)
|
|
392
|
+
# 2. Fall back to inferring from feature_cols (exclude SMILES column)
|
|
393
|
+
# 3. Verify against model's X_d_transform dimension
|
|
394
|
+
if feature_metadata and "extra_feature_cols" in feature_metadata:
|
|
395
|
+
extra_feature_cols = feature_metadata["extra_feature_cols"]
|
|
396
|
+
else:
|
|
397
|
+
# Infer from feature list - everything except SMILES is an extra feature
|
|
398
|
+
extra_feature_cols = [f for f in feature_cols if f.lower() != "smiles"]
|
|
399
|
+
|
|
400
|
+
# Verify against model's actual extra descriptor dimension
|
|
401
|
+
n_extra_from_model = _get_n_extra_descriptors(loaded_model)
|
|
402
|
+
if n_extra_from_model > 0 and len(extra_feature_cols) != n_extra_from_model:
|
|
403
|
+
log.warning(
|
|
404
|
+
f"Inferred {len(extra_feature_cols)} extra features but model expects "
|
|
405
|
+
f"{n_extra_from_model}. Using inferred columns."
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
use_extra_features = len(extra_feature_cols) > 0
|
|
409
|
+
|
|
410
|
+
print(f"SMILES column: {smiles_column}")
|
|
411
|
+
print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
|
|
412
|
+
|
|
413
|
+
# Drop rows with missing SMILES or target values
|
|
414
|
+
initial_count = len(df)
|
|
415
|
+
df = df.dropna(subset=[smiles_column, target_col])
|
|
416
|
+
dropped = initial_count - len(df)
|
|
417
|
+
if dropped > 0:
|
|
418
|
+
print(f"Dropped {dropped} rows with missing SMILES or target values")
|
|
419
|
+
|
|
420
|
+
# Extract hyperparameters from loaded model
|
|
421
|
+
hyperparameters = _extract_model_hyperparameters(loaded_model)
|
|
422
|
+
print(f"Extracted hyperparameters: {hyperparameters}")
|
|
423
|
+
|
|
424
|
+
# Get number of classes for classifier
|
|
425
|
+
num_classes = None
|
|
426
|
+
if is_classifier:
|
|
427
|
+
# Try to get from loaded model's FFN first (most reliable)
|
|
428
|
+
ffn = loaded_model.predictor
|
|
429
|
+
if hasattr(ffn, "n_classes"):
|
|
430
|
+
num_classes = ffn.n_classes
|
|
431
|
+
elif label_encoder is not None and hasattr(label_encoder, "classes_"):
|
|
432
|
+
num_classes = len(label_encoder.classes_)
|
|
433
|
+
else:
|
|
434
|
+
# Fit label encoder to get classes
|
|
435
|
+
if label_encoder is None:
|
|
436
|
+
label_encoder = LabelEncoder()
|
|
437
|
+
label_encoder.fit(df[target_col])
|
|
438
|
+
num_classes = len(label_encoder.classes_)
|
|
439
|
+
print(f"Classification task with {num_classes} classes")
|
|
440
|
+
|
|
441
|
+
X = df[[smiles_column] + extra_feature_cols]
|
|
442
|
+
y = df[target_col]
|
|
443
|
+
ids = df[id_col]
|
|
444
|
+
|
|
445
|
+
# Encode target if classifier
|
|
446
|
+
if label_encoder is not None:
|
|
447
|
+
if not hasattr(label_encoder, "classes_"):
|
|
448
|
+
label_encoder.fit(y)
|
|
449
|
+
y_encoded = label_encoder.transform(y)
|
|
450
|
+
y_for_cv = pd.Series(y_encoded, index=y.index, name=target_col)
|
|
451
|
+
else:
|
|
452
|
+
y_for_cv = y
|
|
453
|
+
|
|
454
|
+
# Prepare KFold
|
|
455
|
+
kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
|
|
456
|
+
|
|
457
|
+
# Initialize results collection
|
|
458
|
+
fold_metrics = []
|
|
459
|
+
predictions_df = pd.DataFrame({id_col: ids, target_col: y})
|
|
460
|
+
if is_classifier:
|
|
461
|
+
predictions_df["pred_proba"] = [None] * len(predictions_df)
|
|
462
|
+
|
|
463
|
+
# Perform cross-validation
|
|
464
|
+
for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y_for_cv), 1):
|
|
465
|
+
print(f"\n{'='*50}")
|
|
466
|
+
print(f"Fold {fold_idx}/{nfolds}")
|
|
467
|
+
print(f"{'='*50}")
|
|
468
|
+
|
|
469
|
+
# Split data
|
|
470
|
+
df_train = df.iloc[train_idx].copy()
|
|
471
|
+
df_val = df.iloc[val_idx].copy()
|
|
472
|
+
|
|
473
|
+
# Encode target for this fold
|
|
474
|
+
if is_classifier:
|
|
475
|
+
df_train[target_col] = label_encoder.transform(df_train[target_col])
|
|
476
|
+
df_val[target_col] = label_encoder.transform(df_val[target_col])
|
|
477
|
+
|
|
478
|
+
# Prepare extra features if using hybrid mode
|
|
479
|
+
train_extra_features = None
|
|
480
|
+
val_extra_features = None
|
|
481
|
+
col_means = None
|
|
482
|
+
|
|
483
|
+
if use_extra_features:
|
|
484
|
+
train_extra_features = df_train[extra_feature_cols].values.astype(np.float32)
|
|
485
|
+
val_extra_features = df_val[extra_feature_cols].values.astype(np.float32)
|
|
486
|
+
|
|
487
|
+
# Fill NaN with column means from training data
|
|
488
|
+
col_means = np.nanmean(train_extra_features, axis=0)
|
|
489
|
+
for i in range(train_extra_features.shape[1]):
|
|
490
|
+
train_nan_mask = np.isnan(train_extra_features[:, i])
|
|
491
|
+
val_nan_mask = np.isnan(val_extra_features[:, i])
|
|
492
|
+
train_extra_features[train_nan_mask, i] = col_means[i]
|
|
493
|
+
val_extra_features[val_nan_mask, i] = col_means[i]
|
|
494
|
+
|
|
495
|
+
# Create ChemProp datasets
|
|
496
|
+
train_datapoints, train_valid_idx = _create_molecule_datapoints(
|
|
497
|
+
df_train[smiles_column].tolist(),
|
|
498
|
+
df_train[target_col].tolist(),
|
|
499
|
+
train_extra_features,
|
|
500
|
+
)
|
|
501
|
+
val_datapoints, val_valid_idx = _create_molecule_datapoints(
|
|
502
|
+
df_val[smiles_column].tolist(),
|
|
503
|
+
df_val[target_col].tolist(),
|
|
504
|
+
val_extra_features,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# Update dataframes to only include valid molecules
|
|
508
|
+
df_train_valid = df_train.iloc[train_valid_idx].reset_index(drop=True)
|
|
509
|
+
df_val_valid = df_val.iloc[val_valid_idx].reset_index(drop=True)
|
|
510
|
+
|
|
511
|
+
train_dataset = data.MoleculeDataset(train_datapoints)
|
|
512
|
+
val_dataset = data.MoleculeDataset(val_datapoints)
|
|
513
|
+
|
|
514
|
+
# Save raw validation features before scaling
|
|
515
|
+
val_extra_raw = val_extra_features[val_valid_idx] if val_extra_features is not None else None
|
|
516
|
+
|
|
517
|
+
# Scale extra descriptors
|
|
518
|
+
feature_scaler = None
|
|
519
|
+
x_d_transform = None
|
|
520
|
+
if use_extra_features:
|
|
521
|
+
feature_scaler = train_dataset.normalize_inputs("X_d")
|
|
522
|
+
val_dataset.normalize_inputs("X_d", feature_scaler)
|
|
523
|
+
x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
|
|
524
|
+
|
|
525
|
+
# Scale targets for regression
|
|
526
|
+
target_scaler = None
|
|
527
|
+
output_transform = None
|
|
528
|
+
if not is_classifier:
|
|
529
|
+
target_scaler = train_dataset.normalize_targets()
|
|
530
|
+
val_dataset.normalize_targets(target_scaler)
|
|
531
|
+
output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
|
|
532
|
+
|
|
533
|
+
# Get batch size
|
|
534
|
+
batch_size = min(64, max(16, len(df_train_valid) // 16))
|
|
535
|
+
|
|
536
|
+
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
537
|
+
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
|
|
538
|
+
|
|
539
|
+
# Build the model
|
|
540
|
+
n_extra = len(extra_feature_cols) if use_extra_features else 0
|
|
541
|
+
mpnn = _build_mpnn_model(
|
|
542
|
+
hyperparameters,
|
|
543
|
+
task="classification" if is_classifier else "regression",
|
|
544
|
+
num_classes=num_classes,
|
|
545
|
+
n_extra_descriptors=n_extra,
|
|
546
|
+
x_d_transform=x_d_transform,
|
|
547
|
+
output_transform=output_transform,
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# Training configuration
|
|
551
|
+
max_epochs = hyperparameters.get("max_epochs", 50)
|
|
552
|
+
patience = hyperparameters.get("patience", 10)
|
|
553
|
+
|
|
554
|
+
# Set up trainer
|
|
555
|
+
checkpoint_dir = os.path.join(model_dir, f"fold_{fold_idx}")
|
|
556
|
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
557
|
+
|
|
558
|
+
callbacks = [
|
|
559
|
+
pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
|
|
560
|
+
pl.callbacks.ModelCheckpoint(
|
|
561
|
+
dirpath=checkpoint_dir,
|
|
562
|
+
filename="best_model",
|
|
563
|
+
monitor="val_loss",
|
|
564
|
+
mode="min",
|
|
565
|
+
save_top_k=1,
|
|
566
|
+
),
|
|
567
|
+
]
|
|
568
|
+
|
|
569
|
+
trainer = pl.Trainer(
|
|
570
|
+
accelerator="auto",
|
|
571
|
+
max_epochs=max_epochs,
|
|
572
|
+
callbacks=callbacks,
|
|
573
|
+
logger=False,
|
|
574
|
+
enable_progress_bar=True,
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
# Train the model
|
|
578
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
579
|
+
|
|
580
|
+
# Load the best checkpoint
|
|
581
|
+
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
|
|
582
|
+
best_ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
583
|
+
checkpoint = torch.load(best_ckpt_path, weights_only=False)
|
|
584
|
+
mpnn.load_state_dict(checkpoint["state_dict"])
|
|
585
|
+
|
|
586
|
+
mpnn.eval()
|
|
587
|
+
|
|
588
|
+
# Make predictions using raw features
|
|
589
|
+
val_datapoints_raw, _ = _create_molecule_datapoints(
|
|
590
|
+
df_val_valid[smiles_column].tolist(),
|
|
591
|
+
df_val_valid[target_col].tolist(),
|
|
592
|
+
val_extra_raw,
|
|
593
|
+
)
|
|
594
|
+
val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
|
|
595
|
+
val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
|
|
596
|
+
|
|
597
|
+
with torch.inference_mode():
|
|
598
|
+
val_predictions = trainer.predict(mpnn, val_loader_pred)
|
|
599
|
+
|
|
600
|
+
preds = np.concatenate([p.numpy() for p in val_predictions], axis=0)
|
|
601
|
+
|
|
602
|
+
# ChemProp may return (n_samples, 1, n_classes) for multiclass - squeeze middle dim
|
|
603
|
+
if preds.ndim == 3 and preds.shape[1] == 1:
|
|
604
|
+
preds = preds.squeeze(axis=1)
|
|
605
|
+
|
|
606
|
+
# Map predictions back to original indices
|
|
607
|
+
original_val_indices = df.iloc[val_idx].index[val_valid_idx]
|
|
608
|
+
|
|
609
|
+
if is_classifier:
|
|
610
|
+
# Get class predictions
|
|
611
|
+
if preds.ndim == 2 and preds.shape[1] > 1:
|
|
612
|
+
class_preds = np.argmax(preds, axis=1)
|
|
613
|
+
else:
|
|
614
|
+
class_preds = (preds.flatten() > 0.5).astype(int)
|
|
615
|
+
|
|
616
|
+
preds_decoded = label_encoder.inverse_transform(class_preds)
|
|
617
|
+
predictions_df.loc[original_val_indices, "prediction"] = preds_decoded
|
|
618
|
+
|
|
619
|
+
# Store probabilities
|
|
620
|
+
if preds.ndim == 2 and preds.shape[1] > 1:
|
|
621
|
+
for i, idx in enumerate(original_val_indices):
|
|
622
|
+
predictions_df.at[idx, "pred_proba"] = preds[i].tolist()
|
|
623
|
+
else:
|
|
624
|
+
predictions_df.loc[original_val_indices, "prediction"] = preds.flatten()
|
|
625
|
+
|
|
626
|
+
# Calculate fold metrics
|
|
627
|
+
y_val = df_val_valid[target_col].values
|
|
628
|
+
|
|
629
|
+
if is_classifier:
|
|
630
|
+
y_val_orig = label_encoder.inverse_transform(y_val.astype(int))
|
|
631
|
+
preds_orig = preds_decoded
|
|
632
|
+
|
|
633
|
+
prec, rec, f1, _ = precision_recall_fscore_support(
|
|
634
|
+
y_val_orig, preds_orig, average="weighted", zero_division=0
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
prec_per_class, rec_per_class, f1_per_class, _ = precision_recall_fscore_support(
|
|
638
|
+
y_val_orig, preds_orig, average=None, zero_division=0, labels=label_encoder.classes_
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
# ROC AUC
|
|
642
|
+
if preds.ndim == 2 and preds.shape[1] > 1:
|
|
643
|
+
roc_auc_overall = roc_auc_score(y_val, preds, multi_class="ovr", average="macro")
|
|
644
|
+
roc_auc_per_class = roc_auc_score(y_val, preds, multi_class="ovr", average=None)
|
|
645
|
+
else:
|
|
646
|
+
roc_auc_overall = roc_auc_score(y_val, preds.flatten())
|
|
647
|
+
roc_auc_per_class = [roc_auc_overall]
|
|
648
|
+
|
|
649
|
+
fold_metrics.append(
|
|
650
|
+
{
|
|
651
|
+
"fold": fold_idx,
|
|
652
|
+
"precision": prec,
|
|
653
|
+
"recall": rec,
|
|
654
|
+
"f1": f1,
|
|
655
|
+
"roc_auc": roc_auc_overall,
|
|
656
|
+
"precision_per_class": prec_per_class,
|
|
657
|
+
"recall_per_class": rec_per_class,
|
|
658
|
+
"f1_per_class": f1_per_class,
|
|
659
|
+
"roc_auc_per_class": roc_auc_per_class,
|
|
660
|
+
}
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
print(f"Fold {fold_idx} - F1: {f1:.4f}, ROC-AUC: {roc_auc_overall:.4f}")
|
|
664
|
+
else:
|
|
665
|
+
spearman_corr, _ = spearmanr(y_val, preds.flatten())
|
|
666
|
+
rmse = np.sqrt(mean_squared_error(y_val, preds.flatten()))
|
|
667
|
+
|
|
668
|
+
fold_metrics.append(
|
|
669
|
+
{
|
|
670
|
+
"fold": fold_idx,
|
|
671
|
+
"rmse": rmse,
|
|
672
|
+
"mae": mean_absolute_error(y_val, preds.flatten()),
|
|
673
|
+
"medae": median_absolute_error(y_val, preds.flatten()),
|
|
674
|
+
"r2": r2_score(y_val, preds.flatten()),
|
|
675
|
+
"spearmanr": spearman_corr,
|
|
676
|
+
}
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
print(f"Fold {fold_idx} - RMSE: {rmse:.4f}, R2: {fold_metrics[-1]['r2']:.4f}")
|
|
680
|
+
|
|
681
|
+
# Calculate summary metrics
|
|
682
|
+
fold_df = pd.DataFrame(fold_metrics)
|
|
683
|
+
|
|
684
|
+
if is_classifier:
|
|
685
|
+
if "pred_proba" in predictions_df.columns:
|
|
686
|
+
predictions_df = expand_proba_column(predictions_df, label_encoder.classes_)
|
|
687
|
+
|
|
688
|
+
metric_rows = []
|
|
689
|
+
for idx, class_name in enumerate(label_encoder.classes_):
|
|
690
|
+
prec_scores = np.array([fold["precision_per_class"][idx] for fold in fold_metrics])
|
|
691
|
+
rec_scores = np.array([fold["recall_per_class"][idx] for fold in fold_metrics])
|
|
692
|
+
f1_scores = np.array([fold["f1_per_class"][idx] for fold in fold_metrics])
|
|
693
|
+
roc_auc_scores = np.array([fold["roc_auc_per_class"][idx] for fold in fold_metrics])
|
|
694
|
+
|
|
695
|
+
y_orig = label_encoder.inverse_transform(y_for_cv)
|
|
696
|
+
support = int((y_orig == class_name).sum())
|
|
697
|
+
|
|
698
|
+
metric_rows.append(
|
|
699
|
+
{
|
|
700
|
+
"class": class_name,
|
|
701
|
+
"precision": prec_scores.mean(),
|
|
702
|
+
"recall": rec_scores.mean(),
|
|
703
|
+
"f1": f1_scores.mean(),
|
|
704
|
+
"roc_auc": roc_auc_scores.mean(),
|
|
705
|
+
"support": support,
|
|
706
|
+
}
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
metric_rows.append(
|
|
710
|
+
{
|
|
711
|
+
"class": "all",
|
|
712
|
+
"precision": fold_df["precision"].mean(),
|
|
713
|
+
"recall": fold_df["recall"].mean(),
|
|
714
|
+
"f1": fold_df["f1"].mean(),
|
|
715
|
+
"roc_auc": fold_df["roc_auc"].mean(),
|
|
716
|
+
"support": len(y_for_cv),
|
|
717
|
+
}
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
metrics_df = pd.DataFrame(metric_rows)
|
|
721
|
+
else:
|
|
722
|
+
metrics_df = pd.DataFrame(
|
|
723
|
+
[
|
|
724
|
+
{
|
|
725
|
+
"rmse": fold_df["rmse"].mean(),
|
|
726
|
+
"mae": fold_df["mae"].mean(),
|
|
727
|
+
"medae": fold_df["medae"].mean(),
|
|
728
|
+
"r2": fold_df["r2"].mean(),
|
|
729
|
+
"spearmanr": fold_df["spearmanr"].mean(),
|
|
730
|
+
"support": len(y_for_cv),
|
|
731
|
+
}
|
|
732
|
+
]
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
print(f"\n{'='*50}")
|
|
736
|
+
print("Cross-Validation Summary")
|
|
737
|
+
print(f"{'='*50}")
|
|
738
|
+
print(metrics_df.to_string(index=False))
|
|
739
|
+
|
|
740
|
+
return metrics_df, predictions_df
|
|
741
|
+
|
|
742
|
+
finally:
|
|
743
|
+
log.info(f"Cleaning up model directory: {model_dir}")
|
|
744
|
+
shutil.rmtree(model_dir, ignore_errors=True)
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
if __name__ == "__main__":
|
|
748
|
+
|
|
749
|
+
# Tests for the ChemProp utilities
|
|
750
|
+
from workbench.api import Endpoint, Model
|
|
751
|
+
|
|
752
|
+
# Initialize Workbench model
|
|
753
|
+
model_name = "aqsol-chemprop-reg"
|
|
754
|
+
print(f"Loading Workbench model: {model_name}")
|
|
755
|
+
model = Model(model_name)
|
|
756
|
+
print(f"Model Framework: {model.model_framework}")
|
|
757
|
+
|
|
758
|
+
# Perform cross-fold inference
|
|
759
|
+
end = Endpoint(model.endpoints()[0])
|
|
760
|
+
end.cross_fold_inference()
|