workbench 0.8.163__py3-none-any.whl → 0.8.164__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

workbench/api/endpoint.py CHANGED
@@ -70,6 +70,17 @@ class Endpoint(EndpointCore):
70
70
  """
71
71
  return super().fast_inference(eval_df, threads=threads)
72
72
 
73
+ def cross_fold_inference(self, nfolds: int = 5) -> dict:
74
+ """Run cross-fold inference (only works for XGBoost models)
75
+
76
+ Args:
77
+ nfolds (int): The number of folds to use for cross-validation (default: 5)
78
+
79
+ Returns:
80
+ dict: A dictionary with fold results
81
+ """
82
+ return super().cross_fold_inference(nfolds)
83
+
73
84
 
74
85
  if __name__ == "__main__":
75
86
  """Exercise the Endpoint Class"""
@@ -378,16 +378,17 @@ class EndpointCore(Artifact):
378
378
  self.log.important("No model associated with this endpoint, running 'no frills' inference...")
379
379
  return self.fast_inference(eval_df)
380
380
 
381
+ # Grab the model features and target column
382
+ model = ModelCore(self.model_name)
383
+ features = model.features()
384
+ target_column = model.target()
385
+
381
386
  # Run predictions on the evaluation data
382
- prediction_df = self._predict(eval_df, drop_error_rows)
387
+ prediction_df = self._predict(eval_df, features, drop_error_rows)
383
388
  if prediction_df.empty:
384
389
  self.log.warning("No predictions were made. Returning empty DataFrame.")
385
390
  return prediction_df
386
391
 
387
- # Get the target column
388
- model = ModelCore(self.model_name)
389
- target_column = model.target()
390
-
391
392
  # Sanity Check that the target column is present
392
393
  if target_column and (target_column not in prediction_df.columns):
393
394
  self.log.important(f"Target Column {target_column} not found in prediction_df!")
@@ -419,12 +420,6 @@ class EndpointCore(Artifact):
419
420
  capture_name, prediction_df, target_column, model_type, metrics, description, features, id_column
420
421
  )
421
422
 
422
- # Capture CrossFold Inference Results
423
- cross_fold_metrics = cross_fold_inference(model)
424
- if cross_fold_metrics:
425
- # Now put into the Parameter Store Model Inference Namespace
426
- self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", cross_fold_metrics)
427
-
428
423
  # For UQ Models we also capture the uncertainty metrics
429
424
  if model_type in [ModelType.UQ_REGRESSOR]:
430
425
  metrics = uq_metrics(prediction_df, target_column)
@@ -435,6 +430,25 @@ class EndpointCore(Artifact):
435
430
  # Return the prediction DataFrame
436
431
  return prediction_df
437
432
 
433
+ def cross_fold_inference(self, nfolds: int = 5) -> dict:
434
+ """Run cross-fold inference (only works for XGBoost models)
435
+
436
+ Args:
437
+ nfolds (int): Number of folds to use for cross-fold (default: 5)
438
+
439
+ Returns:
440
+ dict: Dictionary with the cross-fold inference results
441
+ """
442
+
443
+ # Grab our model
444
+ model = ModelCore(self.model_name)
445
+
446
+ # Compute CrossFold Metrics
447
+ cross_fold_metrics = cross_fold_inference(model, nfolds=nfolds)
448
+ if cross_fold_metrics:
449
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", cross_fold_metrics)
450
+ return cross_fold_metrics
451
+
438
452
  def fast_inference(self, eval_df: pd.DataFrame, threads: int = 4) -> pd.DataFrame:
439
453
  """Run inference on the Endpoint using the provided DataFrame
440
454
 
