workbench 0.8.213__py3-none-any.whl → 0.8.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
  3. workbench/algorithms/dataframe/projection_2d.py +38 -21
  4. workbench/algorithms/dataframe/proximity.py +75 -150
  5. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  6. workbench/algorithms/models/cleanlab_model.py +382 -0
  7. workbench/algorithms/models/noise_model.py +2 -2
  8. workbench/algorithms/sql/outliers.py +3 -3
  9. workbench/api/__init__.py +3 -0
  10. workbench/api/endpoint.py +10 -5
  11. workbench/api/feature_set.py +76 -6
  12. workbench/api/meta_model.py +289 -0
  13. workbench/api/model.py +43 -4
  14. workbench/core/artifacts/endpoint_core.py +65 -117
  15. workbench/core/artifacts/feature_set_core.py +3 -3
  16. workbench/core/artifacts/model_core.py +6 -4
  17. workbench/core/pipelines/pipeline_executor.py +1 -1
  18. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
  19. workbench/model_script_utils/model_script_utils.py +15 -11
  20. workbench/model_script_utils/pytorch_utils.py +11 -1
  21. workbench/model_scripts/chemprop/chemprop.template +147 -71
  22. workbench/model_scripts/chemprop/generated_model_script.py +151 -75
  23. workbench/model_scripts/chemprop/model_script_utils.py +15 -11
  24. workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
  25. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  26. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  27. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  28. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  29. workbench/model_scripts/meta_model/meta_model.template +209 -0
  30. workbench/model_scripts/pytorch_model/generated_model_script.py +45 -27
  31. workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
  32. workbench/model_scripts/pytorch_model/pytorch.template +42 -24
  33. workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
  34. workbench/model_scripts/script_generation.py +4 -0
  35. workbench/model_scripts/xgb_model/generated_model_script.py +167 -156
  36. workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
  37. workbench/model_scripts/xgb_model/xgb_model.template +163 -152
  38. workbench/repl/workbench_shell.py +0 -5
  39. workbench/scripts/endpoint_test.py +2 -2
  40. workbench/scripts/meta_model_sim.py +35 -0
  41. workbench/utils/chem_utils/fingerprints.py +87 -46
  42. workbench/utils/chemprop_utils.py +23 -5
  43. workbench/utils/meta_model_simulator.py +499 -0
  44. workbench/utils/metrics_utils.py +94 -10
  45. workbench/utils/model_utils.py +91 -9
  46. workbench/utils/pytorch_utils.py +1 -1
  47. workbench/utils/shap_utils.py +1 -55
  48. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  49. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/METADATA +2 -1
  50. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/RECORD +54 -50
  51. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/entry_points.txt +1 -0
  52. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  53. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  54. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  55. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  56. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/WHEEL +0 -0
  57. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/licenses/LICENSE +0 -0
  58. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,289 @@
