workbench 0.8.161__py3-none-any.whl → 0.8.192__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. workbench/algorithms/dataframe/proximity.py +143 -102
  2. workbench/algorithms/graph/light/proximity_graph.py +2 -1
  3. workbench/api/compound.py +1 -1
  4. workbench/api/endpoint.py +12 -0
  5. workbench/api/feature_set.py +4 -4
  6. workbench/api/meta.py +5 -2
  7. workbench/api/model.py +16 -12
  8. workbench/api/monitor.py +1 -16
  9. workbench/core/artifacts/artifact.py +11 -3
  10. workbench/core/artifacts/data_capture_core.py +355 -0
  11. workbench/core/artifacts/endpoint_core.py +168 -78
  12. workbench/core/artifacts/feature_set_core.py +72 -13
  13. workbench/core/artifacts/model_core.py +50 -15
  14. workbench/core/artifacts/monitor_core.py +33 -248
  15. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  16. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  17. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  18. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  19. workbench/core/transforms/features_to_model/features_to_model.py +9 -4
  20. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  21. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  22. workbench/core/views/training_view.py +49 -53
  23. workbench/core/views/view.py +51 -1
  24. workbench/core/views/view_utils.py +4 -4
  25. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  26. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  27. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  28. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  29. workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
  30. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  31. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
  32. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  33. workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
  34. workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
  35. workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
  36. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  37. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
  38. workbench/model_scripts/pytorch_model/pytorch.template +19 -20
  39. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  40. workbench/model_scripts/script_generation.py +7 -2
  41. workbench/model_scripts/uq_models/mapie.template +492 -0
  42. workbench/model_scripts/uq_models/requirements.txt +1 -0
  43. workbench/model_scripts/xgb_model/xgb_model.template +31 -40
  44. workbench/repl/workbench_shell.py +11 -6
  45. workbench/scripts/lambda_launcher.py +63 -0
  46. workbench/scripts/ml_pipeline_batch.py +137 -0
  47. workbench/scripts/ml_pipeline_sqs.py +186 -0
  48. workbench/scripts/monitor_cloud_watch.py +20 -100
  49. workbench/utils/aws_utils.py +4 -3
  50. workbench/utils/chem_utils/__init__.py +0 -0
  51. workbench/utils/chem_utils/fingerprints.py +134 -0
  52. workbench/utils/chem_utils/misc.py +194 -0
  53. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  54. workbench/utils/chem_utils/mol_standardize.py +450 -0
  55. workbench/utils/chem_utils/mol_tagging.py +348 -0
  56. workbench/utils/chem_utils/projections.py +209 -0
  57. workbench/utils/chem_utils/salts.py +256 -0
  58. workbench/utils/chem_utils/sdf.py +292 -0
  59. workbench/utils/chem_utils/toxicity.py +250 -0
  60. workbench/utils/chem_utils/vis.py +253 -0
  61. workbench/utils/cloudwatch_handler.py +1 -1
  62. workbench/utils/cloudwatch_utils.py +137 -0
  63. workbench/utils/config_manager.py +3 -7
  64. workbench/utils/endpoint_utils.py +5 -7
  65. workbench/utils/license_manager.py +2 -6
  66. workbench/utils/model_utils.py +76 -30
  67. workbench/utils/monitor_utils.py +44 -62
  68. workbench/utils/pandas_utils.py +3 -3
  69. workbench/utils/shap_utils.py +10 -2
  70. workbench/utils/workbench_logging.py +0 -3
  71. workbench/utils/workbench_sqs.py +1 -1
  72. workbench/utils/xgboost_model_utils.py +283 -145
  73. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  74. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  75. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  76. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/METADATA +4 -4
  77. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/RECORD +81 -76
  78. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/entry_points.txt +3 -0
  79. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  80. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  81. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  82. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  83. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  84. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  85. workbench/model_scripts/pytorch_model/generated_model_script.py +0 -565
  86. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  87. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  88. workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
  89. workbench/model_scripts/xgb_model/generated_model_script.py +0 -477
  90. workbench/utils/chem_utils.py +0 -1556
  91. workbench/utils/execution_environment.py +0 -211
  92. workbench/utils/fast_inference.py +0 -167
  93. workbench/utils/resource_utils.py +0 -39
  94. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/WHEEL +0 -0
  95. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/licenses/LICENSE +0 -0
  96. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,7 @@ import pandas as pd
