workbench 0.8.158__py3-none-any.whl → 0.8.160__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (31) hide show
  1. workbench/api/feature_set.py +12 -4
  2. workbench/api/meta.py +1 -1
  3. workbench/cached/cached_feature_set.py +1 -0
  4. workbench/cached/cached_meta.py +10 -12
  5. workbench/core/artifacts/cached_artifact_mixin.py +6 -3
  6. workbench/core/artifacts/model_core.py +19 -7
  7. workbench/core/cloud_platform/aws/aws_meta.py +66 -45
  8. workbench/core/cloud_platform/cloud_meta.py +5 -2
  9. workbench/core/transforms/features_to_model/features_to_model.py +9 -5
  10. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +6 -0
  11. workbench/model_scripts/{custom_models/nn_models → pytorch_model}/generated_model_script.py +170 -156
  12. workbench/model_scripts/{custom_models/nn_models → pytorch_model}/pytorch.template +153 -147
  13. workbench/model_scripts/pytorch_model/requirements.txt +2 -0
  14. workbench/model_scripts/scikit_learn/generated_model_script.py +307 -0
  15. workbench/model_scripts/script_generation.py +6 -2
  16. workbench/model_scripts/xgb_model/generated_model_script.py +6 -6
  17. workbench/repl/workbench_shell.py +11 -10
  18. workbench/utils/json_utils.py +27 -8
  19. workbench/utils/pandas_utils.py +12 -13
  20. workbench/utils/redis_cache.py +28 -13
  21. workbench/utils/workbench_cache.py +20 -14
  22. workbench/web_interface/page_views/endpoints_page_view.py +1 -1
  23. workbench/web_interface/page_views/main_page.py +1 -1
  24. {workbench-0.8.158.dist-info → workbench-0.8.160.dist-info}/METADATA +5 -8
  25. {workbench-0.8.158.dist-info → workbench-0.8.160.dist-info}/RECORD +29 -29
  26. workbench/model_scripts/custom_models/nn_models/Readme.md +0 -9
  27. workbench/model_scripts/custom_models/nn_models/requirements.txt +0 -4
  28. {workbench-0.8.158.dist-info → workbench-0.8.160.dist-info}/WHEEL +0 -0
  29. {workbench-0.8.158.dist-info → workbench-0.8.160.dist-info}/entry_points.txt +0 -0
  30. {workbench-0.8.158.dist-info → workbench-0.8.160.dist-info}/licenses/LICENSE +0 -0
  31. {workbench-0.8.158.dist-info → workbench-0.8.160.dist-info}/top_level.txt +0 -0
@@ -87,8 +87,9 @@ class FeatureSet(FeatureSetCore):
87
87
  model_import_str: str = None,
88
88
  custom_script: Union[str, Path] = None,
89
89
  custom_args: dict = None,
90
+ training_image: str = "xgb_training",
91
+ inference_image: str = "xgb_inference",
90
92
  inference_arch: str = "x86_64",
91
- inference_image: str = "inference",
92
93
  **kwargs,
93
94
  ) -> Union[Model, None]:
94
95
  """Create a Model from the FeatureSet
@@ -101,11 +102,12 @@ class FeatureSet(FeatureSetCore):
101
102
  description (str, optional): Set the description for the model. If not give a description is generated.
102
103
  feature_list (list, optional): Set the feature list for the model. If not given a feature list is generated.
103
104
  target_column (str, optional): The target column for the model (use None for unsupervised model)
104
- model_class (str, optional): Scikit model class to use (e.g. "KMeans", default: None)
105
+ model_class (str, optional): Model class to use (e.g. "KMeans", "PyTorch", default: None)
105
106
  model_import_str (str, optional): The import for the model (e.g. "from sklearn.cluster import KMeans")
106
107
  custom_script (str, optional): The custom script to use for the model (default: None)
108
+ training_image (str, optional): The training image to use (default: "xgb_training")
109
+ inference_image (str, optional): The inference image to use (default: "xgb_inference")
107
110
  inference_arch (str, optional): The architecture to use for inference (default: "x86_64")
108
- inference_image (str, optional): The inference image to use (default: "inference")
109
111
 
110
112
  Returns:
111
113
  Model: The Model created from the FeatureSet (or None if the Model could not be created)
@@ -125,6 +127,11 @@ class FeatureSet(FeatureSetCore):
125
127
  # Create the Model Tags
126
128
  tags = [name] if tags is None else tags
127
129
 
130
+ # If the model_class is PyTorch, ensure we set the training and inference images
131
+ if model_class and model_class.lower() == "pytorch":
132
+ training_image = "pytorch_training"
133
+ inference_image = "pytorch_inference"
134
+
128
135
  # Transform the FeatureSet into a Model
129
136
  features_to_model = FeaturesToModel(
130
137
  feature_name=self.name,
@@ -134,8 +141,9 @@ class FeatureSet(FeatureSetCore):
134
141
  model_import_str=model_import_str,
135
142
  custom_script=custom_script,
136
143
  custom_args=custom_args,
137
- inference_arch=inference_arch,
144
+ training_image=training_image,
138
145
  inference_image=inference_image,
146
+ inference_arch=inference_arch,
139
147
  )
140
148
  features_to_model.set_output_tags(tags)
141
149
  features_to_model.transform(
workbench/api/meta.py CHANGED
@@ -28,7 +28,7 @@ class Meta(CloudMeta):
28
28
  meta.data_sources()
29
29
  meta.feature_sets(details=True/False)
30
30
  meta.models(details=True/False)
31
- meta.endpoints()
31
+ meta.endpoints(details=True/False)
32
32
  meta.views()
33
33
  meta.pipelines()
34
34
 
@@ -79,6 +79,7 @@ if __name__ == "__main__":
79
79
 
80
80
  # Retrieve an existing FeatureSet
81
81
  my_features = CachedFeatureSet("abalone_features")
82
+ pprint(my_features.smart_sample())
82
83
  pprint(my_features.summary())
83
84
  pprint(my_features.details())
84
85
  pprint(my_features.health_check())
@@ -13,8 +13,6 @@ from workbench.utils.workbench_cache import WorkbenchCache
13
13
 
14
14
 
15
15
  # Decorator to cache method results from the Meta class
16
- # Note: This has to be outside the class definition to work properly in Python 3.9
17
- # When we deprecated support for 3.9, move this back into the class definition
18
16
  def cache_result(method):
19
17
  """Decorator to cache method results in meta_cache"""
20
18
 
@@ -24,11 +22,8 @@ def cache_result(method):
24
22
  cache_key = CachedMeta._flatten_redis_key(method, *args, **kwargs)
25
23
 
26
24
  # Check for fresh data, spawn thread to refresh if stale
27
- if WorkbenchCache.refresh_enabled and self.fresh_cache.get(cache_key) is None:
28
- self.log.debug(f"Async: Metadata for {cache_key} refresh thread started...")
29
- self.fresh_cache.set(cache_key, True) # Mark as refreshed
30
-
31
- # Spawn a thread to refresh data without blocking
25
+ if self.fresh_cache.atomic_set(cache_key, True):
26
+ self.log.important(f"Async: Metadata for {cache_key} refresh thread started...")
32
27
  self.thread_pool.submit(self._refresh_data_in_background, cache_key, method, *args, **kwargs)
33
28
 
34
29
  # Return data (fresh or stale) if available
@@ -62,7 +57,7 @@ class CachedMeta(CloudMeta):
62
57
  meta.data_sources()
63
58
  meta.feature_sets(details=True/False)
64
59
  meta.models(details=True/False)
65
- meta.endpoints()
60
+ meta.endpoints(details=True/False)
66
61
  meta.views()
67
62
 
68
63
  # These are 'describe' methods
@@ -91,7 +86,7 @@ class CachedMeta(CloudMeta):
91
86
 
92
87
  # Create both our Meta Cache and Fresh Cache (tracks if data is stale)
93
88
  self.meta_cache = WorkbenchCache(prefix="meta")
94
- self.fresh_cache = WorkbenchCache(prefix="meta_fresh", expire=90) # 90-second expiration
89
+ self.fresh_cache = WorkbenchCache(prefix="meta_fresh", expire=300) # 5-minute expiration
95
90
 
96
91
  # Create a ThreadPoolExecutor for refreshing stale data
97
92
  self.thread_pool = ThreadPoolExecutor(max_workers=5)
@@ -193,13 +188,16 @@ class CachedMeta(CloudMeta):
193
188
  return super().models(details=details)
194
189
 
195
190
  @cache_result
196
- def endpoints(self) -> pd.DataFrame:
191
+ def endpoints(self, details: bool = False) -> pd.DataFrame:
197
192
  """Get a summary of the Endpoints deployed in the Cloud Platform
