workbench 0.8.197__py3-none-any.whl → 0.8.201__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. workbench/algorithms/dataframe/proximity.py +19 -12
  2. workbench/api/__init__.py +2 -1
  3. workbench/api/feature_set.py +7 -4
  4. workbench/api/model.py +1 -1
  5. workbench/core/artifacts/__init__.py +11 -2
  6. workbench/core/artifacts/endpoint_core.py +84 -46
  7. workbench/core/artifacts/feature_set_core.py +69 -1
  8. workbench/core/artifacts/model_core.py +37 -7
  9. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  10. workbench/core/transforms/features_to_model/features_to_model.py +23 -20
  11. workbench/core/views/view.py +2 -2
  12. workbench/model_scripts/chemprop/chemprop.template +931 -0
  13. workbench/model_scripts/chemprop/generated_model_script.py +931 -0
  14. workbench/model_scripts/chemprop/requirements.txt +11 -0
  15. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  16. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  17. workbench/model_scripts/custom_models/proximity/proximity.py +19 -12
  18. workbench/model_scripts/custom_models/uq_models/proximity.py +19 -12
  19. workbench/model_scripts/pytorch_model/generated_model_script.py +130 -88
  20. workbench/model_scripts/pytorch_model/pytorch.template +128 -86
  21. workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
  22. workbench/model_scripts/script_generation.py +10 -7
  23. workbench/model_scripts/uq_models/generated_model_script.py +25 -18
  24. workbench/model_scripts/uq_models/mapie.template +23 -16
  25. workbench/model_scripts/xgb_model/generated_model_script.py +6 -6
  26. workbench/model_scripts/xgb_model/xgb_model.template +2 -2
  27. workbench/repl/workbench_shell.py +14 -5
  28. workbench/scripts/endpoint_test.py +162 -0
  29. workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
  30. workbench/utils/chemprop_utils.py +724 -0
  31. workbench/utils/pytorch_utils.py +497 -0
  32. workbench/utils/xgboost_model_utils.py +12 -5
  33. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/METADATA +2 -2
  34. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/RECORD +38 -30
  35. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/entry_points.txt +2 -1
  36. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/WHEEL +0 -0
  37. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/licenses/LICENSE +0 -0
  38. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/top_level.txt +0 -0
@@ -68,7 +68,8 @@ class Proximity:
68
68
  self,
69
69
  top_percent: float = 1.0,
70
70
  min_delta: Optional[float] = None,
71
- k_neighbors: int = 5,
71
+ k_neighbors: int = 4,
72
+ only_coincident: bool = False,
72
73
  ) -> pd.DataFrame:
