workbench 0.8.198__py3-none-any.whl → 0.8.203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. workbench/algorithms/dataframe/proximity.py +11 -4
  2. workbench/api/__init__.py +2 -1
  3. workbench/api/df_store.py +17 -108
  4. workbench/api/feature_set.py +48 -11
  5. workbench/api/model.py +1 -1
  6. workbench/api/parameter_store.py +3 -52
  7. workbench/core/artifacts/__init__.py +11 -2
  8. workbench/core/artifacts/artifact.py +5 -5
  9. workbench/core/artifacts/df_store_core.py +114 -0
  10. workbench/core/artifacts/endpoint_core.py +261 -78
  11. workbench/core/artifacts/feature_set_core.py +69 -1
  12. workbench/core/artifacts/model_core.py +48 -14
  13. workbench/core/artifacts/parameter_store_core.py +98 -0
  14. workbench/core/transforms/features_to_model/features_to_model.py +50 -33
  15. workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
  16. workbench/core/views/view.py +2 -2
  17. workbench/model_scripts/chemprop/chemprop.template +933 -0
  18. workbench/model_scripts/chemprop/generated_model_script.py +933 -0
  19. workbench/model_scripts/chemprop/requirements.txt +11 -0
  20. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  21. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  22. workbench/model_scripts/custom_models/proximity/proximity.py +11 -4
  23. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +11 -5
  24. workbench/model_scripts/custom_models/uq_models/meta_uq.template +11 -5
  25. workbench/model_scripts/custom_models/uq_models/ngboost.template +11 -5
  26. workbench/model_scripts/custom_models/uq_models/proximity.py +11 -4
  27. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +11 -5
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +365 -173
  29. workbench/model_scripts/pytorch_model/pytorch.template +362 -170
  30. workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
  31. workbench/model_scripts/script_generation.py +10 -7
  32. workbench/model_scripts/uq_models/generated_model_script.py +43 -27
  33. workbench/model_scripts/uq_models/mapie.template +40 -24
  34. workbench/model_scripts/xgb_model/generated_model_script.py +36 -7
  35. workbench/model_scripts/xgb_model/xgb_model.template +36 -7
  36. workbench/repl/workbench_shell.py +14 -5
  37. workbench/resources/open_source_api.key +1 -1
  38. workbench/scripts/endpoint_test.py +162 -0
  39. workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
  40. workbench/utils/chemprop_utils.py +761 -0
  41. workbench/utils/pytorch_utils.py +527 -0
  42. workbench/utils/xgboost_model_utils.py +10 -5
  43. workbench/web_interface/components/model_plot.py +7 -1
  44. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/METADATA +3 -3
  45. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/RECORD +49 -43
  46. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/entry_points.txt +2 -1
  47. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  48. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  49. workbench/model_scripts/__pycache__/script_generation.cpython-312.pyc +0 -0
  50. workbench/model_scripts/__pycache__/script_generation.cpython-313.pyc +0 -0
  51. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/WHEEL +0 -0
  52. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/licenses/LICENSE +0 -0
  53. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/top_level.txt +0 -0
@@ -30,12 +30,14 @@ from sagemaker import Predictor
30
30
 
31
31
  # Workbench Imports
32
32
  from workbench.core.artifacts.artifact import Artifact
33
- from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType
33
+ from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType, ModelFramework
34
34
  from workbench.utils.endpoint_metrics import EndpointMetrics
35
35
  from workbench.utils.cache import Cache
36
36
  from workbench.utils.s3_utils import compute_s3_object_hash
37
37
  from workbench.utils.model_utils import uq_metrics
38
- from workbench.utils.xgboost_model_utils import cross_fold_inference
38
+ from workbench.utils.xgboost_model_utils import cross_fold_inference as xgboost_cross_fold
39
+ from workbench.utils.pytorch_utils import pull_cv_results as pytorch_pull_cv
40
+ from workbench.utils.chemprop_utils import pull_cv_results as chemprop_pull_cv
39
41
  from workbench_bridges.endpoints.fast_inference import fast_inference
40
42
 
41
43
 
@@ -387,7 +389,7 @@ class EndpointCore(Artifact):
387
389
  # Grab the model features and target column
388
390
  model = ModelCore(self.model_name)
389
391
  features = model.features()
390
- target_column = model.target()
392
+ targets = model.target() # Note: We have multi-target models (so this could be a list)
391
393
 