1
+ """MetaModel: A Model that aggregates predictions from multiple child endpoints.
2
+
3
+ MetaModels don't train on feature data - they combine predictions from existing
4
+ endpoints using confidence-weighted voting. This provides ensemble benefits
5
+ across different model frameworks (XGBoost, PyTorch, ChemProp, etc.).
6
+ """
7
+
8
+ from pathlib import Path
9
+ import time
10
+ import logging
11
+
12
+ from sagemaker.estimator import Estimator
13
+
14
+ # Workbench Imports
15
+ from workbench.api.model import Model
16
+ from workbench.api.endpoint import Endpoint
17
+ from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelFramework, ModelImages
18
+ from workbench.core.artifacts.artifact import Artifact
19
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
20
+ from workbench.model_scripts.script_generation import generate_model_script
21
+ from workbench.utils.config_manager import ConfigManager
22
+ from workbench.utils.model_utils import supported_instance_types
23
+
24
+ # Set up logging
25
+ log = logging.getLogger("workbench")
26
+
27
+
28
+ class MetaModel(Model):
29
+ """MetaModel: A Model that aggregates predictions from child endpoints.
30
+
31
+ Common Usage:
32
+ ```python
33
+ # Create a meta model from existing endpoints
34
+ meta = MetaModel.create(
35
+ name="my-meta-model",
36
+ child_endpoints=["endpoint-1", "endpoint-2", "endpoint-3"],
37
+ target_column="target"
38
+ )
39
+
40
+ # Deploy like any other model
41
+ endpoint = meta.to_endpoint()
42
+ ```
43
+ """
44
+
45
+ @classmethod
46
+ def create(
47
+ cls,
48
+ name: str,
49
+ child_endpoints: list[str],
50
+ target_column: str,
51
+ description: str = None,
52
+ tags: list[str] = None,
53
+ ) -> "MetaModel":
54
+ """Create a new MetaModel from a list of child endpoints.
55
+
56
+ Args:
57
+ name: Name for the meta model
58
+ child_endpoints: List of endpoint names to aggregate
59
+ target_column: Name of the target column (for metadata)
60
+ description: Optional description for the model
61
+ tags: Optional list of tags
62
+
63
+ Returns:
64
+ MetaModel: The created meta model
65
+ """
66
+ Artifact.is_name_valid(name, delimiter="-", lower_case=False)
67
+
68
+ # Validate endpoints and get lineage info from primary endpoint
69
+ feature_list, feature_set_name, model_weights = cls._validate_and_get_lineage(child_endpoints)
70
+
71
+ # Delete existing model if it exists
72
+ log.important(f"Trying to delete existing model {name}...")
73
+ ModelCore.managed_delete(name)
74
+
75
+ # Run training and register model
76
+ aws_clamp = AWSAccountClamp()
77
+ estimator = cls._run_training(name, child_endpoints, target_column, model_weights, aws_clamp)
78
+ cls._register_model(name, child_endpoints, description, tags, estimator, aws_clamp)
79
+
80
+ # Set metadata and onboard
81
+ cls._set_metadata(name, target_column, feature_list, feature_set_name, child_endpoints)
82
+
83
+ log.important(f"MetaModel {name} created successfully!")
84
+ return cls(name)
85
+
86
+ @classmethod
87
+ def _validate_and_get_lineage(cls, child_endpoints: list[str]) -> tuple[list[str], str, dict[str, float]]:
88
+ """Validate child endpoints exist and get lineage info from primary endpoint.
89
+
90
+ Args:
91
+ child_endpoints: List of endpoint names
92
+
93
+ Returns:
94
+ tuple: (feature_list, feature_set_name, model_weights) from the primary endpoint's model
95
+ """
96
+ log.info("Verifying child endpoints and gathering model metrics...")
97
+ mae_scores = {}
98
+
99
+ for ep_name in child_endpoints:
100
+ ep = Endpoint(ep_name)
101
+ if not ep.exists():
102
+ raise ValueError(f"Child endpoint '{ep_name}' does not exist")
103
+
104
+ # Get model MAE from full_inference metrics
105
+ model = Model(ep.get_input())
106
+ metrics = model.get_inference_metrics("full_inference")
107
+ if metrics is not None and "mae" in metrics.columns:
108
+ mae = float(metrics["mae"].iloc[0])
109
+ mae_scores[ep_name] = mae
110
+ log.info(f" {ep_name} -> {model.name}: MAE={mae:.4f}")
111
+ else:
112
+ log.warning(f" {ep_name}: No full_inference metrics found, using default weight")
113
+ mae_scores[ep_name] = None
114
+
115
+ # Compute inverse-MAE weights (higher weight for lower MAE)
116
+ valid_mae = {k: v for k, v in mae_scores.items() if v is not None}
117
+ if valid_mae:
118
+ inv_mae = {k: 1.0 / v for k, v in valid_mae.items()}
119
+ total = sum(inv_mae.values())
120
+ model_weights = {k: v / total for k, v in inv_mae.items()}
121
+ # Fill in missing weights with equal share of remaining weight
122
+ missing = [k for k in mae_scores if mae_scores[k] is None]
123
+ if missing:
124
+ equal_weight = (1.0 - sum(model_weights.values())) / len(missing)
125
+ for k in missing:
126
+ model_weights[k] = equal_weight
127
+ else:
128
+ # No metrics available, use equal weights
129
+ model_weights = {k: 1.0 / len(child_endpoints) for k in child_endpoints}
130
+ log.warning("No MAE metrics found, using equal weights")
131
+
132
+ log.info(f"Model weights: {model_weights}")
133
+
134
+ # Use first endpoint as primary - backtrack to get model and feature set
135
+ primary_endpoint = Endpoint(child_endpoints[0])
136
+ primary_model = Model(primary_endpoint.get_input())
137
+ feature_list = primary_model.features()
138
+ feature_set_name = primary_model.get_input()
139
+
140
+ log.info(
141
+ f"Primary endpoint: {child_endpoints[0]} -> Model: {primary_model.name} -> FeatureSet: {feature_set_name}"
142
+ )
143
+ return feature_list, feature_set_name, model_weights
144
+
145
+ @classmethod
146
+ def _run_training(
147
+ cls,
148
+ name: str,
149
+ child_endpoints: list[str],
150
+ target_column: str,
151
+ model_weights: dict[str, float],
152
+ aws_clamp: AWSAccountClamp,
153
+ ) -> Estimator:
154
+ """Run the minimal training job that saves the meta model config.
155
+
156
+ Args:
157
+ name: Model name
158
+ child_endpoints: List of endpoint names
159
+ target_column: Target column name
160
+ model_weights: Dict mapping endpoint name to weight
161
+ aws_clamp: AWS account clamp
162
+
163
+ Returns:
164
+ Estimator: The fitted estimator
165
+ """
166
+ sm_session = aws_clamp.sagemaker_session()
167
+ cm = ConfigManager()
168
+ workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
169
+ models_s3_path = f"s3://{workbench_bucket}/models"
170
+
171
+ # Generate the model script from template
172
+ template_params = {
173
+ "model_type": ModelType.REGRESSOR,
174
+ "model_framework": ModelFramework.META,
175
+ "child_endpoints": child_endpoints,
176
+ "target_column": target_column,
177
+ "model_weights": model_weights,
178
+ "model_metrics_s3_path": f"{models_s3_path}/{name}/training",
179
+ "aws_region": sm_session.boto_region_name,
180
+ }
181
+ script_path = generate_model_script(template_params)
182
+
183
+ # Create estimator
184
+ training_image = ModelImages.get_image_uri(sm_session.boto_region_name, "meta_training")
185
+ log.info(f"Using Meta Training Image: {training_image}")
186
+ estimator = Estimator(
187
+ entry_point=Path(script_path).name,
188
+ source_dir=str(Path(script_path).parent),
189
+ role=aws_clamp.aws_session.get_workbench_execution_role_arn(),
190
+ instance_count=1,
191
+ instance_type="ml.m5.large",
192
+ sagemaker_session=sm_session,
193
+ image_uri=training_image,
194
+ )
195
+
196
+ # Run training (no input data needed - just saves config)
197
+ log.important(f"Creating MetaModel {name}...")
198
+ estimator.fit()
199
+
200
+ return estimator
201
+
202
+ @classmethod
203
+ def _register_model(
204
+ cls,
205
+ name: str,
206
+ child_endpoints: list[str],
207
+ description: str,
208
+ tags: list[str],
209
+ estimator: Estimator,
210
+ aws_clamp: AWSAccountClamp,
211
+ ):
212
+ """Create model group and register the model.
213
+
214
+ Args:
215
+ name: Model name
216
+ child_endpoints: List of endpoint names
217
+ description: Model description
218
+ tags: Model tags
219
+ estimator: Fitted estimator
220
+ aws_clamp: AWS account clamp
221
+ """
222
+ sm_session = aws_clamp.sagemaker_session()
223
+ model_description = description or f"Meta model aggregating: {', '.join(child_endpoints)}"
224
+
225
+ # Create model group
226
+ aws_clamp.sagemaker_client().create_model_package_group(
227
+ ModelPackageGroupName=name,
228
+ ModelPackageGroupDescription=model_description,
229
+ Tags=[{"Key": "workbench_tags", "Value": "::".join(tags or [name])}],
230
+ )
231
+
232
+ # Register the model with meta inference image
233
+ inference_image = ModelImages.get_image_uri(sm_session.boto_region_name, "meta_inference")
234
+ log.important(f"Registering model {name} with Inference Image {inference_image}...")
235
+ estimator.create_model(role=aws_clamp.aws_session.get_workbench_execution_role_arn()).register(
236
+ model_package_group_name=name,
237
+ image_uri=inference_image,
238
+ content_types=["text/csv"],
239
+ response_types=["text/csv"],
240
+ inference_instances=supported_instance_types("x86_64"),
241
+ transform_instances=["ml.m5.large", "ml.m5.xlarge"],
242
+ approval_status="Approved",
243
+ description=model_description,
244
+ )
245
+
246
+ @classmethod
247
+ def _set_metadata(
248
+ cls, name: str, target_column: str, feature_list: list[str], feature_set_name: str, child_endpoints: list[str]
249
+ ):
250
+ """Set model metadata and onboard.
251
+
252
+ Args:
253
+ name: Model name
254
+ target_column: Target column name
255
+ feature_list: List of feature names
256
+ feature_set_name: Name of the input FeatureSet
257
+ child_endpoints: List of child endpoint names
258
+ """
259
+ time.sleep(3)
260
+ output_model = ModelCore(name)
261
+ output_model._set_model_type(ModelType.UQ_REGRESSOR)
262
+ output_model._set_model_framework(ModelFramework.META)
263
+ output_model.set_input(feature_set_name, force=True)
264
+ output_model.upsert_workbench_meta({"workbench_model_target": target_column})
265
+ output_model.upsert_workbench_meta({"workbench_model_features": feature_list})
266
+ output_model.upsert_workbench_meta({"child_endpoints": child_endpoints})
267
+ output_model.onboard_with_args(ModelType.UQ_REGRESSOR, target_column, feature_list=feature_list)
268
+
269
+
270
+ if __name__ == "__main__":
271
+ """Exercise the MetaModel Class"""
272
+
273
+ meta = MetaModel.create(
274
+ name="logd-meta",
275
+ child_endpoints=["logd-xgb", "logd-pytorch", "logd-chemprop"],
276
+ target_column="logd",
277
+ description="Meta model for LogD prediction",
278
+ tags=["meta", "logd", "ensemble"],
279
+ )
280
+ print(meta.summary())
281
+
282
+ # Create an endpoint for the meta model
283
+ end = meta.to_endpoint(tags=["meta", "logd"])
284
+ end.set_owner("BW")
285
+ end.auto_inference()
286
+
287
+ # Test loading an existing meta model
288
+ meta = MetaModel("logd-meta")
289
+ print(meta.details())
workbench/api/model.py CHANGED
@@ -10,7 +10,12 @@ from workbench.core.artifacts.artifact import Artifact
10
10
  from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelFramework # noqa: F401
