workbench 0.8.171__py3-none-any.whl → 0.8.173__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. workbench/algorithms/graph/light/proximity_graph.py +2 -1
  2. workbench/api/compound.py +1 -1
  3. workbench/api/feature_set.py +4 -4
  4. workbench/api/monitor.py +1 -16
  5. workbench/core/artifacts/artifact.py +11 -3
  6. workbench/core/artifacts/data_capture_core.py +315 -0
  7. workbench/core/artifacts/endpoint_core.py +9 -3
  8. workbench/core/artifacts/model_core.py +37 -14
  9. workbench/core/artifacts/monitor_core.py +33 -249
  10. workbench/core/cloud_platform/aws/aws_account_clamp.py +4 -1
  11. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  12. workbench/core/transforms/features_to_model/features_to_model.py +4 -4
  13. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +471 -0
  14. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +428 -0
  15. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  16. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +19 -9
  17. workbench/model_scripts/custom_models/uq_models/mapie.template +502 -0
  18. workbench/model_scripts/custom_models/uq_models/meta_uq.template +8 -5
  19. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  20. workbench/model_scripts/script_generation.py +5 -0
  21. workbench/model_scripts/xgb_model/generated_model_script.py +5 -5
  22. workbench/repl/workbench_shell.py +3 -3
  23. workbench/utils/chem_utils/__init__.py +0 -0
  24. workbench/utils/chem_utils/fingerprints.py +134 -0
  25. workbench/utils/chem_utils/misc.py +194 -0
  26. workbench/utils/chem_utils/mol_descriptors.py +471 -0
  27. workbench/utils/chem_utils/mol_standardize.py +428 -0
  28. workbench/utils/chem_utils/mol_tagging.py +348 -0
  29. workbench/utils/chem_utils/projections.py +209 -0
  30. workbench/utils/chem_utils/salts.py +256 -0
  31. workbench/utils/chem_utils/sdf.py +292 -0
  32. workbench/utils/chem_utils/toxicity.py +250 -0
  33. workbench/utils/chem_utils/vis.py +253 -0
  34. workbench/utils/model_utils.py +1 -1
  35. workbench/utils/monitor_utils.py +49 -56
  36. workbench/utils/pandas_utils.py +3 -3
  37. workbench/utils/workbench_sqs.py +1 -1
  38. workbench/utils/xgboost_model_utils.py +1 -0
  39. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  40. {workbench-0.8.171.dist-info → workbench-0.8.173.dist-info}/METADATA +1 -1
  41. {workbench-0.8.171.dist-info → workbench-0.8.173.dist-info}/RECORD +45 -34
  42. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  43. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  44. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  45. workbench/utils/chem_utils.py +0 -1556
  46. {workbench-0.8.171.dist-info → workbench-0.8.173.dist-info}/WHEEL +0 -0
  47. {workbench-0.8.171.dist-info → workbench-0.8.173.dist-info}/entry_points.txt +0 -0
  48. {workbench-0.8.171.dist-info → workbench-0.8.173.dist-info}/licenses/LICENSE +0 -0
  49. {workbench-0.8.171.dist-info → workbench-0.8.173.dist-info}/top_level.txt +0 -0
@@ -135,7 +135,8 @@ if __name__ == "__main__":
135
135
  from workbench.algorithms.dataframe.fingerprint_proximity import FingerprintProximity
136
136
  from workbench.web_interface.components.plugins.graph_plot import GraphPlot
137
137
  from workbench.api import DFStore
138
- from workbench.utils.chem_utils import compute_morgan_fingerprints, project_fingerprints
138
+ from workbench.utils.chem_utils.fingerprints import compute_morgan_fingerprints
139
+ from workbench.utils.chem_utils.projections import project_fingerprints
139
140
  from workbench.utils.graph_utils import connected_sample, graph_layout
140
141
 
141
142
  def show_graph(graph, id_column):
workbench/api/compound.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
3
  from typing import List
4
4
 
5
5
  # Workbench Imports
6
- from workbench.utils.chem_utils import svg_from_smiles
6
+ from workbench.utils.chem_utils.vis import svg_from_smiles
7
7
 
8
8
 
9
9
  @dataclass
@@ -87,8 +87,8 @@ class FeatureSet(FeatureSetCore):
87
87
  model_import_str: str = None,