73
74
  """
74
75
  Find compounds with steep target gradients (data quality issues and activity cliffs).
@@ -80,7 +81,8 @@ class Proximity:
80
81
  Args:
81
82
  top_percent: Percentage of compounds with steepest gradients to return (e.g., 1.0 = top 1%)
82
83
  min_delta: Minimum absolute target difference to consider. If None, defaults to target_range/100
83
- k_neighbors: Number of neighbors to use for median calculation (default: 5)
84
+ k_neighbors: Number of neighbors to use for median calculation (default: 4)
85
+ only_coincident: If True, only consider compounds that are coincident (default: False)
84
86
 
85
87
  Returns:
86
88
  DataFrame of compounds with steepest gradients, sorted by gradient (descending)
@@ -99,10 +101,15 @@ class Proximity:
99
101
  min_delta = self.target_range / 100.0 if self.target_range > 0 else 0.0
100
102
  candidates = candidates[candidates["nn_target_diff"] >= min_delta]
101
103
 
102
- # Get top X% by initial gradient
103
- percentile = 100 - top_percent
104
- threshold = np.percentile(candidates["gradient"], percentile)
105
- candidates = candidates[candidates["gradient"] >= threshold].copy()
104
+ # Filter based on mode
105
+ if only_coincident:
106
+ # Only keep coincident points (nn_distance ~= 0)
107
+ candidates = candidates[candidates["nn_distance"] < epsilon].copy()
108
+ else:
109
+ # Get top X% by initial gradient
110
+ percentile = 100 - top_percent
111
+ threshold = np.percentile(candidates["gradient"], percentile)
112
+ candidates = candidates[candidates["gradient"] >= threshold].copy()
106
113
 
107
114
  # Phase 2: Verify with k-neighbor median to filter out cases where nearest neighbor is the outlier
108
115
  results = []
@@ -113,23 +120,23 @@ class Proximity:
113
120
  # Get k nearest neighbors (excluding self)
114
121
  nbrs = self.neighbors(cmpd_id, n_neighbors=k_neighbors, include_self=False)
115
122
 
116
- # Calculate median target of k nearest neighbors
117
- neighbor_median = nbrs.head(k_neighbors)[self.target].median()
123
+ # Calculate median target of k neighbors, excluding the nearest neighbor (index 0)
124
+ neighbor_median = nbrs.iloc[1:k_neighbors][self.target].median()
118
125
  median_diff = abs(cmpd_target - neighbor_median)
119
126
 
120
127
  # Only keep if compound differs from neighborhood median
121
128
  # This filters out cases where the nearest neighbor is the outlier
122
129
  if median_diff >= min_delta:
123
- mean_distance = nbrs.head(k_neighbors)["distance"].mean()
124
-
125
130
  results.append(
126
131
  {
127
132
  self.id_column: cmpd_id,
128
133
  self.target: cmpd_target,
134
+ "nn_target": row["nn_target"],
135
+ "nn_target_diff": row["nn_target_diff"],
136
+ "nn_distance": row["nn_distance"],
137
+ "gradient": row["gradient"], # Keep Phase 1 gradient
129
138
  "neighbor_median": neighbor_median,
130
139
  "neighbor_median_diff": median_diff,
131
- "mean_distance": mean_distance,
132
- "gradient": median_diff / (mean_distance + epsilon),
133
140
  }
134
141
  )
135
142
 
workbench/api/__init__.py CHANGED
@@ -14,7 +14,7 @@ These class provide high-level APIs for the Workbench package, offering easy acc
14
14
 
15
15
  from .data_source import DataSource
16
16
  from .feature_set import FeatureSet
17
- from .model import Model, ModelType
17
+ from .model import Model, ModelType, ModelFramework
18
18
  from .endpoint import Endpoint
19
19
  from .meta import Meta
20
20
  from .parameter_store import ParameterStore
@@ -25,6 +25,7 @@ __all__ = [
25
25
  "FeatureSet",
26
26
  "Model",
27
27
  "ModelType",
28
+ "ModelFramework",
28
29
  "Endpoint",
29
30
  "Meta",
30
31
  "ParameterStore",
@@ -12,7 +12,7 @@ import pandas as pd
12
12
  from workbench.core.artifacts.artifact import Artifact
13
13
  from workbench.core.artifacts.feature_set_core import FeatureSetCore
14
14
  from workbench.core.transforms.features_to_model.features_to_model import FeaturesToModel
15
- from workbench.api.model import Model, ModelType
15
+ from workbench.api.model import Model, ModelType, ModelFramework
16
16
 
17
17
 
18
18
  class FeatureSet(FeatureSetCore):
@@ -79,6 +79,7 @@ class FeatureSet(FeatureSetCore):
79
79
  self,
80
80
  name: str,
81
81
  model_type: ModelType,
82
+ model_framework: ModelFramework = ModelFramework.XGBOOST,
82
83
  tags: list = None,
83
84
  description: str = None,
84
85
  feature_list: list = None,
@@ -98,11 +99,12 @@ class FeatureSet(FeatureSetCore):
98
99
 
99
100
  name (str): The name of the Model to create
100
101
  model_type (ModelType): The type of model to create (See workbench.model.ModelType)
102
+ model_framework (ModelFramework, optional): The framework to use for the model (default: XGBOOST)
101
103
  tags (list, optional): Set the tags for the model. If not given tags will be generated.
102
104
  description (str, optional): Set the description for the model. If not give a description is generated.
103
105
  feature_list (list, optional): Set the feature list for the model. If not given a feature list is generated.
104
106
  target_column (str, optional): The target column for the model (use None for unsupervised model)
105
- model_class (str, optional): Model class to use (e.g. "KMeans", "PyTorch", default: None)
107
+ model_class (str, optional): Model class to use (e.g. "KMeans", default: None)
106
108
  model_import_str (str, optional): The import for the model (e.g. "from sklearn.cluster import KMeans")
107
109
  custom_script (str, optional): The custom script to use for the model (default: None)
108
110
  training_image (str, optional): The training image to use (default: "training")
@@ -128,8 +130,8 @@ class FeatureSet(FeatureSetCore):
128
130
  # Create the Model Tags
129
131
  tags = [name] if tags is None else tags
130
132
 
131
- # If the model_class is PyTorch, ensure we set the training and inference images
132
- if model_class and model_class.lower() == "pytorch":
133
+ # If the model framework is PyTorch or ChemProp, ensure we set the training and inference images
134
+ if model_framework in (ModelFramework.PYTORCH_TABULAR, ModelFramework.CHEMPROP):
133
135
  training_image = "pytorch_training"
134
136
  inference_image = "pytorch_inference"
135
137
 
@@ -138,6 +140,7 @@ class FeatureSet(FeatureSetCore):
138
140
  feature_name=self.name,
139
141
  model_name=name,
140
142
  model_type=model_type,
143
+ model_framework=model_framework,
141
144
  model_class=model_class,
142
145
  model_import_str=model_import_str,
143
146
  custom_script=custom_script,
workbench/api/model.py CHANGED
@@ -7,7 +7,7 @@ Dashboard UI, which provides additional model details and performance metrics
7
7
 
8
8
  # Workbench Imports
9
9
  from workbench.core.artifacts.artifact import Artifact
10
- from workbench.core.artifacts.model_core import ModelCore, ModelType # noqa: F401
10
+ from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelFramework # noqa: F401
11
11
  from workbench.core.transforms.model_to_endpoint.model_to_endpoint import ModelToEndpoint
12
12
  from workbench.api.endpoint import Endpoint
13
13
  from workbench.utils.model_utils import proximity_model_local, uq_model
@@ -15,7 +15,16 @@ from .artifact import Artifact
15
15
  from .athena_source import AthenaSource
16
16
  from .data_source_abstract import DataSourceAbstract
17
17
  from .feature_set_core import FeatureSetCore
18
- from .model_core import ModelCore, ModelType
18
+ from .model_core import ModelCore, ModelType, ModelFramework
19
19
  from .endpoint_core import EndpointCore
20
20
 
21
- __all__ = ["Artifact", "AthenaSource", "DataSourceAbstract", "FeatureSetCore", "ModelCore", "ModelType", "EndpointCore"]
21
+ __all__ = [
22
+ "Artifact",
23
+ "AthenaSource",
24
+ "DataSourceAbstract",
25
+ "FeatureSetCore",
26
+ "ModelCore",
27
+ "ModelType",
28
+ "ModelFramework",
29
+ "EndpointCore",
30
+ ]
@@ -30,12 +30,14 @@ from sagemaker import Predictor
30
30
 
31
31
  # Workbench Imports
32
32
  from workbench.core.artifacts.artifact import Artifact
33
- from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType
33
+ from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType, ModelFramework
34
34
  from workbench.utils.endpoint_metrics import EndpointMetrics
35
35
  from workbench.utils.cache import Cache
36
36
  from workbench.utils.s3_utils import compute_s3_object_hash
37
37
  from workbench.utils.model_utils import uq_metrics
38
- from workbench.utils.xgboost_model_utils import cross_fold_inference
38
+ from workbench.utils.xgboost_model_utils import cross_fold_inference as xgboost_cross_fold
39
+ from workbench.utils.pytorch_utils import cross_fold_inference as pytorch_cross_fold
40
+ from workbench.utils.chemprop_utils import cross_fold_inference as chemprop_cross_fold
39
41
  from workbench_bridges.endpoints.fast_inference import fast_inference
40
42
 
41
43
 
@@ -399,41 +401,40 @@ class EndpointCore(Artifact):
399
401
  if target_column and (target_column not in prediction_df.columns):
400
402
  self.log.important(f"Target Column {target_column} not found in prediction_df!")
401
403
  self.log.important("In order to compute metrics, the target column must be present!")
402
- return prediction_df
404
+ metrics = pd.DataFrame()
403
405
 
404
406
  # Compute the standard performance metrics for this model
405
- model_type = model.model_type
406
- if model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
407
- prediction_df = self.residuals(target_column, prediction_df)
408
- metrics = self.regression_metrics(target_column, prediction_df)
409
- elif model_type == ModelType.CLASSIFIER:
410
- metrics = self.classification_metrics(target_column, prediction_df)
411
407
  else:
412
- # For other model types, we don't compute metrics
413
- self.log.info(f"Model Type: {model_type} doesn't have metrics...")
414
- metrics = pd.DataFrame()
408
+ if model.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
409
+ prediction_df = self.residuals(target_column, prediction_df)
410
+ metrics = self.regression_metrics(target_column, prediction_df)
411
+ elif model.model_type == ModelType.CLASSIFIER:
412
+ metrics = self.classification_metrics(target_column, prediction_df)
413
+ else:
414
+ # For other model types, we don't compute metrics
415
+ self.log.info(f"Model Type: {model.model_type} doesn't have metrics...")
416
+ metrics = pd.DataFrame()
415
417
 
416
418
  # Print out the metrics
417
- if not metrics.empty:
418
- print(f"Performance Metrics for {self.model_name} on {self.name}")
419
- print(metrics.head())
420
-
421
- # Capture the inference results and metrics
422
- if capture_name is not None:
423
-
424
- # If we don't have an id_column, we'll pull it from the model's FeatureSet
425
- if id_column is None:
426
- fs = FeatureSetCore(model.get_input())
427
- id_column = fs.id_column
428
- description = capture_name.replace("_", " ").title()
429
- self._capture_inference_results(
430
- capture_name, prediction_df, target_column, model_type, metrics, description, features, id_column
431
- )
432
-
433
- # For UQ Models we also capture the uncertainty metrics
434
- if model_type in [ModelType.UQ_REGRESSOR]:
435
- metrics = uq_metrics(prediction_df, target_column)
436
- self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
419
+ print(f"Performance Metrics for {self.model_name} on {self.name}")
420
+ print(metrics.head())
421
+
422
+ # Capture the inference results and metrics
423
+ if capture_name is not None:
424
+
425
+ # If we don't have an id_column, we'll pull it from the model's FeatureSet
426
+ if id_column is None:
427
+ fs = FeatureSetCore(model.get_input())
428
+ id_column = fs.id_column
429
+ description = capture_name.replace("_", " ").title()
430
+ self._capture_inference_results(
431
+ capture_name, prediction_df, target_column, model.model_type, metrics, description, features, id_column
432
+ )
433
+
434
+ # For UQ Models we also capture the uncertainty metrics
435
+ if model.model_type in [ModelType.UQ_REGRESSOR]:
436
+ metrics = uq_metrics(prediction_df, target_column)
437
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
437
438
 
438
439
  # Return the prediction DataFrame
439
440
  return prediction_df
@@ -452,7 +453,15 @@ class EndpointCore(Artifact):
452
453
  model = ModelCore(self.model_name)
453
454
 
454
455
  # Compute CrossFold (Metrics and Prediction Dataframe)
455
- cross_fold_metrics, out_of_fold_df = cross_fold_inference(model, nfolds=nfolds)
456
+ if model.model_framework in [ModelFramework.UNKNOWN, ModelFramework.XGBOOST]:
457
+ cross_fold_metrics, out_of_fold_df = xgboost_cross_fold(model, nfolds=nfolds)
458
+ elif model.model_framework == ModelFramework.PYTORCH_TABULAR:
459
+ cross_fold_metrics, out_of_fold_df = pytorch_cross_fold(model, nfolds=nfolds)
460
+ elif model.model_framework == ModelFramework.CHEMPROP:
461
+ cross_fold_metrics, out_of_fold_df = chemprop_cross_fold(model, nfolds=nfolds)
462
+ else:
463
+ self.log.error(f"Cross-Fold Inference not supported for Model Framework: {model.model_framework}.")
464
+ return pd.DataFrame()
456
465
 
457
466
  # If the metrics dataframe isn't empty save to the param store
458
467
  if not cross_fold_metrics.empty:
@@ -460,6 +469,11 @@ class EndpointCore(Artifact):
460
469
  metrics = cross_fold_metrics.to_dict(orient="records")
461
470
  self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", metrics)
462
471
 
472
+ # If the out_of_fold_df is empty return it
473
+ if out_of_fold_df.empty:
474
+ self.log.warning("No out-of-fold predictions were made. Returning empty DataFrame.")
475
+ return out_of_fold_df
476
+
463
477
  # Capture the results
464
478
  capture_name = "full_cross_fold"
465
479
  description = capture_name.replace("_", " ").title()
@@ -765,20 +779,18 @@ class EndpointCore(Artifact):
765
779
  self.log.info(f"Writing metrics to {inference_capture_path}/inference_metrics.csv")
766
780
  wr.s3.to_csv(metrics, f"{inference_capture_path}/inference_metrics.csv", index=False)
767
781
 
768
- # Grab the target column, prediction column, any _proba columns, and the ID column (if present)
769
- output_columns = [target_column]
770
- output_columns += [col for col in pred_results_df.columns if "prediction" in col]
782
+ # Grab the ID column and target column if they are present
783
+ output_columns = []
784
+ if id_column and id_column in pred_results_df.columns:
785
+ output_columns.append(id_column)
786
+ if target_column in pred_results_df.columns:
787
+ output_columns.append(target_column)
771
788
 
772
- # Add any _proba columns to the output columns
789
+ # Grab the prediction column, any _proba columns, and UQ columns
790
+ output_columns += [col for col in pred_results_df.columns if "prediction" in col]
773
791
  output_columns += [col for col in pred_results_df.columns if col.endswith("_proba")]
774
-
775
- # Add any Uncertainty Quantile columns to the output columns
776
792
  output_columns += [col for col in pred_results_df.columns if col.startswith("q_") or col == "confidence"]
777
793
 
778
- # Add the ID column
779
- if id_column and id_column in pred_results_df.columns:
780
- output_columns.insert(0, id_column)
781
-
782
794
  # Write the predictions to our S3 Model Inference Folder
783
795
  self.log.info(f"Writing predictions to {inference_capture_path}/inference_predictions.csv")
784
796
  subset_df = pred_results_df[output_columns]
@@ -810,10 +822,23 @@ class EndpointCore(Artifact):
810
822
  self.log.warning("No predictions were made. Returning empty DataFrame.")
811
823
  return pd.DataFrame()
812
824
 
825
+ # Check for NaN values in target or prediction columns
826
+ prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
827
+ if prediction_df[target_column].isnull().any() or prediction_df[prediction_col].isnull().any():
828
+ # Compute the number of NaN values in each column
829
+ num_nan_target = prediction_df[target_column].isnull().sum()
830
+ num_nan_prediction = prediction_df[prediction_col].isnull().sum()
831
+ self.log.warning(
832
+ f"NaNs Found: {target_column} {num_nan_target} and {prediction_col}: {num_nan_prediction}."
833
+ )
834
+ self.log.warning(
835
+ "NaN values found in target or prediction columns. Dropping NaN rows for metric computation."
836
+ )
837
+ prediction_df = prediction_df.dropna(subset=[target_column, prediction_col])
838
+
813
839
  # Compute the metrics
814
840
  try:
815
841
  y_true = prediction_df[target_column]
816
- prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
817
842
  y_pred = prediction_df[prediction_col]
818
843
 
819
844
  mae = mean_absolute_error(y_true, y_pred)
@@ -891,6 +916,14 @@ class EndpointCore(Artifact):
891
916
  Returns:
892
917
  pd.DataFrame: DataFrame with the performance metrics
893
918
  """