8
8
  import numpy as np
9
9
  from io import StringIO
10
10
  import awswrangler as wr
11
- from typing import Union, Optional
11
+ from typing import Union, Optional, Tuple
12
12
  import hashlib
13
13
 
14
14
  # Model Performance Scores
@@ -32,11 +32,11 @@ from sagemaker import Predictor
32
32
  from workbench.core.artifacts.artifact import Artifact
33
33
  from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType
34
34
  from workbench.utils.endpoint_metrics import EndpointMetrics
35
- from workbench.utils.fast_inference import fast_inference
36
35
  from workbench.utils.cache import Cache
37
36
  from workbench.utils.s3_utils import compute_s3_object_hash
38
37
  from workbench.utils.model_utils import uq_metrics
39
38
  from workbench.utils.xgboost_model_utils import cross_fold_inference
39
+ from workbench_bridges.endpoints.fast_inference import fast_inference
40
40
 
41
41
 
42
42
  class EndpointCore(Artifact):
@@ -164,11 +164,17 @@ class EndpointCore(Artifact):
164
164
  """
165
165
  return "Serverless" in self.endpoint_meta["InstanceType"]
166
166
 
167
- def add_data_capture(self):
167
+ def data_capture(self):
168
+ """Get the MonitorCore class for this endpoint"""
169
+ from workbench.core.artifacts.data_capture_core import DataCaptureCore
170
+
171
+ return DataCaptureCore(self.endpoint_name)
172
+
173
+ def enable_data_capture(self):
168
174
  """Add data capture to the endpoint"""
169
- self.get_monitor().add_data_capture()
175
+ self.data_capture().enable()
170
176
 
171
- def get_monitor(self):
177
+ def monitor(self):
172
178
  """Get the MonitorCore class for this endpoint"""
173
179
  from workbench.core.artifacts.monitor_core import MonitorCore
174
180
 
@@ -350,7 +356,7 @@ class EndpointCore(Artifact):
350
356
  return pd.DataFrame()
351
357
 
352
358
  # Grab the evaluation data from the FeatureSet
353
- table = fs.view("training").table
359
+ table = model.training_view().table
354
360
  eval_df = fs.query(f'SELECT * FROM "{table}" where training = FALSE')
355
361
  capture_name = "auto_inference" if capture else None
356
362
  return self.inference(eval_df, capture_name, id_column=fs.id_column)
@@ -378,16 +384,17 @@ class EndpointCore(Artifact):
378
384
  self.log.important("No model associated with this endpoint, running 'no frills' inference...")
379
385
  return self.fast_inference(eval_df)
380
386
 
387
+ # Grab the model features and target column
388
+ model = ModelCore(self.model_name)
389
+ features = model.features()
390
+ target_column = model.target()
391
+
381
392
  # Run predictions on the evaluation data
382
- prediction_df = self._predict(eval_df, drop_error_rows)
393
+ prediction_df = self._predict(eval_df, features, drop_error_rows)
383
394
  if prediction_df.empty:
384
395
  self.log.warning("No predictions were made. Returning empty DataFrame.")
385
396
  return prediction_df
386
397
 
387
- # Get the target column
388
- model = ModelCore(self.model_name)
389
- target_column = model.target()
390
-
391
398
  # Sanity Check that the target column is present
392
399
  if target_column and (target_column not in prediction_df.columns):
393
400
  self.log.important(f"Target Column {target_column} not found in prediction_df!")
@@ -413,28 +420,95 @@ class EndpointCore(Artifact):
413
420
 
414
421
  # Capture the inference results and metrics
415
422
  if capture_name is not None:
423
+
424
+ # If we don't have an id_column, we'll pull it from the model's FeatureSet
425
+ if id_column is None:
426
+ fs = FeatureSetCore(model.get_input())
427
+ id_column = fs.id_column
416
428
  description = capture_name.replace("_", " ").title()
417
- features = model.features()
418
429
  self._capture_inference_results(
419
430
  capture_name, prediction_df, target_column, model_type, metrics, description, features, id_column
420
431
  )
421
432
 
422
- # Capture CrossFold Inference Results
423
- cross_fold_metrics = cross_fold_inference(model)
424
- if cross_fold_metrics:
425
- # Now put into the Parameter Store Model Inference Namespace
426
- self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", cross_fold_metrics)
427
-
428
433
  # For UQ Models we also capture the uncertainty metrics
429
434
  if model_type in [ModelType.UQ_REGRESSOR]:
430
435
  metrics = uq_metrics(prediction_df, target_column)
431
-
432
- # Now put into the Parameter Store Model Inference Namespace
433
436
  self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
434
437
 
435
438
  # Return the prediction DataFrame
436
439
  return prediction_df
437
440
 
441
+ def cross_fold_inference(self, nfolds: int = 5) -> Tuple[dict, pd.DataFrame]:
442
+ """Run cross-fold inference (only works for XGBoost models)
443
+
444
+ Args:
445
+ nfolds (int): Number of folds to use for cross-fold (default: 5)
446
+
447
+ Returns:
448
+ Tuple[dict, pd.DataFrame]: Tuple of (cross_fold_metrics, out_of_fold_df)
449
+ """
450
+
451
+ # Grab our model
452
+ model = ModelCore(self.model_name)
453
+
454
+ # Compute CrossFold Metrics
455
+ cross_fold_metrics, out_of_fold_df = cross_fold_inference(model, nfolds=nfolds)
456
+ if cross_fold_metrics:
457
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", cross_fold_metrics)
458
+
459
+ # Capture the results
460
+ capture_name = "full_cross_fold"
461
+ description = capture_name.replace("_", " ").title()
462
+ target_column = model.target()
463
+ model_type = model.model_type
464
+
465
+ # Get the id_column from the model's FeatureSet
466
+ fs = FeatureSetCore(model.get_input())
467
+ id_column = fs.id_column
468
+
469
+ # Is this a UQ Model? If so, run full inference and merge the results
470
+ additional_columns = []
471
+ if model_type == ModelType.UQ_REGRESSOR:
472
+ self.log.important("UQ Regressor detected, running full inference to get uncertainty estimates...")
473
+
474
+ # Get the training view dataframe for inference
475
+ training_df = model.training_view().pull_dataframe()
476
+
477
+ # Run inference on the endpoint to get UQ outputs
478
+ uq_df = self.inference(training_df)
479
+
480
+ # Identify UQ-specific columns (quantiles and prediction_std)
481
+ uq_columns = [col for col in uq_df.columns if col.startswith("q_") or col == "prediction_std"]
482
+
483
+ # Merge UQ columns with out-of-fold predictions
484
+ if uq_columns:
485
+ # Keep id_column and UQ columns, drop 'prediction' to avoid conflict when merging
486
+ uq_df = uq_df[[id_column] + uq_columns]
487
+
488
+ # Drop duplicates in uq_df based on id_column
489
+ uq_df = uq_df.drop_duplicates(subset=[id_column])
490
+
491
+ # Merge UQ columns into out_of_fold_df
492
+ out_of_fold_df = pd.merge(out_of_fold_df, uq_df, on=id_column, how="left")
493
+ additional_columns = uq_columns
494
+ self.log.info(f"Added UQ columns: {', '.join(additional_columns)}")
495
+
496
+ # Also compute UQ metrics
497
+ metrics = uq_metrics(out_of_fold_df, target_column)
498
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
499
+
500
+ self._capture_inference_results(
501
+ capture_name,
502
+ out_of_fold_df,
503
+ target_column,
504
+ model_type,
505
+ pd.DataFrame([cross_fold_metrics["summary_metrics"]]),
506
+ description,
507
+ features=additional_columns,
508
+ id_column=id_column,
509
+ )
510
+ return cross_fold_metrics, out_of_fold_df
511
+
438
512
  def fast_inference(self, eval_df: pd.DataFrame, threads: int = 4) -> pd.DataFrame:
439
513
  """Run inference on the Endpoint using the provided DataFrame
