workbench 0.8.161__py3-none-any.whl → 0.8.192__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/proximity.py +143 -102
- workbench/algorithms/graph/light/proximity_graph.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +12 -0
- workbench/api/feature_set.py +4 -4
- workbench/api/meta.py +5 -2
- workbench/api/model.py +16 -12
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +168 -78
- workbench/core/artifacts/feature_set_core.py +72 -13
- workbench/core/artifacts/model_core.py +50 -15
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +9 -4
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +49 -53
- workbench/core/views/view.py +51 -1
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
- workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
- workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
- workbench/model_scripts/pytorch_model/pytorch.template +19 -20
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +7 -2
- workbench/model_scripts/uq_models/mapie.template +492 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/xgb_model.template +31 -40
- workbench/repl/workbench_shell.py +11 -6
- workbench/scripts/lambda_launcher.py +63 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +76 -30
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +283 -145
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/METADATA +4 -4
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/RECORD +81 -76
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/entry_points.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/pytorch_model/generated_model_script.py +0 -565
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
- workbench/model_scripts/xgb_model/generated_model_script.py +0 -477
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/WHEEL +0 -0
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/top_level.txt +0 -0
|
@@ -8,7 +8,7 @@ import pandas as pd
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from io import StringIO
|
|
10
10
|
import awswrangler as wr
|
|
11
|
-
from typing import Union, Optional
|
|
11
|
+
from typing import Union, Optional, Tuple
|
|
12
12
|
import hashlib
|
|
13
13
|
|
|
14
14
|
# Model Performance Scores
|
|
@@ -32,11 +32,11 @@ from sagemaker import Predictor
|
|
|
32
32
|
from workbench.core.artifacts.artifact import Artifact
|
|
33
33
|
from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType
|
|
34
34
|
from workbench.utils.endpoint_metrics import EndpointMetrics
|
|
35
|
-
from workbench.utils.fast_inference import fast_inference
|
|
36
35
|
from workbench.utils.cache import Cache
|
|
37
36
|
from workbench.utils.s3_utils import compute_s3_object_hash
|
|
38
37
|
from workbench.utils.model_utils import uq_metrics
|
|
39
38
|
from workbench.utils.xgboost_model_utils import cross_fold_inference
|
|
39
|
+
from workbench_bridges.endpoints.fast_inference import fast_inference
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
class EndpointCore(Artifact):
|
|
@@ -164,11 +164,17 @@ class EndpointCore(Artifact):
|
|
|
164
164
|
"""
|
|
165
165
|
return "Serverless" in self.endpoint_meta["InstanceType"]
|
|
166
166
|
|
|
167
|
-
def
|
|
167
|
+
def data_capture(self):
|
|
168
|
+
"""Get the MonitorCore class for this endpoint"""
|
|
169
|
+
from workbench.core.artifacts.data_capture_core import DataCaptureCore
|
|
170
|
+
|
|
171
|
+
return DataCaptureCore(self.endpoint_name)
|
|
172
|
+
|
|
173
|
+
def enable_data_capture(self):
|
|
168
174
|
"""Add data capture to the endpoint"""
|
|
169
|
-
self.
|
|
175
|
+
self.data_capture().enable()
|
|
170
176
|
|
|
171
|
-
def
|
|
177
|
+
def monitor(self):
|
|
172
178
|
"""Get the MonitorCore class for this endpoint"""
|
|
173
179
|
from workbench.core.artifacts.monitor_core import MonitorCore
|
|
174
180
|
|
|
@@ -350,7 +356,7 @@ class EndpointCore(Artifact):
|
|
|
350
356
|
return pd.DataFrame()
|
|
351
357
|
|
|
352
358
|
# Grab the evaluation data from the FeatureSet
|
|
353
|
-
table =
|
|
359
|
+
table = model.training_view().table
|
|
354
360
|
eval_df = fs.query(f'SELECT * FROM "{table}" where training = FALSE')
|
|
355
361
|
capture_name = "auto_inference" if capture else None
|
|
356
362
|
return self.inference(eval_df, capture_name, id_column=fs.id_column)
|
|
@@ -378,16 +384,17 @@ class EndpointCore(Artifact):
|
|
|
378
384
|
self.log.important("No model associated with this endpoint, running 'no frills' inference...")
|
|
379
385
|
return self.fast_inference(eval_df)
|
|
380
386
|
|
|
387
|
+
# Grab the model features and target column
|
|
388
|
+
model = ModelCore(self.model_name)
|
|
389
|
+
features = model.features()
|
|
390
|
+
target_column = model.target()
|
|
391
|
+
|
|
381
392
|
# Run predictions on the evaluation data
|
|
382
|
-
prediction_df = self._predict(eval_df, drop_error_rows)
|
|
393
|
+
prediction_df = self._predict(eval_df, features, drop_error_rows)
|
|
383
394
|
if prediction_df.empty:
|
|
384
395
|
self.log.warning("No predictions were made. Returning empty DataFrame.")
|
|
385
396
|
return prediction_df
|
|
386
397
|
|
|
387
|
-
# Get the target column
|
|
388
|
-
model = ModelCore(self.model_name)
|
|
389
|
-
target_column = model.target()
|
|
390
|
-
|
|
391
398
|
# Sanity Check that the target column is present
|
|
392
399
|
if target_column and (target_column not in prediction_df.columns):
|
|
393
400
|
self.log.important(f"Target Column {target_column} not found in prediction_df!")
|
|
@@ -413,28 +420,95 @@ class EndpointCore(Artifact):
|
|
|
413
420
|
|
|
414
421
|
# Capture the inference results and metrics
|
|
415
422
|
if capture_name is not None:
|
|
423
|
+
|
|
424
|
+
# If we don't have an id_column, we'll pull it from the model's FeatureSet
|
|
425
|
+
if id_column is None:
|
|
426
|
+
fs = FeatureSetCore(model.get_input())
|
|
427
|
+
id_column = fs.id_column
|
|
416
428
|
description = capture_name.replace("_", " ").title()
|
|
417
|
-
features = model.features()
|
|
418
429
|
self._capture_inference_results(
|
|
419
430
|
capture_name, prediction_df, target_column, model_type, metrics, description, features, id_column
|
|
420
431
|
)
|
|
421
432
|
|
|
422
|
-
# Capture CrossFold Inference Results
|
|
423
|
-
cross_fold_metrics = cross_fold_inference(model)
|
|
424
|
-
if cross_fold_metrics:
|
|
425
|
-
# Now put into the Parameter Store Model Inference Namespace
|
|
426
|
-
self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", cross_fold_metrics)
|
|
427
|
-
|
|
428
433
|
# For UQ Models we also capture the uncertainty metrics
|
|
429
434
|
if model_type in [ModelType.UQ_REGRESSOR]:
|
|
430
435
|
metrics = uq_metrics(prediction_df, target_column)
|
|
431
|
-
|
|
432
|
-
# Now put into the Parameter Store Model Inference Namespace
|
|
433
436
|
self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
|
|
434
437
|
|
|
435
438
|
# Return the prediction DataFrame
|
|
436
439
|
return prediction_df
|
|
437
440
|
|
|
441
|
+
def cross_fold_inference(self, nfolds: int = 5) -> Tuple[dict, pd.DataFrame]:
|
|
442
|
+
"""Run cross-fold inference (only works for XGBoost models)
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
nfolds (int): Number of folds to use for cross-fold (default: 5)
|
|
446
|
+
|
|
447
|
+
Returns:
|
|
448
|
+
Tuple[dict, pd.DataFrame]: Tuple of (cross_fold_metrics, out_of_fold_df)
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
# Grab our model
|
|
452
|
+
model = ModelCore(self.model_name)
|
|
453
|
+
|
|
454
|
+
# Compute CrossFold Metrics
|
|
455
|
+
cross_fold_metrics, out_of_fold_df = cross_fold_inference(model, nfolds=nfolds)
|
|
456
|
+
if cross_fold_metrics:
|
|
457
|
+
self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", cross_fold_metrics)
|
|
458
|
+
|
|
459
|
+
# Capture the results
|
|
460
|
+
capture_name = "full_cross_fold"
|
|
461
|
+
description = capture_name.replace("_", " ").title()
|
|
462
|
+
target_column = model.target()
|
|
463
|
+
model_type = model.model_type
|
|
464
|
+
|
|
465
|
+
# Get the id_column from the model's FeatureSet
|
|
466
|
+
fs = FeatureSetCore(model.get_input())
|
|
467
|
+
id_column = fs.id_column
|
|
468
|
+
|
|
469
|
+
# Is this a UQ Model? If so, run full inference and merge the results
|
|
470
|
+
additional_columns = []
|
|
471
|
+
if model_type == ModelType.UQ_REGRESSOR:
|
|
472
|
+
self.log.important("UQ Regressor detected, running full inference to get uncertainty estimates...")
|
|
473
|
+
|
|
474
|
+
# Get the training view dataframe for inference
|
|
475
|
+
training_df = model.training_view().pull_dataframe()
|
|
476
|
+
|
|
477
|
+
# Run inference on the endpoint to get UQ outputs
|
|
478
|
+
uq_df = self.inference(training_df)
|
|
479
|
+
|
|
480
|
+
# Identify UQ-specific columns (quantiles and prediction_std)
|
|
481
|
+
uq_columns = [col for col in uq_df.columns if col.startswith("q_") or col == "prediction_std"]
|
|
482
|
+
|
|
483
|
+
# Merge UQ columns with out-of-fold predictions
|
|
484
|
+
if uq_columns:
|
|
485
|
+
# Keep id_column and UQ columns, drop 'prediction' to avoid conflict when merging
|
|
486
|
+
uq_df = uq_df[[id_column] + uq_columns]
|
|
487
|
+
|
|
488
|
+
# Drop duplicates in uq_df based on id_column
|
|
489
|
+
uq_df = uq_df.drop_duplicates(subset=[id_column])
|
|
490
|
+
|
|
491
|
+
# Merge UQ columns into out_of_fold_df
|
|
492
|
+
out_of_fold_df = pd.merge(out_of_fold_df, uq_df, on=id_column, how="left")
|
|
493
|
+
additional_columns = uq_columns
|
|
494
|
+
self.log.info(f"Added UQ columns: {', '.join(additional_columns)}")
|
|
495
|
+
|
|
496
|
+
# Also compute UQ metrics
|
|
497
|
+
metrics = uq_metrics(out_of_fold_df, target_column)
|
|
498
|
+
self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
|
|
499
|
+
|
|
500
|
+
self._capture_inference_results(
|
|
501
|
+
capture_name,
|
|
502
|
+
out_of_fold_df,
|
|
503
|
+
target_column,
|
|
504
|
+
model_type,
|
|
505
|
+
pd.DataFrame([cross_fold_metrics["summary_metrics"]]),
|
|
506
|
+
description,
|
|
507
|
+
features=additional_columns,
|
|
508
|
+
id_column=id_column,
|
|
509
|
+
)
|
|
510
|
+
return cross_fold_metrics, out_of_fold_df
|
|
511
|
+
|
|
438
512
|
def fast_inference(self, eval_df: pd.DataFrame, threads: int = 4) -> pd.DataFrame:
|
|
439
513
|
"""Run inference on the Endpoint using the provided DataFrame
|
|
440
514
|
|
|
@@ -450,11 +524,12 @@ class EndpointCore(Artifact):
|
|
|
450
524
|
"""
|
|
451
525
|
return fast_inference(self.name, eval_df, self.sm_session, threads=threads)
|
|
452
526
|
|
|
453
|
-
def _predict(self, eval_df: pd.DataFrame, drop_error_rows: bool = False) -> pd.DataFrame:
|
|
454
|
-
"""Internal: Run prediction on
|
|
527
|
+
def _predict(self, eval_df: pd.DataFrame, features: list[str], drop_error_rows: bool = False) -> pd.DataFrame:
|
|
528
|
+
"""Internal: Run prediction on observations in the given DataFrame
|
|
455
529
|
|
|
456
530
|
Args:
|
|
457
531
|
eval_df (pd.DataFrame): DataFrame to run predictions on (must have superset of features)
|
|
532
|
+
features (list[str]): List of feature column names needed for prediction
|
|
458
533
|
drop_error_rows (bool): If True, drop rows that had endpoint errors/issues (default=False)
|
|
459
534
|
Returns:
|
|
460
535
|
pd.DataFrame: Return the DataFrame with additional columns, prediction and any _proba columns
|
|
@@ -465,19 +540,12 @@ class EndpointCore(Artifact):
|
|
|
465
540
|
self.log.warning("Evaluation DataFrame has 0 rows. No predictions to run.")
|
|
466
541
|
return pd.DataFrame(columns=eval_df.columns) # Return empty DataFrame with same structure
|
|
467
542
|
|
|
468
|
-
# Sanity check: Does the
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
df_columns_lower = set(col.lower() for col in eval_df.columns)
|
|
475
|
-
features_lower = set(feature.lower() for feature in features)
|
|
476
|
-
|
|
477
|
-
# Check if the features are a subset of the DataFrame columns (case-insensitive)
|
|
478
|
-
if not features_lower.issubset(df_columns_lower):
|
|
479
|
-
missing_features = features_lower - df_columns_lower
|
|
480
|
-
raise ValueError(f"DataFrame does not contain required features: {missing_features}")
|
|
543
|
+
# Sanity check: Does the DataFrame have the required features?
|
|
544
|
+
df_columns_lower = set(col.lower() for col in eval_df.columns)
|
|
545
|
+
features_lower = set(feature.lower() for feature in features)
|
|
546
|
+
if not features_lower.issubset(df_columns_lower):
|
|
547
|
+
missing_features = features_lower - df_columns_lower
|
|
548
|
+
raise ValueError(f"DataFrame does not contain required features: {missing_features}")
|
|
481
549
|
|
|
482
550
|
# Create our Endpoint Predictor Class
|
|
483
551
|
predictor = Predictor(
|
|
@@ -634,6 +702,10 @@ class EndpointCore(Artifact):
|
|
|
634
702
|
@staticmethod
|
|
635
703
|
def _hash_dataframe(df: pd.DataFrame, hash_length: int = 8):
|
|
636
704
|
# Internal: Compute a data hash for the dataframe
|
|
705
|
+
if df.empty:
|
|
706
|
+
return "--hash--"
|
|
707
|
+
|
|
708
|
+
# Sort the dataframe by columns to ensure consistent ordering
|
|
637
709
|
df = df.copy()
|
|
638
710
|
df = df.sort_values(by=sorted(df.columns.tolist()))
|
|
639
711
|
row_hashes = pd.util.hash_pandas_object(df, index=False)
|
|
@@ -688,8 +760,8 @@ class EndpointCore(Artifact):
|
|
|
688
760
|
wr.s3.to_csv(metrics, f"{inference_capture_path}/inference_metrics.csv", index=False)
|
|
689
761
|
|
|
690
762
|
# Grab the target column, prediction column, any _proba columns, and the ID column (if present)
|
|
691
|
-
|
|
692
|
-
output_columns
|
|
763
|
+
output_columns = [target_column]
|
|
764
|
+
output_columns += [col for col in pred_results_df.columns if "prediction" in col]
|
|
693
765
|
|
|
694
766
|
# Add any _proba columns to the output columns
|
|
695
767
|
output_columns += [col for col in pred_results_df.columns if col.endswith("_proba")]
|
|
@@ -699,7 +771,7 @@ class EndpointCore(Artifact):
|
|
|
699
771
|
|
|
700
772
|
# Add the ID column
|
|
701
773
|
if id_column and id_column in pred_results_df.columns:
|
|
702
|
-
output_columns.
|
|
774
|
+
output_columns.insert(0, id_column)
|
|
703
775
|
|
|
704
776
|
# Write the predictions to our S3 Model Inference Folder
|
|
705
777
|
self.log.info(f"Writing predictions to {inference_capture_path}/inference_predictions.csv")
|
|
@@ -713,18 +785,10 @@ class EndpointCore(Artifact):
|
|
|
713
785
|
# Note: Unlike other dataframes here, we want to write the index (labels) to the CSV
|
|
714
786
|
wr.s3.to_csv(conf_mtx, f"{inference_capture_path}/inference_cm.csv", index=True)
|
|
715
787
|
|
|
716
|
-
# Generate SHAP values for our Prediction Dataframe
|
|
717
|
-
# generate_shap_values(self.endpoint_name, model_type.value, pred_results_df, inference_capture_path)
|
|
718
|
-
|
|
719
788
|
# Now recompute the details for our Model
|
|
720
|
-
self.log.important(f"
|
|
789
|
+
self.log.important(f"Loading inference metrics for {self.model_name}...")
|
|
721
790
|
model = ModelCore(self.model_name)
|
|
722
791
|
model._load_inference_metrics(capture_name)
|
|
723
|
-
model.details()
|
|
724
|
-
|
|
725
|
-
# Recompute the details so that inference model metrics are updated
|
|
726
|
-
self.log.important(f"Recomputing Details for {self.name} to show latest Inference Results...")
|
|
727
|
-
self.details()
|
|
728
792
|
|
|
729
793
|
def regression_metrics(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
|
|
730
794
|
"""Compute the performance metrics for this Endpoint
|
|
@@ -876,9 +940,11 @@ class EndpointCore(Artifact):
|
|
|
876
940
|
|
|
877
941
|
def generate_confusion_matrix(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
|
|
878
942
|
"""Compute the confusion matrix for this Endpoint
|
|
943
|
+
|
|
879
944
|
Args:
|
|
880
945
|
target_column (str): Name of the target column
|
|
881
946
|
prediction_df (pd.DataFrame): DataFrame with the prediction results
|
|
947
|
+
|
|
882
948
|
Returns:
|
|
883
949
|
pd.DataFrame: DataFrame with the confusion matrix
|
|
884
950
|
"""
|
|
@@ -887,25 +953,20 @@ class EndpointCore(Artifact):
|
|
|
887
953
|
prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
|
|
888
954
|
y_pred = prediction_df[prediction_col]
|
|
889
955
|
|
|
890
|
-
#
|
|
891
|
-
|
|
892
|
-
if class_labels is None:
|
|
893
|
-
class_labels = sorted(list(set(y_true) | set(y_pred)))
|
|
956
|
+
# Get model class labels
|
|
957
|
+
model_class_labels = ModelCore(self.model_name).class_labels()
|
|
894
958
|
|
|
895
|
-
#
|
|
896
|
-
|
|
959
|
+
# Use model labels if available, otherwise infer from data
|
|
960
|
+
if model_class_labels:
|
|
961
|
+
self.log.important("Using model class labels for confusion matrix ordering...")
|
|
962
|
+
labels = model_class_labels
|
|
963
|
+
else:
|
|
964
|
+
labels = sorted(list(set(y_true) | set(y_pred)))
|
|
897
965
|
|
|
898
|
-
#
|
|
899
|
-
|
|
966
|
+
# Compute confusion matrix and create DataFrame
|
|
967
|
+
conf_mtx = confusion_matrix(y_true, y_pred, labels=labels)
|
|
968
|
+
conf_mtx_df = pd.DataFrame(conf_mtx, index=labels, columns=labels)
|
|
900
969
|
conf_mtx_df.index.name = "labels"
|
|
901
|
-
|
|
902
|
-
# Check if our model has class labels. If so make the index and columns ordered
|
|
903
|
-
model_class_labels = ModelCore(self.model_name).class_labels()
|
|
904
|
-
if model_class_labels:
|
|
905
|
-
self.log.important("Reordering the confusion matrix based on model class labels...")
|
|
906
|
-
conf_mtx_df.index = pd.Categorical(conf_mtx_df.index, categories=model_class_labels, ordered=True)
|
|
907
|
-
conf_mtx_df.columns = pd.Categorical(conf_mtx_df.columns, categories=model_class_labels, ordered=True)
|
|
908
|
-
conf_mtx_df = conf_mtx_df.sort_index().sort_index(axis=1)
|
|
909
970
|
return conf_mtx_df
|
|
910
971
|
|
|
911
972
|
def endpoint_config_name(self) -> str:
|
|
@@ -932,9 +993,9 @@ class EndpointCore(Artifact):
|
|
|
932
993
|
self.upsert_workbench_meta({"workbench_input": input})
|
|
933
994
|
|
|
934
995
|
def delete(self):
|
|
935
|
-
"""
|
|
996
|
+
"""Delete an existing Endpoint: Underlying Models, Configuration, and Endpoint"""
|
|
936
997
|
if not self.exists():
|
|
937
|
-
self.log.warning(f"Trying to delete an
|
|
998
|
+
self.log.warning(f"Trying to delete an Endpoint that doesn't exist: {self.name}")
|
|
938
999
|
|
|
939
1000
|
# Remove this endpoint from the list of registered endpoints
|
|
940
1001
|
self.log.info(f"Removing {self.name} from the list of registered endpoints...")
|
|
@@ -975,12 +1036,23 @@ class EndpointCore(Artifact):
|
|
|
975
1036
|
cls.log.info(f"Deleting Monitoring Schedule {schedule['MonitoringScheduleName']}...")
|
|
976
1037
|
cls.sm_client.delete_monitoring_schedule(MonitoringScheduleName=schedule["MonitoringScheduleName"])
|
|
977
1038
|
|
|
978
|
-
# Recursively delete all endpoint S3 artifacts (inference,
|
|
1039
|
+
# Recursively delete all endpoint S3 artifacts (inference, etc)
|
|
1040
|
+
# Note: We do not want to delete the data_capture/ files since these
|
|
1041
|
+
# might be used for collection and data drift analysis
|
|
979
1042
|
base_endpoint_path = f"{cls.endpoints_s3_path}/{endpoint_name}"
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
1043
|
+
all_s3_objects = wr.s3.list_objects(base_endpoint_path, boto3_session=cls.boto3_session)
|
|
1044
|
+
|
|
1045
|
+
# Filter out objects that contain 'data_capture/' in their path
|
|
1046
|
+
s3_objects_to_delete = [obj for obj in all_s3_objects if "/data_capture/" not in obj]
|
|
1047
|
+
cls.log.info(f"Found {len(all_s3_objects)} total objects at {base_endpoint_path}")
|
|
1048
|
+
cls.log.info(f"Filtering out data_capture files, will delete {len(s3_objects_to_delete)} objects...")
|
|
1049
|
+
cls.log.info(f"Objects to delete: {s3_objects_to_delete}")
|
|
1050
|
+
|
|
1051
|
+
if s3_objects_to_delete:
|
|
1052
|
+
wr.s3.delete_objects(s3_objects_to_delete, boto3_session=cls.boto3_session)
|
|
1053
|
+
cls.log.info(f"Successfully deleted {len(s3_objects_to_delete)} objects")
|
|
1054
|
+
else:
|
|
1055
|
+
cls.log.info("No objects to delete (only data_capture files found)")
|
|
984
1056
|
|
|
985
1057
|
# Delete any dataframes that were stored in the Dataframe Cache
|
|
986
1058
|
cls.log.info("Deleting Dataframe Cache...")
|
|
@@ -1031,7 +1103,7 @@ class EndpointCore(Artifact):
|
|
|
1031
1103
|
if __name__ == "__main__":
|
|
1032
1104
|
"""Exercise the Endpoint Class"""
|
|
1033
1105
|
from workbench.api import FeatureSet
|
|
1034
|
-
from workbench.utils.endpoint_utils import
|
|
1106
|
+
from workbench.utils.endpoint_utils import get_evaluation_data
|
|
1035
1107
|
import random
|
|
1036
1108
|
|
|
1037
1109
|
# Grab an EndpointCore object and pull some information from it
|
|
@@ -1039,7 +1111,7 @@ if __name__ == "__main__":
|
|
|
1039
1111
|
|
|
1040
1112
|
# Test various error conditions (set row 42 length to pd.NA)
|
|
1041
1113
|
# Note: This test should return ALL rows
|
|
1042
|
-
my_eval_df =
|
|
1114
|
+
my_eval_df = get_evaluation_data(my_endpoint)
|
|
1043
1115
|
my_eval_df.at[42, "length"] = pd.NA
|
|
1044
1116
|
pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
|
|
1045
1117
|
print(f"Sent rows: {len(my_eval_df)}")
|
|
@@ -1047,6 +1119,9 @@ if __name__ == "__main__":
|
|
|
1047
1119
|
assert len(pred_results) == len(my_eval_df), "Predictions should match the number of sent rows"
|
|
1048
1120
|
|
|
1049
1121
|
# Now we put in an invalid value
|
|
1122
|
+
print("*" * 80)
|
|
1123
|
+
print("NOW TESTING ERROR CONDITIONS...")
|
|
1124
|
+
print("*" * 80)
|
|
1050
1125
|
my_eval_df.at[42, "length"] = "invalid_value"
|
|
1051
1126
|
pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
|
|
1052
1127
|
print(f"Sent rows: {len(my_eval_df)}")
|
|
@@ -1086,13 +1161,20 @@ if __name__ == "__main__":
|
|
|
1086
1161
|
df = fs.pull_dataframe()[:100]
|
|
1087
1162
|
cap_df = df.copy()
|
|
1088
1163
|
cap_df.columns = [col.upper() for col in cap_df.columns]
|
|
1089
|
-
my_endpoint.
|
|
1164
|
+
my_endpoint.inference(cap_df)
|
|
1090
1165
|
|
|
1091
1166
|
# Boolean Type Test
|
|
1092
1167
|
df["bool_column"] = [random.choice([True, False]) for _ in range(len(df))]
|
|
1093
|
-
result_df = my_endpoint.
|
|
1168
|
+
result_df = my_endpoint.inference(df)
|
|
1094
1169
|
assert result_df["bool_column"].dtype == bool
|
|
1095
1170
|
|
|
1171
|
+
# Missing Feature Test
|
|
1172
|
+
missing_df = df.drop(columns=["length"])
|
|
1173
|
+
try:
|
|
1174
|
+
my_endpoint.inference(missing_df)
|
|
1175
|
+
except ValueError as e:
|
|
1176
|
+
print(f"Expected error for missing feature: {e}")
|
|
1177
|
+
|
|
1096
1178
|
# Run Auto Inference on the Endpoint (uses the FeatureSet)
|
|
1097
1179
|
print("Running Auto Inference...")
|
|
1098
1180
|
my_endpoint.auto_inference()
|
|
@@ -1100,13 +1182,20 @@ if __name__ == "__main__":
|
|
|
1100
1182
|
# Run Inference where we provide the data
|
|
1101
1183
|
# Note: This dataframe could be from a FeatureSet or any other source
|
|
1102
1184
|
print("Running Inference...")
|
|
1103
|
-
my_eval_df =
|
|
1185
|
+
my_eval_df = get_evaluation_data(my_endpoint)
|
|
1104
1186
|
pred_results = my_endpoint.inference(my_eval_df)
|
|
1105
1187
|
|
|
1106
1188
|
# Now set capture=True to save inference results and metrics
|
|
1107
|
-
my_eval_df =
|
|
1189
|
+
my_eval_df = get_evaluation_data(my_endpoint)
|
|
1108
1190
|
pred_results = my_endpoint.inference(my_eval_df, capture_name="holdout_xyz")
|
|
1109
1191
|
|
|
1192
|
+
# Run predictions using the fast_inference method
|
|
1193
|
+
fast_results = my_endpoint.fast_inference(my_eval_df)
|
|
1194
|
+
|
|
1195
|
+
# Test the cross_fold_inference method
|
|
1196
|
+
print("Running Cross-Fold Inference...")
|
|
1197
|
+
metrics, all_results = my_endpoint.cross_fold_inference()
|
|
1198
|
+
|
|
1110
1199
|
# Run Inference and metrics for a Classification Endpoint
|
|
1111
1200
|
class_endpoint = EndpointCore("wine-classification")
|
|
1112
1201
|
auto_predictions = class_endpoint.auto_inference()
|
|
@@ -1115,8 +1204,9 @@ if __name__ == "__main__":
|
|
|
1115
1204
|
target = "wine_class"
|
|
1116
1205
|
print(class_endpoint.generate_confusion_matrix(target, auto_predictions))
|
|
1117
1206
|
|
|
1118
|
-
#
|
|
1119
|
-
|
|
1207
|
+
# Test the cross_fold_inference method
|
|
1208
|
+
print("Running Cross-Fold Inference...")
|
|
1209
|
+
metrics, all_results = class_endpoint.cross_fold_inference()
|
|
1120
1210
|
|
|
1121
1211
|
# Test the class method delete (commented out for now)
|
|
1122
1212
|
# from workbench.api import Model
|
|
@@ -17,7 +17,7 @@ from workbench.core.artifacts.artifact import Artifact
|
|
|
17
17
|
from workbench.core.artifacts.data_source_factory import DataSourceFactory
|
|
18
18
|
from workbench.core.artifacts.athena_source import AthenaSource
|
|
19
19
|
|
|
20
|
-
from typing import TYPE_CHECKING
|
|
20
|
+
from typing import TYPE_CHECKING, Optional, List, Union
|
|
21
21
|
|
|
22
22
|
from workbench.utils.aws_utils import aws_throttle
|
|
23
23
|
|
|
@@ -194,24 +194,24 @@ class FeatureSetCore(Artifact):
|
|
|
194
194
|
|
|
195
195
|
return View(self, view_name)
|
|
196
196
|
|
|
197
|
-
def set_display_columns(self,
|
|
197
|
+
def set_display_columns(self, display_columns: list[str]):
|
|
198
198
|
"""Set the display columns for this Data Source
|
|
199
199
|
|
|
200
200
|
Args:
|
|
201
|
-
|
|
201
|
+
display_columns (list[str]): The display columns for this Data Source
|
|
202
202
|
"""
|
|
203
203
|
# Check mismatch of display columns to computation columns
|
|
204
204
|
c_view = self.view("computation")
|
|
205
205
|
computation_columns = c_view.columns
|
|
206
|
-
mismatch_columns = [col for col in
|
|
206
|
+
mismatch_columns = [col for col in display_columns if col not in computation_columns]
|
|
207
207
|
if mismatch_columns:
|
|
208
208
|
self.log.monitor(f"Display View/Computation mismatch: {mismatch_columns}")
|
|
209
209
|
|
|
210
|
-
self.log.important(f"Setting Display Columns...{
|
|
210
|
+
self.log.important(f"Setting Display Columns...{display_columns}")
|
|
211
211
|
from workbench.core.views import DisplayView
|
|
212
212
|
|
|
213
213
|
# Create a NEW display view
|
|
214
|
-
DisplayView.create(self, source_table=c_view.table, column_list=
|
|
214
|
+
DisplayView.create(self, source_table=c_view.table, column_list=display_columns)
|
|
215
215
|
|
|
216
216
|
def set_computation_columns(self, computation_columns: list[str], reset_display: bool = True):
|
|
217
217
|
"""Set the computation columns for this FeatureSet
|
|
@@ -509,6 +509,48 @@ class FeatureSetCore(Artifact):
|
|
|
509
509
|
].tolist()
|
|
510
510
|
return hold_out_ids
|
|
511
511
|
|
|
512
|
+
def set_training_filter(self, filter_expression: Optional[str] = None):
|
|
513
|
+
"""Set a filter expression for the training view for this FeatureSet
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
filter_expression (Optional[str]): A SQL filter expression (e.g., "age > 25 AND status = 'active'")
|
|
517
|
+
If None or empty string, will reset to training view with no filter
|
|
518
|
+
(default: None)
|
|
519
|
+
"""
|
|
520
|
+
from workbench.core.views import TrainingView
|
|
521
|
+
|
|
522
|
+
# Grab the existing holdout ids
|
|
523
|
+
holdout_ids = self.get_training_holdouts()
|
|
524
|
+
|
|
525
|
+
# Create a NEW training view
|
|
526
|
+
self.log.important(f"Setting Training Filter: {filter_expression}")
|
|
527
|
+
TrainingView.create(
|
|
528
|
+
self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
|
|
532
|
+
"""Exclude a list of IDs from the training view
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
ids (List[Union[str, int]],): List of IDs to exclude from training
|
|
536
|
+
column_name (Optional[str]): Column name to filter on.
|
|
537
|
+
If None, uses self.id_column (default: None)
|
|
538
|
+
"""
|
|
539
|
+
# Use the default id_column if not specified
|
|
540
|
+
column = column_name or self.id_column
|
|
541
|
+
|
|
542
|
+
# Handle empty list case
|
|
543
|
+
if not ids:
|
|
544
|
+
self.log.warning("No IDs provided to exclude")
|
|
545
|
+
return
|
|
546
|
+
|
|
547
|
+
# Build the filter expression with proper SQL quoting
|
|
548
|
+
quoted_ids = ", ".join([repr(id) for id in ids])
|
|
549
|
+
filter_expression = f"{column} NOT IN ({quoted_ids})"
|
|
550
|
+
|
|
551
|
+
# Apply the filter
|
|
552
|
+
self.set_training_filter(filter_expression)
|
|
553
|
+
|
|
512
554
|
@classmethod
|
|
513
555
|
def delete_views(cls, table: str, database: str):
|
|
514
556
|
"""Delete any views associated with this FeatureSet
|
|
@@ -707,7 +749,7 @@ if __name__ == "__main__":
|
|
|
707
749
|
|
|
708
750
|
# Test getting the holdout ids
|
|
709
751
|
print("Getting the hold out ids...")
|
|
710
|
-
holdout_ids = my_features.get_training_holdouts(
|
|
752
|
+
holdout_ids = my_features.get_training_holdouts()
|
|
711
753
|
print(f"Holdout IDs: {holdout_ids}")
|
|
712
754
|
|
|
713
755
|
# Get a sample of the data
|
|
@@ -729,16 +771,33 @@ if __name__ == "__main__":
|
|
|
729
771
|
table = my_features.view("training").table
|
|
730
772
|
df = my_features.query(f'SELECT id, name FROM "{table}"')
|
|
731
773
|
my_holdout_ids = [id for id in df["id"] if id < 20]
|
|
732
|
-
my_features.set_training_holdouts(
|
|
733
|
-
|
|
734
|
-
# Test the hold out set functionality with strings
|
|
735
|
-
print("Setting hold out ids (strings)...")
|
|
736
|
-
my_holdout_ids = [name for name in df["name"] if int(name.split(" ")[1]) > 80]
|
|
737
|
-
my_features.set_training_holdouts("name", my_holdout_ids)
|
|
774
|
+
my_features.set_training_holdouts(my_holdout_ids)
|
|
738
775
|
|
|
739
776
|
# Get the training data
|
|
740
777
|
print("Getting the training data...")
|
|
741
778
|
training_data = my_features.get_training_data()
|
|
779
|
+
print(f"Training Data: {training_data.shape}")
|
|
780
|
+
|
|
781
|
+
# Test the filter expression functionality
|
|
782
|
+
print("Setting a filter expression...")
|
|
783
|
+
my_features.set_training_filter("id < 50 AND height > 65.0")
|
|
784
|
+
training_data = my_features.get_training_data()
|
|
785
|
+
print(f"Training Data: {training_data.shape}")
|
|
786
|
+
print(training_data)
|
|
787
|
+
|
|
788
|
+
# Remove training filter
|
|
789
|
+
print("Removing the filter expression...")
|
|
790
|
+
my_features.set_training_filter(None)
|
|
791
|
+
training_data = my_features.get_training_data()
|
|
792
|
+
print(f"Training Data: {training_data.shape}")
|
|
793
|
+
print(training_data)
|
|
794
|
+
|
|
795
|
+
# Test excluding ids from training
|
|
796
|
+
print("Excluding ids from training...")
|
|
797
|
+
my_features.exclude_ids_from_training([1, 2, 3, 4, 5])
|
|
798
|
+
training_data = my_features.get_training_data()
|
|
799
|
+
print(f"Training Data: {training_data.shape}")
|
|
800
|
+
print(training_data)
|
|
742
801
|
|
|
743
802
|
# Now delete the AWS artifacts associated with this Feature Set
|
|
744
803
|
# print("Deleting Workbench Feature Set...")
|