workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +14 -12
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/compound.py +1 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +18 -5
- workbench/api/feature_set.py +121 -15
- workbench/api/meta.py +5 -2
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +55 -21
- workbench/api/monitor.py +1 -16
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +16 -8
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +382 -253
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -80
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +62 -40
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +649 -0
- workbench/model_scripts/chemprop/generated_model_script.py +649 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +20 -11
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +369 -401
- workbench/repl/workbench_shell.py +28 -19
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/scripts/training_test.py +85 -0
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +175 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +219 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +278 -79
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -219
- workbench/web_interface/components/model_plot.py +14 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +38 -74
- workbench/web_interface/components/plugins/scatter_plot.py +6 -10
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
- workbench-0.8.220.dist-info/entry_points.txt +11 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- workbench-0.8.162.dist-info/entry_points.txt +0 -5
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
workbench/api/parameter_store.py
CHANGED
|
@@ -1,13 +1,10 @@
|
|
|
1
1
|
"""ParameterStore: Manages Workbench parameters in a Cloud Based Parameter Store."""
|
|
2
2
|
|
|
3
|
-
from typing import Union
|
|
4
|
-
import logging
|
|
5
|
-
|
|
6
3
|
# Workbench Imports
|
|
7
|
-
from workbench.core.
|
|
4
|
+
from workbench.core.artifacts.parameter_store_core import ParameterStoreCore
|
|
8
5
|
|
|
9
6
|
|
|
10
|
-
class ParameterStore(
|
|
7
|
+
class ParameterStore(ParameterStoreCore):
|
|
11
8
|
"""ParameterStore: Manages Workbench parameters in a Cloud Based Parameter Store.
|
|
12
9
|
|
|
13
10
|
Common Usage:
|
|
@@ -43,56 +40,10 @@ class ParameterStore(AWSParameterStore):
|
|
|
43
40
|
|
|
44
41
|
def __init__(self):
|
|
45
42
|
"""ParameterStore Init Method"""
|
|
46
|
-
self.log = logging.getLogger("workbench")
|
|
47
43
|
|
|
48
|
-
# Initialize
|
|
44
|
+
# Initialize parent class
|
|
49
45
|
super().__init__()
|
|
50
46
|
|
|
51
|
-
def list(self, prefix: str = None) -> list:
|
|
52
|
-
"""List all parameters in the AWS Parameter Store, optionally filtering by a prefix.
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
prefix (str, optional): A prefix to filter the parameters by. Defaults to None.
|
|
56
|
-
|
|
57
|
-
Returns:
|
|
58
|
-
list: A list of parameter names and details.
|
|
59
|
-
"""
|
|
60
|
-
return super().list(prefix=prefix)
|
|
61
|
-
|
|
62
|
-
def get(self, name: str, warn: bool = True, decrypt: bool = True) -> Union[str, list, dict, None]:
|
|
63
|
-
"""Retrieve a parameter value from the AWS Parameter Store.
|
|
64
|
-
|
|
65
|
-
Args:
|
|
66
|
-
name (str): The name of the parameter to retrieve.
|
|
67
|
-
warn (bool): Whether to log a warning if the parameter is not found.
|
|
68
|
-
decrypt (bool): Whether to decrypt secure string parameters.
|
|
69
|
-
|
|
70
|
-
Returns:
|
|
71
|
-
Union[str, list, dict, None]: The value of the parameter or None if not found.
|
|
72
|
-
"""
|
|
73
|
-
return super().get(name=name, warn=warn, decrypt=decrypt)
|
|
74
|
-
|
|
75
|
-
def upsert(self, name: str, value):
|
|
76
|
-
"""Insert or update a parameter in the AWS Parameter Store.
|
|
77
|
-
|
|
78
|
-
Args:
|
|
79
|
-
name (str): The name of the parameter.
|
|
80
|
-
value (str | list | dict): The value of the parameter.
|
|
81
|
-
"""
|
|
82
|
-
super().upsert(name=name, value=value)
|
|
83
|
-
|
|
84
|
-
def delete(self, name: str):
|
|
85
|
-
"""Delete a parameter from the AWS Parameter Store.
|
|
86
|
-
|
|
87
|
-
Args:
|
|
88
|
-
name (str): The name of the parameter to delete.
|
|
89
|
-
"""
|
|
90
|
-
super().delete(name=name)
|
|
91
|
-
|
|
92
|
-
def __repr__(self):
|
|
93
|
-
"""Return a string representation of the ParameterStore object."""
|
|
94
|
-
return super().__repr__()
|
|
95
|
-
|
|
96
47
|
|
|
97
48
|
if __name__ == "__main__":
|
|
98
49
|
"""Exercise the ParameterStore Class"""
|
workbench/cached/cached_model.py
CHANGED
|
@@ -72,11 +72,11 @@ class CachedModel(CachedArtifactMixin, ModelCore):
|
|
|
72
72
|
return super().list_inference_runs()
|
|
73
73
|
|
|
74
74
|
@CachedArtifactMixin.cache_result
|
|
75
|
-
def get_inference_metrics(self, capture_name: str = "
|
|
75
|
+
def get_inference_metrics(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
|
|
76
76
|
"""Retrieve the captured prediction results for this model
|
|
77
77
|
|
|
78
78
|
Args:
|
|
79
|
-
capture_name (str, optional): Specific capture_name (default:
|
|
79
|
+
capture_name (str, optional): Specific capture_name (default: auto)
|
|
80
80
|
|
|
81
81
|
Returns:
|
|
82
82
|
pd.DataFrame: DataFrame of the Captured Metrics (might be None)
|
|
@@ -101,11 +101,11 @@ class CachedModel(CachedArtifactMixin, ModelCore):
|
|
|
101
101
|
return df
|
|
102
102
|
|
|
103
103
|
@CachedArtifactMixin.cache_result
|
|
104
|
-
def confusion_matrix(self, capture_name: str = "
|
|
104
|
+
def confusion_matrix(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
|
|
105
105
|
"""Retrieve the confusion matrix for the model
|
|
106
106
|
|
|
107
107
|
Args:
|
|
108
|
-
capture_name (str, optional): Specific capture_name (default:
|
|
108
|
+
capture_name (str, optional): Specific capture_name (default: auto)
|
|
109
109
|
|
|
110
110
|
Returns:
|
|
111
111
|
pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
|
|
@@ -15,7 +15,16 @@ from .artifact import Artifact
|
|
|
15
15
|
from .athena_source import AthenaSource
|
|
16
16
|
from .data_source_abstract import DataSourceAbstract
|
|
17
17
|
from .feature_set_core import FeatureSetCore
|
|
18
|
-
from .model_core import ModelCore, ModelType
|
|
18
|
+
from .model_core import ModelCore, ModelType, ModelFramework
|
|
19
19
|
from .endpoint_core import EndpointCore
|
|
20
20
|
|
|
21
|
-
__all__ = [
|
|
21
|
+
__all__ = [
|
|
22
|
+
"Artifact",
|
|
23
|
+
"AthenaSource",
|
|
24
|
+
"DataSourceAbstract",
|
|
25
|
+
"FeatureSetCore",
|
|
26
|
+
"ModelCore",
|
|
27
|
+
"ModelType",
|
|
28
|
+
"ModelFramework",
|
|
29
|
+
"EndpointCore",
|
|
30
|
+
]
|
|
@@ -8,8 +8,8 @@ from typing import Union
|
|
|
8
8
|
|
|
9
9
|
# Workbench Imports
|
|
10
10
|
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
11
|
-
from workbench.core.
|
|
12
|
-
from workbench.core.
|
|
11
|
+
from workbench.core.artifacts.parameter_store_core import ParameterStoreCore
|
|
12
|
+
from workbench.core.artifacts.df_store_core import DFStoreCore
|
|
13
13
|
from workbench.utils.aws_utils import dict_to_aws_tags
|
|
14
14
|
from workbench.utils.config_manager import ConfigManager, FatalConfigError
|
|
15
15
|
from workbench.core.cloud_platform.cloud_meta import CloudMeta
|
|
@@ -48,11 +48,11 @@ class Artifact(ABC):
|
|
|
48
48
|
tag_delimiter = "::"
|
|
49
49
|
|
|
50
50
|
# Grab our Dataframe Cache Storage
|
|
51
|
-
df_cache =
|
|
51
|
+
df_cache = DFStoreCore(path_prefix="/workbench/dataframe_cache")
|
|
52
52
|
|
|
53
53
|
# Artifact may want to use the Parameter Store or Dataframe Store
|
|
54
|
-
param_store =
|
|
55
|
-
df_store =
|
|
54
|
+
param_store = ParameterStoreCore()
|
|
55
|
+
df_store = DFStoreCore()
|
|
56
56
|
|
|
57
57
|
def __init__(self, name: str, use_cached_meta: bool = False):
|
|
58
58
|
"""Initialize the Artifact Base Class
|
|
@@ -236,6 +236,12 @@ class Artifact(ABC):
|
|
|
236
236
|
This functionality will work for FeatureSets, Models, and Endpoints
|
|
237
237
|
but not for DataSources. The DataSource class overrides this method.
|
|
238
238
|
"""
|
|
239
|
+
|
|
240
|
+
# Check for ReadOnly Role
|
|
241
|
+
if self.aws_account_clamp.read_only:
|
|
242
|
+
self.log.info("Cannot add metadata with a ReadOnly Permissions...")
|
|
243
|
+
return
|
|
244
|
+
|
|
239
245
|
# Sanity check
|
|
240
246
|
aws_arn = self.arn()
|
|
241
247
|
if aws_arn is None:
|
|
@@ -444,10 +450,12 @@ class Artifact(ABC):
|
|
|
444
450
|
|
|
445
451
|
if __name__ == "__main__":
|
|
446
452
|
"""Exercise the Artifact Class"""
|
|
447
|
-
from workbench.api
|
|
448
|
-
|
|
453
|
+
from workbench.api import DataSource, FeatureSet, Endpoint
|
|
454
|
+
|
|
455
|
+
# Grab an Endpoint (which is a subclass of Artifact)
|
|
456
|
+
end = Endpoint("wine-classification")
|
|
449
457
|
|
|
450
|
-
#
|
|
458
|
+
# Grab a DataSource (which is a subclass of Artifact)
|
|
451
459
|
data_source = DataSource("test_data")
|
|
452
460
|
|
|
453
461
|
# Just some random tests
|
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
"""DataCaptureCore class for managing SageMaker endpoint data capture"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import re
|
|
5
|
+
import time
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Tuple
|
|
8
|
+
import pandas as pd
|
|
9
|
+
from sagemaker import Predictor
|
|
10
|
+
from sagemaker.model_monitor import DataCaptureConfig
|
|
11
|
+
import awswrangler as wr
|
|
12
|
+
|
|
13
|
+
# Workbench Imports
|
|
14
|
+
from workbench.core.artifacts.endpoint_core import EndpointCore
|
|
15
|
+
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
16
|
+
from workbench.utils.monitor_utils import process_data_capture
|
|
17
|
+
|
|
18
|
+
# Setup logging
|
|
19
|
+
log = logging.getLogger("workbench")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DataCaptureCore:
|
|
23
|
+
"""Manages data capture configuration and retrieval for SageMaker endpoints"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, endpoint_name: str):
|
|
26
|
+
"""DataCaptureCore Class
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
endpoint_name (str): Name of the endpoint to manage data capture for
|
|
30
|
+
"""
|
|
31
|
+
self.log = logging.getLogger("workbench")
|
|
32
|
+
self.endpoint_name = endpoint_name
|
|
33
|
+
self.endpoint = EndpointCore(self.endpoint_name)
|
|
34
|
+
|
|
35
|
+
# Initialize Class Attributes
|
|
36
|
+
self.sagemaker_session = self.endpoint.sm_session
|
|
37
|
+
self.sagemaker_client = self.endpoint.sm_client
|
|
38
|
+
self.data_capture_path = self.endpoint.endpoint_data_capture_path
|
|
39
|
+
self.workbench_role_arn = AWSAccountClamp().aws_session.get_workbench_execution_role_arn()
|
|
40
|
+
|
|
41
|
+
def summary(self) -> dict:
|
|
42
|
+
"""Return the summary of data capture configuration
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
dict: Summary of data capture status
|
|
46
|
+
"""
|
|
47
|
+
if self.endpoint.is_serverless():
|
|
48
|
+
return {"endpoint_type": "serverless", "data_capture": "not supported"}
|
|
49
|
+
else:
|
|
50
|
+
return {
|
|
51
|
+
"endpoint_type": "realtime",
|
|
52
|
+
"data_capture_enabled": self.is_enabled(),
|
|
53
|
+
"capture_percentage": self.capture_percentage(),
|
|
54
|
+
"capture_modes": self.capture_modes() if self.is_enabled() else [],
|
|
55
|
+
"data_capture_path": self.data_capture_path if self.is_enabled() else None,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
def enable(self, capture_percentage=100, capture_options=None, force_redeploy=False):
|
|
59
|
+
"""
|
|
60
|
+
Enable data capture for the SageMaker endpoint.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
capture_percentage (int): Percentage of data to capture. Defaults to 100.
|
|
64
|
+
capture_options (list): List of what to capture - ["REQUEST"], ["RESPONSE"], or ["REQUEST", "RESPONSE"].
|
|
65
|
+
Defaults to ["REQUEST", "RESPONSE"] to capture both.
|
|
66
|
+
force_redeploy (bool): If True, force redeployment even if data capture is already enabled.
|
|
67
|
+
"""
|
|
68
|
+
# Early returns for cases where we can't/don't need to add data capture
|
|
69
|
+
if self.endpoint.is_serverless():
|
|
70
|
+
self.log.warning("Data capture is not supported for serverless endpoints.")
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
# Default to capturing both if not specified
|
|
74
|
+
if capture_options is None:
|
|
75
|
+
capture_options = ["REQUEST", "RESPONSE"]
|
|
76
|
+
|
|
77
|
+
# Validate capture_options
|
|
78
|
+
valid_options = {"REQUEST", "RESPONSE"}
|
|
79
|
+
if not all(opt in valid_options for opt in capture_options):
|
|
80
|
+
self.log.error("Invalid capture_options. Must be a list containing 'REQUEST' and/or 'RESPONSE'")
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
if self.is_enabled() and not force_redeploy:
|
|
84
|
+
self.log.important(f"Data capture already configured for {self.endpoint_name}.")
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
# Get the current endpoint configuration name for later deletion
|
|
88
|
+
current_endpoint_config_name = self.endpoint.endpoint_config_name()
|
|
89
|
+
|
|
90
|
+
# Log the data capture operation
|
|
91
|
+
self.log.important(f"Enabling Data Capture for {self.endpoint_name} --> {self.data_capture_path}")
|
|
92
|
+
self.log.important(f"Capturing: {', '.join(capture_options)} at {capture_percentage}% sampling")
|
|
93
|
+
self.log.important("This will redeploy the endpoint...")
|
|
94
|
+
|
|
95
|
+
# Create and apply the data capture configuration
|
|
96
|
+
data_capture_config = DataCaptureConfig(
|
|
97
|
+
enable_capture=True,
|
|
98
|
+
sampling_percentage=capture_percentage,
|
|
99
|
+
destination_s3_uri=self.data_capture_path,
|
|
100
|
+
capture_options=capture_options,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Update endpoint with the new capture configuration
|
|
104
|
+
Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
|
|
105
|
+
data_capture_config=data_capture_config
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Clean up old endpoint configuration
|
|
109
|
+
try:
|
|
110
|
+
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
|
|
111
|
+
self.log.info(f"Deleted old endpoint configuration: {current_endpoint_config_name}")
|
|
112
|
+
except Exception as e:
|
|
113
|
+
self.log.warning(f"Could not delete old endpoint configuration {current_endpoint_config_name}: {e}")
|
|
114
|
+
|
|
115
|
+
def disable(self):
|
|
116
|
+
"""
|
|
117
|
+
Disable data capture for the SageMaker endpoint.
|
|
118
|
+
"""
|
|
119
|
+
# Early return if data capture isn't configured
|
|
120
|
+
if not self.is_enabled():
|
|
121
|
+
self.log.important(f"Data capture is not currently enabled for {self.endpoint_name}.")
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
# Get the current endpoint configuration name for later deletion
|
|
125
|
+
current_endpoint_config_name = self.endpoint.endpoint_config_name()
|
|
126
|
+
|
|
127
|
+
# Log the operation
|
|
128
|
+
self.log.important(f"Disabling Data Capture for {self.endpoint_name}")
|
|
129
|
+
self.log.important("This normally redeploys the endpoint...")
|
|
130
|
+
|
|
131
|
+
# Create a configuration with capture disabled
|
|
132
|
+
data_capture_config = DataCaptureConfig(enable_capture=False, destination_s3_uri=self.data_capture_path)
|
|
133
|
+
|
|
134
|
+
# Update endpoint with the new configuration
|
|
135
|
+
Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
|
|
136
|
+
data_capture_config=data_capture_config
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Clean up old endpoint configuration
|
|
140
|
+
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
|
|
141
|
+
|
|
142
|
+
def is_enabled(self) -> bool:
|
|
143
|
+
"""
|
|
144
|
+
Check if data capture is enabled on the endpoint.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
bool: True if data capture is enabled, False otherwise.
|
|
148
|
+
"""
|
|
149
|
+
try:
|
|
150
|
+
endpoint_config_name = self.endpoint.endpoint_config_name()
|
|
151
|
+
endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
|
152
|
+
data_capture_config = endpoint_config.get("DataCaptureConfig", {})
|
|
153
|
+
|
|
154
|
+
# Check if data capture is enabled
|
|
155
|
+
is_enabled = data_capture_config.get("EnableCapture", False)
|
|
156
|
+
return is_enabled
|
|
157
|
+
except Exception as e:
|
|
158
|
+
self.log.error(f"Error checking data capture configuration: {e}")
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
def capture_percentage(self) -> int:
|
|
162
|
+
"""
|
|
163
|
+
Get the data capture percentage from the endpoint configuration.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
int: Data capture percentage if enabled, None otherwise.
|
|
167
|
+
"""
|
|
168
|
+
try:
|
|
169
|
+
endpoint_config_name = self.endpoint.endpoint_config_name()
|
|
170
|
+
endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
|
171
|
+
data_capture_config = endpoint_config.get("DataCaptureConfig", {})
|
|
172
|
+
|
|
173
|
+
# Check if data capture is enabled and return the percentage
|
|
174
|
+
if data_capture_config.get("EnableCapture", False):
|
|
175
|
+
return data_capture_config.get("InitialSamplingPercentage", 0)
|
|
176
|
+
else:
|
|
177
|
+
return None
|
|
178
|
+
except Exception as e:
|
|
179
|
+
self.log.error(f"Error checking data capture percentage: {e}")
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
def get_config(self) -> dict:
|
|
183
|
+
"""
|
|
184
|
+
Returns the complete data capture configuration from the endpoint config.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
dict: Complete DataCaptureConfig from AWS, or None if not configured
|
|
188
|
+
"""
|
|
189
|
+
config_name = self.endpoint.endpoint_config_name()
|
|
190
|
+
response = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=config_name)
|
|
191
|
+
data_capture_config = response.get("DataCaptureConfig")
|
|
192
|
+
if not data_capture_config:
|
|
193
|
+
self.log.error(f"No data capture configuration found for endpoint config {config_name}")
|
|
194
|
+
return None
|
|
195
|
+
return data_capture_config
|
|
196
|
+
|
|
197
|
+
def capture_modes(self) -> list:
|
|
198
|
+
"""Get the current capture modes (REQUEST/RESPONSE)"""
|
|
199
|
+
if not self.is_enabled():
|
|
200
|
+
return []
|
|
201
|
+
|
|
202
|
+
config = self.get_config()
|
|
203
|
+
if not config:
|
|
204
|
+
return []
|
|
205
|
+
|
|
206
|
+
capture_options = config.get("CaptureOptions", [])
|
|
207
|
+
modes = [opt.get("CaptureMode") for opt in capture_options]
|
|
208
|
+
return ["REQUEST" if m == "Input" else "RESPONSE" for m in modes if m]
|
|
209
|
+
|
|
210
|
+
def get_captured_data(self, from_date: str = None, add_timestamp: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
211
|
+
"""
|
|
212
|
+
Read and process captured data from S3.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
from_date (str, optional): Only process files from this date onwards (YYYY-MM-DD format).
|
|
216
|
+
Defaults to None to process all files.
|
|
217
|
+
add_timestamp (bool, optional): Whether to add a timestamp column to the DataFrame.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Tuple[pd.DataFrame, pd.DataFrame]: Processed input and output DataFrames.
|
|
221
|
+
"""
|
|
222
|
+
files = wr.s3.list_objects(self.data_capture_path)
|
|
223
|
+
if not files:
|
|
224
|
+
self.log.warning(f"No data capture files found in {self.data_capture_path}.")
|
|
225
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
226
|
+
|
|
227
|
+
# Filter by date if specified
|
|
228
|
+
if from_date:
|
|
229
|
+
from_date_obj = datetime.strptime(from_date, "%Y-%m-%d").date()
|
|
230
|
+
files = [f for f in files if self._file_date_filter(f, from_date_obj)]
|
|
231
|
+
self.log.info(f"Processing {len(files)} files from {from_date} onwards.")
|
|
232
|
+
else:
|
|
233
|
+
self.log.info(f"Processing all {len(files)} files...")
|
|
234
|
+
|
|
235
|
+
# Check if any files remain after filtering
|
|
236
|
+
if not files:
|
|
237
|
+
self.log.info("No files to process after date filtering.")
|
|
238
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
239
|
+
|
|
240
|
+
# Sort files by name (assumed to include timestamp)
|
|
241
|
+
files.sort()
|
|
242
|
+
|
|
243
|
+
# Get all timestamps in one batch if needed
|
|
244
|
+
timestamps = {}
|
|
245
|
+
if add_timestamp:
|
|
246
|
+
# Batch describe operation - much more efficient than per-file calls
|
|
247
|
+
timestamps = wr.s3.describe_objects(path=files)
|
|
248
|
+
|
|
249
|
+
# Process files using concurrent.futures
|
|
250
|
+
start_time = time.time()
|
|
251
|
+
|
|
252
|
+
def process_single_file(file_path):
|
|
253
|
+
"""Process a single file and return input/output DataFrames."""
|
|
254
|
+
try:
|
|
255
|
+
log.debug(f"Processing file: {file_path}...")
|
|
256
|
+
df = wr.s3.read_json(path=file_path, lines=True)
|
|
257
|
+
if not df.empty:
|
|
258
|
+
input_df, output_df = process_data_capture(df)
|
|
259
|
+
if add_timestamp and file_path in timestamps:
|
|
260
|
+
output_df["timestamp"] = timestamps[file_path]["LastModified"]
|
|
261
|
+
return input_df, output_df
|
|
262
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
263
|
+
except Exception as e:
|
|
264
|
+
self.log.warning(f"Error processing {file_path}: {e}")
|
|
265
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
266
|
+
|
|
267
|
+
# Use ThreadPoolExecutor for I/O-bound operations
|
|
268
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
269
|
+
|
|
270
|
+
max_workers = min(32, len(files)) # Cap at 32 threads or number of files
|
|
271
|
+
|
|
272
|
+
all_input_dfs, all_output_dfs = [], []
|
|
273
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
274
|
+
futures = [executor.submit(process_single_file, file_path) for file_path in files]
|
|
275
|
+
for future in futures:
|
|
276
|
+
input_df, output_df = future.result()
|
|
277
|
+
if not input_df.empty:
|
|
278
|
+
all_input_dfs.append(input_df)
|
|
279
|
+
if not output_df.empty:
|
|
280
|
+
all_output_dfs.append(output_df)
|
|
281
|
+
|
|
282
|
+
if not all_input_dfs:
|
|
283
|
+
self.log.warning("No valid data was processed.")
|
|
284
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
285
|
+
|
|
286
|
+
input_df = pd.concat(all_input_dfs, ignore_index=True)
|
|
287
|
+
output_df = pd.concat(all_output_dfs, ignore_index=True)
|
|
288
|
+
|
|
289
|
+
elapsed_time = time.time() - start_time
|
|
290
|
+
self.log.info(f"Processed {len(files)} files in {elapsed_time:.2f} seconds.")
|
|
291
|
+
return input_df, output_df
|
|
292
|
+
|
|
293
|
+
def _file_date_filter(self, file_path, from_date_obj):
|
|
294
|
+
"""Extract date from S3 path and compare with from_date."""
|
|
295
|
+
try:
|
|
296
|
+
# Match YYYY/MM/DD pattern in the path
|
|
297
|
+
date_match = re.search(r"/(\d{4})/(\d{2})/(\d{2})/", file_path)
|
|
298
|
+
if date_match:
|
|
299
|
+
year, month, day = date_match.groups()
|
|
300
|
+
file_date = datetime(int(year), int(month), int(day)).date()
|
|
301
|
+
return file_date >= from_date_obj
|
|
302
|
+
return False # No date pattern found
|
|
303
|
+
except ValueError:
|
|
304
|
+
return False
|
|
305
|
+
|
|
306
|
+
def __repr__(self) -> str:
|
|
307
|
+
"""String representation of this DataCaptureCore object
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
str: String representation of this DataCaptureCore object
|
|
311
|
+
"""
|
|
312
|
+
summary_dict = self.summary()
|
|
313
|
+
summary_items = [f" {repr(key)}: {repr(value)}" for key, value in summary_dict.items()]
|
|
314
|
+
summary_str = f"{self.__class__.__name__}: {self.endpoint_name}\n" + ",\n".join(summary_items)
|
|
315
|
+
return summary_str
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
# Test function for the class
|
|
319
|
+
if __name__ == "__main__":
|
|
320
|
+
"""Exercise the MonitorCore class"""
|
|
321
|
+
from pprint import pprint
|
|
322
|
+
|
|
323
|
+
# Set options for actually seeing the dataframe
|
|
324
|
+
pd.set_option("display.max_columns", None)
|
|
325
|
+
pd.set_option("display.width", None)
|
|
326
|
+
|
|
327
|
+
# Create the Class and test it out
|
|
328
|
+
endpoint_name = "abalone-regression-rt"
|
|
329
|
+
my_endpoint = EndpointCore(endpoint_name)
|
|
330
|
+
if not my_endpoint.exists():
|
|
331
|
+
print(f"Endpoint {endpoint_name} does not exist.")
|
|
332
|
+
exit(1)
|
|
333
|
+
dc = my_endpoint.data_capture()
|
|
334
|
+
|
|
335
|
+
# Check the summary of the data capture class
|
|
336
|
+
pprint(dc.summary())
|
|
337
|
+
|
|
338
|
+
# Enable data capture on the endpoint
|
|
339
|
+
# dc.enable(force_redeploy=True)
|
|
340
|
+
my_endpoint.enable_data_capture()
|
|
341
|
+
|
|
342
|
+
# Test the data capture by running some predictions
|
|
343
|
+
# pred_df = my_endpoint.auto_inference()
|
|
344
|
+
# print(pred_df.head())
|
|
345
|
+
|
|
346
|
+
# Check that data capture is working
|
|
347
|
+
input_df, output_df = dc.get_captured_data(from_date="2025-09-01")
|
|
348
|
+
if input_df.empty and output_df.empty:
|
|
349
|
+
print("No data capture files found, for a new endpoint it may take a few minutes to start capturing data")
|
|
350
|
+
else:
|
|
351
|
+
print("Found data capture files")
|
|
352
|
+
print("Input")
|
|
353
|
+
print(input_df.head())
|
|
354
|
+
print("Output")
|
|
355
|
+
print(output_df.head())
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""DFStoreCore: Fast/efficient storage of DataFrames using AWS S3/Parquet/Snappy"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
# Workbench Imports
|
|
7
|
+
from workbench.utils.config_manager import ConfigManager
|
|
8
|
+
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
9
|
+
|
|
10
|
+
# Workbench Bridges Import
|
|
11
|
+
from workbench_bridges.api import DFStore as BridgesDFStore
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DFStoreCore(BridgesDFStore):
|
|
15
|
+
"""DFStoreCore: Fast/efficient storage of DataFrames using AWS S3/Parquet/Snappy
|
|
16
|
+
|
|
17
|
+
Common Usage:
|
|
18
|
+
```python
|
|
19
|
+
df_store = DFStoreCore()
|
|
20
|
+
|
|
21
|
+
# List Data
|
|
22
|
+
df_store.list()
|
|
23
|
+
|
|
24
|
+
# Add DataFrame
|
|
25
|
+
df = pd.DataFrame({"A": [1, 2], "B": [3, 4]})
|
|
26
|
+
df_store.upsert("/test/my_data", df)
|
|
27
|
+
|
|
28
|
+
# Retrieve DataFrame
|
|
29
|
+
df = df_store.get("/test/my_data")
|
|
30
|
+
print(df)
|
|
31
|
+
|
|
32
|
+
# Delete Data
|
|
33
|
+
df_store.delete("/test/my_data")
|
|
34
|
+
```
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, path_prefix: Union[str, None] = None):
|
|
38
|
+
"""DFStoreCore Init Method
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
path_prefix (Union[str, None], optional): Add a path prefix to storage locations (Defaults to None)
|
|
42
|
+
"""
|
|
43
|
+
# Get config from workbench's systems
|
|
44
|
+
bucket = ConfigManager().get_config("WORKBENCH_BUCKET")
|
|
45
|
+
session = AWSAccountClamp().boto3_session
|
|
46
|
+
|
|
47
|
+
# Initialize parent with workbench config
|
|
48
|
+
super().__init__(path_prefix=path_prefix, s3_bucket=bucket, boto3_session=session)
|
|
49
|
+
self.log = logging.getLogger("workbench")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
if __name__ == "__main__":
|
|
53
|
+
"""Exercise the DFStoreCore Class"""
|
|
54
|
+
import time
|
|
55
|
+
import pandas as pd
|
|
56
|
+
|
|
57
|
+
# Create a DFStoreCore manager
|
|
58
|
+
df_store = DFStoreCore()
|
|
59
|
+
|
|
60
|
+
# Details of the Dataframe Store
|
|
61
|
+
print("Detailed Data...")
|
|
62
|
+
print(df_store.details())
|
|
63
|
+
|
|
64
|
+
# Add a new DataFrame
|
|
65
|
+
my_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]})
|
|
66
|
+
df_store.upsert("/testing/test_data", my_df)
|
|
67
|
+
|
|
68
|
+
# Get the DataFrame
|
|
69
|
+
print(f"Getting data 'test_data':\n{df_store.get('/testing/test_data')}")
|
|
70
|
+
|
|
71
|
+
# Now let's test adding a Series
|
|
72
|
+
series = pd.Series([1, 2, 3, 4], name="Series")
|
|
73
|
+
df_store.upsert("/testing/test_series", series)
|
|
74
|
+
print(f"Getting data 'test_series':\n{df_store.get('/testing/test_series')}")
|
|
75
|
+
|
|
76
|
+
# Summary of the data
|
|
77
|
+
print("Summary Data...")
|
|
78
|
+
print(df_store.summary())
|
|
79
|
+
|
|
80
|
+
# Repr of the DFStoreCore object
|
|
81
|
+
print("DFStoreCore Object:")
|
|
82
|
+
print(df_store)
|
|
83
|
+
|
|
84
|
+
# Check if the data exists
|
|
85
|
+
print("Check if data exists...")
|
|
86
|
+
print(df_store.check("/testing/test_data"))
|
|
87
|
+
print(df_store.check("/testing/test_series"))
|
|
88
|
+
|
|
89
|
+
# Time the check
|
|
90
|
+
start_time = time.time()
|
|
91
|
+
print(df_store.check("/testing/test_data"))
|
|
92
|
+
print("--- Check %s seconds ---" % (time.time() - start_time))
|
|
93
|
+
|
|
94
|
+
# Now delete the test data
|
|
95
|
+
df_store.delete("/testing/test_data")
|
|
96
|
+
df_store.delete("/testing/test_series")
|
|
97
|
+
|
|
98
|
+
# Check if the data exists
|
|
99
|
+
print("Check if data exists...")
|
|
100
|
+
print(df_store.check("/testing/test_data"))
|
|
101
|
+
print(df_store.check("/testing/test_series"))
|
|
102
|
+
|
|
103
|
+
# Add a bunch of dataframes and then test recursive delete
|
|
104
|
+
for i in range(10):
|
|
105
|
+
df_store.upsert(f"/testing/data_{i}", pd.DataFrame({"A": [1, 2], "B": [3, 4]}))
|
|
106
|
+
print("Before Recursive Delete:")
|
|
107
|
+
print(df_store.summary())
|
|
108
|
+
df_store.delete_recursive("/testing")
|
|
109
|
+
print("After Recursive Delete:")
|
|
110
|
+
print(df_store.summary())
|
|
111
|
+
|
|
112
|
+
# Get a non-existent DataFrame
|
|
113
|
+
print("Getting non-existent data...")
|
|
114
|
+
print(df_store.get("/testing/no_where"))
|