workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -30,13 +30,15 @@ from sagemaker import Predictor
30
30
 
31
31
  # Workbench Imports
32
32
  from workbench.core.artifacts.artifact import Artifact
33
- from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType
33
+ from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType, ModelFramework
34
34
  from workbench.utils.endpoint_metrics import EndpointMetrics
35
- from workbench.utils.fast_inference import fast_inference
36
35
  from workbench.utils.cache import Cache
37
36
  from workbench.utils.s3_utils import compute_s3_object_hash
38
37
  from workbench.utils.model_utils import uq_metrics
39
- from workbench.utils.xgboost_model_utils import cross_fold_inference
38
+ from workbench.utils.xgboost_model_utils import cross_fold_inference as xgboost_cross_fold
39
+ from workbench.utils.pytorch_utils import pull_cv_results as pytorch_pull_cv
40
+ from workbench.utils.chemprop_utils import pull_cv_results as chemprop_pull_cv
41
+ from workbench_bridges.endpoints.fast_inference import fast_inference
40
42
 
41
43
 
42
44
  class EndpointCore(Artifact):
@@ -164,11 +166,17 @@ class EndpointCore(Artifact):
164
166
  """
165
167
  return "Serverless" in self.endpoint_meta["InstanceType"]
166
168
 
167
- def add_data_capture(self):
169
+ def data_capture(self):
170
+ """Get the MonitorCore class for this endpoint"""
171
+ from workbench.core.artifacts.data_capture_core import DataCaptureCore
172
+
173
+ return DataCaptureCore(self.endpoint_name)
174
+
175
+ def enable_data_capture(self):
168
176
  """Add data capture to the endpoint"""
169
- self.get_monitor().add_data_capture()
177
+ self.data_capture().enable()
170
178
 
171
- def get_monitor(self):
179
+ def monitor(self):
172
180
  """Get the MonitorCore class for this endpoint"""
173
181
  from workbench.core.artifacts.monitor_core import MonitorCore
174
182
 
@@ -350,7 +358,7 @@ class EndpointCore(Artifact):
350
358
  return pd.DataFrame()
351
359
 
352
360
  # Grab the evaluation data from the FeatureSet
353
- table = fs.view("training").table
361
+ table = model.training_view().table
354
362
  eval_df = fs.query(f'SELECT * FROM "{table}" where training = FALSE')
355
363
  capture_name = "auto_inference" if capture else None
356
364
  return self.inference(eval_df, capture_name, id_column=fs.id_column)
@@ -378,63 +386,150 @@ class EndpointCore(Artifact):
378
386
  self.log.important("No model associated with this endpoint, running 'no frills' inference...")
379
387
  return self.fast_inference(eval_df)
380
388
 
389
+ # Grab the model features and target column
390
+ model = ModelCore(self.model_name)
391
+ features = model.features()
392
+ target_column = model.target()
393
+
381
394
  # Run predictions on the evaluation data
382
- prediction_df = self._predict(eval_df, drop_error_rows)
395
+ prediction_df = self._predict(eval_df, features, drop_error_rows)
383
396
  if prediction_df.empty:
384
397
  self.log.warning("No predictions were made. Returning empty DataFrame.")
385
398
  return prediction_df
386
399
 
387
- # Get the target column
388
- model = ModelCore(self.model_name)
389
- target_column = model.target()
390
-
391
400
  # Sanity Check that the target column is present
392
401
  if target_column and (target_column not in prediction_df.columns):
393
402
  self.log.important(f"Target Column {target_column} not found in prediction_df!")
394
403
  self.log.important("In order to compute metrics, the target column must be present!")
395
- return prediction_df
404
+ metrics = pd.DataFrame()
396
405
 
397
406
  # Compute the standard performance metrics for this model
398
- model_type = model.model_type
399
- if model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
400
- prediction_df = self.residuals(target_column, prediction_df)
401
- metrics = self.regression_metrics(target_column, prediction_df)
402
- elif model_type == ModelType.CLASSIFIER:
403
- metrics = self.classification_metrics(target_column, prediction_df)
404
407
  else:
405
- # For other model types, we don't compute metrics
406
- self.log.info(f"Model Type: {model_type} doesn't have metrics...")
407
- metrics = pd.DataFrame()
408
+ if model.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
409
+ prediction_df = self.residuals(target_column, prediction_df)
410
+ metrics = self.regression_metrics(target_column, prediction_df)
411
+ elif model.model_type == ModelType.CLASSIFIER:
412
+ metrics = self.classification_metrics(target_column, prediction_df)
413
+ else:
414
+ # For other model types, we don't compute metrics
415
+ self.log.info(f"Model Type: {model.model_type} doesn't have metrics...")
416
+ metrics = pd.DataFrame()
408
417
 
409
418
  # Print out the metrics
410
- if not metrics.empty:
411
- print(f"Performance Metrics for {self.model_name} on {self.name}")
412
- print(metrics.head())
413
-
414
- # Capture the inference results and metrics
415
- if capture_name is not None:
416
- description = capture_name.replace("_", " ").title()
417
- features = model.features()
418
- self._capture_inference_results(
419
- capture_name, prediction_df, target_column, model_type, metrics, description, features, id_column
420
- )
421
-
422
- # Capture CrossFold Inference Results
423
- cross_fold_metrics = cross_fold_inference(model)
424
- if cross_fold_metrics:
425
- # Now put into the Parameter Store Model Inference Namespace
426
- self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", cross_fold_metrics)
427
-
428
- # For UQ Models we also capture the uncertainty metrics
429
- if model_type in [ModelType.UQ_REGRESSOR]:
430
- metrics = uq_metrics(prediction_df, target_column)
431
-
432
- # Now put into the Parameter Store Model Inference Namespace
433
- self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
419
+ print(f"Performance Metrics for {self.model_name} on {self.name}")
420
+ print(metrics.head())
421
+
422
+ # Capture the inference results and metrics
423
+ if capture_name is not None:
424
+
425
+ # If we don't have an id_column, we'll pull it from the model's FeatureSet
426
+ if id_column is None:
427
+ fs = FeatureSetCore(model.get_input())
428
+ id_column = fs.id_column
429
+ description = capture_name.replace("_", " ").title()
430
+ self._capture_inference_results(
431
+ capture_name, prediction_df, target_column, model.model_type, metrics, description, features, id_column
432
+ )
433
+
434
+ # For UQ Models we also capture the uncertainty metrics
435
+ if model.model_type in [ModelType.UQ_REGRESSOR]:
436
+ metrics = uq_metrics(prediction_df, target_column)
437
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
434
438
 
435
439
  # Return the prediction DataFrame
436
440
  return prediction_df
437
441
 
442
+ def cross_fold_inference(self, nfolds: int = 5) -> pd.DataFrame:
443
+ """Run cross-fold inference (only works for XGBoost models)
444
+
445
+ Args:
446
+ nfolds (int): Number of folds to use for cross-fold (default: 5)
447
+
448
+ Returns:
449
+ pd.DataFrame: A DataFrame with cross fold predictions
450
+ """
451
+
452
+ # Grab our model
453
+ model = ModelCore(self.model_name)
454
+
455
+ # Compute CrossFold (Metrics and Prediction Dataframe)
456
+ # For PyTorch and ChemProp, pull pre-computed CV results from training
457
+ if model.model_framework in [ModelFramework.UNKNOWN, ModelFramework.XGBOOST]:
458
+ cross_fold_metrics, out_of_fold_df = xgboost_cross_fold(model, nfolds=nfolds)
459
+ elif model.model_framework == ModelFramework.PYTORCH_TABULAR:
460
+ cross_fold_metrics, out_of_fold_df = pytorch_pull_cv(model)
461
+ elif model.model_framework == ModelFramework.CHEMPROP:
462
+ cross_fold_metrics, out_of_fold_df = chemprop_pull_cv(model)
463
+ else:
464
+ self.log.error(f"Cross-Fold Inference not supported for Model Framework: {model.model_framework}.")
465
+ return pd.DataFrame()
466
+
467
+ # If the metrics dataframe isn't empty save to the param store
468
+ if not cross_fold_metrics.empty:
469
+ # Convert to list of dictionaries
470
+ metrics = cross_fold_metrics.to_dict(orient="records")
471
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", metrics)
472
+
473
+ # If the out_of_fold_df is empty return it
474
+ if out_of_fold_df.empty:
475
+ self.log.warning("No out-of-fold predictions were made. Returning empty DataFrame.")
476
+ return out_of_fold_df
477
+
478
+ # Capture the results
479
+ capture_name = "full_cross_fold"
480
+ description = capture_name.replace("_", " ").title()
481
+ target_column = model.target()
482
+ model_type = model.model_type
483
+
484
+ # Get the id_column from the model's FeatureSet
485
+ fs = FeatureSetCore(model.get_input())
486
+ id_column = fs.id_column
487
+
488
+ # Is this a UQ Model? If so, run full inference and merge the results
489
+ additional_columns = []
490
+ if model.model_framework == ModelFramework.XGBOOST and model_type == ModelType.UQ_REGRESSOR:
491
+ self.log.important("UQ Regressor detected, running full inference to get uncertainty estimates...")
492
+
493
+ # Get the training view dataframe for inference
494
+ training_df = model.training_view().pull_dataframe()
495
+
496
+ # Run inference on the endpoint to get UQ outputs
497
+ uq_df = self.inference(training_df)
498
+
499
+ # Identify UQ-specific columns (quantiles and prediction_std)
500
+ uq_columns = [
501
+ col for col in uq_df.columns if col.startswith("q_") or col == "prediction_std" or col == "confidence"
502
+ ]
503
+
504
+ # Merge UQ columns with out-of-fold predictions
505
+ if uq_columns:
506
+ # Keep id_column and UQ columns, drop 'prediction' to avoid conflict when merging
507
+ uq_df = uq_df[[id_column] + uq_columns]
508
+
509
+ # Drop duplicates in uq_df based on id_column
510
+ uq_df = uq_df.drop_duplicates(subset=[id_column])
511
+
512
+ # Merge UQ columns into out_of_fold_df
513
+ out_of_fold_df = pd.merge(out_of_fold_df, uq_df, on=id_column, how="left")
514
+ additional_columns = uq_columns
515
+ self.log.info(f"Added UQ columns: {', '.join(additional_columns)}")
516
+
517
+ # Also compute UQ metrics
518
+ metrics = uq_metrics(out_of_fold_df, target_column)
519
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
520
+
521
+ self._capture_inference_results(
522
+ capture_name,
523
+ out_of_fold_df,
524
+ target_column,
525
+ model_type,
526
+ cross_fold_metrics,
527
+ description,
528
+ features=additional_columns,
529
+ id_column=id_column,
530
+ )
531
+ return out_of_fold_df
532
+
438
533
  def fast_inference(self, eval_df: pd.DataFrame, threads: int = 4) -> pd.DataFrame:
439
534
  """Run inference on the Endpoint using the provided DataFrame
