workbench 0.8.158__py3-none-any.whl → 0.8.159__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/api/feature_set.py +12 -4
- workbench/api/meta.py +1 -1
- workbench/cached/cached_feature_set.py +1 -0
- workbench/cached/cached_meta.py +10 -12
- workbench/core/artifacts/cached_artifact_mixin.py +6 -3
- workbench/core/artifacts/model_core.py +19 -7
- workbench/core/cloud_platform/aws/aws_meta.py +66 -45
- workbench/core/cloud_platform/cloud_meta.py +5 -2
- workbench/core/transforms/features_to_model/features_to_model.py +9 -5
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +6 -0
- workbench/model_scripts/{custom_models/nn_models → pytorch_model}/generated_model_script.py +170 -156
- workbench/model_scripts/{custom_models/nn_models → pytorch_model}/pytorch.template +153 -147
- workbench/model_scripts/pytorch_model/requirements.txt +2 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +307 -0
- workbench/model_scripts/script_generation.py +6 -2
- workbench/model_scripts/xgb_model/generated_model_script.py +6 -6
- workbench/repl/workbench_shell.py +4 -9
- workbench/utils/json_utils.py +27 -8
- workbench/utils/pandas_utils.py +12 -13
- workbench/utils/redis_cache.py +28 -13
- workbench/utils/workbench_cache.py +20 -14
- workbench/web_interface/page_views/endpoints_page_view.py +1 -1
- workbench/web_interface/page_views/main_page.py +1 -1
- {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/METADATA +5 -8
- {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/RECORD +29 -29
- workbench/model_scripts/custom_models/nn_models/Readme.md +0 -9
- workbench/model_scripts/custom_models/nn_models/requirements.txt +0 -4
- {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/WHEEL +0 -0
- {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/top_level.txt +0 -0
workbench/api/feature_set.py
CHANGED
|
@@ -87,8 +87,9 @@ class FeatureSet(FeatureSetCore):
|
|
|
87
87
|
model_import_str: str = None,
|
|
88
88
|
custom_script: Union[str, Path] = None,
|
|
89
89
|
custom_args: dict = None,
|
|
90
|
+
training_image: str = "xgb_training",
|
|
91
|
+
inference_image: str = "xgb_inference",
|
|
90
92
|
inference_arch: str = "x86_64",
|
|
91
|
-
inference_image: str = "inference",
|
|
92
93
|
**kwargs,
|
|
93
94
|
) -> Union[Model, None]:
|
|
94
95
|
"""Create a Model from the FeatureSet
|
|
@@ -101,11 +102,12 @@ class FeatureSet(FeatureSetCore):
|
|
|
101
102
|
description (str, optional): Set the description for the model. If not give a description is generated.
|
|
102
103
|
feature_list (list, optional): Set the feature list for the model. If not given a feature list is generated.
|
|
103
104
|
target_column (str, optional): The target column for the model (use None for unsupervised model)
|
|
104
|
-
model_class (str, optional):
|
|
105
|
+
model_class (str, optional): Model class to use (e.g. "KMeans", "PyTorch", default: None)
|
|
105
106
|
model_import_str (str, optional): The import for the model (e.g. "from sklearn.cluster import KMeans")
|
|
106
107
|
custom_script (str, optional): The custom script to use for the model (default: None)
|
|
108
|
+
training_image (str, optional): The training image to use (default: "xgb_training")
|
|
109
|
+
inference_image (str, optional): The inference image to use (default: "xgb_inference")
|
|
107
110
|
inference_arch (str, optional): The architecture to use for inference (default: "x86_64")
|
|
108
|
-
inference_image (str, optional): The inference image to use (default: "inference")
|
|
109
111
|
|
|
110
112
|
Returns:
|
|
111
113
|
Model: The Model created from the FeatureSet (or None if the Model could not be created)
|
|
@@ -125,6 +127,11 @@ class FeatureSet(FeatureSetCore):
|
|
|
125
127
|
# Create the Model Tags
|
|
126
128
|
tags = [name] if tags is None else tags
|
|
127
129
|
|
|
130
|
+
# If the model_class is PyTorch, ensure we set the training and inference images
|
|
131
|
+
if model_class and model_class.lower() == "pytorch":
|
|
132
|
+
training_image = "pytorch_training"
|
|
133
|
+
inference_image = "pytorch_inference"
|
|
134
|
+
|
|
128
135
|
# Transform the FeatureSet into a Model
|
|
129
136
|
features_to_model = FeaturesToModel(
|
|
130
137
|
feature_name=self.name,
|
|
@@ -134,8 +141,9 @@ class FeatureSet(FeatureSetCore):
|
|
|
134
141
|
model_import_str=model_import_str,
|
|
135
142
|
custom_script=custom_script,
|
|
136
143
|
custom_args=custom_args,
|
|
137
|
-
|
|
144
|
+
training_image=training_image,
|
|
138
145
|
inference_image=inference_image,
|
|
146
|
+
inference_arch=inference_arch,
|
|
139
147
|
)
|
|
140
148
|
features_to_model.set_output_tags(tags)
|
|
141
149
|
features_to_model.transform(
|
workbench/api/meta.py
CHANGED
|
@@ -79,6 +79,7 @@ if __name__ == "__main__":
|
|
|
79
79
|
|
|
80
80
|
# Retrieve an existing FeatureSet
|
|
81
81
|
my_features = CachedFeatureSet("abalone_features")
|
|
82
|
+
pprint(my_features.smart_sample())
|
|
82
83
|
pprint(my_features.summary())
|
|
83
84
|
pprint(my_features.details())
|
|
84
85
|
pprint(my_features.health_check())
|
workbench/cached/cached_meta.py
CHANGED
|
@@ -13,8 +13,6 @@ from workbench.utils.workbench_cache import WorkbenchCache
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
# Decorator to cache method results from the Meta class
|
|
16
|
-
# Note: This has to be outside the class definition to work properly in Python 3.9
|
|
17
|
-
# When we deprecated support for 3.9, move this back into the class definition
|
|
18
16
|
def cache_result(method):
|
|
19
17
|
"""Decorator to cache method results in meta_cache"""
|
|
20
18
|
|
|
@@ -24,11 +22,8 @@ def cache_result(method):
|
|
|
24
22
|
cache_key = CachedMeta._flatten_redis_key(method, *args, **kwargs)
|
|
25
23
|
|
|
26
24
|
# Check for fresh data, spawn thread to refresh if stale
|
|
27
|
-
if
|
|
28
|
-
self.log.
|
|
29
|
-
self.fresh_cache.set(cache_key, True) # Mark as refreshed
|
|
30
|
-
|
|
31
|
-
# Spawn a thread to refresh data without blocking
|
|
25
|
+
if self.fresh_cache.atomic_set(cache_key, True):
|
|
26
|
+
self.log.important(f"Async: Metadata for {cache_key} refresh thread started...")
|
|
32
27
|
self.thread_pool.submit(self._refresh_data_in_background, cache_key, method, *args, **kwargs)
|
|
33
28
|
|
|
34
29
|
# Return data (fresh or stale) if available
|
|
@@ -62,7 +57,7 @@ class CachedMeta(CloudMeta):
|
|
|
62
57
|
meta.data_sources()
|
|
63
58
|
meta.feature_sets(details=True/False)
|
|
64
59
|
meta.models(details=True/False)
|
|
65
|
-
meta.endpoints()
|
|
60
|
+
meta.endpoints(details=True/False)
|
|
66
61
|
meta.views()
|
|
67
62
|
|
|
68
63
|
# These are 'describe' methods
|
|
@@ -91,7 +86,7 @@ class CachedMeta(CloudMeta):
|
|
|
91
86
|
|
|
92
87
|
# Create both our Meta Cache and Fresh Cache (tracks if data is stale)
|
|
93
88
|
self.meta_cache = WorkbenchCache(prefix="meta")
|
|
94
|
-
self.fresh_cache = WorkbenchCache(prefix="meta_fresh", expire=
|
|
89
|
+
self.fresh_cache = WorkbenchCache(prefix="meta_fresh", expire=300) # 5-minute expiration
|
|
95
90
|
|
|
96
91
|
# Create a ThreadPoolExecutor for refreshing stale data
|
|
97
92
|
self.thread_pool = ThreadPoolExecutor(max_workers=5)
|
|
@@ -193,13 +188,16 @@ class CachedMeta(CloudMeta):
|
|
|
193
188
|
return super().models(details=details)
|
|
194
189
|
|
|
195
190
|
@cache_result
|
|
196
|
-
def endpoints(self) -> pd.DataFrame:
|
|
191
|
+
def endpoints(self, details: bool = False) -> pd.DataFrame:
|
|
197
192
|
"""Get a summary of the Endpoints deployed in the Cloud Platform
|
|
198
193
|
|
|
194
|
+
Args:
|
|
195
|
+
details (bool, optional): Include detailed information. Defaults to False.
|
|
196
|
+
|
|
199
197
|
Returns:
|
|
200
198
|
pd.DataFrame: A summary of the Endpoints in the Cloud Platform
|
|
201
199
|
"""
|
|
202
|
-
return super().endpoints()
|
|
200
|
+
return super().endpoints(details=details)
|
|
203
201
|
|
|
204
202
|
@cache_result
|
|
205
203
|
def glue_job(self, job_name: str) -> Union[dict, None]:
|
|
@@ -266,7 +264,7 @@ class CachedMeta(CloudMeta):
|
|
|
266
264
|
"""Background task to refresh AWS metadata."""
|
|
267
265
|
result = method(self, *args, **kwargs)
|
|
268
266
|
self.meta_cache.set(cache_key, result)
|
|
269
|
-
self.log.
|
|
267
|
+
self.log.important(f"Updated Metadata for {cache_key}")
|
|
270
268
|
|
|
271
269
|
@staticmethod
|
|
272
270
|
def _flatten_redis_key(method, *args, **kwargs):
|
|
@@ -13,7 +13,7 @@ class CachedArtifactMixin:
|
|
|
13
13
|
# Class-level caches, thread pool, and shutdown flag
|
|
14
14
|
log = logging.getLogger("workbench")
|
|
15
15
|
artifact_cache = WorkbenchCache(prefix="artifact_cache")
|
|
16
|
-
fresh_cache = WorkbenchCache(prefix="artifact_fresh_cache", expire=
|
|
16
|
+
fresh_cache = WorkbenchCache(prefix="artifact_fresh_cache", expire=120)
|
|
17
17
|
thread_pool = ThreadPoolExecutor(max_workers=5)
|
|
18
18
|
|
|
19
19
|
@staticmethod
|
|
@@ -45,8 +45,8 @@ class CachedArtifactMixin:
|
|
|
45
45
|
cls.artifact_cache.set(cache_key, result)
|
|
46
46
|
return result
|
|
47
47
|
|
|
48
|
-
# Stale cache: Refresh in the background
|
|
49
|
-
if
|
|
48
|
+
# Stale cache: Refresh in the background
|
|
49
|
+
if cache_fresh is None:
|
|
50
50
|
self.log.debug(f"Async: Refresh thread started: {cache_key}...")
|
|
51
51
|
cls.fresh_cache.set(cache_key, True)
|
|
52
52
|
cls.thread_pool.submit(cls._refresh_data_in_background, self, cache_key, method, *args, **kwargs)
|
|
@@ -88,4 +88,7 @@ if __name__ == "__main__":
|
|
|
88
88
|
my_model = CachedModel("abalone-regression")
|
|
89
89
|
pprint(my_model.summary())
|
|
90
90
|
pprint(my_model.details())
|
|
91
|
+
# Second call to demonstrate caching
|
|
92
|
+
pprint(my_model.summary())
|
|
93
|
+
pprint(my_model.details())
|
|
91
94
|
CachedArtifactMixin._shutdown()
|
|
@@ -42,24 +42,36 @@ class ModelImages:
|
|
|
42
42
|
|
|
43
43
|
image_uris = {
|
|
44
44
|
# US East 1 images
|
|
45
|
-
("us-east-1", "
|
|
45
|
+
("us-east-1", "xgb_training", "0.1", "x86_64"): (
|
|
46
46
|
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
|
|
47
47
|
),
|
|
48
|
-
("us-east-1", "
|
|
48
|
+
("us-east-1", "xgb_inference", "0.1", "x86_64"): (
|
|
49
49
|
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
|
|
50
50
|
),
|
|
51
|
+
("us-east-1", "pytorch_training", "0.1", "x86_64"): (
|
|
52
|
+
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
|
|
53
|
+
),
|
|
54
|
+
("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
|
|
55
|
+
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
56
|
+
),
|
|
51
57
|
# US West 2 images
|
|
52
|
-
("us-west-2", "
|
|
58
|
+
("us-west-2", "xgb_training", "0.1", "x86_64"): (
|
|
53
59
|
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
|
|
54
60
|
),
|
|
55
|
-
("us-west-2", "
|
|
61
|
+
("us-west-2", "xgb_inference", "0.1", "x86_64"): (
|
|
56
62
|
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
|
|
57
63
|
),
|
|
64
|
+
("us-west-2", "pytorch_training", "0.1", "x86_64"): (
|
|
65
|
+
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
|
|
66
|
+
),
|
|
67
|
+
("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
|
|
68
|
+
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
|
|
69
|
+
),
|
|
58
70
|
# ARM64 images
|
|
59
|
-
("us-east-1", "
|
|
71
|
+
("us-east-1", "xgb_inference", "0.1", "arm64"): (
|
|
60
72
|
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
61
73
|
),
|
|
62
|
-
("us-west-2", "
|
|
74
|
+
("us-west-2", "xgb_inference", "0.1", "arm64"): (
|
|
63
75
|
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
64
76
|
),
|
|
65
77
|
# Meta Endpoint inference images
|
|
@@ -72,7 +84,7 @@ class ModelImages:
|
|
|
72
84
|
}
|
|
73
85
|
|
|
74
86
|
@classmethod
|
|
75
|
-
def get_image_uri(cls, region, image_type
|
|
87
|
+
def get_image_uri(cls, region, image_type, version="0.1", architecture="x86_64"):
|
|
76
88
|
key = (region, image_type, version, architecture)
|
|
77
89
|
if key in cls.image_uris:
|
|
78
90
|
return cls.image_uris[key]
|
|
@@ -179,7 +179,7 @@ class AWSMeta:
|
|
|
179
179
|
feature_set_details.update(self.sm_client.describe_feature_group(FeatureGroupName=name))
|
|
180
180
|
|
|
181
181
|
# Retrieve Workbench metadata from tags
|
|
182
|
-
aws_tags = self.get_aws_tags(fg["FeatureGroupArn"])
|
|
182
|
+
aws_tags = self.get_aws_tags(fg["FeatureGroupArn"]) if details else {}
|
|
183
183
|
summary = {
|
|
184
184
|
"Feature Group": name,
|
|
185
185
|
"Health": "",
|
|
@@ -258,70 +258,60 @@ class AWSMeta:
|
|
|
258
258
|
df = pd.DataFrame(model_summary).convert_dtypes()
|
|
259
259
|
return df.sort_values(by="Created", ascending=False)
|
|
260
260
|
|
|
261
|
-
def endpoints(self,
|
|
261
|
+
def endpoints(self, details: bool = False) -> pd.DataFrame:
|
|
262
262
|
"""Get a summary of the Endpoints in AWS.
|
|
263
263
|
|
|
264
264
|
Args:
|
|
265
|
-
|
|
265
|
+
details (bool, optional): Get additional details (Defaults to False).
|
|
266
266
|
|
|
267
267
|
Returns:
|
|
268
268
|
pd.DataFrame: A summary of the Endpoints in AWS.
|
|
269
269
|
"""
|
|
270
270
|
from workbench.utils.endpoint_utils import is_monitored # noqa: E402
|
|
271
271
|
|
|
272
|
-
#
|
|
273
|
-
|
|
274
|
-
paginator = sagemaker_client.get_paginator("list_endpoints")
|
|
272
|
+
# Use our SageMaker client to list all endpoints
|
|
273
|
+
paginator = self.sm_client.get_paginator("list_endpoints")
|
|
275
274
|
data_summary = []
|
|
276
275
|
|
|
277
276
|
# Use the paginator to retrieve all endpoints
|
|
278
277
|
for page in paginator.paginate():
|
|
279
278
|
for endpoint in page["Endpoints"]:
|
|
280
279
|
endpoint_name = endpoint["EndpointName"]
|
|
281
|
-
endpoint_info = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
|
|
282
280
|
|
|
283
|
-
#
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
if
|
|
297
|
-
|
|
298
|
-
mem_size = production_variant["ServerlessConfig"]["MemorySizeInMB"]
|
|
299
|
-
concurrency = production_variant["ServerlessConfig"]["MaxConcurrency"]
|
|
300
|
-
instance_type = f"Serverless ({mem_size // 1024}GB/{concurrency})"
|
|
301
|
-
except sagemaker_client.exceptions.ClientError:
|
|
302
|
-
# If the endpoint config is not found, change the config name to reflect this
|
|
303
|
-
endpoint_config_name = f"{endpoint_config_name} (Not Found)"
|
|
304
|
-
production_variant = {}
|
|
305
|
-
instance_type = "Unknown"
|
|
306
|
-
|
|
307
|
-
# Check if the endpoint has monitoring enabled
|
|
308
|
-
endpoint_monitored = is_monitored(endpoint_name, sagemaker_client)
|
|
281
|
+
# Grab various endpoint details
|
|
282
|
+
endpoint_details = {"config": {"instance": "-", "variant": "-"}, "monitored": "-"}
|
|
283
|
+
aws_tags = {}
|
|
284
|
+
if details:
|
|
285
|
+
endpoint_details = self.sm_client.describe_endpoint(EndpointName=endpoint_name)
|
|
286
|
+
|
|
287
|
+
# Retrieve AWS Tags for this Endpoint
|
|
288
|
+
aws_tags = self.get_aws_tags(endpoint_details["EndpointArn"])
|
|
289
|
+
|
|
290
|
+
# Getting the endpoint configuration
|
|
291
|
+
config_info = self._endpoint_config_info(endpoint_details["EndpointConfigName"])
|
|
292
|
+
endpoint_details["config"] = config_info
|
|
293
|
+
|
|
294
|
+
# Check if the endpoint has monitoring enabled
|
|
295
|
+
endpoint_details["monitored"] = is_monitored(endpoint_name, self.sm_client)
|
|
309
296
|
|
|
310
297
|
# Compile endpoint summary
|
|
298
|
+
created = (
|
|
299
|
+
datetime_string(endpoint_details["CreationTime"]) if "CreationTime" in endpoint_details else "-"
|
|
300
|
+
)
|
|
311
301
|
summary = {
|
|
312
302
|
"Name": endpoint_name,
|
|
313
|
-
"Health":
|
|
303
|
+
"Health": aws_tags.get("workbench_health_tags", ""),
|
|
314
304
|
"Owner": aws_tags.get("workbench_owner", "-"),
|
|
315
|
-
"Instance":
|
|
316
|
-
"Created":
|
|
305
|
+
"Instance": endpoint_details["config"]["instance"],
|
|
306
|
+
"Created": created,
|
|
317
307
|
"Input": aws_tags.get("workbench_input", "-"),
|
|
318
|
-
"Status":
|
|
319
|
-
"Config":
|
|
320
|
-
"Variant":
|
|
321
|
-
"Capture": str(
|
|
322
|
-
"Samp(%)": str(
|
|
308
|
+
"Status": endpoint_details.get("EndpointStatus", "-"),
|
|
309
|
+
"Config": endpoint_details.get("EndpointConfigName", "-"),
|
|
310
|
+
"Variant": endpoint_details["config"]["variant"],
|
|
311
|
+
"Capture": str(endpoint_details.get("DataCaptureConfig", {}).get("EnableCapture", "False")),
|
|
312
|
+
"Samp(%)": str(endpoint_details.get("DataCaptureConfig", {}).get("CurrentSamplingPercentage", "-")),
|
|
323
313
|
"Tags": aws_tags.get("workbench_tags", "-"),
|
|
324
|
-
"Monitored":
|
|
314
|
+
"Monitored": endpoint_details["monitored"],
|
|
325
315
|
}
|
|
326
316
|
data_summary.append(summary)
|
|
327
317
|
|
|
@@ -329,6 +319,34 @@ class AWSMeta:
|
|
|
329
319
|
df = pd.DataFrame(data_summary).convert_dtypes()
|
|
330
320
|
return df.sort_values(by="Created", ascending=False)
|
|
331
321
|
|
|
322
|
+
def _endpoint_config_info(self, endpoint_config_name: str) -> dict:
|
|
323
|
+
"""Internal: Get the Endpoint Configuration information for the given endpoint config name.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
endpoint_config_name (str): The name of the endpoint configuration.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
dict: The endpoint configuration details.
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
# Retrieve the endpoint configuration
|
|
333
|
+
try:
|
|
334
|
+
endpoint_config = self.sm_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
|
335
|
+
production_variant = endpoint_config["ProductionVariants"][0]
|
|
336
|
+
|
|
337
|
+
# Determine instance type or serverless configuration
|
|
338
|
+
instance_type = production_variant.get("InstanceType")
|
|
339
|
+
if instance_type is None:
|
|
340
|
+
# If no instance type, it's a serverless configuration
|
|
341
|
+
mem_size = production_variant["ServerlessConfig"]["MemorySizeInMB"]
|
|
342
|
+
concurrency = production_variant["ServerlessConfig"]["MaxConcurrency"]
|
|
343
|
+
instance_type = f"Serverless ({mem_size // 1024}GB/{concurrency})"
|
|
344
|
+
|
|
345
|
+
return {"instance": instance_type, "variant": production_variant.get("VariantName", "-")}
|
|
346
|
+
except self.sm_client.exceptions.ClientError as e:
|
|
347
|
+
self.log.error(f"Error retrieving endpoint config {endpoint_config_name}: {e}")
|
|
348
|
+
return {"instance": "-", "variant": "-"}
|
|
349
|
+
|
|
332
350
|
def pipelines(self) -> pd.DataFrame:
|
|
333
351
|
"""List all the Pipelines in the S3 Bucket
|
|
334
352
|
|
|
@@ -702,7 +720,6 @@ class AWSMeta:
|
|
|
702
720
|
|
|
703
721
|
if __name__ == "__main__":
|
|
704
722
|
"""Exercise the Workbench AWSMeta Class"""
|
|
705
|
-
import time
|
|
706
723
|
from pprint import pprint
|
|
707
724
|
|
|
708
725
|
# Pandas Display Options
|
|
@@ -712,6 +729,7 @@ if __name__ == "__main__":
|
|
|
712
729
|
# Create the class
|
|
713
730
|
meta = AWSMeta()
|
|
714
731
|
|
|
732
|
+
"""
|
|
715
733
|
# Test the __repr__ method
|
|
716
734
|
print(meta)
|
|
717
735
|
|
|
@@ -759,11 +777,15 @@ if __name__ == "__main__":
|
|
|
759
777
|
start_time = time.time()
|
|
760
778
|
pprint(meta.models(details=True))
|
|
761
779
|
print(f"Elapsed Time Model (with details): {time.time() - start_time:.2f}")
|
|
762
|
-
|
|
780
|
+
"""
|
|
763
781
|
# Get the Endpoints
|
|
764
782
|
print("\n\n*** Endpoints ***")
|
|
765
783
|
pprint(meta.endpoints())
|
|
766
784
|
|
|
785
|
+
# Get the Endpoints with Details
|
|
786
|
+
print("\n\n*** Endpoints with Details ***")
|
|
787
|
+
pprint(meta.endpoints(details=True))
|
|
788
|
+
|
|
767
789
|
# List Pipelines
|
|
768
790
|
print("\n\n*** Workbench Pipelines ***")
|
|
769
791
|
pprint(meta.pipelines())
|
|
@@ -785,7 +807,6 @@ if __name__ == "__main__":
|
|
|
785
807
|
pprint(meta.model("abalone-regression"))
|
|
786
808
|
print("\n\n*** Endpoint Details ***")
|
|
787
809
|
pprint(meta.endpoint("abalone-regression"))
|
|
788
|
-
pprint(meta.endpoint("test-timing-realtime"))
|
|
789
810
|
|
|
790
811
|
# Test out a non-existent model
|
|
791
812
|
print("\n\n*** Model Doesn't Exist ***")
|
|
@@ -121,13 +121,16 @@ class CloudMeta(AWSMeta):
|
|
|
121
121
|
"""
|
|
122
122
|
return super().models(details=details)
|
|
123
123
|
|
|
124
|
-
def endpoints(self) -> pd.DataFrame:
|
|
124
|
+
def endpoints(self, details: bool = False) -> pd.DataFrame:
|
|
125
125
|
"""Get a summary of the Endpoints deployed in the Cloud Platform
|
|
126
126
|
|
|
127
|
+
Args:
|
|
128
|
+
details (bool, optional): Include detailed information. Defaults to False.
|
|
129
|
+
|
|
127
130
|
Returns:
|
|
128
131
|
pd.DataFrame: A summary of the Endpoints in the Cloud Platform
|
|
129
132
|
"""
|
|
130
|
-
return super().endpoints()
|
|
133
|
+
return super().endpoints(details=details)
|
|
131
134
|
|
|
132
135
|
def pipelines(self) -> pd.DataFrame:
|
|
133
136
|
"""Get a summary of the Pipelines deployed in the Cloud Platform
|
|
@@ -37,8 +37,9 @@ class FeaturesToModel(Transform):
|
|
|
37
37
|
model_import_str=None,
|
|
38
38
|
custom_script=None,
|
|
39
39
|
custom_args=None,
|
|
40
|
+
training_image="xgb_training",
|
|
41
|
+
inference_image="xgb_inference",
|
|
40
42
|
inference_arch="x86_64",
|
|
41
|
-
inference_image="inference",
|
|
42
43
|
):
|
|
43
44
|
"""FeaturesToModel Initialization
|
|
44
45
|
Args:
|
|
@@ -49,8 +50,9 @@ class FeaturesToModel(Transform):
|
|
|
49
50
|
model_import_str (str, optional): The import string for the model (default None)
|
|
50
51
|
custom_script (str, optional): Custom script to use for the model (default None)
|
|
51
52
|
custom_args (dict, optional): Custom arguments to pass to custom model scripts (default None)
|
|
53
|
+
training_image (str, optional): Training image (default "xgb_training")
|
|
54
|
+
inference_image (str, optional): Inference image (default "xgb_inference")
|
|
52
55
|
inference_arch (str, optional): Inference architecture (default "x86_64")
|
|
53
|
-
inference_image (str, optional): Inference image (default "inference")
|
|
54
56
|
"""
|
|
55
57
|
|
|
56
58
|
# Make sure the model_name is a valid name
|
|
@@ -73,8 +75,9 @@ class FeaturesToModel(Transform):
|
|
|
73
75
|
self.model_feature_list = None
|
|
74
76
|
self.target_column = None
|
|
75
77
|
self.class_labels = None
|
|
76
|
-
self.
|
|
78
|
+
self.training_image = training_image
|
|
77
79
|
self.inference_image = inference_image
|
|
80
|
+
self.inference_arch = inference_arch
|
|
78
81
|
|
|
79
82
|
def transform_impl(
|
|
80
83
|
self, target_column: str, description: str = None, feature_list: list = None, train_all_data=False, **kwargs
|
|
@@ -229,7 +232,7 @@ class FeaturesToModel(Transform):
|
|
|
229
232
|
source_dir = str(Path(script_path).parent)
|
|
230
233
|
|
|
231
234
|
# Create a Sagemaker Model with our script
|
|
232
|
-
image = ModelImages.get_image_uri(self.sm_session.boto_region_name,
|
|
235
|
+
image = ModelImages.get_image_uri(self.sm_session.boto_region_name, self.training_image, "0.1")
|
|
233
236
|
self.estimator = Estimator(
|
|
234
237
|
entry_point=entry_point,
|
|
235
238
|
source_dir=source_dir,
|
|
@@ -246,6 +249,7 @@ class FeaturesToModel(Transform):
|
|
|
246
249
|
training_job_name = f"{self.output_name}-{training_date_time_utc}"
|
|
247
250
|
|
|
248
251
|
# Train the estimator
|
|
252
|
+
self.log.important(f"Training the Model {self.output_name} with Training Image {image}...")
|
|
249
253
|
self.estimator.fit({"train": s3_training_path}, job_name=training_job_name)
|
|
250
254
|
|
|
251
255
|
# Now delete the training data
|
|
@@ -297,7 +301,7 @@ class FeaturesToModel(Transform):
|
|
|
297
301
|
image = ModelImages.get_image_uri(
|
|
298
302
|
self.sm_session.boto_region_name, self.inference_image, "0.1", self.inference_arch
|
|
299
303
|
)
|
|
300
|
-
self.log.important(f"Registering model {self.output_name} with
|
|
304
|
+
self.log.important(f"Registering model {self.output_name} with Inference Image {image}...")
|
|
301
305
|
model = self.estimator.create_model(role=self.workbench_role_arn)
|
|
302
306
|
if aws_region:
|
|
303
307
|
self.log.important(f"Setting AWS Region: {aws_region} for model {self.output_name}...")
|
|
@@ -78,6 +78,12 @@ class ModelToEndpoint(Transform):
|
|
|
78
78
|
sagemaker_session=self.sm_session,
|
|
79
79
|
)
|
|
80
80
|
|
|
81
|
+
# Log the image that will be used for deployment
|
|
82
|
+
inference_image = self.sm_client.describe_model_package(ModelPackageName=model_package_arn)[
|
|
83
|
+
"InferenceSpecification"
|
|
84
|
+
]["Containers"][0]["Image"]
|
|
85
|
+
self.log.important(f"Deploying Model Package: {self.input_name} with Inference Image: {inference_image}")
|
|
86
|
+
|
|
81
87
|
# Get the metadata/tags to push into AWS
|
|
82
88
|
aws_tags = self.get_aws_tags()
|
|
83
89
|
|