workbench 0.8.161__py3-none-any.whl → 0.8.192__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. workbench/algorithms/dataframe/proximity.py +143 -102
  2. workbench/algorithms/graph/light/proximity_graph.py +2 -1
  3. workbench/api/compound.py +1 -1
  4. workbench/api/endpoint.py +12 -0
  5. workbench/api/feature_set.py +4 -4
  6. workbench/api/meta.py +5 -2
  7. workbench/api/model.py +16 -12
  8. workbench/api/monitor.py +1 -16
  9. workbench/core/artifacts/artifact.py +11 -3
  10. workbench/core/artifacts/data_capture_core.py +355 -0
  11. workbench/core/artifacts/endpoint_core.py +168 -78
  12. workbench/core/artifacts/feature_set_core.py +72 -13
  13. workbench/core/artifacts/model_core.py +50 -15
  14. workbench/core/artifacts/monitor_core.py +33 -248
  15. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  16. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  17. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  18. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  19. workbench/core/transforms/features_to_model/features_to_model.py +9 -4
  20. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  21. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  22. workbench/core/views/training_view.py +49 -53
  23. workbench/core/views/view.py +51 -1
  24. workbench/core/views/view_utils.py +4 -4
  25. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  26. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  27. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  28. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  29. workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
  30. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  31. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
  32. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  33. workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
  34. workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
  35. workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
  36. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  37. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
  38. workbench/model_scripts/pytorch_model/pytorch.template +19 -20
  39. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  40. workbench/model_scripts/script_generation.py +7 -2
  41. workbench/model_scripts/uq_models/mapie.template +492 -0
  42. workbench/model_scripts/uq_models/requirements.txt +1 -0
  43. workbench/model_scripts/xgb_model/xgb_model.template +31 -40
  44. workbench/repl/workbench_shell.py +11 -6
  45. workbench/scripts/lambda_launcher.py +63 -0
  46. workbench/scripts/ml_pipeline_batch.py +137 -0
  47. workbench/scripts/ml_pipeline_sqs.py +186 -0
  48. workbench/scripts/monitor_cloud_watch.py +20 -100
  49. workbench/utils/aws_utils.py +4 -3
  50. workbench/utils/chem_utils/__init__.py +0 -0
  51. workbench/utils/chem_utils/fingerprints.py +134 -0
  52. workbench/utils/chem_utils/misc.py +194 -0
  53. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  54. workbench/utils/chem_utils/mol_standardize.py +450 -0
  55. workbench/utils/chem_utils/mol_tagging.py +348 -0
  56. workbench/utils/chem_utils/projections.py +209 -0
  57. workbench/utils/chem_utils/salts.py +256 -0
  58. workbench/utils/chem_utils/sdf.py +292 -0
  59. workbench/utils/chem_utils/toxicity.py +250 -0
  60. workbench/utils/chem_utils/vis.py +253 -0
  61. workbench/utils/cloudwatch_handler.py +1 -1
  62. workbench/utils/cloudwatch_utils.py +137 -0
  63. workbench/utils/config_manager.py +3 -7
  64. workbench/utils/endpoint_utils.py +5 -7
  65. workbench/utils/license_manager.py +2 -6
  66. workbench/utils/model_utils.py +76 -30
  67. workbench/utils/monitor_utils.py +44 -62
  68. workbench/utils/pandas_utils.py +3 -3
  69. workbench/utils/shap_utils.py +10 -2
  70. workbench/utils/workbench_logging.py +0 -3
  71. workbench/utils/workbench_sqs.py +1 -1
  72. workbench/utils/xgboost_model_utils.py +283 -145
  73. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  74. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  75. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  76. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/METADATA +4 -4
  77. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/RECORD +81 -76
  78. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/entry_points.txt +3 -0
  79. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  80. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  81. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  82. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  83. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  84. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  85. workbench/model_scripts/pytorch_model/generated_model_script.py +0 -565
  86. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  87. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  88. workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
  89. workbench/model_scripts/xgb_model/generated_model_script.py +0 -477
  90. workbench/utils/chem_utils.py +0 -1556
  91. workbench/utils/execution_environment.py +0 -211
  92. workbench/utils/fast_inference.py +0 -167
  93. workbench/utils/resource_utils.py +0 -39
  94. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/WHEEL +0 -0
  95. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/licenses/LICENSE +0 -0
  96. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