440
535
 
@@ -450,11 +545,12 @@ class EndpointCore(Artifact):
450
545
  """
451
546
  return fast_inference(self.name, eval_df, self.sm_session, threads=threads)
452
547
 
453
- def _predict(self, eval_df: pd.DataFrame, drop_error_rows: bool = False) -> pd.DataFrame:
454
- """Internal: Run prediction on the given observations in the given DataFrame
548
+ def _predict(self, eval_df: pd.DataFrame, features: list[str], drop_error_rows: bool = False) -> pd.DataFrame:
549
+ """Internal: Run prediction on observations in the given DataFrame
455
550
 
456
551
  Args:
457
552
  eval_df (pd.DataFrame): DataFrame to run predictions on (must have superset of features)
553
+ features (list[str]): List of feature column names needed for prediction
458
554
  drop_error_rows (bool): If True, drop rows that had endpoint errors/issues (default=False)
459
555
  Returns:
460
556
  pd.DataFrame: Return the DataFrame with additional columns, prediction and any _proba columns
@@ -465,19 +561,12 @@ class EndpointCore(Artifact):
465
561
  self.log.warning("Evaluation DataFrame has 0 rows. No predictions to run.")
466
562
  return pd.DataFrame(columns=eval_df.columns) # Return empty DataFrame with same structure
467
563
 
468
- # Sanity check: Does the Model have Features?
469
- features = ModelCore(self.model_name).features()
470
- if not features:
471
- self.log.warning("Model does not have features defined, using all columns in the DataFrame")
472
- else:
473
- # Sanity check: Does the DataFrame have the required features?
474
- df_columns_lower = set(col.lower() for col in eval_df.columns)
475
- features_lower = set(feature.lower() for feature in features)
476
-
477
- # Check if the features are a subset of the DataFrame columns (case-insensitive)
478
- if not features_lower.issubset(df_columns_lower):
479
- missing_features = features_lower - df_columns_lower
480
- raise ValueError(f"DataFrame does not contain required features: {missing_features}")
564
+ # Sanity check: Does the DataFrame have the required features?
565
+ df_columns_lower = set(col.lower() for col in eval_df.columns)
566
+ features_lower = set(feature.lower() for feature in features)
567
+ if not features_lower.issubset(df_columns_lower):
568
+ missing_features = features_lower - df_columns_lower
569
+ raise ValueError(f"DataFrame does not contain required features: {missing_features}")
481
570
 
482
571
  # Create our Endpoint Predictor Class
483
572
  predictor = Predictor(
@@ -634,6 +723,10 @@ class EndpointCore(Artifact):
634
723
  @staticmethod
635
724
  def _hash_dataframe(df: pd.DataFrame, hash_length: int = 8):
636
725
  # Internal: Compute a data hash for the dataframe
726
+ if df.empty:
727
+ return "--hash--"
728
+
729
+ # Sort the dataframe by columns to ensure consistent ordering
637
730
  df = df.copy()
638
731
  df = df.sort_values(by=sorted(df.columns.tolist()))
639
732
  row_hashes = pd.util.hash_pandas_object(df, index=False)
@@ -687,19 +780,17 @@ class EndpointCore(Artifact):
687
780
  self.log.info(f"Writing metrics to {inference_capture_path}/inference_metrics.csv")
688
781
  wr.s3.to_csv(metrics, f"{inference_capture_path}/inference_metrics.csv", index=False)
689
782
 
690
- # Grab the target column, prediction column, any _proba columns, and the ID column (if present)
691
- prediction_col = "prediction" if "prediction" in pred_results_df.columns else "predictions"
692
- output_columns = [target_column, prediction_col]
693
-
694
- # Add any _proba columns to the output columns
695
- output_columns += [col for col in pred_results_df.columns if col.endswith("_proba")]
696
-
697
- # Add any quantile columns to the output columns
698
- output_columns += [col for col in pred_results_df.columns if col.startswith("q_") or col.startswith("qr_")]
699
-
700
- # Add the ID column
783
+ # Grab the ID column and target column if they are present
784
+ output_columns = []
701
785
  if id_column and id_column in pred_results_df.columns:
702
786
  output_columns.append(id_column)
787
+ if target_column in pred_results_df.columns:
788
+ output_columns.append(target_column)
789
+
790
+ # Grab the prediction column, any _proba columns, and UQ columns
791
+ output_columns += [col for col in pred_results_df.columns if "prediction" in col]
792
+ output_columns += [col for col in pred_results_df.columns if col.endswith("_proba")]
793
+ output_columns += [col for col in pred_results_df.columns if col.startswith("q_") or col == "confidence"]
703
794
 
704
795
  # Write the predictions to our S3 Model Inference Folder
705
796
  self.log.info(f"Writing predictions to {inference_capture_path}/inference_predictions.csv")
@@ -713,18 +804,10 @@ class EndpointCore(Artifact):
713
804
  # Note: Unlike other dataframes here, we want to write the index (labels) to the CSV
714
805
  wr.s3.to_csv(conf_mtx, f"{inference_capture_path}/inference_cm.csv", index=True)
715
806
 
716
- # Generate SHAP values for our Prediction Dataframe
717
- # generate_shap_values(self.endpoint_name, model_type.value, pred_results_df, inference_capture_path)
718
-
719
807
  # Now recompute the details for our Model
720
- self.log.important(f"Recomputing Details for {self.model_name} to show latest Inference Results...")
808
+ self.log.important(f"Loading inference metrics for {self.model_name}...")
721
809
  model = ModelCore(self.model_name)
722
810
  model._load_inference_metrics(capture_name)
723
- model.details()
724
-
725
- # Recompute the details so that inference model metrics are updated
726
- self.log.important(f"Recomputing Details for {self.name} to show latest Inference Results...")
727
- self.details()
728
811
 
729
812
  def regression_metrics(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
730
813
  """Compute the performance metrics for this Endpoint