11
11
  from workbench.core.transforms.model_to_endpoint.model_to_endpoint import ModelToEndpoint
12
12
  from workbench.api.endpoint import Endpoint
13
- from workbench.utils.model_utils import proximity_model_local, noise_model_local
13
+ from workbench.utils.model_utils import (
14
+ proximity_model_local,
15
+ fingerprint_prox_model_local,
16
+ noise_model_local,
17
+ cleanlab_model_local,
18
+ )
14
19
 
15
20
 
16
21
  class Model(ModelCore):
@@ -83,13 +88,38 @@ class Model(ModelCore):
83
88
  end.set_owner(self.get_owner())
84
89
  return end
85
90
 
86
- def prox_model(self):
91
+ def prox_model(self, include_all_columns: bool = False):
87
92
  """Create a local Proximity Model for this Model
88
93
 
94
+ Args:
95
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
96
+
97
+ Returns:
98
+ FeatureSpaceProximity: A local FeatureSpaceProximity Model
99
+ """
100
+ return proximity_model_local(self, include_all_columns=include_all_columns)
101
+
102
+ def fp_prox_model(
103
+ self,
104
+ include_all_columns: bool = False,
105
+ radius: int = 2,
106
+ n_bits: int = 1024,
107
+ counts: bool = False,
108
+ ):
109
+ """Create a local Fingerprint Proximity Model for this Model
110
+
111
+ Args:
112
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
113
+ radius (int): Morgan fingerprint radius (default: 2)
114
+ n_bits (int): Number of bits for the fingerprint (default: 1024)
115
+ counts (bool): Use count fingerprints instead of binary (default: False)
116
+
89
117
  Returns:
90
- Proximity: A local Proximity Model
118
+ FingerprintProximity: A local FingerprintProximity Model
91
119
  """