919
+ # Drop rows with NaN predictions (can't compute metrics on missing predictions)
920
+ prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
921
+ nan_mask = prediction_df[prediction_col].isna()
922
+ if nan_mask.any():
923
+ n_nan = nan_mask.sum()
924
+ self.log.warning(f"Dropping {n_nan} rows with NaN predictions for metrics calculation")
925
+ prediction_df = prediction_df[~nan_mask].copy()
926
+
894
927
  # Get the class labels from the model
895
928
  class_labels = ModelCore(self.model_name).class_labels()
896
929
  if class_labels is None:
@@ -903,7 +936,6 @@ class EndpointCore(Artifact):
903
936
  self.validate_proba_columns(prediction_df, class_labels)
904
937
 
905
938
  # Calculate precision, recall, f1, and support, handling zero division
906
- prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
907
939
  scores = precision_recall_fscore_support(
908
940
  prediction_df[target_column],
909
941
  prediction_df[prediction_col],
@@ -954,9 +986,15 @@ class EndpointCore(Artifact):
954
986
  Returns:
955
987
  pd.DataFrame: DataFrame with the confusion matrix
956
988
  """
989
+ # Drop rows with NaN predictions (can't include in confusion matrix)
990
+ prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
991
+ nan_mask = prediction_df[prediction_col].isna()
992
+ if nan_mask.any():
993
+ n_nan = nan_mask.sum()
994
+ self.log.warning(f"Dropping {n_nan} rows with NaN predictions for confusion matrix")
995
+ prediction_df = prediction_df[~nan_mask].copy()
957
996
 
958
997
  y_true = prediction_df[target_column]
959
- prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
960
998
  y_pred = prediction_df[prediction_col]
961
999
 
962
1000
  # Get model class labels
@@ -16,8 +16,9 @@ from sagemaker.feature_store.feature_store import FeatureStore
16
16
  from workbench.core.artifacts.artifact import Artifact
17
17
  from workbench.core.artifacts.data_source_factory import DataSourceFactory
18
18
  from workbench.core.artifacts.athena_source import AthenaSource
19
+ from workbench.utils.deprecated_utils import deprecated
19
20
 
20
- from typing import TYPE_CHECKING, Optional, List, Union
21
+ from typing import TYPE_CHECKING, Optional, List, Dict, Union
21
22
 
22
23
  from workbench.utils.aws_utils import aws_throttle
23
24
 
@@ -509,6 +510,71 @@ class FeatureSetCore(Artifact):
509
510
  ].tolist()
510
511
  return hold_out_ids
511
512
 
513
+ def set_sample_weights(
514
+ self,
515
+ weight_dict: Dict[Union[str, int], float],
516
+ default_weight: float = 1.0,
517
+ exclude_zero_weights: bool = True,
518
+ ):
519
+ """Configure training view with sample weights for each ID.
520
+
521
+ Args:
522
+ weight_dict: Mapping of ID to sample weight
523
+ - weight > 1.0: oversample/emphasize
524
+ - weight = 1.0: normal (default)
525
+ - 0 < weight < 1.0: downweight/de-emphasize
526
+ - weight = 0.0: exclude from training
527
+ default_weight: Weight for IDs not in weight_dict (default: 1.0)
528
+ exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
529
+
530
+ Example:
531
+ weights = {
532
+ 'compound_42': 3.0, # oversample 3x
533
+ 'compound_99': 0.1, # noisy, downweight
534
+ 'compound_123': 0.0, # exclude from training
535
+ }
536
+ model.set_sample_weights(weights) # zeros automatically excluded
537
+ model.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
538
+ """
539
+ from workbench.core.views import TrainingView
540
+
541
+ if not weight_dict:
542
+ self.log.important("Empty weight_dict, creating standard training view")
543
+ TrainingView.create(self, id_column=self.id_column)
544
+ return
545
+
546
+ self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
547
+
548
+ # Helper to format IDs for SQL
549
+ def format_id(id_val):
550
+ return repr(id_val)
551
+
552
+ # Build CASE statement for sample_weight
553
+ case_conditions = [
554
+ f"WHEN {self.id_column} = {format_id(id_val)} THEN {weight}" for id_val, weight in weight_dict.items()
555
+ ]
556
+ case_statement = "\n ".join(case_conditions)
557
+
558
+ # Build inner query with sample weights
559
+ inner_sql = f"""SELECT
560
+ *,
561
+ CASE
562
+ {case_statement}
563
+ ELSE {default_weight}
564
+ END AS sample_weight
565
+ FROM {self.table}"""
566
+
567
+ # Optionally filter out zero weights
568
+ if exclude_zero_weights:
569
+ zero_count = sum(1 for weight in weight_dict.values() if weight == 0.0)
570
+ custom_sql = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
571
+ self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
572
+ else:
573
+ custom_sql = inner_sql
574
+
575
+ TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
576
+
577
+ @deprecated(version=0.9)
512
578
  def set_training_filter(self, filter_expression: Optional[str] = None):
513
579
  """Set a filter expression for the training view for this FeatureSet
514
580
 
@@ -528,6 +594,7 @@ class FeatureSetCore(Artifact):
528
594
  self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
529
595
  )