@@ -740,10 +823,23 @@ class EndpointCore(Artifact):
740
823
  self.log.warning("No predictions were made. Returning empty DataFrame.")
741
824
  return pd.DataFrame()
742
825
 
826
+ # Check for NaN values in target or prediction columns
827
+ prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
828
+ if prediction_df[target_column].isnull().any() or prediction_df[prediction_col].isnull().any():
829
+ # Compute the number of NaN values in each column
830
+ num_nan_target = prediction_df[target_column].isnull().sum()
831
+ num_nan_prediction = prediction_df[prediction_col].isnull().sum()
832
+ self.log.warning(
833
+ f"NaNs Found: {target_column} {num_nan_target} and {prediction_col}: {num_nan_prediction}."
834
+ )
835
+ self.log.warning(
836
+ "NaN values found in target or prediction columns. Dropping NaN rows for metric computation."
837
+ )
838
+ prediction_df = prediction_df.dropna(subset=[target_column, prediction_col])
839
+
743
840
  # Compute the metrics
744
841
  try:
745
842
  y_true = prediction_df[target_column]
746
- prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
747
843
  y_pred = prediction_df[prediction_col]
748
844
 
749
845
  mae = mean_absolute_error(y_true, y_pred)
@@ -821,6 +917,14 @@ class EndpointCore(Artifact):
821
917
  Returns:
822
918
  pd.DataFrame: DataFrame with the performance metrics
823
919
  """
920
+ # Drop rows with NaN predictions (can't compute metrics on missing predictions)
921
+ prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
922
+ nan_mask = prediction_df[prediction_col].isna()
923
+ if nan_mask.any():
924
+ n_nan = nan_mask.sum()
925
+ self.log.warning(f"Dropping {n_nan} rows with NaN predictions for metrics calculation")
926
+ prediction_df = prediction_df[~nan_mask].copy()
927
+
824
928
  # Get the class labels from the model
825
929
  class_labels = ModelCore(self.model_name).class_labels()
826
930
  if class_labels is None:
@@ -832,8 +936,7 @@ class EndpointCore(Artifact):
832
936
  else:
833
937
  self.validate_proba_columns(prediction_df, class_labels)
834
938
 
835
- # Calculate precision, recall, fscore, and support, handling zero division
836
- prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
939
+ # Calculate precision, recall, f1, and support, handling zero division
837
940
  scores = precision_recall_fscore_support(
838
941
  prediction_df[target_column],
839
942
  prediction_df[prediction_col],
@@ -867,7 +970,7 @@ class EndpointCore(Artifact):
867
970
  target_column: class_labels,
868
971
  "precision": scores[0],
869
972
  "recall": scores[1],
870
- "fscore": scores[2],
973
+ "f1": scores[2],
871
974
  "roc_auc": roc_auc_per_label,
872
975
  "support": scores[3],
873
976
  }
@@ -876,36 +979,39 @@ class EndpointCore(Artifact):
876
979
 
877
980
  def generate_confusion_matrix(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
878
981
  """Compute the confusion matrix for this Endpoint
