workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,10 @@
1
1
  """ParameterStore: Manages Workbench parameters in a Cloud Based Parameter Store."""
2
2
 
3
- from typing import Union
4
- import logging
5
-
6
3
  # Workbench Imports
7
- from workbench.core.cloud_platform.aws.aws_parameter_store import AWSParameterStore
4
+ from workbench.core.artifacts.parameter_store_core import ParameterStoreCore
8
5
 
9
6
 
10
- class ParameterStore(AWSParameterStore):
7
+ class ParameterStore(ParameterStoreCore):
11
8
  """ParameterStore: Manages Workbench parameters in a Cloud Based Parameter Store.
12
9
 
13
10
  Common Usage:
@@ -43,56 +40,10 @@ class ParameterStore(AWSParameterStore):
43
40
 
44
41
  def __init__(self):
45
42
  """ParameterStore Init Method"""
46
- self.log = logging.getLogger("workbench")
47
43
 
48
- # Initialize the SuperClass
44
+ # Initialize parent class
49
45
  super().__init__()
50
46
 
51
- def list(self, prefix: str = None) -> list:
52
- """List all parameters in the AWS Parameter Store, optionally filtering by a prefix.
53
-
54
- Args:
55
- prefix (str, optional): A prefix to filter the parameters by. Defaults to None.
56
-
57
- Returns:
58
- list: A list of parameter names and details.
59
- """
60
- return super().list(prefix=prefix)
61
-
62
- def get(self, name: str, warn: bool = True, decrypt: bool = True) -> Union[str, list, dict, None]:
63
- """Retrieve a parameter value from the AWS Parameter Store.
64
-
65
- Args:
66
- name (str): The name of the parameter to retrieve.
67
- warn (bool): Whether to log a warning if the parameter is not found.
68
- decrypt (bool): Whether to decrypt secure string parameters.
69
-
70
- Returns:
71
- Union[str, list, dict, None]: The value of the parameter or None if not found.
72
- """
73
- return super().get(name=name, warn=warn, decrypt=decrypt)
74
-
75
- def upsert(self, name: str, value):
76
- """Insert or update a parameter in the AWS Parameter Store.
77
-
78
- Args:
79
- name (str): The name of the parameter.
80
- value (str | list | dict): The value of the parameter.
81
- """
82
- super().upsert(name=name, value=value)
83
-
84
- def delete(self, name: str):
85
- """Delete a parameter from the AWS Parameter Store.
86
-
87
- Args:
88
- name (str): The name of the parameter to delete.
89
- """
90
- super().delete(name=name)
91
-
92
- def __repr__(self):
93
- """Return a string representation of the ParameterStore object."""
94
- return super().__repr__()
95
-
96
47
 
97
48
  if __name__ == "__main__":
98
49
  """Exercise the ParameterStore Class"""
@@ -72,11 +72,11 @@ class CachedModel(CachedArtifactMixin, ModelCore):
72
72
  return super().list_inference_runs()
73
73
 
74
74
  @CachedArtifactMixin.cache_result
75
- def get_inference_metrics(self, capture_name: str = "latest") -> Union[pd.DataFrame, None]:
75
+ def get_inference_metrics(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
76
76
  """Retrieve the captured prediction results for this model
77
77
 
78
78
  Args:
79
- capture_name (str, optional): Specific capture_name (default: latest)
79
+ capture_name (str, optional): Specific capture_name (default: auto)
80
80
 
81
81
  Returns:
82
82
  pd.DataFrame: DataFrame of the Captured Metrics (might be None)
@@ -101,11 +101,11 @@ class CachedModel(CachedArtifactMixin, ModelCore):
101
101
  return df
102
102
 
103
103
  @CachedArtifactMixin.cache_result