21
21
  from workbench.utils.s3_utils import compute_s3_object_hash
22
22
  from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
23
23
  from workbench.utils.deprecated_utils import deprecated
24
+ from workbench.utils.model_utils import proximity_model
24
25
 
25
26
 
26
27
  class ModelType(Enum):
@@ -42,11 +43,11 @@ class ModelImages:
42
43
 
43
44
  image_uris = {
44
45
  # US East 1 images
45
- ("us-east-1", "xgb_training", "0.1", "x86_64"): (
46
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
46
+ ("us-east-1", "training", "0.1", "x86_64"): (
47
+ "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
47
48
  ),
48
- ("us-east-1", "xgb_inference", "0.1", "x86_64"): (
49
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
49
+ ("us-east-1", "inference", "0.1", "x86_64"): (
50
+ "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
50
51
  ),
51
52
  ("us-east-1", "pytorch_training", "0.1", "x86_64"): (
52
53
  "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
@@ -55,11 +56,11 @@ class ModelImages:
55
56
  "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
56
57
  ),
57
58
  # US West 2 images
58
- ("us-west-2", "xgb_training", "0.1", "x86_64"): (
59
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
59
+ ("us-west-2", "training", "0.1", "x86_64"): (
60
+ "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
60
61
  ),
61
- ("us-west-2", "xgb_inference", "0.1", "x86_64"): (
62
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
62
+ ("us-west-2", "inference", "0.1", "x86_64"): (
63
+ "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
63
64
  ),
64
65
  ("us-west-2", "pytorch_training", "0.1", "x86_64"): (
65
66
  "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
@@ -68,12 +69,6 @@ class ModelImages:
68
69
  "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
69
70
  ),
70
71
  # ARM64 images
71
- ("us-east-1", "xgb_inference", "0.1", "arm64"): (
72
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
73
- ),
74
- ("us-west-2", "xgb_inference", "0.1", "arm64"): (
75
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
76
- ),
77
72
  # Meta Endpoint inference images
78
73
  ("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
79
74
  "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
@@ -597,6 +592,24 @@ class ModelCore(Artifact):
597
592
  # Return the details
598
593
  return details
599
594
 
595
+ # Training View for this model
596
+ def training_view(self):
597
+ """Get the training view for this model"""
598
+ from workbench.core.artifacts.feature_set_core import FeatureSetCore
599
+ from workbench.core.views import View
600
+
601
+ # Grab our FeatureSet
602
+ fs = FeatureSetCore(self.get_input())
603
+
604
+ # See if we have a training view for this model
605
+ my_model_training_view = f"{self.name.replace('-', '_')}_training"
606
+ view = View(fs, my_model_training_view, auto_create_view=False)
607
+ if view.exists():
608
+ return view
609
+ else:
610
+ self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
611
+ return fs.view("training")
612
+
600
613
  # Pipeline for this model
601
614
  def get_pipeline(self) -> str:
602
615
  """Get the pipeline for this model"""
@@ -867,6 +880,14 @@ class ModelCore(Artifact):
867
880
  shap_data[key] = self.df_store.get(df_location)
868
881
  return shap_data or None
869
882
 
883
+ def cross_folds(self) -> dict:
884
+ """Retrieve the cross-fold inference results(only works for XGBoost models)
885
+
886
+ Returns:
887
+ dict: Dictionary with the cross-fold inference results
888
+ """
889
+ return self.param_store.get(f"/workbench/models/{self.name}/inference/cross_fold")
890
+
870
891
  def supported_inference_instances(self) -> Optional[list]:
871
892
  """Retrieve the supported endpoint inference instance types
872
893
 
@@ -879,10 +900,24 @@ class ModelCore(Artifact):
879
900
  except (KeyError, IndexError, TypeError):
880
901
  return None
881
902
 
903
+ def publish_prox_model(self, prox_model_name: str = None, track_columns: list = None):
904
+ """Create and publish a Proximity Model for this Model
905
+
906
+ Args:
907
+ prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
908
+ track_columns (list, optional): List of columns to track in the Proximity Model.
909
+
910
+ Returns:
911
+ Model: The published Proximity Model
912
+ """
913
+ if prox_model_name is None:
914
+ prox_model_name = self.model_name + "-prox"
915
+ return proximity_model(self, prox_model_name, track_columns=track_columns)
916
+
882
917
  def delete(self):
883
918
  """Delete the Model Packages and the Model Group"""
884
919
  if not self.exists():
885
- self.log.warning(f"Trying to delete an Model that doesn't exist: {self.name}")
920
+ self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
886
921
 
887
922
  # Call the Class Method to delete the Model Group
888
923
  ModelCore.managed_delete(model_group_name=self.name)
@@ -2,12 +2,10 @@
2
2
 
3
3
  import logging
4
4
  import json
5
- from typing import Union, Tuple
5
+ from typing import Union
6
6
  import pandas as pd
7
- from sagemaker import Predictor
8
7
  from sagemaker.model_monitor import (
9
8
  CronExpressionGenerator,
10
- DataCaptureConfig,
11
9
  DefaultModelMonitor,
12
10
  DatasetFormat,
13
11
  )
@@ -15,29 +13,32 @@ import awswrangler as wr
15
13
 
16
14
  # Workbench Imports
17
15
  from workbench.core.artifacts.endpoint_core import EndpointCore
16
+ from workbench.core.artifacts.data_capture_core import DataCaptureCore
18
17
  from workbench.api import Model, FeatureSet
19
18
  from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
20
19
  from workbench.utils.s3_utils import read_content_from_s3, upload_content_to_s3
21
20
  from workbench.utils.datetime_utils import datetime_string
22
21
  from workbench.utils.monitor_utils import (
23
- process_data_capture,
24
22
  get_monitor_json_data,
25
23
  parse_monitoring_results,
26
24
  preprocessing_script,
27
25
  )
28
26
 
29
- # Note: This resource might come in handy when doing code refactoring
27
+ # Note: These resources might come in handy when doing code refactoring
30
28
  # https://github.com/aws-samples/amazon-sagemaker-from-idea-to-production/blob/master/06-monitoring.ipynb
31
29
  # https://docs.aws.amazon.com/sagemaker/latest/dg/model-monitor-pre-and-post-processing.html
32
30
  # https://github.com/aws/amazon-sagemaker-examples/blob/main/sagemaker_model_monitor/introduction/SageMaker-ModelMonitoring.ipynb
33
31
 
34
32
 
35
33
  class MonitorCore:
34
+ """Manages monitoring, baselines, and monitoring schedules for SageMaker endpoints"""
35
+
36
36
  def __init__(self, endpoint_name, instance_type="ml.m5.large"):
37
37
  """MonitorCore Class
38
+
38
39
  Args:
39
40
  endpoint_name (str): Name of the endpoint to set up monitoring for
40
- instance_type (str): Instance type to use for monitoring. Defaults to "ml.t3.medium".
41
+ instance_type (str): Instance type to use for monitoring. Defaults to "ml.m5.large".
41
42
  """
42
43
  self.log = logging.getLogger("workbench")
43
44
  self.endpoint_name = endpoint_name
@@ -46,7 +47,6 @@ class MonitorCore:
46
47
  # Initialize Class Attributes
47
48
  self.sagemaker_session = self.endpoint.sm_session
48
49
  self.sagemaker_client = self.endpoint.sm_client
49
- self.data_capture_path = self.endpoint.endpoint_data_capture_path
50
50
  self.monitoring_path = self.endpoint.endpoint_monitoring_path
51
51
  self.monitoring_schedule_name = f"{self.endpoint_name}-monitoring-schedule"
52
52
  self.baseline_dir = f"{self.monitoring_path}/baseline"
@@ -57,6 +57,10 @@ class MonitorCore:
57
57
  self.workbench_role_arn = AWSAccountClamp().aws_session.get_workbench_execution_role_arn()
58
58
  self.instance_type = instance_type
59
59
 
60
+ # Create DataCaptureCore instance for composition
61
+ self.data_capture = DataCaptureCore(endpoint_name)
62
+ self.data_capture_path = self.data_capture.data_capture_path
63
+
60
64
  # Check if a monitoring schedule already exists for this endpoint
61
65
  existing_schedule = self.monitoring_schedule_exists()
62
66
 
@@ -74,23 +78,20 @@ class MonitorCore:
74
78
  self.log.info(f"Initialized new model monitor for {self.endpoint_name}")
75
79
 
76
80
  def summary(self) -> dict:
77
- """Return the summary of information about the endpoint monitor
81
+ """Return the summary of monitoring configuration
78
82
 
79
83
  Returns:
80
- dict: Summary of information about the endpoint monitor
84
+ dict: Summary of monitoring status
81
85
  """
82
86
  if self.endpoint.is_serverless():
83
87
  return {
84
88
  "endpoint_type": "serverless",
85
- "data_capture": "not supported",
86
89
  "baseline": "not supported",
87
90
  "monitoring_schedule": "not supported",
88
91
  }
89
92
  else:
90
93
  summary = {
91
94
  "endpoint_type": "realtime",
92
- "data_capture": self.data_capture_enabled(),
93
- "capture_percent": self.data_capture_percent(),
94
95
  "baseline": self.baseline_exists(),
95
96
  "monitoring_schedule": self.monitoring_schedule_exists(),
96
97
  "preprocessing": self.preprocessing_exists(),
@@ -103,22 +104,15 @@ class MonitorCore:
103
104
  Returns:
104
105
  dict: The monitoring details for the endpoint
105
106
  """
106
- # Get the actual data capture path
107
- actual_capture_path = self.data_capture_config()["DestinationS3Uri"]
108
- if actual_capture_path != self.data_capture_path:
109
- self.log.warning(
110
- f"Data capture path mismatch: Expected {self.data_capture_path}, "
111
- f"but found {actual_capture_path}. Using the actual path."
112
- )
113
- self.data_capture_path = actual_capture_path
114
107
  result = self.summary()
115
108
  info = {
116
- "data_capture_path": self.data_capture_path if self.data_capture_enabled() else None,
117
- "preprocessing_script_file": self.preprocessing_script_file if self.preprocessing_exists() else None,
118
109
  "monitoring_schedule_status": "Not Scheduled",
119
110
  }
120
111
  result.update(info)
121
112
 
113
+ if self.preprocessing_exists():
114
+ result["preprocessing_script_file"] = self.preprocessing_script_file
115
+
122
116
  if self.baseline_exists():
123
117
  result.update(
124
118
  {
@@ -144,7 +138,6 @@ class MonitorCore:
144
138
 
145
139
  last_run = schedule_details.get("LastMonitoringExecutionSummary", {})
146
140
  if last_run:
147
-
148
141
  # If no inference was run since the last monitoring schedule, the
149
142
  # status will be "Failed" with reason "Job inputs had no data",
150
143
  # so we check for that and set the status to "No New Data"
@@ -162,187 +155,22 @@ class MonitorCore:
162
155
 
163
156
  return result
164
157
 
165
- def enable_data_capture(self, capture_percentage=100, force=False):
166
- """
167
- Enable data capture for the SageMaker endpoint.
158
+ def enable_data_capture(self, capture_percentage=100):
159
+ """Enable data capture for the endpoint
168
160
 
169
161
  Args:
170
- capture_percentage (int): Percentage of data to capture. Defaults to 100.
171
- force (bool): If True, force reconfiguration even if data capture is already enabled.
162
+ capture_percentage (int): Percentage of requests to capture (0-100, default 100)
172
163
  """
173
- # Early returns for cases where we can't/don't need to add data capture
174
164
  if self.endpoint.is_serverless():
175
165
  self.log.warning("Data capture is not supported for serverless endpoints.")
176
166
  return
177
167
 
178
- if self.data_capture_enabled() and not force:
179
- self.log.important(f"Data capture already configured for {self.endpoint_name}.")
180
- return
181
-
182
- # Get the current endpoint configuration name for later deletion
183
- current_endpoint_config_name = self.endpoint.endpoint_config_name()
184
-
185
- # Log the data capture operation
186
- self.log.important(f"Enabling Data Capture for {self.endpoint_name} --> {self.data_capture_path}")
187
- self.log.important("This normally redeploys the endpoint...")
188
-
189
- # Create and apply the data capture configuration
190
- data_capture_config = DataCaptureConfig(
191
- enable_capture=True, # Required parameter
192
- sampling_percentage=capture_percentage,
193
- destination_s3_uri=self.data_capture_path,
194
- )
195
-
196
- # Update endpoint with the new capture configuration
197
- Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
198
- data_capture_config=data_capture_config
199
- )
200
-
201
- # Clean up old endpoint configuration
202
- self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
203
-
204
- def data_capture_config(self):
205
- """
206
- Returns the complete data capture configuration from the endpoint config.
207
- Returns:
208
- dict: Complete DataCaptureConfig from AWS, or None if not configured
209
- """
210
- config_name = self.endpoint.endpoint_config_name()
211
- response = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=config_name)
212
- data_capture_config = response.get("DataCaptureConfig")
213
- if not data_capture_config:
214
- self.log.error(f"No data capture configuration found for endpoint config {config_name}")
215
- return None
216
- return data_capture_config
217
-
218
- def disable_data_capture(self):
219
- """
220
- Disable data capture for the SageMaker endpoint.
221
- """
222
- # Early return if data capture isn't configured
223
- if not self.data_capture_enabled():
224
- self.log.important(f"Data capture is not currently enabled for {self.endpoint_name}.")
168
+ if self.data_capture.is_enabled():
169
+ self.log.info(f"Data capture is already enabled for {self.endpoint_name}.")
225
170
  return
226
171
 
227
- # Get the current endpoint configuration name for later deletion
228
- current_endpoint_config_name = self.endpoint.endpoint_config_name()
229
-
230
- # Log the operation
231
- self.log.important(f"Disabling Data Capture for {self.endpoint_name}")
232
- self.log.important("This normally redeploys the endpoint...")
233
-
234
- # Create a configuration with capture disabled
235
- data_capture_config = DataCaptureConfig(enable_capture=False, destination_s3_uri=self.data_capture_path)
236
-
237
- # Update endpoint with the new configuration
238
- Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
239
- data_capture_config=data_capture_config
240
- )
241
-
242
- # Clean up old endpoint configuration
243
- self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
244
-
245
- def data_capture_enabled(self):
246
- """
247
- Check if data capture is already configured on the endpoint.
248
- Args:
249
- capture_percentage (int): Expected data capture percentage.
250
- Returns:
251
- bool: True if data capture is already configured, False otherwise.
252
- """
253
- try:
254
- endpoint_config_name = self.endpoint.endpoint_config_name()
255
- endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
256
- data_capture_config = endpoint_config.get("DataCaptureConfig", {})
257
-
258
- # Check if data capture is enabled and the percentage matches
259
- is_enabled = data_capture_config.get("EnableCapture", False)
260
- return is_enabled
261
- except Exception as e:
262
- self.log.error(f"Error checking data capture configuration: {e}")
263
- return False
264
-
265
- def data_capture_percent(self):
266
- """
267
- Get the data capture percentage from the endpoint configuration.
268
-
269
- Returns:
270
- int: Data capture percentage if enabled, None otherwise.
271
- """
272
- try:
273
- endpoint_config_name = self.endpoint.endpoint_config_name()
274
- endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
275
- data_capture_config = endpoint_config.get("DataCaptureConfig", {})
276
-
277
- # Check if data capture is enabled and return the percentage
278
- if data_capture_config.get("EnableCapture", False):
279
- return data_capture_config.get("InitialSamplingPercentage", 0)
280
- else:
281
- return None
282
- except Exception as e:
283
- self.log.error(f"Error checking data capture percentage: {e}")
284
- return None
285
-
286
- def get_captured_data(self, max_files=None, add_timestamp=True) -> Tuple[pd.DataFrame, pd.DataFrame]:
287
- """
288
- Read and process captured data from S3.
289
-
290
- Args:
291
- max_files (int, optional): Maximum number of files to process.
292
- Defaults to None to process all files.
293
- add_timestamp (bool, optional): Whether to add a timestamp column to the DataFrame.
294
-
295
- Returns:
296
- Tuple[pd.DataFrame, pd.DataFrame]: Processed input and output DataFrames.
297
- """
298
- # List files in the specified S3 path
299
- files = wr.s3.list_objects(self.data_capture_path)
300
- if not files:
301
- self.log.warning(f"No data capture files found in {self.data_capture_path}.")
302
- return pd.DataFrame(), pd.DataFrame()
303
-
304
- self.log.info(f"Found {len(files)} files in {self.data_capture_path}.")
305
-
306
- # Sort files by timestamp (assuming the naming convention includes timestamp)
307
- files.sort()
308
-
309
- # Select files to process
310
- if max_files is None:
311
- files_to_process = files
312
- self.log.info(f"Processing all {len(files)} files.")
313
- else:
314
- files_to_process = files[-max_files:] if files else []
315
- self.log.info(f"Processing the {len(files_to_process)} most recent file(s).")
316
-
317
- # Process each file
318
- all_input_dfs = []
319
- all_output_dfs = []
320
- for file_path in files_to_process:
321
- self.log.info(f"Processing {file_path}...")
322
- try:
323
- # Read the JSON lines file
324
- df = wr.s3.read_json(path=file_path, lines=True)
325
- if not df.empty:
326
- input_df, output_df = process_data_capture(df)
327
- # Generate a timestamp column if requested
328
- if add_timestamp:
329
- # Get file metadata to extract last modified time
330
- file_metadata = wr.s3.describe_objects(path=file_path)
331
- timestamp = file_metadata[file_path]["LastModified"]
332
- output_df["timestamp"] = timestamp
333
-
334
- # Append the processed DataFrames to the lists
335
- all_input_dfs.append(input_df)
336
- all_output_dfs.append(output_df)
337
- except Exception as e:
338
- self.log.warning(f"Error processing file {file_path}: {e}")
339
-
340
- # Combine all DataFrames
341
- if not all_input_dfs or not all_output_dfs:
342
- self.log.warning("No valid data was processed from the captured files.")
343
- return pd.DataFrame(), pd.DataFrame()
344
-
345
- return pd.concat(all_input_dfs, ignore_index=True), pd.concat(all_output_dfs, ignore_index=True)
172
+ self.data_capture.enable(capture_percentage=capture_percentage)
173
+ self.log.important(f"Enabled data capture for {self.endpoint_name} at {self.data_capture_path}")
346
174
 
347
175
  def baseline_exists(self) -> bool:
348
176
  """
@@ -533,6 +361,11 @@ class MonitorCore:
533
361
  self.log.warning("If you want to create another one, delete existing schedule first.")
534
362
  return
535
363
 
364
+ # Check if data capture is enabled, if not enable it
365
+ if not self.data_capture.is_enabled():
366
+ self.log.warning("Data capture is not enabled for this endpoint. Enabling it now...")
367
+ self.enable_data_capture(capture_percentage=100)
368
+
536
369
  # Set up a NEW monitoring schedule
537
370
  schedule_args = {
538
371
  "monitor_schedule_name": self.monitoring_schedule_name,
@@ -577,33 +410,6 @@ class MonitorCore:
577
410
  self.model_monitor.delete_monitoring_schedule()
578
411
  self.log.important(f"Deleted monitoring schedule for {self.endpoint_name}.")
579
412
 
580
- # Put this functionality into this class
581
- """
582
- executions = my_monitor.list_executions()
583
- latest_execution = executions[-1]
584
-
585
- latest_execution.describe()['ProcessingJobStatus']
586
- latest_execution.describe()['ExitMessage']
587
- Here are the possible terminal states and what each of them means:
588
-
589
- - Completed - This means the monitoring execution completed and no issues were found in the violations report.
590
- - CompletedWithViolations - This means the execution completed, but constraint violations were detected.
591
- - Failed - The monitoring execution failed, maybe due to client error
592
- (perhaps incorrect role premissions) or infrastructure issues. Further
593
- examination of the FailureReason and ExitMessage is necessary to identify what exactly happened.
594
- - Stopped - job exceeded the max runtime or was manually stopped.
595
- You can also get the S3 URI for the output with latest_execution.output.destination and analyze the results.
596
-
597
- Visualize results
598
- You can use the monitor object to gather reports for visualization:
599
-
600
- suggested_constraints = my_monitor.suggested_constraints()
601
- baseline_statistics = my_monitor.baseline_statistics()
602
-
603
- latest_monitoring_violations = my_monitor.latest_monitoring_constraint_violations()
604
- latest_monitoring_statistics = my_monitor.latest_monitoring_statistics()
605
- """
606
-
607
413
  def get_monitoring_results(self, max_results=10) -> pd.DataFrame:
608
414
  """Get the results of monitoring executions
609
415
 
@@ -758,7 +564,7 @@ class MonitorCore:
758
564
  Returns:
759
565
  str: String representation of this MonitorCore object
760
566
  """
761
- summary_dict = {} # Disabling for now self.summary()
567
+ summary_dict = self.summary()
762
568
  summary_items = [f" {repr(key)}: {repr(value)}" for key, value in summary_dict.items()]
763
569
  summary_str = f"{self.__class__.__name__}: {self.endpoint_name}\n" + ",\n".join(summary_items)
764
570
  return summary_str
@@ -775,7 +581,6 @@ if __name__ == "__main__":
775
581
 
776
582
  # Create the Class and test it out
777
583
  endpoint_name = "abalone-regression-rt"
778
- endpoint_name = "logd-dev-reg-rt"
779
584
  my_endpoint = EndpointCore(endpoint_name)
780
585
  if not my_endpoint.exists():
781
586
  print(f"Endpoint {endpoint_name} does not exist.")
@@ -788,11 +593,10 @@ if __name__ == "__main__":
788
593
  # Check the details of the monitoring class
789
594
  pprint(mm.details())
790
595
 
791
- # Enable data capture on the endpoint
792
- mm.enable_data_capture()
596
+ # Enable data capture (if not already enabled)
597
+ mm.enable_data_capture(capture_percentage=100)
793
598
 
794
599
  # Create a baseline for monitoring
795
- # mm.create_baseline(recreate=True)
796
600
  mm.create_baseline()
797
601
 
798
602
  # Check the monitoring outputs
@@ -804,30 +608,11 @@ if __name__ == "__main__":
804
608
  pprint(mm.get_constraints())
805
609
 
806
610
  print("\nStatistics...")
807
- print(mm.get_statistics())
611
+ print(str(mm.get_statistics())[:1000]) # Print only first 1000 characters
808
612
 
809
613
  # Set up the monitoring schedule (if it doesn't already exist)
810
614
  mm.create_monitoring_schedule()
811
615
 
812
- #
813
- # Test the data capture by running some predictions
814
- #
815
-
816
- # Make predictions on the Endpoint using the FeatureSet evaluation data
817
- # pred_df = my_endpoint.auto_inference()
818
- # print(pred_df.head())
819
-
820
- # Check that data capture is working
821
- input_df, output_df = mm.get_captured_data()
822
- if input_df.empty or output_df.empty:
823
- print("No data capture files found, for a new endpoint it may take a few minutes to start capturing data")
824
- else:
825
- print("Found data capture files")
826
- print("Input")
827
- print(input_df.head())
828
- print("Output")
829
- print(output_df.head())
830
-
831
616
  # Test update_constraints (commented out for now)
832
617
  # print("\nTesting constraint updates...")
833
618
  # custom_constraints = {"sex": {"allowed_values": ["M", "F", "I"]}, "length": {"min": 0.0, "max": 1.0}}
@@ -846,7 +631,7 @@ if __name__ == "__main__":
846
631
  print("\nTesting execution details retrieval...")
847
632
  if not results_df.empty:
848
633
  latest_execution_arn = results_df.iloc[0]["processing_job_arn"]
849
- execution_details = mm.get_execution_details(latest_execution_arn)
634
+ execution_details = mm.get_execution_details(latest_execution_arn) if latest_execution_arn else None
850
635
  if execution_details:
851
636
  print(f"Execution details for {latest_execution_arn}:")
852
637
  pprint(execution_details)
@@ -54,7 +54,11 @@ class AWSAccountClamp:
54
54
 
55
55
  # Check our Assume Role
56
56
  self.log.info("Checking Workbench Assumed Role...")
57
- self.aws_session.assumed_role_info()
57
+ role_info = self.aws_session.assumed_role_info()
58
+ self.log.info(f"Assumed Role: {role_info}")
59
+
60
+ # Check if we have tag write permissions (if we don't, we are read-only)
61
+ self.read_only = not self.check_tag_permissions()
58
62
 
59
63
  # Check our Workbench API Key and Load the License
60
64
  self.log.info("Checking Workbench API License...")
@@ -138,6 +142,45 @@ class AWSAccountClamp:
138
142
  """
139
143
  return self.boto3_session.client("sagemaker")
140
144
 
145
+ def check_tag_permissions(self):
146
+ """Check if current role has permission to add tags to SageMaker endpoints.
147
+
148
+ Returns:
149
+ bool: True if AddTags is allowed, False otherwise
150
+ """
151
+ try:
152
+ sagemaker = self.boto3_session.client("sagemaker")
153
+
154
+ # Use a non-existent endpoint name
155
+ fake_endpoint = "workbench-permission-check-dummy-endpoint"
156
+
157
+ # Try to add tags to the non-existent endpoint
158
+ sagemaker.add_tags(
159
+ ResourceArn=f"arn:aws:sagemaker:{self.region}:{self.account_id}:endpoint/{fake_endpoint}",
160
+ Tags=[{"Key": "PermissionCheck", "Value": "Test"}],
161
+ )
162
+
163
+ # If we get here, we have permission (but endpoint doesn't exist)
164
+ return True
165
+
166
+ except ClientError as e:
167
+ error_code = e.response["Error"]["Code"]
168
+
169
+ # AccessDeniedException = no permission
170
+ if error_code == "AccessDeniedException":
171
+ self.log.debug("No AddTags permission (AccessDeniedException)")
172
+ return False
173
+
174
+ # ResourceNotFound = we have permission, but endpoint doesn't exist
175
+ elif error_code in ["ResourceNotFound", "ValidationException"]:
176
+ self.log.debug("AddTags permission verified (resource not found)")
177
+ return True
178
+
179
+ # Unexpected error, assume no permission for safety
180
+ else:
181
+ self.log.debug(f"Unexpected error checking permissions: {error_code}")
182
+ return False
183
+
141
184
 
142
185
  if __name__ == "__main__":
143
186
  """Exercise the AWS Account Clamp Class"""
@@ -162,3 +205,9 @@ if __name__ == "__main__":
162
205
  print("\n\n*** AWS Sagemaker Session/Client Check ***")
163
206
  sm_client = aws_account_clamp.sagemaker_client()
164
207
  print(sm_client.list_feature_groups()["FeatureGroupSummaries"])
208
+
209
+ print("\n\n*** AWS Tag Permission Check ***")
210
+ if aws_account_clamp.check_tag_permissions():
211
+ print("Tag Permission Check Success...")
212
+ else:
213
+ print("Tag Permission Check Failed...")
@@ -196,7 +196,9 @@ class AWSMeta:
196
196
 
197
197
  # Return the summary as a DataFrame
198
198
  df = pd.DataFrame(data_summary).convert_dtypes()
199
- return df.sort_values(by="Created", ascending=False)
199
+ if not df.empty:
200
+ df.sort_values(by="Created", ascending=False, inplace=True)
201
+ return df
200
202
 
201
203
  def models(self, details: bool = False) -> pd.DataFrame:
202
204
  """Get a summary of the Models in AWS.
@@ -256,7 +258,9 @@ class AWSMeta:
256
258
 
257
259
  # Return the summary as a DataFrame
258
260
  df = pd.DataFrame(model_summary).convert_dtypes()
259
- return df.sort_values(by="Created", ascending=False)
261
+ if not df.empty:
262
+ df.sort_values(by="Created", ascending=False, inplace=True)
263
+ return df
260
264
 
261
265
  def endpoints(self, details: bool = False) -> pd.DataFrame:
262
266
  """Get a summary of the Endpoints in AWS.
@@ -308,7 +312,7 @@ class AWSMeta:
308
312
  "Status": endpoint_details.get("EndpointStatus", "-"),
309
313
  "Config": endpoint_details.get("EndpointConfigName", "-"),
310
314
  "Variant": endpoint_details["config"]["variant"],
311
- "Capture": str(endpoint_details.get("DataCaptureConfig", {}).get("EnableCapture", "False")),
315
+ "Capture": str(endpoint_details.get("DataCaptureConfig", {}).get("EnableCapture", "-")),
312
316
  "Samp(%)": str(endpoint_details.get("DataCaptureConfig", {}).get("CurrentSamplingPercentage", "-")),
313
317
  "Tags": aws_tags.get("workbench_tags", "-"),
314
318
  "Monitored": endpoint_details["monitored"],
@@ -317,7 +321,9 @@ class AWSMeta:
317
321
 
318
322
  # Return the summary as a DataFrame
319
323
  df = pd.DataFrame(data_summary).convert_dtypes()
320
- return df.sort_values(by="Created", ascending=False)
324
+ if not df.empty:
325
+ df.sort_values(by="Created", ascending=False, inplace=True)
326
+ return df
321
327
 
322
328
  def _endpoint_config_info(self, endpoint_config_name: str) -> dict:
323
329
  """Internal: Get the Endpoint Configuration information for the given endpoint config name.
@@ -657,7 +663,8 @@ class AWSMeta:
657
663
  df = pd.DataFrame(data_summary).convert_dtypes()
658
664
 
659
665
  # Sort by the Modified column
660
- df = df.sort_values(by="Modified", ascending=False)
666
+ if not df.empty:
667
+ df = df.sort_values(by="Modified", ascending=False)
661
668
  return df
662
669
 
663
670
  def _aws_pipelines(self) -> pd.DataFrame: