workbench 0.8.212__py3-none-any.whl → 0.8.217__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
  3. workbench/algorithms/dataframe/projection_2d.py +38 -21
  4. workbench/algorithms/dataframe/proximity.py +75 -150
  5. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  6. workbench/algorithms/models/cleanlab_model.py +382 -0
  7. workbench/algorithms/models/noise_model.py +2 -2
  8. workbench/api/__init__.py +3 -0
  9. workbench/api/endpoint.py +10 -5
  10. workbench/api/feature_set.py +76 -6
  11. workbench/api/meta_model.py +289 -0
  12. workbench/api/model.py +43 -4
  13. workbench/core/artifacts/endpoint_core.py +75 -129
  14. workbench/core/artifacts/feature_set_core.py +1 -1
  15. workbench/core/artifacts/model_core.py +6 -4
  16. workbench/core/pipelines/pipeline_executor.py +1 -1
  17. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
  18. workbench/model_script_utils/pytorch_utils.py +11 -1
  19. workbench/model_scripts/chemprop/chemprop.template +145 -69
  20. workbench/model_scripts/chemprop/generated_model_script.py +147 -71
  21. workbench/model_scripts/custom_models/chem_info/fingerprints.py +7 -3
  22. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  23. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  24. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  25. workbench/model_scripts/custom_models/uq_models/meta_uq.template +6 -6
  26. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  27. workbench/model_scripts/meta_model/meta_model.template +209 -0
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +42 -24
  29. workbench/model_scripts/pytorch_model/pytorch.template +42 -24
  30. workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
  31. workbench/model_scripts/script_generation.py +4 -0
  32. workbench/model_scripts/xgb_model/generated_model_script.py +169 -158
  33. workbench/model_scripts/xgb_model/xgb_model.template +163 -152
  34. workbench/repl/workbench_shell.py +0 -5
  35. workbench/scripts/endpoint_test.py +2 -2
  36. workbench/utils/chem_utils/fingerprints.py +7 -3
  37. workbench/utils/chemprop_utils.py +23 -5
  38. workbench/utils/meta_model_simulator.py +471 -0
  39. workbench/utils/metrics_utils.py +94 -10
  40. workbench/utils/model_utils.py +91 -9
  41. workbench/utils/pytorch_utils.py +1 -1
  42. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  43. {workbench-0.8.212.dist-info → workbench-0.8.217.dist-info}/METADATA +2 -1
  44. {workbench-0.8.212.dist-info → workbench-0.8.217.dist-info}/RECORD +48 -43
  45. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  46. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  47. {workbench-0.8.212.dist-info → workbench-0.8.217.dist-info}/WHEEL +0 -0
  48. {workbench-0.8.212.dist-info → workbench-0.8.217.dist-info}/entry_points.txt +0 -0
  49. {workbench-0.8.212.dist-info → workbench-0.8.217.dist-info}/licenses/LICENSE +0 -0
  50. {workbench-0.8.212.dist-info → workbench-0.8.217.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,289 @@
1
+ """MetaModel: A Model that aggregates predictions from multiple child endpoints.
2
+
3
+ MetaModels don't train on feature data - they combine predictions from existing
4
+ endpoints using confidence-weighted voting. This provides ensemble benefits
5
+ across different model frameworks (XGBoost, PyTorch, ChemProp, etc.).
6
+ """
7
+
8
+ from pathlib import Path
9
+ import time
10
+ import logging
11
+
12
+ from sagemaker.estimator import Estimator
13
+
14
+ # Workbench Imports
15
+ from workbench.api.model import Model
16
+ from workbench.api.endpoint import Endpoint
17
+ from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelFramework, ModelImages
18
+ from workbench.core.artifacts.artifact import Artifact
19
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
20
+ from workbench.model_scripts.script_generation import generate_model_script
21
+ from workbench.utils.config_manager import ConfigManager
22
+ from workbench.utils.model_utils import supported_instance_types
23
+
24
+ # Set up logging
25
+ log = logging.getLogger("workbench")
26
+
27
+
28
+ class MetaModel(Model):
29
+ """MetaModel: A Model that aggregates predictions from child endpoints.
30
+
31
+ Common Usage:
32
+ ```python
33
+ # Create a meta model from existing endpoints
34
+ meta = MetaModel.create(
35
+ name="my-meta-model",
36
+ child_endpoints=["endpoint-1", "endpoint-2", "endpoint-3"],
37
+ target_column="target"
38
+ )
39
+
40
+ # Deploy like any other model
41
+ endpoint = meta.to_endpoint()
42
+ ```
43
+ """
44
+
45
+ @classmethod
46
+ def create(
47
+ cls,
48
+ name: str,
49
+ child_endpoints: list[str],
50
+ target_column: str,
51
+ description: str = None,
52
+ tags: list[str] = None,
53
+ ) -> "MetaModel":
54
+ """Create a new MetaModel from a list of child endpoints.
55
+
56
+ Args:
57
+ name: Name for the meta model
58
+ child_endpoints: List of endpoint names to aggregate
59
+ target_column: Name of the target column (for metadata)
60
+ description: Optional description for the model
61
+ tags: Optional list of tags
62
+
63
+ Returns:
64
+ MetaModel: The created meta model
65
+ """
66
+ Artifact.is_name_valid(name, delimiter="-", lower_case=False)
67
+
68
+ # Validate endpoints and get lineage info from primary endpoint
69
+ feature_list, feature_set_name, model_weights = cls._validate_and_get_lineage(child_endpoints)
70
+
71
+ # Delete existing model if it exists
72
+ log.important(f"Trying to delete existing model {name}...")
73
+ ModelCore.managed_delete(name)
74
+
75
+ # Run training and register model
76
+ aws_clamp = AWSAccountClamp()
77
+ estimator = cls._run_training(name, child_endpoints, target_column, model_weights, aws_clamp)
78
+ cls._register_model(name, child_endpoints, description, tags, estimator, aws_clamp)
79
+
80
+ # Set metadata and onboard
81
+ cls._set_metadata(name, target_column, feature_list, feature_set_name, child_endpoints)
82
+
83
+ log.important(f"MetaModel {name} created successfully!")
84
+ return cls(name)
85
+
86
+ @classmethod
87
+ def _validate_and_get_lineage(cls, child_endpoints: list[str]) -> tuple[list[str], str, dict[str, float]]:
88
+ """Validate child endpoints exist and get lineage info from primary endpoint.
89
+
90
+ Args:
91
+ child_endpoints: List of endpoint names
92
+
93
+ Returns:
94
+ tuple: (feature_list, feature_set_name, model_weights) from the primary endpoint's model
95
+ """
96
+ log.info("Verifying child endpoints and gathering model metrics...")
97
+ mae_scores = {}
98
+
99
+ for ep_name in child_endpoints:
100
+ ep = Endpoint(ep_name)
101
+ if not ep.exists():
102
+ raise ValueError(f"Child endpoint '{ep_name}' does not exist")
103
+
104
+ # Get model MAE from full_inference metrics
105
+ model = Model(ep.get_input())
106
+ metrics = model.get_inference_metrics("full_inference")
107
+ if metrics is not None and "mae" in metrics.columns:
108
+ mae = float(metrics["mae"].iloc[0])
109
+ mae_scores[ep_name] = mae
110
+ log.info(f" {ep_name} -> {model.name}: MAE={mae:.4f}")
111
+ else:
112
+ log.warning(f" {ep_name}: No full_inference metrics found, using default weight")
113
+ mae_scores[ep_name] = None
114
+
115
+ # Compute inverse-MAE weights (higher weight for lower MAE)
116
+ valid_mae = {k: v for k, v in mae_scores.items() if v is not None}
117
+ if valid_mae:
118
+ inv_mae = {k: 1.0 / v for k, v in valid_mae.items()}
119
+ total = sum(inv_mae.values())
120
+ model_weights = {k: v / total for k, v in inv_mae.items()}
121
+ # Fill in missing weights with equal share of remaining weight
122
+ missing = [k for k in mae_scores if mae_scores[k] is None]
123
+ if missing:
124
+ equal_weight = (1.0 - sum(model_weights.values())) / len(missing)
125
+ for k in missing:
126
+ model_weights[k] = equal_weight
127
+ else:
128
+ # No metrics available, use equal weights
129
+ model_weights = {k: 1.0 / len(child_endpoints) for k in child_endpoints}
130
+ log.warning("No MAE metrics found, using equal weights")
131
+
132
+ log.info(f"Model weights: {model_weights}")
133
+
134
+ # Use first endpoint as primary - backtrack to get model and feature set
135
+ primary_endpoint = Endpoint(child_endpoints[0])
136
+ primary_model = Model(primary_endpoint.get_input())
137
+ feature_list = primary_model.features()
138
+ feature_set_name = primary_model.get_input()
139
+
140
+ log.info(
141
+ f"Primary endpoint: {child_endpoints[0]} -> Model: {primary_model.name} -> FeatureSet: {feature_set_name}"
142
+ )
143
+ return feature_list, feature_set_name, model_weights
144
+
145
+ @classmethod
146
+ def _run_training(
147
+ cls,
148
+ name: str,
149
+ child_endpoints: list[str],
150
+ target_column: str,
151
+ model_weights: dict[str, float],
152
+ aws_clamp: AWSAccountClamp,
153
+ ) -> Estimator:
154
+ """Run the minimal training job that saves the meta model config.
155
+
156
+ Args:
157
+ name: Model name
158
+ child_endpoints: List of endpoint names
159
+ target_column: Target column name
160
+ model_weights: Dict mapping endpoint name to weight
161
+ aws_clamp: AWS account clamp
162
+
163
+ Returns:
164
+ Estimator: The fitted estimator
165
+ """
166
+ sm_session = aws_clamp.sagemaker_session()
167
+ cm = ConfigManager()
168
+ workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
169
+ models_s3_path = f"s3://{workbench_bucket}/models"
170
+
171
+ # Generate the model script from template
172
+ template_params = {
173
+ "model_type": ModelType.REGRESSOR,
174
+ "model_framework": ModelFramework.META,
175
+ "child_endpoints": child_endpoints,
176
+ "target_column": target_column,
177
+ "model_weights": model_weights,
178
+ "model_metrics_s3_path": f"{models_s3_path}/{name}/training",
179
+ "aws_region": sm_session.boto_region_name,
180
+ }
181
+ script_path = generate_model_script(template_params)
182
+
183
+ # Create estimator
184
+ training_image = ModelImages.get_image_uri(sm_session.boto_region_name, "meta_training")
185
+ log.info(f"Using Meta Training Image: {training_image}")
186
+ estimator = Estimator(
187
+ entry_point=Path(script_path).name,
188
+ source_dir=str(Path(script_path).parent),
189
+ role=aws_clamp.aws_session.get_workbench_execution_role_arn(),
190
+ instance_count=1,
191
+ instance_type="ml.m5.large",
192
+ sagemaker_session=sm_session,
193
+ image_uri=training_image,
194
+ )
195
+
196
+ # Run training (no input data needed - just saves config)
197
+ log.important(f"Creating MetaModel {name}...")
198
+ estimator.fit()
199
+
200
+ return estimator
201
+
202
+ @classmethod
203
+ def _register_model(
204
+ cls,
205
+ name: str,
206
+ child_endpoints: list[str],
207
+ description: str,
208
+ tags: list[str],
209
+ estimator: Estimator,
210
+ aws_clamp: AWSAccountClamp,
211
+ ):
212
+ """Create model group and register the model.
213
+
214
+ Args:
215
+ name: Model name
216
+ child_endpoints: List of endpoint names
217
+ description: Model description
218
+ tags: Model tags
219
+ estimator: Fitted estimator
220
+ aws_clamp: AWS account clamp
221
+ """
222
+ sm_session = aws_clamp.sagemaker_session()
223
+ model_description = description or f"Meta model aggregating: {', '.join(child_endpoints)}"
224
+
225
+ # Create model group
226
+ aws_clamp.sagemaker_client().create_model_package_group(
227
+ ModelPackageGroupName=name,
228
+ ModelPackageGroupDescription=model_description,
229
+ Tags=[{"Key": "workbench_tags", "Value": "::".join(tags or [name])}],
230
+ )
231
+
232
+ # Register the model with meta inference image
233
+ inference_image = ModelImages.get_image_uri(sm_session.boto_region_name, "meta_inference")
234
+ log.important(f"Registering model {name} with Inference Image {inference_image}...")
235
+ estimator.create_model(role=aws_clamp.aws_session.get_workbench_execution_role_arn()).register(
236
+ model_package_group_name=name,
237
+ image_uri=inference_image,
238
+ content_types=["text/csv"],
239
+ response_types=["text/csv"],
240
+ inference_instances=supported_instance_types("x86_64"),
241
+ transform_instances=["ml.m5.large", "ml.m5.xlarge"],
242
+ approval_status="Approved",
243
+ description=model_description,
244
+ )
245
+
246
+ @classmethod
247
+ def _set_metadata(
248
+ cls, name: str, target_column: str, feature_list: list[str], feature_set_name: str, child_endpoints: list[str]
249
+ ):
250
+ """Set model metadata and onboard.
251
+
252
+ Args:
253
+ name: Model name
254
+ target_column: Target column name
255
+ feature_list: List of feature names
256
+ feature_set_name: Name of the input FeatureSet
257
+ child_endpoints: List of child endpoint names
258
+ """
259
+ time.sleep(3)
260
+ output_model = ModelCore(name)
261
+ output_model._set_model_type(ModelType.UQ_REGRESSOR)
262
+ output_model._set_model_framework(ModelFramework.META)
263
+ output_model.set_input(feature_set_name, force=True)
264
+ output_model.upsert_workbench_meta({"workbench_model_target": target_column})
265
+ output_model.upsert_workbench_meta({"workbench_model_features": feature_list})
266
+ output_model.upsert_workbench_meta({"child_endpoints": child_endpoints})
267
+ output_model.onboard_with_args(ModelType.UQ_REGRESSOR, target_column, feature_list=feature_list)
268
+
269
+
270
+ if __name__ == "__main__":
271
+ """Exercise the MetaModel Class"""
272
+
273
+ meta = MetaModel.create(
274
+ name="logd-meta",
275
+ child_endpoints=["logd-xgb", "logd-pytorch", "logd-chemprop"],
276
+ target_column="logd",
277
+ description="Meta model for LogD prediction",
278
+ tags=["meta", "logd", "ensemble"],
279
+ )
280
+ print(meta.summary())
281
+
282
+ # Create an endpoint for the meta model
283
+ end = meta.to_endpoint(tags=["meta", "logd"])
284
+ end.set_owner("BW")
285
+ end.auto_inference()
286
+
287
+ # Test loading an existing meta model
288
+ meta = MetaModel("logd-meta")
289
+ print(meta.details())
workbench/api/model.py CHANGED
@@ -10,7 +10,12 @@ from workbench.core.artifacts.artifact import Artifact
10
10
  from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelFramework # noqa: F401
11
11
  from workbench.core.transforms.model_to_endpoint.model_to_endpoint import ModelToEndpoint
12
12
  from workbench.api.endpoint import Endpoint
13
- from workbench.utils.model_utils import proximity_model_local, noise_model_local
13
+ from workbench.utils.model_utils import (
14
+ proximity_model_local,
15
+ fingerprint_prox_model_local,
16
+ noise_model_local,
17
+ cleanlab_model_local,
18
+ )
14
19
 
15
20
 
16
21
  class Model(ModelCore):
@@ -83,13 +88,38 @@ class Model(ModelCore):
83
88
  end.set_owner(self.get_owner())
84
89
  return end
85
90
 
86
- def prox_model(self):
91
+ def prox_model(self, include_all_columns: bool = False):
87
92
  """Create a local Proximity Model for this Model
88
93
 
94
+ Args:
95
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
96
+
97
+ Returns:
98
+ FeatureSpaceProximity: A local FeatureSpaceProximity Model
99
+ """
100
+ return proximity_model_local(self, include_all_columns=include_all_columns)
101
+
102
+ def fp_prox_model(
103
+ self,
104
+ include_all_columns: bool = False,
105
+ radius: int = 2,
106
+ n_bits: int = 1024,
107
+ counts: bool = False,
108
+ ):
109
+ """Create a local Fingerprint Proximity Model for this Model
110
+
111
+ Args:
112
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
113
+ radius (int): Morgan fingerprint radius (default: 2)
114
+ n_bits (int): Number of bits for the fingerprint (default: 1024)
115
+ counts (bool): Use count fingerprints instead of binary (default: False)
116
+
89
117
  Returns:
90
- Proximity: A local Proximity Model
118
+ FingerprintProximity: A local FingerprintProximity Model
91
119
  """
92
- return proximity_model_local(self)
120
+ return fingerprint_prox_model_local(
121
+ self, include_all_columns=include_all_columns, radius=radius, n_bits=n_bits, counts=counts
122
+ )
93
123
 
94
124
  def noise_model(self):
95
125
  """Create a local Noise Model for this Model
@@ -99,6 +129,15 @@ class Model(ModelCore):
99
129
  """
100
130
  return noise_model_local(self)
101
131
 
132
+ def cleanlab_model(self):
133
+ """Create a CleanLearning model for this Model's training data.
134
+
135
+ Returns:
136
+ CleanLearning: A fitted cleanlab model. Use get_label_issues() to get
137
+ a DataFrame with id_column, label_quality, predicted_label, given_label, is_label_issue.
138
+ """
139
+ return cleanlab_model_local(self)
140
+
102
141
 
103
142
  if __name__ == "__main__":
104
143
  """Exercise the Model Class"""
@@ -330,12 +330,8 @@ class EndpointCore(Artifact):
330
330
  self.details()
331
331
  return True
332
332
 
333
- def auto_inference(self, capture: bool = False) -> pd.DataFrame:
334
- """Run inference on the endpoint using FeatureSet data
335
-
336
- Args:
337
- capture (bool, optional): Capture the inference results and metrics (default=False)
338
- """
333
+ def auto_inference(self) -> pd.DataFrame:
334
+ """Run inference on the endpoint using the test data from the model training view"""
339
335
 
340
336
  # Sanity Check that we have a model
341
337
  model = ModelCore(self.get_input())
@@ -343,22 +339,40 @@ class EndpointCore(Artifact):
343
339
  self.log.error("No model found for this endpoint. Returning empty DataFrame.")
344
340
  return pd.DataFrame()
345
341
 
346
- # Now get the FeatureSet and make sure it exists
347
- fs = FeatureSetCore(model.get_input())
348
- if not fs.exists():
349
- self.log.error("No FeatureSet found for this endpoint. Returning empty DataFrame.")
342
+ # Grab the evaluation data from the Model's training view
343
+ all_df = model.training_view().pull_dataframe()
344
+ eval_df = all_df[~all_df["training"]]
345
+
346
+ # Remove AWS created columns
347
+ aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
348
+ eval_df = eval_df.drop(columns=aws_cols, errors="ignore")
349
+
350
+ # Run inference
351
+ return self.inference(eval_df, "auto_inference")
352
+
353
+ def full_inference(self) -> pd.DataFrame:
354
+ """Run inference on the endpoint using all the data from the model training view"""
355
+
356
+ # Sanity Check that we have a model
357
+ model = ModelCore(self.get_input())
358
+ if not model.exists():
359
+ self.log.error("No model found for this endpoint. Returning empty DataFrame.")
350
360
  return pd.DataFrame()
351
361
 
352
- # Grab the evaluation data from the FeatureSet
353
- table = model.training_view().table
354
- eval_df = fs.query(f'SELECT * FROM "{table}" where training = FALSE')
355
- capture_name = "auto_inference" if capture else None
356
- return self.inference(eval_df, capture_name, id_column=fs.id_column)
362
+ # Grab the full data from the Model's training view
363
+ eval_df = model.training_view().pull_dataframe()
364
+
365
+ # Remove AWS created columns
366
+ aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
367
+ eval_df = eval_df.drop(columns=aws_cols, errors="ignore")
368
+
369
+ # Run inference
370
+ return self.inference(eval_df, "full_inference")
357
371
 
358
372
  def inference(
359
373
  self, eval_df: pd.DataFrame, capture_name: str = None, id_column: str = None, drop_error_rows: bool = False
360
374
  ) -> pd.DataFrame:
361
- """Run inference and compute performance metrics with optional capture
375
+ """Run inference on the Endpoint using the provided DataFrame
362
376
 
363
377
  Args:
364
378
  eval_df (pd.DataFrame): DataFrame to run predictions on (must have superset of features)
@@ -440,11 +454,14 @@ class EndpointCore(Artifact):
440
454
  # Drop rows with NaN target values for metrics/plots
441
455
  target_df = prediction_df.dropna(subset=[target])
442
456
 
457
+ # For multi-target models, prediction column is {target}_pred, otherwise "prediction"
458
+ pred_col = f"{target}_pred" if is_multi_target else "prediction"
459
+
443
460
  # Compute per-target metrics
444
461
  if model.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
445
- target_metrics = self.regression_metrics(target, target_df)
462
+ target_metrics = self.regression_metrics(target, target_df, prediction_col=pred_col)
446
463
  elif model.model_type == ModelType.CLASSIFIER:
447
- target_metrics = self.classification_metrics(target, target_df)
464
+ target_metrics = self.classification_metrics(target, target_df, prediction_col=pred_col)
448
465
  else:
449
466
  target_metrics = pd.DataFrame()
450
467
 
@@ -476,8 +493,8 @@ class EndpointCore(Artifact):
476
493
  id_column,
477
494
  )
