workbench 0.8.158__py3-none-any.whl → 0.8.159__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/api/feature_set.py +12 -4
- workbench/api/meta.py +1 -1
- workbench/cached/cached_feature_set.py +1 -0
- workbench/cached/cached_meta.py +10 -12
- workbench/core/artifacts/cached_artifact_mixin.py +6 -3
- workbench/core/artifacts/model_core.py +19 -7
- workbench/core/cloud_platform/aws/aws_meta.py +66 -45
- workbench/core/cloud_platform/cloud_meta.py +5 -2
- workbench/core/transforms/features_to_model/features_to_model.py +9 -5
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +6 -0
- workbench/model_scripts/{custom_models/nn_models → pytorch_model}/generated_model_script.py +170 -156
- workbench/model_scripts/{custom_models/nn_models → pytorch_model}/pytorch.template +153 -147
- workbench/model_scripts/pytorch_model/requirements.txt +2 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +307 -0
- workbench/model_scripts/script_generation.py +6 -2
- workbench/model_scripts/xgb_model/generated_model_script.py +6 -6
- workbench/repl/workbench_shell.py +4 -9
- workbench/utils/json_utils.py +27 -8
- workbench/utils/pandas_utils.py +12 -13
- workbench/utils/redis_cache.py +28 -13
- workbench/utils/workbench_cache.py +20 -14
- workbench/web_interface/page_views/endpoints_page_view.py +1 -1
- workbench/web_interface/page_views/main_page.py +1 -1
- {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/METADATA +5 -8
- {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/RECORD +29 -29
- workbench/model_scripts/custom_models/nn_models/Readme.md +0 -9
- workbench/model_scripts/custom_models/nn_models/requirements.txt +0 -4
- {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/WHEEL +0 -0
- {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
# Model Imports (this will be replaced with the imports for the template)
|
|
2
|
+
None
|
|
3
|
+
|
|
4
|
+
# Template Placeholders
|
|
5
|
+
TEMPLATE_PARAMS = {
|
|
6
|
+
"model_type": "regressor",
|
|
7
|
+
"target_column": "solubility",
|
|
8
|
+
"feature_list": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
|
|
9
|
+
"model_class": PyTorch,
|
|
10
|
+
"model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-pytorch-reg/training",
|
|
11
|
+
"train_all_data": False
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
import awswrangler as wr
|
|
15
|
+
from sklearn.preprocessing import LabelEncoder, StandardScaler
|
|
16
|
+
from sklearn.model_selection import train_test_split
|
|
17
|
+
from sklearn.pipeline import Pipeline
|
|
18
|
+
|
|
19
|
+
from io import StringIO
|
|
20
|
+
import json
|
|
21
|
+
import argparse
|
|
22
|
+
import joblib
|
|
23
|
+
import os
|
|
24
|
+
import pandas as pd
|
|
25
|
+
from typing import List
|
|
26
|
+
|
|
27
|
+
# Global model_type for both training and inference
|
|
28
|
+
model_type = TEMPLATE_PARAMS["model_type"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Function to check if dataframe is empty
|
|
32
|
+
def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
|
|
33
|
+
"""Check if the DataFrame is empty and raise an error if so."""
|
|
34
|
+
if df.empty:
|
|
35
|
+
msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
|
|
36
|
+
print(msg)
|
|
37
|
+
raise ValueError(msg)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# Function to expand probability column into individual class probability columns
|
|
41
|
+
def expand_proba_column(df: pd.DataFrame, class_labels: List[str]) -> pd.DataFrame:
|
|
42
|
+
"""Expand 'pred_proba' column into separate columns for each class label."""
|
|
43
|
+
proba_column = "pred_proba"
|
|
44
|
+
if proba_column not in df.columns:
|
|
45
|
+
raise ValueError('DataFrame does not contain a "pred_proba" column')
|
|
46
|
+
|
|
47
|
+
# Create new columns for each class label's probability
|
|
48
|
+
new_col_names = [f"{label}_proba" for label in class_labels]
|
|
49
|
+
proba_df = pd.DataFrame(df[proba_column].tolist(), columns=new_col_names)
|
|
50
|
+
|
|
51
|
+
# Drop the original 'pred_proba' column and reset the index
|
|
52
|
+
df = df.drop(columns=[proba_column]).reset_index(drop=True)
|
|
53
|
+
|
|
54
|
+
# Concatenate the new probability columns with the original DataFrame
|
|
55
|
+
df = pd.concat([df, proba_df], axis=1)
|
|
56
|
+
return df
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Function to match DataFrame columns to model features (case-insensitive)
|
|
60
|
+
def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
|
|
61
|
+
"""Match and rename DataFrame columns to match the model's features, case-insensitively."""
|
|
62
|
+
# Create a set of exact matches from the DataFrame columns
|
|
63
|
+
exact_match_set = set(df.columns)
|
|
64
|
+
|
|
65
|
+
# Create a case-insensitive map of DataFrame columns
|
|
66
|
+
column_map = {col.lower(): col for col in df.columns}
|
|
67
|
+
rename_dict = {}
|
|
68
|
+
|
|
69
|
+
# Build a dictionary for renaming columns based on case-insensitive matching
|
|
70
|
+
for feature in model_features:
|
|
71
|
+
if feature in exact_match_set:
|
|
72
|
+
rename_dict[feature] = feature
|
|
73
|
+
elif feature.lower() in column_map:
|
|
74
|
+
rename_dict[column_map[feature.lower()]] = feature
|
|
75
|
+
|
|
76
|
+
# Rename columns in the DataFrame to match model features
|
|
77
|
+
return df.rename(columns=rename_dict)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
#
|
|
81
|
+
# Training Section
|
|
82
|
+
#
|
|
83
|
+
if __name__ == "__main__":
|
|
84
|
+
# Template Parameters
|
|
85
|
+
target = TEMPLATE_PARAMS["target_column"] # Can be None for unsupervised models
|
|
86
|
+
feature_list = TEMPLATE_PARAMS["feature_list"]
|
|
87
|
+
model_class = TEMPLATE_PARAMS["model_class"]
|
|
88
|
+
model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
|
|
89
|
+
train_all_data = TEMPLATE_PARAMS["train_all_data"]
|
|
90
|
+
validation_split = 0.2
|
|
91
|
+
|
|
92
|
+
# Script arguments for input/output directories
|
|
93
|
+
parser = argparse.ArgumentParser()
|
|
94
|
+
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
95
|
+
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
96
|
+
parser.add_argument(
|
|
97
|
+
"--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
|
|
98
|
+
)
|
|
99
|
+
args = parser.parse_args()
|
|
100
|
+
|
|
101
|
+
# Load training data from the specified directory
|
|
102
|
+
training_files = [
|
|
103
|
+
os.path.join(args.train, file)
|
|
104
|
+
for file in os.listdir(args.train) if file.endswith(".csv")
|
|
105
|
+
]
|
|
106
|
+
all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
|
|
107
|
+
|
|
108
|
+
# Check if the DataFrame is empty
|
|
109
|
+
check_dataframe(all_df, "training_df")
|
|
110
|
+
|
|
111
|
+
# Initialize the model using the specified model class
|
|
112
|
+
model = model_class()
|
|
113
|
+
|
|
114
|
+
# Determine if standardization is needed based on the model type
|
|
115
|
+
needs_standardization = model_type in ["clusterer", "projection"]
|
|
116
|
+
|
|
117
|
+
if needs_standardization:
|
|
118
|
+
# Create a pipeline with standardization and the model
|
|
119
|
+
model = Pipeline([
|
|
120
|
+
("scaler", StandardScaler()),
|
|
121
|
+
("model", model)
|
|
122
|
+
])
|
|
123
|
+
|
|
124
|
+
# Handle logic based on the model_type
|
|
125
|
+
if model_type in ["classifier", "regressor"]:
|
|
126
|
+
# Supervised Models: Prepare for training
|
|
127
|
+
if train_all_data:
|
|
128
|
+
# Use all data for both training and validation
|
|
129
|
+
print("Training on all data...")
|
|
130
|
+
df_train = all_df.copy()
|
|
131
|
+
df_val = all_df.copy()
|
|
132
|
+
elif "training" in all_df.columns:
|
|
133
|
+
# Split data based on a 'training' column if it exists
|
|
134
|
+
print("Splitting data based on 'training' column...")
|
|
135
|
+
df_train = all_df[all_df["training"]].copy()
|
|
136
|
+
df_val = all_df[~all_df["training"]].copy()
|
|
137
|
+
else:
|
|
138
|
+
# Perform a random split if no 'training' column is found
|
|
139
|
+
print("Splitting data randomly...")
|
|
140
|
+
df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
|
|
141
|
+
|
|
142
|
+
# Encode the target variable if the model is a classifier
|
|
143
|
+
label_encoder = None
|
|
144
|
+
if model_type == "classifier" and target:
|
|
145
|
+
label_encoder = LabelEncoder()
|
|
146
|
+
df_train[target] = label_encoder.fit_transform(df_train[target])
|
|
147
|
+
df_val[target] = label_encoder.transform(df_val[target])
|
|
148
|
+
|
|
149
|
+
# Prepare features and targets for training
|
|
150
|
+
X_train = df_train[feature_list]
|
|
151
|
+
X_val = df_val[feature_list]
|
|
152
|
+
y_train = df_train[target] if target else None
|
|
153
|
+
y_val = df_val[target] if target else None
|
|
154
|
+
|
|
155
|
+
# Train the model using the training data
|
|
156
|
+
model.fit(X_train, y_train)
|
|
157
|
+
|
|
158
|
+
# Make predictions and handle classification-specific logic
|
|
159
|
+
preds = model.predict(X_val)
|
|
160
|
+
if model_type == "classifier" and target:
|
|
161
|
+
# Get class probabilities and expand them into separate columns
|
|
162
|
+
probs = model.predict_proba(X_val)
|
|
163
|
+
df_val["pred_proba"] = [p.tolist() for p in probs]
|
|
164
|
+
df_val = expand_proba_column(df_val, label_encoder.classes_)
|
|
165
|
+
|
|
166
|
+
# Decode the target and prediction labels
|
|
167
|
+
df_val[target] = label_encoder.inverse_transform(df_val[target])
|
|
168
|
+
preds = label_encoder.inverse_transform(preds)
|
|
169
|
+
|
|
170
|
+
# Add predictions to the validation DataFrame
|
|
171
|
+
df_val["prediction"] = preds
|
|
172
|
+
|
|
173
|
+
# Save the validation predictions to S3
|
|
174
|
+
output_columns = [target, "prediction"] + [col for col in df_val.columns if col.endswith("_proba")]
|
|
175
|
+
wr.s3.to_csv(df_val[output_columns], path=f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
|
|
176
|
+
|
|
177
|
+
elif model_type == "clusterer":
|
|
178
|
+
# Unsupervised Clustering Models: Assign cluster labels
|
|
179
|
+
all_df["cluster"] = model.fit_predict(all_df[feature_list])
|
|
180
|
+
|
|
181
|
+
elif model_type == "projection":
|
|
182
|
+
# Projection Models: Apply transformation and label first three components as x, y, z
|
|
183
|
+
transformed_data = model.fit_transform(all_df[feature_list])
|
|
184
|
+
num_components = transformed_data.shape[1]
|
|
185
|
+
|
|
186
|
+
# Special labels for the first three components, if they exist
|
|
187
|
+
special_labels = ["x", "y", "z"]
|
|
188
|
+
for i in range(num_components):
|
|
189
|
+
if i < len(special_labels):
|
|
190
|
+
all_df[special_labels[i]] = transformed_data[:, i]
|
|
191
|
+
else:
|
|
192
|
+
all_df[f"component_{i + 1}"] = transformed_data[:, i]
|
|
193
|
+
|
|
194
|
+
elif model_type == "transformer":
|
|
195
|
+
# Transformer Models: Apply transformation and use generic component labels
|
|
196
|
+
transformed_data = model.fit_transform(all_df[feature_list])
|
|
197
|
+
for i in range(transformed_data.shape[1]):
|
|
198
|
+
all_df[f"component_{i + 1}"] = transformed_data[:, i]
|
|
199
|
+
|
|
200
|
+
# Save the trained model and any necessary assets
|
|
201
|
+
joblib.dump(model, os.path.join(args.model_dir, "model.joblib"))
|
|
202
|
+
if model_type == "classifier" and label_encoder:
|
|
203
|
+
joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
|
|
204
|
+
|
|
205
|
+
# Save the feature list to validate input during predictions
|
|
206
|
+
with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
|
|
207
|
+
json.dump(feature_list, fp)
|
|
208
|
+
|
|
209
|
+
#
|
|
210
|
+
# Inference Section
|
|
211
|
+
#
|
|
212
|
+
def model_fn(model_dir):
|
|
213
|
+
"""Load and return the model from the specified directory."""
|
|
214
|
+
return joblib.load(os.path.join(model_dir, "model.joblib"))
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def input_fn(input_data, content_type):
|
|
218
|
+
"""Parse input data and return a DataFrame."""
|
|
219
|
+
if not input_data:
|
|
220
|
+
raise ValueError("Empty input data is not supported!")
|
|
221
|
+
|
|
222
|
+
# Decode bytes to string if necessary
|
|
223
|
+
if isinstance(input_data, bytes):
|
|
224
|
+
input_data = input_data.decode("utf-8")
|
|
225
|
+
|
|
226
|
+
if "text/csv" in content_type:
|
|
227
|
+
return pd.read_csv(StringIO(input_data))
|
|
228
|
+
elif "application/json" in content_type:
|
|
229
|
+
return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
|
|
230
|
+
else:
|
|
231
|
+
raise ValueError(f"{content_type} not supported!")
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def output_fn(output_df, accept_type):
|
|
235
|
+
"""Supports both CSV and JSON output formats."""
|
|
236
|
+
if "text/csv" in accept_type:
|
|
237
|
+
csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
|
|
238
|
+
return csv_output, "text/csv"
|
|
239
|
+
elif "application/json" in accept_type:
|
|
240
|
+
return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
|
|
241
|
+
else:
|
|
242
|
+
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def predict_fn(df, model):
|
|
246
|
+
"""Make predictions or apply transformations using the model and return the DataFrame with results."""
|
|
247
|
+
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
248
|
+
|
|
249
|
+
# Load feature columns from the saved file
|
|
250
|
+
with open(os.path.join(model_dir, "feature_columns.json")) as fp:
|
|
251
|
+
model_features = json.load(fp)
|
|
252
|
+
|
|
253
|
+
# Load label encoder if available (for classification models)
|
|
254
|
+
label_encoder = None
|
|
255
|
+
if os.path.exists(os.path.join(model_dir, "label_encoder.joblib")):
|
|
256
|
+
label_encoder = joblib.load(os.path.join(model_dir, "label_encoder.joblib"))
|
|
257
|
+
|
|
258
|
+
# Match features in a case-insensitive manner
|
|
259
|
+
matched_df = match_features_case_insensitive(df, model_features)
|
|
260
|
+
|
|
261
|
+
# Initialize a dictionary to store the results
|
|
262
|
+
results = {}
|
|
263
|
+
|
|
264
|
+
# Determine how to handle the model based on its available methods
|
|
265
|
+
if hasattr(model, "predict"):
|
|
266
|
+
# For supervised models (classifier or regressor)
|
|
267
|
+
predictions = model.predict(matched_df[model_features])
|
|
268
|
+
results["prediction"] = predictions
|
|
269
|
+
|
|
270
|
+
elif hasattr(model, "fit_predict"):
|
|
271
|
+
# For clustering models (e.g., DBSCAN)
|
|
272
|
+
clusters = model.fit_predict(matched_df[model_features])
|
|
273
|
+
results["cluster"] = clusters
|
|
274
|
+
|
|
275
|
+
elif hasattr(model, "fit_transform") and not hasattr(model, "predict"):
|
|
276
|
+
# For transformation/projection models (e.g., t-SNE, PCA)
|
|
277
|
+
transformed_data = model.fit_transform(matched_df[model_features])
|
|
278
|
+
|
|
279
|
+
# Handle 2D projection models specifically
|
|
280
|
+
if model_type == "projection" and transformed_data.shape[1] == 2:
|
|
281
|
+
results["x"] = transformed_data[:, 0]
|
|
282
|
+
results["y"] = transformed_data[:, 1]
|
|
283
|
+
else:
|
|
284
|
+
# General case for any number of components
|
|
285
|
+
for i in range(transformed_data.shape[1]):
|
|
286
|
+
results[f"component_{i + 1}"] = transformed_data[:, i]
|
|
287
|
+
|
|
288
|
+
else:
|
|
289
|
+
# Raise an error if the model does not support the expected methods
|
|
290
|
+
raise ValueError("Model does not support predict, fit_predict, or fit_transform methods.")
|
|
291
|
+
|
|
292
|
+
# Decode predictions if using a label encoder (for classification)
|
|
293
|
+
if label_encoder and "prediction" in results:
|
|
294
|
+
results["prediction"] = label_encoder.inverse_transform(results["prediction"])
|
|
295
|
+
|
|
296
|
+
# Add the results to the DataFrame
|
|
297
|
+
for key, value in results.items():
|
|
298
|
+
df[key] = value
|
|
299
|
+
|
|
300
|
+
# Add probability columns if the model supports it (for classification)
|
|
301
|
+
if hasattr(model, "predict_proba"):
|
|
302
|
+
probs = model.predict_proba(matched_df[model_features])
|
|
303
|
+
df["pred_proba"] = [p.tolist() for p in probs]
|
|
304
|
+
df = expand_proba_column(df, label_encoder.classes_)
|
|
305
|
+
|
|
306
|
+
# Return the modified DataFrame
|
|
307
|
+
return df
|
|
@@ -101,8 +101,12 @@ def generate_model_script(template_params: dict) -> str:
|
|
|
101
101
|
|
|
102
102
|
# Determine which template to use based on model type
|
|
103
103
|
if template_params.get("model_class"):
|
|
104
|
-
|
|
105
|
-
|
|
104
|
+
if template_params["model_class"].lower() == "pytorch":
|
|
105
|
+
template_name = "pytorch.template"
|
|
106
|
+
model_script_dir = "pytorch_model"
|
|
107
|
+
else:
|
|
108
|
+
template_name = "scikit_learn.template"
|
|
109
|
+
model_script_dir = "scikit_learn"
|
|
106
110
|
elif template_params["model_type"] in [ModelType.REGRESSOR, ModelType.CLASSIFIER]:
|
|
107
111
|
template_name = "xgb_model.template"
|
|
108
112
|
model_script_dir = "xgb_model"
|
|
@@ -28,12 +28,12 @@ from typing import List, Tuple
|
|
|
28
28
|
|
|
29
29
|
# Template Parameters
|
|
30
30
|
TEMPLATE_PARAMS = {
|
|
31
|
-
"model_type": "
|
|
32
|
-
"target_column": "
|
|
33
|
-
"features": ['
|
|
34
|
-
"compressed_features": [],
|
|
35
|
-
"model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/
|
|
36
|
-
"train_all_data":
|
|
31
|
+
"model_type": "classifier",
|
|
32
|
+
"target_column": "solubility_class",
|
|
33
|
+
"features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct', 'fingerprint'],
|
|
34
|
+
"compressed_features": ['fingerprint'],
|
|
35
|
+
"model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-fingerprints-plus-class/training",
|
|
36
|
+
"train_all_data": True
|
|
37
37
|
}
|
|
38
38
|
|
|
39
39
|
# Function to check if dataframe is empty
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import IPython
|
|
2
1
|
from IPython import start_ipython
|
|
3
2
|
from IPython.terminal.prompts import Prompts
|
|
4
3
|
from IPython.terminal.ipapp import load_default_config
|
|
@@ -10,7 +9,6 @@ import botocore
|
|
|
10
9
|
import webbrowser
|
|
11
10
|
import pandas as pd
|
|
12
11
|
import readline # noqa
|
|
13
|
-
from distutils.version import LooseVersion
|
|
14
12
|
|
|
15
13
|
try:
|
|
16
14
|
import matplotlib.pyplot as plt # noqa
|
|
@@ -34,6 +32,8 @@ from workbench.utils.repl_utils import cprint, Spinner
|
|
|
34
32
|
from workbench.utils.workbench_logging import IMPORTANT_LEVEL_NUM, TRACE_LEVEL_NUM
|
|
35
33
|
from workbench.utils.config_manager import ConfigManager
|
|
36
34
|
from workbench.utils.log_utils import silence_logs, log_theme
|
|
35
|
+
from workbench.api import Meta
|
|
36
|
+
from workbench.cached.cached_meta import CachedMeta
|
|
37
37
|
|
|
38
38
|
# If we have RDKIT/Mordred let's pull in our cheminformatics utils
|
|
39
39
|
try:
|
|
@@ -196,10 +196,7 @@ class WorkbenchShell:
|
|
|
196
196
|
|
|
197
197
|
# Start IPython with the config and commands in the namespace
|
|
198
198
|
try:
|
|
199
|
-
|
|
200
|
-
ipython_argv = ["--no-tip", "--theme", "linux"]
|
|
201
|
-
else:
|
|
202
|
-
ipython_argv = []
|
|
199
|
+
ipython_argv = ["--no-tip", "--theme", "linux"]
|
|
203
200
|
start_ipython(ipython_argv, user_ns=locs, config=config)
|
|
204
201
|
finally:
|
|
205
202
|
spinner = self.spinner_start("Goodbye to AWS:")
|
|
@@ -255,7 +252,7 @@ class WorkbenchShell:
|
|
|
255
252
|
|
|
256
253
|
def import_workbench(self):
|
|
257
254
|
# Import all the Workbench modules
|
|
258
|
-
spinner = self.spinner_start("
|
|
255
|
+
spinner = self.spinner_start("Spinning up Workbench:")
|
|
259
256
|
try:
|
|
260
257
|
# These are the classes we want to expose to the REPL
|
|
261
258
|
self.commands["DataSource"] = importlib.import_module("workbench.api.data_source").DataSource
|
|
@@ -475,8 +472,6 @@ class WorkbenchShell:
|
|
|
475
472
|
|
|
476
473
|
# Helpers method to switch from direct Meta to Cached Meta
|
|
477
474
|
def try_cached_meta(self):
|
|
478
|
-
from workbench.api import Meta
|
|
479
|
-
from workbench.cached.cached_meta import CachedMeta
|
|
480
475
|
|
|
481
476
|
with silence_logs():
|
|
482
477
|
self.meta = CachedMeta()
|
workbench/utils/json_utils.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""JSON Utilities"""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
from io import StringIO
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
6
7
|
import logging
|
|
@@ -33,9 +34,7 @@ class CustomEncoder(json.JSONEncoder):
|
|
|
33
34
|
elif isinstance(obj, pd.DataFrame):
|
|
34
35
|
return {
|
|
35
36
|
"__dataframe__": True,
|
|
36
|
-
"df": obj.
|
|
37
|
-
"index": obj.index.tolist(),
|
|
38
|
-
"index_name": obj.index.name,
|
|
37
|
+
"df": obj.to_json(orient="table"),
|
|
39
38
|
}
|
|
40
39
|
return super().default(obj)
|
|
41
40
|
except Exception as e:
|
|
@@ -62,10 +61,16 @@ def custom_decoder(dct):
|
|
|
62
61
|
if "__datetime__" in dct:
|
|
63
62
|
return iso8601_to_datetime(dct["datetime"])
|
|
64
63
|
elif "__dataframe__" in dct:
|
|
65
|
-
|
|
66
|
-
if
|
|
67
|
-
df.
|
|
68
|
-
|
|
64
|
+
df_data = dct["df"]
|
|
65
|
+
if isinstance(df_data, str):
|
|
66
|
+
df = pd.read_json(StringIO(df_data), orient="table")
|
|
67
|
+
else:
|
|
68
|
+
# Old format compatibility
|
|
69
|
+
log.warning("Decoding old dataframe format...")
|
|
70
|
+
df = pd.DataFrame.from_dict(df_data)
|
|
71
|
+
if "index" in dct:
|
|
72
|
+
df.index = dct["index"]
|
|
73
|
+
df.index.name = dct.get("index_name")
|
|
69
74
|
return df
|
|
70
75
|
return dct
|
|
71
76
|
except Exception as e:
|
|
@@ -86,6 +91,7 @@ if __name__ == "__main__":
|
|
|
86
91
|
"datetime": datetime.now(),
|
|
87
92
|
"date": date.today(),
|
|
88
93
|
"dataframe": pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}),
|
|
94
|
+
"list": [1, 2, 3],
|
|
89
95
|
}
|
|
90
96
|
|
|
91
97
|
# Encode the test dictionary
|
|
@@ -120,4 +126,17 @@ if __name__ == "__main__":
|
|
|
120
126
|
decoded_df = json.loads(encoded_df, object_hook=custom_decoder)
|
|
121
127
|
|
|
122
128
|
print("Original DataFrame index name:", df_with_index.index.name)
|
|
123
|
-
print("Decoded DataFrame index name:", decoded_df.index.name)
|
|
129
|
+
print("Decoded DataFrame index name:", decoded_df.index.name)
|
|
130
|
+
|
|
131
|
+
# Dataframe Testing
|
|
132
|
+
from workbench.api import DFStore
|
|
133
|
+
|
|
134
|
+
df_store = DFStore()
|
|
135
|
+
df = df_store.get("/testing/json_encoding/smart_sample_bad")
|
|
136
|
+
encoded = json.dumps(df, cls=CustomEncoder)
|
|
137
|
+
decoded_df = json.loads(encoded, object_hook=custom_decoder)
|
|
138
|
+
|
|
139
|
+
# Compare original and decoded DataFrame
|
|
140
|
+
from workbench.utils.pandas_utils import compare_dataframes
|
|
141
|
+
|
|
142
|
+
compare_dataframes(df, decoded_df)
|
workbench/utils/pandas_utils.py
CHANGED
|
@@ -94,14 +94,16 @@ def dataframe_delta(func_that_returns_df, previous_hash: Optional[str] = None) -
|
|
|
94
94
|
return df, current_hash
|
|
95
95
|
|
|
96
96
|
|
|
97
|
-
def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: list):
|
|
97
|
+
def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: list = None):
|
|
98
98
|
"""Compare two DataFrames and report on differences.
|
|
99
99
|
|
|
100
100
|
Args:
|
|
101
101
|
df1 (pd.DataFrame): First DataFrame to compare.
|
|
102
102
|
df2 (pd.DataFrame): Second DataFrame to compare.
|
|
103
|
-
display_columns (list): Columns to display when differences are found.
|
|
103
|
+
display_columns (list): Columns to display when differences are found (defaults to all columns).
|
|
104
104
|
"""
|
|
105
|
+
if display_columns is None:
|
|
106
|
+
display_columns = df1.columns.tolist()
|
|
105
107
|
|
|
106
108
|
# Check if the entire dataframes are equal
|
|
107
109
|
if df1.equals(df2):
|
|
@@ -130,7 +132,7 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
|
|
|
130
132
|
# Print out the column types
|
|
131
133
|
print("\nColumn Types:")
|
|
132
134
|
print(f"DF1: {df1[common_columns].dtypes.value_counts()}")
|
|
133
|
-
print(f"
|
|
135
|
+
print(f"\nDF2: {df2[common_columns].dtypes.value_counts()}")
|
|
134
136
|
|
|
135
137
|
# Count the NaNs in each DataFrame individually (only show columns with > 0 NaNs)
|
|
136
138
|
nan_counts_df1 = df1.isna().sum()
|
|
@@ -146,6 +148,7 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
|
|
|
146
148
|
|
|
147
149
|
# Define tolerance for float comparisons
|
|
148
150
|
epsilon = 1e-10
|
|
151
|
+
difference_counts = {}
|
|
149
152
|
|
|
150
153
|
# Check for differences in common columns
|
|
151
154
|
for column in common_columns:
|
|
@@ -161,18 +164,10 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
|
|
|
161
164
|
# Other types (e.g., int) with NaNs treated as equal
|
|
162
165
|
differences = ~(df1[column].fillna(0) == df2[column].fillna(0))
|
|
163
166
|
|
|
164
|
-
# Create a merged DataFrame showing values from both DataFrames
|
|
165
|
-
merged_df = pd.DataFrame(
|
|
166
|
-
{
|
|
167
|
-
**{col: df1.loc[differences, col] for col in display_columns},
|
|
168
|
-
f"{column}_1": df1.loc[differences, column],
|
|
169
|
-
f"{column}_2": df2.loc[differences, column],
|
|
170
|
-
}
|
|
171
|
-
)
|
|
172
|
-
|
|
173
167
|
# If differences exist, display them
|
|
174
168
|
if differences.any():
|
|
175
|
-
print(f"\nColumn {column} has differences
|
|
169
|
+
print(f"\nColumn {column} has {differences.sum()} differences")
|
|
170
|
+
difference_counts[column] = differences.sum()
|
|
176
171
|
|
|
177
172
|
# Create a merged DataFrame showing values from both DataFrames
|
|
178
173
|
merged_df = pd.DataFrame(
|
|
@@ -186,6 +181,10 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
|
|
|
186
181
|
# Display the merged DataFrame
|
|
187
182
|
print(merged_df)
|
|
188
183
|
|
|
184
|
+
# If there are no differences report that
|
|
185
|
+
if not difference_counts:
|
|
186
|
+
print(f"\nNo differences found in common columns within {epsilon}")
|
|
187
|
+
|
|
189
188
|
|
|
190
189
|
def subnormal_check(df):
|
|
191
190
|
"""
|
workbench/utils/redis_cache.py
CHANGED
|
@@ -73,12 +73,25 @@ class RedisCache:
|
|
|
73
73
|
|
|
74
74
|
def set(self, key, value):
|
|
75
75
|
"""Add an item to the redis_cache, all items are JSON serialized
|
|
76
|
+
|
|
76
77
|
Args:
|
|
77
|
-
|
|
78
|
-
|
|
78
|
+
key: item key
|
|
79
|
+
value: the value associated with this key
|
|
79
80
|
"""
|
|
80
81
|
self._set(key, json.dumps(value, cls=CustomEncoder))
|
|
81
82
|
|
|
83
|
+
def atomic_set(self, key, value) -> bool:
|
|
84
|
+
"""Atomically set key to value only if key doesn't exist.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
True if the key was set, False if it already existed.
|
|
88
|
+
"""
|
|
89
|
+
# Serialize the value to JSON
|
|
90
|
+
serialized_value = json.dumps(value, cls=CustomEncoder)
|
|
91
|
+
result = self.redis_db.set(self.prefix + str(key) + self.postfix, serialized_value, ex=self.expire, nx=True)
|
|
92
|
+
log.debug(f"Atomic Set: {key} -> {value} (Result: {result})")
|
|
93
|
+
return result is True
|
|
94
|
+
|
|
82
95
|
def get(self, key):
|
|
83
96
|
"""Get an item from the redis_cache, all items are JSON deserialized
|
|
84
97
|
Args:
|
|
@@ -165,22 +178,20 @@ class RedisCache:
|
|
|
165
178
|
def get_memory_config(self):
|
|
166
179
|
"""Get Redis memory usage and configuration settings as a dictionary"""
|
|
167
180
|
info = {}
|
|
168
|
-
try:
|
|
169
|
-
memory_info = self.redis_db.info("memory")
|
|
170
|
-
info["used_memory"] = memory_info.get("used_memory", "N/A")
|
|
171
|
-
info["used_memory_human"] = memory_info.get("used_memory_human", "N/A")
|
|
172
|
-
info["mem_fragmentation_ratio"] = memory_info.get("mem_fragmentation_ratio", "N/A")
|
|
173
|
-
info["maxmemory_policy"] = memory_info.get("maxmemory_policy", "N/A")
|
|
174
|
-
except redis.exceptions.RedisError as e:
|
|
175
|
-
log.error(f"Error retrieving memory info from Redis: {e}")
|
|
176
181
|
|
|
182
|
+
# Memory info about the Redis database
|
|
183
|
+
memory_info = self.redis_db.info("memory")
|
|
184
|
+
info["used_memory"] = memory_info.get("used_memory", "N/A")
|
|
185
|
+
info["used_memory_human"] = memory_info.get("used_memory_human", "N/A")
|
|
186
|
+
info["mem_fragmentation_ratio"] = memory_info.get("mem_fragmentation_ratio", "N/A")
|
|
187
|
+
info["maxmemory_policy"] = memory_info.get("maxmemory_policy", "N/A")
|
|
188
|
+
# CONFIG commands are disabled in managed Redis services like ElastiCache
|
|
177
189
|
try:
|
|
178
190
|
max_memory = self.redis_db.config_get("maxmemory")
|
|
179
191
|
info["maxmemory"] = max_memory.get("maxmemory", "N/A")
|
|
180
192
|
except redis.exceptions.RedisError as e:
|
|
181
|
-
log.
|
|
182
|
-
info["maxmemory"] = "Not Available -
|
|
183
|
-
|
|
193
|
+
log.debug(f"CONFIG GET disabled (likely managed Redis service): {e}")
|
|
194
|
+
info["maxmemory"] = "Not Available - Managed Service"
|
|
184
195
|
return info
|
|
185
196
|
|
|
186
197
|
def report_memory_config(self):
|
|
@@ -244,6 +255,10 @@ if __name__ == "__main__":
|
|
|
244
255
|
# Delete anything in the test database
|
|
245
256
|
my_redis_cache.clear()
|
|
246
257
|
|
|
258
|
+
# Test the atomic set
|
|
259
|
+
assert my_redis_cache.atomic_set("foo", "bar") is True
|
|
260
|
+
assert my_redis_cache.atomic_set("foo", "baz") is False
|
|
261
|
+
|
|
247
262
|
# Test storage
|
|
248
263
|
my_redis_cache.set("foo", "bar")
|
|
249
264
|
assert my_redis_cache.get("foo") == "bar"
|
|
@@ -3,7 +3,6 @@ use RedisCache if it's available, and fall back to Cache if it's not.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from contextlib import contextmanager
|
|
7
6
|
from workbench.utils.cache import Cache
|
|
8
7
|
from workbench.utils.redis_cache import RedisCache
|
|
9
8
|
|
|
@@ -12,21 +11,8 @@ import logging
|
|
|
12
11
|
log = logging.getLogger("workbench")
|
|
13
12
|
|
|
14
13
|
|
|
15
|
-
# Context manager for disabling refresh
|
|
16
|
-
@contextmanager
|
|
17
|
-
def disable_refresh():
|
|
18
|
-
log.warning("WorkbenchCache: Disabling Refresh")
|
|
19
|
-
WorkbenchCache.refresh_enabled = False
|
|
20
|
-
yield
|
|
21
|
-
log.warning("WorkbenchCache: Enabling Refresh")
|
|
22
|
-
WorkbenchCache.refresh_enabled = True
|
|
23
|
-
|
|
24
|
-
|
|
25
14
|
class WorkbenchCache:
|
|
26
15
|
|
|
27
|
-
# Class attribute to control refresh treads (on/off)
|
|
28
|
-
refresh_enabled = True
|
|
29
|
-
|
|
30
16
|
def __init__(self, expire=None, prefix="", postfix=""):
|
|
31
17
|
"""WorkbenchCache Initialization
|
|
32
18
|
Args:
|
|
@@ -82,6 +68,21 @@ class WorkbenchCache:
|
|
|
82
68
|
def clear(self):
|
|
83
69
|
return self._actual_cache.clear()
|
|
84
70
|
|
|
71
|
+
def atomic_set(self, key, value) -> bool:
|
|
72
|
+
"""Atomically set key to value only if key doesn't exist.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
True if the key was set, False if it already existed.
|
|
76
|
+
"""
|
|
77
|
+
if self._using_redis:
|
|
78
|
+
return self._actual_cache.atomic_set(key, value)
|
|
79
|
+
|
|
80
|
+
# In-Memory Cache does not support atomic operations, so we simulate it
|
|
81
|
+
else:
|
|
82
|
+
key_exists = self._actual_cache.get(key) is not None
|
|
83
|
+
self._actual_cache.set(key, value)
|
|
84
|
+
return not key_exists
|
|
85
|
+
|
|
85
86
|
def show_size_details(self, value):
|
|
86
87
|
"""Print the size of the sub-parts of the value"""
|
|
87
88
|
try:
|
|
@@ -118,6 +119,10 @@ if __name__ == "__main__":
|
|
|
118
119
|
# Delete anything in the test database
|
|
119
120
|
my_cache.clear()
|
|
120
121
|
|
|
122
|
+
# Test the atomic set
|
|
123
|
+
assert my_cache.atomic_set("foo", "bar") is True
|
|
124
|
+
assert my_cache.atomic_set("foo", "baz") is False # Should not overwrite
|
|
125
|
+
|
|
121
126
|
# Test storage
|
|
122
127
|
my_cache.set("foo", "bar")
|
|
123
128
|
assert my_cache.get("foo") == "bar"
|
|
@@ -167,3 +172,4 @@ if __name__ == "__main__":
|
|
|
167
172
|
my_cache.set("df", df)
|
|
168
173
|
df = my_cache.get("df")
|
|
169
174
|
print(df)
|
|
175
|
+
my_cache.clear()
|
|
@@ -25,7 +25,7 @@ class EndpointsPageView(PageView):
|
|
|
25
25
|
def refresh(self):
|
|
26
26
|
"""Refresh the endpoint data from the Cloud Platform"""
|
|
27
27
|
self.log.important("Calling endpoint page view refresh()..")
|
|
28
|
-
self.endpoints_df = self.meta.endpoints()
|
|
28
|
+
self.endpoints_df = self.meta.endpoints(details=True)
|
|
29
29
|
|
|
30
30
|
# Drop the AWS URL column
|
|
31
31
|
self.endpoints_df.drop(columns=["_aws_url"], inplace=True, errors="ignore")
|