workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,20 @@
1
1
  """Plot Utilities for Workbench"""
2
2
 
3
+ import logging
3
4
  import numpy as np
4
5
  import pandas as pd
5
6
  import plotly.graph_objects as go
7
+ from dash import html
8
+
9
+ from workbench.utils.color_utils import is_dark
10
+
11
+ log = logging.getLogger("workbench")
6
12
 
7
13
 
8
14
  # For approximating beeswarm effect
9
15
  def beeswarm_offsets(values, point_size=0.05, precision=2, max_offset=0.3):
10
16
  """
11
- Generate optimal beeswarm offsets with a maximum limit.
17
+ Generate beeswarm offsets using random jitter with collision avoidance.
12
18
 
13
19
  Args:
14
20
  values: Array of positions to be adjusted
@@ -22,42 +28,55 @@ def beeswarm_offsets(values, point_size=0.05, precision=2, max_offset=0.3):
22
28
  values = np.asarray(values)
23
29
  rounded = np.round(values, precision)
24
30
  offsets = np.zeros_like(values, dtype=float)
25
-
26
- # Sort indices by original values
27
- sorted_idx = np.argsort(values)
31
+ rng = np.random.default_rng(42) # Fixed seed for reproducibility
28
32
 
29
33
  for val in np.unique(rounded):
30
34
  # Get indices belonging to this group
31
- group_idx = sorted_idx[np.isin(sorted_idx, np.where(rounded == val)[0])]
35
+ group_mask = rounded == val
36
+ group_idx = np.where(group_mask)[0]
32
37
 
33
38
  if len(group_idx) > 1:
34
39
  # Track occupied positions for collision detection
35
40
  occupied = []
36
41
 
37
42
  for idx in group_idx:
38
- # Find best position with no collision
39
- offset = 0
40
- direction = 1
41
- step = 0
42
-
43
- while True:
44
- # Check if current offset position is free
45
- collision = any(abs(offset - pos) < point_size for pos in occupied)
46
-
47
- if not collision or abs(offset) >= max_offset:
48
- # Accept position if no collision or max offset reached
49
- if abs(offset) > max_offset:
50
- # Clamp to maximum
51
- offset = max_offset * (1 if offset > 0 else -1)
52
- break
53
-
54
- # Switch sides with increasing distance
55
- step += 0.25
56
- direction *= -1
57
- offset = direction * step * point_size
58
-
59
- offsets[idx] = offset
60
- occupied.append(offset)
43
+ # Try random positions, starting near center and expanding outward
44
+ best_offset = 0
45
+ found = False
46
+
47
+ # First point goes to center
48
+ if not occupied:
49
+ found = True
50
+ else:
51
+ # Try random positions with increasing spread
52
+ for attempt in range(50):
53
+ # Gradually increase the range of random offsets
54
+ spread = min(max_offset, point_size * (1 + attempt * 0.5))
55
+ offset = rng.uniform(-spread, spread)
56
+
57
+ # Check for collision with occupied positions
58
+ if not any(abs(offset - pos) < point_size for pos in occupied):
59
+ best_offset = offset
60
+ found = True
61
+ break
62
+
63
+ # If no free position found after attempts, find the least crowded spot
64
+ if not found:
65
+ # Try a grid of positions and pick one with most space
66
+ candidates = np.linspace(-max_offset, max_offset, 20)
67
+ rng.shuffle(candidates)
68
+ for candidate in candidates:
69
+ if not any(abs(candidate - pos) < point_size * 0.8 for pos in occupied):
70
+ best_offset = candidate
71
+ found = True
72
+ break
73
+
74
+ # Last resort: just use a random position within bounds
75
+ if not found:
76
+ best_offset = rng.uniform(-max_offset, max_offset)
77
+
78
+ offsets[idx] = best_offset
79
+ occupied.append(best_offset)
61
80
 
62
81
  return offsets
63
82
 
@@ -132,7 +151,7 @@ def prediction_intervals(df, figure, x_col):
132
151
  x=sorted_df[x_col],
133
152
  y=sorted_df["q_025"],
134
153
  mode="lines",
135
- line=dict(width=1, color="rgba(99, 110, 250, 0.25)"),
154
+ line=dict(width=1, color="rgba(99, 110, 250, 0.5)"),
136
155
  name="2.5 Percentile",
137
156
  hoverinfo="skip",
138
157
  showlegend=False,
@@ -143,12 +162,12 @@ def prediction_intervals(df, figure, x_col):
143
162
  x=sorted_df[x_col],
144
163
  y=sorted_df["q_975"],
145
164
  mode="lines",
146
- line=dict(width=1, color="rgba(99, 110, 250, 0.25)"),
165
+ line=dict(width=1, color="rgba(99, 110, 250, 0.5)"),
147
166
  name="97.5 Percentile",
148
167
  hoverinfo="skip",
149
168
  showlegend=False,
150
169
  fill="tonexty",
151
- fillcolor="rgba(99, 110, 250, 0.2)",
170
+ fillcolor="rgba(99, 110, 250, 0.35)",
152
171
  )
153
172
  )
154
173
  # Add inner band (q_25 to q_75) - less transparent
@@ -157,7 +176,7 @@ def prediction_intervals(df, figure, x_col):
157
176
  x=sorted_df[x_col],
158
177
  y=sorted_df["q_25"],
159
178
  mode="lines",
160
- line=dict(width=1, color="rgba(99, 250, 110, 0.25)"),
179
+ line=dict(width=1, color="rgba(99, 250, 110, 0.5)"),
161
180
  name="25 Percentile",
162
181
  hoverinfo="skip",
163
182
  showlegend=False,
@@ -168,17 +187,123 @@ def prediction_intervals(df, figure, x_col):
168
187
  x=sorted_df[x_col],
169
188
  y=sorted_df["q_75"],
170
189
  mode="lines",
171
- line=dict(width=1, color="rgba(99, 250, 100, 0.25)"),
190
+ line=dict(width=1, color="rgba(99, 250, 110, 0.5)"),
172
191
  name="75 Percentile",
173
192
  hoverinfo="skip",
174
193
  showlegend=False,
175
194
  fill="tonexty",
176
- fillcolor="rgba(99, 250, 110, 0.2)",
195
+ fillcolor="rgba(99, 250, 110, 0.35)",
177
196
  )
178
197
  )
179
198
  return figure
180
199
 
181
200
 
201
+ def molecule_hover_tooltip(
202
+ smiles: str, mol_id: str = None, width: int = 300, height: int = 200, background: str = None
203
+ ) -> list:
204
+ """Generate a molecule hover tooltip from a SMILES string.
205
+
206
+ This function creates a visually appealing tooltip with a dark background
207
+ that displays the molecule ID at the top and structure below when hovering
208
+ over scatter plot points.
209
+
210
+ Args:
211
+ smiles: SMILES string representing the molecule
212
+ mol_id: Optional molecule ID to display at the top of the tooltip
213
+ width: Width of the molecule image in pixels (default: 300)
214
+ height: Height of the molecule image in pixels (default: 200)
215
+ background: Optional background color (if None, uses dark gray)
216
+
217
+ Returns:
218
+ list: A list containing an html.Div with the ID header and molecule SVG,
219
+ or an html.Div with an error message if rendering fails
220
+ """
221
+ try:
222
+ from workbench.utils.chem_utils.vis import svg_from_smiles
223
+
224
+ # Use provided background or default to dark gray
225
+ if background is None:
226
+ background = "rgba(64, 64, 64, 1)"
227
+
228
+ # Generate the SVG image from SMILES (base64 encoded data URI)
229
+ img = svg_from_smiles(smiles, width, height, background=background)
230
+
231
+ if img is None:
232
+ log.warning(f"Could not render molecule for SMILES: {smiles}")
233
+ return [
234
+ html.Div(
235
+ "Invalid SMILES",
236
+ className="custom-tooltip",
237
+ style={
238
+ "padding": "10px",
239
+ "color": "rgb(255, 140, 140)",
240
+ "width": f"{width}px",
241
+ "height": f"{height}px",
242
+ "display": "flex",
243
+ "alignItems": "center",
244
+ "justifyContent": "center",
245
+ },
246
+ )
247
+ ]
248
+
249
+ # Build the tooltip with ID header and molecule image
250
+ children = []
251
+
252
+ # Add ID header if provided
253
+ if mol_id is not None:
254
+ # Set text color based on background brightness
255
+ text_color = "rgb(200, 200, 200)" if is_dark(background) else "rgb(60, 60, 60)"
256
+ children.append(
257
+ html.Div(
258
+ str(mol_id),
259
+ style={
260
+ "textAlign": "center",
261
+ "padding": "8px",
262
+ "color": text_color,
263
+ "fontSize": "14px",
264
+ "fontWeight": "bold",
265
+ "borderBottom": "1px solid rgba(128, 128, 128, 0.5)",
266
+ },
267
+ )
268
+ )
269
+
270
+ # Add molecule image
271
+ children.append(
272
+ html.Img(
273
+ src=img,
274
+ style={"padding": "0px", "margin": "0px", "display": "block"},
275
+ width=str(width),
276
+ height=str(height),
277
+ )
278
+ )
279
+
280
+ return [
281
+ html.Div(
282
+ children,
283
+ className="custom-tooltip",
284
+ style={"padding": "0px", "margin": "0px"},
285
+ )
286
+ ]
287
+
288
+ except ImportError as e:
289
+ log.error(f"RDKit not available for molecule rendering: {e}")
290
+ return [
291
+ html.Div(
292
+ "RDKit not installed",
293
+ className="custom-tooltip",
294
+ style={
295
+ "padding": "10px",
296
+ "color": "rgb(255, 195, 140)",
297
+ "width": f"{width}px",
298
+ "height": f"{height}px",
299
+ "display": "flex",
300
+ "alignItems": "center",
301
+ "justifyContent": "center",
302
+ },
303
+ )
304
+ ]
305
+
306
+
182
307
  if __name__ == "__main__":
183
308
  """Exercise the Plot Utilities"""
184
309
  import plotly.express as px
@@ -0,0 +1,87 @@
1
+ """PyTorch Tabular utilities for Workbench models."""
2
+
3
+ import logging
4
+ import os
5
+ import tarfile
6
+ import tempfile
7
+ from typing import Any, Tuple
8
+
9
+ import awswrangler as wr
10
+ import pandas as pd
11
+
12
+ from workbench.utils.aws_utils import pull_s3_data
13
+ from workbench.utils.metrics_utils import compute_metrics_from_predictions
14
+
15
+ log = logging.getLogger("workbench")
16
+
17
+
18
+ def download_and_extract_model(s3_uri: str, model_dir: str) -> None:
19
+ """Download and extract a PyTorch model artifact from S3.
20
+
21
+ Args:
22
+ s3_uri: S3 URI of the model.tar.gz artifact
23
+ model_dir: Local directory to extract the model to
24
+ """
25
+ with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
26
+ tmp_path = tmp.name
27
+
28
+ try:
29
+ wr.s3.download(path=s3_uri, local_file=tmp_path)
30
+ with tarfile.open(tmp_path, "r:gz") as tar:
31
+ tar.extractall(model_dir)
32
+ log.info(f"Extracted model to {model_dir}")
33
+ finally:
34
+ if os.path.exists(tmp_path):
35
+ os.remove(tmp_path)
36
+
37
+
38
+ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
39
+ """Pull cross-validation results from AWS training artifacts.
40
+
41
+ This retrieves the validation predictions saved during model training and
42
+ computes metrics directly from them. For PyTorch models trained with
43
+ n_folds > 1, these are out-of-fold predictions from k-fold cross-validation.
44
+
45
+ Args:
46
+ workbench_model: Workbench model object
47
+
48
+ Returns:
49
+ Tuple of:
50
+ - DataFrame with computed metrics
51
+ - DataFrame with validation predictions
52
+ """
53
+ # Get the validation predictions from S3
54
+ s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
55
+ predictions_df = pull_s3_data(s3_path)
56
+
57
+ if predictions_df is None:
58
+ raise ValueError(f"No validation predictions found at {s3_path}")
59
+
60
+ log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
61
+
62
+ # Compute metrics from predictions
63
+ target = workbench_model.target()
64
+ class_labels = workbench_model.class_labels()
65
+
66
+ if target in predictions_df.columns and "prediction" in predictions_df.columns:
67
+ metrics_df = compute_metrics_from_predictions(predictions_df, target, class_labels)
68
+ else:
69
+ metrics_df = pd.DataFrame()
70
+
71
+ return metrics_df, predictions_df
72
+
73
+
74
+ if __name__ == "__main__":
75
+ from workbench.api import Model
76
+
77
+ # Test pulling CV results
78
+ model_name = "aqsol-reg-pytorch"
79
+ print(f"Loading Workbench model: {model_name}")
80
+ model = Model(model_name)
81
+ print(f"Model Framework: {model.model_framework}")
82
+
83
+ # Pull CV results from training artifacts
84
+ metrics_df, predictions_df = pull_cv_results(model)
85
+ print(f"\nMetrics:\n{metrics_df}")
86
+ print(f"\nPredictions shape: {predictions_df.shape}")
87
+ print(f"Predictions columns: {predictions_df.columns.tolist()}")
@@ -9,6 +9,7 @@ from typing import Optional, List, Tuple, Dict, Union
9
9
  from workbench.utils.xgboost_model_utils import xgboost_model_from_s3
10
10
  from workbench.utils.model_utils import load_category_mappings_from_s3
11
11
  from workbench.utils.pandas_utils import convert_categorical_types
12
+ from workbench.model_script_utils.model_script_utils import decompress_features
12
13
 
13
14
  # Set up the log
14
15
  log = logging.getLogger("workbench")
@@ -111,61 +112,6 @@ def shap_values_data(
111
112
  return result_df, feature_df
112
113
 
113
114
 
114
- def decompress_features(
115
- df: pd.DataFrame, features: List[str], compressed_features: List[str]
116
- ) -> Tuple[pd.DataFrame, List[str]]:
117
- """Prepare features for the XGBoost model
118
-
119
- Args:
120
- df (pd.DataFrame): The features DataFrame
121
- features (List[str]): Full list of feature names
122
- compressed_features (List[str]): List of feature names to decompress (bitstrings)
123
-
124
- Returns:
125
- pd.DataFrame: DataFrame with the decompressed features
126
- List[str]: Updated list of feature names after decompression
127
-
128
- Raises:
129
- ValueError: If any missing values are found in the specified features
130
- """
131
-
132
- # Check for any missing values in the required features
133
- missing_counts = df[features].isna().sum()
134
- if missing_counts.any():
135
- missing_features = missing_counts[missing_counts > 0]
136
- print(
137
- f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
138
- "WARNING: You might want to remove/replace all NaN values before processing."
139
- )
140
-
141
- # Decompress the specified compressed features
142
- decompressed_features = features
143
- for feature in compressed_features:
144
- if (feature not in df.columns) or (feature not in features):
145
- print(f"Feature '{feature}' not in the features list, skipping decompression.")
146
- continue
147
-
148
- # Remove the feature from the list of features to avoid duplication
149
- decompressed_features.remove(feature)
150
-
151
- # Handle all compressed features as bitstrings
152
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
153
- prefix = feature[:3]
154
-
155
- # Create all new columns at once - avoids fragmentation
156
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
157
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
158
-
159
- # Add to features list
160
- decompressed_features.extend(new_col_names)
161
-
162
- # Drop original column and concatenate new ones
163
- df = df.drop(columns=[feature])
164
- df = pd.concat([df, new_df], axis=1)
165
-
166
- return df, decompressed_features
167
-
168
-
169
115
  def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
170
116
  """