478
495
 
479
- # For UQ Models we also capture the uncertainty metrics
480
- if model.model_type == ModelType.UQ_REGRESSOR:
496
+ # Capture uncertainty metrics if prediction_std is available (UQ, ChemProp, etc.)
497
+ if "prediction_std" in prediction_df.columns:
481
498
  metrics = uq_metrics(prediction_df, primary_target)
482
499
  self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
483
500
 
@@ -525,22 +542,20 @@ class EndpointCore(Artifact):
525
542
  fs = FeatureSetCore(model.get_input())
526
543
  id_column = fs.id_column
527
544
 
528
- # For UQ models, get UQ columns from training CV results and compute metrics
529
- # Note: XGBoost training now saves all UQ columns (q_*, confidence, prediction_std)
530
- additional_columns = []
531
- if model_type == ModelType.UQ_REGRESSOR:
532
- uq_columns = [col for col in out_of_fold_df.columns if col.startswith("q_") or col == "confidence"]
533
- if uq_columns:
534
- additional_columns = uq_columns
535
- self.log.info(f"UQ columns from training: {', '.join(uq_columns)}")
536
- primary_target = targets[0] if isinstance(targets, list) else targets
537
- metrics = uq_metrics(out_of_fold_df, primary_target)
538
- self.param_store.upsert(f"/workbench/models/{model.name}/inference/full_cross_fold", metrics)
539
-
540
545
  # Normalize targets to a list for iteration
