workbench 0.8.213__py3-none-any.whl → 0.8.217__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
- workbench/algorithms/dataframe/projection_2d.py +38 -21
- workbench/algorithms/dataframe/proximity.py +75 -150
- workbench/algorithms/graph/light/proximity_graph.py +5 -5
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +2 -2
- workbench/api/__init__.py +3 -0
- workbench/api/endpoint.py +10 -5
- workbench/api/feature_set.py +76 -6
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +43 -4
- workbench/core/artifacts/endpoint_core.py +63 -115
- workbench/core/artifacts/feature_set_core.py +1 -1
- workbench/core/artifacts/model_core.py +6 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
- workbench/model_script_utils/pytorch_utils.py +11 -1
- workbench/model_scripts/chemprop/chemprop.template +145 -69
- workbench/model_scripts/chemprop/generated_model_script.py +147 -71
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +7 -3
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +6 -6
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +42 -24
- workbench/model_scripts/pytorch_model/pytorch.template +42 -24
- workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
- workbench/model_scripts/script_generation.py +4 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +169 -158
- workbench/model_scripts/xgb_model/xgb_model.template +163 -152
- workbench/repl/workbench_shell.py +0 -5
- workbench/scripts/endpoint_test.py +2 -2
- workbench/utils/chem_utils/fingerprints.py +7 -3
- workbench/utils/chemprop_utils.py +23 -5
- workbench/utils/meta_model_simulator.py +471 -0
- workbench/utils/metrics_utils.py +94 -10
- workbench/utils/model_utils.py +91 -9
- workbench/utils/pytorch_utils.py +1 -1
- workbench/web_interface/components/plugins/scatter_plot.py +4 -8
- {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/METADATA +2 -1
- {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/RECORD +48 -43
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
- {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/WHEEL +0 -0
- {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/top_level.txt +0 -0
|
@@ -123,7 +123,7 @@ class PipelineExecutor:
|
|
|
123
123
|
if "model" in workbench_objects and (not subset or "endpoint" in subset):
|
|
124
124
|
workbench_objects["model"].to_endpoint(**kwargs)
|
|
125
125
|
endpoint = Endpoint(kwargs["name"])
|
|
126
|
-
endpoint.auto_inference(
|
|
126
|
+
endpoint.auto_inference()
|
|
127
127
|
|
|
128
128
|
# Found something weird
|
|
129
129
|
else:
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""ModelToEndpoint: Deploy an Endpoint for a Model"""
|
|
2
2
|
|
|
3
3
|
import time
|
|
4
|
+
from botocore.exceptions import ClientError
|
|
4
5
|
from sagemaker import ModelPackage
|
|
5
6
|
from sagemaker.serializers import CSVSerializer
|
|
6
7
|
from sagemaker.deserializers import CSVDeserializer
|
|
@@ -137,16 +138,35 @@ class ModelToEndpoint(Transform):
|
|
|
137
138
|
|
|
138
139
|
# Deploy the Endpoint
|
|
139
140
|
self.log.important(f"Deploying the Endpoint {self.output_name}...")
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
141
|
+
try:
|
|
142
|
+
model_package.deploy(
|
|
143
|
+
initial_instance_count=1,
|
|
144
|
+
instance_type=self.instance_type,
|
|
145
|
+
serverless_inference_config=serverless_config,
|
|
146
|
+
endpoint_name=self.output_name,
|
|
147
|
+
serializer=CSVSerializer(),
|
|
148
|
+
deserializer=CSVDeserializer(),
|
|
149
|
+
data_capture_config=data_capture_config,
|
|
150
|
+
tags=aws_tags,
|
|
151
|
+
)
|
|
152
|
+
except ClientError as e:
|
|
153
|
+
# Check if this is the "endpoint config already exists" error
|
|
154
|
+
if "Cannot create already existing endpoint configuration" in str(e):
|
|
155
|
+
self.log.warning("Endpoint config already exists, deleting and retrying...")
|
|
156
|
+
self.sm_client.delete_endpoint_config(EndpointConfigName=self.output_name)
|
|
157
|
+
# Retry the deploy
|
|
158
|
+
model_package.deploy(
|
|
159
|
+
initial_instance_count=1,
|
|
160
|
+
instance_type=self.instance_type,
|
|
161
|
+
serverless_inference_config=serverless_config,
|
|
162
|
+
endpoint_name=self.output_name,
|
|
163
|
+
serializer=CSVSerializer(),
|
|
164
|
+
deserializer=CSVDeserializer(),
|
|
165
|
+
data_capture_config=data_capture_config,
|
|
166
|
+
tags=aws_tags,
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
raise
|
|
150
170
|
|
|
151
171
|
def post_transform(self, **kwargs):
|
|
152
172
|
"""Post-Transform: Calling onboard() for the Endpoint"""
|
|
@@ -245,6 +245,7 @@ def train_model(
|
|
|
245
245
|
patience: int = 20,
|
|
246
246
|
batch_size: int = 128,
|
|
247
247
|
learning_rate: float = 1e-3,
|
|
248
|
+
loss: str = "L1Loss",
|
|
248
249
|
device: str = "cpu",
|
|
249
250
|
) -> tuple[TabularMLP, dict]:
|
|
250
251
|
"""Train the model with early stopping.
|
|
@@ -272,7 +273,16 @@ def train_model(
|
|
|
272
273
|
if task == "classification":
|
|
273
274
|
criterion = nn.CrossEntropyLoss()
|
|
274
275
|
else:
|
|
275
|
-
|
|
276
|
+
# Map loss name to PyTorch loss class
|
|
277
|
+
loss_map = {
|
|
278
|
+
"L1Loss": nn.L1Loss,
|
|
279
|
+
"MSELoss": nn.MSELoss,
|
|
280
|
+
"HuberLoss": nn.HuberLoss,
|
|
281
|
+
"SmoothL1Loss": nn.SmoothL1Loss,
|
|
282
|
+
}
|
|
283
|
+
if loss not in loss_map:
|
|
284
|
+
raise ValueError(f"Unknown loss '{loss}'. Supported: {list(loss_map.keys())}")
|
|
285
|
+
criterion = loss_map[loss]()
|
|
276
286
|
|
|
277
287
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
|
278
288
|
|
|
@@ -5,36 +5,24 @@
|
|
|
5
5
|
# - Multi-task regression support
|
|
6
6
|
# - Hybrid mode (SMILES + extra molecular descriptors)
|
|
7
7
|
# - Classification (single-target only)
|
|
8
|
+
#
|
|
9
|
+
# NOTE: Imports are structured to minimize serverless endpoint startup time.
|
|
10
|
+
# Heavy imports (lightning, sklearn, awswrangler) are deferred to training time.
|
|
8
11
|
|
|
9
|
-
import argparse
|
|
10
|
-
import glob
|
|
11
12
|
import json
|
|
12
13
|
import os
|
|
13
14
|
|
|
14
|
-
import awswrangler as wr
|
|
15
15
|
import joblib
|
|
16
16
|
import numpy as np
|
|
17
17
|
import pandas as pd
|
|
18
18
|
import torch
|
|
19
|
-
from lightning import pytorch as pl
|
|
20
|
-
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
|
|
21
|
-
from sklearn.preprocessing import LabelEncoder
|
|
22
|
-
|
|
23
|
-
# Enable Tensor Core optimization for GPUs that support it
|
|
24
|
-
torch.set_float32_matmul_precision("medium")
|
|
25
19
|
|
|
26
|
-
from chemprop import data, models
|
|
20
|
+
from chemprop import data, models
|
|
27
21
|
|
|
28
22
|
from model_script_utils import (
|
|
29
|
-
check_dataframe,
|
|
30
|
-
compute_classification_metrics,
|
|
31
|
-
compute_regression_metrics,
|
|
32
23
|
expand_proba_column,
|
|
33
24
|
input_fn,
|
|
34
25
|
output_fn,
|
|
35
|
-
print_classification_metrics,
|
|
36
|
-
print_confusion_matrix,
|
|
37
|
-
print_regression_metrics,
|
|
38
26
|
)
|
|
39
27
|
|
|
40
28
|
# =============================================================================
|
|
@@ -44,15 +32,17 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
44
32
|
# Training
|
|
45
33
|
"n_folds": 5,
|
|
46
34
|
"max_epochs": 400,
|
|
47
|
-
"patience":
|
|
48
|
-
"batch_size":
|
|
35
|
+
"patience": 50,
|
|
36
|
+
"batch_size": 32,
|
|
49
37
|
# Message Passing
|
|
50
38
|
"hidden_dim": 700,
|
|
51
39
|
"depth": 6,
|
|
52
|
-
"dropout": 0.
|
|
40
|
+
"dropout": 0.1, # Lower dropout - ensemble provides regularization
|
|
53
41
|
# FFN
|
|
54
42
|
"ffn_hidden_dim": 2000,
|
|
55
43
|
"ffn_num_layers": 2,
|
|
44
|
+
# Loss function for regression (mae, mse)
|
|
45
|
+
"criterion": "mae",
|
|
56
46
|
# Random seed
|
|
57
47
|
"seed": 42,
|
|
58
48
|
}
|
|
@@ -71,7 +61,26 @@ TEMPLATE_PARAMS = {
|
|
|
71
61
|
# =============================================================================
|
|
72
62
|
# Helper Functions
|
|
73
63
|
# =============================================================================
|
|
74
|
-
def
|
|
64
|
+
def _compute_std_confidence(df: pd.DataFrame, median_std: float, std_col: str = "prediction_std") -> pd.DataFrame:
|
|
65
|
+
"""Compute confidence score from ensemble prediction_std.
|
|
66
|
+
|
|
67
|
+
Uses exponential decay: confidence = exp(-std / median_std)
|
|
68
|
+
- Low std (ensemble agreement) -> high confidence
|
|
69
|
+
- High std (ensemble disagreement) -> low confidence
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
df: DataFrame with prediction_std column
|
|
73
|
+
median_std: Median std from training validation set (normalization factor)
|
|
74
|
+
std_col: Name of the std column to use
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
DataFrame with added 'confidence' column (0.0 to 1.0)
|
|
78
|
+
"""
|
|
79
|
+
df["confidence"] = np.exp(-df[std_col] / median_std)
|
|
80
|
+
return df
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _find_smiles_column(columns: list[str]) -> str:
|
|
75
84
|
"""Find SMILES column (case-insensitive match for 'smiles')."""
|
|
76
85
|
smiles_col = next((c for c in columns if c.lower() == "smiles"), None)
|
|
77
86
|
if smiles_col is None:
|
|
@@ -79,7 +88,7 @@ def find_smiles_column(columns: list[str]) -> str:
|
|
|
79
88
|
return smiles_col
|
|
80
89
|
|
|
81
90
|
|
|
82
|
-
def
|
|
91
|
+
def _create_molecule_datapoints(
|
|
83
92
|
smiles_list: list[str],
|
|
84
93
|
targets: np.ndarray | None = None,
|
|
85
94
|
extra_descriptors: np.ndarray | None = None,
|
|
@@ -101,47 +110,13 @@ def create_molecule_datapoints(
|
|
|
101
110
|
return datapoints, valid_indices
|
|
102
111
|
|
|
103
112
|
|
|
104
|
-
def build_mpnn_model(
|
|
105
|
-
hyperparameters: dict,
|
|
106
|
-
task: str = "regression",
|
|
107
|
-
num_classes: int | None = None,
|
|
108
|
-
n_targets: int = 1,
|
|
109
|
-
n_extra_descriptors: int = 0,
|
|
110
|
-
x_d_transform: nn.ScaleTransform | None = None,
|
|
111
|
-
output_transform: nn.UnscaleTransform | None = None,
|
|
112
|
-
task_weights: np.ndarray | None = None,
|
|
113
|
-
) -> models.MPNN:
|
|
114
|
-
"""Build an MPNN model with specified hyperparameters."""
|
|
115
|
-
hidden_dim = hyperparameters["hidden_dim"]
|
|
116
|
-
depth = hyperparameters["depth"]
|
|
117
|
-
dropout = hyperparameters["dropout"]
|
|
118
|
-
ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
|
|
119
|
-
ffn_num_layers = hyperparameters["ffn_num_layers"]
|
|
120
|
-
|
|
121
|
-
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
122
|
-
agg = nn.NormAggregation()
|
|
123
|
-
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
124
|
-
|
|
125
|
-
if task == "classification" and num_classes is not None:
|
|
126
|
-
ffn = nn.MulticlassClassificationFFN(
|
|
127
|
-
n_classes=num_classes, input_dim=ffn_input_dim,
|
|
128
|
-
hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
|
|
129
|
-
)
|
|
130
|
-
else:
|
|
131
|
-
weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
|
|
132
|
-
ffn = nn.RegressionFFN(
|
|
133
|
-
input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
|
|
134
|
-
dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
|
|
138
|
-
|
|
139
|
-
|
|
140
113
|
# =============================================================================
|
|
141
114
|
# Model Loading (for SageMaker inference)
|
|
142
115
|
# =============================================================================
|
|
143
116
|
def model_fn(model_dir: str) -> dict:
|
|
144
117
|
"""Load ChemProp MPNN ensemble from the specified directory."""
|
|
118
|
+
from lightning import pytorch as pl
|
|
119
|
+
|
|
145
120
|
metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
|
|
146
121
|
ensemble_models = []
|
|
147
122
|
for i in range(metadata["n_ensemble"]):
|
|
@@ -149,8 +124,17 @@ def model_fn(model_dir: str) -> dict:
|
|
|
149
124
|
model.eval()
|
|
150
125
|
ensemble_models.append(model)
|
|
151
126
|
|
|
127
|
+
# Pre-initialize trainer once during model loading (expensive operation)
|
|
128
|
+
trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
129
|
+
|
|
152
130
|
print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
|
|
153
|
-
return {
|
|
131
|
+
return {
|
|
132
|
+
"ensemble_models": ensemble_models,
|
|
133
|
+
"n_ensemble": metadata["n_ensemble"],
|
|
134
|
+
"target_columns": metadata["target_columns"],
|
|
135
|
+
"median_std": metadata["median_std"],
|
|
136
|
+
"trainer": trainer,
|
|
137
|
+
}
|
|
154
138
|
|
|
155
139
|
|
|
156
140
|
# =============================================================================
|
|
@@ -163,6 +147,7 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
163
147
|
|
|
164
148
|
ensemble_models = model_dict["ensemble_models"]
|
|
165
149
|
target_columns = model_dict["target_columns"]
|
|
150
|
+
trainer = model_dict["trainer"] # Use pre-initialized trainer
|
|
166
151
|
|
|
167
152
|
# Load artifacts
|
|
168
153
|
label_encoder = None
|
|
@@ -177,7 +162,7 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
177
162
|
print(f"Hybrid mode: {len(feature_metadata['extra_feature_cols'])} extra features")
|
|
178
163
|
|
|
179
164
|
# Find SMILES column and validate
|
|
180
|
-
smiles_column =
|
|
165
|
+
smiles_column = _find_smiles_column(df.columns.tolist())
|
|
181
166
|
smiles_list = df[smiles_column].tolist()
|
|
182
167
|
|
|
183
168
|
valid_mask = np.array([bool(s and isinstance(s, str) and s.strip()) for s in smiles_list])
|
|
@@ -212,13 +197,12 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
212
197
|
extra_features[:, j] = col_means[j]
|
|
213
198
|
|
|
214
199
|
# Create datapoints and predict
|
|
215
|
-
datapoints, rdkit_valid =
|
|
200
|
+
datapoints, rdkit_valid = _create_molecule_datapoints(valid_smiles, extra_descriptors=extra_features)
|
|
216
201
|
if len(datapoints) == 0:
|
|
217
202
|
return df
|
|
218
203
|
|
|
219
204
|
dataset = data.MoleculeDataset(datapoints)
|
|
220
205
|
dataloader = data.build_dataloader(dataset, shuffle=False)
|
|
221
|
-
trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
222
206
|
|
|
223
207
|
# Ensemble predictions
|
|
224
208
|
all_preds = []
|
|
@@ -259,6 +243,9 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
259
243
|
df["prediction"] = df[f"{target_columns[0]}_pred"]
|
|
260
244
|
df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
|
|
261
245
|
|
|
246
|
+
# Compute confidence from ensemble std
|
|
247
|
+
df = _compute_std_confidence(df, model_dict["median_std"])
|
|
248
|
+
|
|
262
249
|
return df
|
|
263
250
|
|
|
264
251
|
|
|
@@ -266,6 +253,82 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
266
253
|
# Training
|
|
267
254
|
# =============================================================================
|
|
268
255
|
if __name__ == "__main__":
|
|
256
|
+
# -------------------------------------------------------------------------
|
|
257
|
+
# Training-only imports (deferred to reduce serverless startup time)
|
|
258
|
+
# -------------------------------------------------------------------------
|
|
259
|
+
import argparse
|
|
260
|
+
import glob
|
|
261
|
+
|
|
262
|
+
import awswrangler as wr
|
|
263
|
+
from lightning import pytorch as pl
|
|
264
|
+
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
|
|
265
|
+
from sklearn.preprocessing import LabelEncoder
|
|
266
|
+
|
|
267
|
+
# Enable Tensor Core optimization for GPUs that support it
|
|
268
|
+
torch.set_float32_matmul_precision("medium")
|
|
269
|
+
|
|
270
|
+
from chemprop import nn
|
|
271
|
+
|
|
272
|
+
from model_script_utils import (
|
|
273
|
+
check_dataframe,
|
|
274
|
+
compute_classification_metrics,
|
|
275
|
+
compute_regression_metrics,
|
|
276
|
+
print_classification_metrics,
|
|
277
|
+
print_confusion_matrix,
|
|
278
|
+
print_regression_metrics,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# -------------------------------------------------------------------------
|
|
282
|
+
# Training-only helper function
|
|
283
|
+
# -------------------------------------------------------------------------
|
|
284
|
+
def build_mpnn_model(
|
|
285
|
+
hyperparameters: dict,
|
|
286
|
+
task: str = "regression",
|
|
287
|
+
num_classes: int | None = None,
|
|
288
|
+
n_targets: int = 1,
|
|
289
|
+
n_extra_descriptors: int = 0,
|
|
290
|
+
x_d_transform: nn.ScaleTransform | None = None,
|
|
291
|
+
output_transform: nn.UnscaleTransform | None = None,
|
|
292
|
+
task_weights: np.ndarray | None = None,
|
|
293
|
+
) -> models.MPNN:
|
|
294
|
+
"""Build an MPNN model with specified hyperparameters."""
|
|
295
|
+
hidden_dim = hyperparameters["hidden_dim"]
|
|
296
|
+
depth = hyperparameters["depth"]
|
|
297
|
+
dropout = hyperparameters["dropout"]
|
|
298
|
+
ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
|
|
299
|
+
ffn_num_layers = hyperparameters["ffn_num_layers"]
|
|
300
|
+
|
|
301
|
+
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
302
|
+
agg = nn.NormAggregation()
|
|
303
|
+
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
304
|
+
|
|
305
|
+
if task == "classification" and num_classes is not None:
|
|
306
|
+
ffn = nn.MulticlassClassificationFFN(
|
|
307
|
+
n_classes=num_classes, input_dim=ffn_input_dim,
|
|
308
|
+
hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
|
|
309
|
+
)
|
|
310
|
+
else:
|
|
311
|
+
# Map criterion name to ChemProp metric class (must have .clone() method)
|
|
312
|
+
from chemprop.nn.metrics import MAE, MSE
|
|
313
|
+
|
|
314
|
+
criterion_map = {
|
|
315
|
+
"mae": MAE,
|
|
316
|
+
"mse": MSE,
|
|
317
|
+
}
|
|
318
|
+
criterion_name = hyperparameters.get("criterion", "mae")
|
|
319
|
+
if criterion_name not in criterion_map:
|
|
320
|
+
raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
|
|
321
|
+
criterion = criterion_map[criterion_name]()
|
|
322
|
+
|
|
323
|
+
weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
|
|
324
|
+
ffn = nn.RegressionFFN(
|
|
325
|
+
input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
|
|
326
|
+
dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
|
|
327
|
+
criterion=criterion,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
|
|
331
|
+
|
|
269
332
|
# -------------------------------------------------------------------------
|
|
270
333
|
# Setup: Parse arguments and load data
|
|
271
334
|
# -------------------------------------------------------------------------
|
|
@@ -287,7 +350,7 @@ if __name__ == "__main__":
|
|
|
287
350
|
raise ValueError("'targets' must be a non-empty list of target column names")
|
|
288
351
|
n_targets = len(target_columns)
|
|
289
352
|
|
|
290
|
-
smiles_column =
|
|
353
|
+
smiles_column = _find_smiles_column(feature_list)
|
|
291
354
|
extra_feature_cols = [f for f in feature_list if f != smiles_column]
|
|
292
355
|
use_extra_features = len(extra_feature_cols) > 0
|
|
293
356
|
|
|
@@ -342,7 +405,7 @@ if __name__ == "__main__":
|
|
|
342
405
|
all_targets = all_df[target_columns].values.astype(np.float32)
|
|
343
406
|
|
|
344
407
|
# Filter invalid SMILES
|
|
345
|
-
_, valid_indices =
|
|
408
|
+
_, valid_indices = _create_molecule_datapoints(all_df[smiles_column].tolist(), all_targets, all_extra_features)
|
|
346
409
|
all_df = all_df.iloc[valid_indices].reset_index(drop=True)
|
|
347
410
|
all_targets = all_targets[valid_indices]
|
|
348
411
|
if all_extra_features is not None:
|
|
@@ -401,8 +464,8 @@ if __name__ == "__main__":
|
|
|
401
464
|
val_extra_raw = val_extra.copy() if val_extra is not None else None
|
|
402
465
|
|
|
403
466
|
# Create datasets
|
|
404
|
-
train_dps, _ =
|
|
405
|
-
val_dps, _ =
|
|
467
|
+
train_dps, _ = _create_molecule_datapoints(df_train[smiles_column].tolist(), train_targets, train_extra)
|
|
468
|
+
val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra)
|
|
406
469
|
train_dataset, val_dataset = data.MoleculeDataset(train_dps), data.MoleculeDataset(val_dps)
|
|
407
470
|
|
|
408
471
|
# Scale features/targets
|
|
@@ -447,7 +510,7 @@ if __name__ == "__main__":
|
|
|
447
510
|
ensemble_models.append(mpnn)
|
|
448
511
|
|
|
449
512
|
# Out-of-fold predictions (using raw features)
|
|
450
|
-
val_dps_raw, _ =
|
|
513
|
+
val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
|
|
451
514
|
val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
|
|
452
515
|
|
|
453
516
|
with torch.inference_mode():
|
|
@@ -486,6 +549,7 @@ if __name__ == "__main__":
|
|
|
486
549
|
# -------------------------------------------------------------------------
|
|
487
550
|
# Compute metrics and prepare output
|
|
488
551
|
# -------------------------------------------------------------------------
|
|
552
|
+
median_std = None # Only set for regression models with ensemble
|
|
489
553
|
if model_type == "classifier":
|
|
490
554
|
class_preds = preds[:, 0].astype(int)
|
|
491
555
|
target_name = target_columns[0]
|
|
@@ -507,7 +571,7 @@ if __name__ == "__main__":
|
|
|
507
571
|
preds_std = None
|
|
508
572
|
if len(ensemble_models) > 1:
|
|
509
573
|
print("Computing prediction_std from ensemble...")
|
|
510
|
-
val_dps, _ =
|
|
574
|
+
val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), y_validate, val_extra_features)
|
|
511
575
|
val_loader = data.build_dataloader(data.MoleculeDataset(val_dps), batch_size=batch_size, shuffle=False)
|
|
512
576
|
trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
513
577
|
|
|
@@ -535,13 +599,19 @@ if __name__ == "__main__":
|
|
|
535
599
|
df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
|
|
536
600
|
df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
|
|
537
601
|
|
|
602
|
+
# Compute confidence from ensemble std
|
|
603
|
+
median_std = float(np.median(preds_std[:, 0]))
|
|
604
|
+
print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
|
|
605
|
+
df_val = _compute_std_confidence(df_val, median_std)
|
|
606
|
+
print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
|
|
607
|
+
|
|
538
608
|
# -------------------------------------------------------------------------
|
|
539
609
|
# Save validation predictions to S3
|
|
540
610
|
# -------------------------------------------------------------------------
|
|
541
611
|
output_columns = [id_column] if id_column in df_val.columns else []
|
|
542
612
|
output_columns += target_columns
|
|
543
613
|
output_columns += [f"{t}_pred" for t in target_columns] + [f"{t}_pred_std" for t in target_columns]
|
|
544
|
-
output_columns += ["prediction", "prediction_std"]
|
|
614
|
+
output_columns += ["prediction", "prediction_std", "confidence"]
|
|
545
615
|
output_columns += [c for c in df_val.columns if c.endswith("_proba")]
|
|
546
616
|
output_columns = [c for c in output_columns if c in df_val.columns]
|
|
547
617
|
|
|
@@ -558,7 +628,13 @@ if __name__ == "__main__":
|
|
|
558
628
|
for ckpt in glob.glob(os.path.join(args.model_dir, "best_*.ckpt")):
|
|
559
629
|
os.remove(ckpt)
|
|
560
630
|
|
|
561
|
-
|
|
631
|
+
ensemble_metadata = {
|
|
632
|
+
"n_ensemble": len(ensemble_models),
|
|
633
|
+
"n_folds": n_folds,
|
|
634
|
+
"target_columns": target_columns,
|
|
635
|
+
"median_std": median_std, # For confidence calculation during inference
|
|
636
|
+
}
|
|
637
|
+
joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
562
638
|
|
|
563
639
|
with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
|
|
564
640
|
json.dump(hyperparameters, f, indent=2)
|