198
193
 
194
+ Args:
195
+ details (bool, optional): Include detailed information. Defaults to False.
196
+
199
197
  Returns:
200
198
  pd.DataFrame: A summary of the Endpoints in the Cloud Platform
201
199
  """
202
- return super().endpoints()
200
+ return super().endpoints(details=details)
203
201
 
204
202
  @cache_result
205
203
  def glue_job(self, job_name: str) -> Union[dict, None]:
@@ -266,7 +264,7 @@ class CachedMeta(CloudMeta):
266
264
  """Background task to refresh AWS metadata."""
267
265
  result = method(self, *args, **kwargs)
268
266
  self.meta_cache.set(cache_key, result)
269
- self.log.debug(f"Updated Metadata for {cache_key}")
267
+ self.log.important(f"Updated Metadata for {cache_key}")
270
268
 
271
269
  @staticmethod
272
270
  def _flatten_redis_key(method, *args, **kwargs):
@@ -13,7 +13,7 @@ class CachedArtifactMixin:
13
13
  # Class-level caches, thread pool, and shutdown flag
14
14
  log = logging.getLogger("workbench")
15
15
  artifact_cache = WorkbenchCache(prefix="artifact_cache")
16
- fresh_cache = WorkbenchCache(prefix="artifact_fresh_cache", expire=10)
16
+ fresh_cache = WorkbenchCache(prefix="artifact_fresh_cache", expire=120)
17
17
  thread_pool = ThreadPoolExecutor(max_workers=5)
18
18
 
19
19
  @staticmethod
@@ -45,8 +45,8 @@ class CachedArtifactMixin:
45
45
  cls.artifact_cache.set(cache_key, result)
46
46
  return result
47
47
 
48
- # Stale cache: Refresh in the background if enabled and no refresh is already in progress
49
- if WorkbenchCache.refresh_enabled and cache_fresh is None:
48
+ # Stale cache: Refresh in the background
49
+ if cache_fresh is None:
50
50
  self.log.debug(f"Async: Refresh thread started: {cache_key}...")
51
51
  cls.fresh_cache.set(cache_key, True)
52
52
  cls.thread_pool.submit(cls._refresh_data_in_background, self, cache_key, method, *args, **kwargs)
@@ -88,4 +88,7 @@ if __name__ == "__main__":
88
88
  my_model = CachedModel("abalone-regression")
89
89
  pprint(my_model.summary())
90
90
  pprint(my_model.details())
91
+ # Second call to demonstrate caching
92
+ pprint(my_model.summary())
93
+ pprint(my_model.details())
91
94
  CachedArtifactMixin._shutdown()
@@ -42,24 +42,36 @@ class ModelImages:
42
42
 
43
43
  image_uris = {
44
44
  # US East 1 images
45
- ("us-east-1", "training", "0.1", "x86_64"): (
45
+ ("us-east-1", "xgb_training", "0.1", "x86_64"): (
46
46
  "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
47
47
  ),
48
- ("us-east-1", "inference", "0.1", "x86_64"): (
48
+ ("us-east-1", "xgb_inference", "0.1", "x86_64"): (
49
49
  "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
50
50
  ),
51
+ ("us-east-1", "pytorch_training", "0.1", "x86_64"): (
52
+ "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
53
+ ),
54
+ ("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
55
+ "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
56
+ ),
51
57
  # US West 2 images
52
- ("us-west-2", "training", "0.1", "x86_64"): (
58
+ ("us-west-2", "xgb_training", "0.1", "x86_64"): (
53
59
  "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
54
60
  ),
55
- ("us-west-2", "inference", "0.1", "x86_64"): (
61
+ ("us-west-2", "xgb_inference", "0.1", "x86_64"): (
56
62
  "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
57
63
  ),
64
+ ("us-west-2", "pytorch_training", "0.1", "x86_64"): (
65
+ "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
66
+ ),
67
+ ("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
68
+ "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
69
+ ),
58
70
  # ARM64 images
59
- ("us-east-1", "inference", "0.1", "arm64"): (
71
+ ("us-east-1", "xgb_inference", "0.1", "arm64"): (
60
72
  "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
61
73
  ),
62
- ("us-west-2", "inference", "0.1", "arm64"): (
74
+ ("us-west-2", "xgb_inference", "0.1", "arm64"): (
63
75
  "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
64
76
  ),
65
77
  # Meta Endpoint inference images
@@ -72,7 +84,7 @@ class ModelImages:
72
84
  }
73
85
 
74
86
  @classmethod
75
- def get_image_uri(cls, region, image_type="training", version="0.1", architecture="x86_64"):
87
+ def get_image_uri(cls, region, image_type, version="0.1", architecture="x86_64"):
76
88
  key = (region, image_type, version, architecture)
77
89
  if key in cls.image_uris:
78
90
  return cls.image_uris[key]
@@ -179,7 +179,7 @@ class AWSMeta:
179
179
  feature_set_details.update(self.sm_client.describe_feature_group(FeatureGroupName=name))
180
180
 
181
181
  # Retrieve Workbench metadata from tags
182
- aws_tags = self.get_aws_tags(fg["FeatureGroupArn"])
182
+ aws_tags = self.get_aws_tags(fg["FeatureGroupArn"]) if details else {}
183
183
  summary = {
184
184
  "Feature Group": name,
185
185
  "Health": "",
@@ -258,70 +258,60 @@ class AWSMeta:
258
258
  df = pd.DataFrame(model_summary).convert_dtypes()
259
259
  return df.sort_values(by="Created", ascending=False)
260
260
 
261
- def endpoints(self, refresh: bool = False) -> pd.DataFrame:
261
+ def endpoints(self, details: bool = False) -> pd.DataFrame:
262
262
  """Get a summary of the Endpoints in AWS.
