workbench 0.8.161__py3-none-any.whl → 0.8.192__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/proximity.py +143 -102
- workbench/algorithms/graph/light/proximity_graph.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +12 -0
- workbench/api/feature_set.py +4 -4
- workbench/api/meta.py +5 -2
- workbench/api/model.py +16 -12
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +168 -78
- workbench/core/artifacts/feature_set_core.py +72 -13
- workbench/core/artifacts/model_core.py +50 -15
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +9 -4
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +49 -53
- workbench/core/views/view.py +51 -1
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
- workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
- workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
- workbench/model_scripts/pytorch_model/pytorch.template +19 -20
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +7 -2
- workbench/model_scripts/uq_models/mapie.template +492 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/xgb_model.template +31 -40
- workbench/repl/workbench_shell.py +11 -6
- workbench/scripts/lambda_launcher.py +63 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +76 -30
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +283 -145
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/METADATA +4 -4
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/RECORD +81 -76
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/entry_points.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/pytorch_model/generated_model_script.py +0 -565
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
- workbench/model_scripts/xgb_model/generated_model_script.py +0 -477
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/WHEEL +0 -0
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/top_level.txt +0 -0
|
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
|
|
|
21
21
|
from workbench.utils.s3_utils import compute_s3_object_hash
|
|
22
22
|
from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
|
|
23
23
|
from workbench.utils.deprecated_utils import deprecated
|
|
24
|
+
from workbench.utils.model_utils import proximity_model
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class ModelType(Enum):
|
|
@@ -42,11 +43,11 @@ class ModelImages:
|
|
|
42
43
|
|
|
43
44
|
image_uris = {
|
|
44
45
|
# US East 1 images
|
|
45
|
-
("us-east-1", "
|
|
46
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-
|
|
46
|
+
("us-east-1", "training", "0.1", "x86_64"): (
|
|
47
|
+
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
|
|
47
48
|
),
|
|
48
|
-
("us-east-1", "
|
|
49
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-
|
|
49
|
+
("us-east-1", "inference", "0.1", "x86_64"): (
|
|
50
|
+
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
|
|
50
51
|
),
|
|
51
52
|
("us-east-1", "pytorch_training", "0.1", "x86_64"): (
|
|
52
53
|
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
|
|
@@ -55,11 +56,11 @@ class ModelImages:
|
|
|
55
56
|
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
56
57
|
),
|
|
57
58
|
# US West 2 images
|
|
58
|
-
("us-west-2", "
|
|
59
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-
|
|
59
|
+
("us-west-2", "training", "0.1", "x86_64"): (
|
|
60
|
+
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
|
|
60
61
|
),
|
|
61
|
-
("us-west-2", "
|
|
62
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-
|
|
62
|
+
("us-west-2", "inference", "0.1", "x86_64"): (
|
|
63
|
+
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
|
|
63
64
|
),
|
|
64
65
|
("us-west-2", "pytorch_training", "0.1", "x86_64"): (
|
|
65
66
|
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
|
|
@@ -68,12 +69,6 @@ class ModelImages:
|
|
|
68
69
|
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
69
70
|
),
|
|
70
71
|
# ARM64 images
|
|
71
|
-
("us-east-1", "xgb_inference", "0.1", "arm64"): (
|
|
72
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
73
|
-
),
|
|
74
|
-
("us-west-2", "xgb_inference", "0.1", "arm64"): (
|
|
75
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
76
|
-
),
|
|
77
72
|
# Meta Endpoint inference images
|
|
78
73
|
("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
|
|
79
74
|
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
|
|
@@ -597,6 +592,24 @@ class ModelCore(Artifact):
|
|
|
597
592
|
# Return the details
|
|
598
593
|
return details
|
|
599
594
|
|
|
595
|
+
# Training View for this model
|
|
596
|
+
def training_view(self):
|
|
597
|
+
"""Get the training view for this model"""
|
|
598
|
+
from workbench.core.artifacts.feature_set_core import FeatureSetCore
|
|
599
|
+
from workbench.core.views import View
|
|
600
|
+
|
|
601
|
+
# Grab our FeatureSet
|
|
602
|
+
fs = FeatureSetCore(self.get_input())
|
|
603
|
+
|
|
604
|
+
# See if we have a training view for this model
|
|
605
|
+
my_model_training_view = f"{self.name.replace('-', '_')}_training"
|
|
606
|
+
view = View(fs, my_model_training_view, auto_create_view=False)
|
|
607
|
+
if view.exists():
|
|
608
|
+
return view
|
|
609
|
+
else:
|
|
610
|
+
self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
|
|
611
|
+
return fs.view("training")
|
|
612
|
+
|
|
600
613
|
# Pipeline for this model
|
|
601
614
|
def get_pipeline(self) -> str:
|
|
602
615
|
"""Get the pipeline for this model"""
|
|
@@ -867,6 +880,14 @@ class ModelCore(Artifact):
|
|
|
867
880
|
shap_data[key] = self.df_store.get(df_location)
|
|
868
881
|
return shap_data or None
|
|
869
882
|
|
|
883
|
+
def cross_folds(self) -> dict:
|
|
884
|
+
"""Retrieve the cross-fold inference results(only works for XGBoost models)
|
|
885
|
+
|
|
886
|
+
Returns:
|
|
887
|
+
dict: Dictionary with the cross-fold inference results
|
|
888
|
+
"""
|
|
889
|
+
return self.param_store.get(f"/workbench/models/{self.name}/inference/cross_fold")
|
|
890
|
+
|
|
870
891
|
def supported_inference_instances(self) -> Optional[list]:
|
|
871
892
|
"""Retrieve the supported endpoint inference instance types
|
|
872
893
|
|
|
@@ -879,10 +900,24 @@ class ModelCore(Artifact):
|
|
|
879
900
|
except (KeyError, IndexError, TypeError):
|
|
880
901
|
return None
|
|
881
902
|
|
|
903
|
+
def publish_prox_model(self, prox_model_name: str = None, track_columns: list = None):
|
|
904
|
+
"""Create and publish a Proximity Model for this Model
|
|
905
|
+
|
|
906
|
+
Args:
|
|
907
|
+
prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
|
|
908
|
+
track_columns (list, optional): List of columns to track in the Proximity Model.
|
|
909
|
+
|
|
910
|
+
Returns:
|
|
911
|
+
Model: The published Proximity Model
|
|
912
|
+
"""
|
|
913
|
+
if prox_model_name is None:
|
|
914
|
+
prox_model_name = self.model_name + "-prox"
|
|
915
|
+
return proximity_model(self, prox_model_name, track_columns=track_columns)
|
|
916
|
+
|
|
882
917
|
def delete(self):
|
|
883
918
|
"""Delete the Model Packages and the Model Group"""
|
|
884
919
|
if not self.exists():
|
|
885
|
-
self.log.warning(f"Trying to delete
|
|
920
|
+
self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
|
|
886
921
|
|
|
887
922
|
# Call the Class Method to delete the Model Group
|
|
888
923
|
ModelCore.managed_delete(model_group_name=self.name)
|
|
@@ -2,12 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import json
|
|
5
|
-
from typing import Union
|
|
5
|
+
from typing import Union
|
|
6
6
|
import pandas as pd
|
|
7
|
-
from sagemaker import Predictor
|
|
8
7
|
from sagemaker.model_monitor import (
|
|
9
8
|
CronExpressionGenerator,
|
|
10
|
-
DataCaptureConfig,
|
|
11
9
|
DefaultModelMonitor,
|
|
12
10
|
DatasetFormat,
|
|
13
11
|
)
|
|
@@ -15,29 +13,32 @@ import awswrangler as wr
|
|
|
15
13
|
|
|
16
14
|
# Workbench Imports
|
|
17
15
|
from workbench.core.artifacts.endpoint_core import EndpointCore
|
|
16
|
+
from workbench.core.artifacts.data_capture_core import DataCaptureCore
|
|
18
17
|
from workbench.api import Model, FeatureSet
|
|
19
18
|
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
20
19
|
from workbench.utils.s3_utils import read_content_from_s3, upload_content_to_s3
|
|
21
20
|
from workbench.utils.datetime_utils import datetime_string
|
|
22
21
|
from workbench.utils.monitor_utils import (
|
|
23
|
-
process_data_capture,
|
|
24
22
|
get_monitor_json_data,
|
|
25
23
|
parse_monitoring_results,
|
|
26
24
|
preprocessing_script,
|
|
27
25
|
)
|
|
28
26
|
|
|
29
|
-
# Note:
|
|
27
|
+
# Note: These resources might come in handy when doing code refactoring
|
|
30
28
|
# https://github.com/aws-samples/amazon-sagemaker-from-idea-to-production/blob/master/06-monitoring.ipynb
|
|
31
29
|
# https://docs.aws.amazon.com/sagemaker/latest/dg/model-monitor-pre-and-post-processing.html
|
|
32
30
|
# https://github.com/aws/amazon-sagemaker-examples/blob/main/sagemaker_model_monitor/introduction/SageMaker-ModelMonitoring.ipynb
|
|
33
31
|
|
|
34
32
|
|
|
35
33
|
class MonitorCore:
|
|
34
|
+
"""Manages monitoring, baselines, and monitoring schedules for SageMaker endpoints"""
|
|
35
|
+
|
|
36
36
|
def __init__(self, endpoint_name, instance_type="ml.m5.large"):
|
|
37
37
|
"""MonitorCore Class
|
|
38
|
+
|
|
38
39
|
Args:
|
|
39
40
|
endpoint_name (str): Name of the endpoint to set up monitoring for
|
|
40
|
-
instance_type (str): Instance type to use for monitoring. Defaults to "ml.
|
|
41
|
+
instance_type (str): Instance type to use for monitoring. Defaults to "ml.m5.large".
|
|
41
42
|
"""
|
|
42
43
|
self.log = logging.getLogger("workbench")
|
|
43
44
|
self.endpoint_name = endpoint_name
|
|
@@ -46,7 +47,6 @@ class MonitorCore:
|
|
|
46
47
|
# Initialize Class Attributes
|
|
47
48
|
self.sagemaker_session = self.endpoint.sm_session
|
|
48
49
|
self.sagemaker_client = self.endpoint.sm_client
|
|
49
|
-
self.data_capture_path = self.endpoint.endpoint_data_capture_path
|
|
50
50
|
self.monitoring_path = self.endpoint.endpoint_monitoring_path
|
|
51
51
|
self.monitoring_schedule_name = f"{self.endpoint_name}-monitoring-schedule"
|
|
52
52
|
self.baseline_dir = f"{self.monitoring_path}/baseline"
|
|
@@ -57,6 +57,10 @@ class MonitorCore:
|
|
|
57
57
|
self.workbench_role_arn = AWSAccountClamp().aws_session.get_workbench_execution_role_arn()
|
|
58
58
|
self.instance_type = instance_type
|
|
59
59
|
|
|
60
|
+
# Create DataCaptureCore instance for composition
|
|
61
|
+
self.data_capture = DataCaptureCore(endpoint_name)
|
|
62
|
+
self.data_capture_path = self.data_capture.data_capture_path
|
|
63
|
+
|
|
60
64
|
# Check if a monitoring schedule already exists for this endpoint
|
|
61
65
|
existing_schedule = self.monitoring_schedule_exists()
|
|
62
66
|
|
|
@@ -74,23 +78,20 @@ class MonitorCore:
|
|
|
74
78
|
self.log.info(f"Initialized new model monitor for {self.endpoint_name}")
|
|
75
79
|
|
|
76
80
|
def summary(self) -> dict:
|
|
77
|
-
"""Return the summary of
|
|
81
|
+
"""Return the summary of monitoring configuration
|
|
78
82
|
|
|
79
83
|
Returns:
|
|
80
|
-
dict: Summary of
|
|
84
|
+
dict: Summary of monitoring status
|
|
81
85
|
"""
|
|
82
86
|
if self.endpoint.is_serverless():
|
|
83
87
|
return {
|
|
84
88
|
"endpoint_type": "serverless",
|
|
85
|
-
"data_capture": "not supported",
|
|
86
89
|
"baseline": "not supported",
|
|
87
90
|
"monitoring_schedule": "not supported",
|
|
88
91
|
}
|
|
89
92
|
else:
|
|
90
93
|
summary = {
|
|
91
94
|
"endpoint_type": "realtime",
|
|
92
|
-
"data_capture": self.data_capture_enabled(),
|
|
93
|
-
"capture_percent": self.data_capture_percent(),
|
|
94
95
|
"baseline": self.baseline_exists(),
|
|
95
96
|
"monitoring_schedule": self.monitoring_schedule_exists(),
|
|
96
97
|
"preprocessing": self.preprocessing_exists(),
|
|
@@ -103,22 +104,15 @@ class MonitorCore:
|
|
|
103
104
|
Returns:
|
|
104
105
|
dict: The monitoring details for the endpoint
|
|
105
106
|
"""
|
|
106
|
-
# Get the actual data capture path
|
|
107
|
-
actual_capture_path = self.data_capture_config()["DestinationS3Uri"]
|
|
108
|
-
if actual_capture_path != self.data_capture_path:
|
|
109
|
-
self.log.warning(
|
|
110
|
-
f"Data capture path mismatch: Expected {self.data_capture_path}, "
|
|
111
|
-
f"but found {actual_capture_path}. Using the actual path."
|
|
112
|
-
)
|
|
113
|
-
self.data_capture_path = actual_capture_path
|
|
114
107
|
result = self.summary()
|
|
115
108
|
info = {
|
|
116
|
-
"data_capture_path": self.data_capture_path if self.data_capture_enabled() else None,
|
|
117
|
-
"preprocessing_script_file": self.preprocessing_script_file if self.preprocessing_exists() else None,
|
|
118
109
|
"monitoring_schedule_status": "Not Scheduled",
|
|
119
110
|
}
|
|
120
111
|
result.update(info)
|
|
121
112
|
|
|
113
|
+
if self.preprocessing_exists():
|
|
114
|
+
result["preprocessing_script_file"] = self.preprocessing_script_file
|
|
115
|
+
|
|
122
116
|
if self.baseline_exists():
|
|
123
117
|
result.update(
|
|
124
118
|
{
|
|
@@ -144,7 +138,6 @@ class MonitorCore:
|
|
|
144
138
|
|
|
145
139
|
last_run = schedule_details.get("LastMonitoringExecutionSummary", {})
|
|
146
140
|
if last_run:
|
|
147
|
-
|
|
148
141
|
# If no inference was run since the last monitoring schedule, the
|
|
149
142
|
# status will be "Failed" with reason "Job inputs had no data",
|
|
150
143
|
# so we check for that and set the status to "No New Data"
|
|
@@ -162,187 +155,22 @@ class MonitorCore:
|
|
|
162
155
|
|
|
163
156
|
return result
|
|
164
157
|
|
|
165
|
-
def enable_data_capture(self, capture_percentage=100
|
|
166
|
-
"""
|
|
167
|
-
Enable data capture for the SageMaker endpoint.
|
|
158
|
+
def enable_data_capture(self, capture_percentage=100):
|
|
159
|
+
"""Enable data capture for the endpoint
|
|
168
160
|
|
|
169
161
|
Args:
|
|
170
|
-
capture_percentage (int): Percentage of
|
|
171
|
-
force (bool): If True, force reconfiguration even if data capture is already enabled.
|
|
162
|
+
capture_percentage (int): Percentage of requests to capture (0-100, default 100)
|
|
172
163
|
"""
|
|
173
|
-
# Early returns for cases where we can't/don't need to add data capture
|
|
174
164
|
if self.endpoint.is_serverless():
|
|
175
165
|
self.log.warning("Data capture is not supported for serverless endpoints.")
|
|
176
166
|
return
|
|
177
167
|
|
|
178
|
-
if self.
|
|
179
|
-
self.log.
|
|
180
|
-
return
|
|
181
|
-
|
|
182
|
-
# Get the current endpoint configuration name for later deletion
|
|
183
|
-
current_endpoint_config_name = self.endpoint.endpoint_config_name()
|
|
184
|
-
|
|
185
|
-
# Log the data capture operation
|
|
186
|
-
self.log.important(f"Enabling Data Capture for {self.endpoint_name} --> {self.data_capture_path}")
|
|
187
|
-
self.log.important("This normally redeploys the endpoint...")
|
|
188
|
-
|
|
189
|
-
# Create and apply the data capture configuration
|
|
190
|
-
data_capture_config = DataCaptureConfig(
|
|
191
|
-
enable_capture=True, # Required parameter
|
|
192
|
-
sampling_percentage=capture_percentage,
|
|
193
|
-
destination_s3_uri=self.data_capture_path,
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
# Update endpoint with the new capture configuration
|
|
197
|
-
Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
|
|
198
|
-
data_capture_config=data_capture_config
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
# Clean up old endpoint configuration
|
|
202
|
-
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
|
|
203
|
-
|
|
204
|
-
def data_capture_config(self):
|
|
205
|
-
"""
|
|
206
|
-
Returns the complete data capture configuration from the endpoint config.
|
|
207
|
-
Returns:
|
|
208
|
-
dict: Complete DataCaptureConfig from AWS, or None if not configured
|
|
209
|
-
"""
|
|
210
|
-
config_name = self.endpoint.endpoint_config_name()
|
|
211
|
-
response = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=config_name)
|
|
212
|
-
data_capture_config = response.get("DataCaptureConfig")
|
|
213
|
-
if not data_capture_config:
|
|
214
|
-
self.log.error(f"No data capture configuration found for endpoint config {config_name}")
|
|
215
|
-
return None
|
|
216
|
-
return data_capture_config
|
|
217
|
-
|
|
218
|
-
def disable_data_capture(self):
|
|
219
|
-
"""
|
|
220
|
-
Disable data capture for the SageMaker endpoint.
|
|
221
|
-
"""
|
|
222
|
-
# Early return if data capture isn't configured
|
|
223
|
-
if not self.data_capture_enabled():
|
|
224
|
-
self.log.important(f"Data capture is not currently enabled for {self.endpoint_name}.")
|
|
168
|
+
if self.data_capture.is_enabled():
|
|
169
|
+
self.log.info(f"Data capture is already enabled for {self.endpoint_name}.")
|
|
225
170
|
return
|
|
226
171
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
# Log the operation
|
|
231
|
-
self.log.important(f"Disabling Data Capture for {self.endpoint_name}")
|
|
232
|
-
self.log.important("This normally redeploys the endpoint...")
|
|
233
|
-
|
|
234
|
-
# Create a configuration with capture disabled
|
|
235
|
-
data_capture_config = DataCaptureConfig(enable_capture=False, destination_s3_uri=self.data_capture_path)
|
|
236
|
-
|
|
237
|
-
# Update endpoint with the new configuration
|
|
238
|
-
Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
|
|
239
|
-
data_capture_config=data_capture_config
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
# Clean up old endpoint configuration
|
|
243
|
-
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
|
|
244
|
-
|
|
245
|
-
def data_capture_enabled(self):
|
|
246
|
-
"""
|
|
247
|
-
Check if data capture is already configured on the endpoint.
|
|
248
|
-
Args:
|
|
249
|
-
capture_percentage (int): Expected data capture percentage.
|
|
250
|
-
Returns:
|
|
251
|
-
bool: True if data capture is already configured, False otherwise.
|
|
252
|
-
"""
|
|
253
|
-
try:
|
|
254
|
-
endpoint_config_name = self.endpoint.endpoint_config_name()
|
|
255
|
-
endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
|
256
|
-
data_capture_config = endpoint_config.get("DataCaptureConfig", {})
|
|
257
|
-
|
|
258
|
-
# Check if data capture is enabled and the percentage matches
|
|
259
|
-
is_enabled = data_capture_config.get("EnableCapture", False)
|
|
260
|
-
return is_enabled
|
|
261
|
-
except Exception as e:
|
|
262
|
-
self.log.error(f"Error checking data capture configuration: {e}")
|
|
263
|
-
return False
|
|
264
|
-
|
|
265
|
-
def data_capture_percent(self):
|
|
266
|
-
"""
|
|
267
|
-
Get the data capture percentage from the endpoint configuration.
|
|
268
|
-
|
|
269
|
-
Returns:
|
|
270
|
-
int: Data capture percentage if enabled, None otherwise.
|
|
271
|
-
"""
|
|
272
|
-
try:
|
|
273
|
-
endpoint_config_name = self.endpoint.endpoint_config_name()
|
|
274
|
-
endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
|
275
|
-
data_capture_config = endpoint_config.get("DataCaptureConfig", {})
|
|
276
|
-
|
|
277
|
-
# Check if data capture is enabled and return the percentage
|
|
278
|
-
if data_capture_config.get("EnableCapture", False):
|
|
279
|
-
return data_capture_config.get("InitialSamplingPercentage", 0)
|
|
280
|
-
else:
|
|
281
|
-
return None
|
|
282
|
-
except Exception as e:
|
|
283
|
-
self.log.error(f"Error checking data capture percentage: {e}")
|
|
284
|
-
return None
|
|
285
|
-
|
|
286
|
-
def get_captured_data(self, max_files=None, add_timestamp=True) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
287
|
-
"""
|
|
288
|
-
Read and process captured data from S3.
|
|
289
|
-
|
|
290
|
-
Args:
|
|
291
|
-
max_files (int, optional): Maximum number of files to process.
|
|
292
|
-
Defaults to None to process all files.
|
|
293
|
-
add_timestamp (bool, optional): Whether to add a timestamp column to the DataFrame.
|
|
294
|
-
|
|
295
|
-
Returns:
|
|
296
|
-
Tuple[pd.DataFrame, pd.DataFrame]: Processed input and output DataFrames.
|
|
297
|
-
"""
|
|
298
|
-
# List files in the specified S3 path
|
|
299
|
-
files = wr.s3.list_objects(self.data_capture_path)
|
|
300
|
-
if not files:
|
|
301
|
-
self.log.warning(f"No data capture files found in {self.data_capture_path}.")
|
|
302
|
-
return pd.DataFrame(), pd.DataFrame()
|
|
303
|
-
|
|
304
|
-
self.log.info(f"Found {len(files)} files in {self.data_capture_path}.")
|
|
305
|
-
|
|
306
|
-
# Sort files by timestamp (assuming the naming convention includes timestamp)
|
|
307
|
-
files.sort()
|
|
308
|
-
|
|
309
|
-
# Select files to process
|
|
310
|
-
if max_files is None:
|
|
311
|
-
files_to_process = files
|
|
312
|
-
self.log.info(f"Processing all {len(files)} files.")
|
|
313
|
-
else:
|
|
314
|
-
files_to_process = files[-max_files:] if files else []
|
|
315
|
-
self.log.info(f"Processing the {len(files_to_process)} most recent file(s).")
|
|
316
|
-
|
|
317
|
-
# Process each file
|
|
318
|
-
all_input_dfs = []
|
|
319
|
-
all_output_dfs = []
|
|
320
|
-
for file_path in files_to_process:
|
|
321
|
-
self.log.info(f"Processing {file_path}...")
|
|
322
|
-
try:
|
|
323
|
-
# Read the JSON lines file
|
|
324
|
-
df = wr.s3.read_json(path=file_path, lines=True)
|
|
325
|
-
if not df.empty:
|
|
326
|
-
input_df, output_df = process_data_capture(df)
|
|
327
|
-
# Generate a timestamp column if requested
|
|
328
|
-
if add_timestamp:
|
|
329
|
-
# Get file metadata to extract last modified time
|
|
330
|
-
file_metadata = wr.s3.describe_objects(path=file_path)
|
|
331
|
-
timestamp = file_metadata[file_path]["LastModified"]
|
|
332
|
-
output_df["timestamp"] = timestamp
|
|
333
|
-
|
|
334
|
-
# Append the processed DataFrames to the lists
|
|
335
|
-
all_input_dfs.append(input_df)
|
|
336
|
-
all_output_dfs.append(output_df)
|
|
337
|
-
except Exception as e:
|
|
338
|
-
self.log.warning(f"Error processing file {file_path}: {e}")
|
|
339
|
-
|
|
340
|
-
# Combine all DataFrames
|
|
341
|
-
if not all_input_dfs or not all_output_dfs:
|
|
342
|
-
self.log.warning("No valid data was processed from the captured files.")
|
|
343
|
-
return pd.DataFrame(), pd.DataFrame()
|
|
344
|
-
|
|
345
|
-
return pd.concat(all_input_dfs, ignore_index=True), pd.concat(all_output_dfs, ignore_index=True)
|
|
172
|
+
self.data_capture.enable(capture_percentage=capture_percentage)
|
|
173
|
+
self.log.important(f"Enabled data capture for {self.endpoint_name} at {self.data_capture_path}")
|
|
346
174
|
|
|
347
175
|
def baseline_exists(self) -> bool:
|
|
348
176
|
"""
|
|
@@ -533,6 +361,11 @@ class MonitorCore:
|
|
|
533
361
|
self.log.warning("If you want to create another one, delete existing schedule first.")
|
|
534
362
|
return
|
|
535
363
|
|
|
364
|
+
# Check if data capture is enabled, if not enable it
|
|
365
|
+
if not self.data_capture.is_enabled():
|
|
366
|
+
self.log.warning("Data capture is not enabled for this endpoint. Enabling it now...")
|
|
367
|
+
self.enable_data_capture(capture_percentage=100)
|
|
368
|
+
|
|
536
369
|
# Set up a NEW monitoring schedule
|
|
537
370
|
schedule_args = {
|
|
538
371
|
"monitor_schedule_name": self.monitoring_schedule_name,
|
|
@@ -577,33 +410,6 @@ class MonitorCore:
|
|
|
577
410
|
self.model_monitor.delete_monitoring_schedule()
|
|
578
411
|
self.log.important(f"Deleted monitoring schedule for {self.endpoint_name}.")
|
|
579
412
|
|
|
580
|
-
# Put this functionality into this class
|
|
581
|
-
"""
|
|
582
|
-
executions = my_monitor.list_executions()
|
|
583
|
-
latest_execution = executions[-1]
|
|
584
|
-
|
|
585
|
-
latest_execution.describe()['ProcessingJobStatus']
|
|
586
|
-
latest_execution.describe()['ExitMessage']
|
|
587
|
-
Here are the possible terminal states and what each of them means:
|
|
588
|
-
|
|
589
|
-
- Completed - This means the monitoring execution completed and no issues were found in the violations report.
|
|
590
|
-
- CompletedWithViolations - This means the execution completed, but constraint violations were detected.
|
|
591
|
-
- Failed - The monitoring execution failed, maybe due to client error
|
|
592
|
-
(perhaps incorrect role premissions) or infrastructure issues. Further
|
|
593
|
-
examination of the FailureReason and ExitMessage is necessary to identify what exactly happened.
|
|
594
|
-
- Stopped - job exceeded the max runtime or was manually stopped.
|
|
595
|
-
You can also get the S3 URI for the output with latest_execution.output.destination and analyze the results.
|
|
596
|
-
|
|
597
|
-
Visualize results
|
|
598
|
-
You can use the monitor object to gather reports for visualization:
|
|
599
|
-
|
|
600
|
-
suggested_constraints = my_monitor.suggested_constraints()
|
|
601
|
-
baseline_statistics = my_monitor.baseline_statistics()
|
|
602
|
-
|
|
603
|
-
latest_monitoring_violations = my_monitor.latest_monitoring_constraint_violations()
|
|
604
|
-
latest_monitoring_statistics = my_monitor.latest_monitoring_statistics()
|
|
605
|
-
"""
|
|
606
|
-
|
|
607
413
|
def get_monitoring_results(self, max_results=10) -> pd.DataFrame:
|
|
608
414
|
"""Get the results of monitoring executions
|
|
609
415
|
|
|
@@ -758,7 +564,7 @@ class MonitorCore:
|
|
|
758
564
|
Returns:
|
|
759
565
|
str: String representation of this MonitorCore object
|
|
760
566
|
"""
|
|
761
|
-
summary_dict =
|
|
567
|
+
summary_dict = self.summary()
|
|
762
568
|
summary_items = [f" {repr(key)}: {repr(value)}" for key, value in summary_dict.items()]
|
|
763
569
|
summary_str = f"{self.__class__.__name__}: {self.endpoint_name}\n" + ",\n".join(summary_items)
|
|
764
570
|
return summary_str
|
|
@@ -775,7 +581,6 @@ if __name__ == "__main__":
|
|
|
775
581
|
|
|
776
582
|
# Create the Class and test it out
|
|
777
583
|
endpoint_name = "abalone-regression-rt"
|
|
778
|
-
endpoint_name = "logd-dev-reg-rt"
|
|
779
584
|
my_endpoint = EndpointCore(endpoint_name)
|
|
780
585
|
if not my_endpoint.exists():
|
|
781
586
|
print(f"Endpoint {endpoint_name} does not exist.")
|
|
@@ -788,11 +593,10 @@ if __name__ == "__main__":
|
|
|
788
593
|
# Check the details of the monitoring class
|
|
789
594
|
pprint(mm.details())
|
|
790
595
|
|
|
791
|
-
# Enable data capture
|
|
792
|
-
mm.enable_data_capture()
|
|
596
|
+
# Enable data capture (if not already enabled)
|
|
597
|
+
mm.enable_data_capture(capture_percentage=100)
|
|
793
598
|
|
|
794
599
|
# Create a baseline for monitoring
|
|
795
|
-
# mm.create_baseline(recreate=True)
|
|
796
600
|
mm.create_baseline()
|
|
797
601
|
|
|
798
602
|
# Check the monitoring outputs
|
|
@@ -804,30 +608,11 @@ if __name__ == "__main__":
|
|
|
804
608
|
pprint(mm.get_constraints())
|
|
805
609
|
|
|
806
610
|
print("\nStatistics...")
|
|
807
|
-
print(mm.get_statistics())
|
|
611
|
+
print(str(mm.get_statistics())[:1000]) # Print only first 1000 characters
|
|
808
612
|
|
|
809
613
|
# Set up the monitoring schedule (if it doesn't already exist)
|
|
810
614
|
mm.create_monitoring_schedule()
|
|
811
615
|
|
|
812
|
-
#
|
|
813
|
-
# Test the data capture by running some predictions
|
|
814
|
-
#
|
|
815
|
-
|
|
816
|
-
# Make predictions on the Endpoint using the FeatureSet evaluation data
|
|
817
|
-
# pred_df = my_endpoint.auto_inference()
|
|
818
|
-
# print(pred_df.head())
|
|
819
|
-
|
|
820
|
-
# Check that data capture is working
|
|
821
|
-
input_df, output_df = mm.get_captured_data()
|
|
822
|
-
if input_df.empty or output_df.empty:
|
|
823
|
-
print("No data capture files found, for a new endpoint it may take a few minutes to start capturing data")
|
|
824
|
-
else:
|
|
825
|
-
print("Found data capture files")
|
|
826
|
-
print("Input")
|
|
827
|
-
print(input_df.head())
|
|
828
|
-
print("Output")
|
|
829
|
-
print(output_df.head())
|
|
830
|
-
|
|
831
616
|
# Test update_constraints (commented out for now)
|
|
832
617
|
# print("\nTesting constraint updates...")
|
|
833
618
|
# custom_constraints = {"sex": {"allowed_values": ["M", "F", "I"]}, "length": {"min": 0.0, "max": 1.0}}
|
|
@@ -846,7 +631,7 @@ if __name__ == "__main__":
|
|
|
846
631
|
print("\nTesting execution details retrieval...")
|
|
847
632
|
if not results_df.empty:
|
|
848
633
|
latest_execution_arn = results_df.iloc[0]["processing_job_arn"]
|
|
849
|
-
execution_details = mm.get_execution_details(latest_execution_arn)
|
|
634
|
+
execution_details = mm.get_execution_details(latest_execution_arn) if latest_execution_arn else None
|
|
850
635
|
if execution_details:
|
|
851
636
|
print(f"Execution details for {latest_execution_arn}:")
|
|
852
637
|
pprint(execution_details)
|
|
@@ -54,7 +54,11 @@ class AWSAccountClamp:
|
|
|
54
54
|
|
|
55
55
|
# Check our Assume Role
|
|
56
56
|
self.log.info("Checking Workbench Assumed Role...")
|
|
57
|
-
self.aws_session.assumed_role_info()
|
|
57
|
+
role_info = self.aws_session.assumed_role_info()
|
|
58
|
+
self.log.info(f"Assumed Role: {role_info}")
|
|
59
|
+
|
|
60
|
+
# Check if we have tag write permissions (if we don't, we are read-only)
|
|
61
|
+
self.read_only = not self.check_tag_permissions()
|
|
58
62
|
|
|
59
63
|
# Check our Workbench API Key and Load the License
|
|
60
64
|
self.log.info("Checking Workbench API License...")
|
|
@@ -138,6 +142,45 @@ class AWSAccountClamp:
|
|
|
138
142
|
"""
|
|
139
143
|
return self.boto3_session.client("sagemaker")
|
|
140
144
|
|
|
145
|
+
def check_tag_permissions(self):
|
|
146
|
+
"""Check if current role has permission to add tags to SageMaker endpoints.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
bool: True if AddTags is allowed, False otherwise
|
|
150
|
+
"""
|
|
151
|
+
try:
|
|
152
|
+
sagemaker = self.boto3_session.client("sagemaker")
|
|
153
|
+
|
|
154
|
+
# Use a non-existent endpoint name
|
|
155
|
+
fake_endpoint = "workbench-permission-check-dummy-endpoint"
|
|
156
|
+
|
|
157
|
+
# Try to add tags to the non-existent endpoint
|
|
158
|
+
sagemaker.add_tags(
|
|
159
|
+
ResourceArn=f"arn:aws:sagemaker:{self.region}:{self.account_id}:endpoint/{fake_endpoint}",
|
|
160
|
+
Tags=[{"Key": "PermissionCheck", "Value": "Test"}],
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# If we get here, we have permission (but endpoint doesn't exist)
|
|
164
|
+
return True
|
|
165
|
+
|
|
166
|
+
except ClientError as e:
|
|
167
|
+
error_code = e.response["Error"]["Code"]
|
|
168
|
+
|
|
169
|
+
# AccessDeniedException = no permission
|
|
170
|
+
if error_code == "AccessDeniedException":
|
|
171
|
+
self.log.debug("No AddTags permission (AccessDeniedException)")
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
# ResourceNotFound = we have permission, but endpoint doesn't exist
|
|
175
|
+
elif error_code in ["ResourceNotFound", "ValidationException"]:
|
|
176
|
+
self.log.debug("AddTags permission verified (resource not found)")
|
|
177
|
+
return True
|
|
178
|
+
|
|
179
|
+
# Unexpected error, assume no permission for safety
|
|
180
|
+
else:
|
|
181
|
+
self.log.debug(f"Unexpected error checking permissions: {error_code}")
|
|
182
|
+
return False
|
|
183
|
+
|
|
141
184
|
|
|
142
185
|
if __name__ == "__main__":
|
|
143
186
|
"""Exercise the AWS Account Clamp Class"""
|
|
@@ -162,3 +205,9 @@ if __name__ == "__main__":
|
|
|
162
205
|
print("\n\n*** AWS Sagemaker Session/Client Check ***")
|
|
163
206
|
sm_client = aws_account_clamp.sagemaker_client()
|
|
164
207
|
print(sm_client.list_feature_groups()["FeatureGroupSummaries"])
|
|
208
|
+
|
|
209
|
+
print("\n\n*** AWS Tag Permission Check ***")
|
|
210
|
+
if aws_account_clamp.check_tag_permissions():
|
|
211
|
+
print("Tag Permission Check Success...")
|
|
212
|
+
else:
|
|
213
|
+
print("Tag Permission Check Failed...")
|
|
@@ -196,7 +196,9 @@ class AWSMeta:
|
|
|
196
196
|
|
|
197
197
|
# Return the summary as a DataFrame
|
|
198
198
|
df = pd.DataFrame(data_summary).convert_dtypes()
|
|
199
|
-
|
|
199
|
+
if not df.empty:
|
|
200
|
+
df.sort_values(by="Created", ascending=False, inplace=True)
|
|
201
|
+
return df
|
|
200
202
|
|
|
201
203
|
def models(self, details: bool = False) -> pd.DataFrame:
|
|
202
204
|
"""Get a summary of the Models in AWS.
|
|
@@ -256,7 +258,9 @@ class AWSMeta:
|
|
|
256
258
|
|
|
257
259
|
# Return the summary as a DataFrame
|
|
258
260
|
df = pd.DataFrame(model_summary).convert_dtypes()
|
|
259
|
-
|
|
261
|
+
if not df.empty:
|
|
262
|
+
df.sort_values(by="Created", ascending=False, inplace=True)
|
|
263
|
+
return df
|
|
260
264
|
|
|
261
265
|
def endpoints(self, details: bool = False) -> pd.DataFrame:
|
|
262
266
|
"""Get a summary of the Endpoints in AWS.
|
|
@@ -308,7 +312,7 @@ class AWSMeta:
|
|
|
308
312
|
"Status": endpoint_details.get("EndpointStatus", "-"),
|
|
309
313
|
"Config": endpoint_details.get("EndpointConfigName", "-"),
|
|
310
314
|
"Variant": endpoint_details["config"]["variant"],
|
|
311
|
-
"Capture": str(endpoint_details.get("DataCaptureConfig", {}).get("EnableCapture", "
|
|
315
|
+
"Capture": str(endpoint_details.get("DataCaptureConfig", {}).get("EnableCapture", "-")),
|
|
312
316
|
"Samp(%)": str(endpoint_details.get("DataCaptureConfig", {}).get("CurrentSamplingPercentage", "-")),
|
|
313
317
|
"Tags": aws_tags.get("workbench_tags", "-"),
|
|
314
318
|
"Monitored": endpoint_details["monitored"],
|
|
@@ -317,7 +321,9 @@ class AWSMeta:
|
|
|
317
321
|
|
|
318
322
|
# Return the summary as a DataFrame
|
|
319
323
|
df = pd.DataFrame(data_summary).convert_dtypes()
|
|
320
|
-
|
|
324
|
+
if not df.empty:
|
|
325
|
+
df.sort_values(by="Created", ascending=False, inplace=True)
|
|
326
|
+
return df
|
|
321
327
|
|
|
322
328
|
def _endpoint_config_info(self, endpoint_config_name: str) -> dict:
|
|
323
329
|
"""Internal: Get the Endpoint Configuration information for the given endpoint config name.
|
|
@@ -657,7 +663,8 @@ class AWSMeta:
|
|
|
657
663
|
df = pd.DataFrame(data_summary).convert_dtypes()
|
|
658
664
|
|
|
659
665
|
# Sort by the Modified column
|
|
660
|
-
|
|
666
|
+
if not df.empty:
|
|
667
|
+
df = df.sort_values(by="Modified", ascending=False)
|
|
661
668
|
return df
|
|
662
669
|
|
|
663
670
|
def _aws_pipelines(self) -> pd.DataFrame:
|