171
117
  Internal function to calculate SHAP values for Workbench Models.
@@ -212,6 +158,14 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
212
158
  log.error("No XGBoost model found in the artifact.")
213
159
  return None, None, None, None
214
160
 
161
+ # Get the booster (SHAP requires the booster, not the sklearn wrapper)
162
+ if hasattr(xgb_model, "get_booster"):
163
+ # Full sklearn model - extract the booster
164
+ booster = xgb_model.get_booster()
165
+ else:
166
+ # Already a booster
167
+ booster = xgb_model
168
+
215
169
  # Load category mappings if available
216
170
  category_mappings = load_category_mappings_from_s3(model_artifact_uri)
217
171
 
@@ -229,8 +183,8 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
229
183
  # Create a DMatrix with categorical support
230
184
  dmatrix = xgb.DMatrix(X, enable_categorical=True)
231
185
 
232
- # Use XGBoost's built-in SHAP calculation
233
- shap_values = xgb_model.predict(dmatrix, pred_contribs=True, strict_shape=True)
186
+ # Use XGBoost's built-in SHAP calculation (booster method, not sklearn)
187
+ shap_values = booster.predict(dmatrix, pred_contribs=True, strict_shape=True)
234
188
  features_with_bias = features + ["bias"]
235
189
 
236
190
  # Now we need to subset the columns based on top 10 SHAP values
@@ -76,10 +76,28 @@ class ThemeManager:
76
76
  def set_theme(cls, theme_name: str):
77
77
  """Set the current theme."""
78
78
 
79
- # For 'auto', we try to grab a theme from the Parameter Store
80
- # if we can't find one, we'll set the theme to the default
79
+ # For 'auto', we check multiple sources in priority order:
80
+ # 1. Browser cookie (from localStorage, for per-user preference)
81
+ # 2. Parameter Store (for org-wide default)
82
+ # 3. Default theme
81
83
  if theme_name == "auto":
82
- theme_name = cls.ps.get("/workbench/dashboard/theme", warn=False) or cls.default_theme
84
+ theme_name = None
85
+
86
+ # 1. Check Flask request cookie (set from localStorage)
87
+ try:
88
+ from flask import request, has_request_context
89
+
90
+ if has_request_context():
91
+ theme_name = request.cookies.get("wb_theme")
92
+ except Exception:
93
+ pass
94
+
95
+ # 2. Fall back to ParameterStore
96
+ if not theme_name:
97
+ theme_name = cls.ps.get("/workbench/dashboard/theme", warn=False)
98
+
99
+ # 3. Fall back to default
100
+ theme_name = theme_name or cls.default_theme
83
101
 
84
102
  # Check if the theme is in our available themes