263
263
 
264
264
  Args:
265
- refresh (bool, optional): Force a refresh of the metadata. Defaults to False.
265
+ details (bool, optional): Get additional details (Defaults to False).
266
266
 
267
267
  Returns:
268
268
  pd.DataFrame: A summary of the Endpoints in AWS.
269
269
  """
270
270
  from workbench.utils.endpoint_utils import is_monitored # noqa: E402
271
271
 
272
- # Initialize the SageMaker client and list all endpoints
273
- sagemaker_client = self.boto3_session.client("sagemaker")
274
- paginator = sagemaker_client.get_paginator("list_endpoints")
272
+ # Use our SageMaker client to list all endpoints
273
+ paginator = self.sm_client.get_paginator("list_endpoints")
275
274
  data_summary = []
276
275
 
277
276
  # Use the paginator to retrieve all endpoints
278
277
  for page in paginator.paginate():
279
278
  for endpoint in page["Endpoints"]:
280
279
  endpoint_name = endpoint["EndpointName"]
281
- endpoint_info = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
282
280
 
283
- # Retrieve Workbench metadata from tags
284
- aws_tags = self.get_aws_tags(endpoint_info["EndpointArn"])
285
- health_tags = aws_tags.get("workbench_health_tags", "")
286
-
287
- # Retrieve endpoint configuration to determine instance type or serverless info
288
- endpoint_config_name = endpoint_info["EndpointConfigName"]
289
-
290
- # Getting the endpoint configuration can fail so account for that
291
- try:
292
- endpoint_config = sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
293
- production_variant = endpoint_config["ProductionVariants"][0]
294
- # Determine instance type or serverless configuration
295
- instance_type = production_variant.get("InstanceType")
296
- if instance_type is None:
297
- # If no instance type, it's a serverless configuration
298
- mem_size = production_variant["ServerlessConfig"]["MemorySizeInMB"]
299
- concurrency = production_variant["ServerlessConfig"]["MaxConcurrency"]
300
- instance_type = f"Serverless ({mem_size // 1024}GB/{concurrency})"
301
- except sagemaker_client.exceptions.ClientError:
302
- # If the endpoint config is not found, change the config name to reflect this
303
- endpoint_config_name = f"{endpoint_config_name} (Not Found)"
304
- production_variant = {}
305
- instance_type = "Unknown"
306
-
307
- # Check if the endpoint has monitoring enabled
308
- endpoint_monitored = is_monitored(endpoint_name, sagemaker_client)
281
+ # Grab various endpoint details
282
+ endpoint_details = {"config": {"instance": "-", "variant": "-"}, "monitored": "-"}
283
+ aws_tags = {}
284
+ if details:
285
+ endpoint_details = self.sm_client.describe_endpoint(EndpointName=endpoint_name)
286
+
287
+ # Retrieve AWS Tags for this Endpoint
288
+ aws_tags = self.get_aws_tags(endpoint_details["EndpointArn"])
289
+
290
+ # Getting the endpoint configuration
291
+ config_info = self._endpoint_config_info(endpoint_details["EndpointConfigName"])
292
+ endpoint_details["config"] = config_info
293
+
294
+ # Check if the endpoint has monitoring enabled
295
+ endpoint_details["monitored"] = is_monitored(endpoint_name, self.sm_client)
309
296
 
310
297
  # Compile endpoint summary
298
+ created = (
299
+ datetime_string(endpoint_details["CreationTime"]) if "CreationTime" in endpoint_details else "-"
300
+ )
311
301
  summary = {
312
302
  "Name": endpoint_name,
313
- "Health": health_tags,
303
+ "Health": aws_tags.get("workbench_health_tags", ""),
314
304
  "Owner": aws_tags.get("workbench_owner", "-"),
315
- "Instance": instance_type,
316
- "Created": datetime_string(endpoint_info.get("CreationTime")),
305
+ "Instance": endpoint_details["config"]["instance"],
306
+ "Created": created,
317
307
  "Input": aws_tags.get("workbench_input", "-"),
318
- "Status": endpoint_info["EndpointStatus"],
319
- "Config": endpoint_config_name,
320
- "Variant": production_variant.get("VariantName", "-"),
321
- "Capture": str(endpoint_info.get("DataCaptureConfig", {}).get("EnableCapture", "False")),
322
- "Samp(%)": str(endpoint_info.get("DataCaptureConfig", {}).get("CurrentSamplingPercentage", "-")),
308
+ "Status": endpoint_details.get("EndpointStatus", "-"),
309
+ "Config": endpoint_details.get("EndpointConfigName", "-"),
310
+ "Variant": endpoint_details["config"]["variant"],
311
+ "Capture": str(endpoint_details.get("DataCaptureConfig", {}).get("EnableCapture", "False")),
312
+ "Samp(%)": str(endpoint_details.get("DataCaptureConfig", {}).get("CurrentSamplingPercentage", "-")),
323
313
  "Tags": aws_tags.get("workbench_tags", "-"),
324
- "Monitored": endpoint_monitored,
314
+ "Monitored": endpoint_details["monitored"],
325
315
  }
326
316
  data_summary.append(summary)
327
317
 
@@ -329,6 +319,34 @@ class AWSMeta:
329
319
  df = pd.DataFrame(data_summary).convert_dtypes()
330
320
  return df.sort_values(by="Created", ascending=False)
331
321
 
322
+ def _endpoint_config_info(self, endpoint_config_name: str) -> dict:
323
+ """Internal: Get the Endpoint Configuration information for the given endpoint config name.
324
+
325
+ Args:
326
+ endpoint_config_name (str): The name of the endpoint configuration.
327
+
328
+ Returns:
329
+ dict: The endpoint configuration details.
330
+ """
331
+
332
+ # Retrieve the endpoint configuration
333
+ try:
334
+ endpoint_config = self.sm_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
335
+ production_variant = endpoint_config["ProductionVariants"][0]
336
+
337
+ # Determine instance type or serverless configuration
338
+ instance_type = production_variant.get("InstanceType")
339
+ if instance_type is None:
340
+ # If no instance type, it's a serverless configuration
341
+ mem_size = production_variant["ServerlessConfig"]["MemorySizeInMB"]
342
+ concurrency = production_variant["ServerlessConfig"]["MaxConcurrency"]
343
+ instance_type = f"Serverless ({mem_size // 1024}GB/{concurrency})"
344
+
345
+ return {"instance": instance_type, "variant": production_variant.get("VariantName", "-")}
346
+ except self.sm_client.exceptions.ClientError as e:
347
+ self.log.error(f"Error retrieving endpoint config {endpoint_config_name}: {e}")
348
+ return {"instance": "-", "variant": "-"}
349
+
332
350
  def pipelines(self) -> pd.DataFrame:
333
351
  """List all the Pipelines in the S3 Bucket
