workbench 0.8.198__py3-none-any.whl → 0.8.201__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/proximity.py +11 -4
- workbench/api/__init__.py +2 -1
- workbench/api/feature_set.py +7 -4
- workbench/api/model.py +1 -1
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/endpoint_core.py +84 -46
- workbench/core/artifacts/feature_set_core.py +69 -1
- workbench/core/artifacts/model_core.py +37 -7
- workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
- workbench/core/transforms/features_to_model/features_to_model.py +23 -20
- workbench/core/views/view.py +2 -2
- workbench/model_scripts/chemprop/chemprop.template +931 -0
- workbench/model_scripts/chemprop/generated_model_script.py +931 -0
- workbench/model_scripts/chemprop/requirements.txt +11 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/proximity.py +11 -4
- workbench/model_scripts/custom_models/uq_models/proximity.py +11 -4
- workbench/model_scripts/pytorch_model/generated_model_script.py +130 -88
- workbench/model_scripts/pytorch_model/pytorch.template +128 -86
- workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
- workbench/model_scripts/script_generation.py +10 -7
- workbench/model_scripts/uq_models/generated_model_script.py +25 -18
- workbench/model_scripts/uq_models/mapie.template +23 -16
- workbench/model_scripts/xgb_model/generated_model_script.py +6 -6
- workbench/model_scripts/xgb_model/xgb_model.template +2 -2
- workbench/repl/workbench_shell.py +14 -5
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
- workbench/utils/chemprop_utils.py +724 -0
- workbench/utils/pytorch_utils.py +497 -0
- workbench/utils/xgboost_model_utils.py +10 -5
- {workbench-0.8.198.dist-info → workbench-0.8.201.dist-info}/METADATA +2 -2
- {workbench-0.8.198.dist-info → workbench-0.8.201.dist-info}/RECORD +38 -32
- {workbench-0.8.198.dist-info → workbench-0.8.201.dist-info}/entry_points.txt +2 -1
- workbench/model_scripts/__pycache__/script_generation.cpython-312.pyc +0 -0
- workbench/model_scripts/__pycache__/script_generation.cpython-313.pyc +0 -0
- {workbench-0.8.198.dist-info → workbench-0.8.201.dist-info}/WHEEL +0 -0
- {workbench-0.8.198.dist-info → workbench-0.8.201.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.198.dist-info → workbench-0.8.201.dist-info}/top_level.txt +0 -0
|
@@ -69,6 +69,7 @@ class Proximity:
|
|
|
69
69
|
top_percent: float = 1.0,
|
|
70
70
|
min_delta: Optional[float] = None,
|
|
71
71
|
k_neighbors: int = 4,
|
|
72
|
+
only_coincident: bool = False,
|
|
72
73
|
) -> pd.DataFrame:
|
|
73
74
|
"""
|
|
74
75
|
Find compounds with steep target gradients (data quality issues and activity cliffs).
|
|
@@ -81,6 +82,7 @@ class Proximity:
|
|
|
81
82
|
top_percent: Percentage of compounds with steepest gradients to return (e.g., 1.0 = top 1%)
|
|
82
83
|
min_delta: Minimum absolute target difference to consider. If None, defaults to target_range/100
|
|
83
84
|
k_neighbors: Number of neighbors to use for median calculation (default: 4)
|
|
85
|
+
only_coincident: If True, only consider compounds that are coincident (default: False)
|
|
84
86
|
|
|
85
87
|
Returns:
|
|
86
88
|
DataFrame of compounds with steepest gradients, sorted by gradient (descending)
|
|
@@ -99,10 +101,15 @@ class Proximity:
|
|
|
99
101
|
min_delta = self.target_range / 100.0 if self.target_range > 0 else 0.0
|
|
100
102
|
candidates = candidates[candidates["nn_target_diff"] >= min_delta]
|
|
101
103
|
|
|
102
|
-
#
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
104
|
+
# Filter based on mode
|
|
105
|
+
if only_coincident:
|
|
106
|
+
# Only keep coincident points (nn_distance ~= 0)
|
|
107
|
+
candidates = candidates[candidates["nn_distance"] < epsilon].copy()
|
|
108
|
+
else:
|
|
109
|
+
# Get top X% by initial gradient
|
|
110
|
+
percentile = 100 - top_percent
|
|
111
|
+
threshold = np.percentile(candidates["gradient"], percentile)
|
|
112
|
+
candidates = candidates[candidates["gradient"] >= threshold].copy()
|
|
106
113
|
|
|
107
114
|
# Phase 2: Verify with k-neighbor median to filter out cases where nearest neighbor is the outlier
|
|
108
115
|
results = []
|
workbench/api/__init__.py
CHANGED
|
@@ -14,7 +14,7 @@ These class provide high-level APIs for the Workbench package, offering easy acc
|
|
|
14
14
|
|
|
15
15
|
from .data_source import DataSource
|
|
16
16
|
from .feature_set import FeatureSet
|
|
17
|
-
from .model import Model, ModelType
|
|
17
|
+
from .model import Model, ModelType, ModelFramework
|
|
18
18
|
from .endpoint import Endpoint
|
|
19
19
|
from .meta import Meta
|
|
20
20
|
from .parameter_store import ParameterStore
|
|
@@ -25,6 +25,7 @@ __all__ = [
|
|
|
25
25
|
"FeatureSet",
|
|
26
26
|
"Model",
|
|
27
27
|
"ModelType",
|
|
28
|
+
"ModelFramework",
|
|
28
29
|
"Endpoint",
|
|
29
30
|
"Meta",
|
|
30
31
|
"ParameterStore",
|
workbench/api/feature_set.py
CHANGED
|
@@ -12,7 +12,7 @@ import pandas as pd
|
|
|
12
12
|
from workbench.core.artifacts.artifact import Artifact
|
|
13
13
|
from workbench.core.artifacts.feature_set_core import FeatureSetCore
|
|
14
14
|
from workbench.core.transforms.features_to_model.features_to_model import FeaturesToModel
|
|
15
|
-
from workbench.api.model import Model, ModelType
|
|
15
|
+
from workbench.api.model import Model, ModelType, ModelFramework
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class FeatureSet(FeatureSetCore):
|
|
@@ -79,6 +79,7 @@ class FeatureSet(FeatureSetCore):
|
|
|
79
79
|
self,
|
|
80
80
|
name: str,
|
|
81
81
|
model_type: ModelType,
|
|
82
|
+
model_framework: ModelFramework = ModelFramework.XGBOOST,
|
|
82
83
|
tags: list = None,
|
|
83
84
|
description: str = None,
|
|
84
85
|
feature_list: list = None,
|
|
@@ -98,11 +99,12 @@ class FeatureSet(FeatureSetCore):
|
|
|
98
99
|
|
|
99
100
|
name (str): The name of the Model to create
|
|
100
101
|
model_type (ModelType): The type of model to create (See workbench.model.ModelType)
|
|
102
|
+
model_framework (ModelFramework, optional): The framework to use for the model (default: XGBOOST)
|
|
101
103
|
tags (list, optional): Set the tags for the model. If not given tags will be generated.
|
|
102
104
|
description (str, optional): Set the description for the model. If not give a description is generated.
|
|
103
105
|
feature_list (list, optional): Set the feature list for the model. If not given a feature list is generated.
|
|
104
106
|
target_column (str, optional): The target column for the model (use None for unsupervised model)
|
|
105
|
-
model_class (str, optional): Model class to use (e.g. "KMeans",
|
|
107
|
+
model_class (str, optional): Model class to use (e.g. "KMeans", default: None)
|
|
106
108
|
model_import_str (str, optional): The import for the model (e.g. "from sklearn.cluster import KMeans")
|
|
107
109
|
custom_script (str, optional): The custom script to use for the model (default: None)
|
|
108
110
|
training_image (str, optional): The training image to use (default: "training")
|
|
@@ -128,8 +130,8 @@ class FeatureSet(FeatureSetCore):
|
|
|
128
130
|
# Create the Model Tags
|
|
129
131
|
tags = [name] if tags is None else tags
|
|
130
132
|
|
|
131
|
-
# If the
|
|
132
|
-
if
|
|
133
|
+
# If the model framework is PyTorch or ChemProp, ensure we set the training and inference images
|
|
134
|
+
if model_framework in (ModelFramework.PYTORCH_TABULAR, ModelFramework.CHEMPROP):
|
|
133
135
|
training_image = "pytorch_training"
|
|
134
136
|
inference_image = "pytorch_inference"
|
|
135
137
|
|
|
@@ -138,6 +140,7 @@ class FeatureSet(FeatureSetCore):
|
|
|
138
140
|
feature_name=self.name,
|
|
139
141
|
model_name=name,
|
|
140
142
|
model_type=model_type,
|
|
143
|
+
model_framework=model_framework,
|
|
141
144
|
model_class=model_class,
|
|
142
145
|
model_import_str=model_import_str,
|
|
143
146
|
custom_script=custom_script,
|
workbench/api/model.py
CHANGED
|
@@ -7,7 +7,7 @@ Dashboard UI, which provides additional model details and performance metrics
|
|
|
7
7
|
|
|
8
8
|
# Workbench Imports
|
|
9
9
|
from workbench.core.artifacts.artifact import Artifact
|
|
10
|
-
from workbench.core.artifacts.model_core import ModelCore, ModelType # noqa: F401
|
|
10
|
+
from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelFramework # noqa: F401
|
|
11
11
|
from workbench.core.transforms.model_to_endpoint.model_to_endpoint import ModelToEndpoint
|
|
12
12
|
from workbench.api.endpoint import Endpoint
|
|
13
13
|
from workbench.utils.model_utils import proximity_model_local, uq_model
|
|
@@ -15,7 +15,16 @@ from .artifact import Artifact
|
|
|
15
15
|
from .athena_source import AthenaSource
|
|
16
16
|
from .data_source_abstract import DataSourceAbstract
|
|
17
17
|
from .feature_set_core import FeatureSetCore
|
|
18
|
-
from .model_core import ModelCore, ModelType
|
|
18
|
+
from .model_core import ModelCore, ModelType, ModelFramework
|
|
19
19
|
from .endpoint_core import EndpointCore
|
|
20
20
|
|
|
21
|
-
__all__ = [
|
|
21
|
+
__all__ = [
|
|
22
|
+
"Artifact",
|
|
23
|
+
"AthenaSource",
|
|
24
|
+
"DataSourceAbstract",
|
|
25
|
+
"FeatureSetCore",
|
|
26
|
+
"ModelCore",
|
|
27
|
+
"ModelType",
|
|
28
|
+
"ModelFramework",
|
|
29
|
+
"EndpointCore",
|
|
30
|
+
]
|
|
@@ -30,12 +30,14 @@ from sagemaker import Predictor
|
|
|
30
30
|
|
|
31
31
|
# Workbench Imports
|
|
32
32
|
from workbench.core.artifacts.artifact import Artifact
|
|
33
|
-
from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType
|
|
33
|
+
from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType, ModelFramework
|
|
34
34
|
from workbench.utils.endpoint_metrics import EndpointMetrics
|
|
35
35
|
from workbench.utils.cache import Cache
|
|
36
36
|
from workbench.utils.s3_utils import compute_s3_object_hash
|
|
37
37
|
from workbench.utils.model_utils import uq_metrics
|
|
38
|
-
from workbench.utils.xgboost_model_utils import cross_fold_inference
|
|
38
|
+
from workbench.utils.xgboost_model_utils import cross_fold_inference as xgboost_cross_fold
|
|
39
|
+
from workbench.utils.pytorch_utils import cross_fold_inference as pytorch_cross_fold
|
|
40
|
+
from workbench.utils.chemprop_utils import cross_fold_inference as chemprop_cross_fold
|
|
39
41
|
from workbench_bridges.endpoints.fast_inference import fast_inference
|
|
40
42
|
|
|
41
43
|
|
|
@@ -399,41 +401,40 @@ class EndpointCore(Artifact):
|
|
|
399
401
|
if target_column and (target_column not in prediction_df.columns):
|
|
400
402
|
self.log.important(f"Target Column {target_column} not found in prediction_df!")
|
|
401
403
|
self.log.important("In order to compute metrics, the target column must be present!")
|
|
402
|
-
|
|
404
|
+
metrics = pd.DataFrame()
|
|
403
405
|
|
|
404
406
|
# Compute the standard performance metrics for this model
|
|
405
|
-
model_type = model.model_type
|
|
406
|
-
if model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
|
|
407
|
-
prediction_df = self.residuals(target_column, prediction_df)
|
|
408
|
-
metrics = self.regression_metrics(target_column, prediction_df)
|
|
409
|
-
elif model_type == ModelType.CLASSIFIER:
|
|
410
|
-
metrics = self.classification_metrics(target_column, prediction_df)
|
|
411
407
|
else:
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
408
|
+
if model.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
|
|
409
|
+
prediction_df = self.residuals(target_column, prediction_df)
|
|
410
|
+
metrics = self.regression_metrics(target_column, prediction_df)
|
|
411
|
+
elif model.model_type == ModelType.CLASSIFIER:
|
|
412
|
+
metrics = self.classification_metrics(target_column, prediction_df)
|
|
413
|
+
else:
|
|
414
|
+
# For other model types, we don't compute metrics
|
|
415
|
+
self.log.info(f"Model Type: {model.model_type} doesn't have metrics...")
|
|
416
|
+
metrics = pd.DataFrame()
|
|
415
417
|
|
|
416
418
|
# Print out the metrics
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
|
|
419
|
+
print(f"Performance Metrics for {self.model_name} on {self.name}")
|
|
420
|
+
print(metrics.head())
|
|
421
|
+
|
|
422
|
+
# Capture the inference results and metrics
|
|
423
|
+
if capture_name is not None:
|
|
424
|
+
|
|
425
|
+
# If we don't have an id_column, we'll pull it from the model's FeatureSet
|
|
426
|
+
if id_column is None:
|
|
427
|
+
fs = FeatureSetCore(model.get_input())
|
|
428
|
+
id_column = fs.id_column
|
|
429
|
+
description = capture_name.replace("_", " ").title()
|
|
430
|
+
self._capture_inference_results(
|
|
431
|
+
capture_name, prediction_df, target_column, model.model_type, metrics, description, features, id_column
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# For UQ Models we also capture the uncertainty metrics
|
|
435
|
+
if model.model_type in [ModelType.UQ_REGRESSOR]:
|
|
436
|
+
metrics = uq_metrics(prediction_df, target_column)
|
|
437
|
+
self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
|
|
437
438
|
|
|
438
439
|
# Return the prediction DataFrame
|
|
439
440
|
return prediction_df
|
|
@@ -452,7 +453,15 @@ class EndpointCore(Artifact):
|
|
|
452
453
|
model = ModelCore(self.model_name)
|
|
453
454
|
|
|
454
455
|
# Compute CrossFold (Metrics and Prediction Dataframe)
|
|
455
|
-
|
|
456
|
+
if model.model_framework in [ModelFramework.UNKNOWN, ModelFramework.XGBOOST]:
|
|
457
|
+
cross_fold_metrics, out_of_fold_df = xgboost_cross_fold(model, nfolds=nfolds)
|
|
458
|
+
elif model.model_framework == ModelFramework.PYTORCH_TABULAR:
|
|
459
|
+
cross_fold_metrics, out_of_fold_df = pytorch_cross_fold(model, nfolds=nfolds)
|
|
460
|
+
elif model.model_framework == ModelFramework.CHEMPROP:
|
|
461
|
+
cross_fold_metrics, out_of_fold_df = chemprop_cross_fold(model, nfolds=nfolds)
|
|
462
|
+
else:
|
|
463
|
+
self.log.error(f"Cross-Fold Inference not supported for Model Framework: {model.model_framework}.")
|
|
464
|
+
return pd.DataFrame()
|
|
456
465
|
|
|
457
466
|
# If the metrics dataframe isn't empty save to the param store
|
|
458
467
|
if not cross_fold_metrics.empty:
|
|
@@ -460,6 +469,11 @@ class EndpointCore(Artifact):
|
|
|
460
469
|
metrics = cross_fold_metrics.to_dict(orient="records")
|
|
461
470
|
self.param_store.upsert(f"/workbench/models/{model.name}/inference/cross_fold", metrics)
|
|
462
471
|
|
|
472
|
+
# If the out_of_fold_df is empty return it
|
|
473
|
+
if out_of_fold_df.empty:
|
|
474
|
+
self.log.warning("No out-of-fold predictions were made. Returning empty DataFrame.")
|
|
475
|
+
return out_of_fold_df
|
|
476
|
+
|
|
463
477
|
# Capture the results
|
|
464
478
|
capture_name = "full_cross_fold"
|
|
465
479
|
description = capture_name.replace("_", " ").title()
|
|
@@ -765,20 +779,18 @@ class EndpointCore(Artifact):
|
|
|
765
779
|
self.log.info(f"Writing metrics to {inference_capture_path}/inference_metrics.csv")
|
|
766
780
|
wr.s3.to_csv(metrics, f"{inference_capture_path}/inference_metrics.csv", index=False)
|
|
767
781
|
|
|
768
|
-
# Grab the
|
|
769
|
-
output_columns = [
|
|
770
|
-
|
|
782
|
+
# Grab the ID column and target column if they are present
|
|
783
|
+
output_columns = []
|
|
784
|
+
if id_column and id_column in pred_results_df.columns:
|
|
785
|
+
output_columns.append(id_column)
|
|
786
|
+
if target_column in pred_results_df.columns:
|
|
787
|
+
output_columns.append(target_column)
|
|
771
788
|
|
|
772
|
-
#
|
|
789
|
+
# Grab the prediction column, any _proba columns, and UQ columns
|
|
790
|
+
output_columns += [col for col in pred_results_df.columns if "prediction" in col]
|
|
773
791
|
output_columns += [col for col in pred_results_df.columns if col.endswith("_proba")]
|
|
774
|
-
|
|
775
|
-
# Add any Uncertainty Quantile columns to the output columns
|
|
776
792
|
output_columns += [col for col in pred_results_df.columns if col.startswith("q_") or col == "confidence"]
|
|
777
793
|
|
|
778
|
-
# Add the ID column
|
|
779
|
-
if id_column and id_column in pred_results_df.columns:
|
|
780
|
-
output_columns.insert(0, id_column)
|
|
781
|
-
|
|
782
794
|
# Write the predictions to our S3 Model Inference Folder
|
|
783
795
|
self.log.info(f"Writing predictions to {inference_capture_path}/inference_predictions.csv")
|
|
784
796
|
subset_df = pred_results_df[output_columns]
|
|
@@ -810,10 +822,23 @@ class EndpointCore(Artifact):
|
|
|
810
822
|
self.log.warning("No predictions were made. Returning empty DataFrame.")
|
|
811
823
|
return pd.DataFrame()
|
|
812
824
|
|
|
825
|
+
# Check for NaN values in target or prediction columns
|
|
826
|
+
prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
|
|
827
|
+
if prediction_df[target_column].isnull().any() or prediction_df[prediction_col].isnull().any():
|
|
828
|
+
# Compute the number of NaN values in each column
|
|
829
|
+
num_nan_target = prediction_df[target_column].isnull().sum()
|
|
830
|
+
num_nan_prediction = prediction_df[prediction_col].isnull().sum()
|
|
831
|
+
self.log.warning(
|
|
832
|
+
f"NaNs Found: {target_column} {num_nan_target} and {prediction_col}: {num_nan_prediction}."
|
|
833
|
+
)
|
|
834
|
+
self.log.warning(
|
|
835
|
+
"NaN values found in target or prediction columns. Dropping NaN rows for metric computation."
|
|
836
|
+
)
|
|
837
|
+
prediction_df = prediction_df.dropna(subset=[target_column, prediction_col])
|
|
838
|
+
|
|
813
839
|
# Compute the metrics
|
|
814
840
|
try:
|
|
815
841
|
y_true = prediction_df[target_column]
|
|
816
|
-
prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
|
|
817
842
|
y_pred = prediction_df[prediction_col]
|
|
818
843
|
|
|
819
844
|
mae = mean_absolute_error(y_true, y_pred)
|
|
@@ -891,6 +916,14 @@ class EndpointCore(Artifact):
|
|
|
891
916
|
Returns:
|
|
892
917
|
pd.DataFrame: DataFrame with the performance metrics
|
|
893
918
|
"""
|
|
919
|
+
# Drop rows with NaN predictions (can't compute metrics on missing predictions)
|
|
920
|
+
prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
|
|
921
|
+
nan_mask = prediction_df[prediction_col].isna()
|
|
922
|
+
if nan_mask.any():
|
|
923
|
+
n_nan = nan_mask.sum()
|
|
924
|
+
self.log.warning(f"Dropping {n_nan} rows with NaN predictions for metrics calculation")
|
|
925
|
+
prediction_df = prediction_df[~nan_mask].copy()
|
|
926
|
+
|
|
894
927
|
# Get the class labels from the model
|
|
895
928
|
class_labels = ModelCore(self.model_name).class_labels()
|
|
896
929
|
if class_labels is None:
|
|
@@ -903,7 +936,6 @@ class EndpointCore(Artifact):
|
|
|
903
936
|
self.validate_proba_columns(prediction_df, class_labels)
|
|
904
937
|
|
|
905
938
|
# Calculate precision, recall, f1, and support, handling zero division
|
|
906
|
-
prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
|
|
907
939
|
scores = precision_recall_fscore_support(
|
|
908
940
|
prediction_df[target_column],
|
|
909
941
|
prediction_df[prediction_col],
|
|
@@ -954,9 +986,15 @@ class EndpointCore(Artifact):
|
|
|
954
986
|
Returns:
|
|
955
987
|
pd.DataFrame: DataFrame with the confusion matrix
|
|
956
988
|
"""
|
|
989
|
+
# Drop rows with NaN predictions (can't include in confusion matrix)
|
|
990
|
+
prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
|
|
991
|
+
nan_mask = prediction_df[prediction_col].isna()
|
|
992
|
+
if nan_mask.any():
|
|
993
|
+
n_nan = nan_mask.sum()
|
|
994
|
+
self.log.warning(f"Dropping {n_nan} rows with NaN predictions for confusion matrix")
|
|
995
|
+
prediction_df = prediction_df[~nan_mask].copy()
|
|
957
996
|
|
|
958
997
|
y_true = prediction_df[target_column]
|
|
959
|
-
prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
|
|
960
998
|
y_pred = prediction_df[prediction_col]
|
|
961
999
|
|
|
962
1000
|
# Get model class labels
|
|
@@ -16,8 +16,9 @@ from sagemaker.feature_store.feature_store import FeatureStore
|
|
|
16
16
|
from workbench.core.artifacts.artifact import Artifact
|
|
17
17
|
from workbench.core.artifacts.data_source_factory import DataSourceFactory
|
|
18
18
|
from workbench.core.artifacts.athena_source import AthenaSource
|
|
19
|
+
from workbench.utils.deprecated_utils import deprecated
|
|
19
20
|
|
|
20
|
-
from typing import TYPE_CHECKING, Optional, List, Union
|
|
21
|
+
from typing import TYPE_CHECKING, Optional, List, Dict, Union
|
|
21
22
|
|
|
22
23
|
from workbench.utils.aws_utils import aws_throttle
|
|
23
24
|
|
|
@@ -509,6 +510,71 @@ class FeatureSetCore(Artifact):
|
|
|
509
510
|
].tolist()
|
|
510
511
|
return hold_out_ids
|
|
511
512
|
|
|
513
|
+
def set_sample_weights(
|
|
514
|
+
self,
|
|
515
|
+
weight_dict: Dict[Union[str, int], float],
|
|
516
|
+
default_weight: float = 1.0,
|
|
517
|
+
exclude_zero_weights: bool = True,
|
|
518
|
+
):
|
|
519
|
+
"""Configure training view with sample weights for each ID.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
weight_dict: Mapping of ID to sample weight
|
|
523
|
+
- weight > 1.0: oversample/emphasize
|
|
524
|
+
- weight = 1.0: normal (default)
|
|
525
|
+
- 0 < weight < 1.0: downweight/de-emphasize
|
|
526
|
+
- weight = 0.0: exclude from training
|
|
527
|
+
default_weight: Weight for IDs not in weight_dict (default: 1.0)
|
|
528
|
+
exclude_zero_weights: If True, filter out rows with sample_weight=0 (default: True)
|
|
529
|
+
|
|
530
|
+
Example:
|
|
531
|
+
weights = {
|
|
532
|
+
'compound_42': 3.0, # oversample 3x
|
|
533
|
+
'compound_99': 0.1, # noisy, downweight
|
|
534
|
+
'compound_123': 0.0, # exclude from training
|
|
535
|
+
}
|
|
536
|
+
model.set_sample_weights(weights) # zeros automatically excluded
|
|
537
|
+
model.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
|
|
538
|
+
"""
|
|
539
|
+
from workbench.core.views import TrainingView
|
|
540
|
+
|
|
541
|
+
if not weight_dict:
|
|
542
|
+
self.log.important("Empty weight_dict, creating standard training view")
|
|
543
|
+
TrainingView.create(self, id_column=self.id_column)
|
|
544
|
+
return
|
|
545
|
+
|
|
546
|
+
self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
|
|
547
|
+
|
|
548
|
+
# Helper to format IDs for SQL
|
|
549
|
+
def format_id(id_val):
|
|
550
|
+
return repr(id_val)
|
|
551
|
+
|
|
552
|
+
# Build CASE statement for sample_weight
|
|
553
|
+
case_conditions = [
|
|
554
|
+
f"WHEN {self.id_column} = {format_id(id_val)} THEN {weight}" for id_val, weight in weight_dict.items()
|
|
555
|
+
]
|
|
556
|
+
case_statement = "\n ".join(case_conditions)
|
|
557
|
+
|
|
558
|
+
# Build inner query with sample weights
|
|
559
|
+
inner_sql = f"""SELECT
|
|
560
|
+
*,
|
|
561
|
+
CASE
|
|
562
|
+
{case_statement}
|
|
563
|
+
ELSE {default_weight}
|
|
564
|
+
END AS sample_weight
|
|
565
|
+
FROM {self.table}"""
|
|
566
|
+
|
|
567
|
+
# Optionally filter out zero weights
|
|
568
|
+
if exclude_zero_weights:
|
|
569
|
+
zero_count = sum(1 for weight in weight_dict.values() if weight == 0.0)
|
|
570
|
+
custom_sql = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
|
|
571
|
+
self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
|
|
572
|
+
else:
|
|
573
|
+
custom_sql = inner_sql
|
|
574
|
+
|
|
575
|
+
TrainingView.create_with_sql(self, sql_query=custom_sql, id_column=self.id_column)
|
|
576
|
+
|
|
577
|
+
@deprecated(version=0.9)
|
|
512
578
|
def set_training_filter(self, filter_expression: Optional[str] = None):
|
|
513
579
|
"""Set a filter expression for the training view for this FeatureSet
|
|
514
580
|
|
|
@@ -528,6 +594,7 @@ class FeatureSetCore(Artifact):
|
|
|
528
594
|
self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
|
|
529
595
|
)
|
|
530
596
|
|
|
597
|
+
@deprecated(version="0.9")
|
|
531
598
|
def exclude_ids_from_training(self, ids: List[Union[str, int]], column_name: Optional[str] = None):
|
|
532
599
|
"""Exclude a list of IDs from the training view
|
|
533
600
|
|
|
@@ -551,6 +618,7 @@ class FeatureSetCore(Artifact):
|
|
|
551
618
|
# Apply the filter
|
|
552
619
|
self.set_training_filter(filter_expression)
|
|
553
620
|
|
|
621
|
+
@deprecated(version="0.9")
|
|
554
622
|
def set_training_sampling(
|
|
555
623
|
self,
|
|
556
624
|
exclude_ids: Optional[List[Union[str, int]]] = None,
|
|
@@ -30,11 +30,23 @@ class ModelType(Enum):
|
|
|
30
30
|
CLASSIFIER = "classifier"
|
|
31
31
|
REGRESSOR = "regressor"
|
|
32
32
|
CLUSTERER = "clusterer"
|
|
33
|
-
TRANSFORMER = "transformer"
|
|
34
33
|
PROXIMITY = "proximity"
|
|
35
34
|
PROJECTION = "projection"
|
|
36
35
|
UQ_REGRESSOR = "uq_regressor"
|
|
37
36
|
ENSEMBLE_REGRESSOR = "ensemble_regressor"
|
|
37
|
+
TRANSFORMER = "transformer"
|
|
38
|
+
UNKNOWN = "unknown"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ModelFramework(Enum):
|
|
42
|
+
"""Enumerated Types for Workbench Model Frameworks"""
|
|
43
|
+
|
|
44
|
+
SKLEARN = "sklearn"
|
|
45
|
+
XGBOOST = "xgboost"
|
|
46
|
+
LIGHTGBM = "lightgbm"
|
|
47
|
+
PYTORCH_TABULAR = "pytorch_tabular"
|
|
48
|
+
CHEMPROP = "chemprop"
|
|
49
|
+
TRANSFORMER = "transformer"
|
|
38
50
|
UNKNOWN = "unknown"
|
|
39
51
|
|
|
40
52
|
|
|
@@ -87,11 +99,10 @@ class ModelCore(Artifact):
|
|
|
87
99
|
```
|
|
88
100
|
"""
|
|
89
101
|
|
|
90
|
-
def __init__(self, model_name: str,
|
|
102
|
+
def __init__(self, model_name: str, **kwargs):
|
|
91
103
|
"""ModelCore Initialization
|
|
92
104
|
Args:
|
|
93
105
|
model_name (str): Name of Model in Workbench.
|
|
94
|
-
model_type (ModelType, optional): Set this for newly created Models. Defaults to None.
|
|
95
106
|
**kwargs: Additional keyword arguments
|
|
96
107
|
"""
|
|
97
108
|
|
|
@@ -125,10 +136,8 @@ class ModelCore(Artifact):
|
|
|
125
136
|
self.latest_model = self.model_meta["ModelPackageList"][0]
|
|
126
137
|
self.description = self.latest_model.get("ModelPackageDescription", "-")
|
|
127
138
|
self.training_job_name = self._extract_training_job_name()
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
else:
|
|
131
|
-
self.model_type = self._get_model_type()
|
|
139
|
+
self.model_type = self._get_model_type()
|
|
140
|
+
self.model_framework = self._get_model_framework()
|
|
132
141
|
except (IndexError, KeyError):
|
|
133
142
|
self.log.critical(f"Model {self.model_name} appears to be malformed. Delete and recreate it!")
|
|
134
143
|
return
|
|
@@ -972,6 +981,27 @@ class ModelCore(Artifact):
|
|
|
972
981
|
self.log.warning(f"Could not determine model type for {self.model_name}!")
|
|
973
982
|
return ModelType.UNKNOWN
|
|
974
983
|
|
|
984
|
+
def _set_model_framework(self, model_framework: ModelFramework):
|
|
985
|
+
"""Internal: Set the Model Framework for this Model"""
|
|
986
|
+
self.model_framework = model_framework
|
|
987
|
+
self.upsert_workbench_meta({"workbench_model_framework": self.model_framework.value})
|
|
988
|
+
self.remove_health_tag("model_framework_unknown")
|
|
989
|
+
|
|
990
|
+
def _get_model_framework(self) -> ModelFramework:
|
|
991
|
+
"""Internal: Query the Workbench Metadata to get the model framework
|
|
992
|
+
Returns:
|
|
993
|
+
ModelFramework: The ModelFramework of this Model
|
|
994
|
+
Notes:
|
|
995
|
+
This is an internal method that should not be called directly
|
|
996
|
+
Use the model_framework attribute instead
|
|
997
|
+
"""
|
|
998
|
+
model_framework = self.workbench_meta().get("workbench_model_framework")
|
|
999
|
+
try:
|
|
1000
|
+
return ModelFramework(model_framework)
|
|
1001
|
+
except ValueError:
|
|
1002
|
+
self.log.warning(f"Could not determine model framework for {self.model_name}!")
|
|
1003
|
+
return ModelFramework.UNKNOWN
|
|
1004
|
+
|
|
975
1005
|
def _load_training_metrics(self):
|
|
976
1006
|
"""Internal: Retrieve the training metrics and Confusion Matrix for this model
|
|
977
1007
|
and load the data into the Workbench Metadata
|
|
@@ -4,6 +4,7 @@ from typing import Union
|
|
|
4
4
|
import logging
|
|
5
5
|
import json
|
|
6
6
|
import zlib
|
|
7
|
+
import time
|
|
7
8
|
import base64
|
|
8
9
|
from botocore.exceptions import ClientError
|
|
9
10
|
|
|
@@ -77,7 +78,7 @@ class AWSParameterStore:
|
|
|
77
78
|
all_parameters = []
|
|
78
79
|
|
|
79
80
|
# Make the initial call to describe parameters
|
|
80
|
-
response = self.ssm_client.describe_parameters
|
|
81
|
+
response = self._call_with_retry(self.ssm_client.describe_parameters, **params)
|
|
81
82
|
|
|
82
83
|
# Aggregate the names from the initial response
|
|
83
84
|
all_parameters.extend(param["Name"] for param in response["Parameters"])
|
|
@@ -86,7 +87,7 @@ class AWSParameterStore:
|
|
|
86
87
|
while "NextToken" in response:
|
|
87
88
|
# Update the parameters with the NextToken for subsequent calls
|
|
88
89
|
params["NextToken"] = response["NextToken"]
|
|
89
|
-
response = self.ssm_client.describe_parameters
|
|
90
|
+
response = self._call_with_retry(self.ssm_client.describe_parameters, **params)
|
|
90
91
|
|
|
91
92
|
# Aggregate the names from the subsequent responses
|
|
92
93
|
all_parameters.extend(param["Name"] for param in response["Parameters"])
|
|
@@ -183,6 +184,21 @@ class AWSParameterStore:
|
|
|
183
184
|
self.log.critical(f"Failed to add/update parameter '{name}': {e}")
|
|
184
185
|
raise
|
|
185
186
|
|
|
187
|
+
def _call_with_retry(self, func, **kwargs):
|
|
188
|
+
"""Call AWS API with exponential backoff on throttling."""
|
|
189
|
+
max_retries = 5
|
|
190
|
+
base_delay = 1
|
|
191
|
+
for attempt in range(max_retries):
|
|
192
|
+
try:
|
|
193
|
+
return func(**kwargs)
|
|
194
|
+
except ClientError as e:
|
|
195
|
+
if e.response["Error"]["Code"] == "ThrottlingException" and attempt < max_retries - 1:
|
|
196
|
+
delay = base_delay * (2**attempt)
|
|
197
|
+
self.log.warning(f"Throttled, retrying in {delay}s...")
|
|
198
|
+
time.sleep(delay)
|
|
199
|
+
else:
|
|
200
|
+
raise
|
|
201
|
+
|
|
186
202
|
@staticmethod
|
|
187
203
|
def _compress_value(value) -> str:
|
|
188
204
|
"""Compress a value with precision reduction."""
|