88
88
  custom_script: Union[str, Path] = None,
89
89
  custom_args: dict = None,
90
- training_image: str = "xgb_training",
91
- inference_image: str = "xgb_inference",
90
+ training_image: str = "training",
91
+ inference_image: str = "inference",
92
92
  inference_arch: str = "x86_64",
93
93
  **kwargs,
94
94
  ) -> Union[Model, None]:
@@ -105,8 +105,8 @@ class FeatureSet(FeatureSetCore):
105
105
  model_class (str, optional): Model class to use (e.g. "KMeans", "PyTorch", default: None)
106
106
  model_import_str (str, optional): The import for the model (e.g. "from sklearn.cluster import KMeans")
107
107
  custom_script (str, optional): The custom script to use for the model (default: None)
108
- training_image (str, optional): The training image to use (default: "xgb_training")
109
- inference_image (str, optional): The inference image to use (default: "xgb_inference")
108
+ training_image (str, optional): The training image to use (default: "training")
109
+ inference_image (str, optional): The inference image to use (default: "inference")
110
110
  inference_arch (str, optional): The architecture to use for inference (default: "x86_64")
111
111
  kwargs (dict, optional): Additional keyword arguments to pass to the model
112
112
 
workbench/api/monitor.py CHANGED
@@ -15,7 +15,7 @@ class Monitor(MonitorCore):
15
15
 
16
16
  Common Usage:
17
17
  ```
18
- mon = Endpoint(name).get_monitor() # Pull from endpoint OR
18
+ mon = Endpoint(name).monitor() # Pull from endpoint OR
19
19
  mon = Monitor(name) # Create using Endpoint Name
20
20
  mon.summary()
21
21
  mon.details()
@@ -29,7 +29,6 @@ class Monitor(MonitorCore):
29
29
  baseline_df = mon.get_baseline()
30
30
  constraints_df = mon.get_constraints()
31
31
  stats_df = mon.get_statistics()
32
- input_df, output_df = mon.get_captured_data()
33
32
  ```
34
33
  """
35
34
 
@@ -81,15 +80,6 @@ class Monitor(MonitorCore):
81
80
  """
82
81
  super().create_monitoring_schedule(schedule)
83
82
 
84
- def get_captured_data(self) -> (pd.DataFrame, pd.DataFrame):
85
- """
86
- Get the latest data capture input and output from S3.
87
-
88
- Returns:
89
- DataFrame (input), DataFrame(output): Flattened and processed DataFrames for input and output data.
90
- """
91
- return super().get_captured_data()
92
-
93
83
  def get_baseline(self) -> Union[pd.DataFrame, None]:
94
84
  """Code to get the baseline CSV from the S3 baseline directory
95
85
 
@@ -155,8 +145,3 @@ if __name__ == "__main__":
155
145
 
156
146
  print("\nStatistics...")
157
147
  print(mm.get_statistics())
158
-
159
- # Get the latest data capture
160
- input_df, output_df = mm.get_captured_data()
161
- print(input_df.head())
162
- print(output_df.head())
@@ -236,6 +236,12 @@ class Artifact(ABC):
236
236
  This functionality will work for FeatureSets, Models, and Endpoints
237
237
  but not for DataSources. The DataSource class overrides this method.
238
238
  """
239
+
240
+ # Check for ReadOnly Role
241
+ if self.aws_account_clamp.read_only_role:
242
+ self.log.info("Cannot add metadata with a ReadOnly Role...")
243
+ return
244
+
239
245
  # Sanity check
240
246
  aws_arn = self.arn()
241
247
  if aws_arn is None:
@@ -444,10 +450,12 @@ class Artifact(ABC):
444
450
 
445
451
  if __name__ == "__main__":
446
452
  """Exercise the Artifact Class"""
447
- from workbench.api.data_source import DataSource
448
- from workbench.api.feature_set import FeatureSet
453
+ from workbench.api import DataSource, FeatureSet, Endpoint
454
+
455
+ # Grab an Endpoint (which is a subclass of Artifact)
456
+ end = Endpoint("wine-classification")
449
457
 
450
- # Create a DataSource (which is a subclass of Artifact)
458
+ # Grab a DataSource (which is a subclass of Artifact)
451
459
  data_source = DataSource("test_data")
452
460
 
453
461
  # Just some random tests
