workbench 0.8.213__py3-none-any.whl → 0.8.217__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
  3. workbench/algorithms/dataframe/projection_2d.py +38 -21
  4. workbench/algorithms/dataframe/proximity.py +75 -150
  5. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  6. workbench/algorithms/models/cleanlab_model.py +382 -0
  7. workbench/algorithms/models/noise_model.py +2 -2
  8. workbench/api/__init__.py +3 -0
  9. workbench/api/endpoint.py +10 -5
  10. workbench/api/feature_set.py +76 -6
  11. workbench/api/meta_model.py +289 -0
  12. workbench/api/model.py +43 -4
  13. workbench/core/artifacts/endpoint_core.py +63 -115
  14. workbench/core/artifacts/feature_set_core.py +1 -1
  15. workbench/core/artifacts/model_core.py +6 -4
  16. workbench/core/pipelines/pipeline_executor.py +1 -1
  17. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
  18. workbench/model_script_utils/pytorch_utils.py +11 -1
  19. workbench/model_scripts/chemprop/chemprop.template +145 -69
  20. workbench/model_scripts/chemprop/generated_model_script.py +147 -71
  21. workbench/model_scripts/custom_models/chem_info/fingerprints.py +7 -3
  22. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  23. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  24. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  25. workbench/model_scripts/custom_models/uq_models/meta_uq.template +6 -6
  26. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  27. workbench/model_scripts/meta_model/meta_model.template +209 -0
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +42 -24
  29. workbench/model_scripts/pytorch_model/pytorch.template +42 -24
  30. workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
  31. workbench/model_scripts/script_generation.py +4 -0
  32. workbench/model_scripts/xgb_model/generated_model_script.py +169 -158
  33. workbench/model_scripts/xgb_model/xgb_model.template +163 -152
  34. workbench/repl/workbench_shell.py +0 -5
  35. workbench/scripts/endpoint_test.py +2 -2
  36. workbench/utils/chem_utils/fingerprints.py +7 -3
  37. workbench/utils/chemprop_utils.py +23 -5
  38. workbench/utils/meta_model_simulator.py +471 -0
  39. workbench/utils/metrics_utils.py +94 -10
  40. workbench/utils/model_utils.py +91 -9
  41. workbench/utils/pytorch_utils.py +1 -1
  42. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  43. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/METADATA +2 -1
  44. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/RECORD +48 -43
  45. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  46. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  47. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/WHEEL +0 -0
  48. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/entry_points.txt +0 -0
  49. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/licenses/LICENSE +0 -0
  50. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/top_level.txt +0 -0
@@ -123,7 +123,7 @@ class PipelineExecutor:
123
123
  if "model" in workbench_objects and (not subset or "endpoint" in subset):
124
124
  workbench_objects["model"].to_endpoint(**kwargs)
125
125
  endpoint = Endpoint(kwargs["name"])
126
- endpoint.auto_inference(capture=True)
126
+ endpoint.auto_inference()
127
127
 
128
128
  # Found something weird
129
129
  else:
@@ -1,6 +1,7 @@
1
1
  """ModelToEndpoint: Deploy an Endpoint for a Model"""
2
2
 
3
3
  import time
4
+ from botocore.exceptions import ClientError
4
5
  from sagemaker import ModelPackage
5
6
  from sagemaker.serializers import CSVSerializer
6
7
  from sagemaker.deserializers import CSVDeserializer
@@ -137,16 +138,35 @@ class ModelToEndpoint(Transform):
137
138
 
138
139
  # Deploy the Endpoint
139
140
  self.log.important(f"Deploying the Endpoint {self.output_name}...")
140
- model_package.deploy(
141
- initial_instance_count=1,
142
- instance_type=self.instance_type,
143
- serverless_inference_config=serverless_config,
144
- endpoint_name=self.output_name,
145
- serializer=CSVSerializer(),
146
- deserializer=CSVDeserializer(),
147
- data_capture_config=data_capture_config,
148
- tags=aws_tags,
149
- )
141
+ try:
142
+ model_package.deploy(
143
+ initial_instance_count=1,
144
+ instance_type=self.instance_type,
145
+ serverless_inference_config=serverless_config,
146
+ endpoint_name=self.output_name,
147
+ serializer=CSVSerializer(),
148
+ deserializer=CSVDeserializer(),
149
+ data_capture_config=data_capture_config,
150
+ tags=aws_tags,
151
+ )
152
+ except ClientError as e:
153
+ # Check if this is the "endpoint config already exists" error
154
+ if "Cannot create already existing endpoint configuration" in str(e):
155
+ self.log.warning("Endpoint config already exists, deleting and retrying...")
156
+ self.sm_client.delete_endpoint_config(EndpointConfigName=self.output_name)
157
+ # Retry the deploy
158
+ model_package.deploy(
159
+ initial_instance_count=1,
160
+ instance_type=self.instance_type,
161
+ serverless_inference_config=serverless_config,
162
+ endpoint_name=self.output_name,
163
+ serializer=CSVSerializer(),
164
+ deserializer=CSVDeserializer(),
165
+ data_capture_config=data_capture_config,
166
+ tags=aws_tags,
167
+ )
168
+ else:
169
+ raise
150
170
 