982
+
879
983
  Args:
880
984
  target_column (str): Name of the target column
881
985
  prediction_df (pd.DataFrame): DataFrame with the prediction results
986
+
882
987
  Returns:
883
988
  pd.DataFrame: DataFrame with the confusion matrix
884
989
  """
990
+ # Drop rows with NaN predictions (can't include in confusion matrix)
991
+ prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
992
+ nan_mask = prediction_df[prediction_col].isna()
993
+ if nan_mask.any():
994
+ n_nan = nan_mask.sum()
995
+ self.log.warning(f"Dropping {n_nan} rows with NaN predictions for confusion matrix")
996
+ prediction_df = prediction_df[~nan_mask].copy()
885
997
 
886
998
  y_true = prediction_df[target_column]
887
- prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
888
999
  y_pred = prediction_df[prediction_col]
889
1000
 
890
- # Check if our model has class labels, if not we'll use the unique labels in the prediction
891
- class_labels = ModelCore(self.model_name).class_labels()
892
- if class_labels is None:
893
- class_labels = sorted(list(set(y_true) | set(y_pred)))
1001
+ # Get model class labels
1002
+ model_class_labels = ModelCore(self.model_name).class_labels()
894
1003
 
895
- # Compute the confusion matrix (sklearn confusion_matrix)
896
- conf_mtx = confusion_matrix(y_true, y_pred, labels=class_labels)
1004
+ # Use model labels if available, otherwise infer from data
1005
+ if model_class_labels:
1006
+ self.log.important("Using model class labels for confusion matrix ordering...")
1007
+ labels = model_class_labels
1008
+ else:
1009
+ labels = sorted(list(set(y_true) | set(y_pred)))
897
1010
 
898
- # Create a DataFrame
899
- conf_mtx_df = pd.DataFrame(conf_mtx, index=class_labels, columns=class_labels)
1011
+ # Compute confusion matrix and create DataFrame
1012
+ conf_mtx = confusion_matrix(y_true, y_pred, labels=labels)
1013
+ conf_mtx_df = pd.DataFrame(conf_mtx, index=labels, columns=labels)
900
1014
  conf_mtx_df.index.name = "labels"
901
-
902
- # Check if our model has class labels. If so make the index and columns ordered
903
- model_class_labels = ModelCore(self.model_name).class_labels()
904
- if model_class_labels:
905
- self.log.important("Reordering the confusion matrix based on model class labels...")
906
- conf_mtx_df.index = pd.Categorical(conf_mtx_df.index, categories=model_class_labels, ordered=True)
907
- conf_mtx_df.columns = pd.Categorical(conf_mtx_df.columns, categories=model_class_labels, ordered=True)
908
- conf_mtx_df = conf_mtx_df.sort_index().sort_index(axis=1)
909
1015
  return conf_mtx_df
910
1016
 
911
1017
  def endpoint_config_name(self) -> str:
@@ -932,9 +1038,9 @@ class EndpointCore(Artifact):
932
1038
  self.upsert_workbench_meta({"workbench_input": input})
933
1039
 
934
1040
  def delete(self):
935
- """ "Delete an existing Endpoint: Underlying Models, Configuration, and Endpoint"""
1041
+ """Delete an existing Endpoint: Underlying Models, Configuration, and Endpoint"""
936
1042
  if not self.exists():