92
- return proximity_model_local(self)
120
+ return fingerprint_prox_model_local(
121
+ self, include_all_columns=include_all_columns, radius=radius, n_bits=n_bits, counts=counts
122
+ )
93
123
 
94
124
  def noise_model(self):
95
125
  """Create a local Noise Model for this Model
@@ -99,6 +129,15 @@ class Model(ModelCore):
99
129
  """
100
130
  return noise_model_local(self)
101
131
 
132
+ def cleanlab_model(self):
133
+ """Create a CleanLearning model for this Model's training data.
134
+
135
+ Returns:
136
+ CleanLearning: A fitted cleanlab model. Use get_label_issues() to get
137
+ a DataFrame with id_column, label_quality, predicted_label, given_label, is_label_issue.
138
+ """
139
+ return cleanlab_model_local(self)
140
+
102
141
 
103
142
  if __name__ == "__main__":
104
143
  """Exercise the Model Class"""
@@ -330,12 +330,8 @@ class EndpointCore(Artifact):
330
330
  self.details()
331
331
  return True
332
332
 
333
- def auto_inference(self, capture: bool = False) -> pd.DataFrame:
334
- """Run inference on the endpoint using FeatureSet data
335
-
336
- Args:
337
- capture (bool, optional): Capture the inference results and metrics (default=False)
338
- """
333
+ def auto_inference(self) -> pd.DataFrame:
334
+ """Run inference on the endpoint using the test data from the model training view"""
339
335
 