85
103
  if theme_name not in cls.available_themes:
@@ -104,9 +122,27 @@ class ThemeManager:
104
122
  cls.current_theme_name = theme_name
105
123
  cls.log.info(f"Theme set to '{theme_name}'")
106
124
 
125
+ # Bootstrap themes that are dark mode (from Bootswatch)
126
+ _dark_bootstrap_themes = {"DARKLY", "CYBORG", "SLATE", "SOLAR", "SUPERHERO", "VAPOR"}
127
+
107
128
  @classmethod
108
129
  def dark_mode(cls) -> bool:
109
- """Check if the current theme is a dark mode theme."""
130
+ """Check if the current theme is a dark mode theme.
131
+
132
+ Determines dark mode by checking if the Bootstrap base theme is a known dark theme.
133
+ Falls back to checking if 'dark' is in the theme name.
134
+ """
135
+ theme = cls.available_themes.get(cls.current_theme_name, {})
136
+ base_css = theme.get("base_css", "")
137
+
138
+ # Check if the base CSS URL contains a known dark Bootstrap theme
139
+ if base_css:
140
+ base_css_upper = base_css.upper()
141
+ for dark_theme in cls._dark_bootstrap_themes:
142
+ if dark_theme in base_css_upper:
143
+ return True
144
+
145
+ # Fallback: check if 'dark' is in the theme name
110
146
  return "dark" in cls.current_theme().lower()