392
394
  # Run predictions on the evaluation data
393
395
  prediction_df = self._predict(eval_df, features, drop_error_rows)
@@ -395,45 +397,84 @@ class EndpointCore(Artifact):
395
397
  self.log.warning("No predictions were made. Returning empty DataFrame.")
396
398
  return prediction_df
397
399
 
400
+ # FIXME: Multi-target support - currently uses first target for metrics
401
+ # Normalize targets to handle both string and list formats
402
+ if isinstance(targets, list):
403
+ primary_target = targets[0] if targets else None
404
+ else:
405
+ primary_target = targets
406
+
398
407
  # Sanity Check that the target column is present
399
- if target_column and (target_column not in prediction_df.columns):
400
- self.log.important(f"Target Column {target_column} not found in prediction_df!")
408
+ if primary_target and (primary_target not in prediction_df.columns):
409
+ self.log.important(f"Target Column {primary_target} not found in prediction_df!")
401
410
  self.log.important("In order to compute metrics, the target column must be present!")
402
- return prediction_df
411
+ metrics = pd.DataFrame()
403
412
 
404
413
  # Compute the standard performance metrics for this model
405
- model_type = model.model_type
406
- if model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
407
- prediction_df = self.residuals(target_column, prediction_df)
408
- metrics = self.regression_metrics(target_column, prediction_df)
409
- elif model_type == ModelType.CLASSIFIER:
410
- metrics = self.classification_metrics(target_column, prediction_df)
411
414
  else:
412
- # For other model types, we don't compute metrics
413
- self.log.info(f"Model Type: {model_type} doesn't have metrics...")
414
- metrics = pd.DataFrame()
415
+ if model.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
416
+ prediction_df = self.residuals(primary_target, prediction_df)
417
+ metrics = self.regression_metrics(primary_target, prediction_df)
418
+ elif model.model_type == ModelType.CLASSIFIER:
419
+ metrics = self.classification_metrics(primary_target, prediction_df)
420
+ else:
421
+ # For other model types, we don't compute metrics
422
+ self.log.info(f"Model Type: {model.model_type} doesn't have metrics...")
423
+ metrics = pd.DataFrame()
415
424
 
416
425
  # Print out the metrics
417
- if not metrics.empty:
418
- print(f"Performance Metrics for {self.model_name} on {self.name}")
419
- print(metrics.head())
420
-
421
- # Capture the inference results and metrics
422
- if capture_name is not None:
423
-
424
- # If we don't have an id_column, we'll pull it from the model's FeatureSet
425
- if id_column is None:
426
- fs = FeatureSetCore(model.get_input())
427
- id_column = fs.id_column
428
- description = capture_name.replace("_", " ").title()
426
+ print(f"Performance Metrics for {self.model_name} on {self.name}")
427
+ print(metrics.head())
428
+
429
+ # Capture the inference results and metrics
430
+ if capture_name is not None:
431
+
432
+ # If we don't have an id_column, we'll pull it from the model's FeatureSet
433
+ if id_column is None:
434
+ fs = FeatureSetCore(model.get_input())
435
+ id_column = fs.id_column
436
+
437
+ # Normalize targets to a list for iteration
438
+ target_list = targets if isinstance(targets, list) else [targets]
439
+
440
+ # For multi-target models, use target-specific capture names (e.g., auto_target1, auto_target2)
441
+ # For single-target models, use the original capture name for backward compatibility
442
+ for target in target_list:
443
+ # Determine capture name: use prefix for multi-target, original name for single-target
444
+ if len(target_list) > 1:
445
+ prefix = "auto" if "auto" in capture_name else capture_name
446
+ target_capture_name = f"{prefix}_{target}"
447
+ else:
448
+ target_capture_name = capture_name
449
+
450
+ description = target_capture_name.replace("_", " ").title()
451
+
452
+ # Drop rows with NaN target values for metrics/plots
453
+ target_df = prediction_df.dropna(subset=[target])
454
+
455
+ # Compute per-target metrics
456
+ if model.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
457
+ target_metrics = self.regression_metrics(target, target_df)
458
+ elif model.model_type == ModelType.CLASSIFIER:
459
+ target_metrics = self.classification_metrics(target, target_df)
460
+ else:
461
+ target_metrics = pd.DataFrame()
462
+
429
463
  self._capture_inference_results(
430
- capture_name, prediction_df, target_column, model_type, metrics, description, features, id_column
464
+ target_capture_name,
465
+ target_df,
466
+ target,
467
+ model.model_type,
468
+ target_metrics,
469
+ description,
470
+ features,
471
+ id_column,
431
472
  )
