workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +14 -12
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/compound.py +1 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +18 -5
- workbench/api/feature_set.py +121 -15
- workbench/api/meta.py +5 -2
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +55 -21
- workbench/api/monitor.py +1 -16
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +16 -8
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +382 -253
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -80
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +62 -40
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +649 -0
- workbench/model_scripts/chemprop/generated_model_script.py +649 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +20 -11
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +369 -401
- workbench/repl/workbench_shell.py +28 -19
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/scripts/training_test.py +85 -0
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +175 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +219 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +278 -79
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -219
- workbench/web_interface/components/model_plot.py +14 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +38 -74
- workbench/web_interface/components/plugins/scatter_plot.py +6 -10
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
- workbench-0.8.220.dist-info/entry_points.txt +11 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- workbench-0.8.162.dist-info/entry_points.txt +0 -5
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
|
|
|
21
21
|
from workbench.utils.s3_utils import compute_s3_object_hash
|
|
22
22
|
from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
|
|
23
23
|
from workbench.utils.deprecated_utils import deprecated
|
|
24
|
+
from workbench.utils.model_utils import published_proximity_model, get_model_hyperparameters
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class ModelType(Enum):
|
|
@@ -29,69 +30,64 @@ class ModelType(Enum):
|
|
|
29
30
|
CLASSIFIER = "classifier"
|
|
30
31
|
REGRESSOR = "regressor"
|
|
31
32
|
CLUSTERER = "clusterer"
|
|
32
|
-
TRANSFORMER = "transformer"
|
|
33
33
|
PROXIMITY = "proximity"
|
|
34
34
|
PROJECTION = "projection"
|
|
35
35
|
UQ_REGRESSOR = "uq_regressor"
|
|
36
36
|
ENSEMBLE_REGRESSOR = "ensemble_regressor"
|
|
37
|
+
TRANSFORMER = "transformer"
|
|
38
|
+
UNKNOWN = "unknown"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ModelFramework(Enum):
|
|
42
|
+
"""Enumerated Types for Workbench Model Frameworks"""
|
|
43
|
+
|
|
44
|
+
SKLEARN = "sklearn"
|
|
45
|
+
XGBOOST = "xgboost"
|
|
46
|
+
LIGHTGBM = "lightgbm"
|
|
47
|
+
PYTORCH = "pytorch"
|
|
48
|
+
CHEMPROP = "chemprop"
|
|
49
|
+
TRANSFORMER = "transformer"
|
|
50
|
+
META = "meta"
|
|
37
51
|
UNKNOWN = "unknown"
|
|
38
52
|
|
|
39
53
|
|
|
40
54
|
class ModelImages:
|
|
41
55
|
"""Class for retrieving workbench inference images"""
|
|
42
56
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
|
|
55
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
56
|
-
),
|
|
57
|
-
# US West 2 images
|
|
58
|
-
("us-west-2", "xgb_training", "0.1", "x86_64"): (
|
|
59
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
|
|
60
|
-
),
|
|
61
|
-
("us-west-2", "xgb_inference", "0.1", "x86_64"): (
|
|
62
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
|
|
63
|
-
),
|
|
64
|
-
("us-west-2", "pytorch_training", "0.1", "x86_64"): (
|
|
65
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
|
|
66
|
-
),
|
|
67
|
-
("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
|
|
68
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
69
|
-
),
|
|
70
|
-
# ARM64 images
|
|
71
|
-
("us-east-1", "xgb_inference", "0.1", "arm64"): (
|
|
72
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
73
|
-
),
|
|
74
|
-
("us-west-2", "xgb_inference", "0.1", "arm64"): (
|
|
75
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
76
|
-
),
|
|
77
|
-
# Meta Endpoint inference images
|
|
78
|
-
("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
|
|
79
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
|
|
80
|
-
),
|
|
81
|
-
("us-west-2", "meta-endpoint", "0.1", "x86_64"): (
|
|
82
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
|
|
83
|
-
),
|
|
57
|
+
# Account ID
|
|
58
|
+
ACCOUNT_ID = "507740646243"
|
|
59
|
+
|
|
60
|
+
# Image name mappings
|
|
61
|
+
IMAGE_NAMES = {
|
|
62
|
+
"training": "py312-general-ml-training",
|
|
63
|
+
"inference": "py312-general-ml-inference",
|
|
64
|
+
"pytorch_training": "py312-pytorch-training",
|
|
65
|
+
"pytorch_inference": "py312-pytorch-inference",
|
|
66
|
+
"meta_training": "py312-meta-training",
|
|
67
|
+
"meta_inference": "py312-meta-inference",
|
|
84
68
|
}
|
|
85
69
|
|
|
86
70
|
@classmethod
|
|
87
|
-
def get_image_uri(cls, region, image_type, version="
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
)
|
|
71
|
+
def get_image_uri(cls, region, image_type, version="latest", architecture="x86_64"):
|
|
72
|
+
"""
|
|
73
|
+
Dynamically construct ECR image URI.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
region: AWS region (e.g., 'us-east-1', 'us-west-2')
|
|
77
|
+
image_type: Type of image (e.g., 'training', 'inference', 'pytorch_training')
|
|
78
|
+
version: Image version (e.g., '0.1', '0.2' defaults to 'latest')
|
|
79
|
+
architecture: CPU architecture (default: 'x86_64', currently unused but kept for compatibility)
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
ECR image URI string
|
|
83
|
+
"""
|
|
84
|
+
if image_type not in cls.IMAGE_NAMES:
|
|
85
|
+
raise ValueError(f"Unknown image_type: {image_type}. Valid types: {list(cls.IMAGE_NAMES.keys())}")
|
|
86
|
+
|
|
87
|
+
image_name = cls.IMAGE_NAMES[image_type]
|
|
88
|
+
uri = f"{cls.ACCOUNT_ID}.dkr.ecr.{region}.amazonaws.com/aws-ml-images/{image_name}:{version}"
|
|
89
|
+
|
|
90
|
+
return uri
|
|
95
91
|
|
|
96
92
|
|
|
97
93
|
class ModelCore(Artifact):
|
|
@@ -105,11 +101,10 @@ class ModelCore(Artifact):
|
|
|
105
101
|
```
|
|
106
102
|
"""
|
|
107
103
|
|
|
108
|
-
def __init__(self, model_name: str,
|
|
104
|
+
def __init__(self, model_name: str, **kwargs):
|
|
109
105
|
"""ModelCore Initialization
|
|
110
106
|
Args:
|
|
111
107
|
model_name (str): Name of Model in Workbench.
|
|
112
|
-
model_type (ModelType, optional): Set this for newly created Models. Defaults to None.
|
|
113
108
|
**kwargs: Additional keyword arguments
|
|
114
109
|
"""
|
|
115
110
|
|
|
@@ -143,10 +138,8 @@ class ModelCore(Artifact):
|
|
|
143
138
|
self.latest_model = self.model_meta["ModelPackageList"][0]
|
|
144
139
|
self.description = self.latest_model.get("ModelPackageDescription", "-")
|
|
145
140
|
self.training_job_name = self._extract_training_job_name()
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
else:
|
|
149
|
-
self.model_type = self._get_model_type()
|
|
141
|
+
self.model_type = self._get_model_type()
|
|
142
|
+
self.model_framework = self._get_model_framework()
|
|
150
143
|
except (IndexError, KeyError):
|
|
151
144
|
self.log.critical(f"Model {self.model_name} appears to be malformed. Delete and recreate it!")
|
|
152
145
|
return
|
|
@@ -272,21 +265,25 @@ class ModelCore(Artifact):
|
|
|
272
265
|
else:
|
|
273
266
|
self.log.important(f"No inference data found for {self.model_name}!")
|
|
274
267
|
|
|
275
|
-
def get_inference_metrics(self, capture_name: str = "
|
|
268
|
+
def get_inference_metrics(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
|
|
276
269
|
"""Retrieve the inference performance metrics for this model
|
|
277
270
|
|
|
278
271
|
Args:
|
|
279
|
-
capture_name (str, optional): Specific capture_name
|
|
272
|
+
capture_name (str, optional): Specific capture_name (default: "auto")
|
|
280
273
|
Returns:
|
|
281
274
|
pd.DataFrame: DataFrame of the Model Metrics
|
|
282
275
|
|
|
283
276
|
Note:
|
|
284
|
-
If a capture_name isn't specified this will try to
|
|
277
|
+
If a capture_name isn't specified this will try to the 'first' available metrics
|
|
285
278
|
"""
|
|
286
279
|
# Try to get the auto_capture 'training_holdout' or the training
|
|
287
|
-
if capture_name == "
|
|
288
|
-
|
|
289
|
-
|
|
280
|
+
if capture_name == "auto":
|
|
281
|
+
metric_list = self.list_inference_runs()
|
|
282
|
+
if metric_list:
|
|
283
|
+
return self.get_inference_metrics(metric_list[0])
|
|
284
|
+
else:
|
|
285
|
+
self.log.warning(f"No performance metrics found for {self.model_name}!")
|
|
286
|
+
return None
|
|
290
287
|
|
|
291
288
|
# Grab the metrics captured during model training (could return None)
|
|
292
289
|
if capture_name == "model_training":
|
|
@@ -308,11 +305,11 @@ class ModelCore(Artifact):
|
|
|
308
305
|
self.log.warning(f"Performance metrics {capture_name} not found for {self.model_name}!")
|
|
309
306
|
return None
|
|
310
307
|
|
|
311
|
-
def confusion_matrix(self, capture_name: str = "
|
|
308
|
+
def confusion_matrix(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
|
|
312
309
|
"""Retrieve the confusion_matrix for this model
|
|
313
310
|
|
|
314
311
|
Args:
|
|
315
|
-
capture_name (str, optional): Specific capture_name or "training" (default: "
|
|
312
|
+
capture_name (str, optional): Specific capture_name or "training" (default: "auto")
|
|
316
313
|
Returns:
|
|
317
314
|
pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
|
|
318
315
|
"""
|
|
@@ -324,7 +321,7 @@ class ModelCore(Artifact):
|
|
|
324
321
|
raise ValueError(error_msg)
|
|
325
322
|
|
|
326
323
|
# Grab the metrics from the Workbench Metadata (try inference first, then training)
|
|
327
|
-
if capture_name == "
|
|
324
|
+
if capture_name == "auto":
|
|
328
325
|
cm = self.confusion_matrix("auto_inference")
|
|
329
326
|
return cm if cm is not None else self.confusion_matrix("model_training")
|
|
330
327
|
|
|
@@ -546,6 +543,17 @@ class ModelCore(Artifact):
|
|
|
546
543
|
else:
|
|
547
544
|
self.log.error(f"Model {self.model_name} is not a classifier!")
|
|
548
545
|
|
|
546
|
+
def summary(self) -> dict:
|
|
547
|
+
"""Summary information about this Model
|
|
548
|
+
|
|
549
|
+
Returns:
|
|
550
|
+
dict: Dictionary of summary information about this Model
|
|
551
|
+
"""
|
|
552
|
+
self.log.info("Computing Model Summary...")
|
|
553
|
+
summary = super().summary()
|
|
554
|
+
summary["hyperparameters"] = get_model_hyperparameters(self)
|
|
555
|
+
return summary
|
|
556
|
+
|
|
549
557
|
def details(self) -> dict:
|
|
550
558
|
"""Additional Details about this Model
|
|
551
559
|
|
|
@@ -570,6 +578,7 @@ class ModelCore(Artifact):
|
|
|
570
578
|
details["status"] = self.latest_model["ModelPackageStatus"]
|
|
571
579
|
details["approval_status"] = self.latest_model.get("ModelApprovalStatus", "unknown")
|
|
572
580
|
details["image"] = self.container_image().split("/")[-1] # Shorten the image uri
|
|
581
|
+
details["hyperparameters"] = get_model_hyperparameters(self)
|
|
573
582
|
|
|
574
583
|
# Grab the inference and container info
|
|
575
584
|
inference_spec = self.latest_model["InferenceSpecification"]
|
|
@@ -580,16 +589,6 @@ class ModelCore(Artifact):
|
|
|
580
589
|
details["transform_types"] = inference_spec["SupportedTransformInstanceTypes"]
|
|
581
590
|
details["content_types"] = inference_spec["SupportedContentTypes"]
|
|
582
591
|
details["response_types"] = inference_spec["SupportedResponseMIMETypes"]
|
|
583
|
-
details["model_metrics"] = self.get_inference_metrics()
|
|
584
|
-
if self.model_type == ModelType.CLASSIFIER:
|
|
585
|
-
details["confusion_matrix"] = self.confusion_matrix()
|
|
586
|
-
details["predictions"] = None
|
|
587
|
-
elif self.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
|
|
588
|
-
details["confusion_matrix"] = None
|
|
589
|
-
details["predictions"] = self.get_inference_predictions()
|
|
590
|
-
else:
|
|
591
|
-
details["confusion_matrix"] = None
|
|
592
|
-
details["predictions"] = None
|
|
593
592
|
|
|
594
593
|
# Grab the inference metadata
|
|
595
594
|
details["inference_meta"] = self.get_inference_metadata()
|
|
@@ -597,6 +596,24 @@ class ModelCore(Artifact):
|
|
|
597
596
|
# Return the details
|
|
598
597
|
return details
|
|
599
598
|
|
|
599
|
+
# Training View for this model
|
|
600
|
+
def training_view(self):
|
|
601
|
+
"""Get the training view for this model"""
|
|
602
|
+
from workbench.core.artifacts.feature_set_core import FeatureSetCore
|
|
603
|
+
from workbench.core.views import View
|
|
604
|
+
|
|
605
|
+
# Grab our FeatureSet
|
|
606
|
+
fs = FeatureSetCore(self.get_input())
|
|
607
|
+
|
|
608
|
+
# See if we have a training view for this model
|
|
609
|
+
my_model_training_view = f"{self.name.replace('-', '_')}_training".lower()
|
|
610
|
+
view = View(fs, my_model_training_view, auto_create_view=False)
|
|
611
|
+
if view.exists():
|
|
612
|
+
return view
|
|
613
|
+
else:
|
|
614
|
+
self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
|
|
615
|
+
return fs.view("training")
|
|
616
|
+
|
|
600
617
|
# Pipeline for this model
|
|
601
618
|
def get_pipeline(self) -> str:
|
|
602
619
|
"""Get the pipeline for this model"""
|
|
@@ -860,7 +877,7 @@ class ModelCore(Artifact):
|
|
|
860
877
|
return self.df_store.get(f"/workbench/models/{self.name}/shap_data")
|
|
861
878
|
else:
|
|
862
879
|
# Loop over the SHAP data and return a dict of DataFrames
|
|
863
|
-
shap_dfs = self.df_store.
|
|
880
|
+
shap_dfs = self.df_store.list(f"/workbench/models/{self.name}/shap_data")
|
|
864
881
|
shap_data = {}
|
|
865
882
|
for df_location in shap_dfs:
|
|
866
883
|
key = df_location.split("/")[-1]
|
|
@@ -879,10 +896,24 @@ class ModelCore(Artifact):
|
|
|
879
896
|
except (KeyError, IndexError, TypeError):
|
|
880
897
|
return None
|
|
881
898
|
|
|
899
|
+
def publish_prox_model(self, prox_model_name: str = None, include_all_columns: bool = False):
|
|
900
|
+
"""Create and publish a Proximity Model for this Model
|
|
901
|
+
|
|
902
|
+
Args:
|
|
903
|
+
prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
|
|
904
|
+
include_all_columns (bool): Include all DataFrame columns in results (default: False)
|
|
905
|
+
|
|
906
|
+
Returns:
|
|
907
|
+
Model: The published Proximity Model
|
|
908
|
+
"""
|
|
909
|
+
if prox_model_name is None:
|
|
910
|
+
prox_model_name = self.model_name + "-prox"
|
|
911
|
+
return published_proximity_model(self, prox_model_name, include_all_columns=include_all_columns)
|
|
912
|
+
|
|
882
913
|
def delete(self):
|
|
883
914
|
"""Delete the Model Packages and the Model Group"""
|
|
884
915
|
if not self.exists():
|
|
885
|
-
self.log.warning(f"Trying to delete
|
|
916
|
+
self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
|
|
886
917
|
|
|
887
918
|
# Call the Class Method to delete the Model Group
|
|
888
919
|
ModelCore.managed_delete(model_group_name=self.name)
|
|
@@ -958,6 +989,27 @@ class ModelCore(Artifact):
|
|
|
958
989
|
self.log.warning(f"Could not determine model type for {self.model_name}!")
|
|
959
990
|
return ModelType.UNKNOWN
|
|
960
991
|
|
|
992
|
+
def _set_model_framework(self, model_framework: ModelFramework):
|
|
993
|
+
"""Internal: Set the Model Framework for this Model"""
|
|
994
|
+
self.model_framework = model_framework
|
|
995
|
+
self.upsert_workbench_meta({"workbench_model_framework": self.model_framework.value})
|
|
996
|
+
self.remove_health_tag("model_framework_unknown")
|
|
997
|
+
|
|
998
|
+
def _get_model_framework(self) -> ModelFramework:
|
|
999
|
+
"""Internal: Query the Workbench Metadata to get the model framework
|
|
1000
|
+
Returns:
|
|
1001
|
+
ModelFramework: The ModelFramework of this Model
|
|
1002
|
+
Notes:
|
|
1003
|
+
This is an internal method that should not be called directly
|
|
1004
|
+
Use the model_framework attribute instead
|
|
1005
|
+
"""
|
|
1006
|
+
model_framework = self.workbench_meta().get("workbench_model_framework")
|
|
1007
|
+
try:
|
|
1008
|
+
return ModelFramework(model_framework)
|
|
1009
|
+
except ValueError:
|
|
1010
|
+
self.log.warning(f"Could not determine model framework for {self.model_name}!")
|
|
1011
|
+
return ModelFramework.UNKNOWN
|
|
1012
|
+
|
|
961
1013
|
def _load_training_metrics(self):
|
|
962
1014
|
"""Internal: Retrieve the training metrics and Confusion Matrix for this model
|
|
963
1015
|
and load the data into the Workbench Metadata
|
|
@@ -1149,13 +1201,11 @@ if __name__ == "__main__":
|
|
|
1149
1201
|
# Grab a ModelCore object and pull some information from it
|
|
1150
1202
|
my_model = ModelCore("abalone-regression")
|
|
1151
1203
|
|
|
1152
|
-
# Call the various methods
|
|
1153
|
-
|
|
1154
1204
|
# Let's do a check/validation of the Model
|
|
1155
1205
|
print(f"Model Check: {my_model.exists()}")
|
|
1156
1206
|
|
|
1157
1207
|
# Make sure the model is 'ready'
|
|
1158
|
-
|
|
1208
|
+
my_model.onboard()
|
|
1159
1209
|
|
|
1160
1210
|
# Get the ARN of the Model Group
|
|
1161
1211
|
print(f"Model Group ARN: {my_model.group_arn()}")
|
|
@@ -1221,5 +1271,10 @@ if __name__ == "__main__":
|
|
|
1221
1271
|
# Delete the Model
|
|
1222
1272
|
# ModelCore.managed_delete("wine-classification")
|
|
1223
1273
|
|
|
1274
|
+
# Check the training view logic
|
|
1275
|
+
model = ModelCore("wine-class-test-251112-BW")
|
|
1276
|
+
training_view = model.training_view()
|
|
1277
|
+
print(f"Training View Name: {training_view.name}")
|
|
1278
|
+
|
|
1224
1279
|
# Check for a model that doesn't exist
|
|
1225
1280
|
my_model = ModelCore("empty-model-group")
|