530
596
 
597
+ @deprecated(version="0.9")
531
598
  def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
532
599
  """Exclude a list of IDs from the training view
533
600
 
@@ -551,6 +618,7 @@ class FeatureSetCore(Artifact):
551
618
  # Apply the filter
552
619
  self.set_training_filter(filter_expression)
553
620
 
621
+ @deprecated(version="0.9")
554
622
  def set_training_sampling(
555
623
  self,
556
624
  exclude_ids: Optional[List[Union[str, int]]] = None,
@@ -30,11 +30,23 @@ class ModelType(Enum):
30
30
  CLASSIFIER = "classifier"
31
31
  REGRESSOR = "regressor"
32
32
  CLUSTERER = "clusterer"
33
- TRANSFORMER = "transformer"
34
33
  PROXIMITY = "proximity"
35
34
  PROJECTION = "projection"
36
35
  UQ_REGRESSOR = "uq_regressor"
37
36
  ENSEMBLE_REGRESSOR = "ensemble_regressor"
37
+ TRANSFORMER = "transformer"
38
+ UNKNOWN = "unknown"
39
+
40
+
41
+ class ModelFramework(Enum):
42
+ """Enumerated Types for Workbench Model Frameworks"""
43
+
44
+ SKLEARN = "sklearn"
45
+ XGBOOST = "xgboost"
46
+ LIGHTGBM = "lightgbm"
47
+ PYTORCH_TABULAR = "pytorch_tabular"
48
+ CHEMPROP = "chemprop"
49
+ TRANSFORMER = "transformer"
38
50
  UNKNOWN = "unknown"
39
51
 
40
52
 
@@ -87,11 +99,10 @@ class ModelCore(Artifact):
87
99
  ```
88
100
  """