340
336
  # Sanity Check that we have a model
341
337
  model = ModelCore(self.get_input())
@@ -343,22 +339,40 @@ class EndpointCore(Artifact):
343
339
  self.log.error("No model found for this endpoint. Returning empty DataFrame.")
344
340
  return pd.DataFrame()
345
341
 
346
- # Now get the FeatureSet and make sure it exists
347
- fs = FeatureSetCore(model.get_input())
348
- if not fs.exists():
349
- self.log.error("No FeatureSet found for this endpoint. Returning empty DataFrame.")
342
+ # Grab the evaluation data from the Model's training view
343
+ all_df = model.training_view().pull_dataframe()
344
+ eval_df = all_df[~all_df["training"]]
345
+
346
+ # Remove AWS created columns
347
+ aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
348
+ eval_df = eval_df.drop(columns=aws_cols, errors="ignore")
349
+
350
+ # Run inference
351
+ return self.inference(eval_df, "auto_inference")
352
+
353
+ def full_inference(self) -> pd.DataFrame:
354
+ """Run inference on the endpoint using all the data from the model training view"""
355
+
356
+ # Sanity Check that we have a model
357
+ model = ModelCore(self.get_input())
358
+ if not model.exists():
359
+ self.log.error("No model found for this endpoint. Returning empty DataFrame.")
350
360
  return pd.DataFrame()
351
361
 
352
- # Grab the evaluation data from the FeatureSet
353
- table = model.training_view().table
354
- eval_df = fs.query(f'SELECT * FROM "{table}" where training = FALSE')
355
- capture_name = "auto_inference" if capture else None
356
- return self.inference(eval_df, capture_name, id_column=fs.id_column)
362
+ # Grab the full data from the Model's training view
363
+ eval_df = model.training_view().pull_dataframe()
364
+
365
+ # Remove AWS created columns
366
+ aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
367
+ eval_df = eval_df.drop(columns=aws_cols, errors="ignore")
368
+
369
+ # Run inference
370
+ return self.inference(eval_df, "full_inference")
357
371
 