111
147
 
112
148
  @classmethod
@@ -184,30 +220,57 @@ class ThemeManager:
184
220
 
185
221
  @classmethod
186
222
  def css_files(cls) -> list[str]:
187
- """Get the list of CSS files for the current theme."""
188
- theme = cls.available_themes[cls.current_theme_name]
223
+ """Get the list of CSS files for the current theme.
224
+
225
+ Note: Uses /base.css route for dynamic theme switching instead of CDN URLs.
226
+ """
189
227
  css_files = []
190
228
 
191
- # Add base.css or its CDN URL
192
- if theme["base_css"]:
193
- css_files.append(theme["base_css"])
229
+ # Use Flask route for base CSS (allows dynamic theme switching)
230
+ css_files.append("/base.css")
194
231
 
195
232
  # Add the DBC template CSS
196
233
  css_files.append(cls.dbc_css)
197
234
 
198
235
  # Add custom.css if it exists
199
- if theme["custom_css"]:
200
- css_files.append("/custom.css")
236
+ css_files.append("/custom.css")
201
237
 
202
238
  return css_files
203
239
 
240
+ @classmethod
241
+ def _get_theme_from_cookie(cls):
242
+ """Get the theme dict based on the wb_theme cookie, falling back to current theme."""
243
+ from flask import request
244
+
245
+ theme_name = request.cookies.get("wb_theme")
246
+ if theme_name and theme_name in cls.available_themes:
247
+ return cls.available_themes[theme_name], theme_name
248
+ return cls.available_themes[cls.current_theme_name], cls.current_theme_name
249
+
204
250
  @classmethod
