workbench 0.8.168__py3-none-any.whl → 0.8.192__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. workbench/algorithms/dataframe/proximity.py +143 -102
  2. workbench/algorithms/graph/light/proximity_graph.py +2 -1
  3. workbench/api/compound.py +1 -1
  4. workbench/api/endpoint.py +3 -2
  5. workbench/api/feature_set.py +4 -4
  6. workbench/api/model.py +16 -12
  7. workbench/api/monitor.py +1 -16
  8. workbench/core/artifacts/artifact.py +11 -3
  9. workbench/core/artifacts/data_capture_core.py +355 -0
  10. workbench/core/artifacts/endpoint_core.py +113 -27
  11. workbench/core/artifacts/feature_set_core.py +72 -13
  12. workbench/core/artifacts/model_core.py +50 -15
  13. workbench/core/artifacts/monitor_core.py +33 -249
  14. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  15. workbench/core/cloud_platform/aws/aws_meta.py +11 -4
  16. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  17. workbench/core/transforms/features_to_model/features_to_model.py +9 -4
  18. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  19. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  20. workbench/core/views/training_view.py +49 -53
  21. workbench/core/views/view.py +51 -1
  22. workbench/core/views/view_utils.py +4 -4
  23. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  24. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  25. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  26. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  27. workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
  28. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  29. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
  30. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  31. workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
  32. workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
  33. workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
  34. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  35. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
  36. workbench/model_scripts/pytorch_model/pytorch.template +9 -18
  37. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  38. workbench/model_scripts/script_generation.py +7 -2
  39. workbench/model_scripts/uq_models/mapie.template +492 -0
  40. workbench/model_scripts/uq_models/requirements.txt +1 -0
  41. workbench/model_scripts/xgb_model/xgb_model.template +31 -40
  42. workbench/repl/workbench_shell.py +4 -4
  43. workbench/scripts/lambda_launcher.py +63 -0
  44. workbench/scripts/{ml_pipeline_launcher.py → ml_pipeline_batch.py} +49 -51
  45. workbench/scripts/ml_pipeline_sqs.py +186 -0
  46. workbench/utils/chem_utils/__init__.py +0 -0
  47. workbench/utils/chem_utils/fingerprints.py +134 -0
  48. workbench/utils/chem_utils/misc.py +194 -0
  49. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  50. workbench/utils/chem_utils/mol_standardize.py +450 -0
  51. workbench/utils/chem_utils/mol_tagging.py +348 -0
  52. workbench/utils/chem_utils/projections.py +209 -0
  53. workbench/utils/chem_utils/salts.py +256 -0
  54. workbench/utils/chem_utils/sdf.py +292 -0
  55. workbench/utils/chem_utils/toxicity.py +250 -0
  56. workbench/utils/chem_utils/vis.py +253 -0
  57. workbench/utils/config_manager.py +2 -6
  58. workbench/utils/endpoint_utils.py +5 -7
  59. workbench/utils/license_manager.py +2 -6
  60. workbench/utils/model_utils.py +76 -30
  61. workbench/utils/monitor_utils.py +44 -62
  62. workbench/utils/pandas_utils.py +3 -3
  63. workbench/utils/shap_utils.py +10 -2
  64. workbench/utils/workbench_sqs.py +1 -1
  65. workbench/utils/xgboost_model_utils.py +283 -145
  66. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  67. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  68. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  69. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/METADATA +2 -1
  70. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/RECORD +74 -70
  71. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/entry_points.txt +3 -1
  72. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  73. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  74. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  75. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  76. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  77. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  78. workbench/model_scripts/pytorch_model/generated_model_script.py +0 -576
  79. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  80. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  81. workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
  82. workbench/model_scripts/xgb_model/generated_model_script.py +0 -477
  83. workbench/utils/chem_utils.py +0 -1556
  84. workbench/utils/fast_inference.py +0 -167
  85. workbench/utils/resource_utils.py +0 -39
  86. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/WHEEL +0 -0
  87. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/licenses/LICENSE +0 -0
  88. {workbench-0.8.168.dist-info → workbench-0.8.192.dist-info}/top_level.txt +0 -0