937
- self.log.warning(f"Trying to delete an Model that doesn't exist: {self.name}")
1043
+ self.log.warning(f"Trying to delete an Endpoint that doesn't exist: {self.name}")
938
1044
 
939
1045
  # Remove this endpoint from the list of registered endpoints
940
1046
  self.log.info(f"Removing {self.name} from the list of registered endpoints...")
@@ -975,12 +1081,23 @@ class EndpointCore(Artifact):
975
1081
  cls.log.info(f"Deleting Monitoring Schedule {schedule['MonitoringScheduleName']}...")
976
1082
  cls.sm_client.delete_monitoring_schedule(MonitoringScheduleName=schedule["MonitoringScheduleName"])
977
1083
 
978
- # Recursively delete all endpoint S3 artifacts (inference, data capture, monitoring, etc)
979
- base_endpoint_path = f"{cls.endpoints_s3_path}/{endpoint_name}"
980
- s3_objects = wr.s3.list_objects(base_endpoint_path, boto3_session=cls.boto3_session)
981
- cls.log.info(f"Deleting S3 Objects at {base_endpoint_path}...")
982
- cls.log.info(f"{s3_objects}")
983
- wr.s3.delete_objects(s3_objects, boto3_session=cls.boto3_session)
1084
+ # Recursively delete all endpoint S3 artifacts (inference, etc)
1085
+ # Note: We do not want to delete the data_capture/ files since these
1086
+ # might be used for collection and data drift analysis
1087
+ base_endpoint_path = f"{cls.endpoints_s3_path}/{endpoint_name}/"
1088
+ all_s3_objects = wr.s3.list_objects(base_endpoint_path, boto3_session=cls.boto3_session)
1089
+
1090
+ # Filter out objects that contain 'data_capture/' in their path
1091
+ s3_objects_to_delete = [obj for obj in all_s3_objects if "/data_capture/" not in obj]
1092
+ cls.log.info(f"Found {len(all_s3_objects)} total objects at {base_endpoint_path}")
1093
+ cls.log.info(f"Filtering out data_capture files, will delete {len(s3_objects_to_delete)} objects...")
1094
+ cls.log.info(f"Objects to delete: {s3_objects_to_delete}")
1095
+
1096
+ if s3_objects_to_delete:
1097
+ wr.s3.delete_objects(s3_objects_to_delete, boto3_session=cls.boto3_session)
1098
+ cls.log.info(f"Successfully deleted {len(s3_objects_to_delete)} objects")
1099
+ else:
1100
+ cls.log.info("No objects to delete (only data_capture files found)")
984
1101
 
985
1102
  # Delete any dataframes that were stored in the Dataframe Cache
986
1103
  cls.log.info("Deleting Dataframe Cache...")
@@ -1031,7 +1148,7 @@ class EndpointCore(Artifact):
1031
1148
  if __name__ == "__main__":
1032
1149
  """Exercise the Endpoint Class"""
1033
1150
  from workbench.api import FeatureSet
1034
- from workbench.utils.endpoint_utils import fs_evaluation_data
1151
+ from workbench.utils.endpoint_utils import get_evaluation_data
1035
1152
  import random
1036
1153
 
1037
1154
  # Grab an EndpointCore object and pull some information from it
@@ -1039,7 +1156,7 @@ if __name__ == "__main__":
1039
1156
 
1040
1157
  # Test various error conditions (set row 42 length to pd.NA)
1041
1158
  # Note: This test should return ALL rows
1042
- my_eval_df = fs_evaluation_data(my_endpoint)
1159
+ my_eval_df = get_evaluation_data(my_endpoint)
1043
1160
  my_eval_df.at[42, "length"] = pd.NA
1044
1161
  pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
1045
1162
  print(f"Sent rows: {len(my_eval_df)}")
@@ -1047,6 +1164,9 @@ if __name__ == "__main__":
1047
1164
  assert len(pred_results) == len(my_eval_df), "Predictions should match the number of sent rows"