205
251
  def register_css_route(cls, app):
206
- """Register Flask route for custom.css."""
252
+ """Register Flask routes for CSS and before_request hook for theme switching."""
253
+ from flask import redirect
254
+
255
+ @app.server.before_request
256
+ def check_theme_cookie():
257
+ """Check for theme cookie on each request and update theme if needed."""
258
+ _, theme_name = cls._get_theme_from_cookie()
259
+ if theme_name != cls.current_theme_name:
260
+ cls.set_theme(theme_name)
261
+
262
+ @app.server.route("/base.css")
263
+ def serve_base_css():
264
+ """Redirect to the appropriate Bootstrap theme CSS based on cookie."""
265
+ theme, _ = cls._get_theme_from_cookie()
266
+ if theme["base_css"]:
267
+ return redirect(theme["base_css"])
268
+ return "", 404
207
269
 
208
270
  @app.server.route("/custom.css")
209
271
  def serve_custom_css():
210
- theme = cls.available_themes[cls.current_theme_name]
272
+ """Serve the custom.css file based on cookie."""
273
+ theme, _ = cls._get_theme_from_cookie()
211
274
  if theme["custom_css"]:
212
275
  return send_from_directory(theme["custom_css"].parent, theme["custom_css"].name)
213
276
  return "", 404
@@ -250,23 +313,25 @@ class ThemeManager:
250
313
  # Loop over each path in the theme path
251
314
  for theme_path in cls.theme_path_list:
252
315
  for theme_dir in theme_path.iterdir():
253
- if theme_dir.is_dir():
254
- theme_name = theme_dir.name
255
-
256
- # Grab the base.css URL
257
- base_css_url = cls._get_base_css_url(theme_dir)
258
-
259
- # Grab the plotly template json, custom.css, and branding json
260
- plotly_template = theme_dir / "plotly.json"
261
- custom_css = theme_dir / "custom.css"
262
- branding = theme_dir / "branding.json"
263
-
264
- cls.available_themes[theme_name] = {
265
- "base_css": base_css_url,
266
- "plotly_template": plotly_template,
267
- "custom_css": custom_css if custom_css.exists() else None,
268
- "branding": branding if branding.exists() else None,
269
- }
316
+ # Skip hidden directories (e.g., .idea, .git)
317
+ if not theme_dir.is_dir() or theme_dir.name.startswith("."):
318
+ continue
319
+ theme_name = theme_dir.name
320
+
321
+ # Grab the base.css URL
322
+ base_css_url = cls._get_base_css_url(theme_dir)
323
+
324
+ # Grab the plotly template json, custom.css, and branding json
325
+ plotly_template = theme_dir / "plotly.json"
326
+ custom_css = theme_dir / "custom.css"
327
+ branding = theme_dir / "branding.json"
328
+
329
+ cls.available_themes[theme_name] = {
330
+ "base_css": base_css_url,
331
+ "plotly_template": plotly_template,
332
+ "custom_css": custom_css if custom_css.exists() else None,
333
+ "branding": branding if branding.exists() else None,
334
+ }
270
335
 
271
336
  if not cls.available_themes:
272
337
  cls.log.warning(f"No themes found in '{cls.theme_path_list}'...")