89
101
 
90
- def __init__(self, model_name: str, model_type: ModelType = None, **kwargs):
102
+ def __init__(self, model_name: str, **kwargs):
91
103
  """ModelCore Initialization
92
104
  Args:
93
105
  model_name (str): Name of Model in Workbench.
94
- model_type (ModelType, optional): Set this for newly created Models. Defaults to None.
95
106
  **kwargs: Additional keyword arguments
96
107
  """
97
108
 
@@ -125,10 +136,8 @@ class ModelCore(Artifact):
125
136
  self.latest_model = self.model_meta["ModelPackageList"][0]
126
137
  self.description = self.latest_model.get("ModelPackageDescription", "-")
127
138
  self.training_job_name = self._extract_training_job_name()
128
- if model_type:
129
- self._set_model_type(model_type)
130
- else:
131
- self.model_type = self._get_model_type()
139
+ self.model_type = self._get_model_type()
140
+ self.model_framework = self._get_model_framework()
132
141
  except (IndexError, KeyError):
133
142
  self.log.critical(f"Model {self.model_name} appears to be malformed. Delete and recreate it!")
134
143
  return
@@ -972,6 +981,27 @@ class ModelCore(Artifact):
972
981
  self.log.warning(f"Could not determine model type for {self.model_name}!")