358
372
  def inference(
359
373
  self, eval_df: pd.DataFrame, capture_name: str = None, id_column: str = None, drop_error_rows: bool = False
360
374
  ) -> pd.DataFrame:
361
- """Run inference and compute performance metrics with optional capture
375
+ """Run inference on the Endpoint using the provided DataFrame
362
376
 
363
377
  Args:
364
378
  eval_df (pd.DataFrame): DataFrame to run predictions on (must have superset of features)
@@ -396,7 +410,7 @@ class EndpointCore(Artifact):
396
410
  primary_target = targets
397
411
 
398
412
  # Sanity Check that the target column is present
399
- if primary_target and (primary_target not in prediction_df.columns):
413
+ if primary_target not in prediction_df.columns:
400
414
  self.log.important(f"Target Column {primary_target} not found in prediction_df!")
401
415
  self.log.important("In order to compute metrics, the target column must be present!")
402
416
  metrics = pd.DataFrame()
@@ -418,7 +432,7 @@ class EndpointCore(Artifact):
418
432
  print(metrics.head())
419
433
 
420
434
  # Capture the inference results and metrics
421
- if capture_name is not None:
435
+ if primary_target and capture_name:
422
436
 
423
437
  # If we don't have an id_column, we'll pull it from the model's FeatureSet
424
438
  if id_column is None:
@@ -440,11 +454,14 @@ class EndpointCore(Artifact):
440
454
  # Drop rows with NaN target values for metrics/plots
441
455
  target_df = prediction_df.dropna(subset=[target])
442
456
 
457
+ # For multi-target models, prediction column is {target}_pred, otherwise "prediction"
458
+ pred_col = f"{target}_pred" if is_multi_target else "prediction"
459
+
443
460
  # Compute per-target metrics
444
461
  if model.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
445
- target_metrics = self.regression_metrics(target, target_df)
462
+ target_metrics = self.regression_metrics(target, target_df, prediction_col=pred_col)
446
463
  elif model.model_type == ModelType.CLASSIFIER:
447
- target_metrics = self.classification_metrics(target, target_df)
464
+ target_metrics = self.classification_metrics(target, target_df, prediction_col=pred_col)
448
465
  else:
449
466
  target_metrics = pd.DataFrame()
450
467
 
@@ -547,11 +564,14 @@ class EndpointCore(Artifact):
547
564
  # Drop rows with NaN target values for metrics/plots
548
565
  target_df = out_of_fold_df.dropna(subset=[target])
549
566
 
567
+ # For multi-target models, prediction column is {target}_pred, otherwise "prediction"
568
+ pred_col = f"{target}_pred" if is_multi_target else "prediction"
569
+
550
570
  # Compute per-target metrics
551
571
  if model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
552
- target_metrics = self.regression_metrics(target, target_df)
572
+ target_metrics = self.regression_metrics(target, target_df, prediction_col=pred_col)
553
573
  elif model_type == ModelType.CLASSIFIER:
554
- target_metrics = self.classification_metrics(target, target_df)
574
+ target_metrics = self.classification_metrics(target, target_df, prediction_col=pred_col)
555
575
  else:
556
576
  target_metrics = pd.DataFrame()
557
577
 
@@ -865,75 +885,39 @@ class EndpointCore(Artifact):
865
885
  target (str): Target column name
866
886
  id_column (str, optional): Name of the ID column