432
473
 
433
- # For UQ Models we also capture the uncertainty metrics
434
- if model_type in [ModelType.UQ_REGRESSOR]:
435
- metrics = uq_metrics(prediction_df, target_column)
436
- self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
474
+ # For UQ Models we also capture the uncertainty metrics
475
+ if model.model_type in [ModelType.UQ_REGRESSOR]:
476
+ metrics = uq_metrics(prediction_df, primary_target)
477
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
437
478
 
438
479
  # Return the prediction DataFrame
439
480
  return prediction_df
@@ -452,7 +493,16 @@ class EndpointCore(Artifact):
452
493
  model = ModelCore(self.model_name)
453
494
 
454
495
  # Compute CrossFold (Metrics and Prediction Dataframe)
455
- cross_fold_metrics, out_of_fold_df = cross_fold_inference(model, nfolds=nfolds)
496
+ # For PyTorch and ChemProp, pull pre-computed CV results from training
497
+ if model.model_framework in [ModelFramework.UNKNOWN, ModelFramework.XGBOOST]:
498
+ cross_fold_metrics, out_of_fold_df = xgboost_cross_fold(model, nfolds=nfolds)
499
+ elif model.model_framework == ModelFramework.PYTORCH_TABULAR:
500
+ cross_fold_metrics, out_of_fold_df = pytorch_pull_cv(model)
501
+ elif model.model_framework == ModelFramework.CHEMPROP:
502
+ cross_fold_metrics, out_of_fold_df = chemprop_pull_cv(model)
503
+ else:
504
+ self.log.error(f"Cross-Fold Inference not supported for Model Framework: {model.model_framework}.")
505
+ return pd.DataFrame()
456
506
 
457
507
  # If the metrics dataframe isn't empty save to the param store
458
508
  if not cross_fold_metrics.empty:
@@ -460,10 +510,13 @@ class EndpointCore(Artifact):
460
510
  metrics = cross_fold_metrics.to_dict(orient="records")
461
511
  self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", metrics)
462
512
 
513
+ # If the out_of_fold_df is empty return it
514
+ if out_of_fold_df.empty:
515
+ self.log.warning("No out-of-fold predictions were made. Returning empty DataFrame.")
516
+ return out_of_fold_df
517
+
463
518
  # Capture the results
464
- capture_name = "full_cross_fold"
465
- description = capture_name.replace("_", " ").title()
466
- target_column = model.target()
519
+ targets = model.target() # Note: We have multi-target models (so this could be a list)
467
520
  model_type = model.model_type
468
521
 
469
522
  # Get the id_column from the model's FeatureSet
@@ -472,7 +525,7 @@ class EndpointCore(Artifact):
472
525
 
473
526
  # Is this a UQ Model? If so, run full inference and merge the results
474
527
  additional_columns = []
475
- if model_type == ModelType.UQ_REGRESSOR:
528
+ if model.model_framework == ModelFramework.XGBOOST and model_type == ModelType.UQ_REGRESSOR:
476
529
  self.log.important("UQ Regressor detected, running full inference to get uncertainty estimates...")
477
530
 
478
531
  # Get the training view dataframe for inference
@@ -481,9 +534,11 @@ class EndpointCore(Artifact):
481
534
  # Run inference on the endpoint to get UQ outputs
482
535
  uq_df = self.inference(training_df)
483
536
 
484
- # Identify UQ-specific columns (quantiles and prediction_std)
537
+ # Identify UQ-specific columns (quantiles, prediction_std, *_pred_std)
485
538
  uq_columns = [
486
- col for col in uq_df.columns if col.startswith("q_") or col == "prediction_std" or col == "confidence"
539
+ col
540
+ for col in uq_df.columns
541
+ if col.startswith("q_") or col == "prediction_std" or col.endswith("_pred_std") or col == "confidence"
487
542
  ]
488
543
 
489
544
  # Merge UQ columns with out-of-fold predictions
@@ -499,20 +554,42 @@ class EndpointCore(Artifact):
499
554
  additional_columns = uq_columns
500
555
  self.log.info(f"Added UQ columns: {', '.join(additional_columns)}")
501
556
 