541
546
  target_list = targets if isinstance(targets, list) else [targets]
542
547
  primary_target = target_list[0]
543
548
 
549
+ # Collect UQ columns (q_*, confidence) for additional tracking
550
+ additional_columns = [col for col in out_of_fold_df.columns if col.startswith("q_") or col == "confidence"]
551
+ if additional_columns:
552
+ self.log.info(f"UQ columns from training: {', '.join(additional_columns)}")
553
+
554
+ # Capture uncertainty metrics if prediction_std is available (UQ, ChemProp, etc.)
555
+ if "prediction_std" in out_of_fold_df.columns:
556
+ metrics = uq_metrics(out_of_fold_df, primary_target)
557
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/full_cross_fold", metrics)
558
+
544
559
  # For single-target models (99% of cases), just save as "full_cross_fold"
545
560
  # For multi-target models, save each as cv_{target} plus primary as "full_cross_fold"
546
561
  is_multi_target = len(target_list) > 1
@@ -549,11 +564,14 @@ class EndpointCore(Artifact):
549
564
  # Drop rows with NaN target values for metrics/plots
550
565
  target_df = out_of_fold_df.dropna(subset=[target])
551
566
 
567
+ # For multi-target models, prediction column is {target}_pred, otherwise "prediction"
568
+ pred_col = f"{target}_pred" if is_multi_target else "prediction"
569
+
552
570
  # Compute per-target metrics