151
171
  def post_transform(self, **kwargs):
152
172
  """Post-Transform: Calling onboard() for the Endpoint"""
@@ -245,6 +245,7 @@ def train_model(
245
245
  patience: int = 20,
246
246
  batch_size: int = 128,
247
247
  learning_rate: float = 1e-3,
248
+ loss: str = "L1Loss",
248
249
  device: str = "cpu",
249
250
  ) -> tuple[TabularMLP, dict]:
250
251
  """Train the model with early stopping.
@@ -272,7 +273,16 @@ def train_model(
272
273
  if task == "classification":
273
274
  criterion = nn.CrossEntropyLoss()
274
275
  else:
275
- criterion = nn.MSELoss()
276
+ # Map loss name to PyTorch loss class
277
+ loss_map = {
278
+ "L1Loss": nn.L1Loss,
279
+ "MSELoss": nn.MSELoss,
280
+ "HuberLoss": nn.HuberLoss,
281
+ "SmoothL1Loss": nn.SmoothL1Loss,
282
+ }
283
+ if loss not in loss_map:
284
+ raise ValueError(f"Unknown loss '{loss}'. Supported: {list(loss_map.keys())}")
285
+ criterion = loss_map[loss]()
276
286
 
277
287
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
278
288
 
@@ -5,36 +5,24 @@
5
5
  # - Multi-task regression support
6
6
  # - Hybrid mode (SMILES + extra molecular descriptors)
7
7
  # - Classification (single-target only)
8
+ #
9
+ # NOTE: Imports are structured to minimize serverless endpoint startup time.
10
+ # Heavy imports (lightning, sklearn, awswrangler) are deferred to training time.
8
11
 
9
- import argparse
10
- import glob
11
12
  import json
12
13
  import os
13
14
 
14
- import awswrangler as wr
15
15
  import joblib
16
16
  import numpy as np
17
17
  import pandas as pd
18
18
  import torch
19
- from lightning import pytorch as pl
20
- from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
21
- from sklearn.preprocessing import LabelEncoder
22
-
23
- # Enable Tensor Core optimization for GPUs that support it
24
- torch.set_float32_matmul_precision("medium")
25
19
 
26
- from chemprop import data, models, nn
20
+ from chemprop import data, models
27
21
 
28
22
  from model_script_utils import (
29
- check_dataframe,
30
- compute_classification_metrics,
31
- compute_regression_metrics,
32
23
  expand_proba_column,
33
24
  input_fn,
34
25
  output_fn,
35
- print_classification_metrics,
36
- print_confusion_matrix,
37
- print_regression_metrics,
38
26
  )
39
27
 
40
28
  # =============================================================================
@@ -44,15 +32,17 @@ DEFAULT_HYPERPARAMETERS = {
44
32
  # Training
45
33
  "n_folds": 5,
46
34
  "max_epochs": 400,
47
- "patience": 40,
48
- "batch_size": 16,
35
+ "patience": 50,
36
+ "batch_size": 32,
49
37
  # Message Passing
50
38
  "hidden_dim": 700,
51
39
  "depth": 6,
52
- "dropout": 0.15,
40
+ "dropout": 0.1, # Lower dropout - ensemble provides regularization
53
41
  # FFN
54
42
  "ffn_hidden_dim": 2000,
55
43
  "ffn_num_layers": 2,
44
+ # Loss function for regression (mae, mse)
45
+ "criterion": "mae",
56
46
  # Random seed
57
47
  "seed": 42,
58
48
  }
@@ -71,7 +61,26 @@ TEMPLATE_PARAMS = {
71
61
  # =============================================================================
72
62
  # Helper Functions
73
63
  # =============================================================================
74
- def find_smiles_column(columns: list[str]) -> str:
64
+ def _compute_std_confidence(df: pd.DataFrame, median_std: float, std_col: str = "prediction_std") -> pd.DataFrame:
65
+ """Compute confidence score from ensemble prediction_std.
66
+
67
+ Uses exponential decay: confidence = exp(-std / median_std)
68
+ - Low std (ensemble agreement) -> high confidence
69
+ - High std (ensemble disagreement) -> low confidence
70
+
71
+ Args:
72
+ df: DataFrame with prediction_std column
73
+ median_std: Median std from training validation set (normalization factor)
74
+ std_col: Name of the std column to use
75
+
76
+ Returns:
77
+ DataFrame with added 'confidence' column (0.0 to 1.0)
78
+ """
79
+ df["confidence"] = np.exp(-df[std_col] / median_std)
80
+ return df
81
+
82
+
83
+ def _find_smiles_column(columns: list[str]) -> str:
75
84
  """Find SMILES column (case-insensitive match for 'smiles')."""
76
85
  smiles_col = next((c for c in columns if c.lower() == "smiles"), None)
77
86
  if smiles_col is None:
@@ -79,7 +88,7 @@ def find_smiles_column(columns: list[str]) -> str:
79
88
  return smiles_col
80
89
 
81
90
 
82
- def create_molecule_datapoints(
91
+ def _create_molecule_datapoints(
83
92
  smiles_list: list[str],
84
93
  targets: np.ndarray | None = None,
85
94
  extra_descriptors: np.ndarray | None = None,
@@ -101,47 +110,13 @@ def create_molecule_datapoints(
101
110
  return datapoints, valid_indices
102
111
 
103
112
 
104
- def build_mpnn_model(
105
- hyperparameters: dict,
106
- task: str = "regression",
107
- num_classes: int | None = None,
108
- n_targets: int = 1,
109
- n_extra_descriptors: int = 0,
110
- x_d_transform: nn.ScaleTransform | None = None,
111
- output_transform: nn.UnscaleTransform | None = None,
112
- task_weights: np.ndarray | None = None,
113
- ) -> models.MPNN:
114
- """Build an MPNN model with specified hyperparameters."""
115
- hidden_dim = hyperparameters["hidden_dim"]
116
- depth = hyperparameters["depth"]
117
- dropout = hyperparameters["dropout"]
118
- ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
119
- ffn_num_layers = hyperparameters["ffn_num_layers"]
120
-
121
- mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
122
- agg = nn.NormAggregation()
123
- ffn_input_dim = hidden_dim + n_extra_descriptors
124
-
125
- if task == "classification" and num_classes is not None:
126
- ffn = nn.MulticlassClassificationFFN(
127
- n_classes=num_classes, input_dim=ffn_input_dim,
128
- hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
129
- )
130
- else:
131
- weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
132
- ffn = nn.RegressionFFN(
133
- input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
134
- dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
135
- )
136
-
137
- return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
138
-
139
-
140
113
  # =============================================================================
141
114
  # Model Loading (for SageMaker inference)
142
115
  # =============================================================================
143
116
  def model_fn(model_dir: str) -> dict:
144
117
  """Load ChemProp MPNN ensemble from the specified directory."""
118
+ from lightning import pytorch as pl
119
+
145
120
  metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
146
121
  ensemble_models = []
147
122
  for i in range(metadata["n_ensemble"]):
@@ -149,8 +124,17 @@ def model_fn(model_dir: str) -> dict:
149
124
  model.eval()
150
125
  ensemble_models.append(model)
151
126
 
127
+ # Pre-initialize trainer once during model loading (expensive operation)
128
+ trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
129
+
152
130
  print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
153
- return {"ensemble_models": ensemble_models, "n_ensemble": metadata["n_ensemble"], "target_columns": metadata["target_columns"]}
131
+ return {
132
+ "ensemble_models": ensemble_models,
133
+ "n_ensemble": metadata["n_ensemble"],
134
+ "target_columns": metadata["target_columns"],
135
+ "median_std": metadata["median_std"],
136
+ "trainer": trainer,
137
+ }
154
138
 
155
139
 
156
140
  # =============================================================================
@@ -163,6 +147,7 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
163
147
 
164
148
  ensemble_models = model_dict["ensemble_models"]
165
149
  target_columns = model_dict["target_columns"]
150
+ trainer = model_dict["trainer"] # Use pre-initialized trainer
166
151
 
167
152
  # Load artifacts
168
153
  label_encoder = None
@@ -177,7 +162,7 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
177
162
  print(f"Hybrid mode: {len(feature_metadata['extra_feature_cols'])} extra features")
178
163
 
179
164
  # Find SMILES column and validate
180
- smiles_column = find_smiles_column(df.columns.tolist())
165
+ smiles_column = _find_smiles_column(df.columns.tolist())
181
166
  smiles_list = df[smiles_column].tolist()
182
167
 
183
168
  valid_mask = np.array([bool(s and isinstance(s, str) and s.strip()) for s in smiles_list])
@@ -212,13 +197,12 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
212
197
  extra_features[:, j] = col_means[j]
213
198
 
214
199
  # Create datapoints and predict
215
- datapoints, rdkit_valid = create_molecule_datapoints(valid_smiles, extra_descriptors=extra_features)
200
+ datapoints, rdkit_valid = _create_molecule_datapoints(valid_smiles, extra_descriptors=extra_features)
216
201
  if len(datapoints) == 0:
217
202
  return df
218
203
 
219
204
  dataset = data.MoleculeDataset(datapoints)
220
205
  dataloader = data.build_dataloader(dataset, shuffle=False)
221
- trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
222
206
 
223
207
  # Ensemble predictions
224
208
  all_preds = []
@@ -259,6 +243,9 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
259
243
  df["prediction"] = df[f"{target_columns[0]}_pred"]
260
244
  df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
261
245
 
246
+ # Compute confidence from ensemble std
247
+ df = _compute_std_confidence(df, model_dict["median_std"])
248
+
262
249
  return df
263
250
 
264
251
 
@@ -266,6 +253,82 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
266
253
  # Training
267
254
  # =============================================================================
268
255
  if __name__ == "__main__":
256
+ # -------------------------------------------------------------------------
257
+ # Training-only imports (deferred to reduce serverless startup time)
258
+ # -------------------------------------------------------------------------
259
+ import argparse
260
+ import glob
261
+
262
+ import awswrangler as wr
263
+ from lightning import pytorch as pl
264
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
265
+ from sklearn.preprocessing import LabelEncoder
266
+
267
+ # Enable Tensor Core optimization for GPUs that support it
268
+ torch.set_float32_matmul_precision("medium")
269
+
270
+ from chemprop import nn
271
+
272
+ from model_script_utils import (
273
+ check_dataframe,
274
+ compute_classification_metrics,
275
+ compute_regression_metrics,
276
+ print_classification_metrics,
277
+ print_confusion_matrix,
278
+ print_regression_metrics,
279
+ )
280
+
281
+ # -------------------------------------------------------------------------
282
+ # Training-only helper function
283
+ # -------------------------------------------------------------------------
284
+ def build_mpnn_model(
285
+ hyperparameters: dict,
286
+ task: str = "regression",
287
+ num_classes: int | None = None,
288
+ n_targets: int = 1,
289
+ n_extra_descriptors: int = 0,
290
+ x_d_transform: nn.ScaleTransform | None = None,
291
+ output_transform: nn.UnscaleTransform | None = None,
292
+ task_weights: np.ndarray | None = None,
293
+ ) -> models.MPNN:
294
+ """Build an MPNN model with specified hyperparameters."""
295
+ hidden_dim = hyperparameters["hidden_dim"]
296
+ depth = hyperparameters["depth"]
297
+ dropout = hyperparameters["dropout"]
298
+ ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
299
+ ffn_num_layers = hyperparameters["ffn_num_layers"]
300
+
301
+ mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
302
+ agg = nn.NormAggregation()
303
+ ffn_input_dim = hidden_dim + n_extra_descriptors
304
+
305
+ if task == "classification" and num_classes is not None:
306
+ ffn = nn.MulticlassClassificationFFN(
307
+ n_classes=num_classes, input_dim=ffn_input_dim,
308
+ hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
309
+ )
310
+ else:
311
+ # Map criterion name to ChemProp metric class (must have .clone() method)
312
+ from chemprop.nn.metrics import MAE, MSE
313
+
314
+ criterion_map = {
315
+ "mae": MAE,
316
+ "mse": MSE,
317
+ }
318
+ criterion_name = hyperparameters.get("criterion", "mae")
319
+ if criterion_name not in criterion_map:
320
+ raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
321
+ criterion = criterion_map[criterion_name]()
322
+
323
+ weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
324
+ ffn = nn.RegressionFFN(
325
+ input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
326
+ dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
327
+ criterion=criterion,
328
+ )
329
+
330
+ return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
331
+
269
332
  # -------------------------------------------------------------------------
270
333
  # Setup: Parse arguments and load data
271
334
  # -------------------------------------------------------------------------
@@ -287,7 +350,7 @@ if __name__ == "__main__":
287
350
  raise ValueError("'targets' must be a non-empty list of target column names")
288
351
  n_targets = len(target_columns)
289
352
 
290
- smiles_column = find_smiles_column(feature_list)
353
+ smiles_column = _find_smiles_column(feature_list)
291
354
  extra_feature_cols = [f for f in feature_list if f != smiles_column]
292
355
  use_extra_features = len(extra_feature_cols) > 0
293
356
 
@@ -342,7 +405,7 @@ if __name__ == "__main__":
342
405
  all_targets = all_df[target_columns].values.astype(np.float32)
343
406
 
344
407
  # Filter invalid SMILES
345
- _, valid_indices = create_molecule_datapoints(all_df[smiles_column].tolist(), all_targets, all_extra_features)
408
+ _, valid_indices = _create_molecule_datapoints(all_df[smiles_column].tolist(), all_targets, all_extra_features)
346
409
  all_df = all_df.iloc[valid_indices].reset_index(drop=True)
347
410
  all_targets = all_targets[valid_indices]
348
411
  if all_extra_features is not None:
@@ -401,8 +464,8 @@ if __name__ == "__main__":
401
464
  val_extra_raw = val_extra.copy() if val_extra is not None else None
402
465
 
403
466
  # Create datasets
404
- train_dps, _ = create_molecule_datapoints(df_train[smiles_column].tolist(), train_targets, train_extra)
405
- val_dps, _ = create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra)
467
+ train_dps, _ = _create_molecule_datapoints(df_train[smiles_column].tolist(), train_targets, train_extra)
468
+ val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra)
406
469
  train_dataset, val_dataset = data.MoleculeDataset(train_dps), data.MoleculeDataset(val_dps)
407
470
 
408
471
  # Scale features/targets
@@ -447,7 +510,7 @@ if __name__ == "__main__":
447
510
  ensemble_models.append(mpnn)
448
511
 
449
512
  # Out-of-fold predictions (using raw features)
450
- val_dps_raw, _ = create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
513
+ val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
451
514
  val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
452
515
 
453
516
  with torch.inference_mode():
@@ -486,6 +549,7 @@ if __name__ == "__main__":
486
549
  # -------------------------------------------------------------------------
487
550
  # Compute metrics and prepare output
488
551
  # -------------------------------------------------------------------------
552
+ median_std = None # Only set for regression models with ensemble
489
553
  if model_type == "classifier":
490
554
  class_preds = preds[:, 0].astype(int)
491
555
  target_name = target_columns[0]
@@ -507,7 +571,7 @@ if __name__ == "__main__":
507
571
  preds_std = None
508
572
  if len(ensemble_models) > 1:
509
573
  print("Computing prediction_std from ensemble...")
510
- val_dps, _ = create_molecule_datapoints(df_val[smiles_column].tolist(), y_validate, val_extra_features)
574
+ val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), y_validate, val_extra_features)
511
575
  val_loader = data.build_dataloader(data.MoleculeDataset(val_dps), batch_size=batch_size, shuffle=False)
512
576
  trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
513
577
 
@@ -535,13 +599,19 @@ if __name__ == "__main__":
535
599
  df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
536
600
  df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
537
601
 
602
+ # Compute confidence from ensemble std
603
+ median_std = float(np.median(preds_std[:, 0]))
604
+ print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
605
+ df_val = _compute_std_confidence(df_val, median_std)
606
+ print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
607
+
538
608
  # -------------------------------------------------------------------------
539
609
  # Save validation predictions to S3
540
610
  # -------------------------------------------------------------------------
541
611
  output_columns = [id_column] if id_column in df_val.columns else []
542
612
  output_columns += target_columns
543
613
  output_columns += [f"{t}_pred" for t in target_columns] + [f"{t}_pred_std" for t in target_columns]
544
- output_columns += ["prediction", "prediction_std"]
614
+ output_columns += ["prediction", "prediction_std", "confidence"]
545
615
  output_columns += [c for c in df_val.columns if c.endswith("_proba")]
546
616
  output_columns = [c for c in output_columns if c in df_val.columns]
547
617
 
@@ -558,7 +628,13 @@ if __name__ == "__main__":
558
628
  for ckpt in glob.glob(os.path.join(args.model_dir, "best_*.ckpt")):
559
629
  os.remove(ckpt)
560
630
 
561
- joblib.dump({"n_ensemble": len(ensemble_models), "n_folds": n_folds, "target_columns": target_columns}, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
631
+ ensemble_metadata = {
632
+ "n_ensemble": len(ensemble_models),
633
+ "n_folds": n_folds,
634
+ "target_columns": target_columns,
635
+ "median_std": median_std, # For confidence calculation during inference
636
+ }
637
+ joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
562
638
 
563
639
  with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
564
640
  json.dump(hyperparameters, f, indent=2)