334
352
 
@@ -702,7 +720,6 @@ class AWSMeta:
702
720
 
703
721
  if __name__ == "__main__":
704
722
  """Exercise the Workbench AWSMeta Class"""
705
- import time
706
723
  from pprint import pprint
707
724
 
708
725
  # Pandas Display Options
@@ -712,6 +729,7 @@ if __name__ == "__main__":
712
729
  # Create the class
713
730
  meta = AWSMeta()
714
731
 
732
+ """
715
733
  # Test the __repr__ method
716
734
  print(meta)
717
735
 
@@ -759,11 +777,15 @@ if __name__ == "__main__":
759
777
  start_time = time.time()
760
778
  pprint(meta.models(details=True))
761
779
  print(f"Elapsed Time Model (with details): {time.time() - start_time:.2f}")
762
-
780
+ """
763
781
  # Get the Endpoints
764
782
  print("\n\n*** Endpoints ***")
765
783
  pprint(meta.endpoints())
766
784
 
785
+ # Get the Endpoints with Details
786
+ print("\n\n*** Endpoints with Details ***")
787
+ pprint(meta.endpoints(details=True))
788
+
767
789
  # List Pipelines
768
790
  print("\n\n*** Workbench Pipelines ***")
769
791
  pprint(meta.pipelines())