553
571
  if model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
554
- target_metrics = self.regression_metrics(target, target_df)
572
+ target_metrics = self.regression_metrics(target, target_df, prediction_col=pred_col)
555
573
  elif model_type == ModelType.CLASSIFIER:
556
- target_metrics = self.classification_metrics(target, target_df)
574
+ target_metrics = self.classification_metrics(target, target_df, prediction_col=pred_col)
557
575
  else:
558
576
  target_metrics = pd.DataFrame()
559
577
 
@@ -867,75 +885,39 @@ class EndpointCore(Artifact):
867
885
  target (str): Target column name
868
886
  id_column (str, optional): Name of the ID column
869
887
  """
870
- # Start with ID column if present
888
+ cols = pred_results_df.columns
889
+
890
+ # Build output columns: id, target, prediction, prediction_std, UQ columns, proba columns
871
891
  output_columns = []
872
- if id_column and id_column in pred_results_df.columns:
892
+ if id_column and id_column in cols:
873
893
  output_columns.append(id_column)
874
-
875
- # Add target column if present
876
- if target and target in pred_results_df.columns:
894
+ if target and target in cols:
877
895
  output_columns.append(target)
878
896
 
879
- # Build the output DataFrame
880
- output_df = pred_results_df[output_columns].copy() if output_columns else pd.DataFrame()
881
-
882
- # For multi-task: map {target}_pred -> prediction, {target}_pred_std -> prediction_std
883
- # For single-task: just grab prediction and prediction_std columns directly
884
- pred_col = f"{target}_pred"
885
- std_col = f"{target}_pred_std"
886
- if pred_col in pred_results_df.columns:
887
- # Multi-task columns exist
888
- output_df["prediction"] = pred_results_df[pred_col]
889
- if std_col in pred_results_df.columns:
890
- output_df["prediction_std"] = pred_results_df[std_col]
891
- else:
892
- # Single-task: grab standard prediction columns
893
- for col in ["prediction", "prediction_std"]:
894
- if col in pred_results_df.columns:
895
- output_df[col] = pred_results_df[col]
896
- # Also grab any _proba columns and UQ columns
897
- for col in pred_results_df.columns:
898
- if col.endswith("_proba") or col.startswith("q_") or col == "confidence":
899
- output_df[col] = pred_results_df[col]
897
+ output_columns += [c for c in ["prediction", "prediction_std"] if c in cols]
898
+
899
+ # Add UQ columns (q_*, confidence) and proba columns
900
+ output_columns += [c for c in cols if c.startswith("q_") or c == "confidence" or c.endswith("_proba")]
900
901
 
901
902
  # Write the predictions to S3
902
903
  output_file = f"{inference_capture_path}/inference_predictions.csv"
903
904
  self.log.info(f"Writing predictions to {output_file}")
904
- wr.s3.to_csv(output_df, output_file, index=False)
905
+ wr.s3.to_csv(pred_results_df[output_columns], output_file, index=False)
905
906
 
906
- def regression_metrics(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
907
+ def regression_metrics(
908
+ self, target_column: str, prediction_df: pd.DataFrame, prediction_col: str = "prediction"
909
+ ) -> pd.DataFrame:
907
910
  """Compute the performance metrics for this Endpoint