@@ -450,11 +464,12 @@ class EndpointCore(Artifact):
450
464
  """
451
465
  return fast_inference(self.name, eval_df, self.sm_session, threads=threads)
452
466
 
453
- def _predict(self, eval_df: pd.DataFrame, drop_error_rows: bool = False) -> pd.DataFrame:
454
- """Internal: Run prediction on the given observations in the given DataFrame
467
+ def _predict(self, eval_df: pd.DataFrame, features: list[str], drop_error_rows: bool = False) -> pd.DataFrame:
468
+ """Internal: Run prediction on observations in the given DataFrame
455
469
 
456
470
  Args:
457
471
  eval_df (pd.DataFrame): DataFrame to run predictions on (must have superset of features)
472
+ features (list[str]): List of feature column names needed for prediction
458
473
  drop_error_rows (bool): If True, drop rows that had endpoint errors/issues (default=False)
459
474
  Returns:
460
475
  pd.DataFrame: Return the DataFrame with additional columns, prediction and any _proba columns
@@ -465,19 +480,12 @@ class EndpointCore(Artifact):
465
480
  self.log.warning("Evaluation DataFrame has 0 rows. No predictions to run.")
466
481
  return pd.DataFrame(columns=eval_df.columns) # Return empty DataFrame with same structure
467
482
 
468
- # Sanity check: Does the Model have Features?
469
- features = ModelCore(self.model_name).features()
470
- if not features:
471
- self.log.warning("Model does not have features defined, using all columns in the DataFrame")
472
- else:
473
- # Sanity check: Does the DataFrame have the required features?
474
- df_columns_lower = set(col.lower() for col in eval_df.columns)
475
- features_lower = set(feature.lower() for feature in features)
476
-
477
- # Check if the features are a subset of the DataFrame columns (case-insensitive)
478
- if not features_lower.issubset(df_columns_lower):
479
- missing_features = features_lower - df_columns_lower
480
- raise ValueError(f"DataFrame does not contain required features: {missing_features}")
483
+ # Sanity check: Does the DataFrame have the required features?
484
+ df_columns_lower = set(col.lower() for col in eval_df.columns)
485
+ features_lower = set(feature.lower() for feature in features)
486
+ if not features_lower.issubset(df_columns_lower):
487
+ missing_features = features_lower - df_columns_lower
488
+ raise ValueError(f"DataFrame does not contain required features: {missing_features}")
481
489
 
482
490
  # Create our Endpoint Predictor Class
483
491
  predictor = Predictor(
@@ -713,18 +721,10 @@ class EndpointCore(Artifact):
713
721
  # Note: Unlike other dataframes here, we want to write the index (labels) to the CSV
714
722
  wr.s3.to_csv(conf_mtx, f"{inference_capture_path}/inference_cm.csv", index=True)
715
723
 
716
- # Generate SHAP values for our Prediction Dataframe
717
- # generate_shap_values(self.endpoint_name, model_type.value, pred_results_df, inference_capture_path)
718
-
719
724
  # Now recompute the details for our Model
720
- self.log.important(f"Recomputing Details for {self.model_name} to show latest Inference Results...")
725
+ self.log.important(f"Loading inference metrics for {self.model_name}...")
721
726
  model = ModelCore(self.model_name)
722
727
  model._load_inference_metrics(capture_name)
723
- model.details()
724
-
725
- # Recompute the details so that inference model metrics are updated
726
- self.log.important(f"Recomputing Details for {self.name} to show latest Inference Results...")
727
- self.details()
728
728
 
729
729
  def regression_metrics(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
730
730
  """Compute the performance metrics for this Endpoint
@@ -876,9 +876,11 @@ class EndpointCore(Artifact):
876
876
 