502
- # Also compute UQ metrics
503
- metrics = uq_metrics(out_of_fold_df, target_column)
504
- self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
557
+ # Also compute UQ metrics (use first target for multi-target models)
558
+ primary_target = targets[0] if isinstance(targets, list) else targets
559
+ metrics = uq_metrics(out_of_fold_df, primary_target)
560
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/full_cross_fold", metrics)
561
+
562
+ # Normalize targets to a list for iteration
563
+ target_list = targets if isinstance(targets, list) else [targets]
564
+
565
+ # For multi-target models, use target-specific capture names (e.g., cv_target1, cv_target2)
566
+ # For single-target models, use "full_cross_fold" for backward compatibility
567
+ for target in target_list:
568
+ capture_name = f"cv_{target}"
569
+ description = capture_name.replace("_", " ").title()
570
+
571
+ # Drop rows with NaN target values for metrics/plots
572
+ target_df = out_of_fold_df.dropna(subset=[target])
573
+
574
+ # Compute per-target metrics
575
+ if model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
576
+ target_metrics = self.regression_metrics(target, target_df)
577
+ elif model_type == ModelType.CLASSIFIER:
578
+ target_metrics = self.classification_metrics(target, target_df)
579
+ else:
580
+ target_metrics = pd.DataFrame()
581
+
582
+ self._capture_inference_results(
583
+ capture_name,
584
+ target_df,
585
+ target,
586
+ model_type,
587
+ target_metrics,
588
+ description,
589
+ features=additional_columns,
590
+ id_column=id_column,
591
+ )
505
592
 
506
- self._capture_inference_results(
507
- capture_name,
508
- out_of_fold_df,
509
- target_column,
510
- model_type,
511
- cross_fold_metrics,
512
- description,
513
- features=additional_columns,
514
- id_column=id_column,
515
- )
516
593
  return out_of_fold_df
517
594
 
518
595
  def fast_inference(self, eval_df: pd.DataFrame, threads: int = 4) -> pd.DataFrame:
@@ -718,23 +795,47 @@ class EndpointCore(Artifact):
718
795
  combined = row_hashes.values.tobytes()
719
796
  return hashlib.md5(combined).hexdigest()[:hash_length]
720
797
 
798
+ @staticmethod
799
+ def _find_prediction_column(df: pd.DataFrame, target_column: str) -> Optional[str]:
800
+ """Find the prediction column in a DataFrame.
801
+
802
+ Looks for 'prediction' column first, then '{target}_pred' pattern.
803
+
804
+ Args:
805
+ df: DataFrame to search
806
+ target_column: Name of the target column (used for {target}_pred pattern)
807
+
808
+ Returns:
809
+ Name of the prediction column, or None if not found
810
+ """
811
+ # Check for 'prediction' column first (legacy/standard format)
812
+ if "prediction" in df.columns:
813
+ return "prediction"
814
+
815
+ # Check for '{target}_pred' format (multi-target format)
816
+ target_pred_col = f"{target_column}_pred"
817
+ if target_pred_col in df.columns:
818
+ return target_pred_col
819
+
820
+ return None
821
+
721
822
  def _capture_inference_results(
722
823
  self,
723
824
  capture_name: str,
724
825
  pred_results_df: pd.DataFrame,
725
- target_column: str,
826
+ target: str,
726
827
  model_type: ModelType,
727
828
  metrics: pd.DataFrame,
728
829
  description: str,
729
830
  features: list,
730
831
  id_column: str = None,
731
832
  ):
732
- """Internal: Capture the inference results and metrics to S3
833
+ """Internal: Capture the inference results and metrics to S3 for a single target
733
834
 
734
835
  Args:
735
836
  capture_name (str): Name of the inference capture
736
837
  pred_results_df (pd.DataFrame): DataFrame with the prediction results
737
- target_column (str): Name of the target column
838
+ target (str): Target column name
738
839
  model_type (ModelType): Type of the model (e.g. REGRESSOR, CLASSIFIER)
739
840
  metrics (pd.DataFrame): DataFrame with the performance metrics
740
841
  description (str): Description of the inference results
@@ -765,28 +866,12 @@ class EndpointCore(Artifact):
765
866
  self.log.info(f"Writing metrics to {inference_capture_path}/inference_metrics.csv")
766
867
  wr.s3.to_csv(metrics, f"{inference_capture_path}/inference_metrics.csv", index=False)
767
868
 