@@ -0,0 +1,315 @@
1
+ """DataCaptureCore class for managing SageMaker endpoint data capture"""
2
+
3
+ import logging
4
+ import re
5
+ from datetime import datetime
6
+ from typing import Tuple
7
+ import pandas as pd
8
+ from sagemaker import Predictor
9
+ from sagemaker.model_monitor import DataCaptureConfig
10
+ import awswrangler as wr
11
+
12
+ # Workbench Imports
13
+ from workbench.core.artifacts.endpoint_core import EndpointCore
14
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
15
+ from workbench.utils.monitor_utils import process_data_capture
16
+
17
+
18
+ class DataCaptureCore:
19
+ """Manages data capture configuration and retrieval for SageMaker endpoints"""
20
+
21
+ def __init__(self, endpoint_name: str):
22
+ """DataCaptureCore Class
23
+
24
+ Args:
25
+ endpoint_name (str): Name of the endpoint to manage data capture for
26
+ """
27
+ self.log = logging.getLogger("workbench")
28
+ self.endpoint_name = endpoint_name
29
+ self.endpoint = EndpointCore(self.endpoint_name)
30
+
31
+ # Initialize Class Attributes
32
+ self.sagemaker_session = self.endpoint.sm_session
33
+ self.sagemaker_client = self.endpoint.sm_client
34
+ self.data_capture_path = self.endpoint.endpoint_data_capture_path
35
+ self.workbench_role_arn = AWSAccountClamp().aws_session.get_workbench_execution_role_arn()
36
+
37
+ def summary(self) -> dict:
38
+ """Return the summary of data capture configuration
39
+
40
+ Returns:
41
+ dict: Summary of data capture status
42
+ """
43
+ if self.endpoint.is_serverless():
44
+ return {"endpoint_type": "serverless", "data_capture": "not supported"}
45
+ else:
46
+ return {
47
+ "endpoint_type": "realtime",
48
+ "data_capture_enabled": self.is_enabled(),
49
+ "capture_percentage": self.capture_percentage(),
50
+ "capture_modes": self.capture_modes() if self.is_enabled() else [],
51
+ "data_capture_path": self.data_capture_path if self.is_enabled() else None,
52
+ }
53
+
54
+ def enable(self, capture_percentage=100, capture_options=None, force_redeploy=False):
55
+ """
56
+ Enable data capture for the SageMaker endpoint.
57
+
58
+ Args:
59
+ capture_percentage (int): Percentage of data to capture. Defaults to 100.
60
+ capture_options (list): List of what to capture - ["REQUEST"], ["RESPONSE"], or ["REQUEST", "RESPONSE"].
61
+ Defaults to ["REQUEST", "RESPONSE"] to capture both.
62
+ force_redeploy (bool): If True, force redeployment even if data capture is already enabled.
63
+ """
64
+ # Early returns for cases where we can't/don't need to add data capture
65
+ if self.endpoint.is_serverless():
66
+ self.log.warning("Data capture is not supported for serverless endpoints.")
67
+ return
68
+
69
+ # Default to capturing both if not specified
70
+ if capture_options is None:
71
+ capture_options = ["REQUEST", "RESPONSE"]
72
+
73
+ # Validate capture_options
74
+ valid_options = {"REQUEST", "RESPONSE"}
75
+ if not all(opt in valid_options for opt in capture_options):
76
+ self.log.error("Invalid capture_options. Must be a list containing 'REQUEST' and/or 'RESPONSE'")
77
+ return
78
+
79
+ if self.is_enabled() and not force_redeploy:
80
+ self.log.important(f"Data capture already configured for {self.endpoint_name}.")
81
+ return
82
+
83
+ # Get the current endpoint configuration name for later deletion
84
+ current_endpoint_config_name = self.endpoint.endpoint_config_name()
85
+
86
+ # Log the data capture operation
87
+ self.log.important(f"Enabling Data Capture for {self.endpoint_name} --> {self.data_capture_path}")
88
+ self.log.important(f"Capturing: {', '.join(capture_options)} at {capture_percentage}% sampling")
89
+ self.log.important("This will redeploy the endpoint...")
90
+
91
+ # Create and apply the data capture configuration
92
+ data_capture_config = DataCaptureConfig(
93
+ enable_capture=True,
94
+ sampling_percentage=capture_percentage,
95
+ destination_s3_uri=self.data_capture_path,
96
+ capture_options=capture_options,
97
+ )
98
+
99
+ # Update endpoint with the new capture configuration
100
+ Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
101
+ data_capture_config=data_capture_config
102
+ )
103
+
104
+ # Clean up old endpoint configuration
105
+ try:
106
+ self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
107
+ self.log.info(f"Deleted old endpoint configuration: {current_endpoint_config_name}")
108
+ except Exception as e:
109
+ self.log.warning(f"Could not delete old endpoint configuration {current_endpoint_config_name}: {e}")
110
+
111
+ def disable(self):
112
+ """
113
+ Disable data capture for the SageMaker endpoint.
114
+ """
115
+ # Early return if data capture isn't configured
116
+ if not self.is_enabled():
117
+ self.log.important(f"Data capture is not currently enabled for {self.endpoint_name}.")
118
+ return
119
+
120
+ # Get the current endpoint configuration name for later deletion
121
+ current_endpoint_config_name = self.endpoint.endpoint_config_name()
122
+
123
+ # Log the operation
124
+ self.log.important(f"Disabling Data Capture for {self.endpoint_name}")
125
+ self.log.important("This normally redeploys the endpoint...")
126
+
127
+ # Create a configuration with capture disabled
128
+ data_capture_config = DataCaptureConfig(enable_capture=False, destination_s3_uri=self.data_capture_path)
129
+
130
+ # Update endpoint with the new configuration
131
+ Predictor(self.endpoint_name, sagemaker_session=self.sagemaker_session).update_data_capture_config(
132
+ data_capture_config=data_capture_config
133
+ )
134
+
135
+ # Clean up old endpoint configuration
136
+ self.sagemaker_client.delete_endpoint_config(EndpointConfigName=current_endpoint_config_name)
137
+
138
+ def is_enabled(self) -> bool:
139
+ """
140
+ Check if data capture is enabled on the endpoint.
141
+
142
+ Returns:
143
+ bool: True if data capture is enabled, False otherwise.
144
+ """
145
+ try:
146
+ endpoint_config_name = self.endpoint.endpoint_config_name()
147
+ endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
148
+ data_capture_config = endpoint_config.get("DataCaptureConfig", {})
149
+
150
+ # Check if data capture is enabled
151
+ is_enabled = data_capture_config.get("EnableCapture", False)
152
+ return is_enabled
153
+ except Exception as e:
154
+ self.log.error(f"Error checking data capture configuration: {e}")
155
+ return False
156
+
157
+ def capture_percentage(self) -> int:
158
+ """
159
+ Get the data capture percentage from the endpoint configuration.
160
+
161
+ Returns:
162
+ int: Data capture percentage if enabled, None otherwise.
163
+ """
164
+ try:
165
+ endpoint_config_name = self.endpoint.endpoint_config_name()
166
+ endpoint_config = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
167
+ data_capture_config = endpoint_config.get("DataCaptureConfig", {})
168
+
169
+ # Check if data capture is enabled and return the percentage
170
+ if data_capture_config.get("EnableCapture", False):
171
+ return data_capture_config.get("InitialSamplingPercentage", 0)
172
+ else:
173
+ return None
174
+ except Exception as e:
175
+ self.log.error(f"Error checking data capture percentage: {e}")
176
+ return None
177
+
178
+ def get_config(self) -> dict:
179
+ """
180
+ Returns the complete data capture configuration from the endpoint config.
181
+
182
+ Returns:
183
+ dict: Complete DataCaptureConfig from AWS, or None if not configured
184
+ """
185
+ config_name = self.endpoint.endpoint_config_name()
186
+ response = self.sagemaker_client.describe_endpoint_config(EndpointConfigName=config_name)
187
+ data_capture_config = response.get("DataCaptureConfig")
188
+ if not data_capture_config:
189
+ self.log.error(f"No data capture configuration found for endpoint config {config_name}")
190
+ return None
191
+ return data_capture_config
192
+
193
+ def capture_modes(self) -> list:
194
+ """Get the current capture modes (REQUEST/RESPONSE)"""
195
+ if not self.is_enabled():
196
+ return []
197
+
198
+ config = self.get_config()
199
+ if not config:
200
+ return []
201
+
202
+ capture_options = config.get("CaptureOptions", [])
203
+ modes = [opt.get("CaptureMode") for opt in capture_options]
204
+ return ["REQUEST" if m == "Input" else "RESPONSE" for m in modes if m]
205
+
206
+ def get_captured_data(self, from_date=None, add_timestamp=True) -> Tuple[pd.DataFrame, pd.DataFrame]:
207
+ """
208
+ Read and process captured data from S3.
209
+
210
+ Args:
211
+ from_date (str, optional): Only process files from this date onwards (YYYY-MM-DD format).
212
+ Defaults to None to process all files.
213
+ add_timestamp (bool, optional): Whether to add a timestamp column to the DataFrame.
214
+
215
+ Returns:
216
+ Tuple[pd.DataFrame, pd.DataFrame]: Processed input and output DataFrames.
217
+ """
218
+ files = wr.s3.list_objects(self.data_capture_path)
219
+ if not files:
220
+ self.log.warning(f"No data capture files found in {self.data_capture_path}.")
221
+ return pd.DataFrame(), pd.DataFrame()
222
+
223
+ # Filter by date if specified
224
+ if from_date:
225
+ from_date_obj = datetime.strptime(from_date, "%Y-%m-%d").date()
226
+ files = [f for f in files if self._file_date_filter(f, from_date_obj)]
227
+ self.log.info(f"Processing {len(files)} files from {from_date} onwards.")
228
+ else:
229
+ self.log.info(f"Processing all {len(files)} files.")
230
+ files.sort()
231
+
232
+ # Process files
233
+ all_input_dfs, all_output_dfs = [], []
234
+ for file_path in files:
235
+ try:
236
+ df = wr.s3.read_json(path=file_path, lines=True)
237
+ if not df.empty:
238
+ input_df, output_df = process_data_capture(df)
239
+ if add_timestamp:
240
+ timestamp = wr.s3.describe_objects(path=file_path)[file_path]["LastModified"]
241
+ output_df["timestamp"] = timestamp
242
+ all_input_dfs.append(input_df)
243
+ all_output_dfs.append(output_df)
244
+ except Exception as e:
245
+ self.log.warning(f"Error processing {file_path}: {e}")
246
+
247
+ if not all_input_dfs:
248
+ self.log.warning("No valid data was processed.")
249
+ return pd.DataFrame(), pd.DataFrame()
250
+
251
+ return pd.concat(all_input_dfs, ignore_index=True), pd.concat(all_output_dfs, ignore_index=True)
252
+
253
+ def _file_date_filter(self, file_path, from_date_obj):
254
+ """Extract date from S3 path and compare with from_date."""
255
+ try:
256
+ # Match YYYY/MM/DD pattern in the path
257
+ date_match = re.search(r"/(\d{4})/(\d{2})/(\d{2})/", file_path)
258
+ if date_match:
259
+ year, month, day = date_match.groups()
260
+ file_date = datetime(int(year), int(month), int(day)).date()
261
+ return file_date >= from_date_obj
262
+ return False # No date pattern found
263
+ except ValueError:
264
+ return False
265
+
266
+ def __repr__(self) -> str:
267
+ """String representation of this DataCaptureCore object
268
+
269
+ Returns:
270
+ str: String representation of this DataCaptureCore object
271
+ """
272
+ summary_dict = self.summary()
273
+ summary_items = [f" {repr(key)}: {repr(value)}" for key, value in summary_dict.items()]
274
+ summary_str = f"{self.__class__.__name__}: {self.endpoint_name}\n" + ",\n".join(summary_items)
275
+ return summary_str
276
+
277
+
278
+ # Test function for the class
279
+ if __name__ == "__main__":
280
+ """Exercise the MonitorCore class"""
281
+ from pprint import pprint
282
+
283
+ # Set options for actually seeing the dataframe
284
+ pd.set_option("display.max_columns", None)
285
+ pd.set_option("display.width", None)
286
+
287
+ # Create the Class and test it out
288
+ endpoint_name = "abalone-regression-rt"
289
+ my_endpoint = EndpointCore(endpoint_name)
290
+ if not my_endpoint.exists():
291
+ print(f"Endpoint {endpoint_name} does not exist.")
292
+ exit(1)
293
+ dc = my_endpoint.data_capture()
294
+
295
+ # Check the summary of the data capture class
296
+ pprint(dc.summary())
297
+
298
+ # Enable data capture on the endpoint
299
+ # dc.enable(force_redeploy=True)
300
+ my_endpoint.enable_data_capture()
301
+
302
+ # Test the data capture by running some predictions
303
+ # pred_df = my_endpoint.auto_inference()
304
+ # print(pred_df.head())
305
+
306
+ # Check that data capture is working
307
+ input_df, output_df = dc.get_captured_data()
308
+ if input_df.empty and output_df.empty:
309
+ print("No data capture files found, for a new endpoint it may take a few minutes to start capturing data")
310
+ else:
311
+ print("Found data capture files")
312
+ print("Input")
313
+ print(input_df.head())
314
+ print("Output")
315
+ print(output_df.head())
@@ -164,11 +164,17 @@ class EndpointCore(Artifact):
164
164
  """
165
165
  return "Serverless" in self.endpoint_meta["InstanceType"]
166
166
 
167
- def add_data_capture(self):
167
+ def data_capture(self):
168
+ """Get the MonitorCore class for this endpoint"""
169
+ from workbench.core.artifacts.data_capture_core import DataCaptureCore
170
+
171
+ return DataCaptureCore(self.endpoint_name)
172
+
173
+ def enable_data_capture(self):
168
174
  """Add data capture to the endpoint"""
169
- self.get_monitor().add_data_capture()
175
+ self.data_capture().enable()
170
176
 
171
- def get_monitor(self):
177
+ def monitor(self):
172
178
  """Get the MonitorCore class for this endpoint"""
173
179
  from workbench.core.artifacts.monitor_core import MonitorCore
174
180
 
@@ -37,16 +37,45 @@ class ModelType(Enum):
37
37
  UNKNOWN = "unknown"
38
38
 
39
39
 
40
+ # Deprecated Images
41
+ """
42
+ # US East 1 images
43
+ "py312-general-ml-training"
44
+ ("us-east-1", "training", "0.1", "x86_64"): (
45
+ "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
46
+ ),
47
+ ("us-east-1", "inference", "0.1", "x86_64"): (
48
+ "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
49
+ ),
50
+
51
+ # US West 2 images
52
+ ("us-west-2", "training", "0.1", "x86_64"): (
53
+ "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
54
+ ),
55
+ ("us-west-2", "inference", "0.1", "x86_64"): (
56
+ "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
57
+ ),
58
+
59
+ # ARM64 images
60
+ ("us-east-1", "inference", "0.1", "arm64"): (
61
+ "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
62
+ ),
63
+ ("us-west-2", "inference", "0.1", "arm64"): (
64
+ "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
65
+ ),
66
+ """
67
+
68
+
40
69
  class ModelImages:
41
70
  """Class for retrieving workbench inference images"""
42
71
 
43
72
  image_uris = {
44
73
  # US East 1 images
45
- ("us-east-1", "xgb_training", "0.1", "x86_64"): (
46
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
74
+ ("us-east-1", "training", "0.1", "x86_64"): (
75
+ "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
47
76
  ),
48
- ("us-east-1", "xgb_inference", "0.1", "x86_64"): (
49
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
77
+ ("us-east-1", "inference", "0.1", "x86_64"): (
78
+ "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
50
79
  ),
51
80
  ("us-east-1", "pytorch_training", "0.1", "x86_64"): (
52
81
  "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
@@ -55,11 +84,11 @@ class ModelImages:
55
84
  "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
56
85
  ),
57
86
  # US West 2 images
58
- ("us-west-2", "xgb_training", "0.1", "x86_64"): (
59
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
87
+ ("us-west-2", "training", "0.1", "x86_64"): (
88
+ "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
60
89
  ),
61
- ("us-west-2", "xgb_inference", "0.1", "x86_64"): (
62
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
90
+ ("us-west-2", "inference", "0.1", "x86_64"): (
91
+ "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
63
92
  ),
64
93
  ("us-west-2", "pytorch_training", "0.1", "x86_64"): (
65
94
  "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
@@ -68,12 +97,6 @@ class ModelImages:
68
97
  "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
69
98
  ),
70
99
  # ARM64 images
71
- ("us-east-1", "xgb_inference", "0.1", "arm64"): (
72
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
73
- ),
74
- ("us-west-2", "xgb_inference", "0.1", "arm64"): (
75
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
76
- ),
77
100
  # Meta Endpoint inference images
78
101
  ("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
79
102
  "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"