440
514
 
@@ -450,11 +524,12 @@ class EndpointCore(Artifact):
450
524
  """
451
525
  return fast_inference(self.name, eval_df, self.sm_session, threads=threads)
452
526
 
453
- def _predict(self, eval_df: pd.DataFrame, drop_error_rows: bool = False) -> pd.DataFrame:
454
- """Internal: Run prediction on the given observations in the given DataFrame
527
+ def _predict(self, eval_df: pd.DataFrame, features: list[str], drop_error_rows: bool = False) -> pd.DataFrame:
528
+ """Internal: Run prediction on observations in the given DataFrame
455
529
 
456
530
  Args:
457
531
  eval_df (pd.DataFrame): DataFrame to run predictions on (must have superset of features)
532
+ features (list[str]): List of feature column names needed for prediction
458
533
  drop_error_rows (bool): If True, drop rows that had endpoint errors/issues (default=False)
459
534
  Returns:
460
535
  pd.DataFrame: Return the DataFrame with additional columns, prediction and any _proba columns
@@ -465,19 +540,12 @@ class EndpointCore(Artifact):
465
540
  self.log.warning("Evaluation DataFrame has 0 rows. No predictions to run.")
466
541
  return pd.DataFrame(columns=eval_df.columns) # Return empty DataFrame with same structure
467
542
 
468
- # Sanity check: Does the Model have Features?
469
- features = ModelCore(self.model_name).features()
470
- if not features:
471
- self.log.warning("Model does not have features defined, using all columns in the DataFrame")
472
- else:
473
- # Sanity check: Does the DataFrame have the required features?
474
- df_columns_lower = set(col.lower() for col in eval_df.columns)
475
- features_lower = set(feature.lower() for feature in features)
476
-
477
- # Check if the features are a subset of the DataFrame columns (case-insensitive)
478
- if not features_lower.issubset(df_columns_lower):
479
- missing_features = features_lower - df_columns_lower
480
- raise ValueError(f"DataFrame does not contain required features: {missing_features}")
543
+ # Sanity check: Does the DataFrame have the required features?
544
+ df_columns_lower = set(col.lower() for col in eval_df.columns)
545
+ features_lower = set(feature.lower() for feature in features)
546
+ if not features_lower.issubset(df_columns_lower):
547
+ missing_features = features_lower - df_columns_lower
548
+ raise ValueError(f"DataFrame does not contain required features: {missing_features}")
481
549
 
482
550
  # Create our Endpoint Predictor Class
483
551
  predictor = Predictor(
@@ -634,6 +702,10 @@ class EndpointCore(Artifact):
634
702
  @staticmethod
635
703
  def _hash_dataframe(df: pd.DataFrame, hash_length: int = 8):
636
704
  # Internal: Compute a data hash for the dataframe
705
+ if df.empty:
706
+ return "--hash--"
707
+
708
+ # Sort the dataframe by columns to ensure consistent ordering
637
709
  df = df.copy()
638
710
  df = df.sort_values(by=sorted(df.columns.tolist()))
639
711
  row_hashes = pd.util.hash_pandas_object(df, index=False)
@@ -688,8 +760,8 @@ class EndpointCore(Artifact):
688
760
  wr.s3.to_csv(metrics, f"{inference_capture_path}/inference_metrics.csv", index=False)
689
761
 
690
762
  # Grab the target column, prediction column, any _proba columns, and the ID column (if present)
691
- prediction_col = "prediction" if "prediction" in pred_results_df.columns else "predictions"
692
- output_columns = [target_column, prediction_col]
763
+ output_columns = [target_column]
764
+ output_columns += [col for col in pred_results_df.columns if "prediction" in col]
693
765
 
694
766
  # Add any _proba columns to the output columns
695
767
  output_columns += [col for col in pred_results_df.columns if col.endswith("_proba")]
@@ -699,7 +771,7 @@ class EndpointCore(Artifact):
699
771
 
700
772
  # Add the ID column
701
773
  if id_column and id_column in pred_results_df.columns:
702
- output_columns.append(id_column)
774
+ output_columns.insert(0, id_column)
703
775
 
704
776
  # Write the predictions to our S3 Model Inference Folder
705
777
  self.log.info(f"Writing predictions to {inference_capture_path}/inference_predictions.csv")
@@ -713,18 +785,10 @@ class EndpointCore(Artifact):
713
785
  # Note: Unlike other dataframes here, we want to write the index (labels) to the CSV
714
786
  wr.s3.to_csv(conf_mtx, f"{inference_capture_path}/inference_cm.csv", index=True)
715
787
 
716
- # Generate SHAP values for our Prediction Dataframe
717
- # generate_shap_values(self.endpoint_name, model_type.value, pred_results_df, inference_capture_path)
718
-
719
788
  # Now recompute the details for our Model
720
- self.log.important(f"Recomputing Details for {self.model_name} to show latest Inference Results...")
789
+ self.log.important(f"Loading inference metrics for {self.model_name}...")
721
790
  model = ModelCore(self.model_name)
722
791
  model._load_inference_metrics(capture_name)
723
- model.details()
724
-
725
- # Recompute the details so that inference model metrics are updated
726
- self.log.important(f"Recomputing Details for {self.name} to show latest Inference Results...")
727
- self.details()
728
792
 
729
793
  def regression_metrics(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
730
794
  """Compute the performance metrics for this Endpoint