@@ -2,14 +2,10 @@
2
2
 
3
3
  import logging
4
4
  import json
5
- import re
6
- from datetime import datetime
7
- from typing import Union, Tuple
5
+ from typing import Union
8
6
  import pandas as pd
9
- from sagemaker import Predictor
10
7
  from sagemaker.model_monitor import (
11
8
  CronExpressionGenerator,
12
- DataCaptureConfig,
13
9
  DefaultModelMonitor,
14
10
  DatasetFormat,
15
11
  )
@@ -17,29 +13,32 @@ import awswrangler as wr
17
13
 
18
14
  # Workbench Imports
19
15
  from workbench.core.artifacts.endpoint_core import EndpointCore
16
+ from workbench.core.artifacts.data_capture_core import DataCaptureCore
20
17
  from workbench.api import Model, FeatureSet
21
18
  from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
22
19
  from workbench.utils.s3_utils import read_content_from_s3, upload_content_to_s3
23
20
  from workbench.utils.datetime_utils import datetime_string
24
21
  from workbench.utils.monitor_utils import (
25
- process_data_capture,
26
22
  get_monitor_json_data,
27
23
  parse_monitoring_results,
28
24
  preprocessing_script,
29
25
  )
30
26
 
31
- # Note: This resource might come in handy when doing code refactoring
27
+ # Note: These resources might come in handy when doing code refactoring
32
28
  # https://github.com/aws-samples/amazon-sagemaker-from-idea-to-production/blob/master/06-monitoring.ipynb
33
29
  # https://docs.aws.amazon.com/sagemaker/latest/dg/model-monitor-pre-and-post-processing.html
34
30
  # https://github.com/aws/amazon-sagemaker-examples/blob/main/sagemaker_model_monitor/introduction/SageMaker-ModelMonitoring.ipynb
35
31
 
36
32
 
37
33
  class MonitorCore:
34
+ """Manages monitoring, baselines, and monitoring schedules for SageMaker endpoints"""
35
+
38
36
  def __init__(self, endpoint_name, instance_type="ml.m5.large"):
39
37
  """MonitorCore Class
38
+
40
39
  Args:
41
40
  endpoint_name (str): Name of the endpoint to set up monitoring for
42
- instance_type (str): Instance type to use for monitoring. Defaults to "ml.t3.medium".
41
+ instance_type (str): Instance type to use for monitoring. Defaults to "ml.m5.large".
43
42
  """
44
43
  self.log = logging.getLogger("workbench")
45
44
  self.endpoint_name = endpoint_name
@@ -48,7 +47,6 @@ class MonitorCore:
48
47
  # Initialize Class Attributes
49
48
  self.sagemaker_session = self.endpoint.sm_session
50
49
  self.sagemaker_client = self.endpoint.sm_client
51
- self.data_capture_path = self.endpoint.endpoint_data_capture_path
52
50
  self.monitoring_path = self.endpoint.endpoint_monitoring_path
53
51
  self.monitoring_schedule_name = f"{self.endpoint_name}-monitoring-schedule"
54
52
  self.baseline_dir = f"{self.monitoring_path}/baseline"
@@ -59,6 +57,10 @@ class MonitorCore:
59
57
  self.workbench_role_arn = AWSAccountClamp().aws_session.get_workbench_execution_role_arn()
60
58
  self.instance_type = instance_type
61
59
 
60
+ # Create DataCaptureCore instance for composition
61
+ self.data_capture = DataCaptureCore(endpoint_name)
62
+ self.data_capture_path = self.data_capture.data_capture_path
63
+
62
64
  # Check if a monitoring schedule already exists for this endpoint
63
65
  existing_schedule = self.monitoring_schedule_exists()
64
66
 
@@ -76,23 +78,20 @@ class MonitorCore:
76
78
  self.log.info(f"Initialized new model monitor for {self.endpoint_name}")
77
79
 
78
80
  def summary(self) -> dict:
79
- """Return the summary of information about the endpoint monitor
81
+ """Return the summary of monitoring configuration
80
82
 
81
83
  Returns:
82
- dict: Summary of information about the endpoint monitor
84
+ dict: Summary of monitoring status
83
85
  """
84
86
  if self.endpoint.is_serverless():
85
87
  return {
86
88
  "endpoint_type": "serverless",
87
- "data_capture": "not supported",
88
89
  "baseline": "not supported",
89
90
  "monitoring_schedule": "not supported",
90
91
  }
91
92
  else:
92
93
  summary = {
93
94
  "endpoint_type": "realtime",
94
- "data_capture": self.data_capture_enabled(),
95
- "capture_percent": self.data_capture_percent(),
96
95
  "baseline": self.baseline_exists(),
97
96
  "monitoring_schedule": self.monitoring_schedule_exists(),
98
97
  "preprocessing": self.preprocessing_exists(),
@@ -105,22 +104,15 @@ class MonitorCore:
105
104
  Returns:
106
105
  dict: The monitoring details for the endpoint
107
106
  """
108
- # Get the actual data capture path
109
- actual_capture_path = self.data_capture_config()["DestinationS3Uri"]
110
- if actual_capture_path != self.data_capture_path:
111
- self.log.warning(
112
- f"Data capture path mismatch: Expected {self.data_capture_path}, "
113
- f"but found {actual_capture_path}. Using the actual path."
114
- )
115
- self.data_capture_path = actual_capture_path
116
107
  result = self.summary()
117
108
  info = {
118
- "data_capture_path": self.data_capture_path if self.data_capture_enabled() else None,
119
- "preprocessing_script_file": self.preprocessing_script_file if self.preprocessing_exists() else None,
120
109
  "monitoring_schedule_status": "Not Scheduled",
121
110
  }
122
111
  result.update(info)
123
112
 
113
+ if self.preprocessing_exists():
114
+ result["preprocessing_script_file"] = self.preprocessing_script_file
115
+
124
116
  if self.baseline_exists():
