workbench 0.8.201__py3-none-any.whl → 0.8.204__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/api/df_store.py +17 -108
- workbench/api/feature_set.py +41 -7
- workbench/api/parameter_store.py +3 -52
- workbench/core/artifacts/artifact.py +5 -5
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +184 -75
- workbench/core/artifacts/model_core.py +11 -7
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/transforms/features_to_model/features_to_model.py +27 -13
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +11 -0
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
- workbench/model_scripts/chemprop/chemprop.template +312 -293
- workbench/model_scripts/chemprop/generated_model_script.py +316 -297
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +11 -5
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +11 -5
- workbench/model_scripts/custom_models/uq_models/ngboost.template +11 -5
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +11 -5
- workbench/model_scripts/pytorch_model/generated_model_script.py +278 -128
- workbench/model_scripts/pytorch_model/pytorch.template +273 -123
- workbench/model_scripts/uq_models/generated_model_script.py +20 -11
- workbench/model_scripts/uq_models/mapie.template +17 -8
- workbench/model_scripts/xgb_model/generated_model_script.py +38 -9
- workbench/model_scripts/xgb_model/xgb_model.template +34 -5
- workbench/resources/open_source_api.key +1 -1
- workbench/utils/chemprop_utils.py +38 -1
- workbench/utils/pytorch_utils.py +38 -8
- workbench/web_interface/components/model_plot.py +7 -1
- {workbench-0.8.201.dist-info → workbench-0.8.204.dist-info}/METADATA +2 -2
- {workbench-0.8.201.dist-info → workbench-0.8.204.dist-info}/RECORD +33 -33
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -296
- {workbench-0.8.201.dist-info → workbench-0.8.204.dist-info}/WHEEL +0 -0
- {workbench-0.8.201.dist-info → workbench-0.8.204.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.201.dist-info → workbench-0.8.204.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.201.dist-info → workbench-0.8.204.dist-info}/top_level.txt +0 -0
|
@@ -1,404 +0,0 @@
|
|
|
1
|
-
"""AWSDFStore: Fast/efficient storage of DataFrames using AWS S3/Parquet/Snappy"""
|
|
2
|
-
|
|
3
|
-
from datetime import datetime
|
|
4
|
-
from typing import Union
|
|
5
|
-
import logging
|
|
6
|
-
import awswrangler as wr
|
|
7
|
-
import pandas as pd
|
|
8
|
-
import re
|
|
9
|
-
from urllib.parse import urlparse
|
|
10
|
-
|
|
11
|
-
# Workbench Imports
|
|
12
|
-
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
13
|
-
from workbench.utils.config_manager import ConfigManager
|
|
14
|
-
from workbench.utils.aws_utils import not_found_returns_none
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class AWSDFStore:
|
|
18
|
-
"""AWSDFStore: Fast/efficient storage of DataFrames using AWS S3/Parquet/Snappy
|
|
19
|
-
|
|
20
|
-
Common Usage:
|
|
21
|
-
```python
|
|
22
|
-
df_store = AWSDFStore()
|
|
23
|
-
|
|
24
|
-
# List Data
|
|
25
|
-
df_store.list()
|
|
26
|
-
|
|
27
|
-
# Add DataFrame
|
|
28
|
-
df = pd.DataFrame({"A": [1, 2], "B": [3, 4]})
|
|
29
|
-
df_store.upsert("/test/my_data", df)
|
|
30
|
-
|
|
31
|
-
# Retrieve DataFrame
|
|
32
|
-
df = df_store.get("/test/my_data")
|
|
33
|
-
print(df)
|
|
34
|
-
|
|
35
|
-
# Delete Data
|
|
36
|
-
df_store.delete("/test/my_data")
|
|
37
|
-
```
|
|
38
|
-
"""
|
|
39
|
-
|
|
40
|
-
def __init__(self, path_prefix: Union[str, None] = None):
|
|
41
|
-
"""AWSDFStore Init Method
|
|
42
|
-
|
|
43
|
-
Args:
|
|
44
|
-
path_prefix (Union[str, None], optional): Path prefix for storage locations (Defaults to None)
|
|
45
|
-
"""
|
|
46
|
-
self.log = logging.getLogger("workbench")
|
|
47
|
-
self._base_prefix = "df_store/"
|
|
48
|
-
self.path_prefix = self._base_prefix + path_prefix if path_prefix else self._base_prefix
|
|
49
|
-
self.path_prefix = re.sub(r"/+", "/", self.path_prefix) # Collapse slashes
|
|
50
|
-
|
|
51
|
-
# Get the Workbench Bucket
|
|
52
|
-
config = ConfigManager()
|
|
53
|
-
self.workbench_bucket = config.get_config("WORKBENCH_BUCKET")
|
|
54
|
-
|
|
55
|
-
# Get the S3 Client
|
|
56
|
-
self.boto3_session = AWSAccountClamp().boto3_session
|
|
57
|
-
self.s3_client = self.boto3_session.client("s3")
|
|
58
|
-
|
|
59
|
-
def list(self, include_cache: bool = False) -> list:
|
|
60
|
-
"""List all objects in the data_store prefix
|
|
61
|
-
|
|
62
|
-
Args:
|
|
63
|
-
include_cache (bool, optional): Include cache objects in the list (Defaults to False)
|
|
64
|
-
|
|
65
|
-
Returns:
|
|
66
|
-
list: A list of all the objects in the data_store prefix.
|
|
67
|
-
"""
|
|
68
|
-
df = self.summary(include_cache=include_cache)
|
|
69
|
-
return df["location"].tolist()
|
|
70
|
-
|
|
71
|
-
def last_modified(self, location: str) -> Union[datetime, None]:
|
|
72
|
-
"""Return the last modified date of a graph.
|
|
73
|
-
|
|
74
|
-
Args:
|
|
75
|
-
location (str): Logical location of the graph.
|
|
76
|
-
|
|
77
|
-
Returns:
|
|
78
|
-
Union[datetime, None]: Last modified datetime or None if not found.
|
|
79
|
-
"""
|
|
80
|
-
s3_uri = self._generate_s3_uri(location)
|
|
81
|
-
bucket, key = self._parse_s3_uri(s3_uri)
|
|
82
|
-
|
|
83
|
-
try:
|
|
84
|
-
response = self.s3_client.head_object(Bucket=bucket, Key=key)
|
|
85
|
-
return response["LastModified"]
|
|
86
|
-
except self.s3_client.exceptions.ClientError:
|
|
87
|
-
return None
|
|
88
|
-
|
|
89
|
-
def summary(self, include_cache: bool = False) -> pd.DataFrame:
|
|
90
|
-
"""Return a nicely formatted summary of object locations, sizes (in MB), and modified dates.
|
|
91
|
-
|
|
92
|
-
Args:
|
|
93
|
-
include_cache (bool, optional): Include cache objects in the summary (Defaults to False)
|
|
94
|
-
"""
|
|
95
|
-
df = self.details(include_cache=include_cache)
|
|
96
|
-
|
|
97
|
-
# Create a formatted DataFrame
|
|
98
|
-
formatted_df = pd.DataFrame(
|
|
99
|
-
{
|
|
100
|
-
"location": df["location"],
|
|
101
|
-
"size (MB)": (df["size"] / (1024 * 1024)).round(2), # Convert size to MB
|
|
102
|
-
"modified": pd.to_datetime(df["modified"]).dt.strftime("%Y-%m-%d %H:%M:%S"), # Format date
|
|
103
|
-
}
|
|
104
|
-
)
|
|
105
|
-
return formatted_df
|
|
106
|
-
|
|
107
|
-
def details(self, include_cache: bool = False) -> pd.DataFrame:
|
|
108
|
-
"""Return detailed metadata for all objects, optionally excluding the specified prefix.
|
|
109
|
-
|
|
110
|
-
Args:
|
|
111
|
-
include_cache (bool, optional): Include cache objects in the details (Defaults to False)
|
|
112
|
-
"""
|
|
113
|
-
try:
|
|
114
|
-
response = self.s3_client.list_objects_v2(Bucket=self.workbench_bucket, Prefix=self.path_prefix)
|
|
115
|
-
if "Contents" not in response:
|
|
116
|
-
return pd.DataFrame(columns=["location", "s3_file", "size", "modified"])
|
|
117
|
-
|
|
118
|
-
# Collect details for each object
|
|
119
|
-
data = []
|
|
120
|
-
for obj in response["Contents"]:
|
|
121
|
-
full_key = obj["Key"]
|
|
122
|
-
|
|
123
|
-
# Reverse logic: Strip the bucket/prefix in the front and .parquet in the end
|
|
124
|
-
location = full_key.replace(f"{self.path_prefix}", "/").split(".parquet")[0]
|
|
125
|
-
s3_file = f"s3://{self.workbench_bucket}/{full_key}"
|
|
126
|
-
size = obj["Size"]
|
|
127
|
-
modified = obj["LastModified"]
|
|
128
|
-
data.append([location, s3_file, size, modified])
|
|
129
|
-
|
|
130
|
-
# Create the DataFrame
|
|
131
|
-
df = pd.DataFrame(data, columns=["location", "s3_file", "size", "modified"])
|
|
132
|
-
|
|
133
|
-
# Apply the exclude_prefix filter if set
|
|
134
|
-
cache_prefix = "/workbench/dataframe_cache/"
|
|
135
|
-
if not include_cache:
|
|
136
|
-
df = df[~df["location"].str.startswith(cache_prefix)]
|
|
137
|
-
|
|
138
|
-
return df
|
|
139
|
-
|
|
140
|
-
except Exception as e:
|
|
141
|
-
self.log.error(f"Failed to get object details: {e}")
|
|
142
|
-
return pd.DataFrame(columns=["location", "s3_file", "size", "created", "modified"])
|
|
143
|
-
|
|
144
|
-
def check(self, location: str) -> bool:
|
|
145
|
-
"""Check if a DataFrame exists at the specified location
|
|
146
|
-
|
|
147
|
-
Args:
|
|
148
|
-
location (str): The location of the data to check.
|
|
149
|
-
|
|
150
|
-
Returns:
|
|
151
|
-
bool: True if the data exists, False otherwise.
|
|
152
|
-
"""
|
|
153
|
-
# Generate the specific S3 prefix for the target location
|
|
154
|
-
s3_prefix = f"{self.path_prefix}/{location}.parquet/"
|
|
155
|
-
s3_prefix = re.sub(r"/+", "/", s3_prefix) # Collapse slashes
|
|
156
|
-
|
|
157
|
-
# Use list_objects_v2 to check if any objects exist under this specific prefix
|
|
158
|
-
response = self.s3_client.list_objects_v2(Bucket=self.workbench_bucket, Prefix=s3_prefix, MaxKeys=1)
|
|
159
|
-
return "Contents" in response
|
|
160
|
-
|
|
161
|
-
@not_found_returns_none
|
|
162
|
-
def get(self, location: str) -> Union[pd.DataFrame, None]:
|
|
163
|
-
"""Retrieve a DataFrame from AWS S3.
|
|
164
|
-
|
|
165
|
-
Args:
|
|
166
|
-
location (str): The location of the data to retrieve.
|
|
167
|
-
|
|
168
|
-
Returns:
|
|
169
|
-
pd.DataFrame: The retrieved DataFrame or None if not found.
|
|
170
|
-
"""
|
|
171
|
-
s3_uri = self._generate_s3_uri(location)
|
|
172
|
-
return wr.s3.read_parquet(s3_uri)
|
|
173
|
-
|
|
174
|
-
def upsert(self, location: str, data: Union[pd.DataFrame, pd.Series]):
|
|
175
|
-
"""Insert or update a DataFrame or Series in the AWS S3.
|
|
176
|
-
|
|
177
|
-
Args:
|
|
178
|
-
location (str): The location of the data.
|
|
179
|
-
data (Union[pd.DataFrame, pd.Series]): The data to be stored.
|
|
180
|
-
"""
|
|
181
|
-
# Check if the data is a Pandas Series, convert it to a DataFrame
|
|
182
|
-
if isinstance(data, pd.Series):
|
|
183
|
-
data = data.to_frame()
|
|
184
|
-
|
|
185
|
-
# Ensure data is a DataFrame
|
|
186
|
-
if not isinstance(data, pd.DataFrame):
|
|
187
|
-
raise ValueError("Only Pandas DataFrame or Series objects are supported.")
|
|
188
|
-
|
|
189
|
-
# Convert object columns to string type to avoid PyArrow type inference issues.
|
|
190
|
-
data = self.type_convert_before_parquet(data)
|
|
191
|
-
|
|
192
|
-
# Update/Insert the DataFrame to S3
|
|
193
|
-
s3_uri = self._generate_s3_uri(location)
|
|
194
|
-
try:
|
|
195
|
-
wr.s3.to_parquet(df=data, path=s3_uri, dataset=True, mode="overwrite", index=True)
|
|
196
|
-
self.log.info(f"Dataframe cached {s3_uri}...")
|
|
197
|
-
except Exception as e:
|
|
198
|
-
self.log.error(f"Failed to cache dataframe '{s3_uri}': {e}")
|
|
199
|
-
raise
|
|
200
|
-
|
|
201
|
-
@staticmethod
|
|
202
|
-
def type_convert_before_parquet(df: pd.DataFrame) -> pd.DataFrame:
|
|
203
|
-
# Convert object columns to string type to avoid PyArrow type inference issues.
|
|
204
|
-
df = df.copy()
|
|
205
|
-
object_cols = df.select_dtypes(include=["object"]).columns
|
|
206
|
-
df[object_cols] = df[object_cols].astype("str")
|
|
207
|
-
return df
|
|
208
|
-
|
|
209
|
-
def delete(self, location: str):
|
|
210
|
-
"""Delete a DataFrame from the AWS S3.
|
|
211
|
-
|
|
212
|
-
Args:
|
|
213
|
-
location (str): The location of the data to delete.
|
|
214
|
-
"""
|
|
215
|
-
s3_uri = self._generate_s3_uri(location)
|
|
216
|
-
|
|
217
|
-
# Check if the folder (prefix) exists in S3
|
|
218
|
-
if not wr.s3.list_objects(s3_uri):
|
|
219
|
-
self.log.info(f"Data '{location}' does not exist in S3...")
|
|
220
|
-
return
|
|
221
|
-
|
|
222
|
-
# Delete the data from S3
|
|
223
|
-
try:
|
|
224
|
-
wr.s3.delete_objects(s3_uri)
|
|
225
|
-
self.log.info(f"Data '{location}' deleted successfully from S3.")
|
|
226
|
-
except Exception as e:
|
|
227
|
-
self.log.error(f"Failed to delete data '{location}': {e}")
|
|
228
|
-
|
|
229
|
-
def delete_recursive(self, location: str):
|
|
230
|
-
"""Recursively delete all data under the specified location in AWS S3.
|
|
231
|
-
|
|
232
|
-
Args:
|
|
233
|
-
location (str): The location prefix of the data to delete.
|
|
234
|
-
"""
|
|
235
|
-
# Construct the full prefix for S3
|
|
236
|
-
s3_prefix = re.sub(r"/+", "/", f"{self.path_prefix}/{location}") # Collapse slashes
|
|
237
|
-
s3_prefix = s3_prefix.rstrip("/") + "/" # Ensure the prefix ends with a slash
|
|
238
|
-
|
|
239
|
-
# List all objects under the given prefix
|
|
240
|
-
try:
|
|
241
|
-
response = self.s3_client.list_objects_v2(Bucket=self.workbench_bucket, Prefix=s3_prefix)
|
|
242
|
-
if "Contents" not in response:
|
|
243
|
-
self.log.info(f"No data found under '{s3_prefix}' to delete.")
|
|
244
|
-
return
|
|
245
|
-
|
|
246
|
-
# Gather all keys to delete
|
|
247
|
-
keys = [{"Key": obj["Key"]} for obj in response["Contents"]]
|
|
248
|
-
response = self.s3_client.delete_objects(Bucket=self.workbench_bucket, Delete={"Objects": keys})
|
|
249
|
-
for response in response.get("Deleted", []):
|
|
250
|
-
self.log.info(f"Deleted: {response['Key']}")
|
|
251
|
-
|
|
252
|
-
except Exception as e:
|
|
253
|
-
self.log.error(f"Failed to delete data recursively at '{location}': {e}")
|
|
254
|
-
|
|
255
|
-
def list_subfiles(self, prefix: str) -> list:
|
|
256
|
-
"""Return a list of file locations with the given prefix.
|
|
257
|
-
|
|
258
|
-
Args:
|
|
259
|
-
prefix (str, optional): Only include files with the given prefix
|
|
260
|
-
|
|
261
|
-
Returns:
|
|
262
|
-
list: List of file locations (paths)
|
|
263
|
-
"""
|
|
264
|
-
try:
|
|
265
|
-
full_prefix = f"{self.path_prefix}{prefix.lstrip('/')}"
|
|
266
|
-
response = self.s3_client.list_objects_v2(Bucket=self.workbench_bucket, Prefix=full_prefix)
|
|
267
|
-
if "Contents" not in response:
|
|
268
|
-
return []
|
|
269
|
-
|
|
270
|
-
locations = []
|
|
271
|
-
for obj in response["Contents"]:
|
|
272
|
-
full_key = obj["Key"]
|
|
273
|
-
location = full_key.replace(f"{self.path_prefix}", "/").split(".parquet")[0]
|
|
274
|
-
locations.append(location)
|
|
275
|
-
return locations
|
|
276
|
-
|
|
277
|
-
except Exception as e:
|
|
278
|
-
self.log.error(f"Failed to list subfiles: {e}")
|
|
279
|
-
return []
|
|
280
|
-
|
|
281
|
-
def _generate_s3_uri(self, location: str) -> str:
|
|
282
|
-
"""Generate the S3 URI for the given location."""
|
|
283
|
-
s3_path = f"{self.workbench_bucket}/{self.path_prefix}/{location}.parquet"
|
|
284
|
-
return f"s3://{re.sub(r'/+', '/', s3_path)}"
|
|
285
|
-
|
|
286
|
-
def _parse_s3_uri(self, s3_uri: str) -> tuple:
|
|
287
|
-
"""Parse an S3 URI into bucket and key."""
|
|
288
|
-
parsed = urlparse(s3_uri)
|
|
289
|
-
if parsed.scheme != "s3":
|
|
290
|
-
raise ValueError(f"Invalid S3 URI: {s3_uri}")
|
|
291
|
-
return parsed.netloc, parsed.path.lstrip("/")
|
|
292
|
-
|
|
293
|
-
def __repr__(self):
|
|
294
|
-
"""Return a string representation of the AWSDFStore object."""
|
|
295
|
-
# Use the summary() method and format it to align columns for printing
|
|
296
|
-
summary_df = self.summary()
|
|
297
|
-
|
|
298
|
-
# Sanity check: If there are no objects, return a message
|
|
299
|
-
if summary_df.empty:
|
|
300
|
-
return "AWSDFStore: No data objects found in the store."
|
|
301
|
-
|
|
302
|
-
# Dynamically compute the max length of the 'location' column and add 5 spaces for padding
|
|
303
|
-
max_location_len = summary_df["location"].str.len().max() + 2
|
|
304
|
-
summary_df["location"] = summary_df["location"].str.ljust(max_location_len)
|
|
305
|
-
|
|
306
|
-
# Format the size column to include (MB) and ensure 3 spaces between size and date
|
|
307
|
-
summary_df["size (MB)"] = summary_df["size (MB)"].apply(lambda x: f"{x:.2f} MB")
|
|
308
|
-
|
|
309
|
-
# Enclose the modified date in parentheses and ensure 3 spaces between size and date
|
|
310
|
-
summary_df["modified"] = summary_df["modified"].apply(lambda x: f" ({x})")
|
|
311
|
-
|
|
312
|
-
# Convert the DataFrame to a string, remove headers, and return
|
|
313
|
-
return summary_df.to_string(index=False, header=False)
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
if __name__ == "__main__":
|
|
317
|
-
"""Exercise the AWSDFStore Class"""
|
|
318
|
-
import time
|
|
319
|
-
|
|
320
|
-
# Create a AWSDFStore manager
|
|
321
|
-
df_store = AWSDFStore()
|
|
322
|
-
|
|
323
|
-
# Details of the Dataframe Store
|
|
324
|
-
print("Detailed Data...")
|
|
325
|
-
print(df_store.details())
|
|
326
|
-
|
|
327
|
-
# List all objects
|
|
328
|
-
print("List Data...")
|
|
329
|
-
print(df_store.list())
|
|
330
|
-
|
|
331
|
-
# Add a new DataFrame
|
|
332
|
-
my_df = pd.DataFrame({"A": [1, 2], "B": [3, 4]})
|
|
333
|
-
df_store.upsert("/testing/test_data", my_df)
|
|
334
|
-
|
|
335
|
-
# Check the last modified date
|
|
336
|
-
print("Last Modified Date:")
|
|
337
|
-
print(df_store.last_modified("/testing/test_data"))
|
|
338
|
-
|
|
339
|
-
# Get the DataFrame
|
|
340
|
-
print(f"Getting data 'test_data':\n{df_store.get('/testing/test_data')}")
|
|
341
|
-
|
|
342
|
-
# Now let's test adding a Series
|
|
343
|
-
series = pd.Series([1, 2, 3, 4], name="Series")
|
|
344
|
-
df_store.upsert("/testing/test_series", series)
|
|
345
|
-
print(f"Getting data 'test_series':\n{df_store.get('/testing/test_series')}")
|
|
346
|
-
|
|
347
|
-
# Summary of the data
|
|
348
|
-
print("Summary Data...")
|
|
349
|
-
print(df_store.summary())
|
|
350
|
-
|
|
351
|
-
# Repr of the AWSDFStore object
|
|
352
|
-
print("AWSDFStore Object:")
|
|
353
|
-
print(df_store)
|
|
354
|
-
|
|
355
|
-
# Check if the data exists
|
|
356
|
-
print("Check if data exists...")
|
|
357
|
-
print(df_store.check("/testing/test_data"))
|
|
358
|
-
print(df_store.check("/testing/test_series"))
|
|
359
|
-
|
|
360
|
-
# Time the check
|
|
361
|
-
start_time = time.time()
|
|
362
|
-
print(df_store.check("/testing/test_data"))
|
|
363
|
-
print("--- Check %s seconds ---" % (time.time() - start_time))
|
|
364
|
-
|
|
365
|
-
# Test list_subfiles
|
|
366
|
-
print("List Subfiles:")
|
|
367
|
-
print(df_store.list_subfiles("/testing"))
|
|
368
|
-
|
|
369
|
-
# Now delete the test data
|
|
370
|
-
df_store.delete("/testing/test_data")
|
|
371
|
-
df_store.delete("/testing/test_series")
|
|
372
|
-
|
|
373
|
-
# Check if the data exists
|
|
374
|
-
print("Check if data exists...")
|
|
375
|
-
print(df_store.check("/testing/test_data"))
|
|
376
|
-
print(df_store.check("/testing/test_series"))
|
|
377
|
-
|
|
378
|
-
# Add a bunch of dataframes and then test recursive delete
|
|
379
|
-
for i in range(10):
|
|
380
|
-
df_store.upsert(f"/testing/data_{i}", pd.DataFrame({"A": [1, 2], "B": [3, 4]}))
|
|
381
|
-
print("Before Recursive Delete:")
|
|
382
|
-
print(df_store.summary())
|
|
383
|
-
df_store.delete_recursive("/testing")
|
|
384
|
-
print("After Recursive Delete:")
|
|
385
|
-
print(df_store.summary())
|
|
386
|
-
|
|
387
|
-
# Get a non-existent DataFrame
|
|
388
|
-
print("Getting non-existent data...")
|
|
389
|
-
print(df_store.get("/testing/no_where"))
|
|
390
|
-
|
|
391
|
-
# Test path_prefix
|
|
392
|
-
df_store = AWSDFStore(path_prefix="/super/test")
|
|
393
|
-
print(df_store.path_prefix)
|
|
394
|
-
df_store.upsert("test_data", my_df)
|
|
395
|
-
print(df_store.get("test_data"))
|
|
396
|
-
print(df_store.summary())
|
|
397
|
-
df_store.delete("test_data")
|
|
398
|
-
print(df_store.summary())
|
|
399
|
-
|
|
400
|
-
# Test columns with Spaces in them
|
|
401
|
-
my_df = pd.DataFrame({"My A": [1, 2], "My B": [3, 4]})
|
|
402
|
-
df_store.upsert("/testing/test_data", my_df)
|
|
403
|
-
my_df = df_store.get("/testing/test_data")
|
|
404
|
-
print(my_df)
|
|
@@ -1,296 +0,0 @@
|
|
|
1
|
-
"""AWSParameterStore: Manages Workbench parameters in AWS Systems Manager Parameter Store."""
|
|
2
|
-
|
|
3
|
-
from typing import Union
|
|
4
|
-
import logging
|
|
5
|
-
import json
|
|
6
|
-
import zlib
|
|
7
|
-
import time
|
|
8
|
-
import base64
|
|
9
|
-
from botocore.exceptions import ClientError
|
|
10
|
-
|
|
11
|
-
# Workbench Imports
|
|
12
|
-
from workbench.core.cloud_platform.aws.aws_session import AWSSession
|
|
13
|
-
from workbench.utils.json_utils import CustomEncoder
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class AWSParameterStore:
|
|
17
|
-
"""AWSParameterStore: Manages Workbench parameters in AWS Systems Manager Parameter Store.
|
|
18
|
-
|
|
19
|
-
Common Usage:
|
|
20
|
-
```python
|
|
21
|
-
params = AWSParameterStore()
|
|
22
|
-
|
|
23
|
-
# List Parameters
|
|
24
|
-
params.list()
|
|
25
|
-
|
|
26
|
-
['/workbench/abalone_info',
|
|
27
|
-
'/workbench/my_data',
|
|
28
|
-
'/workbench/test',
|
|
29
|
-
'/workbench/pipelines/my_pipeline']
|
|
30
|
-
|
|
31
|
-
# Add Key
|
|
32
|
-
params.upsert("key", "value")
|
|
33
|
-
value = params.get("key")
|
|
34
|
-
|
|
35
|
-
# Add any data (lists, dictionaries, etc..)
|
|
36
|
-
my_data = {"key": "value", "number": 4.2, "list": [1,2,3]}
|
|
37
|
-
params.upsert("my_data", my_data)
|
|
38
|
-
|
|
39
|
-
# Retrieve data
|
|
40
|
-
return_value = params.get("my_data")
|
|
41
|
-
pprint(return_value)
|
|
42
|
-
|
|
43
|
-
{'key': 'value', 'list': [1, 2, 3], 'number': 4.2}
|
|
44
|
-
|
|
45
|
-
# Delete parameters
|
|
46
|
-
param_store.delete("my_data")
|
|
47
|
-
```
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
def __init__(self):
|
|
51
|
-
"""AWSParameterStore Init Method"""
|
|
52
|
-
self.log = logging.getLogger("workbench")
|
|
53
|
-
|
|
54
|
-
# Initialize a Workbench Session (to assume the Workbench ExecutionRole)
|
|
55
|
-
self.boto3_session = AWSSession().boto3_session
|
|
56
|
-
|
|
57
|
-
# Create a Systems Manager (SSM) client for Parameter Store operations
|
|
58
|
-
self.ssm_client = self.boto3_session.client("ssm")
|
|
59
|
-
|
|
60
|
-
def list(self, prefix: str = None) -> list:
|
|
61
|
-
"""List all parameters in the AWS Parameter Store, optionally filtering by a prefix.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
prefix (str, optional): A prefix to filter the parameters by. Defaults to None.
|
|
65
|
-
|
|
66
|
-
Returns:
|
|
67
|
-
list: A list of parameter names and details.
|
|
68
|
-
"""
|
|
69
|
-
try:
|
|
70
|
-
# Set up parameters for the query
|
|
71
|
-
params = {"MaxResults": 50}
|
|
72
|
-
|
|
73
|
-
# If a prefix is provided, add the 'ParameterFilters' for optimization
|
|
74
|
-
if prefix:
|
|
75
|
-
params["ParameterFilters"] = [{"Key": "Name", "Option": "BeginsWith", "Values": [prefix]}]
|
|
76
|
-
|
|
77
|
-
# Initialize the list to collect parameter names
|
|
78
|
-
all_parameters = []
|
|
79
|
-
|
|
80
|
-
# Make the initial call to describe parameters
|
|
81
|
-
response = self._call_with_retry(self.ssm_client.describe_parameters, **params)
|
|
82
|
-
|
|
83
|
-
# Aggregate the names from the initial response
|
|
84
|
-
all_parameters.extend(param["Name"] for param in response["Parameters"])
|
|
85
|
-
|
|
86
|
-
# Continue to paginate if there's a NextToken
|
|
87
|
-
while "NextToken" in response:
|
|
88
|
-
# Update the parameters with the NextToken for subsequent calls
|
|
89
|
-
params["NextToken"] = response["NextToken"]
|
|
90
|
-
response = self._call_with_retry(self.ssm_client.describe_parameters, **params)
|
|
91
|
-
|
|
92
|
-
# Aggregate the names from the subsequent responses
|
|
93
|
-
all_parameters.extend(param["Name"] for param in response["Parameters"])
|
|
94
|
-
|
|
95
|
-
except Exception as e:
|
|
96
|
-
self.log.error(f"Failed to list parameters: {e}")
|
|
97
|
-
return []
|
|
98
|
-
|
|
99
|
-
# Return the aggregated list of parameter names
|
|
100
|
-
return all_parameters
|
|
101
|
-
|
|
102
|
-
def get(self, name: str, warn: bool = True, decrypt: bool = True) -> Union[str, list, dict, None]:
|
|
103
|
-
"""Retrieve a parameter value from the AWS Parameter Store.
|
|
104
|
-
|
|
105
|
-
Args:
|
|
106
|
-
name (str): The name of the parameter to retrieve.
|
|
107
|
-
warn (bool): Whether to log a warning if the parameter is not found.
|
|
108
|
-
decrypt (bool): Whether to decrypt secure string parameters.
|
|
109
|
-
|
|
110
|
-
Returns:
|
|
111
|
-
Union[str, list, dict, None]: The value of the parameter or None if not found.
|
|
112
|
-
"""
|
|
113
|
-
try:
|
|
114
|
-
# Retrieve the parameter from Parameter Store
|
|
115
|
-
response = self.ssm_client.get_parameter(Name=name, WithDecryption=decrypt)
|
|
116
|
-
value = response["Parameter"]["Value"]
|
|
117
|
-
|
|
118
|
-
# Auto-detect and decompress if needed
|
|
119
|
-
if value.startswith("COMPRESSED:"):
|
|
120
|
-
# Base64 decode and decompress
|
|
121
|
-
self.log.important(f"Decompressing parameter '{name}'...")
|
|
122
|
-
compressed_value = base64.b64decode(value[len("COMPRESSED:") :])
|
|
123
|
-
value = zlib.decompress(compressed_value).decode("utf-8")
|
|
124
|
-
|
|
125
|
-
# Attempt to parse the value back to its original type
|
|
126
|
-
try:
|
|
127
|
-
parsed_value = json.loads(value)
|
|
128
|
-
return parsed_value
|
|
129
|
-
except (json.JSONDecodeError, TypeError):
|
|
130
|
-
# If parsing fails, return the value as is "hope for the best"
|
|
131
|
-
return value
|
|
132
|
-
|
|
133
|
-
except ClientError as e:
|
|
134
|
-
if e.response["Error"]["Code"] == "ParameterNotFound":
|
|
135
|
-
if warn:
|
|
136
|
-
self.log.warning(f"Parameter '{name}' not found")
|
|
137
|
-
else:
|
|
138
|
-
self.log.error(f"Failed to get parameter '{name}': {e}")
|
|
139
|
-
return None
|
|
140
|
-
|
|
141
|
-
def upsert(self, name: str, value, precision: int = 3):
|
|
142
|
-
"""Insert or update a parameter in the AWS Parameter Store.
|
|
143
|
-
|
|
144
|
-
Args:
|
|
145
|
-
name (str): The name of the parameter.
|
|
146
|
-
value (str | list | dict): The value of the parameter.
|
|
147
|
-
precision (int): The precision for float values in the JSON encoding.
|
|
148
|
-
"""
|
|
149
|
-
try:
|
|
150
|
-
# Convert to JSON and check if compression is needed
|
|
151
|
-
json_value = json.dumps(value, cls=CustomEncoder, precision=precision)
|
|
152
|
-
if len(json_value) <= 4096:
|
|
153
|
-
# Store normally if under 4KB
|
|
154
|
-
self._store_parameter(name, json_value)
|
|
155
|
-
return
|
|
156
|
-
|
|
157
|
-
# Need compression - log warning
|
|
158
|
-
self.log.important(
|
|
159
|
-
f"Parameter {name} exceeds 4KB ({len(json_value)} bytes): compressing and reducing precision..."
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
# Try compression with precision reduction
|
|
163
|
-
compressed_value = self._compress_value(value)
|
|
164
|
-
|
|
165
|
-
if len(compressed_value) <= 4096:
|
|
166
|
-
self._store_parameter(name, compressed_value)
|
|
167
|
-
return
|
|
168
|
-
|
|
169
|
-
# Try clipping the data
|
|
170
|
-
clipped_value = self._clip_data(value)
|
|
171
|
-
compressed_clipped = self._compress_value(clipped_value)
|
|
172
|
-
|
|
173
|
-
if len(compressed_clipped) <= 4096:
|
|
174
|
-
self.log.warning(
|
|
175
|
-
f"Parameter {name} data clipped to 100 items/elements: ({len(compressed_clipped)} bytes)"
|
|
176
|
-
)
|
|
177
|
-
self._store_parameter(name, compressed_clipped)
|
|
178
|
-
return
|
|
179
|
-
|
|
180
|
-
# Still too large - give up
|
|
181
|
-
self._handle_oversized_data(name, len(compressed_clipped))
|
|
182
|
-
|
|
183
|
-
except Exception as e:
|
|
184
|
-
self.log.critical(f"Failed to add/update parameter '{name}': {e}")
|
|
185
|
-
raise
|
|
186
|
-
|
|
187
|
-
def _call_with_retry(self, func, **kwargs):
|
|
188
|
-
"""Call AWS API with exponential backoff on throttling."""
|
|
189
|
-
max_retries = 5
|
|
190
|
-
base_delay = 1
|
|
191
|
-
for attempt in range(max_retries):
|
|
192
|
-
try:
|
|
193
|
-
return func(**kwargs)
|
|
194
|
-
except ClientError as e:
|
|
195
|
-
if e.response["Error"]["Code"] == "ThrottlingException" and attempt < max_retries - 1:
|
|
196
|
-
delay = base_delay * (2**attempt)
|
|
197
|
-
self.log.warning(f"Throttled, retrying in {delay}s...")
|
|
198
|
-
time.sleep(delay)
|
|
199
|
-
else:
|
|
200
|
-
raise
|
|
201
|
-
|
|
202
|
-
@staticmethod
|
|
203
|
-
def _compress_value(value) -> str:
|
|
204
|
-
"""Compress a value with precision reduction."""
|
|
205
|
-
json_value = json.dumps(value, cls=CustomEncoder, precision=3)
|
|
206
|
-
compressed = zlib.compress(json_value.encode("utf-8"), level=9)
|
|
207
|
-
return "COMPRESSED:" + base64.b64encode(compressed).decode("utf-8")
|
|
208
|
-
|
|
209
|
-
@staticmethod
|
|
210
|
-
def _clip_data(value):
|
|
211
|
-
"""Clip data to reduce size, clip to first 100 items/elements."""
|
|
212
|
-
if isinstance(value, dict):
|
|
213
|
-
return dict(list(value.items())[:100])
|
|
214
|
-
elif isinstance(value, list):
|
|
215
|
-
return value[:100]
|
|
216
|
-
return value
|
|
217
|
-
|
|
218
|
-
def _store_parameter(self, name: str, value: str):
|
|
219
|
-
"""Store parameter in AWS Parameter Store."""
|
|
220
|
-
self.ssm_client.put_parameter(Name=name, Value=value, Type="String", Overwrite=True)
|
|
221
|
-
self.log.info(f"Parameter '{name}' added/updated successfully.")
|
|
222
|
-
|
|
223
|
-
def _handle_oversized_data(self, name: str, size: int):
|
|
224
|
-
"""Handle data that's too large even after compression and clipping."""
|
|
225
|
-
doc_link = "https://supercowpowers.github.io/workbench/api_classes/df_store"
|
|
226
|
-
self.log.error(f"Compressed size {size} bytes, cannot store > 4KB")
|
|
227
|
-
self.log.error(f"For larger data use the DFStore() class ({doc_link})")
|
|
228
|
-
|
|
229
|
-
def delete(self, name: str):
|
|
230
|
-
"""Delete a parameter from the AWS Parameter Store.
|
|
231
|
-
|
|
232
|
-
Args:
|
|
233
|
-
name (str): The name of the parameter to delete.
|
|
234
|
-
"""
|
|
235
|
-
try:
|
|
236
|
-
# Delete the parameter from Parameter Store
|
|
237
|
-
self.ssm_client.delete_parameter(Name=name)
|
|
238
|
-
self.log.info(f"Parameter '{name}' deleted successfully.")
|
|
239
|
-
except Exception as e:
|
|
240
|
-
self.log.error(f"Failed to delete parameter '{name}': {e}")
|
|
241
|
-
|
|
242
|
-
def delete_recursive(self, prefix: str):
|
|
243
|
-
"""Delete all parameters with a given prefix from the AWS Parameter Store.
|
|
244
|
-
|
|
245
|
-
Args:
|
|
246
|
-
prefix (str): The prefix of the parameters to delete.
|
|
247
|
-
"""
|
|
248
|
-
# List all parameters with the given prefix
|
|
249
|
-
parameters = self.list(prefix=prefix)
|
|
250
|
-
for param in parameters:
|
|
251
|
-
self.delete(param)
|
|
252
|
-
|
|
253
|
-
def __repr__(self):
|
|
254
|
-
"""Return a string representation of the AWSParameterStore object."""
|
|
255
|
-
return "\n".join(self.list())
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
if __name__ == "__main__":
|
|
259
|
-
"""Exercise the AWSParameterStore Class"""
|
|
260
|
-
|
|
261
|
-
# Create a AWSParameterStore manager
|
|
262
|
-
param_store = AWSParameterStore()
|
|
263
|
-
|
|
264
|
-
# List the parameters
|
|
265
|
-
print("Listing Parameters...")
|
|
266
|
-
print(param_store.list())
|
|
267
|
-
|
|
268
|
-
# Add a new parameter
|
|
269
|
-
param_store.upsert("/workbench/test", "value")
|
|
270
|
-
|
|
271
|
-
# Get the parameter
|
|
272
|
-
print(f"Getting parameter 'test': {param_store.get('/workbench/test')}")
|
|
273
|
-
|
|
274
|
-
# Add a dictionary as a parameter
|
|
275
|
-
sample_dict = {"key": "str_value", "awesome_value": 4.2}
|
|
276
|
-
param_store.upsert("/workbench/my_data", sample_dict)
|
|
277
|
-
|
|
278
|
-
# Retrieve the parameter as a dictionary
|
|
279
|
-
retrieved_value = param_store.get("/workbench/my_data")
|
|
280
|
-
print("Retrieved value:", retrieved_value)
|
|
281
|
-
|
|
282
|
-
# List the parameters
|
|
283
|
-
print("Listing Parameters...")
|
|
284
|
-
print(param_store.list())
|
|
285
|
-
|
|
286
|
-
# List the parameters with a prefix
|
|
287
|
-
print("Listing Parameters with prefix '/workbench':")
|
|
288
|
-
print(param_store.list("/workbench"))
|
|
289
|
-
|
|
290
|
-
# Delete the parameters
|
|
291
|
-
param_store.delete("/workbench/test")
|
|
292
|
-
param_store.delete("/workbench/my_data")
|
|
293
|
-
|
|
294
|
-
# Out of scope tests
|
|
295
|
-
param_store.upsert("test", "value")
|
|
296
|
-
param_store.delete("test")
|