867
887
  """
868
- # Start with ID column if present
888
+ cols = pred_results_df.columns
889
+
890
+ # Build output columns: id, target, prediction, prediction_std, UQ columns, proba columns
869
891
  output_columns = []
870
- if id_column and id_column in pred_results_df.columns:
892
+ if id_column and id_column in cols:
871
893
  output_columns.append(id_column)
872
-
873
- # Add target column if present
874
- if target and target in pred_results_df.columns:
894
+ if target and target in cols:
875
895
  output_columns.append(target)
876
896
 
877
- # Build the output DataFrame
878
- output_df = pred_results_df[output_columns].copy() if output_columns else pd.DataFrame()
879
-
880
- # For multi-task: map {target}_pred -> prediction, {target}_pred_std -> prediction_std
881
- # For single-task: just grab prediction and prediction_std columns directly
882
- pred_col = f"{target}_pred"
883
- std_col = f"{target}_pred_std"
884
- if pred_col in pred_results_df.columns:
885
- # Multi-task columns exist
886
- output_df["prediction"] = pred_results_df[pred_col]
887
- if std_col in pred_results_df.columns:
888
- output_df["prediction_std"] = pred_results_df[std_col]
889
- else:
890
- # Single-task: grab standard prediction columns
891
- for col in ["prediction", "prediction_std"]:
892
- if col in pred_results_df.columns:
893
- output_df[col] = pred_results_df[col]
894
- # Also grab any _proba columns and UQ columns
895
- for col in pred_results_df.columns:
896
- if col.endswith("_proba") or col.startswith("q_") or col == "confidence":
897
- output_df[col] = pred_results_df[col]
897
+ output_columns += [c for c in ["prediction", "prediction_std"] if c in cols]
898
+
899
+ # Add UQ columns (q_*, confidence) and proba columns
900
+ output_columns += [c for c in cols if c.startswith("q_") or c == "confidence" or c.endswith("_proba")]
898
901
 
899
902
  # Write the predictions to S3
900
903
  output_file = f"{inference_capture_path}/inference_predictions.csv"
901
904
  self.log.info(f"Writing predictions to {output_file}")
902
- wr.s3.to_csv(output_df, output_file, index=False)
905
+ wr.s3.to_csv(pred_results_df[output_columns], output_file, index=False)
903
906
 
904
- def regression_metrics(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
907
+ def regression_metrics(
908
+ self, target_column: str, prediction_df: pd.DataFrame, prediction_col: str = "prediction"
909
+ ) -> pd.DataFrame:
905
910
  """Compute the performance metrics for this Endpoint
911
+
906
912
  Args:
907
913
  target_column (str): Name of the target column
908
914
  prediction_df (pd.DataFrame): DataFrame with the prediction results
915
+ prediction_col (str): Name of the prediction column (default: "prediction")
916
+
909
917
  Returns:
910
918
  pd.DataFrame: DataFrame with the performance metrics
911
919
  """
912
-
913
- # Sanity Check the prediction DataFrame
914
- if prediction_df.empty:
915
- self.log.warning("No predictions were made. Returning empty DataFrame.")
916
- return pd.DataFrame()
917
-
918
- # Check for prediction column
919
- if "prediction" not in prediction_df.columns:
920
- self.log.warning("No 'prediction' column found in DataFrame")
921
- return pd.DataFrame()
922
-
923
- # Check for NaN values in target or prediction columns
924
- if prediction_df[target_column].isnull().any() or prediction_df["prediction"].isnull().any():
925
- num_nan_target = prediction_df[target_column].isnull().sum()
926
- num_nan_prediction = prediction_df["prediction"].isnull().sum()
927
- self.log.warning(f"NaNs Found: {target_column} {num_nan_target} and prediction: {num_nan_prediction}.")
928
- self.log.warning("Dropping NaN rows for metric computation.")
929
- prediction_df = prediction_df.dropna(subset=[target_column, "prediction"])
930
-
931
- # Compute the metrics using shared utilities
932
- try:
933
- return compute_regression_metrics(prediction_df, target_column)
934
- except Exception as e:
935
- self.log.warning(f"Error computing regression metrics: {str(e)}")
936
- return pd.DataFrame()
920
+ return compute_regression_metrics(prediction_df, target_column, prediction_col)
937
921
 
938
922
  def residuals(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
939
923
  """Add the residuals to the prediction DataFrame
@@ -963,58 +947,22 @@ class EndpointCore(Artifact):
963
947
 
964
948
  return prediction_df
965
949
 