@@ -785,7 +807,6 @@ if __name__ == "__main__":
785
807
  pprint(meta.model("abalone-regression"))
786
808
  print("\n\n*** Endpoint Details ***")
787
809
  pprint(meta.endpoint("abalone-regression"))
788
- pprint(meta.endpoint("test-timing-realtime"))
789
810
 
790
811
  # Test out a non-existent model
791
812
  print("\n\n*** Model Doesn't Exist ***")
@@ -121,13 +121,16 @@ class CloudMeta(AWSMeta):
121
121
  """
122
122
  return super().models(details=details)
123
123
 
124
- def endpoints(self) -> pd.DataFrame:
124
+ def endpoints(self, details: bool = False) -> pd.DataFrame:
125
125
  """Get a summary of the Endpoints deployed in the Cloud Platform
126
126
 
127
+ Args:
128
+ details (bool, optional): Include detailed information. Defaults to False.
129
+
127
130
  Returns:
128
131
  pd.DataFrame: A summary of the Endpoints in the Cloud Platform
129
132
  """
130
- return super().endpoints()
133
+ return super().endpoints(details=details)
131
134
 
132
135
  def pipelines(self) -> pd.DataFrame:
133
136
  """Get a summary of the Pipelines deployed in the Cloud Platform
@@ -37,8 +37,9 @@ class FeaturesToModel(Transform):
37
37
  model_import_str=None,