911
+
908
912
  Args:
909
913
  target_column (str): Name of the target column
910
914
  prediction_df (pd.DataFrame): DataFrame with the prediction results
915
+ prediction_col (str): Name of the prediction column (default: "prediction")
916
+
911
917
  Returns:
912
918
  pd.DataFrame: DataFrame with the performance metrics
913
919
  """
914
-
915
- # Sanity Check the prediction DataFrame
916
- if prediction_df.empty:
917
- self.log.warning("No predictions were made. Returning empty DataFrame.")
918
- return pd.DataFrame()
919
-
920
- # Check for prediction column
921
- if "prediction" not in prediction_df.columns:
922
- self.log.warning("No 'prediction' column found in DataFrame")
923
- return pd.DataFrame()
924
-
925
- # Check for NaN values in target or prediction columns
926
- if prediction_df[target_column].isnull().any() or prediction_df["prediction"].isnull().any():
927
- num_nan_target = prediction_df[target_column].isnull().sum()
928
- num_nan_prediction = prediction_df["prediction"].isnull().sum()
929
- self.log.warning(f"NaNs Found: {target_column} {num_nan_target} and prediction: {num_nan_prediction}.")
930
- self.log.warning("Dropping NaN rows for metric computation.")
931
- prediction_df = prediction_df.dropna(subset=[target_column, "prediction"])
932
-
933
- # Compute the metrics using shared utilities
934
- try:
935
- return compute_regression_metrics(prediction_df, target_column)
936
- except Exception as e:
937
- self.log.warning(f"Error computing regression metrics: {str(e)}")
938
- return pd.DataFrame()
920
+ return compute_regression_metrics(prediction_df, target_column, prediction_col)
939
921
 
940
922
  def residuals(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
941
923
  """Add the residuals to the prediction DataFrame
