workbench 0.8.158__py3-none-any.whl → 0.8.159__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. workbench/api/feature_set.py +12 -4
  2. workbench/api/meta.py +1 -1
  3. workbench/cached/cached_feature_set.py +1 -0
  4. workbench/cached/cached_meta.py +10 -12
  5. workbench/core/artifacts/cached_artifact_mixin.py +6 -3
  6. workbench/core/artifacts/model_core.py +19 -7
  7. workbench/core/cloud_platform/aws/aws_meta.py +66 -45
  8. workbench/core/cloud_platform/cloud_meta.py +5 -2
  9. workbench/core/transforms/features_to_model/features_to_model.py +9 -5
  10. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +6 -0
  11. workbench/model_scripts/{custom_models/nn_models → pytorch_model}/generated_model_script.py +170 -156
  12. workbench/model_scripts/{custom_models/nn_models → pytorch_model}/pytorch.template +153 -147
  13. workbench/model_scripts/pytorch_model/requirements.txt +2 -0
  14. workbench/model_scripts/scikit_learn/generated_model_script.py +307 -0
  15. workbench/model_scripts/script_generation.py +6 -2
  16. workbench/model_scripts/xgb_model/generated_model_script.py +6 -6
  17. workbench/repl/workbench_shell.py +4 -9
  18. workbench/utils/json_utils.py +27 -8
  19. workbench/utils/pandas_utils.py +12 -13
  20. workbench/utils/redis_cache.py +28 -13
  21. workbench/utils/workbench_cache.py +20 -14
  22. workbench/web_interface/page_views/endpoints_page_view.py +1 -1
  23. workbench/web_interface/page_views/main_page.py +1 -1
  24. {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/METADATA +5 -8
  25. {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/RECORD +29 -29
  26. workbench/model_scripts/custom_models/nn_models/Readme.md +0 -9
  27. workbench/model_scripts/custom_models/nn_models/requirements.txt +0 -4
  28. {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/WHEEL +0 -0
  29. {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/entry_points.txt +0 -0
  30. {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/licenses/LICENSE +0 -0
  31. {workbench-0.8.158.dist-info → workbench-0.8.159.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,307 @@
1
+ # Model Imports (this will be replaced with the imports for the template)
2
+ None
3
+
4
+ # Template Placeholders
5
+ TEMPLATE_PARAMS = {
6
+ "model_type": "regressor",
7
+ "target_column": "solubility",
8
+ "feature_list": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
9
+ "model_class": PyTorch,
10
+ "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-pytorch-reg/training",
11
+ "train_all_data": False
12
+ }
13
+
14
+ import awswrangler as wr
15
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
16
+ from sklearn.model_selection import train_test_split
17
+ from sklearn.pipeline import Pipeline
18
+
19
+ from io import StringIO
20
+ import json
21
+ import argparse
22
+ import joblib
23
+ import os
24
+ import pandas as pd
25
+ from typing import List
26
+
27
+ # Global model_type for both training and inference
28
+ model_type = TEMPLATE_PARAMS["model_type"]
29
+
30
+
31
+ # Function to check if dataframe is empty
32
+ def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
33
+ """Check if the DataFrame is empty and raise an error if so."""
34
+ if df.empty:
35
+ msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
36
+ print(msg)
37
+ raise ValueError(msg)
38
+
39
+
40
+ # Function to expand probability column into individual class probability columns
41
+ def expand_proba_column(df: pd.DataFrame, class_labels: List[str]) -> pd.DataFrame:
42
+ """Expand 'pred_proba' column into separate columns for each class label."""
43
+ proba_column = "pred_proba"
44
+ if proba_column not in df.columns:
45
+ raise ValueError('DataFrame does not contain a "pred_proba" column')
46
+
47
+ # Create new columns for each class label's probability
48
+ new_col_names = [f"{label}_proba" for label in class_labels]
49
+ proba_df = pd.DataFrame(df[proba_column].tolist(), columns=new_col_names)
50
+
51
+ # Drop the original 'pred_proba' column and reset the index
52
+ df = df.drop(columns=[proba_column]).reset_index(drop=True)
53
+
54
+ # Concatenate the new probability columns with the original DataFrame
55
+ df = pd.concat([df, proba_df], axis=1)
56
+ return df
57
+
58
+
59
+ # Function to match DataFrame columns to model features (case-insensitive)
60
+ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
61
+ """Match and rename DataFrame columns to match the model's features, case-insensitively."""
62
+ # Create a set of exact matches from the DataFrame columns
63
+ exact_match_set = set(df.columns)
64
+
65
+ # Create a case-insensitive map of DataFrame columns
66
+ column_map = {col.lower(): col for col in df.columns}
67
+ rename_dict = {}
68
+
69
+ # Build a dictionary for renaming columns based on case-insensitive matching
70
+ for feature in model_features:
71
+ if feature in exact_match_set:
72
+ rename_dict[feature] = feature
73
+ elif feature.lower() in column_map:
74
+ rename_dict[column_map[feature.lower()]] = feature
75
+
76
+ # Rename columns in the DataFrame to match model features
77
+ return df.rename(columns=rename_dict)
78
+
79
+
80
+ #
81
+ # Training Section
82
+ #
83
+ if __name__ == "__main__":
84
+ # Template Parameters
85
+ target = TEMPLATE_PARAMS["target_column"] # Can be None for unsupervised models
86
+ feature_list = TEMPLATE_PARAMS["feature_list"]
87
+ model_class = TEMPLATE_PARAMS["model_class"]
88
+ model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
89
+ train_all_data = TEMPLATE_PARAMS["train_all_data"]
90
+ validation_split = 0.2
91
+
92
+ # Script arguments for input/output directories
93
+ parser = argparse.ArgumentParser()
94
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
95
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
96
+ parser.add_argument(
97
+ "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
98
+ )
99
+ args = parser.parse_args()
100
+
101
+ # Load training data from the specified directory
102
+ training_files = [
103
+ os.path.join(args.train, file)
104
+ for file in os.listdir(args.train) if file.endswith(".csv")
105
+ ]
106
+ all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
107
+
108
+ # Check if the DataFrame is empty
109
+ check_dataframe(all_df, "training_df")
110
+
111
+ # Initialize the model using the specified model class
112
+ model = model_class()
113
+
114
+ # Determine if standardization is needed based on the model type
115
+ needs_standardization = model_type in ["clusterer", "projection"]
116
+
117
+ if needs_standardization:
118
+ # Create a pipeline with standardization and the model
119
+ model = Pipeline([
120
+ ("scaler", StandardScaler()),
121
+ ("model", model)
122
+ ])
123
+
124
+ # Handle logic based on the model_type
125
+ if model_type in ["classifier", "regressor"]:
126
+ # Supervised Models: Prepare for training
127
+ if train_all_data:
128
+ # Use all data for both training and validation
129
+ print("Training on all data...")
130
+ df_train = all_df.copy()
131
+ df_val = all_df.copy()
132
+ elif "training" in all_df.columns:
133
+ # Split data based on a 'training' column if it exists
134
+ print("Splitting data based on 'training' column...")
135
+ df_train = all_df[all_df["training"]].copy()
136
+ df_val = all_df[~all_df["training"]].copy()
137
+ else:
138
+ # Perform a random split if no 'training' column is found
139
+ print("Splitting data randomly...")
140
+ df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
141
+
142
+ # Encode the target variable if the model is a classifier
143
+ label_encoder = None
144
+ if model_type == "classifier" and target:
145
+ label_encoder = LabelEncoder()
146
+ df_train[target] = label_encoder.fit_transform(df_train[target])
147
+ df_val[target] = label_encoder.transform(df_val[target])
148
+
149
+ # Prepare features and targets for training
150
+ X_train = df_train[feature_list]
151
+ X_val = df_val[feature_list]
152
+ y_train = df_train[target] if target else None
153
+ y_val = df_val[target] if target else None
154
+
155
+ # Train the model using the training data
156
+ model.fit(X_train, y_train)
157
+
158
+ # Make predictions and handle classification-specific logic
159
+ preds = model.predict(X_val)
160
+ if model_type == "classifier" and target:
161
+ # Get class probabilities and expand them into separate columns
162
+ probs = model.predict_proba(X_val)
163
+ df_val["pred_proba"] = [p.tolist() for p in probs]
164
+ df_val = expand_proba_column(df_val, label_encoder.classes_)
165
+
166
+ # Decode the target and prediction labels
167
+ df_val[target] = label_encoder.inverse_transform(df_val[target])
168
+ preds = label_encoder.inverse_transform(preds)
169
+
170
+ # Add predictions to the validation DataFrame
171
+ df_val["prediction"] = preds
172
+
173
+ # Save the validation predictions to S3
174
+ output_columns = [target, "prediction"] + [col for col in df_val.columns if col.endswith("_proba")]
175
+ wr.s3.to_csv(df_val[output_columns], path=f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
176
+
177
+ elif model_type == "clusterer":
178
+ # Unsupervised Clustering Models: Assign cluster labels
179
+ all_df["cluster"] = model.fit_predict(all_df[feature_list])
180
+
181
+ elif model_type == "projection":
182
+ # Projection Models: Apply transformation and label first three components as x, y, z
183
+ transformed_data = model.fit_transform(all_df[feature_list])
184
+ num_components = transformed_data.shape[1]
185
+
186
+ # Special labels for the first three components, if they exist
187
+ special_labels = ["x", "y", "z"]
188
+ for i in range(num_components):
189
+ if i < len(special_labels):
190
+ all_df[special_labels[i]] = transformed_data[:, i]
191
+ else:
192
+ all_df[f"component_{i + 1}"] = transformed_data[:, i]
193
+
194
+ elif model_type == "transformer":
195
+ # Transformer Models: Apply transformation and use generic component labels
196
+ transformed_data = model.fit_transform(all_df[feature_list])
197
+ for i in range(transformed_data.shape[1]):
198
+ all_df[f"component_{i + 1}"] = transformed_data[:, i]
199
+
200
+ # Save the trained model and any necessary assets
201
+ joblib.dump(model, os.path.join(args.model_dir, "model.joblib"))
202
+ if model_type == "classifier" and label_encoder:
203
+ joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
204
+
205
+ # Save the feature list to validate input during predictions
206
+ with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
207
+ json.dump(feature_list, fp)
208
+
209
+ #
210
+ # Inference Section
211
+ #
212
+ def model_fn(model_dir):
213
+ """Load and return the model from the specified directory."""
214
+ return joblib.load(os.path.join(model_dir, "model.joblib"))
215
+
216
+
217
+ def input_fn(input_data, content_type):
218
+ """Parse input data and return a DataFrame."""
219
+ if not input_data:
220
+ raise ValueError("Empty input data is not supported!")
221
+
222
+ # Decode bytes to string if necessary
223
+ if isinstance(input_data, bytes):
224
+ input_data = input_data.decode("utf-8")
225
+
226
+ if "text/csv" in content_type:
227
+ return pd.read_csv(StringIO(input_data))
228
+ elif "application/json" in content_type:
229
+ return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
230
+ else:
231
+ raise ValueError(f"{content_type} not supported!")
232
+
233
+
234
+ def output_fn(output_df, accept_type):
235
+ """Supports both CSV and JSON output formats."""
236
+ if "text/csv" in accept_type:
237
+ csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
238
+ return csv_output, "text/csv"
239
+ elif "application/json" in accept_type:
240
+ return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
241
+ else:
242
+ raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
243
+
244
+
245
+ def predict_fn(df, model):
246
+ """Make predictions or apply transformations using the model and return the DataFrame with results."""
247
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
248
+
249
+ # Load feature columns from the saved file
250
+ with open(os.path.join(model_dir, "feature_columns.json")) as fp:
251
+ model_features = json.load(fp)
252
+
253
+ # Load label encoder if available (for classification models)
254
+ label_encoder = None
255
+ if os.path.exists(os.path.join(model_dir, "label_encoder.joblib")):
256
+ label_encoder = joblib.load(os.path.join(model_dir, "label_encoder.joblib"))
257
+
258
+ # Match features in a case-insensitive manner
259
+ matched_df = match_features_case_insensitive(df, model_features)
260
+
261
+ # Initialize a dictionary to store the results
262
+ results = {}
263
+
264
+ # Determine how to handle the model based on its available methods
265
+ if hasattr(model, "predict"):
266
+ # For supervised models (classifier or regressor)
267
+ predictions = model.predict(matched_df[model_features])
268
+ results["prediction"] = predictions
269
+
270
+ elif hasattr(model, "fit_predict"):
271
+ # For clustering models (e.g., DBSCAN)
272
+ clusters = model.fit_predict(matched_df[model_features])
273
+ results["cluster"] = clusters
274
+
275
+ elif hasattr(model, "fit_transform") and not hasattr(model, "predict"):
276
+ # For transformation/projection models (e.g., t-SNE, PCA)
277
+ transformed_data = model.fit_transform(matched_df[model_features])
278
+
279
+ # Handle 2D projection models specifically
280
+ if model_type == "projection" and transformed_data.shape[1] == 2:
281
+ results["x"] = transformed_data[:, 0]
282
+ results["y"] = transformed_data[:, 1]
283
+ else:
284
+ # General case for any number of components
285
+ for i in range(transformed_data.shape[1]):
286
+ results[f"component_{i + 1}"] = transformed_data[:, i]
287
+
288
+ else:
289
+ # Raise an error if the model does not support the expected methods
290
+ raise ValueError("Model does not support predict, fit_predict, or fit_transform methods.")
291
+
292
+ # Decode predictions if using a label encoder (for classification)
293
+ if label_encoder and "prediction" in results:
294
+ results["prediction"] = label_encoder.inverse_transform(results["prediction"])
295
+
296
+ # Add the results to the DataFrame
297
+ for key, value in results.items():
298
+ df[key] = value
299
+
300
+ # Add probability columns if the model supports it (for classification)
301
+ if hasattr(model, "predict_proba"):
302
+ probs = model.predict_proba(matched_df[model_features])
303
+ df["pred_proba"] = [p.tolist() for p in probs]
304
+ df = expand_proba_column(df, label_encoder.classes_)
305
+
306
+ # Return the modified DataFrame
307
+ return df
@@ -101,8 +101,12 @@ def generate_model_script(template_params: dict) -> str:
101
101
 
102
102
  # Determine which template to use based on model type
103
103
  if template_params.get("model_class"):
104
- template_name = "scikit_learn.template"
105
- model_script_dir = "scikit_learn"
104
+ if template_params["model_class"].lower() == "pytorch":
105
+ template_name = "pytorch.template"
106
+ model_script_dir = "pytorch_model"
107
+ else:
108
+ template_name = "scikit_learn.template"
109
+ model_script_dir = "scikit_learn"
106
110
  elif template_params["model_type"] in [ModelType.REGRESSOR, ModelType.CLASSIFIER]:
107
111
  template_name = "xgb_model.template"
108
112
  model_script_dir = "xgb_model"
@@ -28,12 +28,12 @@ from typing import List, Tuple
28
28
 
29
29
  # Template Parameters
30
30
  TEMPLATE_PARAMS = {
31
- "model_type": "regressor",
32
- "target_column": "iq_score",
33
- "features": ['height', 'weight', 'salary', 'age', 'likes_dogs'],
34
- "compressed_features": [],
35
- "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/abc-regression/training",
36
- "train_all_data": False
31
+ "model_type": "classifier",
32
+ "target_column": "solubility_class",
33
+ "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct', 'fingerprint'],
34
+ "compressed_features": ['fingerprint'],
35
+ "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-fingerprints-plus-class/training",
36
+ "train_all_data": True
37
37
  }
38
38
 
39
39
  # Function to check if dataframe is empty
@@ -1,4 +1,3 @@
1
- import IPython
2
1
  from IPython import start_ipython
3
2
  from IPython.terminal.prompts import Prompts
4
3
  from IPython.terminal.ipapp import load_default_config
@@ -10,7 +9,6 @@ import botocore
10
9
  import webbrowser
11
10
  import pandas as pd
12
11
  import readline # noqa
13
- from distutils.version import LooseVersion
14
12
 
15
13
  try:
16
14
  import matplotlib.pyplot as plt # noqa
@@ -34,6 +32,8 @@ from workbench.utils.repl_utils import cprint, Spinner
34
32
  from workbench.utils.workbench_logging import IMPORTANT_LEVEL_NUM, TRACE_LEVEL_NUM
35
33
  from workbench.utils.config_manager import ConfigManager
36
34
  from workbench.utils.log_utils import silence_logs, log_theme
35
+ from workbench.api import Meta
36
+ from workbench.cached.cached_meta import CachedMeta
37
37
 
38
38
  # If we have RDKIT/Mordred let's pull in our cheminformatics utils
39
39
  try:
@@ -196,10 +196,7 @@ class WorkbenchShell:
196
196
 
197
197
  # Start IPython with the config and commands in the namespace
198
198
  try:
199
- if LooseVersion(IPython.__version__) >= LooseVersion("9.0.0"):
200
- ipython_argv = ["--no-tip", "--theme", "linux"]
201
- else:
202
- ipython_argv = []
199
+ ipython_argv = ["--no-tip", "--theme", "linux"]
203
200
  start_ipython(ipython_argv, user_ns=locs, config=config)
204
201
  finally:
205
202
  spinner = self.spinner_start("Goodbye to AWS:")
@@ -255,7 +252,7 @@ class WorkbenchShell:
255
252
 
256
253
  def import_workbench(self):
257
254
  # Import all the Workbench modules
258
- spinner = self.spinner_start("Importing Workbench:")
255
+ spinner = self.spinner_start("Spinning up Workbench:")
259
256
  try:
260
257
  # These are the classes we want to expose to the REPL
261
258
  self.commands["DataSource"] = importlib.import_module("workbench.api.data_source").DataSource
@@ -475,8 +472,6 @@ class WorkbenchShell:
475
472
 
476
473
  # Helpers method to switch from direct Meta to Cached Meta
477
474
  def try_cached_meta(self):
478
- from workbench.api import Meta
479
- from workbench.cached.cached_meta import CachedMeta
480
475
 
481
476
  with silence_logs():
482
477
  self.meta = CachedMeta()
@@ -1,6 +1,7 @@
1
1
  """JSON Utilities"""
2
2
 
3
3
  import json
4
+ from io import StringIO
4
5
  import numpy as np
5
6
  import pandas as pd
6
7
  import logging
@@ -33,9 +34,7 @@ class CustomEncoder(json.JSONEncoder):
33
34
  elif isinstance(obj, pd.DataFrame):
34
35
  return {
35
36
  "__dataframe__": True,
36
- "df": obj.to_dict(),
37
- "index": obj.index.tolist(),
38
- "index_name": obj.index.name,
37
+ "df": obj.to_json(orient="table"),
39
38
  }
40
39
  return super().default(obj)
41
40
  except Exception as e:
@@ -62,10 +61,16 @@ def custom_decoder(dct):
62
61
  if "__datetime__" in dct:
63
62
  return iso8601_to_datetime(dct["datetime"])
64
63
  elif "__dataframe__" in dct:
65
- df = pd.DataFrame.from_dict(dct["df"])
66
- if "index" in dct:
67
- df.index = dct["index"]
68
- df.index.name = dct.get("index_name")
64
+ df_data = dct["df"]
65
+ if isinstance(df_data, str):
66
+ df = pd.read_json(StringIO(df_data), orient="table")
67
+ else:
68
+ # Old format compatibility
69
+ log.warning("Decoding old dataframe format...")
70
+ df = pd.DataFrame.from_dict(df_data)
71
+ if "index" in dct:
72
+ df.index = dct["index"]
73
+ df.index.name = dct.get("index_name")
69
74
  return df
70
75
  return dct
71
76
  except Exception as e:
@@ -86,6 +91,7 @@ if __name__ == "__main__":
86
91
  "datetime": datetime.now(),
87
92
  "date": date.today(),
88
93
  "dataframe": pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}),
94
+ "list": [1, 2, 3],
89
95
  }
90
96
 
91
97
  # Encode the test dictionary
@@ -120,4 +126,17 @@ if __name__ == "__main__":
120
126
  decoded_df = json.loads(encoded_df, object_hook=custom_decoder)
121
127
 
122
128
  print("Original DataFrame index name:", df_with_index.index.name)
123
- print("Decoded DataFrame index name:", decoded_df.index.name) # Likely None
129
+ print("Decoded DataFrame index name:", decoded_df.index.name)
130
+
131
+ # Dataframe Testing
132
+ from workbench.api import DFStore
133
+
134
+ df_store = DFStore()
135
+ df = df_store.get("/testing/json_encoding/smart_sample_bad")
136
+ encoded = json.dumps(df, cls=CustomEncoder)
137
+ decoded_df = json.loads(encoded, object_hook=custom_decoder)
138
+
139
+ # Compare original and decoded DataFrame
140
+ from workbench.utils.pandas_utils import compare_dataframes
141
+
142
+ compare_dataframes(df, decoded_df)
@@ -94,14 +94,16 @@ def dataframe_delta(func_that_returns_df, previous_hash: Optional[str] = None) -
94
94
  return df, current_hash
95
95
 
96
96
 
97
- def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: list):
97
+ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: list = None):
98
98
  """Compare two DataFrames and report on differences.
99
99
 
100
100
  Args:
101
101
  df1 (pd.DataFrame): First DataFrame to compare.
102
102
  df2 (pd.DataFrame): Second DataFrame to compare.
103
- display_columns (list): Columns to display when differences are found.
103
+ display_columns (list): Columns to display when differences are found (defaults to all columns).
104
104
  """
105
+ if display_columns is None:
106
+ display_columns = df1.columns.tolist()
105
107
 
106
108
  # Check if the entire dataframes are equal
107
109
  if df1.equals(df2):
@@ -130,7 +132,7 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
130
132
  # Print out the column types
131
133
  print("\nColumn Types:")
132
134
  print(f"DF1: {df1[common_columns].dtypes.value_counts()}")
133
- print(f"DF2: {df2[common_columns].dtypes.value_counts()}")
135
+ print(f"\nDF2: {df2[common_columns].dtypes.value_counts()}")
134
136
 
135
137
  # Count the NaNs in each DataFrame individually (only show columns with > 0 NaNs)
136
138
  nan_counts_df1 = df1.isna().sum()
@@ -146,6 +148,7 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
146
148
 
147
149
  # Define tolerance for float comparisons
148
150
  epsilon = 1e-10
151
+ difference_counts = {}
149
152
 
150
153
  # Check for differences in common columns
151
154
  for column in common_columns:
@@ -161,18 +164,10 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
161
164
  # Other types (e.g., int) with NaNs treated as equal
162
165
  differences = ~(df1[column].fillna(0) == df2[column].fillna(0))
163
166
 
164
- # Create a merged DataFrame showing values from both DataFrames
165
- merged_df = pd.DataFrame(
166
- {
167
- **{col: df1.loc[differences, col] for col in display_columns},
168
- f"{column}_1": df1.loc[differences, column],
169
- f"{column}_2": df2.loc[differences, column],
170
- }
171
- )
172
-
173
167
  # If differences exist, display them
174
168
  if differences.any():
175
- print(f"\nColumn {column} has differences:")
169
+ print(f"\nColumn {column} has {differences.sum()} differences")
170
+ difference_counts[column] = differences.sum()
176
171
 
177
172
  # Create a merged DataFrame showing values from both DataFrames
178
173
  merged_df = pd.DataFrame(
@@ -186,6 +181,10 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
186
181
  # Display the merged DataFrame
187
182
  print(merged_df)
188
183
 
184
+ # If there are no differences report that
185
+ if not difference_counts:
186
+ print(f"\nNo differences found in common columns within {epsilon}")
187
+
189
188
 
190
189
  def subnormal_check(df):
191
190
  """
@@ -73,12 +73,25 @@ class RedisCache:
73
73
 
74
74
  def set(self, key, value):
75
75
  """Add an item to the redis_cache, all items are JSON serialized
76
+
76
77
  Args:
77
- key: item key
78
- value: the value associated with this key
78
+ key: item key
79
+ value: the value associated with this key
79
80
  """
80
81
  self._set(key, json.dumps(value, cls=CustomEncoder))
81
82
 
83
+ def atomic_set(self, key, value) -> bool:
84
+ """Atomically set key to value only if key doesn't exist.
85
+
86
+ Returns:
87
+ True if the key was set, False if it already existed.
88
+ """
89
+ # Serialize the value to JSON
90
+ serialized_value = json.dumps(value, cls=CustomEncoder)
91
+ result = self.redis_db.set(self.prefix + str(key) + self.postfix, serialized_value, ex=self.expire, nx=True)
92
+ log.debug(f"Atomic Set: {key} -> {value} (Result: {result})")
93
+ return result is True
94
+
82
95
  def get(self, key):
83
96
  """Get an item from the redis_cache, all items are JSON deserialized
84
97
  Args:
@@ -165,22 +178,20 @@ class RedisCache:
165
178
  def get_memory_config(self):
166
179
  """Get Redis memory usage and configuration settings as a dictionary"""
167
180
  info = {}
168
- try:
169
- memory_info = self.redis_db.info("memory")
170
- info["used_memory"] = memory_info.get("used_memory", "N/A")
171
- info["used_memory_human"] = memory_info.get("used_memory_human", "N/A")
172
- info["mem_fragmentation_ratio"] = memory_info.get("mem_fragmentation_ratio", "N/A")
173
- info["maxmemory_policy"] = memory_info.get("maxmemory_policy", "N/A")
174
- except redis.exceptions.RedisError as e:
175
- log.error(f"Error retrieving memory info from Redis: {e}")
176
181
 
182
+ # Memory info about the Redis database
183
+ memory_info = self.redis_db.info("memory")
184
+ info["used_memory"] = memory_info.get("used_memory", "N/A")
185
+ info["used_memory_human"] = memory_info.get("used_memory_human", "N/A")
186
+ info["mem_fragmentation_ratio"] = memory_info.get("mem_fragmentation_ratio", "N/A")
187
+ info["maxmemory_policy"] = memory_info.get("maxmemory_policy", "N/A")
188
+ # CONFIG commands are disabled in managed Redis services like ElastiCache
177
189
  try:
178
190
  max_memory = self.redis_db.config_get("maxmemory")
179
191
  info["maxmemory"] = max_memory.get("maxmemory", "N/A")
180
192
  except redis.exceptions.RedisError as e:
181
- log.error(f"Error retrieving config info from Redis (likely unsupported command): {e}")
182
- info["maxmemory"] = "Not Available - Command Restricted"
183
-
193
+ log.debug(f"CONFIG GET disabled (likely managed Redis service): {e}")
194
+ info["maxmemory"] = "Not Available - Managed Service"
184
195
  return info
185
196
 
186
197
  def report_memory_config(self):
@@ -244,6 +255,10 @@ if __name__ == "__main__":
244
255
  # Delete anything in the test database
245
256
  my_redis_cache.clear()
246
257
 
258
+ # Test the atomic set
259
+ assert my_redis_cache.atomic_set("foo", "bar") is True
260
+ assert my_redis_cache.atomic_set("foo", "baz") is False
261
+
247
262
  # Test storage
248
263
  my_redis_cache.set("foo", "bar")
249
264
  assert my_redis_cache.get("foo") == "bar"
@@ -3,7 +3,6 @@ use RedisCache if it's available, and fall back to Cache if it's not.
3
3
  """
4
4
 
5
5
  from pprint import pformat
6
- from contextlib import contextmanager
7
6
  from workbench.utils.cache import Cache
8
7
  from workbench.utils.redis_cache import RedisCache
9
8
 
@@ -12,21 +11,8 @@ import logging
12
11
  log = logging.getLogger("workbench")
13
12
 
14
13
 
15
- # Context manager for disabling refresh
16
- @contextmanager
17
- def disable_refresh():
18
- log.warning("WorkbenchCache: Disabling Refresh")
19
- WorkbenchCache.refresh_enabled = False
20
- yield
21
- log.warning("WorkbenchCache: Enabling Refresh")
22
- WorkbenchCache.refresh_enabled = True
23
-
24
-
25
14
  class WorkbenchCache:
26
15
 
27
- # Class attribute to control refresh treads (on/off)
28
- refresh_enabled = True
29
-
30
16
  def __init__(self, expire=None, prefix="", postfix=""):
31
17
  """WorkbenchCache Initialization
32
18
  Args:
@@ -82,6 +68,21 @@ class WorkbenchCache:
82
68
  def clear(self):
83
69
  return self._actual_cache.clear()
84
70
 
71
+ def atomic_set(self, key, value) -> bool:
72
+ """Atomically set key to value only if key doesn't exist.
73
+
74
+ Returns:
75
+ True if the key was set, False if it already existed.
76
+ """
77
+ if self._using_redis:
78
+ return self._actual_cache.atomic_set(key, value)
79
+
80
+ # In-Memory Cache does not support atomic operations, so we simulate it
81
+ else:
82
+ key_exists = self._actual_cache.get(key) is not None
83
+ self._actual_cache.set(key, value)
84
+ return not key_exists
85
+
85
86
  def show_size_details(self, value):
86
87
  """Print the size of the sub-parts of the value"""
87
88
  try:
@@ -118,6 +119,10 @@ if __name__ == "__main__":
118
119
  # Delete anything in the test database
119
120
  my_cache.clear()
120
121
 
122
+ # Test the atomic set
123
+ assert my_cache.atomic_set("foo", "bar") is True
124
+ assert my_cache.atomic_set("foo", "baz") is False # Should not overwrite
125
+
121
126
  # Test storage
122
127
  my_cache.set("foo", "bar")
123
128
  assert my_cache.get("foo") == "bar"
@@ -167,3 +172,4 @@ if __name__ == "__main__":
167
172
  my_cache.set("df", df)
168
173
  df = my_cache.get("df")
169
174
  print(df)
175
+ my_cache.clear()
@@ -25,7 +25,7 @@ class EndpointsPageView(PageView):
25
25
  def refresh(self):
26
26
  """Refresh the endpoint data from the Cloud Platform"""
27
27
  self.log.important("Calling endpoint page view refresh()..")
28
- self.endpoints_df = self.meta.endpoints()
28
+ self.endpoints_df = self.meta.endpoints(details=True)
29
29
 
30
30
  # Drop the AWS URL column
31
31
  self.endpoints_df.drop(columns=["_aws_url"], inplace=True, errors="ignore")