104
- def confusion_matrix(self, capture_name: str = "latest") -> Union[pd.DataFrame, None]:
104
+ def confusion_matrix(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
105
105
  """Retrieve the confusion matrix for the model
106
106
 
107
107
  Args:
108
- capture_name (str, optional): Specific capture_name (default: latest)
108
+ capture_name (str, optional): Specific capture_name (default: auto)
109
109
 
110
110
  Returns:
111
111
  pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
@@ -15,7 +15,16 @@ from .artifact import Artifact
15
15
  from .athena_source import AthenaSource
16
16
  from .data_source_abstract import DataSourceAbstract
17
17
  from .feature_set_core import FeatureSetCore
18
- from .model_core import ModelCore, ModelType
18
+ from .model_core import ModelCore, ModelType, ModelFramework
19
19
  from .endpoint_core import EndpointCore
20
20
 
21
- __all__ = ["Artifact", "AthenaSource", "DataSourceAbstract", "FeatureSetCore", "ModelCore", "ModelType", "EndpointCore"]
21
+ __all__ = [
22
+ "Artifact",
23
+ "AthenaSource",
24
+ "DataSourceAbstract",
25
+ "FeatureSetCore",
26
+ "ModelCore",
27
+ "ModelType",
28
+ "ModelFramework",
29
+ "EndpointCore",
30
+ ]
@@ -8,8 +8,8 @@ from typing import Union
8
8
 
9
9
  # Workbench Imports
10
10
  from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
11
- from workbench.core.cloud_platform.aws.aws_parameter_store import AWSParameterStore as ParameterStore
12
- from workbench.core.cloud_platform.aws.aws_df_store import AWSDFStore as DFStore
11
+ from workbench.core.artifacts.parameter_store_core import ParameterStoreCore
12
+ from workbench.core.artifacts.df_store_core import DFStoreCore
13
13
  from workbench.utils.aws_utils import dict_to_aws_tags
14
14
  from workbench.utils.config_manager import ConfigManager, FatalConfigError
15
15
  from workbench.core.cloud_platform.cloud_meta import CloudMeta
@@ -48,11 +48,11 @@ class Artifact(ABC):
48
48
  tag_delimiter = "::"
49
49
 
50
50
  # Grab our Dataframe Cache Storage
51
- df_cache = DFStore(path_prefix="/workbench/dataframe_cache")
51
+ df_cache = DFStoreCore(path_prefix="/workbench/dataframe_cache")
52
52
 
53
53
  # Artifact may want to use the Parameter Store or Dataframe Store
54
- param_store = ParameterStore()
55
- df_store = DFStore()
54
+ param_store = ParameterStoreCore()
55
+ df_store = DFStoreCore()
56
56
 
57
57
  def __init__(self, name: str, use_cached_meta: bool = False):
58
58
  """Initialize the Artifact Base Class
@@ -236,6 +236,12 @@ class Artifact(ABC):
236
236
  This functionality will work for FeatureSets, Models, and Endpoints
237
237
  but not for DataSources. The DataSource class overrides this method.
238
238
  """
239
+
240
+ # Check for ReadOnly Role
241
+ if self.aws_account_clamp.read_only:
242
+ self.log.info("Cannot add metadata with a ReadOnly Permissions...")
243
+ return
244
+
239
245
  # Sanity check
240
246
  aws_arn = self.arn()
241
247
  if aws_arn is None:
@@ -444,10 +450,12 @@ class Artifact(ABC):
444
450
 
445
451
  if __name__ == "__main__":
446
452
  """Exercise the Artifact Class"""
447
- from workbench.api.data_source import DataSource
448
- from workbench.api.feature_set import FeatureSet
453
+ from workbench.api import DataSource, FeatureSet, Endpoint
454
+
455
+ # Grab an Endpoint (which is a subclass of Artifact)
456
+ end = Endpoint("wine-classification")
449
457
 
450
- # Create a DataSource (which is a subclass of Artifact)
458
+ # Grab a DataSource (which is a subclass of Artifact)
451
459
  data_source = DataSource("test_data")
452
460
 
453
461
  # Just some random tests
@@ -0,0 +1,355 @@
1
+ """DataCaptureCore class for managing SageMaker endpoint data capture"""
2
+
3
+ import logging
4
+ import re
5
+ import time
6
+ from datetime import datetime
7
+ from typing import Tuple
8
+ import pandas as pd
9
+ from sagemaker import Predictor
10
+ from sagemaker.model_monitor import DataCaptureConfig
11
+ import awswrangler as wr
12
+
13
+ # Workbench Imports
14
+ from workbench.core.artifacts.endpoint_core import EndpointCore
15
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
16
+ from workbench.utils.monitor_utils import process_data_capture
17
+
18
+ # Setup logging
19
+ log = logging.getLogger("workbench")
20
+
21
+
22
+ class DataCaptureCore:
23
+ """Manages data capture configuration and retrieval for SageMaker endpoints"""
24
+
25
+ def __init__(self, endpoint_name: str):
26
+ """DataCaptureCore Class
27
+
28
+ Args:
29
+ endpoint_name (str): Name of the endpoint to manage data capture for
30
+ """
31
+ self.log = logging.getLogger("workbench")
32
+ self.endpoint_name = endpoint_name
33
+ self.endpoint = EndpointCore(self.endpoint_name)
34
+
35
+ # Initialize Class Attributes
36
+ self.sagemaker_session = self.endpoint.sm_session
37
+ self.sagemaker_client = self.endpoint.sm_client
38
+ self.data_capture_path = self.endpoint.endpoint_data_capture_path
39
+ self.workbench_role_arn = AWSAccountClamp().aws_session.get_workbench_execution_role_arn()
40
+
41
+ def summary(self) -> dict:
42
+ """Return the summary of data capture configuration
43
+
44
+ Returns:
45
+ dict: Summary of data capture status
46
+ """
47
+ if self.endpoint.is_serverless():
48
+ return {"endpoint_type": "serverless", "data_capture": "not supported"}
49
+ else:
50
+ return {
51
+ "endpoint_type": "realtime",
52
+ "data_capture_enabled": self.is_enabled(),
53
+ "capture_percentage": self.capture_percentage(),
54
+ "capture_modes": self.capture_modes() if self.is_enabled() else [],
55
+ "data_capture_path": self.data_capture_path if self.is_enabled() else None,
56
+ }
57
+
58
+ def enable(self, capture_percentage=100, capture_options=None, force_redeploy=False):
59
+ """
60
+ Enable data capture for the SageMaker endpoint.
61
+
62
+ Args:
63
+ capture_percentage (int): Percentage of data to capture. Defaults to 100.
64
+ capture_options (list): List of what to capture - ["REQUEST"], ["RESPONSE"], or ["REQUEST", "RESPONSE"].
65
+ Defaults to ["REQUEST", "RESPONSE"] to capture both.
66
+ force_redeploy (bool): If True, force redeployment even if data capture is already enabled.
67
+ """
68
+ # Early returns for cases where we can't/don't need to add data capture
69
+ if self.endpoint.is_serverless():
70
+ self.log.warning("Data capture is not supported for serverless endpoints.")
71
+ return
72
+
73
+ # Default to capturing both if not specified
74
+ if capture_options is None:
75
+ capture_options = ["REQUEST", "RESPONSE"]
76
+
77
+ # Validate capture_options
78
+ valid_options = {"REQUEST", "RESPONSE"}
79
+ if not all(opt in valid_options for opt in capture_options):
80
+ self.log.error("Invalid capture_options. Must be a list containing 'REQUEST' and/or 'RESPONSE'")
81
+ return
82
+
83
+ if self.is_enabled() and not force_redeploy:
84
+ self.log.important(f"Data capture already configured for {self.endpoint_name}.")
85
+ return
86
+
87
+ # Get the current endpoint configuration name for later deletion
88
+ current_endpoint_config_name = self.endpoint.endpoint_config_name()
89
+
90
+ # Log the data capture operation
91
+ self.log.important(f"Enabling Data Capture for {self.endpoint_name} --> {self.data_capture_path}")
92
+ self.log.important(f"Capturing: {', '.join(capture_options)} at {capture_percentage}% sampling")
93
+ self.log.important("This will redeploy the endpoint...")
94
+
95
+ # Create and apply the data capture configuration
96
+ data_capture_config = DataCaptureConfig(
97
+ enable_capture=True,
98
+ sampling_percentage=capture_percentage,
99
+ destination_s3_uri=self.data_capture_path,
100
+ capture_options=capture_options,
101
+ )
102
+
103
+ # Update endpoint with the new capture configuration
104
+ Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
105
+ data_capture_config=data_capture_config
106
+ )
107
+
108
+ # Clean up old endpoint configuration
109
+ try:
110
+ self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
111
+ self.log.info(f"Deleted old endpoint configuration: {current_endpoint_config_name}")
112
+ except Exception as e:
113
+ self.log.warning(f"Could not delete old endpoint configuration {current_endpoint_config_name}: {e}")
114
+
115
+ def disable(self):
116
+ """
117
+ Disable data capture for the SageMaker endpoint.
118
+ """
119
+ # Early return if data capture isn't configured
120
+ if not self.is_enabled():
121
+ self.log.important(f"Data capture is not currently enabled for {self.endpoint_name}.")
122
+ return
123
+
124
+ # Get the current endpoint configuration name for later deletion
125
+ current_endpoint_config_name = self.endpoint.endpoint_config_name()
126
+
127
+ # Log the operation
128
+ self.log.important(f"Disabling Data Capture for {self.endpoint_name}")
129
+ self.log.important("This normally redeploys the endpoint...")
130
+
131
+ # Create a configuration with capture disabled
132
+ data_capture_config = DataCaptureConfig(enable_capture=False, destination_s3_uri=self.data_capture_path)
133
+
134
+ # Update endpoint with the new configuration
135
+ Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
136
+ data_capture_config=data_capture_config
137
+ )
138
+
139
+ # Clean up old endpoint configuration
140
+ self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
141
+
142
+ def is_enabled(self) -> bool:
143
+ """
144
+ Check if data capture is enabled on the endpoint.
145
+
146
+ Returns:
147
+ bool: True if data capture is enabled, False otherwise.
148
+ """
149
+ try:
150
+ endpoint_config_name = self.endpoint.endpoint_config_name()
151
+ endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
152
+ data_capture_config = endpoint_config.get("DataCaptureConfig", {})
153
+
154
+ # Check if data capture is enabled
155
+ is_enabled = data_capture_config.get("EnableCapture", False)
156
+ return is_enabled
157
+ except Exception as e:
158
+ self.log.error(f"Error checking data capture configuration: {e}")
159
+ return False
160
+
161
+ def capture_percentage(self) -> int:
162
+ """
163
+ Get the data capture percentage from the endpoint configuration.
164
+
165
+ Returns:
166
+ int: Data capture percentage if enabled, None otherwise.
167
+ """
168
+ try:
169
+ endpoint_config_name = self.endpoint.endpoint_config_name()
170
+ endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
171
+ data_capture_config = endpoint_config.get("DataCaptureConfig", {})
172
+
173
+ # Check if data capture is enabled and return the percentage
174
+ if data_capture_config.get("EnableCapture", False):
175
+ return data_capture_config.get("InitialSamplingPercentage", 0)
176
+ else:
177
+ return None
178
+ except Exception as e:
179
+ self.log.error(f"Error checking data capture percentage: {e}")
180
+ return None
181
+
182
+ def get_config(self) -> dict:
183
+ """
184
+ Returns the complete data capture configuration from the endpoint config.
185
+
186
+ Returns:
187
+ dict: Complete DataCaptureConfig from AWS, or None if not configured
188
+ """
189
+ config_name = self.endpoint.endpoint_config_name()
190
+ response = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=config_name)
191
+ data_capture_config = response.get("DataCaptureConfig")
192
+ if not data_capture_config:
193
+ self.log.error(f"No data capture configuration found for endpoint config {config_name}")
194
+ return None
195
+ return data_capture_config
196
+
197
+ def capture_modes(self) -> list:
198
+ """Get the current capture modes (REQUEST/RESPONSE)"""
199
+ if not self.is_enabled():
200
+ return []
201
+
202
+ config = self.get_config()
203
+ if not config:
204
+ return []
205
+
206
+ capture_options = config.get("CaptureOptions", [])
207
+ modes = [opt.get("CaptureMode") for opt in capture_options]
208
+ return ["REQUEST" if m == "Input" else "RESPONSE" for m in modes if m]
209
+
210
+ def get_captured_data(self, from_date: str = None, add_timestamp: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
211
+ """
212
+ Read and process captured data from S3.
213
+
214
+ Args:
215
+ from_date (str, optional): Only process files from this date onwards (YYYY-MM-DD format).
216
+ Defaults to None to process all files.
217
+ add_timestamp (bool, optional): Whether to add a timestamp column to the DataFrame.
218
+
219
+ Returns:
220
+ Tuple[pd.DataFrame, pd.DataFrame]: Processed input and output DataFrames.
221
+ """
222
+ files = wr.s3.list_objects(self.data_capture_path)
223
+ if not files:
224
+ self.log.warning(f"No data capture files found in {self.data_capture_path}.")
225
+ return pd.DataFrame(), pd.DataFrame()
226
+
227
+ # Filter by date if specified
228
+ if from_date:
229
+ from_date_obj = datetime.strptime(from_date, "%Y-%m-%d").date()
230
+ files = [f for f in files if self._file_date_filter(f, from_date_obj)]
231
+ self.log.info(f"Processing {len(files)} files from {from_date} onwards.")
232
+ else:
233
+ self.log.info(f"Processing all {len(files)} files...")
234
+
235
+ # Check if any files remain after filtering
236
+ if not files:
237
+ self.log.info("No files to process after date filtering.")
238
+ return pd.DataFrame(), pd.DataFrame()
239
+
240
+ # Sort files by name (assumed to include timestamp)
241
+ files.sort()
242
+
243
+ # Get all timestamps in one batch if needed
244
+ timestamps = {}
245
+ if add_timestamp:
246
+ # Batch describe operation - much more efficient than per-file calls
247
+ timestamps = wr.s3.describe_objects(path=files)
248
+
249
+ # Process files using concurrent.futures
250
+ start_time = time.time()
251
+
252
+ def process_single_file(file_path):
253
+ """Process a single file and return input/output DataFrames."""
254
+ try:
255
+ log.debug(f"Processing file: {file_path}...")
256
+ df = wr.s3.read_json(path=file_path, lines=True)
257
+ if not df.empty:
258
+ input_df, output_df = process_data_capture(df)
259
+ if add_timestamp and file_path in timestamps:
260
+ output_df["timestamp"] = timestamps[file_path]["LastModified"]
261
+ return input_df, output_df
262
+ return pd.DataFrame(), pd.DataFrame()
263
+ except Exception as e:
264
+ self.log.warning(f"Error processing {file_path}: {e}")
265
+ return pd.DataFrame(), pd.DataFrame()
266
+
267
+ # Use ThreadPoolExecutor for I/O-bound operations
268
+ from concurrent.futures import ThreadPoolExecutor
269
+
270
+ max_workers = min(32, len(files)) # Cap at 32 threads or number of files
271
+
272
+ all_input_dfs, all_output_dfs = [], []
273
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
274
+ futures = [executor.submit(process_single_file, file_path) for file_path in files]
275
+ for future in futures:
276
+ input_df, output_df = future.result()
277
+ if not input_df.empty:
278
+ all_input_dfs.append(input_df)
279
+ if not output_df.empty:
280
+ all_output_dfs.append(output_df)
281
+
282
+ if not all_input_dfs:
283
+ self.log.warning("No valid data was processed.")
284
+ return pd.DataFrame(), pd.DataFrame()
285
+
286
+ input_df = pd.concat(all_input_dfs, ignore_index=True)
287
+ output_df = pd.concat(all_output_dfs, ignore_index=True)
288
+
289
+ elapsed_time = time.time() - start_time
290
+ self.log.info(f"Processed {len(files)} files in {elapsed_time:.2f} seconds.")
291
+ return input_df, output_df
292
+
293
+ def _file_date_filter(self, file_path, from_date_obj):
294
+ """Extract date from S3 path and compare with from_date."""
295
+ try:
296
+ # Match YYYY/MM/DD pattern in the path
297
+ date_match = re.search(r"/(\d{4})/(\d{2})/(\d{2})/", file_path)
298
+ if date_match:
299
+ year, month, day = date_match.groups()
300
+ file_date = datetime(int(year), int(month), int(day)).date()
301
+ return file_date >= from_date_obj
302
+ return False # No date pattern found
303
+ except ValueError:
304
+ return False
305
+
306
+ def __repr__(self) -> str:
307
+ """String representation of this DataCaptureCore object
308
+
309
+ Returns:
310
+ str: String representation of this DataCaptureCore object
311
+ """
312
+ summary_dict = self.summary()
313
+ summary_items = [f" {repr(key)}: {repr(value)}" for key, value in summary_dict.items()]
314
+ summary_str = f"{self.__class__.__name__}: {self.endpoint_name}\n" + ",\n".join(summary_items)
315
+ return summary_str
316
+
317
+
318
+ # Test function for the class
319
+ if __name__ == "__main__":
320
+ """Exercise the MonitorCore class"""
321
+ from pprint import pprint
322
+
323
+ # Set options for actually seeing the dataframe
324
+ pd.set_option("display.max_columns", None)
325
+ pd.set_option("display.width", None)
326
+
327
+ # Create the Class and test it out
328
+ endpoint_name = "abalone-regression-rt"
329
+ my_endpoint = EndpointCore(endpoint_name)
330
+ if not my_endpoint.exists():
331
+ print(f"Endpoint {endpoint_name} does not exist.")
332
+ exit(1)
333
+ dc = my_endpoint.data_capture()
334
+
335
+ # Check the summary of the data capture class
336
+ pprint(dc.summary())
337
+
338
+ # Enable data capture on the endpoint
339
+ # dc.enable(force_redeploy=True)
340
+ my_endpoint.enable_data_capture()
341
+
342
+ # Test the data capture by running some predictions
343
+ # pred_df = my_endpoint.auto_inference()
344
+ # print(pred_df.head())
345
+
346
+ # Check that data capture is working
347
+ input_df, output_df = dc.get_captured_data(from_date="2025-09-01")
348
+ if input_df.empty and output_df.empty:
349
+ print("No data capture files found, for a new endpoint it may take a few minutes to start capturing data")
350
+ else:
351
+ print("Found data capture files")
352
+ print("Input")
353
+ print(input_df.head())
354
+ print("Output")
355
+ print(output_df.head())
@@ -0,0 +1,114 @@
1
+ """DFStoreCore: Fast/efficient storage of DataFrames using AWS S3/Parquet/Snappy"""
2
+
3
+ import logging
4
+ from typing import Union
5
+
6
+ # Workbench Imports
7
+ from workbench.utils.config_manager import ConfigManager
8
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
9
+
10
+ # Workbench Bridges Import
11
+ from workbench_bridges.api import DFStore as BridgesDFStore
12
+
13
+
14
+ class DFStoreCore(BridgesDFStore):
15
+ """DFStoreCore: Fast/efficient storage of DataFrames using AWS S3/Parquet/Snappy
16
+
17
+ Common Usage:
18
+ ```python
19
+ df_store = DFStoreCore()
20
+
21
+ # List Data
22
+ df_store.list()
23
+
24
+ # Add DataFrame
25
+ df = pd.DataFrame({"A": [1, 2], "B": [3, 4]})
26
+ df_store.upsert("/test/my_data", df)
27
+
28
+ # Retrieve DataFrame
29
+ df = df_store.get("/test/my_data")
30
+ print(df)
31
+
32
+ # Delete Data
33
+ df_store.delete("/test/my_data")
34
+ ```
35
+ """
36
+
37
+ def __init__(self, path_prefix: Union[str, None] = None):
38
+ """DFStoreCore Init Method
39
+
40
+ Args:
41
+ path_prefix (Union[str, None], optional): Add a path prefix to storage locations (Defaults to None)
42
+ """
43
+ # Get config from workbench's systems
44
+ bucket = ConfigManager().get_config("WORKBENCH_BUCKET")
45
+ session = AWSAccountClamp().boto3_session
46
+
47
+ # Initialize parent with workbench config
48
+ super().__init__(path_prefix=path_prefix, s3_bucket=bucket, boto3_session=session)
49
+ self.log = logging.getLogger("workbench")
50
+
51
+
52
+ if __name__ == "__main__":
53
+ """Exercise the DFStoreCore Class"""
54
+ import time
55
+ import pandas as pd
56
+
57
+ # Create a DFStoreCore manager
58
+ df_store = DFStoreCore()
59
+
60
+ # Details of the Dataframe Store
61
+ print("Detailed Data...")
62
+ print(df_store.details())
63
+
64
+ # Add a new DataFrame
65
+ my_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]})
66
+ df_store.upsert("/testing/test_data", my_df)
67
+
68
+ # Get the DataFrame
69
+ print(f"Getting data 'test_data':\n{df_store.get('/testing/test_data')}")
70
+
71
+ # Now let's test adding a Series
72
+ series = pd.Series([1, 2, 3, 4], name="Series")
73
+ df_store.upsert("/testing/test_series", series)
74
+ print(f"Getting data 'test_series':\n{df_store.get('/testing/test_series')}")
75
+
76
+ # Summary of the data
77
+ print("Summary Data...")
78
+ print(df_store.summary())
79
+
80
+ # Repr of the DFStoreCore object
81
+ print("DFStoreCore Object:")
82
+ print(df_store)
83
+
84
+ # Check if the data exists
85
+ print("Check if data exists...")
86
+ print(df_store.check("/testing/test_data"))
87
+ print(df_store.check("/testing/test_series"))
88
+
89
+ # Time the check
90
+ start_time = time.time()
91
+ print(df_store.check("/testing/test_data"))
92
+ print("--- Check %s seconds ---" % (time.time() - start_time))
93
+
94
+ # Now delete the test data
95
+ df_store.delete("/testing/test_data")
96
+ df_store.delete("/testing/test_series")
97
+
98
+ # Check if the data exists
99
+ print("Check if data exists...")
100
+ print(df_store.check("/testing/test_data"))
101
+ print(df_store.check("/testing/test_series"))
102
+
103
+ # Add a bunch of dataframes and then test recursive delete
104
+ for i in range(10):
105
+ df_store.upsert(f"/testing/data_{i}", pd.DataFrame({"A": [1, 2], "B": [3, 4]}))
106
+ print("Before Recursive Delete:")
107
+ print(df_store.summary())
108
+ df_store.delete_recursive("/testing")
109
+ print("After Recursive Delete:")
110
+ print(df_store.summary())
111
+
112
+ # Get a non-existent DataFrame
113
+ print("Getting non-existent data...")
114
+ print(df_store.get("/testing/no_where"))