38
38
  custom_script=None,
39
39
  custom_args=None,
40
+ training_image="xgb_training",
41
+ inference_image="xgb_inference",
40
42
  inference_arch="x86_64",
41
- inference_image="inference",
42
43
  ):
43
44
  """FeaturesToModel Initialization
44
45
  Args:
@@ -49,8 +50,9 @@ class FeaturesToModel(Transform):
49
50
  model_import_str (str, optional): The import string for the model (default None)
50
51
  custom_script (str, optional): Custom script to use for the model (default None)
51
52
  custom_args (dict, optional): Custom arguments to pass to custom model scripts (default None)
53
+ training_image (str, optional): Training image (default "xgb_training")
54
+ inference_image (str, optional): Inference image (default "xgb_inference")
52
55
  inference_arch (str, optional): Inference architecture (default "x86_64")
53
- inference_image (str, optional): Inference image (default "inference")
54
56
  """
55
57
 
56
58
  # Make sure the model_name is a valid name
@@ -73,8 +75,9 @@ class FeaturesToModel(Transform):
73
75
  self.model_feature_list = None
74
76
  self.target_column = None
75
77
  self.class_labels = None
76
- self.inference_arch = inference_arch
78
+ self.training_image = training_image
77
79
  self.inference_image = inference_image
80
+ self.inference_arch = inference_arch
78
81
 
79
82
  def transform_impl(
80
83
  self, target_column: str, description: str = None, feature_list: list = None, train_all_data=False, **kwargs
@@ -229,7 +232,7 @@ class FeaturesToModel(Transform):
229
232
  source_dir = str(Path(script_path).parent)
230
233
 
231
234
  # Create a Sagemaker Model with our script
232
- image = ModelImages.get_image_uri(self.sm_session.boto_region_name, "training", "0.1")
235
+ image = ModelImages.get_image_uri(self.sm_session.boto_region_name, self.training_image, "0.1")
233
236
  self.estimator = Estimator(
234
237
  entry_point=entry_point,
235
238
  source_dir=source_dir,
@@ -246,6 +249,7 @@ class FeaturesToModel(Transform):
246
249
  training_job_name = f"{self.output_name}-{training_date_time_utc}"
247
250
 
248
251
  # Train the estimator
252
+ self.log.important(f"Training the Model {self.output_name} with Training Image {image}...")
249
253
  self.estimator.fit({"train": s3_training_path}, job_name=training_job_name)
250
254
 
251
255
  # Now delete the training data
@@ -297,7 +301,7 @@ class FeaturesToModel(Transform):
297
301
  image = ModelImages.get_image_uri(
298
302
  self.sm_session.boto_region_name, self.inference_image, "0.1", self.inference_arch
299
303
  )
300
- self.log.important(f"Registering model {self.output_name} with image {image}...")
304
+ self.log.important(f"Registering model {self.output_name} with Inference Image {image}...")
301
305
  model = self.estimator.create_model(role=self.workbench_role_arn)
302
306
  if aws_region:
303
307
  self.log.important(f"Setting AWS Region: {aws_region} for model {self.output_name}...")
@@ -78,6 +78,12 @@ class ModelToEndpoint(Transform):
78
78
  sagemaker_session=self.sm_session,
79
79
  )
80
80
 
81
+ # Log the image that will be used for deployment
82
+ inference_image = self.sm_client.describe_model_package(ModelPackageName=model_package_arn)[
83
+ "InferenceSpecification"
84
+ ]["Containers"][0]["Image"]
85
+ self.log.important(f"Deploying Model Package: {self.input_name} with Inference Image: {inference_image}")
86
+
81
87
  # Get the metadata/tags to push into AWS
82
88
  aws_tags = self.get_aws_tags()
83
89