973
982
  return ModelType.UNKNOWN
974
983
 
984
+ def _set_model_framework(self, model_framework: ModelFramework):
985
+ """Internal: Set the Model Framework for this Model"""
986
+ self.model_framework = model_framework
987
+ self.upsert_workbench_meta({"workbench_model_framework": self.model_framework.value})
988
+ self.remove_health_tag("model_framework_unknown")
989
+
990
+ def _get_model_framework(self) -> ModelFramework:
991
+ """Internal: Query the Workbench Metadata to get the model framework
992
+ Returns:
993
+ ModelFramework: The ModelFramework of this Model
994
+ Notes:
995
+ This is an internal method that should not be called directly
996
+ Use the model_framework attribute instead
997
+ """
998
+ model_framework = self.workbench_meta().get("workbench_model_framework")
999
+ try:
1000
+ return ModelFramework(model_framework)
1001
+ except ValueError:
1002
+ self.log.warning(f"Could not determine model framework for {self.model_name}!")
1003
+ return ModelFramework.UNKNOWN
1004
+
975
1005
  def _load_training_metrics(self):
976
1006
  """Internal: Retrieve the training metrics and Confusion Matrix for this model
977
1007
  and load the data into the Workbench Metadata
@@ -4,6 +4,7 @@ from typing import Union
4
4
  import logging
5
5
  import json
6
6
  import zlib
7
+ import time
7
8
  import base64
8
9
  from botocore.exceptions import ClientError
9
10
 
@@ -77,7 +78,7 @@ class AWSParameterStore:
77
78
  all_parameters = []
78
79
 
79
80
  # Make the initial call to describe parameters
80
- response = self.ssm_client.describe_parameters(**params)
81
+ response = self._call_with_retry(self.ssm_client.describe_parameters, **params)
81
82
 
82
83
  # Aggregate the names from the initial response
83
84
  all_parameters.extend(param["Name"] for param in response["Parameters"])
@@ -86,7 +87,7 @@ class AWSParameterStore:
86
87
  while "NextToken" in response:
87
88
  # Update the parameters with the NextToken for subsequent calls
88
89
  params["NextToken"] = response["NextToken"]
89
- response = self.ssm_client.describe_parameters(**params)
90
+ response = self._call_with_retry(self.ssm_client.describe_parameters, **params)
90
91
 
91
92
  # Aggregate the names from the subsequent responses
92
93
  all_parameters.extend(param["Name"] for param in response["Parameters"])
@@ -183,6 +184,21 @@ class AWSParameterStore:
183
184
  self.log.critical(f"Failed to add/update parameter '{name}': {e}")
184
185
  raise
185
186
 
187
+ def _call_with_retry(self, func, **kwargs):
188
+ """Call AWS API with exponential backoff on throttling."""
189
+ max_retries = 5
190
+ base_delay = 1
191
+ for attempt in range(max_retries):
192
+ try:
193
+ return func(**kwargs)
194
+ except ClientError as e:
195
+ if e.response["Error"]["Code"] == "ThrottlingException" and attempt < max_retries - 1:
196
+ delay = base_delay * (2**attempt)
197
+ self.log.warning(f"Throttled, retrying in {delay}s...")
198
+ time.sleep(delay)
199
+ else:
200
+ raise
201
+
186
202
  @staticmethod
187
203
  def _compress_value(value) -> str:
188
204
  """Compress a value with precision reduction."""