workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +7 -7
- workbench/core/artifacts/data_capture_core.py +8 -1
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +323 -205
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +133 -101
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +18 -7
- workbench/utils/chem_utils/mol_standardize.py +80 -58
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +274 -87
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
|
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
|
|
|
21
21
|
from workbench.utils.s3_utils import compute_s3_object_hash
|
|
22
22
|
from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
|
|
23
23
|
from workbench.utils.deprecated_utils import deprecated
|
|
24
|
+
from workbench.utils.model_utils import published_proximity_model, get_model_hyperparameters
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class ModelType(Enum):
|
|
@@ -29,92 +30,64 @@ class ModelType(Enum):
|
|
|
29
30
|
CLASSIFIER = "classifier"
|
|
30
31
|
REGRESSOR = "regressor"
|
|
31
32
|
CLUSTERER = "clusterer"
|
|
32
|
-
TRANSFORMER = "transformer"
|
|
33
33
|
PROXIMITY = "proximity"
|
|
34
34
|
PROJECTION = "projection"
|
|
35
35
|
UQ_REGRESSOR = "uq_regressor"
|
|
36
36
|
ENSEMBLE_REGRESSOR = "ensemble_regressor"
|
|
37
|
+
TRANSFORMER = "transformer"
|
|
37
38
|
UNKNOWN = "unknown"
|
|
38
39
|
|
|
39
40
|
|
|
40
|
-
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
# US West 2 images
|
|
52
|
-
("us-west-2", "training", "0.1", "x86_64"): (
|
|
53
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
|
|
54
|
-
),
|
|
55
|
-
("us-west-2", "inference", "0.1", "x86_64"): (
|
|
56
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
|
|
57
|
-
),
|
|
58
|
-
|
|
59
|
-
# ARM64 images
|
|
60
|
-
("us-east-1", "inference", "0.1", "arm64"): (
|
|
61
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
62
|
-
),
|
|
63
|
-
("us-west-2", "inference", "0.1", "arm64"): (
|
|
64
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
65
|
-
),
|
|
66
|
-
"""
|
|
41
|
+
class ModelFramework(Enum):
|
|
42
|
+
"""Enumerated Types for Workbench Model Frameworks"""
|
|
43
|
+
|
|
44
|
+
SKLEARN = "sklearn"
|
|
45
|
+
XGBOOST = "xgboost"
|
|
46
|
+
LIGHTGBM = "lightgbm"
|
|
47
|
+
PYTORCH = "pytorch"
|
|
48
|
+
CHEMPROP = "chemprop"
|
|
49
|
+
TRANSFORMER = "transformer"
|
|
50
|
+
META = "meta"
|
|
51
|
+
UNKNOWN = "unknown"
|
|
67
52
|
|
|
68
53
|
|
|
69
54
|
class ModelImages:
|
|
70
55
|
"""Class for retrieving workbench inference images"""
|
|
71
56
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
|
|
84
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
85
|
-
),
|
|
86
|
-
# US West 2 images
|
|
87
|
-
("us-west-2", "training", "0.1", "x86_64"): (
|
|
88
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
|
|
89
|
-
),
|
|
90
|
-
("us-west-2", "inference", "0.1", "x86_64"): (
|
|
91
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
|
|
92
|
-
),
|
|
93
|
-
("us-west-2", "pytorch_training", "0.1", "x86_64"): (
|
|
94
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
|
|
95
|
-
),
|
|
96
|
-
("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
|
|
97
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
98
|
-
),
|
|
99
|
-
# ARM64 images
|
|
100
|
-
# Meta Endpoint inference images
|
|
101
|
-
("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
|
|
102
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
|
|
103
|
-
),
|
|
104
|
-
("us-west-2", "meta-endpoint", "0.1", "x86_64"): (
|
|
105
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
|
|
106
|
-
),
|
|
57
|
+
# Account ID
|
|
58
|
+
ACCOUNT_ID = "507740646243"
|
|
59
|
+
|
|
60
|
+
# Image name mappings
|
|
61
|
+
IMAGE_NAMES = {
|
|
62
|
+
"training": "py312-general-ml-training",
|
|
63
|
+
"inference": "py312-general-ml-inference",
|
|
64
|
+
"pytorch_training": "py312-pytorch-training",
|
|
65
|
+
"pytorch_inference": "py312-pytorch-inference",
|
|
66
|
+
"meta_training": "py312-meta-training",
|
|
67
|
+
"meta_inference": "py312-meta-inference",
|
|
107
68
|
}
|
|
108
69
|
|
|
109
70
|
@classmethod
|
|
110
|
-
def get_image_uri(cls, region, image_type, version="
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
)
|
|
71
|
+
def get_image_uri(cls, region, image_type, version="latest", architecture="x86_64"):
|
|
72
|
+
"""
|
|
73
|
+
Dynamically construct ECR image URI.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
region: AWS region (e.g., 'us-east-1', 'us-west-2')
|
|
77
|
+
image_type: Type of image (e.g., 'training', 'inference', 'pytorch_training')
|
|
78
|
+
version: Image version (e.g., '0.1', '0.2' defaults to 'latest')
|
|
79
|
+
architecture: CPU architecture (default: 'x86_64', currently unused but kept for compatibility)
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
ECR image URI string
|
|
83
|
+
"""
|
|
84
|
+
if image_type not in cls.IMAGE_NAMES:
|
|
85
|
+
raise ValueError(f"Unknown image_type: {image_type}. Valid types: {list(cls.IMAGE_NAMES.keys())}")
|
|
86
|
+
|
|
87
|
+
image_name = cls.IMAGE_NAMES[image_type]
|
|
88
|
+
uri = f"{cls.ACCOUNT_ID}.dkr.ecr.{region}.amazonaws.com/aws-ml-images/{image_name}:{version}"
|
|
89
|
+
|
|
90
|
+
return uri
|
|
118
91
|
|
|
119
92
|
|
|
120
93
|
class ModelCore(Artifact):
|
|
@@ -128,11 +101,10 @@ class ModelCore(Artifact):
|
|
|
128
101
|
```
|
|
129
102
|
"""
|
|
130
103
|
|
|
131
|
-
def __init__(self, model_name: str,
|
|
104
|
+
def __init__(self, model_name: str, **kwargs):
|
|
132
105
|
"""ModelCore Initialization
|
|
133
106
|
Args:
|
|
134
107
|
model_name (str): Name of Model in Workbench.
|
|
135
|
-
model_type (ModelType, optional): Set this for newly created Models. Defaults to None.
|
|
136
108
|
**kwargs: Additional keyword arguments
|
|
137
109
|
"""
|
|
138
110
|
|
|
@@ -166,10 +138,8 @@ class ModelCore(Artifact):
|
|
|
166
138
|
self.latest_model = self.model_meta["ModelPackageList"][0]
|
|
167
139
|
self.description = self.latest_model.get("ModelPackageDescription", "-")
|
|
168
140
|
self.training_job_name = self._extract_training_job_name()
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
else:
|
|
172
|
-
self.model_type = self._get_model_type()
|
|
141
|
+
self.model_type = self._get_model_type()
|
|
142
|
+
self.model_framework = self._get_model_framework()
|
|
173
143
|
except (IndexError, KeyError):
|
|
174
144
|
self.log.critical(f"Model {self.model_name} appears to be malformed. Delete and recreate it!")
|
|
175
145
|
return
|
|
@@ -295,21 +265,25 @@ class ModelCore(Artifact):
|
|
|
295
265
|
else:
|
|
296
266
|
self.log.important(f"No inference data found for {self.model_name}!")
|
|
297
267
|
|
|
298
|
-
def get_inference_metrics(self, capture_name: str = "
|
|
268
|
+
def get_inference_metrics(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
|
|
299
269
|
"""Retrieve the inference performance metrics for this model
|
|
300
270
|
|
|
301
271
|
Args:
|
|
302
|
-
capture_name (str, optional): Specific capture_name
|
|
272
|
+
capture_name (str, optional): Specific capture_name (default: "auto")
|
|
303
273
|
Returns:
|
|
304
274
|
pd.DataFrame: DataFrame of the Model Metrics
|
|
305
275
|
|
|
306
276
|
Note:
|
|
307
|
-
If a capture_name isn't specified this will try to
|
|
277
|
+
If a capture_name isn't specified this will try to the 'first' available metrics
|
|
308
278
|
"""
|
|
309
279
|
# Try to get the auto_capture 'training_holdout' or the training
|
|
310
|
-
if capture_name == "
|
|
311
|
-
|
|
312
|
-
|
|
280
|
+
if capture_name == "auto":
|
|
281
|
+
metric_list = self.list_inference_runs()
|
|
282
|
+
if metric_list:
|
|
283
|
+
return self.get_inference_metrics(metric_list[0])
|
|
284
|
+
else:
|
|
285
|
+
self.log.warning(f"No performance metrics found for {self.model_name}!")
|
|
286
|
+
return None
|
|
313
287
|
|
|
314
288
|
# Grab the metrics captured during model training (could return None)
|
|
315
289
|
if capture_name == "model_training":
|
|
@@ -331,11 +305,11 @@ class ModelCore(Artifact):
|
|
|
331
305
|
self.log.warning(f"Performance metrics {capture_name} not found for {self.model_name}!")
|
|
332
306
|
return None
|
|
333
307
|
|
|
334
|
-
def confusion_matrix(self, capture_name: str = "
|
|
308
|
+
def confusion_matrix(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
|
|
335
309
|
"""Retrieve the confusion_matrix for this model
|
|
336
310
|
|
|
337
311
|
Args:
|
|
338
|
-
capture_name (str, optional): Specific capture_name or "training" (default: "
|
|
312
|
+
capture_name (str, optional): Specific capture_name or "training" (default: "auto")
|
|
339
313
|
Returns:
|
|
340
314
|
pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
|
|
341
315
|
"""
|
|
@@ -347,7 +321,7 @@ class ModelCore(Artifact):
|
|
|
347
321
|
raise ValueError(error_msg)
|
|
348
322
|
|
|
349
323
|
# Grab the metrics from the Workbench Metadata (try inference first, then training)
|
|
350
|
-
if capture_name == "
|
|
324
|
+
if capture_name == "auto":
|
|
351
325
|
cm = self.confusion_matrix("auto_inference")
|
|
352
326
|
return cm if cm is not None else self.confusion_matrix("model_training")
|
|
353
327
|
|
|
@@ -569,6 +543,17 @@ class ModelCore(Artifact):
|
|
|
569
543
|
else:
|
|
570
544
|
self.log.error(f"Model {self.model_name} is not a classifier!")
|
|
571
545
|
|
|
546
|
+
def summary(self) -> dict:
|
|
547
|
+
"""Summary information about this Model
|
|
548
|
+
|
|
549
|
+
Returns:
|
|
550
|
+
dict: Dictionary of summary information about this Model
|
|
551
|
+
"""
|
|
552
|
+
self.log.info("Computing Model Summary...")
|
|
553
|
+
summary = super().summary()
|
|
554
|
+
summary["hyperparameters"] = get_model_hyperparameters(self)
|
|
555
|
+
return summary
|
|
556
|
+
|
|
572
557
|
def details(self) -> dict:
|
|
573
558
|
"""Additional Details about this Model
|
|
574
559
|
|
|
@@ -593,6 +578,7 @@ class ModelCore(Artifact):
|
|
|
593
578
|
details["status"] = self.latest_model["ModelPackageStatus"]
|
|
594
579
|
details["approval_status"] = self.latest_model.get("ModelApprovalStatus", "unknown")
|
|
595
580
|
details["image"] = self.container_image().split("/")[-1] # Shorten the image uri
|
|
581
|
+
details["hyperparameters"] = get_model_hyperparameters(self)
|
|
596
582
|
|
|
597
583
|
# Grab the inference and container info
|
|
598
584
|
inference_spec = self.latest_model["InferenceSpecification"]
|
|
@@ -603,16 +589,6 @@ class ModelCore(Artifact):
|
|
|
603
589
|
details["transform_types"] = inference_spec["SupportedTransformInstanceTypes"]
|
|
604
590
|
details["content_types"] = inference_spec["SupportedContentTypes"]
|
|
605
591
|
details["response_types"] = inference_spec["SupportedResponseMIMETypes"]
|
|
606
|
-
details["model_metrics"] = self.get_inference_metrics()
|
|
607
|
-
if self.model_type == ModelType.CLASSIFIER:
|
|
608
|
-
details["confusion_matrix"] = self.confusion_matrix()
|
|
609
|
-
details["predictions"] = None
|
|
610
|
-
elif self.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
|
|
611
|
-
details["confusion_matrix"] = None
|
|
612
|
-
details["predictions"] = self.get_inference_predictions()
|
|
613
|
-
else:
|
|
614
|
-
details["confusion_matrix"] = None
|
|
615
|
-
details["predictions"] = None
|
|
616
592
|
|
|
617
593
|
# Grab the inference metadata
|
|
618
594
|
details["inference_meta"] = self.get_inference_metadata()
|
|
@@ -620,6 +596,24 @@ class ModelCore(Artifact):
|
|
|
620
596
|
# Return the details
|
|
621
597
|
return details
|
|
622
598
|
|
|
599
|
+
# Training View for this model
|
|
600
|
+
def training_view(self):
|
|
601
|
+
"""Get the training view for this model"""
|
|
602
|
+
from workbench.core.artifacts.feature_set_core import FeatureSetCore
|
|
603
|
+
from workbench.core.views import View
|
|
604
|
+
|
|
605
|
+
# Grab our FeatureSet
|
|
606
|
+
fs = FeatureSetCore(self.get_input())
|
|
607
|
+
|
|
608
|
+
# See if we have a training view for this model
|
|
609
|
+
my_model_training_view = f"{self.name.replace('-', '_')}_training".lower()
|
|
610
|
+
view = View(fs, my_model_training_view, auto_create_view=False)
|
|
611
|
+
if view.exists():
|
|
612
|
+
return view
|
|
613
|
+
else:
|
|
614
|
+
self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
|
|
615
|
+
return fs.view("training")
|
|
616
|
+
|
|
623
617
|
# Pipeline for this model
|
|
624
618
|
def get_pipeline(self) -> str:
|
|
625
619
|
"""Get the pipeline for this model"""
|
|
@@ -883,7 +877,7 @@ class ModelCore(Artifact):
|
|
|
883
877
|
return self.df_store.get(f"/workbench/models/{self.name}/shap_data")
|
|
884
878
|
else:
|
|
885
879
|
# Loop over the SHAP data and return a dict of DataFrames
|
|
886
|
-
shap_dfs = self.df_store.
|
|
880
|
+
shap_dfs = self.df_store.list(f"/workbench/models/{self.name}/shap_data")
|
|
887
881
|
shap_data = {}
|
|
888
882
|
for df_location in shap_dfs:
|
|
889
883
|
key = df_location.split("/")[-1]
|
|
@@ -902,10 +896,24 @@ class ModelCore(Artifact):
|
|
|
902
896
|
except (KeyError, IndexError, TypeError):
|
|
903
897
|
return None
|
|
904
898
|
|
|
899
|
+
def publish_prox_model(self, prox_model_name: str = None, include_all_columns: bool = False):
|
|
900
|
+
"""Create and publish a Proximity Model for this Model
|
|
901
|
+
|
|
902
|
+
Args:
|
|
903
|
+
prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
|
|
904
|
+
include_all_columns (bool): Include all DataFrame columns in results (default: False)
|
|
905
|
+
|
|
906
|
+
Returns:
|
|
907
|
+
Model: The published Proximity Model
|
|
908
|
+
"""
|
|
909
|
+
if prox_model_name is None:
|
|
910
|
+
prox_model_name = self.model_name + "-prox"
|
|
911
|
+
return published_proximity_model(self, prox_model_name, include_all_columns=include_all_columns)
|
|
912
|
+
|
|
905
913
|
def delete(self):
|
|
906
914
|
"""Delete the Model Packages and the Model Group"""
|
|
907
915
|
if not self.exists():
|
|
908
|
-
self.log.warning(f"Trying to delete
|
|
916
|
+
self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
|
|
909
917
|
|
|
910
918
|
# Call the Class Method to delete the Model Group
|
|
911
919
|
ModelCore.managed_delete(model_group_name=self.name)
|
|
@@ -981,6 +989,27 @@ class ModelCore(Artifact):
|
|
|
981
989
|
self.log.warning(f"Could not determine model type for {self.model_name}!")
|
|
982
990
|
return ModelType.UNKNOWN
|
|
983
991
|
|
|
992
|
+
def _set_model_framework(self, model_framework: ModelFramework):
|
|
993
|
+
"""Internal: Set the Model Framework for this Model"""
|
|
994
|
+
self.model_framework = model_framework
|
|
995
|
+
self.upsert_workbench_meta({"workbench_model_framework": self.model_framework.value})
|
|
996
|
+
self.remove_health_tag("model_framework_unknown")
|
|
997
|
+
|
|
998
|
+
def _get_model_framework(self) -> ModelFramework:
|
|
999
|
+
"""Internal: Query the Workbench Metadata to get the model framework
|
|
1000
|
+
Returns:
|
|
1001
|
+
ModelFramework: The ModelFramework of this Model
|
|
1002
|
+
Notes:
|
|
1003
|
+
This is an internal method that should not be called directly
|
|
1004
|
+
Use the model_framework attribute instead
|
|
1005
|
+
"""
|
|
1006
|
+
model_framework = self.workbench_meta().get("workbench_model_framework")
|
|
1007
|
+
try:
|
|
1008
|
+
return ModelFramework(model_framework)
|
|
1009
|
+
except ValueError:
|
|
1010
|
+
self.log.warning(f"Could not determine model framework for {self.model_name}!")
|
|
1011
|
+
return ModelFramework.UNKNOWN
|
|
1012
|
+
|
|
984
1013
|
def _load_training_metrics(self):
|
|
985
1014
|
"""Internal: Retrieve the training metrics and Confusion Matrix for this model
|
|
986
1015
|
and load the data into the Workbench Metadata
|
|
@@ -1172,13 +1201,11 @@ if __name__ == "__main__":
|
|
|
1172
1201
|
# Grab a ModelCore object and pull some information from it
|
|
1173
1202
|
my_model = ModelCore("abalone-regression")
|
|
1174
1203
|
|
|
1175
|
-
# Call the various methods
|
|
1176
|
-
|
|
1177
1204
|
# Let's do a check/validation of the Model
|
|
1178
1205
|
print(f"Model Check: {my_model.exists()}")
|
|
1179
1206
|
|
|
1180
1207
|
# Make sure the model is 'ready'
|
|
1181
|
-
|
|
1208
|
+
my_model.onboard()
|
|
1182
1209
|
|
|
1183
1210
|
# Get the ARN of the Model Group
|
|
1184
1211
|
print(f"Model Group ARN: {my_model.group_arn()}")
|
|
@@ -1244,5 +1271,10 @@ if __name__ == "__main__":
|
|
|
1244
1271
|
# Delete the Model
|
|
1245
1272
|
# ModelCore.managed_delete("wine-classification")
|
|
1246
1273
|
|
|
1274
|
+
# Check the training view logic
|
|
1275
|
+
model = ModelCore("wine-class-test-251112-BW")
|
|
1276
|
+
training_view = model.training_view()
|
|
1277
|
+
print(f"Training View Name: {training_view.name}")
|
|
1278
|
+
|
|
1247
1279
|
# Check for a model that doesn't exist
|
|
1248
1280
|
my_model = ModelCore("empty-model-group")
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""ParameterStoreCore: Manages Workbench parameters in a Cloud Based Parameter Store."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
# Workbench Imports
|
|
6
|
+
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
7
|
+
|
|
8
|
+
# Workbench Bridges Import
|
|
9
|
+
from workbench_bridges.api import ParameterStore as BridgesParameterStore
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ParameterStoreCore(BridgesParameterStore):
|
|
13
|
+
"""ParameterStoreCore: Manages Workbench parameters in a Cloud Based Parameter Store.
|
|
14
|
+
|
|
15
|
+
Common Usage:
|
|
16
|
+
```python
|
|
17
|
+
params = ParameterStoreCore()
|
|
18
|
+
|
|
19
|
+
# List Parameters
|
|
20
|
+
params.list()
|
|
21
|
+
|
|
22
|
+
['/workbench/abalone_info',
|
|
23
|
+
'/workbench/my_data',
|
|
24
|
+
'/workbench/test',
|
|
25
|
+
'/workbench/pipelines/my_pipeline']
|
|
26
|
+
|
|
27
|
+
# Add Key
|
|
28
|
+
params.upsert("key", "value")
|
|
29
|
+
value = params.get("key")
|
|
30
|
+
|
|
31
|
+
# Add any data (lists, dictionaries, etc..)
|
|
32
|
+
my_data = {"key": "value", "number": 4.2, "list": [1,2,3]}
|
|
33
|
+
params.upsert("my_data", my_data)
|
|
34
|
+
|
|
35
|
+
# Retrieve data
|
|
36
|
+
return_value = params.get("my_data")
|
|
37
|
+
pprint(return_value)
|
|
38
|
+
|
|
39
|
+
{'key': 'value', 'list': [1, 2, 3], 'number': 4.2}
|
|
40
|
+
|
|
41
|
+
# Delete parameters
|
|
42
|
+
param_store.delete("my_data")
|
|
43
|
+
```
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self):
|
|
47
|
+
"""ParameterStoreCore Init Method"""
|
|
48
|
+
session = AWSAccountClamp().boto3_session
|
|
49
|
+
|
|
50
|
+
# Initialize parent with workbench config
|
|
51
|
+
super().__init__(boto3_session=session)
|
|
52
|
+
self.log = logging.getLogger("workbench")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
if __name__ == "__main__":
|
|
56
|
+
"""Exercise the ParameterStoreCore Class"""
|
|
57
|
+
|
|
58
|
+
# Create a ParameterStoreCore manager
|
|
59
|
+
param_store = ParameterStoreCore()
|
|
60
|
+
|
|
61
|
+
# List the parameters
|
|
62
|
+
print("Listing Parameters...")
|
|
63
|
+
print(param_store.list())
|
|
64
|
+
|
|
65
|
+
# Add a new parameter
|
|
66
|
+
param_store.upsert("/workbench/test", "value")
|
|
67
|
+
|
|
68
|
+
# Get the parameter
|
|
69
|
+
print(f"Getting parameter 'test': {param_store.get('/workbench/test')}")
|
|
70
|
+
|
|
71
|
+
# Add a dictionary as a parameter
|
|
72
|
+
sample_dict = {"key": "str_value", "awesome_value": 4.2}
|
|
73
|
+
param_store.upsert("/workbench/my_data", sample_dict)
|
|
74
|
+
|
|
75
|
+
# Retrieve the parameter as a dictionary
|
|
76
|
+
retrieved_value = param_store.get("/workbench/my_data")
|
|
77
|
+
print("Retrieved value:", retrieved_value)
|
|
78
|
+
|
|
79
|
+
# List the parameters
|
|
80
|
+
print("Listing Parameters...")
|
|
81
|
+
print(param_store.list())
|
|
82
|
+
|
|
83
|
+
# List the parameters with a prefix
|
|
84
|
+
print("Listing Parameters with prefix '/workbench':")
|
|
85
|
+
print(param_store.list("/workbench"))
|
|
86
|
+
|
|
87
|
+
# Delete the parameters
|
|
88
|
+
param_store.delete("/workbench/test")
|
|
89
|
+
param_store.delete("/workbench/my_data")
|
|
90
|
+
|
|
91
|
+
# Out of scope tests
|
|
92
|
+
param_store.upsert("test", "value")
|
|
93
|
+
param_store.delete("test")
|
|
94
|
+
|
|
95
|
+
# Recursive delete test
|
|
96
|
+
param_store.upsert("/workbench/test/test1", "value1")
|
|
97
|
+
param_store.upsert("/workbench/test/test2", "value2")
|
|
98
|
+
param_store.delete_recursive("workbench/test/")
|
|
@@ -55,9 +55,10 @@ class AWSAccountClamp:
|
|
|
55
55
|
# Check our Assume Role
|
|
56
56
|
self.log.info("Checking Workbench Assumed Role...")
|
|
57
57
|
role_info = self.aws_session.assumed_role_info()
|
|
58
|
+
self.log.info(f"Assumed Role: {role_info}")
|
|
58
59
|
|
|
59
|
-
# Check if
|
|
60
|
-
self.
|
|
60
|
+
# Check if we have tag write permissions (if we don't, we are read-only)
|
|
61
|
+
self.read_only = not self.check_tag_permissions()
|
|
61
62
|
|
|
62
63
|
# Check our Workbench API Key and Load the License
|
|
63
64
|
self.log.info("Checking Workbench API License...")
|
|
@@ -141,6 +142,45 @@ class AWSAccountClamp:
|
|
|
141
142
|
"""
|
|
142
143
|
return self.boto3_session.client("sagemaker")
|
|
143
144
|
|
|
145
|
+
def check_tag_permissions(self):
|
|
146
|
+
"""Check if current role has permission to add tags to SageMaker endpoints.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
bool: True if AddTags is allowed, False otherwise
|
|
150
|
+
"""
|
|
151
|
+
try:
|
|
152
|
+
sagemaker = self.boto3_session.client("sagemaker")
|
|
153
|
+
|
|
154
|
+
# Use a non-existent endpoint name
|
|
155
|
+
fake_endpoint = "workbench-permission-check-dummy-endpoint"
|
|
156
|
+
|
|
157
|
+
# Try to add tags to the non-existent endpoint
|
|
158
|
+
sagemaker.add_tags(
|
|
159
|
+
ResourceArn=f"arn:aws:sagemaker:{self.region}:{self.account_id}:endpoint/{fake_endpoint}",
|
|
160
|
+
Tags=[{"Key": "PermissionCheck", "Value": "Test"}],
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# If we get here, we have permission (but endpoint doesn't exist)
|
|
164
|
+
return True
|
|
165
|
+
|
|
166
|
+
except ClientError as e:
|
|
167
|
+
error_code = e.response["Error"]["Code"]
|
|
168
|
+
|
|
169
|
+
# AccessDeniedException = no permission
|
|
170
|
+
if error_code == "AccessDeniedException":
|
|
171
|
+
self.log.debug("No AddTags permission (AccessDeniedException)")
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
# ResourceNotFound = we have permission, but endpoint doesn't exist
|
|
175
|
+
elif error_code in ["ResourceNotFound", "ValidationException"]:
|
|
176
|
+
self.log.debug("AddTags permission verified (resource not found)")
|
|
177
|
+
return True
|
|
178
|
+
|
|
179
|
+
# Unexpected error, assume no permission for safety
|
|
180
|
+
else:
|
|
181
|
+
self.log.debug(f"Unexpected error checking permissions: {error_code}")
|
|
182
|
+
return False
|
|
183
|
+
|
|
144
184
|
|
|
145
185
|
if __name__ == "__main__":
|
|
146
186
|
"""Exercise the AWS Account Clamp Class"""
|
|
@@ -165,3 +205,9 @@ if __name__ == "__main__":
|
|
|
165
205
|
print("\n\n*** AWS Sagemaker Session/Client Check ***")
|
|
166
206
|
sm_client = aws_account_clamp.sagemaker_client()
|
|
167
207
|
print(sm_client.list_feature_groups()["FeatureGroupSummaries"])
|
|
208
|
+
|
|
209
|
+
print("\n\n*** AWS Tag Permission Check ***")
|
|
210
|
+
if aws_account_clamp.check_tag_permissions():
|
|
211
|
+
print("Tag Permission Check Success...")
|
|
212
|
+
else:
|
|
213
|
+
print("Tag Permission Check Failed...")
|
|
@@ -123,7 +123,7 @@ class PipelineExecutor:
|
|
|
123
123
|
if "model" in workbench_objects and (not subset or "endpoint" in subset):
|
|
124
124
|
workbench_objects["model"].to_endpoint(**kwargs)
|
|
125
125
|
endpoint = Endpoint(kwargs["name"])
|
|
126
|
-
endpoint.auto_inference(
|
|
126
|
+
endpoint.auto_inference()
|
|
127
127
|
|
|
128
128
|
# Found something weird
|
|
129
129
|
else:
|