workbench 0.8.177__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +5 -5
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +319 -204
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -82
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +260 -76
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/RECORD +121 -106
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -494
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -494
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
|
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
|
|
|
21
21
|
from workbench.utils.s3_utils import compute_s3_object_hash
|
|
22
22
|
from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
|
|
23
23
|
from workbench.utils.deprecated_utils import deprecated
|
|
24
|
+
from workbench.utils.model_utils import published_proximity_model, get_model_hyperparameters
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class ModelType(Enum):
|
|
@@ -29,63 +30,64 @@ class ModelType(Enum):
|
|
|
29
30
|
CLASSIFIER = "classifier"
|
|
30
31
|
REGRESSOR = "regressor"
|
|
31
32
|
CLUSTERER = "clusterer"
|
|
32
|
-
TRANSFORMER = "transformer"
|
|
33
33
|
PROXIMITY = "proximity"
|
|
34
34
|
PROJECTION = "projection"
|
|
35
35
|
UQ_REGRESSOR = "uq_regressor"
|
|
36
36
|
ENSEMBLE_REGRESSOR = "ensemble_regressor"
|
|
37
|
+
TRANSFORMER = "transformer"
|
|
38
|
+
UNKNOWN = "unknown"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ModelFramework(Enum):
|
|
42
|
+
"""Enumerated Types for Workbench Model Frameworks"""
|
|
43
|
+
|
|
44
|
+
SKLEARN = "sklearn"
|
|
45
|
+
XGBOOST = "xgboost"
|
|
46
|
+
LIGHTGBM = "lightgbm"
|
|
47
|
+
PYTORCH = "pytorch"
|
|
48
|
+
CHEMPROP = "chemprop"
|
|
49
|
+
TRANSFORMER = "transformer"
|
|
50
|
+
META = "meta"
|
|
37
51
|
UNKNOWN = "unknown"
|
|
38
52
|
|
|
39
53
|
|
|
40
54
|
class ModelImages:
|
|
41
55
|
"""Class for retrieving workbench inference images"""
|
|
42
56
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
|
|
55
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
56
|
-
),
|
|
57
|
-
# US West 2 images
|
|
58
|
-
("us-west-2", "training", "0.1", "x86_64"): (
|
|
59
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
|
|
60
|
-
),
|
|
61
|
-
("us-west-2", "inference", "0.1", "x86_64"): (
|
|
62
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
|
|
63
|
-
),
|
|
64
|
-
("us-west-2", "pytorch_training", "0.1", "x86_64"): (
|
|
65
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
|
|
66
|
-
),
|
|
67
|
-
("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
|
|
68
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
69
|
-
),
|
|
70
|
-
# ARM64 images
|
|
71
|
-
# Meta Endpoint inference images
|
|
72
|
-
("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
|
|
73
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
|
|
74
|
-
),
|
|
75
|
-
("us-west-2", "meta-endpoint", "0.1", "x86_64"): (
|
|
76
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
|
|
77
|
-
),
|
|
57
|
+
# Account ID
|
|
58
|
+
ACCOUNT_ID = "507740646243"
|
|
59
|
+
|
|
60
|
+
# Image name mappings
|
|
61
|
+
IMAGE_NAMES = {
|
|
62
|
+
"training": "py312-general-ml-training",
|
|
63
|
+
"inference": "py312-general-ml-inference",
|
|
64
|
+
"pytorch_training": "py312-pytorch-training",
|
|
65
|
+
"pytorch_inference": "py312-pytorch-inference",
|
|
66
|
+
"meta_training": "py312-meta-training",
|
|
67
|
+
"meta_inference": "py312-meta-inference",
|
|
78
68
|
}
|
|
79
69
|
|
|
80
70
|
@classmethod
|
|
81
|
-
def get_image_uri(cls, region, image_type, version="
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
)
|
|
71
|
+
def get_image_uri(cls, region, image_type, version="latest", architecture="x86_64"):
|
|
72
|
+
"""
|
|
73
|
+
Dynamically construct ECR image URI.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
region: AWS region (e.g., 'us-east-1', 'us-west-2')
|
|
77
|
+
image_type: Type of image (e.g., 'training', 'inference', 'pytorch_training')
|
|
78
|
+
version: Image version (e.g., '0.1', '0.2' defaults to 'latest')
|
|
79
|
+
architecture: CPU architecture (default: 'x86_64', currently unused but kept for compatibility)
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
ECR image URI string
|
|
83
|
+
"""
|
|
84
|
+
if image_type not in cls.IMAGE_NAMES:
|
|
85
|
+
raise ValueError(f"Unknown image_type: {image_type}. Valid types: {list(cls.IMAGE_NAMES.keys())}")
|
|
86
|
+
|
|
87
|
+
image_name = cls.IMAGE_NAMES[image_type]
|
|
88
|
+
uri = f"{cls.ACCOUNT_ID}.dkr.ecr.{region}.amazonaws.com/aws-ml-images/{image_name}:{version}"
|
|
89
|
+
|
|
90
|
+
return uri
|
|
89
91
|
|
|
90
92
|
|
|
91
93
|
class ModelCore(Artifact):
|
|
@@ -99,11 +101,10 @@ class ModelCore(Artifact):
|
|
|
99
101
|
```
|
|
100
102
|
"""
|
|
101
103
|
|
|
102
|
-
def __init__(self, model_name: str,
|
|
104
|
+
def __init__(self, model_name: str, **kwargs):
|
|
103
105
|
"""ModelCore Initialization
|
|
104
106
|
Args:
|
|
105
107
|
model_name (str): Name of Model in Workbench.
|
|
106
|
-
model_type (ModelType, optional): Set this for newly created Models. Defaults to None.
|
|
107
108
|
**kwargs: Additional keyword arguments
|
|
108
109
|
"""
|
|
109
110
|
|
|
@@ -137,10 +138,8 @@ class ModelCore(Artifact):
|
|
|
137
138
|
self.latest_model = self.model_meta["ModelPackageList"][0]
|
|
138
139
|
self.description = self.latest_model.get("ModelPackageDescription", "-")
|
|
139
140
|
self.training_job_name = self._extract_training_job_name()
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
else:
|
|
143
|
-
self.model_type = self._get_model_type()
|
|
141
|
+
self.model_type = self._get_model_type()
|
|
142
|
+
self.model_framework = self._get_model_framework()
|
|
144
143
|
except (IndexError, KeyError):
|
|
145
144
|
self.log.critical(f"Model {self.model_name} appears to be malformed. Delete and recreate it!")
|
|
146
145
|
return
|
|
@@ -266,21 +265,25 @@ class ModelCore(Artifact):
|
|
|
266
265
|
else:
|
|
267
266
|
self.log.important(f"No inference data found for {self.model_name}!")
|
|
268
267
|
|
|
269
|
-
def get_inference_metrics(self, capture_name: str = "
|
|
268
|
+
def get_inference_metrics(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
|
|
270
269
|
"""Retrieve the inference performance metrics for this model
|
|
271
270
|
|
|
272
271
|
Args:
|
|
273
|
-
capture_name (str, optional): Specific capture_name
|
|
272
|
+
capture_name (str, optional): Specific capture_name (default: "auto")
|
|
274
273
|
Returns:
|
|
275
274
|
pd.DataFrame: DataFrame of the Model Metrics
|
|
276
275
|
|
|
277
276
|
Note:
|
|
278
|
-
If a capture_name isn't specified this will try to
|
|
277
|
+
If a capture_name isn't specified this will try to the 'first' available metrics
|
|
279
278
|
"""
|
|
280
279
|
# Try to get the auto_capture 'training_holdout' or the training
|
|
281
|
-
if capture_name == "
|
|
282
|
-
|
|
283
|
-
|
|
280
|
+
if capture_name == "auto":
|
|
281
|
+
metric_list = self.list_inference_runs()
|
|
282
|
+
if metric_list:
|
|
283
|
+
return self.get_inference_metrics(metric_list[0])
|
|
284
|
+
else:
|
|
285
|
+
self.log.warning(f"No performance metrics found for {self.model_name}!")
|
|
286
|
+
return None
|
|
284
287
|
|
|
285
288
|
# Grab the metrics captured during model training (could return None)
|
|
286
289
|
if capture_name == "model_training":
|
|
@@ -302,11 +305,11 @@ class ModelCore(Artifact):
|
|
|
302
305
|
self.log.warning(f"Performance metrics {capture_name} not found for {self.model_name}!")
|
|
303
306
|
return None
|
|
304
307
|
|
|
305
|
-
def confusion_matrix(self, capture_name: str = "
|
|
308
|
+
def confusion_matrix(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
|
|
306
309
|
"""Retrieve the confusion_matrix for this model
|
|
307
310
|
|
|
308
311
|
Args:
|
|
309
|
-
capture_name (str, optional): Specific capture_name or "training" (default: "
|
|
312
|
+
capture_name (str, optional): Specific capture_name or "training" (default: "auto")
|
|
310
313
|
Returns:
|
|
311
314
|
pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
|
|
312
315
|
"""
|
|
@@ -318,7 +321,7 @@ class ModelCore(Artifact):
|
|
|
318
321
|
raise ValueError(error_msg)
|
|
319
322
|
|
|
320
323
|
# Grab the metrics from the Workbench Metadata (try inference first, then training)
|
|
321
|
-
if capture_name == "
|
|
324
|
+
if capture_name == "auto":
|
|
322
325
|
cm = self.confusion_matrix("auto_inference")
|
|
323
326
|
return cm if cm is not None else self.confusion_matrix("model_training")
|
|
324
327
|
|
|
@@ -540,6 +543,17 @@ class ModelCore(Artifact):
|
|
|
540
543
|
else:
|
|
541
544
|
self.log.error(f"Model {self.model_name} is not a classifier!")
|
|
542
545
|
|
|
546
|
+
def summary(self) -> dict:
|
|
547
|
+
"""Summary information about this Model
|
|
548
|
+
|
|
549
|
+
Returns:
|
|
550
|
+
dict: Dictionary of summary information about this Model
|
|
551
|
+
"""
|
|
552
|
+
self.log.info("Computing Model Summary...")
|
|
553
|
+
summary = super().summary()
|
|
554
|
+
summary["hyperparameters"] = get_model_hyperparameters(self)
|
|
555
|
+
return summary
|
|
556
|
+
|
|
543
557
|
def details(self) -> dict:
|
|
544
558
|
"""Additional Details about this Model
|
|
545
559
|
|
|
@@ -564,6 +578,7 @@ class ModelCore(Artifact):
|
|
|
564
578
|
details["status"] = self.latest_model["ModelPackageStatus"]
|
|
565
579
|
details["approval_status"] = self.latest_model.get("ModelApprovalStatus", "unknown")
|
|
566
580
|
details["image"] = self.container_image().split("/")[-1] # Shorten the image uri
|
|
581
|
+
details["hyperparameters"] = get_model_hyperparameters(self)
|
|
567
582
|
|
|
568
583
|
# Grab the inference and container info
|
|
569
584
|
inference_spec = self.latest_model["InferenceSpecification"]
|
|
@@ -574,16 +589,6 @@ class ModelCore(Artifact):
|
|
|
574
589
|
details["transform_types"] = inference_spec["SupportedTransformInstanceTypes"]
|
|
575
590
|
details["content_types"] = inference_spec["SupportedContentTypes"]
|
|
576
591
|
details["response_types"] = inference_spec["SupportedResponseMIMETypes"]
|
|
577
|
-
details["model_metrics"] = self.get_inference_metrics()
|
|
578
|
-
if self.model_type == ModelType.CLASSIFIER:
|
|
579
|
-
details["confusion_matrix"] = self.confusion_matrix()
|
|
580
|
-
details["predictions"] = None
|
|
581
|
-
elif self.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
|
|
582
|
-
details["confusion_matrix"] = None
|
|
583
|
-
details["predictions"] = self.get_inference_predictions()
|
|
584
|
-
else:
|
|
585
|
-
details["confusion_matrix"] = None
|
|
586
|
-
details["predictions"] = None
|
|
587
592
|
|
|
588
593
|
# Grab the inference metadata
|
|
589
594
|
details["inference_meta"] = self.get_inference_metadata()
|
|
@@ -591,6 +596,24 @@ class ModelCore(Artifact):
|
|
|
591
596
|
# Return the details
|
|
592
597
|
return details
|
|
593
598
|
|
|
599
|
+
# Training View for this model
|
|
600
|
+
def training_view(self):
|
|
601
|
+
"""Get the training view for this model"""
|
|
602
|
+
from workbench.core.artifacts.feature_set_core import FeatureSetCore
|
|
603
|
+
from workbench.core.views import View
|
|
604
|
+
|
|
605
|
+
# Grab our FeatureSet
|
|
606
|
+
fs = FeatureSetCore(self.get_input())
|
|
607
|
+
|
|
608
|
+
# See if we have a training view for this model
|
|
609
|
+
my_model_training_view = f"{self.name.replace('-', '_')}_training".lower()
|
|
610
|
+
view = View(fs, my_model_training_view, auto_create_view=False)
|
|
611
|
+
if view.exists():
|
|
612
|
+
return view
|
|
613
|
+
else:
|
|
614
|
+
self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
|
|
615
|
+
return fs.view("training")
|
|
616
|
+
|
|
594
617
|
# Pipeline for this model
|
|
595
618
|
def get_pipeline(self) -> str:
|
|
596
619
|
"""Get the pipeline for this model"""
|
|
@@ -854,21 +877,13 @@ class ModelCore(Artifact):
|
|
|
854
877
|
return self.df_store.get(f"/workbench/models/{self.name}/shap_data")
|
|
855
878
|
else:
|
|
856
879
|
# Loop over the SHAP data and return a dict of DataFrames
|
|
857
|
-
shap_dfs = self.df_store.
|
|
880
|
+
shap_dfs = self.df_store.list(f"/workbench/models/{self.name}/shap_data")
|
|
858
881
|
shap_data = {}
|
|
859
882
|
for df_location in shap_dfs:
|
|
860
883
|
key = df_location.split("/")[-1]
|
|
861
884
|
shap_data[key] = self.df_store.get(df_location)
|
|
862
885
|
return shap_data or None
|
|
863
886
|
|
|
864
|
-
def cross_folds(self) -> dict:
|
|
865
|
-
"""Retrieve the cross-fold inference results(only works for XGBoost models)
|
|
866
|
-
|
|
867
|
-
Returns:
|
|
868
|
-
dict: Dictionary with the cross-fold inference results
|
|
869
|
-
"""
|
|
870
|
-
return self.param_store.get(f"/workbench/models/{self.name}/inference/cross_fold")
|
|
871
|
-
|
|
872
887
|
def supported_inference_instances(self) -> Optional[list]:
|
|
873
888
|
"""Retrieve the supported endpoint inference instance types
|
|
874
889
|
|
|
@@ -881,10 +896,24 @@ class ModelCore(Artifact):
|
|
|
881
896
|
except (KeyError, IndexError, TypeError):
|
|
882
897
|
return None
|
|
883
898
|
|
|
899
|
+
def publish_prox_model(self, prox_model_name: str = None, include_all_columns: bool = False):
|
|
900
|
+
"""Create and publish a Proximity Model for this Model
|
|
901
|
+
|
|
902
|
+
Args:
|
|
903
|
+
prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
|
|
904
|
+
include_all_columns (bool): Include all DataFrame columns in results (default: False)
|
|
905
|
+
|
|
906
|
+
Returns:
|
|
907
|
+
Model: The published Proximity Model
|
|
908
|
+
"""
|
|
909
|
+
if prox_model_name is None:
|
|
910
|
+
prox_model_name = self.model_name + "-prox"
|
|
911
|
+
return published_proximity_model(self, prox_model_name, include_all_columns=include_all_columns)
|
|
912
|
+
|
|
884
913
|
def delete(self):
|
|
885
914
|
"""Delete the Model Packages and the Model Group"""
|
|
886
915
|
if not self.exists():
|
|
887
|
-
self.log.warning(f"Trying to delete
|
|
916
|
+
self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
|
|
888
917
|
|
|
889
918
|
# Call the Class Method to delete the Model Group
|
|
890
919
|
ModelCore.managed_delete(model_group_name=self.name)
|
|
@@ -960,6 +989,27 @@ class ModelCore(Artifact):
|
|
|
960
989
|
self.log.warning(f"Could not determine model type for {self.model_name}!")
|
|
961
990
|
return ModelType.UNKNOWN
|
|
962
991
|
|
|
992
|
+
def _set_model_framework(self, model_framework: ModelFramework):
|
|
993
|
+
"""Internal: Set the Model Framework for this Model"""
|
|
994
|
+
self.model_framework = model_framework
|
|
995
|
+
self.upsert_workbench_meta({"workbench_model_framework": self.model_framework.value})
|
|
996
|
+
self.remove_health_tag("model_framework_unknown")
|
|
997
|
+
|
|
998
|
+
def _get_model_framework(self) -> ModelFramework:
|
|
999
|
+
"""Internal: Query the Workbench Metadata to get the model framework
|
|
1000
|
+
Returns:
|
|
1001
|
+
ModelFramework: The ModelFramework of this Model
|
|
1002
|
+
Notes:
|
|
1003
|
+
This is an internal method that should not be called directly
|
|
1004
|
+
Use the model_framework attribute instead
|
|
1005
|
+
"""
|
|
1006
|
+
model_framework = self.workbench_meta().get("workbench_model_framework")
|
|
1007
|
+
try:
|
|
1008
|
+
return ModelFramework(model_framework)
|
|
1009
|
+
except ValueError:
|
|
1010
|
+
self.log.warning(f"Could not determine model framework for {self.model_name}!")
|
|
1011
|
+
return ModelFramework.UNKNOWN
|
|
1012
|
+
|
|
963
1013
|
def _load_training_metrics(self):
|
|
964
1014
|
"""Internal: Retrieve the training metrics and Confusion Matrix for this model
|
|
965
1015
|
and load the data into the Workbench Metadata
|
|
@@ -1151,13 +1201,11 @@ if __name__ == "__main__":
|
|
|
1151
1201
|
# Grab a ModelCore object and pull some information from it
|
|
1152
1202
|
my_model = ModelCore("abalone-regression")
|
|
1153
1203
|
|
|
1154
|
-
# Call the various methods
|
|
1155
|
-
|
|
1156
1204
|
# Let's do a check/validation of the Model
|
|
1157
1205
|
print(f"Model Check: {my_model.exists()}")
|
|
1158
1206
|
|
|
1159
1207
|
# Make sure the model is 'ready'
|
|
1160
|
-
|
|
1208
|
+
my_model.onboard()
|
|
1161
1209
|
|
|
1162
1210
|
# Get the ARN of the Model Group
|
|
1163
1211
|
print(f"Model Group ARN: {my_model.group_arn()}")
|
|
@@ -1223,5 +1271,10 @@ if __name__ == "__main__":
|
|
|
1223
1271
|
# Delete the Model
|
|
1224
1272
|
# ModelCore.managed_delete("wine-classification")
|
|
1225
1273
|
|
|
1274
|
+
# Check the training view logic
|
|
1275
|
+
model = ModelCore("wine-class-test-251112-BW")
|
|
1276
|
+
training_view = model.training_view()
|
|
1277
|
+
print(f"Training View Name: {training_view.name}")
|
|
1278
|
+
|
|
1226
1279
|
# Check for a model that doesn't exist
|
|
1227
1280
|
my_model = ModelCore("empty-model-group")
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""ParameterStoreCore: Manages Workbench parameters in a Cloud Based Parameter Store."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
# Workbench Imports
|
|
6
|
+
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
7
|
+
|
|
8
|
+
# Workbench Bridges Import
|
|
9
|
+
from workbench_bridges.api import ParameterStore as BridgesParameterStore
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ParameterStoreCore(BridgesParameterStore):
|
|
13
|
+
"""ParameterStoreCore: Manages Workbench parameters in a Cloud Based Parameter Store.
|
|
14
|
+
|
|
15
|
+
Common Usage:
|
|
16
|
+
```python
|
|
17
|
+
params = ParameterStoreCore()
|
|
18
|
+
|
|
19
|
+
# List Parameters
|
|
20
|
+
params.list()
|
|
21
|
+
|
|
22
|
+
['/workbench/abalone_info',
|
|
23
|
+
'/workbench/my_data',
|
|
24
|
+
'/workbench/test',
|
|
25
|
+
'/workbench/pipelines/my_pipeline']
|
|
26
|
+
|
|
27
|
+
# Add Key
|
|
28
|
+
params.upsert("key", "value")
|
|
29
|
+
value = params.get("key")
|
|
30
|
+
|
|
31
|
+
# Add any data (lists, dictionaries, etc..)
|
|
32
|
+
my_data = {"key": "value", "number": 4.2, "list": [1,2,3]}
|
|
33
|
+
params.upsert("my_data", my_data)
|
|
34
|
+
|
|
35
|
+
# Retrieve data
|
|
36
|
+
return_value = params.get("my_data")
|
|
37
|
+
pprint(return_value)
|
|
38
|
+
|
|
39
|
+
{'key': 'value', 'list': [1, 2, 3], 'number': 4.2}
|
|
40
|
+
|
|
41
|
+
# Delete parameters
|
|
42
|
+
param_store.delete("my_data")
|
|
43
|
+
```
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self):
|
|
47
|
+
"""ParameterStoreCore Init Method"""
|
|
48
|
+
session = AWSAccountClamp().boto3_session
|
|
49
|
+
|
|
50
|
+
# Initialize parent with workbench config
|
|
51
|
+
super().__init__(boto3_session=session)
|
|
52
|
+
self.log = logging.getLogger("workbench")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
if __name__ == "__main__":
|
|
56
|
+
"""Exercise the ParameterStoreCore Class"""
|
|
57
|
+
|
|
58
|
+
# Create a ParameterStoreCore manager
|
|
59
|
+
param_store = ParameterStoreCore()
|
|
60
|
+
|
|
61
|
+
# List the parameters
|
|
62
|
+
print("Listing Parameters...")
|
|
63
|
+
print(param_store.list())
|
|
64
|
+
|
|
65
|
+
# Add a new parameter
|
|
66
|
+
param_store.upsert("/workbench/test", "value")
|
|
67
|
+
|
|
68
|
+
# Get the parameter
|
|
69
|
+
print(f"Getting parameter 'test': {param_store.get('/workbench/test')}")
|
|
70
|
+
|
|
71
|
+
# Add a dictionary as a parameter
|
|
72
|
+
sample_dict = {"key": "str_value", "awesome_value": 4.2}
|
|
73
|
+
param_store.upsert("/workbench/my_data", sample_dict)
|
|
74
|
+
|
|
75
|
+
# Retrieve the parameter as a dictionary
|
|
76
|
+
retrieved_value = param_store.get("/workbench/my_data")
|
|
77
|
+
print("Retrieved value:", retrieved_value)
|
|
78
|
+
|
|
79
|
+
# List the parameters
|
|
80
|
+
print("Listing Parameters...")
|
|
81
|
+
print(param_store.list())
|
|
82
|
+
|
|
83
|
+
# List the parameters with a prefix
|
|
84
|
+
print("Listing Parameters with prefix '/workbench':")
|
|
85
|
+
print(param_store.list("/workbench"))
|
|
86
|
+
|
|
87
|
+
# Delete the parameters
|
|
88
|
+
param_store.delete("/workbench/test")
|
|
89
|
+
param_store.delete("/workbench/my_data")
|
|
90
|
+
|
|
91
|
+
# Out of scope tests
|
|
92
|
+
param_store.upsert("test", "value")
|
|
93
|
+
param_store.delete("test")
|
|
94
|
+
|
|
95
|
+
# Recursive delete test
|
|
96
|
+
param_store.upsert("/workbench/test/test1", "value1")
|
|
97
|
+
param_store.upsert("/workbench/test/test2", "value2")
|
|
98
|
+
param_store.delete_recursive("workbench/test/")
|
|
@@ -123,7 +123,7 @@ class PipelineExecutor:
|
|
|
123
123
|
if "model" in workbench_objects and (not subset or "endpoint" in subset):
|
|
124
124
|
workbench_objects["model"].to_endpoint(**kwargs)
|
|
125
125
|
endpoint = Endpoint(kwargs["name"])
|
|
126
|
-
endpoint.auto_inference(
|
|
126
|
+
endpoint.auto_inference()
|
|
127
127
|
|
|
128
128
|
# Found something weird
|
|
129
129
|
else:
|