workbench 0.8.217__py3-none-any.whl → 0.8.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
- workbench/algorithms/dataframe/projection_2d.py +8 -2
- workbench/algorithms/dataframe/proximity.py +3 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/feature_set.py +0 -1
- workbench/core/artifacts/endpoint_core.py +2 -2
- workbench/core/artifacts/feature_set_core.py +185 -230
- workbench/core/transforms/features_to_model/features_to_model.py +2 -8
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
- workbench/model_script_utils/model_script_utils.py +15 -11
- workbench/model_scripts/chemprop/chemprop.template +195 -70
- workbench/model_scripts/chemprop/generated_model_script.py +198 -73
- workbench/model_scripts/chemprop/model_script_utils.py +15 -11
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +80 -43
- workbench/model_scripts/pytorch_model/generated_model_script.py +2 -2
- workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
- workbench/model_scripts/xgb_model/generated_model_script.py +7 -7
- workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +71 -2
- workbench/themes/light/custom.css +7 -1
- workbench/themes/midnight_blue/custom.css +34 -0
- workbench/utils/chem_utils/fingerprints.py +80 -43
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/meta_model_simulator.py +41 -13
- workbench/utils/model_utils.py +0 -1
- workbench/utils/plot_utils.py +146 -28
- workbench/utils/shap_utils.py +1 -55
- workbench/utils/theme_manager.py +95 -30
- workbench/web_interface/components/plugins/scatter_plot.py +152 -66
- workbench/web_interface/components/settings_menu.py +184 -0
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/METADATA +4 -13
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/RECORD +38 -37
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/entry_points.txt +1 -0
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/WHEEL +0 -0
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/top_level.txt +0 -0
|
@@ -61,6 +61,13 @@ class MetaModelSimulator:
|
|
|
61
61
|
df["abs_residual"] = df["residual"].abs()
|
|
62
62
|
self._dfs[name] = df
|
|
63
63
|
|
|
64
|
+
# Find common rows across all models
|
|
65
|
+
id_sets = {name: set(df[self.id_column]) for name, df in self._dfs.items()}
|
|
66
|
+
common_ids = set.intersection(*id_sets.values())
|
|
67
|
+
sizes = ", ".join(f"{name}: {len(ids)}" for name, ids in id_sets.items())
|
|
68
|
+
log.info(f"Row counts before alignment: {sizes} -> common: {len(common_ids)}")
|
|
69
|
+
self._dfs = {name: df[df[self.id_column].isin(common_ids)] for name, df in self._dfs.items()}
|
|
70
|
+
|
|
64
71
|
# Align DataFrames by sorting on id column
|
|
65
72
|
self._dfs = {name: df.sort_values(self.id_column).reset_index(drop=True) for name, df in self._dfs.items()}
|
|
66
73
|
log.info(f"Loaded {len(self._dfs)} models, {len(list(self._dfs.values())[0])} samples each")
|
|
@@ -372,13 +379,13 @@ class MetaModelSimulator:
|
|
|
372
379
|
return weight_df
|
|
373
380
|
|
|
374
381
|
def ensemble_failure_analysis(self) -> dict:
|
|
375
|
-
"""Compare ensemble vs best
|
|
382
|
+
"""Compare best ensemble strategy vs best individual model.
|
|
376
383
|
|
|
377
384
|
Returns:
|
|
378
385
|
Dict with comparison statistics
|
|
379
386
|
"""
|
|
380
387
|
print("\n" + "=" * 60)
|
|
381
|
-
print("ENSEMBLE VS BEST MODEL COMPARISON")
|
|
388
|
+
print("BEST ENSEMBLE VS BEST MODEL COMPARISON")
|
|
382
389
|
print("=" * 60)
|
|
383
390
|
|
|
384
391
|
model_names = list(self._dfs.keys())
|
|
@@ -393,35 +400,55 @@ class MetaModelSimulator:
|
|
|
393
400
|
combined[f"{name}_abs_err"] = df["abs_residual"].values
|
|
394
401
|
|
|
395
402
|
pred_cols = [f"{name}_pred" for name in model_names]
|
|
403
|
+
conf_cols = [f"{name}_conf" for name in model_names]
|
|
404
|
+
pred_arr = combined[pred_cols].values
|
|
405
|
+
conf_arr = combined[conf_cols].values
|
|
396
406
|
|
|
397
|
-
# Calculate ensemble prediction (inverse-MAE weighted)
|
|
398
407
|
mae_scores = {name: self._dfs[name]["abs_residual"].mean() for name in model_names}
|
|
399
408
|
inv_mae_weights = np.array([1.0 / mae_scores[name] for name in model_names])
|
|
400
409
|
inv_mae_weights = inv_mae_weights / inv_mae_weights.sum()
|
|
401
|
-
pred_arr = combined[pred_cols].values
|
|
402
|
-
combined["ensemble_pred"] = (pred_arr * inv_mae_weights).sum(axis=1)
|
|
403
|
-
combined["ensemble_abs_err"] = (combined["ensemble_pred"] - combined["target"]).abs()
|
|
404
410
|
|
|
405
|
-
#
|
|
411
|
+
# Compute all ensemble strategies (true ensembles that combine multiple models)
|
|
412
|
+
ensemble_strategies = {}
|
|
413
|
+
ensemble_strategies["Simple Mean"] = combined[pred_cols].mean(axis=1)
|
|
414
|
+
conf_sum = conf_arr.sum(axis=1, keepdims=True) + 1e-8
|
|
415
|
+
ensemble_strategies["Confidence-Weighted"] = (pred_arr * (conf_arr / conf_sum)).sum(axis=1)
|
|
416
|
+
ensemble_strategies["Inverse-MAE Weighted"] = (pred_arr * inv_mae_weights).sum(axis=1)
|
|
417
|
+
scaled_conf = conf_arr * inv_mae_weights
|
|
418
|
+
scaled_conf_sum = scaled_conf.sum(axis=1, keepdims=True) + 1e-8
|
|
419
|
+
ensemble_strategies["Scaled Conf-Weighted"] = (pred_arr * (scaled_conf / scaled_conf_sum)).sum(axis=1)
|
|
420
|
+
worst_model = max(mae_scores, key=mae_scores.get)
|
|
421
|
+
remaining = [n for n in model_names if n != worst_model]
|
|
422
|
+
remaining_cols = [f"{n}_pred" for n in remaining]
|
|
423
|
+
# Only add Drop Worst if it still combines multiple models
|
|
424
|
+
if len(remaining) > 1:
|
|
425
|
+
ensemble_strategies[f"Drop Worst ({worst_model})"] = combined[remaining_cols].mean(axis=1)
|
|
426
|
+
|
|
427
|
+
# Find best individual model
|
|
406
428
|
best_model = min(mae_scores, key=mae_scores.get)
|
|
407
429
|
combined["best_model_abs_err"] = combined[f"{best_model}_abs_err"]
|
|
430
|
+
best_model_mae = mae_scores[best_model]
|
|
408
431
|
|
|
409
|
-
#
|
|
432
|
+
# Find best true ensemble strategy
|
|
433
|
+
strategy_maes = {name: (preds - combined["target"]).abs().mean() for name, preds in ensemble_strategies.items()}
|
|
434
|
+
best_strategy = min(strategy_maes, key=strategy_maes.get)
|
|
435
|
+
combined["ensemble_pred"] = ensemble_strategies[best_strategy]
|
|
436
|
+
combined["ensemble_abs_err"] = (combined["ensemble_pred"] - combined["target"]).abs()
|
|
437
|
+
ensemble_mae = strategy_maes[best_strategy]
|
|
438
|
+
|
|
439
|
+
# Compare
|
|
410
440
|
combined["ensemble_better"] = combined["ensemble_abs_err"] < combined["best_model_abs_err"]
|
|
411
441
|
n_better = combined["ensemble_better"].sum()
|
|
412
442
|
n_total = len(combined)
|
|
413
443
|
|
|
414
|
-
ensemble_mae = combined["ensemble_abs_err"].mean()
|
|
415
|
-
best_model_mae = mae_scores[best_model]
|
|
416
|
-
|
|
417
444
|
print(f"\nBest individual model: {best_model} (MAE={best_model_mae:.4f})")
|
|
418
|
-
print(f"
|
|
445
|
+
print(f"Best ensemble strategy: {best_strategy} (MAE={ensemble_mae:.4f})")
|
|
419
446
|
if ensemble_mae < best_model_mae:
|
|
420
447
|
improvement = (best_model_mae - ensemble_mae) / best_model_mae * 100
|
|
421
448
|
print(f"Ensemble improves over best model by {improvement:.1f}%")
|
|
422
449
|
else:
|
|
423
450
|
degradation = (ensemble_mae - best_model_mae) / best_model_mae * 100
|
|
424
|
-
print(f"
|
|
451
|
+
print(f"No ensemble benefit: best single model outperforms all ensemble strategies by {degradation:.1f}%")
|
|
425
452
|
|
|
426
453
|
print("\nPer-row comparison:")
|
|
427
454
|
print(f" Ensemble wins: {n_better}/{n_total} ({100*n_better/n_total:.1f}%)")
|
|
@@ -443,6 +470,7 @@ class MetaModelSimulator:
|
|
|
443
470
|
|
|
444
471
|
return {
|
|
445
472
|
"ensemble_mae": ensemble_mae,
|
|
473
|
+
"best_strategy": best_strategy,
|
|
446
474
|
"best_model": best_model,
|
|
447
475
|
"best_model_mae": best_model_mae,
|
|
448
476
|
"ensemble_win_rate": n_better / n_total,
|
workbench/utils/model_utils.py
CHANGED
workbench/utils/plot_utils.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
"""Plot Utilities for Workbench"""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
import numpy as np
|
|
4
5
|
import pandas as pd
|
|
5
6
|
import plotly.graph_objects as go
|
|
7
|
+
from dash import html
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger("workbench")
|
|
6
10
|
|
|
7
11
|
|
|
8
12
|
# For approximating beeswarm effect
|
|
9
13
|
def beeswarm_offsets(values, point_size=0.05, precision=2, max_offset=0.3):
|
|
10
14
|
"""
|
|
11
|
-
Generate
|
|
15
|
+
Generate beeswarm offsets using random jitter with collision avoidance.
|
|
12
16
|
|
|
13
17
|
Args:
|
|
14
18
|
values: Array of positions to be adjusted
|
|
@@ -22,42 +26,55 @@ def beeswarm_offsets(values, point_size=0.05, precision=2, max_offset=0.3):
|
|
|
22
26
|
values = np.asarray(values)
|
|
23
27
|
rounded = np.round(values, precision)
|
|
24
28
|
offsets = np.zeros_like(values, dtype=float)
|
|
25
|
-
|
|
26
|
-
# Sort indices by original values
|
|
27
|
-
sorted_idx = np.argsort(values)
|
|
29
|
+
rng = np.random.default_rng(42) # Fixed seed for reproducibility
|
|
28
30
|
|
|
29
31
|
for val in np.unique(rounded):
|
|
30
32
|
# Get indices belonging to this group
|
|
31
|
-
|
|
33
|
+
group_mask = rounded == val
|
|
34
|
+
group_idx = np.where(group_mask)[0]
|
|
32
35
|
|
|
33
36
|
if len(group_idx) > 1:
|
|
34
37
|
# Track occupied positions for collision detection
|
|
35
38
|
occupied = []
|
|
36
39
|
|
|
37
40
|
for idx in group_idx:
|
|
38
|
-
#
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
#
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
41
|
+
# Try random positions, starting near center and expanding outward
|
|
42
|
+
best_offset = 0
|
|
43
|
+
found = False
|
|
44
|
+
|
|
45
|
+
# First point goes to center
|
|
46
|
+
if not occupied:
|
|
47
|
+
found = True
|
|
48
|
+
else:
|
|
49
|
+
# Try random positions with increasing spread
|
|
50
|
+
for attempt in range(50):
|
|
51
|
+
# Gradually increase the range of random offsets
|
|
52
|
+
spread = min(max_offset, point_size * (1 + attempt * 0.5))
|
|
53
|
+
offset = rng.uniform(-spread, spread)
|
|
54
|
+
|
|
55
|
+
# Check for collision with occupied positions
|
|
56
|
+
if not any(abs(offset - pos) < point_size for pos in occupied):
|
|
57
|
+
best_offset = offset
|
|
58
|
+
found = True
|
|
59
|
+
break
|
|
60
|
+
|
|
61
|
+
# If no free position found after attempts, find the least crowded spot
|
|
62
|
+
if not found:
|
|
63
|
+
# Try a grid of positions and pick one with most space
|
|
64
|
+
candidates = np.linspace(-max_offset, max_offset, 20)
|
|
65
|
+
rng.shuffle(candidates)
|
|
66
|
+
for candidate in candidates:
|
|
67
|
+
if not any(abs(candidate - pos) < point_size * 0.8 for pos in occupied):
|
|
68
|
+
best_offset = candidate
|
|
69
|
+
found = True
|
|
70
|
+
break
|
|
71
|
+
|
|
72
|
+
# Last resort: just use a random position within bounds
|
|
73
|
+
if not found:
|
|
74
|
+
best_offset = rng.uniform(-max_offset, max_offset)
|
|
75
|
+
|
|
76
|
+
offsets[idx] = best_offset
|
|
77
|
+
occupied.append(best_offset)
|
|
61
78
|
|
|
62
79
|
return offsets
|
|
63
80
|
|
|
@@ -179,6 +196,107 @@ def prediction_intervals(df, figure, x_col):
|
|
|
179
196
|
return figure
|
|
180
197
|
|
|
181
198
|
|
|
199
|
+
def molecule_hover_tooltip(smiles: str, mol_id: str = None, width: int = 300, height: int = 200) -> list:
|
|
200
|
+
"""Generate a molecule hover tooltip from a SMILES string.
|
|
201
|
+
|
|
202
|
+
This function creates a visually appealing tooltip with a dark background
|
|
203
|
+
that displays the molecule ID at the top and structure below when hovering
|
|
204
|
+
over scatter plot points.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
smiles: SMILES string representing the molecule
|
|
208
|
+
mol_id: Optional molecule ID to display at the top of the tooltip
|
|
209
|
+
width: Width of the molecule image in pixels (default: 300)
|
|
210
|
+
height: Height of the molecule image in pixels (default: 200)
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
list: A list containing an html.Div with the ID header and molecule SVG,
|
|
214
|
+
or an html.Div with an error message if rendering fails
|
|
215
|
+
"""
|
|
216
|
+
try:
|
|
217
|
+
from workbench.utils.chem_utils.vis import svg_from_smiles
|
|
218
|
+
from workbench.utils.theme_manager import ThemeManager
|
|
219
|
+
|
|
220
|
+
# Get the background color from the current theme
|
|
221
|
+
background = ThemeManager().background()
|
|
222
|
+
|
|
223
|
+
# Generate the SVG image from SMILES
|
|
224
|
+
img = svg_from_smiles(smiles, width, height, background=background)
|
|
225
|
+
|
|
226
|
+
if img is None:
|
|
227
|
+
log.warning(f"Could not render molecule for SMILES: {smiles}")
|
|
228
|
+
return [
|
|
229
|
+
html.Div(
|
|
230
|
+
"Invalid SMILES",
|
|
231
|
+
className="custom-tooltip",
|
|
232
|
+
style={
|
|
233
|
+
"padding": "10px",
|
|
234
|
+
"color": "rgb(255, 140, 140)",
|
|
235
|
+
"width": f"{width}px",
|
|
236
|
+
"height": f"{height}px",
|
|
237
|
+
"display": "flex",
|
|
238
|
+
"alignItems": "center",
|
|
239
|
+
"justifyContent": "center",
|
|
240
|
+
},
|
|
241
|
+
)
|
|
242
|
+
]
|
|
243
|
+
|
|
244
|
+
# Build the tooltip with ID header and molecule image
|
|
245
|
+
children = []
|
|
246
|
+
|
|
247
|
+
# Add ID header if provided
|
|
248
|
+
if mol_id is not None:
|
|
249
|
+
children.append(
|
|
250
|
+
html.Div(
|
|
251
|
+
str(mol_id),
|
|
252
|
+
style={
|
|
253
|
+
"textAlign": "center",
|
|
254
|
+
"padding": "8px",
|
|
255
|
+
"color": "rgb(200, 200, 200)",
|
|
256
|
+
"fontSize": "14px",
|
|
257
|
+
"fontWeight": "bold",
|
|
258
|
+
"borderBottom": "1px solid rgba(128, 128, 128, 0.5)",
|
|
259
|
+
},
|
|
260
|
+
)
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# Add molecule image
|
|
264
|
+
children.append(
|
|
265
|
+
html.Img(
|
|
266
|
+
src=img,
|
|
267
|
+
style={"padding": "0px", "margin": "0px", "display": "block"},
|
|
268
|
+
width=str(width),
|
|
269
|
+
height=str(height),
|
|
270
|
+
)
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
return [
|
|
274
|
+
html.Div(
|
|
275
|
+
children,
|
|
276
|
+
className="custom-tooltip",
|
|
277
|
+
style={"padding": "0px", "margin": "0px"},
|
|
278
|
+
)
|
|
279
|
+
]
|
|
280
|
+
|
|
281
|
+
except ImportError as e:
|
|
282
|
+
log.error(f"RDKit not available for molecule rendering: {e}")
|
|
283
|
+
return [
|
|
284
|
+
html.Div(
|
|
285
|
+
"RDKit not installed",
|
|
286
|
+
className="custom-tooltip",
|
|
287
|
+
style={
|
|
288
|
+
"padding": "10px",
|
|
289
|
+
"color": "rgb(255, 195, 140)",
|
|
290
|
+
"width": f"{width}px",
|
|
291
|
+
"height": f"{height}px",
|
|
292
|
+
"display": "flex",
|
|
293
|
+
"alignItems": "center",
|
|
294
|
+
"justifyContent": "center",
|
|
295
|
+
},
|
|
296
|
+
)
|
|
297
|
+
]
|
|
298
|
+
|
|
299
|
+
|
|
182
300
|
if __name__ == "__main__":
|
|
183
301
|
"""Exercise the Plot Utilities"""
|
|
184
302
|
import plotly.express as px
|
workbench/utils/shap_utils.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Optional, List, Tuple, Dict, Union
|
|
|
9
9
|
from workbench.utils.xgboost_model_utils import xgboost_model_from_s3
|
|
10
10
|
from workbench.utils.model_utils import load_category_mappings_from_s3
|
|
11
11
|
from workbench.utils.pandas_utils import convert_categorical_types
|
|
12
|
+
from workbench.model_script_utils.model_script_utils import decompress_features
|
|
12
13
|
|
|
13
14
|
# Set up the log
|
|
14
15
|
log = logging.getLogger("workbench")
|
|
@@ -111,61 +112,6 @@ def shap_values_data(
|
|
|
111
112
|
return result_df, feature_df
|
|
112
113
|
|
|
113
114
|
|
|
114
|
-
def decompress_features(
|
|
115
|
-
df: pd.DataFrame, features: List[str], compressed_features: List[str]
|
|
116
|
-
) -> Tuple[pd.DataFrame, List[str]]:
|
|
117
|
-
"""Prepare features for the XGBoost model
|
|
118
|
-
|
|
119
|
-
Args:
|
|
120
|
-
df (pd.DataFrame): The features DataFrame
|
|
121
|
-
features (List[str]): Full list of feature names
|
|
122
|
-
compressed_features (List[str]): List of feature names to decompress (bitstrings)
|
|
123
|
-
|
|
124
|
-
Returns:
|
|
125
|
-
pd.DataFrame: DataFrame with the decompressed features
|
|
126
|
-
List[str]: Updated list of feature names after decompression
|
|
127
|
-
|
|
128
|
-
Raises:
|
|
129
|
-
ValueError: If any missing values are found in the specified features
|
|
130
|
-
"""
|
|
131
|
-
|
|
132
|
-
# Check for any missing values in the required features
|
|
133
|
-
missing_counts = df[features].isna().sum()
|
|
134
|
-
if missing_counts.any():
|
|
135
|
-
missing_features = missing_counts[missing_counts > 0]
|
|
136
|
-
print(
|
|
137
|
-
f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
|
|
138
|
-
"WARNING: You might want to remove/replace all NaN values before processing."
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
# Decompress the specified compressed features
|
|
142
|
-
decompressed_features = features
|
|
143
|
-
for feature in compressed_features:
|
|
144
|
-
if (feature not in df.columns) or (feature not in features):
|
|
145
|
-
print(f"Feature '{feature}' not in the features list, skipping decompression.")
|
|
146
|
-
continue
|
|
147
|
-
|
|
148
|
-
# Remove the feature from the list of features to avoid duplication
|
|
149
|
-
decompressed_features.remove(feature)
|
|
150
|
-
|
|
151
|
-
# Handle all compressed features as bitstrings
|
|
152
|
-
bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
|
|
153
|
-
prefix = feature[:3]
|
|
154
|
-
|
|
155
|
-
# Create all new columns at once - avoids fragmentation
|
|
156
|
-
new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
|
|
157
|
-
new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
|
|
158
|
-
|
|
159
|
-
# Add to features list
|
|
160
|
-
decompressed_features.extend(new_col_names)
|
|
161
|
-
|
|
162
|
-
# Drop original column and concatenate new ones
|
|
163
|
-
df = df.drop(columns=[feature])
|
|
164
|
-
df = pd.concat([df, new_df], axis=1)
|
|
165
|
-
|
|
166
|
-
return df, decompressed_features
|
|
167
|
-
|
|
168
|
-
|
|
169
115
|
def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
|
|
170
116
|
"""
|
|
171
117
|
Internal function to calculate SHAP values for Workbench Models.
|
workbench/utils/theme_manager.py
CHANGED
|
@@ -76,10 +76,28 @@ class ThemeManager:
|
|
|
76
76
|
def set_theme(cls, theme_name: str):
|
|
77
77
|
"""Set the current theme."""
|
|
78
78
|
|
|
79
|
-
# For 'auto', we
|
|
80
|
-
#
|
|
79
|
+
# For 'auto', we check multiple sources in priority order:
|
|
80
|
+
# 1. Browser cookie (from localStorage, for per-user preference)
|
|
81
|
+
# 2. Parameter Store (for org-wide default)
|
|
82
|
+
# 3. Default theme
|
|
81
83
|
if theme_name == "auto":
|
|
82
|
-
theme_name =
|
|
84
|
+
theme_name = None
|
|
85
|
+
|
|
86
|
+
# 1. Check Flask request cookie (set from localStorage)
|
|
87
|
+
try:
|
|
88
|
+
from flask import request, has_request_context
|
|
89
|
+
|
|
90
|
+
if has_request_context():
|
|
91
|
+
theme_name = request.cookies.get("wb_theme")
|
|
92
|
+
except Exception:
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
# 2. Fall back to ParameterStore
|
|
96
|
+
if not theme_name:
|
|
97
|
+
theme_name = cls.ps.get("/workbench/dashboard/theme", warn=False)
|
|
98
|
+
|
|
99
|
+
# 3. Fall back to default
|
|
100
|
+
theme_name = theme_name or cls.default_theme
|
|
83
101
|
|
|
84
102
|
# Check if the theme is in our available themes
|
|
85
103
|
if theme_name not in cls.available_themes:
|
|
@@ -104,9 +122,27 @@ class ThemeManager:
|
|
|
104
122
|
cls.current_theme_name = theme_name
|
|
105
123
|
cls.log.info(f"Theme set to '{theme_name}'")
|
|
106
124
|
|
|
125
|
+
# Bootstrap themes that are dark mode (from Bootswatch)
|
|
126
|
+
_dark_bootstrap_themes = {"DARKLY", "CYBORG", "SLATE", "SOLAR", "SUPERHERO", "VAPOR"}
|
|
127
|
+
|
|
107
128
|
@classmethod
|
|
108
129
|
def dark_mode(cls) -> bool:
|
|
109
|
-
"""Check if the current theme is a dark mode theme.
|
|
130
|
+
"""Check if the current theme is a dark mode theme.
|
|
131
|
+
|
|
132
|
+
Determines dark mode by checking if the Bootstrap base theme is a known dark theme.
|
|
133
|
+
Falls back to checking if 'dark' is in the theme name.
|
|
134
|
+
"""
|
|
135
|
+
theme = cls.available_themes.get(cls.current_theme_name, {})
|
|
136
|
+
base_css = theme.get("base_css", "")
|
|
137
|
+
|
|
138
|
+
# Check if the base CSS URL contains a known dark Bootstrap theme
|
|
139
|
+
if base_css:
|
|
140
|
+
base_css_upper = base_css.upper()
|
|
141
|
+
for dark_theme in cls._dark_bootstrap_themes:
|
|
142
|
+
if dark_theme in base_css_upper:
|
|
143
|
+
return True
|
|
144
|
+
|
|
145
|
+
# Fallback: check if 'dark' is in the theme name
|
|
110
146
|
return "dark" in cls.current_theme().lower()
|
|
111
147
|
|
|
112
148
|
@classmethod
|
|
@@ -184,30 +220,57 @@ class ThemeManager:
|
|
|
184
220
|
|
|
185
221
|
@classmethod
|
|
186
222
|
def css_files(cls) -> list[str]:
|
|
187
|
-
"""Get the list of CSS files for the current theme.
|
|
188
|
-
|
|
223
|
+
"""Get the list of CSS files for the current theme.
|
|
224
|
+
|
|
225
|
+
Note: Uses /base.css route for dynamic theme switching instead of CDN URLs.
|
|
226
|
+
"""
|
|
189
227
|
css_files = []
|
|
190
228
|
|
|
191
|
-
#
|
|
192
|
-
|
|
193
|
-
css_files.append(theme["base_css"])
|
|
229
|
+
# Use Flask route for base CSS (allows dynamic theme switching)
|
|
230
|
+
css_files.append("/base.css")
|
|
194
231
|
|
|
195
232
|
# Add the DBC template CSS
|
|
196
233
|
css_files.append(cls.dbc_css)
|
|
197
234
|
|
|
198
235
|
# Add custom.css if it exists
|
|
199
|
-
|
|
200
|
-
css_files.append("/custom.css")
|
|
236
|
+
css_files.append("/custom.css")
|
|
201
237
|
|
|
202
238
|
return css_files
|
|
203
239
|
|
|
240
|
+
@classmethod
|
|
241
|
+
def _get_theme_from_cookie(cls):
|
|
242
|
+
"""Get the theme dict based on the wb_theme cookie, falling back to current theme."""
|
|
243
|
+
from flask import request
|
|
244
|
+
|
|
245
|
+
theme_name = request.cookies.get("wb_theme")
|
|
246
|
+
if theme_name and theme_name in cls.available_themes:
|
|
247
|
+
return cls.available_themes[theme_name], theme_name
|
|
248
|
+
return cls.available_themes[cls.current_theme_name], cls.current_theme_name
|
|
249
|
+
|
|
204
250
|
@classmethod
|
|
205
251
|
def register_css_route(cls, app):
|
|
206
|
-
"""Register Flask
|
|
252
|
+
"""Register Flask routes for CSS and before_request hook for theme switching."""
|
|
253
|
+
from flask import redirect
|
|
254
|
+
|
|
255
|
+
@app.server.before_request
|
|
256
|
+
def check_theme_cookie():
|
|
257
|
+
"""Check for theme cookie on each request and update theme if needed."""
|
|
258
|
+
_, theme_name = cls._get_theme_from_cookie()
|
|
259
|
+
if theme_name != cls.current_theme_name:
|
|
260
|
+
cls.set_theme(theme_name)
|
|
261
|
+
|
|
262
|
+
@app.server.route("/base.css")
|
|
263
|
+
def serve_base_css():
|
|
264
|
+
"""Redirect to the appropriate Bootstrap theme CSS based on cookie."""
|
|
265
|
+
theme, _ = cls._get_theme_from_cookie()
|
|
266
|
+
if theme["base_css"]:
|
|
267
|
+
return redirect(theme["base_css"])
|
|
268
|
+
return "", 404
|
|
207
269
|
|
|
208
270
|
@app.server.route("/custom.css")
|
|
209
271
|
def serve_custom_css():
|
|
210
|
-
|
|
272
|
+
"""Serve the custom.css file based on cookie."""
|
|
273
|
+
theme, _ = cls._get_theme_from_cookie()
|
|
211
274
|
if theme["custom_css"]:
|
|
212
275
|
return send_from_directory(theme["custom_css"].parent, theme["custom_css"].name)
|
|
213
276
|
return "", 404
|
|
@@ -250,23 +313,25 @@ class ThemeManager:
|
|
|
250
313
|
# Loop over each path in the theme path
|
|
251
314
|
for theme_path in cls.theme_path_list:
|
|
252
315
|
for theme_dir in theme_path.iterdir():
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
316
|
+
# Skip hidden directories (e.g., .idea, .git)
|
|
317
|
+
if not theme_dir.is_dir() or theme_dir.name.startswith("."):
|
|
318
|
+
continue
|
|
319
|
+
theme_name = theme_dir.name
|
|
320
|
+
|
|
321
|
+
# Grab the base.css URL
|
|
322
|
+
base_css_url = cls._get_base_css_url(theme_dir)
|
|
323
|
+
|
|
324
|
+
# Grab the plotly template json, custom.css, and branding json
|
|
325
|
+
plotly_template = theme_dir / "plotly.json"
|
|
326
|
+
custom_css = theme_dir / "custom.css"
|
|
327
|
+
branding = theme_dir / "branding.json"
|
|
328
|
+
|
|
329
|
+
cls.available_themes[theme_name] = {
|
|
330
|
+
"base_css": base_css_url,
|
|
331
|
+
"plotly_template": plotly_template,
|
|
332
|
+
"custom_css": custom_css if custom_css.exists() else None,
|
|
333
|
+
"branding": branding if branding.exists() else None,
|
|
334
|
+
}
|
|
270
335
|
|
|
271
336
|
if not cls.available_themes:
|
|
272
337
|
cls.log.warning(f"No themes found in '{cls.theme_path_list}'...")
|