877
877
  def generate_confusion_matrix(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
878
878
  """Compute the confusion matrix for this Endpoint
879
+
879
880
  Args:
880
881
  target_column (str): Name of the target column
881
882
  prediction_df (pd.DataFrame): DataFrame with the prediction results
883
+
882
884
  Returns:
883
885
  pd.DataFrame: DataFrame with the confusion matrix
884
886
  """
@@ -887,25 +889,20 @@ class EndpointCore(Artifact):
887
889
  prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
888
890
  y_pred = prediction_df[prediction_col]
889
891
 
890
- # Check if our model has class labels, if not we'll use the unique labels in the prediction
891
- class_labels = ModelCore(self.model_name).class_labels()
892
- if class_labels is None:
893
- class_labels = sorted(list(set(y_true) | set(y_pred)))
892
+ # Get model class labels
893
+ model_class_labels = ModelCore(self.model_name).class_labels()
894
894
 
895
- # Compute the confusion matrix (sklearn confusion_matrix)
896
- conf_mtx = confusion_matrix(y_true, y_pred, labels=class_labels)
895
+ # Use model labels if available, otherwise infer from data
896
+ if model_class_labels:
897
+ self.log.important("Using model class labels for confusion matrix ordering...")
898
+ labels = model_class_labels
899
+ else:
900
+ labels = sorted(list(set(y_true) | set(y_pred)))
897
901
 
898
- # Create a DataFrame
899
- conf_mtx_df = pd.DataFrame(conf_mtx, index=class_labels, columns=class_labels)
902
+ # Compute confusion matrix and create DataFrame
903
+ conf_mtx = confusion_matrix(y_true, y_pred, labels=labels)
904
+ conf_mtx_df = pd.DataFrame(conf_mtx, index=labels, columns=labels)
900
905
  conf_mtx_df.index.name = "labels"
901
-
902
- # Check if our model has class labels. If so make the index and columns ordered
903
- model_class_labels = ModelCore(self.model_name).class_labels()
904
- if model_class_labels:
905
- self.log.important("Reordering the confusion matrix based on model class labels...")
906
- conf_mtx_df.index = pd.Categorical(conf_mtx_df.index, categories=model_class_labels, ordered=True)
907
- conf_mtx_df.columns = pd.Categorical(conf_mtx_df.columns, categories=model_class_labels, ordered=True)
908
- conf_mtx_df = conf_mtx_df.sort_index().sort_index(axis=1)
909
906
  return conf_mtx_df
910
907
 
911
908
  def endpoint_config_name(self) -> str:
@@ -1086,13 +1083,20 @@ if __name__ == "__main__":
1086
1083
  df = fs.pull_dataframe()[:100]
1087
1084
  cap_df = df.copy()
1088
1085
  cap_df.columns = [col.upper() for col in cap_df.columns]
1089
- my_endpoint._predict(cap_df)
1086
+ my_endpoint.inference(cap_df)
1090
1087
 
1091
1088
  # Boolean Type Test
1092
1089
  df["bool_column"] = [random.choice([True, False]) for _ in range(len(df))]
1093
- result_df = my_endpoint._predict(df)
1090
+ result_df = my_endpoint.inference(df)
1094
1091
  assert result_df["bool_column"].dtype == bool
1095
1092
 
1093
+ # Missing Feature Test
1094
+ missing_df = df.drop(columns=["length"])
1095
+ try:
1096
+ my_endpoint.inference(missing_df)
1097
+ except ValueError as e:
1098
+ print(f"Expected error for missing feature: {e}")
1099
+
1096
1100
  # Run Auto Inference on the Endpoint (uses the FeatureSet)
1097
1101
  print("Running Auto Inference...")
1098
1102
  my_endpoint.auto_inference()
@@ -1107,6 +1111,9 @@ if __name__ == "__main__":
1107
1111
  my_eval_df = fs_evaluation_data(my_endpoint)
1108
1112
  pred_results = my_endpoint.inference(my_eval_df, capture_name="holdout_xyz")
1109
1113
 
1114
+ # Run predictions using the fast_inference method
1115
+ fast_results = my_endpoint.fast_inference(my_eval_df)
1116
+
1110
1117
  # Run Inference and metrics for a Classification Endpoint
1111
1118
  class_endpoint = EndpointCore("wine-classification")
1112
1119
  auto_predictions = class_endpoint.auto_inference()
@@ -1115,9 +1122,6 @@ if __name__ == "__main__":
1115
1122
  target = "wine_class"
1116
1123
  print(class_endpoint.generate_confusion_matrix(target, auto_predictions))
1117
1124
 
1118
- # Run predictions using the fast_inference method
1119
- fast_results = my_endpoint.fast_inference(my_eval_df)
1120
-
1121
1125
  # Test the class method delete (commented out for now)
1122
1126
  # from workbench.api import Model
1123
1127
  # model = Model("abalone-regression")
@@ -181,9 +181,6 @@ def logging_setup(color_logs=True):
181
181
  log.debug("Debugging enabled via WORKBENCH_DEBUG environment variable.")
182
182
  else:
183
183
  log.setLevel(logging.INFO)
184
- # Note: Not using the ThrottlingFilter for now
185
- # throttle_filter = ThrottlingFilter(rate_seconds=5)
186
- # handler.addFilter(throttle_filter)
187
184
 
188
185
  # Suppress specific logger
189
186
  logging.getLogger("sagemaker.config").setLevel(logging.WARNING)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: workbench
3
- Version: 0.8.163
3
+ Version: 0.8.164
4
4
  Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
5
5
  Author-email: SuperCowPowers LLC <support@supercowpowers.com>
6
6
  License-Expression: MIT
@@ -31,7 +31,7 @@ workbench/api/__init__.py,sha256=kvrP70ypDOMdPGj_Eeftdh8J0lu_1qQVne6GXMkD4_E,102
31
31
  workbench/api/compound.py,sha256=BHd3Qu4Ra45FEuwiowhFfGMI_HKRRB10XMmoS6ljKrM,2541
32
32
  workbench/api/data_source.py,sha256=Ngz36YZWxFfpJbmURhM1LQPYjh5kdpZNGo6_fCRePbA,8321
33
33
  workbench/api/df_store.py,sha256=Wybb3zO-jPpAi2Ns8Ks1-lagvXAaBlRpBZHhnnl3Lms,6131
34
- workbench/api/endpoint.py,sha256=ejDnfBBgNYMZB-bOA5nX7C6CtBlAjmtrF8M_zpri9Io,3451
34
+ workbench/api/endpoint.py,sha256=RWGqxsCW_pMiENMb_XZlm2ZCldMS4suEBM3F5gT3hYI,3814
35
35
  workbench/api/feature_set.py,sha256=wzNxNjN0K2FaIC7QUIogMnoHqw2vo0iAHYlGk6fWLCw,6649
36
36
  workbench/api/graph_store.py,sha256=LremJyPrQFgsHb7hxsctuCsoxx3p7TKtaY5qALHe6pc,4372
37
37
  workbench/api/meta.py,sha256=fCOtZMfAHWaerzcsTeFnimXfgV8STe9JDiB7QBogktc,8456
@@ -53,7 +53,7 @@ workbench/core/artifacts/athena_source.py,sha256=RNmCe7s6uH4gVHpcdJcL84aSbF5Q1ah
53
53
  workbench/core/artifacts/cached_artifact_mixin.py,sha256=ngqFLZ4cQx_TFouXZgXZQsv_7W6XCvxVGXXSfzzaft8,3775
54
54
  workbench/core/artifacts/data_source_abstract.py,sha256=5IRCzFVK-17cd4NXPMRfx99vQAmQ0WHE5jcm5RfsVTg,10619
55
55
  workbench/core/artifacts/data_source_factory.py,sha256=YL_tA5fsgubbB3dPF6T4tO0rGgz-6oo3ge4i_YXVC-M,2380
56
- workbench/core/artifacts/endpoint_core.py,sha256=L6uWOxHKItjbpRS2rFrAbxAqDyZIv2CO9dnZpohKrUI,48768
56
+ workbench/core/artifacts/endpoint_core.py,sha256=6uDOl-VKrTbLMlHZEYFY80XwrCP5H0W36JoHySjhl7M,48163
57
57
  workbench/core/artifacts/feature_set_core.py,sha256=055VdSYR09HP4ygAuYvIYtHQ7Ec4XxsZygpgEl5H5jQ,29136
58
58
  workbench/core/artifacts/model_core.py,sha256=U0dSkpZMrsIgbUglVkPwAgN0gji7Oa7glOjqMQJDAzE,50927
59
59
  workbench/core/artifacts/monitor_core.py,sha256=BP6UuCyBI4zB2wwcIXvUw6RC0EktTcQd5Rv0x73qzio,37670
@@ -239,7 +239,7 @@ workbench/utils/trace_calls.py,sha256=tY4DOVMGXBh-mbUWzo1l-X9XjD0ux_qR9I1ypkjWNI
239
239
  workbench/utils/type_abbrev.py,sha256=3ai7ZbE8BgvdotOSb48w_BmgrEGVYvLoyzoNYH8ZuOs,1470
240
240
  workbench/utils/workbench_cache.py,sha256=IQchxB81iR4eVggHBxUJdXxUCRkqWz1jKe5gxN3z6yc,5657
241
241
  workbench/utils/workbench_event_bridge.py,sha256=z1GmXOB-Qs7VOgC6Hjnp2DI9nSEWepaSXejACxTIR7o,4150
242
- workbench/utils/workbench_logging.py,sha256=aOUjMZeKqrK03z5mwuVAAwwjIjVxyTA7g-brr85oxY8,10424
242
+ workbench/utils/workbench_logging.py,sha256=WCuMWhQwibrvcGAyj96h2wowh6dH7zNlDJ7sWUzdCeI,10263
243
243
  workbench/utils/workbench_sqs.py,sha256=WFQTqOxoEdOzPEMmTVZcdPzylmkynZ5aKtvRrOAO06w,2127
244
244
  workbench/utils/xgboost_model_utils.py,sha256=AEBSyIXYFk6vI3u89w7J4VdI1dgNJOgQe6XZv4pUhOM,15501
245
245
  workbench/web_interface/components/component_interface.py,sha256=QCPWqiZLkVsAEzQFEQxFelk7H0UF5uI2dVvJNf0lRV4,7980
@@ -275,9 +275,9 @@ workbench/web_interface/page_views/main_page.py,sha256=X4-KyGTKLAdxR-Zk2niuLJB2Y
275
275
  workbench/web_interface/page_views/models_page_view.py,sha256=M0bdC7bAzLyIaE2jviY12FF4abdMFZmg6sFuOY_LaGI,2650
276
276
  workbench/web_interface/page_views/page_view.py,sha256=Gh6YnpOGlUejx-bHZAf5pzqoQ1H1R0OSwOpGhOBO06w,455
277
277
  workbench/web_interface/page_views/pipelines_page_view.py,sha256=v2pxrIbsHBcYiblfius3JK766NZ7ciD2yPx0t3E5IJo,2656
278
- workbench-0.8.163.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
279
- workbench-0.8.163.dist-info/METADATA,sha256=TwnUicLddrHeMkx_gDGiUR6uQD7TR6mRjNG0XY3kh1E,9209
280
- workbench-0.8.163.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
281
- workbench-0.8.163.dist-info/entry_points.txt,sha256=oZykkheWiiIBjRE8cS5SdcxwmZKSFaQEGwMBjNh-eNM,238
282
- workbench-0.8.163.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
283
- workbench-0.8.163.dist-info/RECORD,,
278
+ workbench-0.8.164.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
279
+ workbench-0.8.164.dist-info/METADATA,sha256=qZKnCu_6ahD4lz6rmMk2VW4RyI--YnZafIZaUGMiOHI,9209
280
+ workbench-0.8.164.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
281
+ workbench-0.8.164.dist-info/entry_points.txt,sha256=oZykkheWiiIBjRE8cS5SdcxwmZKSFaQEGwMBjNh-eNM,238
282
+ workbench-0.8.164.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
283
+ workbench-0.8.164.dist-info/RECORD,,