125
117
  result.update(
126
118
  {
@@ -146,7 +138,6 @@ class MonitorCore:
146
138
 
147
139
  last_run = schedule_details.get("LastMonitoringExecutionSummary", {})
148
140
  if last_run:
149
-
150
141
  # If no inference was run since the last monitoring schedule, the
151
142
  # status will be "Failed" with reason "Job inputs had no data",
152
143
  # so we check for that and set the status to "No New Data"
@@ -164,186 +155,22 @@ class MonitorCore:
164
155
 
165
156
  return result
166
157
 
167
- def enable_data_capture(self, capture_percentage=100, force=False):
168
- """
169
- Enable data capture for the SageMaker endpoint.
158
+ def enable_data_capture(self, capture_percentage=100):
159
+ """Enable data capture for the endpoint
170
160
 
171
161
  Args:
172
- capture_percentage (int): Percentage of data to capture. Defaults to 100.
173
- force (bool): If True, force reconfiguration even if data capture is already enabled.
162
+ capture_percentage (int): Percentage of requests to capture (0-100, default 100)
174
163
  """
175
- # Early returns for cases where we can't/don't need to add data capture
176
164
  if self.endpoint.is_serverless():
177
165
  self.log.warning("Data capture is not supported for serverless endpoints.")
178
166
  return
179
167
 
180
- if self.data_capture_enabled() and not force:
181
- self.log.important(f"Data capture already configured for {self.endpoint_name}.")
168
+ if self.data_capture.is_enabled():
169
+ self.log.info(f"Data capture is already enabled for {self.endpoint_name}.")
182
170
  return
183
171
 
184
- # Get the current endpoint configuration name for later deletion
185
- current_endpoint_config_name = self.endpoint.endpoint_config_name()
186
-
187
- # Log the data capture operation
188
- self.log.important(f"Enabling Data Capture for {self.endpoint_name} --> {self.data_capture_path}")
189
- self.log.important("This normally redeploys the endpoint...")
190
-
191
- # Create and apply the data capture configuration
192
- data_capture_config = DataCaptureConfig(
193
- enable_capture=True, # Required parameter
194
- sampling_percentage=capture_percentage,
195
- destination_s3_uri=self.data_capture_path,
196
- )
197
-
198
- # Update endpoint with the new capture configuration
199
- Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
200
- data_capture_config=data_capture_config
201
- )
202
-
203
- # Clean up old endpoint configuration
204
- self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
205
-
206
- def data_capture_config(self):
207
- """
208
- Returns the complete data capture configuration from the endpoint config.
209
- Returns:
210
- dict: Complete DataCaptureConfig from AWS, or None if not configured
211
- """
212
- config_name = self.endpoint.endpoint_config_name()
213
- response = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=config_name)
214
- data_capture_config = response.get("DataCaptureConfig")
215
- if not data_capture_config:
216
- self.log.error(f"No data capture configuration found for endpoint config {config_name}")
217
- return None
218
- return data_capture_config
219
-
220
- def disable_data_capture(self):
221
- """
222
- Disable data capture for the SageMaker endpoint.
223
- """
224
- # Early return if data capture isn't configured
225
- if not self.data_capture_enabled():
226
- self.log.important(f"Data capture is not currently enabled for {self.endpoint_name}.")
227
- return
228
-
229
- # Get the current endpoint configuration name for later deletion
230
- current_endpoint_config_name = self.endpoint.endpoint_config_name()
231
-
232
- # Log the operation
233
- self.log.important(f"Disabling Data Capture for {self.endpoint_name}")
234
- self.log.important("This normally redeploys the endpoint...")
235
-
236
- # Create a configuration with capture disabled
237
- data_capture_config = DataCaptureConfig(enable_capture=False, destination_s3_uri=self.data_capture_path)
238
-
239
- # Update endpoint with the new configuration
240
- Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
241
- data_capture_config=data_capture_config
242
- )
243
-
244
- # Clean up old endpoint configuration
245
- self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
246
-
247
- def data_capture_enabled(self):
248
- """
249
- Check if data capture is already configured on the endpoint.
250
- Args:
251
- capture_percentage (int): Expected data capture percentage.
252
- Returns:
253
- bool: True if data capture is already configured, False otherwise.
254
- """
255
- try:
256
- endpoint_config_name = self.endpoint.endpoint_config_name()
257
- endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
258
- data_capture_config = endpoint_config.get("DataCaptureConfig", {})
259
-
260
- # Check if data capture is enabled and the percentage matches
261
- is_enabled = data_capture_config.get("EnableCapture", False)
262
- return is_enabled
263
- except Exception as e:
264
- self.log.error(f"Error checking data capture configuration: {e}")
265
- return False
266
-
267
- def data_capture_percent(self):
268
- """
269
- Get the data capture percentage from the endpoint configuration.
270
-
271
- Returns:
272
- int: Data capture percentage if enabled, None otherwise.
273
- """
274
- try:
275
- endpoint_config_name = self.endpoint.endpoint_config_name()
276
- endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
277
- data_capture_config = endpoint_config.get("DataCaptureConfig", {})
278
-
279
- # Check if data capture is enabled and return the percentage
280
- if data_capture_config.get("EnableCapture", False):
281
- return data_capture_config.get("InitialSamplingPercentage", 0)
282
- else:
283
- return None
284
- except Exception as e:
285
- self.log.error(f"Error checking data capture percentage: {e}")
286
- return None
287
-
288
- def get_captured_data(self, from_date=None, add_timestamp=True) -> Tuple[pd.DataFrame, pd.DataFrame]:
289
- """
290
- Read and process captured data from S3.
291
-
292
- Args:
293
- from_date (str, optional): Only process files from this date onwards (YYYY-MM-DD format).
294
- Defaults to None to process all files.
295
- add_timestamp (bool, optional): Whether to add a timestamp column to the DataFrame.
296
-
297
- Returns:
298
- Tuple[pd.DataFrame, pd.DataFrame]: Processed input and output DataFrames.
299
- """
300
- files = wr.s3.list_objects(self.data_capture_path)
301
- if not files:
302
- self.log.warning(f"No data capture files found in {self.data_capture_path}.")
303
- return pd.DataFrame(), pd.DataFrame()
304
-
305
- # Filter by date if specified
306
- if from_date:
307
- from_date_obj = datetime.strptime(from_date, "%Y-%m-%d").date()
308
- files = [f for f in files if self._file_date_filter(f, from_date_obj)]
309
- self.log.info(f"Processing {len(files)} files from {from_date} onwards.")
310
- else:
311
- self.log.info(f"Processing all {len(files)} files.")
312
- files.sort()
313
-
314
- # Process files
315
- all_input_dfs, all_output_dfs = [], []
316
- for file_path in files:
317
- try:
318
- df = wr.s3.read_json(path=file_path, lines=True)
319
- if not df.empty:
320
- input_df, output_df = process_data_capture(df)
321
- if add_timestamp:
322
- timestamp = wr.s3.describe_objects(path=file_path)[file_path]["LastModified"]
323
- output_df["timestamp"] = timestamp
324
- all_input_dfs.append(input_df)
325
- all_output_dfs.append(output_df)
326
- except Exception as e:
327
- self.log.warning(f"Error processing {file_path}: {e}")
328
-
329
- if not all_input_dfs:
330
- self.log.warning("No valid data was processed.")
331
- return pd.DataFrame(), pd.DataFrame()
332
-
333
- return pd.concat(all_input_dfs, ignore_index=True), pd.concat(all_output_dfs, ignore_index=True)
334
-
335
- def _file_date_filter(self, file_path, from_date_obj):
336
- """Extract date from S3 path and compare with from_date."""
337
- try:
338
- # Match YYYY/MM/DD pattern in the path
339
- date_match = re.search(r"/(\d{4})/(\d{2})/(\d{2})/", file_path)
340
- if date_match:
341
- year, month, day = date_match.groups()
342
- file_date = datetime(int(year), int(month), int(day)).date()
343
- return file_date >= from_date_obj
344
- return False # No date pattern found
345
- except ValueError:
346
- return False
172
+ self.data_capture.enable(capture_percentage=capture_percentage)
173
+ self.log.important(f"Enabled data capture for {self.endpoint_name} at {self.data_capture_path}")
347
174
 
348
175
  def baseline_exists(self) -> bool:
349
176
  """
@@ -534,6 +361,11 @@ class MonitorCore:
534
361
  self.log.warning("If you want to create another one, delete existing schedule first.")
535
362
  return
536
363
 
364
+ # Check if data capture is enabled, if not enable it
365
+ if not self.data_capture.is_enabled():
366
+ self.log.warning("Data capture is not enabled for this endpoint. Enabling it now...")
367
+ self.enable_data_capture(capture_percentage=100)
368
+
537
369
  # Set up a NEW monitoring schedule
538
370
  schedule_args = {
539
371
  "monitor_schedule_name": self.monitoring_schedule_name,
@@ -578,33 +410,6 @@ class MonitorCore:
578
410
  self.model_monitor.delete_monitoring_schedule()
579
411
  self.log.important(f"Deleted monitoring schedule for {self.endpoint_name}.")
580
412
 
581
- # Put this functionality into this class
582
- """
583
- executions = my_monitor.list_executions()
584
- latest_execution = executions[-1]
585
-
586
- latest_execution.describe()['ProcessingJobStatus']
587
- latest_execution.describe()['ExitMessage']
588
- Here are the possible terminal states and what each of them means:
589
-
590
- - Completed - This means the monitoring execution completed and no issues were found in the violations report.
591
- - CompletedWithViolations - This means the execution completed, but constraint violations were detected.
592
- - Failed - The monitoring execution failed, maybe due to client error
593
- (perhaps incorrect role premissions) or infrastructure issues. Further
594
- examination of the FailureReason and ExitMessage is necessary to identify what exactly happened.
595
- - Stopped - job exceeded the max runtime or was manually stopped.
596
- You can also get the S3 URI for the output with latest_execution.output.destination and analyze the results.
597
-
598
- Visualize results
599
- You can use the monitor object to gather reports for visualization:
600
-
601
- suggested_constraints = my_monitor.suggested_constraints()
602
- baseline_statistics = my_monitor.baseline_statistics()
603
-
604
- latest_monitoring_violations = my_monitor.latest_monitoring_constraint_violations()
605
- latest_monitoring_statistics = my_monitor.latest_monitoring_statistics()
606
- """
607
-
608
413
  def get_monitoring_results(self, max_results=10) -> pd.DataFrame:
609
414
  """Get the results of monitoring executions
610
415
 
@@ -759,7 +564,7 @@ class MonitorCore:
759
564
  Returns:
760
565
  str: String representation of this MonitorCore object
761
566
  """
762
- summary_dict = {} # Disabling for now self.summary()
567
+ summary_dict = self.summary()
763
568
  summary_items = [f" {repr(key)}: {repr(value)}" for key, value in summary_dict.items()]
764
569
  summary_str = f"{self.__class__.__name__}: {self.endpoint_name}\n" + ",\n".join(summary_items)
765
570
  return summary_str
@@ -776,7 +581,6 @@ if __name__ == "__main__":
776
581
 
777
582
  # Create the Class and test it out
778
583
  endpoint_name = "abalone-regression-rt"
779
- endpoint_name = "logd-dev-reg-rt"
780
584
  my_endpoint = EndpointCore(endpoint_name)
781
585
  if not my_endpoint.exists():
782
586
  print(f"Endpoint {endpoint_name} does not exist.")
@@ -789,11 +593,10 @@ if __name__ == "__main__":
789
593
  # Check the details of the monitoring class
790
594
  pprint(mm.details())
791
595
 
792
- # Enable data capture on the endpoint
793
- mm.enable_data_capture()
596
+ # Enable data capture (if not already enabled)
597
+ mm.enable_data_capture(capture_percentage=100)
794
598
 
795
599
  # Create a baseline for monitoring
796
- # mm.create_baseline(recreate=True)
797
600
  mm.create_baseline()
798
601
 
799
602
  # Check the monitoring outputs
@@ -805,30 +608,11 @@ if __name__ == "__main__":
805
608
  pprint(mm.get_constraints())
806
609
 
807
610
  print("\nStatistics...")
808
- print(mm.get_statistics())
611
+ print(str(mm.get_statistics())[:1000]) # Print only first 1000 characters
809
612
 
810
613
  # Set up the monitoring schedule (if it doesn't already exist)
811
614
  mm.create_monitoring_schedule()
812
615
 
813
- #
814
- # Test the data capture by running some predictions
815
- #
816
-
817
- # Make predictions on the Endpoint using the FeatureSet evaluation data
818
- # pred_df = my_endpoint.auto_inference()
819
- # print(pred_df.head())
820
-
821
- # Check that data capture is working
822
- input_df, output_df = mm.get_captured_data()
823
- if input_df.empty or output_df.empty:
824
- print("No data capture files found, for a new endpoint it may take a few minutes to start capturing data")
825
- else:
826
- print("Found data capture files")
827
- print("Input")
828
- print(input_df.head())
829
- print("Output")
830
- print(output_df.head())
831
-
832
616
  # Test update_constraints (commented out for now)
833
617
  # print("\nTesting constraint updates...")
834
618
  # custom_constraints = {"sex": {"allowed_values": ["M", "F", "I"]}, "length": {"min": 0.0, "max": 1.0}}
@@ -847,7 +631,7 @@ if __name__ == "__main__":
847
631
  print("\nTesting execution details retrieval...")
848
632
  if not results_df.empty:
849
633
  latest_execution_arn = results_df.iloc[0]["processing_job_arn"]
850
- execution_details = mm.get_execution_details(latest_execution_arn)
634
+ execution_details = mm.get_execution_details(latest_execution_arn) if latest_execution_arn else None
851
635
  if execution_details:
852
636
  print(f"Execution details for {latest_execution_arn}:")
853
637
  pprint(execution_details)
@@ -54,7 +54,11 @@ class AWSAccountClamp:
54
54
 
55
55
  # Check our Assume Role
56
56
  self.log.info("Checking Workbench Assumed Role...")
57
- self.aws_session.assumed_role_info()
57
+ role_info = self.aws_session.assumed_role_info()
58
+ self.log.info(f"Assumed Role: {role_info}")
59
+
60
+ # Check if we have tag write permissions (if we don't, we are read-only)
61
+ self.read_only = not self.check_tag_permissions()
58
62
 
59
63
  # Check our Workbench API Key and Load the License
60
64
  self.log.info("Checking Workbench API License...")
@@ -138,6 +142,45 @@ class AWSAccountClamp:
138
142
  """
139
143
  return self.boto3_session.client("sagemaker")
140
144
 
145
+ def check_tag_permissions(self):
146
+ """Check if current role has permission to add tags to SageMaker endpoints.
147
+
148
+ Returns:
149
+ bool: True if AddTags is allowed, False otherwise
150
+ """
151
+ try:
152
+ sagemaker = self.boto3_session.client("sagemaker")
153
+
154
+ # Use a non-existent endpoint name
155
+ fake_endpoint = "workbench-permission-check-dummy-endpoint"
156
+
157
+ # Try to add tags to the non-existent endpoint
158
+ sagemaker.add_tags(
159
+ ResourceArn=f"arn:aws:sagemaker:{self.region}:{self.account_id}:endpoint/{fake_endpoint}",
160
+ Tags=[{"Key": "PermissionCheck", "Value": "Test"}],
161
+ )
162
+
163
+ # If we get here, we have permission (but endpoint doesn't exist)
164
+ return True
165
+
166
+ except ClientError as e:
167
+ error_code = e.response["Error"]["Code"]
168
+
169
+ # AccessDeniedException = no permission
170
+ if error_code == "AccessDeniedException":
171
+ self.log.debug("No AddTags permission (AccessDeniedException)")
172
+ return False
173
+
174
+ # ResourceNotFound = we have permission, but endpoint doesn't exist
175
+ elif error_code in ["ResourceNotFound", "ValidationException"]:
176
+ self.log.debug("AddTags permission verified (resource not found)")
177
+ return True
178
+
179
+ # Unexpected error, assume no permission for safety
180
+ else:
181
+ self.log.debug(f"Unexpected error checking permissions: {error_code}")
182
+ return False
183
+
141
184
 
142
185
  if __name__ == "__main__":
143
186
  """Exercise the AWS Account Clamp Class"""
@@ -162,3 +205,9 @@ if __name__ == "__main__":
162
205
  print("\n\n*** AWS Sagemaker Session/Client Check ***")
163
206
  sm_client = aws_account_clamp.sagemaker_client()
164
207
  print(sm_client.list_feature_groups()["FeatureGroupSummaries"])
208
+
209
+ print("\n\n*** AWS Tag Permission Check ***")
210
+ if aws_account_clamp.check_tag_permissions():
211
+ print("Tag Permission Check Success...")
212
+ else:
213
+ print("Tag Permission Check Failed...")
@@ -196,7 +196,9 @@ class AWSMeta:
196
196
 
197
197
  # Return the summary as a DataFrame
198
198
  df = pd.DataFrame(data_summary).convert_dtypes()
199
- return df.sort_values(by="Created", ascending=False)
199
+ if not df.empty:
200
+ df.sort_values(by="Created", ascending=False, inplace=True)
201
+ return df
200
202
 
201
203
  def models(self, details: bool = False) -> pd.DataFrame:
202
204
  """Get a summary of the Models in AWS.
@@ -256,7 +258,9 @@ class AWSMeta:
256
258
 
257
259
  # Return the summary as a DataFrame
258
260
  df = pd.DataFrame(model_summary).convert_dtypes()
259
- return df.sort_values(by="Created", ascending=False)
261
+ if not df.empty:
262
+ df.sort_values(by="Created", ascending=False, inplace=True)
263
+ return df
260
264
 
261
265
  def endpoints(self, details: bool = False) -> pd.DataFrame:
262
266
  """Get a summary of the Endpoints in AWS.
@@ -317,7 +321,9 @@ class AWSMeta:
317
321
 
318
322
  # Return the summary as a DataFrame
319
323
  df = pd.DataFrame(data_summary).convert_dtypes()
320
- return df.sort_values(by="Created", ascending=False)
324
+ if not df.empty:
325
+ df.sort_values(by="Created", ascending=False, inplace=True)
326
+ return df
321
327
 
322
328
  def _endpoint_config_info(self, endpoint_config_name: str) -> dict:
323
329
  """Internal: Get the Endpoint Configuration information for the given endpoint config name.
@@ -657,7 +663,8 @@ class AWSMeta:
657
663
  df = pd.DataFrame(data_summary).convert_dtypes()
658
664
 
659
665
  # Sort by the Modified column
660
- df = df.sort_values(by="Modified", ascending=False)
666
+ if not df.empty:
667
+ df = df.sort_values(by="Modified", ascending=False)
661
668
  return df
662
669
 
663
670
  def _aws_pipelines(self) -> pd.DataFrame:
@@ -1,7 +1,7 @@
1
1
  """MolecularDescriptors: Compute a Feature Set based on RDKit Descriptors
2
2
 
3
- Note: An alternative to using this class is to use the `compute_molecular_descriptors` function directly.
4
- df_features = compute_molecular_descriptors(df)
3
+ Note: An alternative to using this class is to use the `compute_descriptors` function directly.
4
+ df_features = compute_descriptors(df)
5
5
  to_features = PandasToFeatures("my_feature_set")
6
6
  to_features.set_input(df_features, id_column="id")
7
7
  to_features.set_output_tags(["blah", "whatever"])
@@ -10,7 +10,7 @@ Note: An alternative to using this class is to use the `compute_molecular_descri
10
10
 
11
11
  # Local Imports
12
12
  from workbench.core.transforms.data_to_features.light.data_to_features_light import DataToFeaturesLight
13
- from workbench.utils.chem_utils import compute_molecular_descriptors
13
+ from workbench.utils.chem_utils.mol_descriptors import compute_descriptors
14
14
 
15
15
 
16
16
  class MolecularDescriptors(DataToFeaturesLight):
@@ -39,7 +39,7 @@ class MolecularDescriptors(DataToFeaturesLight):
39
39
  """Compute a Feature Set based on RDKit Descriptors"""
40
40
 
41
41
  # Compute/add all the Molecular Descriptors
42
- self.output_df = compute_molecular_descriptors(self.input_df)
42
+ self.output_df = compute_descriptors(self.input_df)
43
43
 
44
44
 
45
45
  if __name__ == "__main__":
@@ -37,8 +37,8 @@ class FeaturesToModel(Transform):
37
37
  model_import_str=None,
38
38
  custom_script=None,
39
39
  custom_args=None,
40
- training_image="xgb_training",
41
- inference_image="xgb_inference",
40
+ training_image="training",
41
+ inference_image="inference",
42
42
  inference_arch="x86_64",
43
43
  ):
44
44
  """FeaturesToModel Initialization
@@ -50,8 +50,8 @@ class FeaturesToModel(Transform):
50
50
  model_import_str (str, optional): The import string for the model (default None)
51
51
  custom_script (str, optional): Custom script to use for the model (default None)
52
52
  custom_args (dict, optional): Custom arguments to pass to custom model scripts (default None)
53
- training_image (str, optional): Training image (default "xgb_training")
54
- inference_image (str, optional): Inference image (default "xgb_inference")
53
+ training_image (str, optional): Training image (default "training")
54
+ inference_image (str, optional): Inference image (default "inference")
55
55
  inference_arch (str, optional): Inference architecture (default "x86_64")
56
56
  """
57
57
 
@@ -264,6 +264,11 @@ class FeaturesToModel(Transform):
264
264
  self.log.important(f"Creating new model {self.output_name}...")
265
265
  self.create_and_register_model(**kwargs)
266
266
 
267
+ # Make a copy of the training view, to lock-in the training data used for this model
268
+ model_training_view_name = f"{self.output_name.replace('-', '_')}_training"
269
+ self.log.important(f"Creating Model Training View: {model_training_view_name}...")
270
+ feature_set.view("training").copy(f"{model_training_view_name}")
271
+
267
272
  def post_transform(self, **kwargs):
268
273
  """Post-Transform: Calling onboard() on the Model"""
269
274
  self.log.info("Post-Transform: Calling onboard() on the Model...")