@@ -876,9 +940,11 @@ class EndpointCore(Artifact):
876
940
 
877
941
  def generate_confusion_matrix(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
878
942
  """Compute the confusion matrix for this Endpoint
943
+
879
944
  Args:
880
945
  target_column (str): Name of the target column
881
946
  prediction_df (pd.DataFrame): DataFrame with the prediction results
947
+
882
948
  Returns:
883
949
  pd.DataFrame: DataFrame with the confusion matrix
884
950
  """
@@ -887,25 +953,20 @@ class EndpointCore(Artifact):
887
953
  prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
888
954
  y_pred = prediction_df[prediction_col]
889
955
 
890
- # Check if our model has class labels, if not we'll use the unique labels in the prediction
891
- class_labels = ModelCore(self.model_name).class_labels()
892
- if class_labels is None:
893
- class_labels = sorted(list(set(y_true) | set(y_pred)))
956
+ # Get model class labels
957
+ model_class_labels = ModelCore(self.model_name).class_labels()
894
958
 
895
- # Compute the confusion matrix (sklearn confusion_matrix)
896
- conf_mtx = confusion_matrix(y_true, y_pred, labels=class_labels)
959
+ # Use model labels if available, otherwise infer from data
960
+ if model_class_labels:
961
+ self.log.important("Using model class labels for confusion matrix ordering...")
962
+ labels = model_class_labels
963
+ else:
964
+ labels = sorted(list(set(y_true) | set(y_pred)))
897
965
 
898
- # Create a DataFrame
899
- conf_mtx_df = pd.DataFrame(conf_mtx, index=class_labels, columns=class_labels)
966
+ # Compute confusion matrix and create DataFrame
967
+ conf_mtx = confusion_matrix(y_true, y_pred, labels=labels)
968
+ conf_mtx_df = pd.DataFrame(conf_mtx, index=labels, columns=labels)
900
969
  conf_mtx_df.index.name = "labels"
901
-
902
- # Check if our model has class labels. If so make the index and columns ordered
903
- model_class_labels = ModelCore(self.model_name).class_labels()
904
- if model_class_labels:
905
- self.log.important("Reordering the confusion matrix based on model class labels...")
906
- conf_mtx_df.index = pd.Categorical(conf_mtx_df.index, categories=model_class_labels, ordered=True)
907
- conf_mtx_df.columns = pd.Categorical(conf_mtx_df.columns, categories=model_class_labels, ordered=True)
908
- conf_mtx_df = conf_mtx_df.sort_index().sort_index(axis=1)
909
970
  return conf_mtx_df
910
971
 
911
972
  def endpoint_config_name(self) -> str:
@@ -932,9 +993,9 @@ class EndpointCore(Artifact):
932
993
  self.upsert_workbench_meta({"workbench_input": input})
933
994
 
934
995
  def delete(self):
935
- """ "Delete an existing Endpoint: Underlying Models, Configuration, and Endpoint"""
996
+ """Delete an existing Endpoint: Underlying Models, Configuration, and Endpoint"""
936
997
  if not self.exists():
937
- self.log.warning(f"Trying to delete an Model that doesn't exist: {self.name}")
998
+ self.log.warning(f"Trying to delete an Endpoint that doesn't exist: {self.name}")
938
999
 
939
1000
  # Remove this endpoint from the list of registered endpoints
940
1001
  self.log.info(f"Removing {self.name} from the list of registered endpoints...")
@@ -975,12 +1036,23 @@ class EndpointCore(Artifact):
975
1036
  cls.log.info(f"Deleting Monitoring Schedule {schedule['MonitoringScheduleName']}...")
976
1037
  cls.sm_client.delete_monitoring_schedule(MonitoringScheduleName=schedule["MonitoringScheduleName"])
977
1038
 
978
- # Recursively delete all endpoint S3 artifacts (inference, data capture, monitoring, etc)
1039
+ # Recursively delete all endpoint S3 artifacts (inference, etc)
1040
+ # Note: We do not want to delete the data_capture/ files since these
1041
+ # might be used for collection and data drift analysis
979
1042
  base_endpoint_path = f"{cls.endpoints_s3_path}/{endpoint_name}"
980
- s3_objects = wr.s3.list_objects(base_endpoint_path, boto3_session=cls.boto3_session)
981
- cls.log.info(f"Deleting S3 Objects at {base_endpoint_path}...")
982
- cls.log.info(f"{s3_objects}")
983
- wr.s3.delete_objects(s3_objects, boto3_session=cls.boto3_session)
1043
+ all_s3_objects = wr.s3.list_objects(base_endpoint_path, boto3_session=cls.boto3_session)
1044
+
1045
+ # Filter out objects that contain 'data_capture/' in their path
1046
+ s3_objects_to_delete = [obj for obj in all_s3_objects if "/data_capture/" not in obj]
1047
+ cls.log.info(f"Found {len(all_s3_objects)} total objects at {base_endpoint_path}")
1048
+ cls.log.info(f"Filtering out data_capture files, will delete {len(s3_objects_to_delete)} objects...")
1049
+ cls.log.info(f"Objects to delete: {s3_objects_to_delete}")
1050
+
1051
+ if s3_objects_to_delete:
1052
+ wr.s3.delete_objects(s3_objects_to_delete, boto3_session=cls.boto3_session)
1053
+ cls.log.info(f"Successfully deleted {len(s3_objects_to_delete)} objects")
1054
+ else:
1055
+ cls.log.info("No objects to delete (only data_capture files found)")
984
1056
 
985
1057
  # Delete any dataframes that were stored in the Dataframe Cache
986
1058
  cls.log.info("Deleting Dataframe Cache...")
@@ -1031,7 +1103,7 @@ class EndpointCore(Artifact):
1031
1103
  if __name__ == "__main__":
1032
1104
  """Exercise the Endpoint Class"""
1033
1105
  from workbench.api import FeatureSet
1034
- from workbench.utils.endpoint_utils import fs_evaluation_data
1106
+ from workbench.utils.endpoint_utils import get_evaluation_data
1035
1107
  import random
1036
1108
 
1037
1109
  # Grab an EndpointCore object and pull some information from it
@@ -1039,7 +1111,7 @@ if __name__ == "__main__":
1039
1111
 
1040
1112
  # Test various error conditions (set row 42 length to pd.NA)
1041
1113
  # Note: This test should return ALL rows
1042
- my_eval_df = fs_evaluation_data(my_endpoint)
1114
+ my_eval_df = get_evaluation_data(my_endpoint)
1043
1115
  my_eval_df.at[42, "length"] = pd.NA
1044
1116
  pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
1045
1117
  print(f"Sent rows: {len(my_eval_df)}")
@@ -1047,6 +1119,9 @@ if __name__ == "__main__":
1047
1119
  assert len(pred_results) == len(my_eval_df), "Predictions should match the number of sent rows"
1048
1120
 
1049
1121
  # Now we put in an invalid value
1122
+ print("*" * 80)
1123
+ print("NOW TESTING ERROR CONDITIONS...")
1124
+ print("*" * 80)
1050
1125
  my_eval_df.at[42, "length"] = "invalid_value"
1051
1126
  pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
1052
1127
  print(f"Sent rows: {len(my_eval_df)}")
@@ -1086,13 +1161,20 @@ if __name__ == "__main__":
1086
1161
  df = fs.pull_dataframe()[:100]
1087
1162
  cap_df = df.copy()
1088
1163
  cap_df.columns = [col.upper() for col in cap_df.columns]
1089
- my_endpoint._predict(cap_df)
1164
+ my_endpoint.inference(cap_df)
1090
1165
 
1091
1166
  # Boolean Type Test
1092
1167
  df["bool_column"] = [random.choice([True, False]) for _ in range(len(df))]
1093
- result_df = my_endpoint._predict(df)
1168
+ result_df = my_endpoint.inference(df)
1094
1169
  assert result_df["bool_column"].dtype == bool
1095
1170
 
1171
+ # Missing Feature Test
1172
+ missing_df = df.drop(columns=["length"])
1173
+ try:
1174
+ my_endpoint.inference(missing_df)
1175
+ except ValueError as e:
1176
+ print(f"Expected error for missing feature: {e}")
1177
+
1096
1178
  # Run Auto Inference on the Endpoint (uses the FeatureSet)
1097
1179
  print("Running Auto Inference...")
1098
1180
  my_endpoint.auto_inference()
@@ -1100,13 +1182,20 @@ if __name__ == "__main__":
1100
1182
  # Run Inference where we provide the data
1101
1183
  # Note: This dataframe could be from a FeatureSet or any other source
1102
1184
  print("Running Inference...")
1103
- my_eval_df = fs_evaluation_data(my_endpoint)
1185
+ my_eval_df = get_evaluation_data(my_endpoint)
1104
1186
  pred_results = my_endpoint.inference(my_eval_df)
1105
1187
 
1106
1188
  # Now set capture=True to save inference results and metrics
1107
- my_eval_df = fs_evaluation_data(my_endpoint)
1189
+ my_eval_df = get_evaluation_data(my_endpoint)
1108
1190
  pred_results = my_endpoint.inference(my_eval_df, capture_name="holdout_xyz")
1109
1191
 
1192
+ # Run predictions using the fast_inference method
1193
+ fast_results = my_endpoint.fast_inference(my_eval_df)
1194
+
1195
+ # Test the cross_fold_inference method
1196
+ print("Running Cross-Fold Inference...")
1197
+ metrics, all_results = my_endpoint.cross_fold_inference()
1198
+
1110
1199
  # Run Inference and metrics for a Classification Endpoint
1111
1200
  class_endpoint = EndpointCore("wine-classification")
1112
1201
  auto_predictions = class_endpoint.auto_inference()
@@ -1115,8 +1204,9 @@ if __name__ == "__main__":
1115
1204
  target = "wine_class"
1116
1205
  print(class_endpoint.generate_confusion_matrix(target, auto_predictions))
1117
1206
 
1118
- # Run predictions using the fast_inference method
1119
- fast_results = my_endpoint.fast_inference(my_eval_df)
1207
+ # Test the cross_fold_inference method
1208
+ print("Running Cross-Fold Inference...")
1209
+ metrics, all_results = class_endpoint.cross_fold_inference()
1120
1210
 
1121
1211
  # Test the class method delete (commented out for now)
1122
1212
  # from workbench.api import Model
@@ -17,7 +17,7 @@ from workbench.core.artifacts.artifact import Artifact
17
17
  from workbench.core.artifacts.data_source_factory import DataSourceFactory
18
18
  from workbench.core.artifacts.athena_source import AthenaSource
19
19
 
20
- from typing import TYPE_CHECKING
20
+ from typing import TYPE_CHECKING, Optional, List, Union
21
21
 
22
22
  from workbench.utils.aws_utils import aws_throttle
23
23
 
@@ -194,24 +194,24 @@ class FeatureSetCore(Artifact):
194
194
 
195
195
  return View(self, view_name)
196
196
 
197
- def set_display_columns(self, diplay_columns: list[str]):
197
+ def set_display_columns(self, display_columns: list[str]):
198
198
  """Set the display columns for this Data Source
199
199
 
200
200
  Args:
201
- diplay_columns (list[str]): The display columns for this Data Source
201
+ display_columns (list[str]): The display columns for this Data Source
202
202
  """
203
203
  # Check mismatch of display columns to computation columns
204
204
  c_view = self.view("computation")
205
205
  computation_columns = c_view.columns
206
- mismatch_columns = [col for col in diplay_columns if col not in computation_columns]
206
+ mismatch_columns = [col for col in display_columns if col not in computation_columns]
207
207
  if mismatch_columns:
208
208
  self.log.monitor(f"Display View/Computation mismatch: {mismatch_columns}")
209
209
 
210
- self.log.important(f"Setting Display Columns...{diplay_columns}")
210
+ self.log.important(f"Setting Display Columns...{display_columns}")
211
211
  from workbench.core.views import DisplayView
212
212
 
213
213
  # Create a NEW display view
214
- DisplayView.create(self, source_table=c_view.table, column_list=diplay_columns)
214
+ DisplayView.create(self, source_table=c_view.table, column_list=display_columns)
215
215
 
216
216
  def set_computation_columns(self, computation_columns: list[str], reset_display: bool = True):
217
217
  """Set the computation columns for this FeatureSet
@@ -509,6 +509,48 @@ class FeatureSetCore(Artifact):
509
509
  ].tolist()
510
510
  return hold_out_ids
511
511
 
512
+ def set_training_filter(self, filter_expression: Optional[str] = None):
513
+ """Set a filter expression for the training view for this FeatureSet
514
+
515
+ Args:
516
+ filter_expression (Optional[str]): A SQL filter expression (e.g., "age > 25 AND status = 'active'")
517
+ If None or empty string, will reset to training view with no filter
518
+ (default: None)
519
+ """
520
+ from workbench.core.views import TrainingView
521
+
522
+ # Grab the existing holdout ids
523
+ holdout_ids = self.get_training_holdouts()
524
+
525
+ # Create a NEW training view
526
+ self.log.important(f"Setting Training Filter: {filter_expression}")
527
+ TrainingView.create(
528
+ self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
529
+ )
530
+
531
+ def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
532
+ """Exclude a list of IDs from the training view
533
+
534
+ Args:
535
+ ids (List[Union[str, int]],): List of IDs to exclude from training
536
+ column_name (Optional[str]): Column name to filter on.
537
+ If None, uses self.id_column (default: None)
538
+ """
539
+ # Use the default id_column if not specified
540
+ column = column_name or self.id_column
541
+
542
+ # Handle empty list case
543
+ if not ids:
544
+ self.log.warning("No IDs provided to exclude")
545
+ return
546
+
547
+ # Build the filter expression with proper SQL quoting
548
+ quoted_ids = ", ".join([repr(id) for id in ids])
549
+ filter_expression = f"{column} NOT IN ({quoted_ids})"
550
+
551
+ # Apply the filter
552
+ self.set_training_filter(filter_expression)
553
+
512
554
  @classmethod
513
555
  def delete_views(cls, table: str, database: str):
514
556
  """Delete any views associated with this FeatureSet
@@ -707,7 +749,7 @@ if __name__ == "__main__":
707
749
 
708
750
  # Test getting the holdout ids
709
751
  print("Getting the hold out ids...")
710
- holdout_ids = my_features.get_training_holdouts("id")
752
+ holdout_ids = my_features.get_training_holdouts()
711
753
  print(f"Holdout IDs: {holdout_ids}")
712
754
 
713
755
  # Get a sample of the data
@@ -729,16 +771,33 @@ if __name__ == "__main__":
729
771
  table = my_features.view("training").table
730
772
  df = my_features.query(f'SELECT id, name FROM "{table}"')
731
773
  my_holdout_ids = [id for id in df["id"] if id < 20]
732
- my_features.set_training_holdouts("id", my_holdout_ids)
733
-
734
- # Test the hold out set functionality with strings
735
- print("Setting hold out ids (strings)...")
736
- my_holdout_ids = [name for name in df["name"] if int(name.split(" ")[1]) > 80]
737
- my_features.set_training_holdouts("name", my_holdout_ids)
774
+ my_features.set_training_holdouts(my_holdout_ids)
738
775
 
739
776
  # Get the training data
740
777
  print("Getting the training data...")
741
778
  training_data = my_features.get_training_data()
779
+ print(f"Training Data: {training_data.shape}")
780
+
781
+ # Test the filter expression functionality
782
+ print("Setting a filter expression...")
783
+ my_features.set_training_filter("id < 50 AND height > 65.0")
784
+ training_data = my_features.get_training_data()
785
+ print(f"Training Data: {training_data.shape}")
786
+ print(training_data)
787
+
788
+ # Remove training filter
789
+ print("Removing the filter expression...")
790
+ my_features.set_training_filter(None)
791
+ training_data = my_features.get_training_data()
792
+ print(f"Training Data: {training_data.shape}")
793
+ print(training_data)
794
+
795
+ # Test excluding ids from training
796
+ print("Excluding ids from training...")
797
+ my_features.exclude_ids_from_training([1, 2, 3, 4, 5])
798
+ training_data = my_features.get_training_data()
799
+ print(f"Training Data: {training_data.shape}")
800
+ print(training_data)
742
801
 
743
802
  # Now delete the AWS artifacts associated with this Feature Set
744
803
  # print("Deleting Workbench Feature Set...")