966
- @staticmethod
967
- def validate_proba_columns(prediction_df: pd.DataFrame, class_labels: list, guessing: bool = False):
968
- """Ensure probability columns are correctly aligned with class labels
969
-
970
- Args:
971
- prediction_df (pd.DataFrame): DataFrame with the prediction results
972
- class_labels (list): List of class labels
973
- guessing (bool, optional): Whether we're guessing the class labels. Defaults to False.
974
- """
975
- proba_columns = [col.replace("_proba", "") for col in prediction_df.columns if col.endswith("_proba")]
976
-
977
- if sorted(class_labels) != sorted(proba_columns):
978
- if guessing:
979
- raise ValueError(f"_proba columns {proba_columns} != GUESSED class_labels {class_labels}!")
980
- else:
981
- raise ValueError(f"_proba columns {proba_columns} != class_labels {class_labels}!")
982
-
983
- def classification_metrics(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
950
+ def classification_metrics(
951
+ self, target_column: str, prediction_df: pd.DataFrame, prediction_col: str = "prediction"
952
+ ) -> pd.DataFrame:
984
953
  """Compute the performance metrics for this Endpoint
985
954
 
986
955
  Args:
987
956
  target_column (str): Name of the target column
988
957
  prediction_df (pd.DataFrame): DataFrame with the prediction results
958
+ prediction_col (str): Name of the prediction column (default: "prediction")
989
959
 
990
960
  Returns:
991
961
  pd.DataFrame: DataFrame with the performance metrics
992
962
  """
993
- # Check for prediction column
994
- if "prediction" not in prediction_df.columns:
995
- self.log.warning("No 'prediction' column found in DataFrame")
996
- return pd.DataFrame()
997
-
998
- # Drop rows with NaN predictions (can't compute metrics on missing predictions)
999
- nan_mask = prediction_df["prediction"].isna()
1000
- if nan_mask.any():
1001
- n_nan = nan_mask.sum()
1002
- self.log.warning(f"Dropping {n_nan} rows with NaN predictions for metrics calculation")
1003
- prediction_df = prediction_df[~nan_mask].copy()
1004
-
1005
- # Get the class labels from the model
963
+ # Get class labels from the model (metrics_utils will infer if None)
1006
964
  class_labels = ModelCore(self.model_name).class_labels()
1007
- if class_labels is None:
1008
- self.log.warning(
1009
- "Class labels not found in the model. Guessing class labels from the prediction DataFrame."
1010
- )
1011
- class_labels = prediction_df[target_column].unique().tolist()
1012
- self.validate_proba_columns(prediction_df, class_labels, guessing=True)
1013
- else:
1014
- self.validate_proba_columns(prediction_df, class_labels)
1015
-
1016
- # Compute the metrics using shared utilities (returns per-class + 'all' row)
1017
- return compute_classification_metrics(prediction_df, target_column, class_labels)
965
+ return compute_classification_metrics(prediction_df, target_column, class_labels, prediction_col)
1018
966
 
1019
967
  def generate_confusion_matrix(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
1020
968
  """Compute the confusion matrix for this Endpoint
@@ -247,7 +247,7 @@ class FeatureSetCore(Artifact):
247
247
 
248
248
  # Set the compressed features in our FeatureSet metadata
249
249
  self.log.important(f"Setting Compressed Columns...{compressed_columns}")
250
- self.upsert_workbench_meta({"comp_features": compressed_columns})
250
+ self.upsert_workbench_meta({"compressed_features": compressed_columns})
251
251
 
252
252
  def get_compressed_features(self) -> list[str]:
253
253
  """Get the compressed features for this FeatureSet
@@ -256,7 +256,7 @@ class FeatureSetCore(Artifact):
256
256
  list[str]: The compressed columns for this FeatureSet
257
257
  """
258
258
  # Get the compressed features from our FeatureSet metadata
259
- return self.workbench_meta().get("comp_features", [])
259
+ return self.workbench_meta().get("compressed_features", [])
260
260
 
261
261
  def num_columns(self) -> int:
262
262
  """Return the number of columns of the Feature Set"""
@@ -574,7 +574,7 @@ class FeatureSetCore(Artifact):
574
574
 
575
575
  TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
576
576
 
577
- @deprecated(version=0.9)
577
+ @deprecated(version="0.9")
578
578
  def set_training_filter(self, filter_expression: Optional[str] = None):
579
579
  """Set a filter expression for the training view for this FeatureSet
580
580