@@ -965,58 +947,22 @@ class EndpointCore(Artifact):
965
947
 
966
948
  return prediction_df
967
949
 
968
- @staticmethod
969
- def validate_proba_columns(prediction_df: pd.DataFrame, class_labels: list, guessing: bool = False):
970
- """Ensure probability columns are correctly aligned with class labels
971
-
972
- Args:
973
- prediction_df (pd.DataFrame): DataFrame with the prediction results
974
- class_labels (list): List of class labels
975
- guessing (bool, optional): Whether we're guessing the class labels. Defaults to False.
976
- """
977
- proba_columns = [col.replace("_proba", "") for col in prediction_df.columns if col.endswith("_proba")]
978
-
979
- if sorted(class_labels) != sorted(proba_columns):
980
- if guessing:
981
- raise ValueError(f"_proba columns {proba_columns} != GUESSED class_labels {class_labels}!")
982
- else:
983
- raise ValueError(f"_proba columns {proba_columns} != class_labels {class_labels}!")
984
-
985
- def classification_metrics(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
950
+ def classification_metrics(
951
+ self, target_column: str, prediction_df: pd.DataFrame, prediction_col: str = "prediction"
952
+ ) -> pd.DataFrame:
986
953
  """Compute the performance metrics for this Endpoint
987
954
 
988
955
  Args:
989
956
  target_column (str): Name of the target column
990
957
  prediction_df (pd.DataFrame): DataFrame with the prediction results
958
+ prediction_col (str): Name of the prediction column (default: "prediction")
991
959
 
992
960
  Returns:
993
961
  pd.DataFrame: DataFrame with the performance metrics
994
962
  """
995
- # Check for prediction column
996
- if "prediction" not in prediction_df.columns:
997
- self.log.warning("No 'prediction' column found in DataFrame")
998
- return pd.DataFrame()
999
-
1000
- # Drop rows with NaN predictions (can't compute metrics on missing predictions)
1001
- nan_mask = prediction_df["prediction"].isna()
1002
- if nan_mask.any():
1003
- n_nan = nan_mask.sum()
1004
- self.log.warning(f"Dropping {n_nan} rows with NaN predictions for metrics calculation")
1005
- prediction_df = prediction_df[~nan_mask].copy()
1006
-
1007
- # Get the class labels from the model
963
+ # Get class labels from the model (metrics_utils will infer if None)
1008
964
  class_labels = ModelCore(self.model_name).class_labels()
1009
- if class_labels is None:
1010
- self.log.warning(
1011
- "Class labels not found in the model. Guessing class labels from the prediction DataFrame."
1012
- )
1013
- class_labels = prediction_df[target_column].unique().tolist()
1014
- self.validate_proba_columns(prediction_df, class_labels, guessing=True)
1015
- else:
1016
- self.validate_proba_columns(prediction_df, class_labels)
1017
-
1018
- # Compute the metrics using shared utilities (returns per-class + 'all' row)
1019
- return compute_classification_metrics(prediction_df, target_column, class_labels)
965
+ return compute_classification_metrics(prediction_df, target_column, class_labels, prediction_col)
1020
966
 
1021
967
  def generate_confusion_matrix(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
1022
968
  """Compute the confusion matrix for this Endpoint
@@ -574,7 +574,7 @@ class FeatureSetCore(Artifact):
574
574
 
575
575
  TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
576
576
 
577
- @deprecated(version=0.9)
577
+ @deprecated(version="0.9")
578
578
  def set_training_filter(self, filter_expression: Optional[str] = None):
579
579
  """Set a filter expression for the training view for this FeatureSet
580
580