768
- # Grab the target column, prediction column, any _proba columns, and the ID column (if present)
769
- output_columns = [target_column]
770
- output_columns += [col for col in pred_results_df.columns if "prediction" in col]
771
-
772
- # Add any _proba columns to the output columns
773
- output_columns += [col for col in pred_results_df.columns if col.endswith("_proba")]
774
-
775
- # Add any Uncertainty Quantile columns to the output columns
776
- output_columns += [col for col in pred_results_df.columns if col.startswith("q_") or col == "confidence"]
777
-
778
- # Add the ID column
779
- if id_column and id_column in pred_results_df.columns:
780
- output_columns.insert(0, id_column)
781
-
782
- # Write the predictions to our S3 Model Inference Folder
783
- self.log.info(f"Writing predictions to {inference_capture_path}/inference_predictions.csv")
784
- subset_df = pred_results_df[output_columns]
785
- wr.s3.to_csv(subset_df, f"{inference_capture_path}/inference_predictions.csv", index=False)
869
+ # Save the inference predictions for this target
870
+ self._save_target_inference(inference_capture_path, pred_results_df, target, id_column)
786
871
 
787
872
  # CLASSIFIER: Write the confusion matrix to our S3 Model Inference Folder
788
873
  if model_type == ModelType.CLASSIFIER:
789
- conf_mtx = self.generate_confusion_matrix(target_column, pred_results_df)
874
+ conf_mtx = self.generate_confusion_matrix(target, pred_results_df)
790
875
  self.log.info(f"Writing confusion matrix to {inference_capture_path}/inference_cm.csv")
791
876
  # Note: Unlike other dataframes here, we want to write the index (labels) to the CSV
792
877
  wr.s3.to_csv(conf_mtx, f"{inference_capture_path}/inference_cm.csv", index=True)
@@ -796,6 +881,57 @@ class EndpointCore(Artifact):
796
881
  model = ModelCore(self.model_name)
797
882
  model._load_inference_metrics(capture_name)
798
883
 
