workbench 0.8.198__py3-none-any.whl → 0.8.203__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/proximity.py +11 -4
- workbench/api/__init__.py +2 -1
- workbench/api/df_store.py +17 -108
- workbench/api/feature_set.py +48 -11
- workbench/api/model.py +1 -1
- workbench/api/parameter_store.py +3 -52
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +5 -5
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +261 -78
- workbench/core/artifacts/feature_set_core.py +69 -1
- workbench/core/artifacts/model_core.py +48 -14
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/transforms/features_to_model/features_to_model.py +50 -33
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
- workbench/core/views/view.py +2 -2
- workbench/model_scripts/chemprop/chemprop.template +933 -0
- workbench/model_scripts/chemprop/generated_model_script.py +933 -0
- workbench/model_scripts/chemprop/requirements.txt +11 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/proximity.py +11 -4
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +11 -5
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +11 -5
- workbench/model_scripts/custom_models/uq_models/ngboost.template +11 -5
- workbench/model_scripts/custom_models/uq_models/proximity.py +11 -4
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +11 -5
- workbench/model_scripts/pytorch_model/generated_model_script.py +365 -173
- workbench/model_scripts/pytorch_model/pytorch.template +362 -170
- workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
- workbench/model_scripts/script_generation.py +10 -7
- workbench/model_scripts/uq_models/generated_model_script.py +43 -27
- workbench/model_scripts/uq_models/mapie.template +40 -24
- workbench/model_scripts/xgb_model/generated_model_script.py +36 -7
- workbench/model_scripts/xgb_model/xgb_model.template +36 -7
- workbench/repl/workbench_shell.py +14 -5
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
- workbench/utils/chemprop_utils.py +761 -0
- workbench/utils/pytorch_utils.py +527 -0
- workbench/utils/xgboost_model_utils.py +10 -5
- workbench/web_interface/components/model_plot.py +7 -1
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/METADATA +3 -3
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/RECORD +49 -43
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/entry_points.txt +2 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/__pycache__/script_generation.cpython-312.pyc +0 -0
- workbench/model_scripts/__pycache__/script_generation.cpython-313.pyc +0 -0
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/WHEEL +0 -0
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/top_level.txt +0 -0
|
@@ -30,12 +30,14 @@ from sagemaker import Predictor
|
|
|
30
30
|
|
|
31
31
|
# Workbench Imports
|
|
32
32
|
from workbench.core.artifacts.artifact import Artifact
|
|
33
|
-
from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType
|
|
33
|
+
from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType, ModelFramework
|
|
34
34
|
from workbench.utils.endpoint_metrics import EndpointMetrics
|
|
35
35
|
from workbench.utils.cache import Cache
|
|
36
36
|
from workbench.utils.s3_utils import compute_s3_object_hash
|
|
37
37
|
from workbench.utils.model_utils import uq_metrics
|
|
38
|
-
from workbench.utils.xgboost_model_utils import cross_fold_inference
|
|
38
|
+
from workbench.utils.xgboost_model_utils import cross_fold_inference as xgboost_cross_fold
|
|
39
|
+
from workbench.utils.pytorch_utils import pull_cv_results as pytorch_pull_cv
|
|
40
|
+
from workbench.utils.chemprop_utils import pull_cv_results as chemprop_pull_cv
|
|
39
41
|
from workbench_bridges.endpoints.fast_inference import fast_inference
|
|
40
42
|
|
|
41
43
|
|
|
@@ -387,7 +389,7 @@ class EndpointCore(Artifact):
|
|
|
387
389
|
# Grab the model features and target column
|
|
388
390
|
model = ModelCore(self.model_name)
|
|
389
391
|
features = model.features()
|
|
390
|
-
|
|
392
|
+
targets = model.target() # Note: We have multi-target models (so this could be a list)
|
|
391
393
|
|
|
392
394
|
# Run predictions on the evaluation data
|
|
393
395
|
prediction_df = self._predict(eval_df, features, drop_error_rows)
|
|
@@ -395,45 +397,84 @@ class EndpointCore(Artifact):
|
|
|
395
397
|
self.log.warning("No predictions were made. Returning empty DataFrame.")
|
|
396
398
|
return prediction_df
|
|
397
399
|
|
|
400
|
+
# FIXME: Multi-target support - currently uses first target for metrics
|
|
401
|
+
# Normalize targets to handle both string and list formats
|
|
402
|
+
if isinstance(targets, list):
|
|
403
|
+
primary_target = targets[0] if targets else None
|
|
404
|
+
else:
|
|
405
|
+
primary_target = targets
|
|
406
|
+
|
|
398
407
|
# Sanity Check that the target column is present
|
|
399
|
-
if
|
|
400
|
-
self.log.important(f"Target Column {
|
|
408
|
+
if primary_target and (primary_target not in prediction_df.columns):
|
|
409
|
+
self.log.important(f"Target Column {primary_target} not found in prediction_df!")
|
|
401
410
|
self.log.important("In order to compute metrics, the target column must be present!")
|
|
402
|
-
|
|
411
|
+
metrics = pd.DataFrame()
|
|
403
412
|
|
|
404
413
|
# Compute the standard performance metrics for this model
|
|
405
|
-
model_type = model.model_type
|
|
406
|
-
if model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
|
|
407
|
-
prediction_df = self.residuals(target_column, prediction_df)
|
|
408
|
-
metrics = self.regression_metrics(target_column, prediction_df)
|
|
409
|
-
elif model_type == ModelType.CLASSIFIER:
|
|
410
|
-
metrics = self.classification_metrics(target_column, prediction_df)
|
|
411
414
|
else:
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
+
if model.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
|
|
416
|
+
prediction_df = self.residuals(primary_target, prediction_df)
|
|
417
|
+
metrics = self.regression_metrics(primary_target, prediction_df)
|
|
418
|
+
elif model.model_type == ModelType.CLASSIFIER:
|
|
419
|
+
metrics = self.classification_metrics(primary_target, prediction_df)
|
|
420
|
+
else:
|
|
421
|
+
# For other model types, we don't compute metrics
|
|
422
|
+
self.log.info(f"Model Type: {model.model_type} doesn't have metrics...")
|
|
423
|
+
metrics = pd.DataFrame()
|
|
415
424
|
|
|
416
425
|
# Print out the metrics
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
426
|
+
print(f"Performance Metrics for {self.model_name} on {self.name}")
|
|
427
|
+
print(metrics.head())
|
|
428
|
+
|
|
429
|
+
# Capture the inference results and metrics
|
|
430
|
+
if capture_name is not None:
|
|
431
|
+
|
|
432
|
+
# If we don't have an id_column, we'll pull it from the model's FeatureSet
|
|
433
|
+
if id_column is None:
|
|
434
|
+
fs = FeatureSetCore(model.get_input())
|
|
435
|
+
id_column = fs.id_column
|
|
436
|
+
|
|
437
|
+
# Normalize targets to a list for iteration
|
|
438
|
+
target_list = targets if isinstance(targets, list) else [targets]
|
|
439
|
+
|
|
440
|
+
# For multi-target models, use target-specific capture names (e.g., auto_target1, auto_target2)
|
|
441
|
+
# For single-target models, use the original capture name for backward compatibility
|
|
442
|
+
for target in target_list:
|
|
443
|
+
# Determine capture name: use prefix for multi-target, original name for single-target
|
|
444
|
+
if len(target_list) > 1:
|
|
445
|
+
prefix = "auto" if "auto" in capture_name else capture_name
|
|
446
|
+
target_capture_name = f"{prefix}_{target}"
|
|
447
|
+
else:
|
|
448
|
+
target_capture_name = capture_name
|
|
449
|
+
|
|
450
|
+
description = target_capture_name.replace("_", " ").title()
|
|
451
|
+
|
|
452
|
+
# Drop rows with NaN target values for metrics/plots
|
|
453
|
+
target_df = prediction_df.dropna(subset=[target])
|
|
454
|
+
|
|
455
|
+
# Compute per-target metrics
|
|
456
|
+
if model.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
|
|
457
|
+
target_metrics = self.regression_metrics(target, target_df)
|
|
458
|
+
elif model.model_type == ModelType.CLASSIFIER:
|
|
459
|
+
target_metrics = self.classification_metrics(target, target_df)
|
|
460
|
+
else:
|
|
461
|
+
target_metrics = pd.DataFrame()
|
|
462
|
+
|
|
429
463
|
self._capture_inference_results(
|
|
430
|
-
|
|
464
|
+
target_capture_name,
|
|
465
|
+
target_df,
|
|
466
|
+
target,
|
|
467
|
+
model.model_type,
|
|
468
|
+
target_metrics,
|
|
469
|
+
description,
|
|
470
|
+
features,
|
|
471
|
+
id_column,
|
|
431
472
|
)
|
|
432
473
|
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
474
|
+
# For UQ Models we also capture the uncertainty metrics
|
|
475
|
+
if model.model_type in [ModelType.UQ_REGRESSOR]:
|
|
476
|
+
metrics = uq_metrics(prediction_df, primary_target)
|
|
477
|
+
self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
|
|
437
478
|
|
|
438
479
|
# Return the prediction DataFrame
|
|
439
480
|
return prediction_df
|
|
@@ -452,7 +493,16 @@ class EndpointCore(Artifact):
|
|
|
452
493
|
model = ModelCore(self.model_name)
|
|
453
494
|
|
|
454
495
|
# Compute CrossFold (Metrics and Prediction Dataframe)
|
|
455
|
-
|
|
496
|
+
# For PyTorch and ChemProp, pull pre-computed CV results from training
|
|
497
|
+
if model.model_framework in [ModelFramework.UNKNOWN, ModelFramework.XGBOOST]:
|
|
498
|
+
cross_fold_metrics, out_of_fold_df = xgboost_cross_fold(model, nfolds=nfolds)
|
|
499
|
+
elif model.model_framework == ModelFramework.PYTORCH_TABULAR:
|
|
500
|
+
cross_fold_metrics, out_of_fold_df = pytorch_pull_cv(model)
|
|
501
|
+
elif model.model_framework == ModelFramework.CHEMPROP:
|
|
502
|
+
cross_fold_metrics, out_of_fold_df = chemprop_pull_cv(model)
|
|
503
|
+
else:
|
|
504
|
+
self.log.error(f"Cross-Fold Inference not supported for Model Framework: {model.model_framework}.")
|
|
505
|
+
return pd.DataFrame()
|
|
456
506
|
|
|
457
507
|
# If the metrics dataframe isn't empty save to the param store
|
|
458
508
|
if not cross_fold_metrics.empty:
|
|
@@ -460,10 +510,13 @@ class EndpointCore(Artifact):
|
|
|
460
510
|
metrics = cross_fold_metrics.to_dict(orient="records")
|
|
461
511
|
self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", metrics)
|
|
462
512
|
|
|
513
|
+
# If the out_of_fold_df is empty return it
|
|
514
|
+
if out_of_fold_df.empty:
|
|
515
|
+
self.log.warning("No out-of-fold predictions were made. Returning empty DataFrame.")
|
|
516
|
+
return out_of_fold_df
|
|
517
|
+
|
|
463
518
|
# Capture the results
|
|
464
|
-
|
|
465
|
-
description = capture_name.replace("_", " ").title()
|
|
466
|
-
target_column = model.target()
|
|
519
|
+
targets = model.target() # Note: We have multi-target models (so this could be a list)
|
|
467
520
|
model_type = model.model_type
|
|
468
521
|
|
|
469
522
|
# Get the id_column from the model's FeatureSet
|
|
@@ -472,7 +525,7 @@ class EndpointCore(Artifact):
|
|
|
472
525
|
|
|
473
526
|
# Is this a UQ Model? If so, run full inference and merge the results
|
|
474
527
|
additional_columns = []
|
|
475
|
-
if model_type == ModelType.UQ_REGRESSOR:
|
|
528
|
+
if model.model_framework == ModelFramework.XGBOOST and model_type == ModelType.UQ_REGRESSOR:
|
|
476
529
|
self.log.important("UQ Regressor detected, running full inference to get uncertainty estimates...")
|
|
477
530
|
|
|
478
531
|
# Get the training view dataframe for inference
|
|
@@ -481,9 +534,11 @@ class EndpointCore(Artifact):
|
|
|
481
534
|
# Run inference on the endpoint to get UQ outputs
|
|
482
535
|
uq_df = self.inference(training_df)
|
|
483
536
|
|
|
484
|
-
# Identify UQ-specific columns (quantiles
|
|
537
|
+
# Identify UQ-specific columns (quantiles, prediction_std, *_pred_std)
|
|
485
538
|
uq_columns = [
|
|
486
|
-
col
|
|
539
|
+
col
|
|
540
|
+
for col in uq_df.columns
|
|
541
|
+
if col.startswith("q_") or col == "prediction_std" or col.endswith("_pred_std") or col == "confidence"
|
|
487
542
|
]
|
|
488
543
|
|
|
489
544
|
# Merge UQ columns with out-of-fold predictions
|
|
@@ -499,20 +554,42 @@ class EndpointCore(Artifact):
|
|
|
499
554
|
additional_columns = uq_columns
|
|
500
555
|
self.log.info(f"Added UQ columns: {', '.join(additional_columns)}")
|
|
501
556
|
|
|
502
|
-
# Also compute UQ metrics
|
|
503
|
-
|
|
504
|
-
|
|
557
|
+
# Also compute UQ metrics (use first target for multi-target models)
|
|
558
|
+
primary_target = targets[0] if isinstance(targets, list) else targets
|
|
559
|
+
metrics = uq_metrics(out_of_fold_df, primary_target)
|
|
560
|
+
self.param_store.upsert(f"/workbench/models/{model.name}/inference/full_cross_fold", metrics)
|
|
561
|
+
|
|
562
|
+
# Normalize targets to a list for iteration
|
|
563
|
+
target_list = targets if isinstance(targets, list) else [targets]
|
|
564
|
+
|
|
565
|
+
# For multi-target models, use target-specific capture names (e.g., cv_target1, cv_target2)
|
|
566
|
+
# For single-target models, use "full_cross_fold" for backward compatibility
|
|
567
|
+
for target in target_list:
|
|
568
|
+
capture_name = f"cv_{target}"
|
|
569
|
+
description = capture_name.replace("_", " ").title()
|
|
570
|
+
|
|
571
|
+
# Drop rows with NaN target values for metrics/plots
|
|
572
|
+
target_df = out_of_fold_df.dropna(subset=[target])
|
|
573
|
+
|
|
574
|
+
# Compute per-target metrics
|
|
575
|
+
if model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
|
|
576
|
+
target_metrics = self.regression_metrics(target, target_df)
|
|
577
|
+
elif model_type == ModelType.CLASSIFIER:
|
|
578
|
+
target_metrics = self.classification_metrics(target, target_df)
|
|
579
|
+
else:
|
|
580
|
+
target_metrics = pd.DataFrame()
|
|
581
|
+
|
|
582
|
+
self._capture_inference_results(
|
|
583
|
+
capture_name,
|
|
584
|
+
target_df,
|
|
585
|
+
target,
|
|
586
|
+
model_type,
|
|
587
|
+
target_metrics,
|
|
588
|
+
description,
|
|
589
|
+
features=additional_columns,
|
|
590
|
+
id_column=id_column,
|
|
591
|
+
)
|
|
505
592
|
|
|
506
|
-
self._capture_inference_results(
|
|
507
|
-
capture_name,
|
|
508
|
-
out_of_fold_df,
|
|
509
|
-
target_column,
|
|
510
|
-
model_type,
|
|
511
|
-
cross_fold_metrics,
|
|
512
|
-
description,
|
|
513
|
-
features=additional_columns,
|
|
514
|
-
id_column=id_column,
|
|
515
|
-
)
|
|
516
593
|
return out_of_fold_df
|
|
517
594
|
|
|
518
595
|
def fast_inference(self, eval_df: pd.DataFrame, threads: int = 4) -> pd.DataFrame:
|
|
@@ -718,23 +795,47 @@ class EndpointCore(Artifact):
|
|
|
718
795
|
combined = row_hashes.values.tobytes()
|
|
719
796
|
return hashlib.md5(combined).hexdigest()[:hash_length]
|
|
720
797
|
|
|
798
|
+
@staticmethod
|
|
799
|
+
def _find_prediction_column(df: pd.DataFrame, target_column: str) -> Optional[str]:
|
|
800
|
+
"""Find the prediction column in a DataFrame.
|
|
801
|
+
|
|
802
|
+
Looks for 'prediction' column first, then '{target}_pred' pattern.
|
|
803
|
+
|
|
804
|
+
Args:
|
|
805
|
+
df: DataFrame to search
|
|
806
|
+
target_column: Name of the target column (used for {target}_pred pattern)
|
|
807
|
+
|
|
808
|
+
Returns:
|
|
809
|
+
Name of the prediction column, or None if not found
|
|
810
|
+
"""
|
|
811
|
+
# Check for 'prediction' column first (legacy/standard format)
|
|
812
|
+
if "prediction" in df.columns:
|
|
813
|
+
return "prediction"
|
|
814
|
+
|
|
815
|
+
# Check for '{target}_pred' format (multi-target format)
|
|
816
|
+
target_pred_col = f"{target_column}_pred"
|
|
817
|
+
if target_pred_col in df.columns:
|
|
818
|
+
return target_pred_col
|
|
819
|
+
|
|
820
|
+
return None
|
|
821
|
+
|
|
721
822
|
def _capture_inference_results(
|
|
722
823
|
self,
|
|
723
824
|
capture_name: str,
|
|
724
825
|
pred_results_df: pd.DataFrame,
|
|
725
|
-
|
|
826
|
+
target: str,
|
|
726
827
|
model_type: ModelType,
|
|
727
828
|
metrics: pd.DataFrame,
|
|
728
829
|
description: str,
|
|
729
830
|
features: list,
|
|
730
831
|
id_column: str = None,
|
|
731
832
|
):
|
|
732
|
-
"""Internal: Capture the inference results and metrics to S3
|
|
833
|
+
"""Internal: Capture the inference results and metrics to S3 for a single target
|
|
733
834
|
|
|
734
835
|
Args:
|
|
735
836
|
capture_name (str): Name of the inference capture
|
|
736
837
|
pred_results_df (pd.DataFrame): DataFrame with the prediction results
|
|
737
|
-
|
|
838
|
+
target (str): Target column name
|
|
738
839
|
model_type (ModelType): Type of the model (e.g. REGRESSOR, CLASSIFIER)
|
|
739
840
|
metrics (pd.DataFrame): DataFrame with the performance metrics
|
|
740
841
|
description (str): Description of the inference results
|
|
@@ -765,28 +866,12 @@ class EndpointCore(Artifact):
|
|
|
765
866
|
self.log.info(f"Writing metrics to {inference_capture_path}/inference_metrics.csv")
|
|
766
867
|
wr.s3.to_csv(metrics, f"{inference_capture_path}/inference_metrics.csv", index=False)
|
|
767
868
|
|
|
768
|
-
#
|
|
769
|
-
|
|
770
|
-
output_columns += [col for col in pred_results_df.columns if "prediction" in col]
|
|
771
|
-
|
|
772
|
-
# Add any _proba columns to the output columns
|
|
773
|
-
output_columns += [col for col in pred_results_df.columns if col.endswith("_proba")]
|
|
774
|
-
|
|
775
|
-
# Add any Uncertainty Quantile columns to the output columns
|
|
776
|
-
output_columns += [col for col in pred_results_df.columns if col.startswith("q_") or col == "confidence"]
|
|
777
|
-
|
|
778
|
-
# Add the ID column
|
|
779
|
-
if id_column and id_column in pred_results_df.columns:
|
|
780
|
-
output_columns.insert(0, id_column)
|
|
781
|
-
|
|
782
|
-
# Write the predictions to our S3 Model Inference Folder
|
|
783
|
-
self.log.info(f"Writing predictions to {inference_capture_path}/inference_predictions.csv")
|
|
784
|
-
subset_df = pred_results_df[output_columns]
|
|
785
|
-
wr.s3.to_csv(subset_df, f"{inference_capture_path}/inference_predictions.csv", index=False)
|
|
869
|
+
# Save the inference predictions for this target
|
|
870
|
+
self._save_target_inference(inference_capture_path, pred_results_df, target, id_column)
|
|
786
871
|
|
|
787
872
|
# CLASSIFIER: Write the confusion matrix to our S3 Model Inference Folder
|
|
788
873
|
if model_type == ModelType.CLASSIFIER:
|
|
789
|
-
conf_mtx = self.generate_confusion_matrix(
|
|
874
|
+
conf_mtx = self.generate_confusion_matrix(target, pred_results_df)
|
|
790
875
|
self.log.info(f"Writing confusion matrix to {inference_capture_path}/inference_cm.csv")
|
|
791
876
|
# Note: Unlike other dataframes here, we want to write the index (labels) to the CSV
|
|
792
877
|
wr.s3.to_csv(conf_mtx, f"{inference_capture_path}/inference_cm.csv", index=True)
|
|
@@ -796,6 +881,57 @@ class EndpointCore(Artifact):
|
|
|
796
881
|
model = ModelCore(self.model_name)
|
|
797
882
|
model._load_inference_metrics(capture_name)
|
|
798
883
|
|
|
884
|
+
def _save_target_inference(
|
|
885
|
+
self,
|
|
886
|
+
inference_capture_path: str,
|
|
887
|
+
pred_results_df: pd.DataFrame,
|
|
888
|
+
target: str,
|
|
889
|
+
id_column: str = None,
|
|
890
|
+
):
|
|
891
|
+
"""Save inference results for a single target.
|
|
892
|
+
|
|
893
|
+
Args:
|
|
894
|
+
inference_capture_path (str): S3 path for inference capture
|
|
895
|
+
pred_results_df (pd.DataFrame): DataFrame with prediction results
|
|
896
|
+
target (str): Target column name
|
|
897
|
+
id_column (str, optional): Name of the ID column
|
|
898
|
+
"""
|
|
899
|
+
# Start with ID column if present
|
|
900
|
+
output_columns = []
|
|
901
|
+
if id_column and id_column in pred_results_df.columns:
|
|
902
|
+
output_columns.append(id_column)
|
|
903
|
+
|
|
904
|
+
# Add target column if present
|
|
905
|
+
if target and target in pred_results_df.columns:
|
|
906
|
+
output_columns.append(target)
|
|
907
|
+
|
|
908
|
+
# Build the output DataFrame
|
|
909
|
+
output_df = pred_results_df[output_columns].copy() if output_columns else pd.DataFrame()
|
|
910
|
+
|
|
911
|
+
# For multi-task: map {target}_pred -> prediction, {target}_pred_std -> prediction_std
|
|
912
|
+
# For single-task: just grab prediction and prediction_std columns directly
|
|
913
|
+
pred_col = f"{target}_pred"
|
|
914
|
+
std_col = f"{target}_pred_std"
|
|
915
|
+
if pred_col in pred_results_df.columns:
|
|
916
|
+
# Multi-task columns exist
|
|
917
|
+
output_df["prediction"] = pred_results_df[pred_col]
|
|
918
|
+
if std_col in pred_results_df.columns:
|
|
919
|
+
output_df["prediction_std"] = pred_results_df[std_col]
|
|
920
|
+
else:
|
|
921
|
+
# Single-task: grab standard prediction columns
|
|
922
|
+
for col in ["prediction", "prediction_std"]:
|
|
923
|
+
if col in pred_results_df.columns:
|
|
924
|
+
output_df[col] = pred_results_df[col]
|
|
925
|
+
# Also grab any _proba columns and UQ columns
|
|
926
|
+
for col in pred_results_df.columns:
|
|
927
|
+
if col.endswith("_proba") or col.startswith("q_") or col == "confidence":
|
|
928
|
+
output_df[col] = pred_results_df[col]
|
|
929
|
+
|
|
930
|
+
# Write the predictions to S3
|
|
931
|
+
output_file = f"{inference_capture_path}/inference_predictions.csv"
|
|
932
|
+
self.log.info(f"Writing predictions to {output_file}")
|
|
933
|
+
wr.s3.to_csv(output_df, output_file, index=False)
|
|
934
|
+
|
|
799
935
|
def regression_metrics(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
|
|
800
936
|
"""Compute the performance metrics for this Endpoint
|
|
801
937
|
Args:
|
|
@@ -810,10 +946,28 @@ class EndpointCore(Artifact):
|
|
|
810
946
|
self.log.warning("No predictions were made. Returning empty DataFrame.")
|
|
811
947
|
return pd.DataFrame()
|
|
812
948
|
|
|
949
|
+
# Find the prediction column: "prediction" or "{target}_pred"
|
|
950
|
+
prediction_col = self._find_prediction_column(prediction_df, target_column)
|
|
951
|
+
if prediction_col is None:
|
|
952
|
+
self.log.warning(f"No prediction column found for target '{target_column}'")
|
|
953
|
+
return pd.DataFrame()
|
|
954
|
+
|
|
955
|
+
# Check for NaN values in target or prediction columns
|
|
956
|
+
if prediction_df[target_column].isnull().any() or prediction_df[prediction_col].isnull().any():
|
|
957
|
+
# Compute the number of NaN values in each column
|
|
958
|
+
num_nan_target = prediction_df[target_column].isnull().sum()
|
|
959
|
+
num_nan_prediction = prediction_df[prediction_col].isnull().sum()
|
|
960
|
+
self.log.warning(
|
|
961
|
+
f"NaNs Found: {target_column} {num_nan_target} and {prediction_col}: {num_nan_prediction}."
|
|
962
|
+
)
|
|
963
|
+
self.log.warning(
|
|
964
|
+
"NaN values found in target or prediction columns. Dropping NaN rows for metric computation."
|
|
965
|
+
)
|
|
966
|
+
prediction_df = prediction_df.dropna(subset=[target_column, prediction_col])
|
|
967
|
+
|
|
813
968
|
# Compute the metrics
|
|
814
969
|
try:
|
|
815
970
|
y_true = prediction_df[target_column]
|
|
816
|
-
prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
|
|
817
971
|
y_pred = prediction_df[prediction_col]
|
|
818
972
|
|
|
819
973
|
mae = mean_absolute_error(y_true, y_pred)
|
|
@@ -849,7 +1003,13 @@ class EndpointCore(Artifact):
|
|
|
849
1003
|
|
|
850
1004
|
# Compute the residuals
|
|
851
1005
|
y_true = prediction_df[target_column]
|
|
852
|
-
|
|
1006
|
+
|
|
1007
|
+
# Find the prediction column: "prediction" or "{target}_pred"
|
|
1008
|
+
prediction_col = self._find_prediction_column(prediction_df, target_column)
|
|
1009
|
+
if prediction_col is None:
|
|
1010
|
+
self.log.warning(f"No prediction column found for target '{target_column}'. Cannot compute residuals.")
|
|
1011
|
+
return prediction_df
|
|
1012
|
+
|
|
853
1013
|
y_pred = prediction_df[prediction_col]
|
|
854
1014
|
|
|
855
1015
|
# Check for classification scenario
|
|
@@ -891,6 +1051,19 @@ class EndpointCore(Artifact):
|
|
|
891
1051
|
Returns:
|
|
892
1052
|
pd.DataFrame: DataFrame with the performance metrics
|
|
893
1053
|
"""
|
|
1054
|
+
# Find the prediction column: "prediction" or "{target}_pred"
|
|
1055
|
+
prediction_col = self._find_prediction_column(prediction_df, target_column)
|
|
1056
|
+
if prediction_col is None:
|
|
1057
|
+
self.log.warning(f"No prediction column found for target '{target_column}'")
|
|
1058
|
+
return pd.DataFrame()
|
|
1059
|
+
|
|
1060
|
+
# Drop rows with NaN predictions (can't compute metrics on missing predictions)
|
|
1061
|
+
nan_mask = prediction_df[prediction_col].isna()
|
|
1062
|
+
if nan_mask.any():
|
|
1063
|
+
n_nan = nan_mask.sum()
|
|
1064
|
+
self.log.warning(f"Dropping {n_nan} rows with NaN predictions for metrics calculation")
|
|
1065
|
+
prediction_df = prediction_df[~nan_mask].copy()
|
|
1066
|
+
|
|
894
1067
|
# Get the class labels from the model
|
|
895
1068
|
class_labels = ModelCore(self.model_name).class_labels()
|
|
896
1069
|
if class_labels is None:
|
|
@@ -903,7 +1076,6 @@ class EndpointCore(Artifact):
|
|
|
903
1076
|
self.validate_proba_columns(prediction_df, class_labels)
|
|
904
1077
|
|
|
905
1078
|
# Calculate precision, recall, f1, and support, handling zero division
|
|
906
|
-
prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
|
|
907
1079
|
scores = precision_recall_fscore_support(
|
|
908
1080
|
prediction_df[target_column],
|
|
909
1081
|
prediction_df[prediction_col],
|
|
@@ -954,9 +1126,20 @@ class EndpointCore(Artifact):
|
|
|
954
1126
|
Returns:
|
|
955
1127
|
pd.DataFrame: DataFrame with the confusion matrix
|
|
956
1128
|
"""
|
|
1129
|
+
# Find the prediction column: "prediction" or "{target}_pred"
|
|
1130
|
+
prediction_col = self._find_prediction_column(prediction_df, target_column)
|
|
1131
|
+
if prediction_col is None:
|
|
1132
|
+
self.log.warning(f"No prediction column found for target '{target_column}'")
|
|
1133
|
+
return pd.DataFrame()
|
|
1134
|
+
|
|
1135
|
+
# Drop rows with NaN predictions (can't include in confusion matrix)
|
|
1136
|
+
nan_mask = prediction_df[prediction_col].isna()
|
|
1137
|
+
if nan_mask.any():
|
|
1138
|
+
n_nan = nan_mask.sum()
|
|
1139
|
+
self.log.warning(f"Dropping {n_nan} rows with NaN predictions for confusion matrix")
|
|
1140
|
+
prediction_df = prediction_df[~nan_mask].copy()
|
|
957
1141
|
|
|
958
1142
|
y_true = prediction_df[target_column]
|
|
959
|
-
prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
|
|
960
1143
|
y_pred = prediction_df[prediction_col]
|
|
961
1144
|
|
|
962
1145
|
# Get model class labels
|
|
@@ -16,8 +16,9 @@ from sagemaker.feature_store.feature_store import FeatureStore
|
|
|
16
16
|
from workbench.core.artifacts.artifact import Artifact
|
|
17
17
|
from workbench.core.artifacts.data_source_factory import DataSourceFactory
|
|
18
18
|
from workbench.core.artifacts.athena_source import AthenaSource
|
|
19
|
+
from workbench.utils.deprecated_utils import deprecated
|
|
19
20
|
|
|
20
|
-
from typing import TYPE_CHECKING, Optional, List, Union
|
|
21
|
+
from typing import TYPE_CHECKING, Optional, List, Dict, Union
|
|
21
22
|
|
|
22
23
|
from workbench.utils.aws_utils import aws_throttle
|
|
23
24
|
|
|
@@ -509,6 +510,71 @@ class FeatureSetCore(Artifact):
|
|
|
509
510
|
].tolist()
|
|
510
511
|
return hold_out_ids
|
|
511
512
|
|
|
513
|
+
def set_sample_weights(
|
|
514
|
+
self,
|
|
515
|
+
weight_dict: Dict[Union[str, int], float],
|
|
516
|
+
default_weight: float = 1.0,
|
|
517
|
+
exclude_zero_weights: bool = True,
|
|
518
|
+
):
|
|
519
|
+
"""Configure training view with sample weights for each ID.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
weight_dict: Mapping of ID to sample weight
|
|
523
|
+
- weight > 1.0: oversample/emphasize
|
|
524
|
+
- weight = 1.0: normal (default)
|
|
525
|
+
- 0 < weight < 1.0: downweight/de-emphasize
|
|
526
|
+
- weight = 0.0: exclude from training
|
|
527
|
+
default_weight: Weight for IDs not in weight_dict (default: 1.0)
|
|
528
|
+
exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
|
|
529
|
+
|
|
530
|
+
Example:
|
|
531
|
+
weights = {
|
|
532
|
+
'compound_42': 3.0, # oversample 3x
|
|
533
|
+
'compound_99': 0.1, # noisy, downweight
|
|
534
|
+
'compound_123': 0.0, # exclude from training
|
|
535
|
+
}
|
|
536
|
+
model.set_sample_weights(weights) # zeros automatically excluded
|
|
537
|
+
model.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
|
|
538
|
+
"""
|
|
539
|
+
from workbench.core.views import TrainingView
|
|
540
|
+
|
|
541
|
+
if not weight_dict:
|
|
542
|
+
self.log.important("Empty weight_dict, creating standard training view")
|
|
543
|
+
TrainingView.create(self, id_column=self.id_column)
|
|
544
|
+
return
|
|
545
|
+
|
|
546
|
+
self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
|
|
547
|
+
|
|
548
|
+
# Helper to format IDs for SQL
|
|
549
|
+
def format_id(id_val):
|
|
550
|
+
return repr(id_val)
|
|
551
|
+
|
|
552
|
+
# Build CASE statement for sample_weight
|
|
553
|
+
case_conditions = [
|
|
554
|
+
f"WHEN {self.id_column} = {format_id(id_val)} THEN {weight}" for id_val, weight in weight_dict.items()
|
|
555
|
+
]
|
|
556
|
+
case_statement = "\n ".join(case_conditions)
|
|
557
|
+
|
|
558
|
+
# Build inner query with sample weights
|
|
559
|
+
inner_sql = f"""SELECT
|
|
560
|
+
*,
|
|
561
|
+
CASE
|
|
562
|
+
{case_statement}
|
|
563
|
+
ELSE {default_weight}
|
|
564
|
+
END AS sample_weight
|
|
565
|
+
FROM {self.table}"""
|
|
566
|
+
|
|
567
|
+
# Optionally filter out zero weights
|
|
568
|
+
if exclude_zero_weights:
|
|
569
|
+
zero_count = sum(1 for weight in weight_dict.values() if weight == 0.0)
|
|
570
|
+
custom_sql = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
|
|
571
|
+
self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
|
|
572
|
+
else:
|
|
573
|
+
custom_sql = inner_sql
|
|
574
|
+
|
|
575
|
+
TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
|
|
576
|
+
|
|
577
|
+
@deprecated(version=0.9)
|
|
512
578
|
def set_training_filter(self, filter_expression: Optional[str] = None):
|
|
513
579
|
"""Set a filter expression for the training view for this FeatureSet
|
|
514
580
|
|
|
@@ -528,6 +594,7 @@ class FeatureSetCore(Artifact):
|
|
|
528
594
|
self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
|
|
529
595
|
)
|
|
530
596
|
|
|
597
|
+
@deprecated(version="0.9")
|
|
531
598
|
def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
|
|
532
599
|
"""Exclude a list of IDs from the training view
|
|
533
600
|
|
|
@@ -551,6 +618,7 @@ class FeatureSetCore(Artifact):
|
|
|
551
618
|
# Apply the filter
|
|
552
619
|
self.set_training_filter(filter_expression)
|
|
553
620
|
|
|
621
|
+
@deprecated(version="0.9")
|
|
554
622
|
def set_training_sampling(
|
|
555
623
|
self,
|
|
556
624
|
exclude_ids: Optional[List[Union[str, int]]] = None,
|