1048
1165
 
1049
1166
  # Now we put in an invalid value
1167
+ print("*" * 80)
1168
+ print("NOW TESTING ERROR CONDITIONS...")
1169
+ print("*" * 80)
1050
1170
  my_eval_df.at[42, "length"] = "invalid_value"
1051
1171
  pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
1052
1172
  print(f"Sent rows: {len(my_eval_df)}")
@@ -1086,13 +1206,20 @@ if __name__ == "__main__":
1086
1206
  df = fs.pull_dataframe()[:100]
1087
1207
  cap_df = df.copy()
1088
1208
  cap_df.columns = [col.upper() for col in cap_df.columns]
1089
- my_endpoint._predict(cap_df)
1209
+ my_endpoint.inference(cap_df)
1090
1210
 
1091
1211
  # Boolean Type Test
1092
1212
  df["bool_column"] = [random.choice([True, False]) for _ in range(len(df))]
1093
- result_df = my_endpoint._predict(df)
1213
+ result_df = my_endpoint.inference(df)
1094
1214
  assert result_df["bool_column"].dtype == bool
1095
1215
 
1216
+ # Missing Feature Test
1217
+ missing_df = df.drop(columns=["length"])
1218
+ try:
1219
+ my_endpoint.inference(missing_df)
1220
+ except ValueError as e:
1221
+ print(f"Expected error for missing feature: {e}")
1222
+
1096
1223
  # Run Auto Inference on the Endpoint (uses the FeatureSet)
1097
1224
  print("Running Auto Inference...")
1098
1225
  my_endpoint.auto_inference()
@@ -1100,13 +1227,21 @@ if __name__ == "__main__":
1100
1227
  # Run Inference where we provide the data
1101
1228
  # Note: This dataframe could be from a FeatureSet or any other source
1102
1229
  print("Running Inference...")
1103
- my_eval_df = fs_evaluation_data(my_endpoint)
1230
+ my_eval_df = get_evaluation_data(my_endpoint)
1104
1231
  pred_results = my_endpoint.inference(my_eval_df)
1105
1232
 
1106
1233
  # Now set capture=True to save inference results and metrics
1107
- my_eval_df = fs_evaluation_data(my_endpoint)
1234
+ my_eval_df = get_evaluation_data(my_endpoint)
1108
1235
  pred_results = my_endpoint.inference(my_eval_df, capture_name="holdout_xyz")
1109
1236
 
1237
+ # Run predictions using the fast_inference method
1238
+ fast_results = my_endpoint.fast_inference(my_eval_df)
1239
+
1240
+ # Test the cross_fold_inference method
1241
+ print("Running Cross-Fold Inference...")
1242
+ all_results = my_endpoint.cross_fold_inference()
1243
+ print(all_results)
1244
+
1110
1245
  # Run Inference and metrics for a Classification Endpoint
1111
1246
  class_endpoint = EndpointCore("wine-classification")
1112
1247
  auto_predictions = class_endpoint.auto_inference()
@@ -1115,8 +1250,11 @@ if __name__ == "__main__":
1115
1250
  target = "wine_class"
1116
1251
  print(class_endpoint.generate_confusion_matrix(target, auto_predictions))
1117
1252
 
1118
- # Run predictions using the fast_inference method
1119
- fast_results = my_endpoint.fast_inference(my_eval_df)
1253
+ # Test the cross_fold_inference method
1254
+ print("Running Cross-Fold Inference...")
1255
+ all_results = class_endpoint.cross_fold_inference()
1256
+ print(all_results)
1257
+ print("All done...")
1120
1258
 
1121
1259
  # Test the class method delete (commented out for now)
1122
1260
  # from workbench.api import Model