884
+ def _save_target_inference(
885
+ self,
886
+ inference_capture_path: str,
887
+ pred_results_df: pd.DataFrame,
888
+ target: str,
889
+ id_column: str = None,
890
+ ):
891
+ """Save inference results for a single target.
892
+
893
+ Args:
894
+ inference_capture_path (str): S3 path for inference capture
895
+ pred_results_df (pd.DataFrame): DataFrame with prediction results
896
+ target (str): Target column name
897
+ id_column (str, optional): Name of the ID column
898
+ """
899
+ # Start with ID column if present
900
+ output_columns = []
901
+ if id_column and id_column in pred_results_df.columns:
902
+ output_columns.append(id_column)
903
+
904
+ # Add target column if present
905
+ if target and target in pred_results_df.columns:
906
+ output_columns.append(target)
907
+
908
+ # Build the output DataFrame
909
+ output_df = pred_results_df[output_columns].copy() if output_columns else pd.DataFrame()
910
+
911
+ # For multi-task: map {target}_pred -> prediction, {target}_pred_std -> prediction_std
912
+ # For single-task: just grab prediction and prediction_std columns directly
913
+ pred_col = f"{target}_pred"
914
+ std_col = f"{target}_pred_std"
915
+ if pred_col in pred_results_df.columns:
916
+ # Multi-task columns exist
917
+ output_df["prediction"] = pred_results_df[pred_col]
918
+ if std_col in pred_results_df.columns:
919
+ output_df["prediction_std"] = pred_results_df[std_col]
920
+ else:
921
+ # Single-task: grab standard prediction columns
922
+ for col in ["prediction", "prediction_std"]:
923
+ if col in pred_results_df.columns:
924
+ output_df[col] = pred_results_df[col]
925
+ # Also grab any _proba columns and UQ columns
926
+ for col in pred_results_df.columns:
927
+ if col.endswith("_proba") or col.startswith("q_") or col == "confidence":
928
+ output_df[col] = pred_results_df[col]
929
+
930
+ # Write the predictions to S3
931
+ output_file = f"{inference_capture_path}/inference_predictions.csv"
932
+ self.log.info(f"Writing predictions to {output_file}")
933
+ wr.s3.to_csv(output_df, output_file, index=False)
934
+
799
935
  def regression_metrics(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
800
936
  """Compute the performance metrics for this Endpoint
801
937
  Args:
@@ -810,10 +946,28 @@ class EndpointCore(Artifact):
810
946
  self.log.warning("No predictions were made. Returning empty DataFrame.")
811
947
  return pd.DataFrame()
812
948
 
949
+ # Find the prediction column: "prediction" or "{target}_pred"
950
+ prediction_col = self._find_prediction_column(prediction_df, target_column)
951
+ if prediction_col is None:
952
+ self.log.warning(f"No prediction column found for target '{target_column}'")
953
+ return pd.DataFrame()
954
+
955
+ # Check for NaN values in target or prediction columns
956
+ if prediction_df[target_column].isnull().any() or prediction_df[prediction_col].isnull().any():
957
+ # Compute the number of NaN values in each column
958
+ num_nan_target = prediction_df[target_column].isnull().sum()
959
+ num_nan_prediction = prediction_df[prediction_col].isnull().sum()
960
+ self.log.warning(
961
+ f"NaNs Found: {target_column} {num_nan_target} and {prediction_col}: {num_nan_prediction}."
962
+ )
963
+ self.log.warning(
964
+ "NaN values found in target or prediction columns. Dropping NaN rows for metric computation."
965
+ )
966
+ prediction_df = prediction_df.dropna(subset=[target_column, prediction_col])
967
+
813
968
  # Compute the metrics
814
969
  try:
815
970
  y_true = prediction_df[target_column]
816
- prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
817
971
  y_pred = prediction_df[prediction_col]
818
972
 
819
973
  mae = mean_absolute_error(y_true, y_pred)
@@ -849,7 +1003,13 @@ class EndpointCore(Artifact):
849
1003
 
850
1004
  # Compute the residuals
851
1005
  y_true = prediction_df[target_column]
852
- prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
1006
+
1007
+ # Find the prediction column: "prediction" or "{target}_pred"
1008
+ prediction_col = self._find_prediction_column(prediction_df, target_column)
1009
+ if prediction_col is None:
1010
+ self.log.warning(f"No prediction column found for target '{target_column}'. Cannot compute residuals.")
1011
+ return prediction_df
1012
+
853
1013
  y_pred = prediction_df[prediction_col]
854
1014
 
855
1015
  # Check for classification scenario
@@ -891,6 +1051,19 @@ class EndpointCore(Artifact):
891
1051
  Returns:
892
1052
  pd.DataFrame: DataFrame with the performance metrics
893
1053
  """
1054
+ # Find the prediction column: "prediction" or "{target}_pred"
1055
+ prediction_col = self._find_prediction_column(prediction_df, target_column)
1056
+ if prediction_col is None:
1057
+ self.log.warning(f"No prediction column found for target '{target_column}'")
1058
+ return pd.DataFrame()
1059
+
1060
+ # Drop rows with NaN predictions (can't compute metrics on missing predictions)
1061
+ nan_mask = prediction_df[prediction_col].isna()
1062
+ if nan_mask.any():
1063
+ n_nan = nan_mask.sum()
1064
+ self.log.warning(f"Dropping {n_nan} rows with NaN predictions for metrics calculation")
1065
+ prediction_df = prediction_df[~nan_mask].copy()
1066
+
894
1067
  # Get the class labels from the model
895
1068
  class_labels = ModelCore(self.model_name).class_labels()
896
1069
  if class_labels is None:
@@ -903,7 +1076,6 @@ class EndpointCore(Artifact):
903
1076
  self.validate_proba_columns(prediction_df, class_labels)
904
1077
 
905
1078
  # Calculate precision, recall, f1, and support, handling zero division
906
- prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
907
1079
  scores = precision_recall_fscore_support(
908
1080
  prediction_df[target_column],
909
1081
  prediction_df[prediction_col],
@@ -954,9 +1126,20 @@ class EndpointCore(Artifact):
954
1126
  Returns:
955
1127
  pd.DataFrame: DataFrame with the confusion matrix
956
1128
  """
1129
+ # Find the prediction column: "prediction" or "{target}_pred"
1130
+ prediction_col = self._find_prediction_column(prediction_df, target_column)
1131
+ if prediction_col is None:
1132
+ self.log.warning(f"No prediction column found for target '{target_column}'")
1133
+ return pd.DataFrame()
1134
+
1135
+ # Drop rows with NaN predictions (can't include in confusion matrix)
1136
+ nan_mask = prediction_df[prediction_col].isna()
1137
+ if nan_mask.any():
1138
+ n_nan = nan_mask.sum()
1139
+ self.log.warning(f"Dropping {n_nan} rows with NaN predictions for confusion matrix")
1140
+ prediction_df = prediction_df[~nan_mask].copy()
957
1141
 
958
1142
  y_true = prediction_df[target_column]
959
- prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
960
1143
  y_pred = prediction_df[prediction_col]
961
1144
 
962
1145
  # Get model class labels
@@ -16,8 +16,9 @@ from sagemaker.feature_store.feature_store import FeatureStore
16
16
  from workbench.core.artifacts.artifact import Artifact
17
17
  from workbench.core.artifacts.data_source_factory import DataSourceFactory
18
18
  from workbench.core.artifacts.athena_source import AthenaSource
19
+ from workbench.utils.deprecated_utils import deprecated
19
20
 
20
- from typing import TYPE_CHECKING, Optional, List, Union
21
+ from typing import TYPE_CHECKING, Optional, List, Dict, Union
21
22
 
22
23
  from workbench.utils.aws_utils import aws_throttle
23
24
 
@@ -509,6 +510,71 @@ class FeatureSetCore(Artifact):
509
510
  ].tolist()
510
511
  return hold_out_ids
511
512
 
513
+ def set_sample_weights(
514
+ self,
515
+ weight_dict: Dict[Union[str, int], float],
516
+ default_weight: float = 1.0,
517
+ exclude_zero_weights: bool = True,
518
+ ):
519
+ """Configure training view with sample weights for each ID.
520
+
521
+ Args:
522
+ weight_dict: Mapping of ID to sample weight
523
+ - weight > 1.0: oversample/emphasize
524
+ - weight = 1.0: normal (default)
525
+ - 0 < weight < 1.0: downweight/de-emphasize
526
+ - weight = 0.0: exclude from training
527
+ default_weight: Weight for IDs not in weight_dict (default: 1.0)
528
+ exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
529
+
530
+ Example:
531
+ weights = {
532
+ 'compound_42': 3.0, # oversample 3x
533
+ 'compound_99': 0.1, # noisy, downweight
534
+ 'compound_123': 0.0, # exclude from training
535
+ }
536
+ model.set_sample_weights(weights) # zeros automatically excluded
537
+ model.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
538
+ """
539
+ from workbench.core.views import TrainingView
540
+
541
+ if not weight_dict:
542
+ self.log.important("Empty weight_dict, creating standard training view")
543
+ TrainingView.create(self, id_column=self.id_column)
544
+ return
545
+
546
+ self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
547
+
548
+ # Helper to format IDs for SQL
549
+ def format_id(id_val):
550
+ return repr(id_val)
551
+
552
+ # Build CASE statement for sample_weight
553
+ case_conditions = [
554
+ f"WHEN {self.id_column} = {format_id(id_val)} THEN {weight}" for id_val, weight in weight_dict.items()
555
+ ]
556
+ case_statement = "\n ".join(case_conditions)
557
+
558
+ # Build inner query with sample weights
559
+ inner_sql = f"""SELECT
560
+ *,
561
+ CASE
562
+ {case_statement}
563
+ ELSE {default_weight}
564
+ END AS sample_weight
565
+ FROM {self.table}"""
566
+
567
+ # Optionally filter out zero weights
568
+ if exclude_zero_weights:
569
+ zero_count = sum(1 for weight in weight_dict.values() if weight == 0.0)
570
+ custom_sql = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
571
+ self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
572
+ else:
573
+ custom_sql = inner_sql
574
+
575
+ TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
576
+
577
+ @deprecated(version=0.9)
512
578
  def set_training_filter(self, filter_expression: Optional[str] = None):
513
579
  """Set a filter expression for the training view for this FeatureSet
514
580
 
@@ -528,6 +594,7 @@ class FeatureSetCore(Artifact):
528
594
  self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
529
595
  )
530
596
 
597
+ @deprecated(version="0.9")
531
598
  def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
532
599
  """Exclude a list of IDs from the training view
533
600
 
@@ -551,6 +618,7 @@ class FeatureSetCore(Artifact):
551
618
  # Apply the filter
552
619
  self.set_training_filter(filter_expression)
553
620
 
621
+ @deprecated(version="0.9")
554
622
  def set_training_sampling(
555
623
  self,
556
624
  exclude_ids: Optional[List[Union[str, int]]] = None,