workbench 0.8.217__py3-none-any.whl → 0.8.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  3. workbench/algorithms/dataframe/projection_2d.py +8 -2
  4. workbench/algorithms/dataframe/proximity.py +3 -0
  5. workbench/algorithms/sql/outliers.py +3 -3
  6. workbench/api/feature_set.py +0 -1
  7. workbench/core/artifacts/endpoint_core.py +2 -2
  8. workbench/core/artifacts/feature_set_core.py +185 -230
  9. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  10. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  11. workbench/model_script_utils/model_script_utils.py +15 -11
  12. workbench/model_scripts/chemprop/chemprop.template +195 -70
  13. workbench/model_scripts/chemprop/generated_model_script.py +198 -73
  14. workbench/model_scripts/chemprop/model_script_utils.py +15 -11
  15. workbench/model_scripts/custom_models/chem_info/fingerprints.py +80 -43
  16. workbench/model_scripts/pytorch_model/generated_model_script.py +2 -2
  17. workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
  18. workbench/model_scripts/xgb_model/generated_model_script.py +7 -7
  19. workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
  20. workbench/scripts/meta_model_sim.py +35 -0
  21. workbench/scripts/ml_pipeline_sqs.py +71 -2
  22. workbench/themes/light/custom.css +7 -1
  23. workbench/themes/midnight_blue/custom.css +34 -0
  24. workbench/utils/chem_utils/fingerprints.py +80 -43
  25. workbench/utils/chem_utils/projections.py +16 -6
  26. workbench/utils/meta_model_simulator.py +41 -13
  27. workbench/utils/model_utils.py +0 -1
  28. workbench/utils/plot_utils.py +146 -28
  29. workbench/utils/shap_utils.py +1 -55
  30. workbench/utils/theme_manager.py +95 -30
  31. workbench/web_interface/components/plugins/scatter_plot.py +152 -66
  32. workbench/web_interface/components/settings_menu.py +184 -0
  33. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/METADATA +4 -13
  34. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/RECORD +38 -37
  35. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/entry_points.txt +1 -0
  36. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  37. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  38. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/WHEEL +0 -0
  39. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/licenses/LICENSE +0 -0
  40. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/top_level.txt +0 -0
@@ -61,6 +61,13 @@ class MetaModelSimulator:
61
61
  df["abs_residual"] = df["residual"].abs()
62
62
  self._dfs[name] = df
63
63
 
64
+ # Find common rows across all models
65
+ id_sets = {name: set(df[self.id_column]) for name, df in self._dfs.items()}
66
+ common_ids = set.intersection(*id_sets.values())
67
+ sizes = ", ".join(f"{name}: {len(ids)}" for name, ids in id_sets.items())
68
+ log.info(f"Row counts before alignment: {sizes} -> common: {len(common_ids)}")
69
+ self._dfs = {name: df[df[self.id_column].isin(common_ids)] for name, df in self._dfs.items()}
70
+
64
71
  # Align DataFrames by sorting on id column
65
72
  self._dfs = {name: df.sort_values(self.id_column).reset_index(drop=True) for name, df in self._dfs.items()}
66
73
  log.info(f"Loaded {len(self._dfs)} models, {len(list(self._dfs.values())[0])} samples each")
@@ -372,13 +379,13 @@ class MetaModelSimulator:
372
379
  return weight_df
373
380
 
374
381
  def ensemble_failure_analysis(self) -> dict:
375
- """Compare ensemble vs best overall model (not per-row oracle).
382
+ """Compare best ensemble strategy vs best individual model.
376
383
 
377
384
  Returns:
378
385
  Dict with comparison statistics
379
386
  """
380
387
  print("\n" + "=" * 60)
381
- print("ENSEMBLE VS BEST MODEL COMPARISON")
388
+ print("BEST ENSEMBLE VS BEST MODEL COMPARISON")
382
389
  print("=" * 60)
383
390
 
384
391
  model_names = list(self._dfs.keys())
@@ -393,35 +400,55 @@ class MetaModelSimulator:
393
400
  combined[f"{name}_abs_err"] = df["abs_residual"].values
394
401
 
395
402
  pred_cols = [f"{name}_pred" for name in model_names]
403
+ conf_cols = [f"{name}_conf" for name in model_names]
404
+ pred_arr = combined[pred_cols].values
405
+ conf_arr = combined[conf_cols].values
396
406
 
397
- # Calculate ensemble prediction (inverse-MAE weighted)
398
407
  mae_scores = {name: self._dfs[name]["abs_residual"].mean() for name in model_names}
399
408
  inv_mae_weights = np.array([1.0 / mae_scores[name] for name in model_names])
400
409
  inv_mae_weights = inv_mae_weights / inv_mae_weights.sum()
401
- pred_arr = combined[pred_cols].values
402
- combined["ensemble_pred"] = (pred_arr * inv_mae_weights).sum(axis=1)
403
- combined["ensemble_abs_err"] = (combined["ensemble_pred"] - combined["target"]).abs()
404
410
 
405
- # Find best overall model (lowest MAE)
411
+ # Compute all ensemble strategies (true ensembles that combine multiple models)
412
+ ensemble_strategies = {}
413
+ ensemble_strategies["Simple Mean"] = combined[pred_cols].mean(axis=1)
414
+ conf_sum = conf_arr.sum(axis=1, keepdims=True) + 1e-8
415
+ ensemble_strategies["Confidence-Weighted"] = (pred_arr * (conf_arr / conf_sum)).sum(axis=1)
416
+ ensemble_strategies["Inverse-MAE Weighted"] = (pred_arr * inv_mae_weights).sum(axis=1)
417
+ scaled_conf = conf_arr * inv_mae_weights
418
+ scaled_conf_sum = scaled_conf.sum(axis=1, keepdims=True) + 1e-8
419
+ ensemble_strategies["Scaled Conf-Weighted"] = (pred_arr * (scaled_conf / scaled_conf_sum)).sum(axis=1)
420
+ worst_model = max(mae_scores, key=mae_scores.get)
421
+ remaining = [n for n in model_names if n != worst_model]
422
+ remaining_cols = [f"{n}_pred" for n in remaining]
423
+ # Only add Drop Worst if it still combines multiple models
424
+ if len(remaining) > 1:
425
+ ensemble_strategies[f"Drop Worst ({worst_model})"] = combined[remaining_cols].mean(axis=1)
426
+
427
+ # Find best individual model
406
428
  best_model = min(mae_scores, key=mae_scores.get)
407
429
  combined["best_model_abs_err"] = combined[f"{best_model}_abs_err"]
430
+ best_model_mae = mae_scores[best_model]
408
431
 
409
- # Compare ensemble vs best model
432
+ # Find best true ensemble strategy
433
+ strategy_maes = {name: (preds - combined["target"]).abs().mean() for name, preds in ensemble_strategies.items()}
434
+ best_strategy = min(strategy_maes, key=strategy_maes.get)
435
+ combined["ensemble_pred"] = ensemble_strategies[best_strategy]
436
+ combined["ensemble_abs_err"] = (combined["ensemble_pred"] - combined["target"]).abs()
437
+ ensemble_mae = strategy_maes[best_strategy]
438
+
439
+ # Compare
410
440
  combined["ensemble_better"] = combined["ensemble_abs_err"] < combined["best_model_abs_err"]
411
441
  n_better = combined["ensemble_better"].sum()
412
442
  n_total = len(combined)
413
443
 
414
- ensemble_mae = combined["ensemble_abs_err"].mean()
415
- best_model_mae = mae_scores[best_model]
416
-
417
444
  print(f"\nBest individual model: {best_model} (MAE={best_model_mae:.4f})")
418
- print(f"Ensemble MAE: {ensemble_mae:.4f}")
445
+ print(f"Best ensemble strategy: {best_strategy} (MAE={ensemble_mae:.4f})")
419
446
  if ensemble_mae < best_model_mae:
420
447
  improvement = (best_model_mae - ensemble_mae) / best_model_mae * 100
421
448
  print(f"Ensemble improves over best model by {improvement:.1f}%")
422
449
  else:
423
450
  degradation = (ensemble_mae - best_model_mae) / best_model_mae * 100
424
- print(f"Ensemble is worse than best model by {degradation:.1f}%")
451
+ print(f"No ensemble benefit: best single model outperforms all ensemble strategies by {degradation:.1f}%")
425
452
 
426
453
  print("\nPer-row comparison:")
427
454
  print(f" Ensemble wins: {n_better}/{n_total} ({100*n_better/n_total:.1f}%)")
@@ -443,6 +470,7 @@ class MetaModelSimulator:
443
470
 
444
471
  return {
445
472
  "ensemble_mae": ensemble_mae,
473
+ "best_strategy": best_strategy,
446
474
  "best_model": best_model,
447
475
  "best_model_mae": best_model_mae,
448
476
  "ensemble_win_rate": n_better / n_total,
@@ -173,7 +173,6 @@ def fingerprint_prox_model_local(
173
173
  include_all_columns=include_all_columns,
174
174
  radius=radius,
175
175
  n_bits=n_bits,
176
- counts=counts,
177
176
  )
178
177
 
179
178
 
@@ -1,14 +1,18 @@
1
1
  """Plot Utilities for Workbench"""
2
2
 
3
+ import logging
3
4
  import numpy as np
4
5
  import pandas as pd
5
6
  import plotly.graph_objects as go
7
+ from dash import html
8
+
9
+ log = logging.getLogger("workbench")
6
10
 
7
11
 
8
12
  # For approximating beeswarm effect
9
13
  def beeswarm_offsets(values, point_size=0.05, precision=2, max_offset=0.3):
10
14
  """
11
- Generate optimal beeswarm offsets with a maximum limit.
15
+ Generate beeswarm offsets using random jitter with collision avoidance.
12
16
 
13
17
  Args:
14
18
  values: Array of positions to be adjusted
@@ -22,42 +26,55 @@ def beeswarm_offsets(values, point_size=0.05, precision=2, max_offset=0.3):
22
26
  values = np.asarray(values)
23
27
  rounded = np.round(values, precision)
24
28
  offsets = np.zeros_like(values, dtype=float)
25
-
26
- # Sort indices by original values
27
- sorted_idx = np.argsort(values)
29
+ rng = np.random.default_rng(42) # Fixed seed for reproducibility
28
30
 
29
31
  for val in np.unique(rounded):
30
32
  # Get indices belonging to this group
31
- group_idx = sorted_idx[np.isin(sorted_idx, np.where(rounded == val)[0])]
33
+ group_mask = rounded == val
34
+ group_idx = np.where(group_mask)[0]
32
35
 
33
36
  if len(group_idx) > 1:
34
37
  # Track occupied positions for collision detection
35
38
  occupied = []
36
39
 
37
40
  for idx in group_idx:
38
- # Find best position with no collision
39
- offset = 0
40
- direction = 1
41
- step = 0
42
-
43
- while True:
44
- # Check if current offset position is free
45
- collision = any(abs(offset - pos) < point_size for pos in occupied)
46
-
47
- if not collision or abs(offset) >= max_offset:
48
- # Accept position if no collision or max offset reached
49
- if abs(offset) > max_offset:
50
- # Clamp to maximum
51
- offset = max_offset * (1 if offset > 0 else -1)
52
- break
53
-
54
- # Switch sides with increasing distance
55
- step += 0.25
56
- direction *= -1
57
- offset = direction * step * point_size
58
-
59
- offsets[idx] = offset
60
- occupied.append(offset)
41
+ # Try random positions, starting near center and expanding outward
42
+ best_offset = 0
43
+ found = False
44
+
45
+ # First point goes to center
46
+ if not occupied:
47
+ found = True
48
+ else:
49
+ # Try random positions with increasing spread
50
+ for attempt in range(50):
51
+ # Gradually increase the range of random offsets
52
+ spread = min(max_offset, point_size * (1 + attempt * 0.5))
53
+ offset = rng.uniform(-spread, spread)
54
+
55
+ # Check for collision with occupied positions
56
+ if not any(abs(offset - pos) < point_size for pos in occupied):
57
+ best_offset = offset
58
+ found = True
59
+ break
60
+
61
+ # If no free position found after attempts, find the least crowded spot
62
+ if not found:
63
+ # Try a grid of positions and pick one with most space
64
+ candidates = np.linspace(-max_offset, max_offset, 20)
65
+ rng.shuffle(candidates)
66
+ for candidate in candidates:
67
+ if not any(abs(candidate - pos) < point_size * 0.8 for pos in occupied):
68
+ best_offset = candidate
69
+ found = True
70
+ break
71
+
72
+ # Last resort: just use a random position within bounds
73
+ if not found:
74
+ best_offset = rng.uniform(-max_offset, max_offset)
75
+
76
+ offsets[idx] = best_offset
77
+ occupied.append(best_offset)
61
78
 
62
79
  return offsets
63
80
 
@@ -179,6 +196,107 @@ def prediction_intervals(df, figure, x_col):
179
196
  return figure
180
197
 
181
198
 
199
+ def molecule_hover_tooltip(smiles: str, mol_id: str = None, width: int = 300, height: int = 200) -> list:
200
+ """Generate a molecule hover tooltip from a SMILES string.
201
+
202
+ This function creates a visually appealing tooltip with a dark background
203
+ that displays the molecule ID at the top and structure below when hovering
204
+ over scatter plot points.
205
+
206
+ Args:
207
+ smiles: SMILES string representing the molecule
208
+ mol_id: Optional molecule ID to display at the top of the tooltip
209
+ width: Width of the molecule image in pixels (default: 300)
210
+ height: Height of the molecule image in pixels (default: 200)
211
+
212
+ Returns:
213
+ list: A list containing an html.Div with the ID header and molecule SVG,
214
+ or an html.Div with an error message if rendering fails
215
+ """
216
+ try:
217
+ from workbench.utils.chem_utils.vis import svg_from_smiles
218
+ from workbench.utils.theme_manager import ThemeManager
219
+
220
+ # Get the background color from the current theme
221
+ background = ThemeManager().background()
222
+
223
+ # Generate the SVG image from SMILES
224
+ img = svg_from_smiles(smiles, width, height, background=background)
225
+
226
+ if img is None:
227
+ log.warning(f"Could not render molecule for SMILES: {smiles}")
228
+ return [
229
+ html.Div(
230
+ "Invalid SMILES",
231
+ className="custom-tooltip",
232
+ style={
233
+ "padding": "10px",
234
+ "color": "rgb(255, 140, 140)",
235
+ "width": f"{width}px",
236
+ "height": f"{height}px",
237
+ "display": "flex",
238
+ "alignItems": "center",
239
+ "justifyContent": "center",
240
+ },
241
+ )
242
+ ]
243
+
244
+ # Build the tooltip with ID header and molecule image
245
+ children = []
246
+
247
+ # Add ID header if provided
248
+ if mol_id is not None:
249
+ children.append(
250
+ html.Div(
251
+ str(mol_id),
252
+ style={
253
+ "textAlign": "center",
254
+ "padding": "8px",
255
+ "color": "rgb(200, 200, 200)",
256
+ "fontSize": "14px",
257
+ "fontWeight": "bold",
258
+ "borderBottom": "1px solid rgba(128, 128, 128, 0.5)",
259
+ },
260
+ )
261
+ )
262
+
263
+ # Add molecule image
264
+ children.append(
265
+ html.Img(
266
+ src=img,
267
+ style={"padding": "0px", "margin": "0px", "display": "block"},
268
+ width=str(width),
269
+ height=str(height),
270
+ )
271
+ )
272
+
273
+ return [
274
+ html.Div(
275
+ children,
276
+ className="custom-tooltip",
277
+ style={"padding": "0px", "margin": "0px"},
278
+ )
279
+ ]
280
+
281
+ except ImportError as e:
282
+ log.error(f"RDKit not available for molecule rendering: {e}")
283
+ return [
284
+ html.Div(
285
+ "RDKit not installed",
286
+ className="custom-tooltip",
287
+ style={
288
+ "padding": "10px",
289
+ "color": "rgb(255, 195, 140)",
290
+ "width": f"{width}px",
291
+ "height": f"{height}px",
292
+ "display": "flex",
293
+ "alignItems": "center",
294
+ "justifyContent": "center",
295
+ },
296
+ )
297
+ ]
298
+
299
+
182
300
  if __name__ == "__main__":
183
301
  """Exercise the Plot Utilities"""
184
302
  import plotly.express as px
@@ -9,6 +9,7 @@ from typing import Optional, List, Tuple, Dict, Union
9
9
  from workbench.utils.xgboost_model_utils import xgboost_model_from_s3
10
10
  from workbench.utils.model_utils import load_category_mappings_from_s3
11
11
  from workbench.utils.pandas_utils import convert_categorical_types
12
+ from workbench.model_script_utils.model_script_utils import decompress_features
12
13
 
13
14
  # Set up the log
14
15
  log = logging.getLogger("workbench")
@@ -111,61 +112,6 @@ def shap_values_data(
111
112
  return result_df, feature_df
112
113
 
113
114
 
114
- def decompress_features(
115
- df: pd.DataFrame, features: List[str], compressed_features: List[str]
116
- ) -> Tuple[pd.DataFrame, List[str]]:
117
- """Prepare features for the XGBoost model
118
-
119
- Args:
120
- df (pd.DataFrame): The features DataFrame
121
- features (List[str]): Full list of feature names
122
- compressed_features (List[str]): List of feature names to decompress (bitstrings)
123
-
124
- Returns:
125
- pd.DataFrame: DataFrame with the decompressed features
126
- List[str]: Updated list of feature names after decompression
127
-
128
- Raises:
129
- ValueError: If any missing values are found in the specified features
130
- """
131
-
132
- # Check for any missing values in the required features
133
- missing_counts = df[features].isna().sum()
134
- if missing_counts.any():
135
- missing_features = missing_counts[missing_counts > 0]
136
- print(
137
- f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
138
- "WARNING: You might want to remove/replace all NaN values before processing."
139
- )
140
-
141
- # Decompress the specified compressed features
142
- decompressed_features = features
143
- for feature in compressed_features:
144
- if (feature not in df.columns) or (feature not in features):
145
- print(f"Feature '{feature}' not in the features list, skipping decompression.")
146
- continue
147
-
148
- # Remove the feature from the list of features to avoid duplication
149
- decompressed_features.remove(feature)
150
-
151
- # Handle all compressed features as bitstrings
152
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
153
- prefix = feature[:3]
154
-
155
- # Create all new columns at once - avoids fragmentation
156
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
157
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
158
-
159
- # Add to features list
160
- decompressed_features.extend(new_col_names)
161
-
162
- # Drop original column and concatenate new ones
163
- df = df.drop(columns=[feature])
164
- df = pd.concat([df, new_df], axis=1)
165
-
166
- return df, decompressed_features
167
-
168
-
169
115
  def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
170
116
  """
171
117
  Internal function to calculate SHAP values for Workbench Models.
@@ -76,10 +76,28 @@ class ThemeManager:
76
76
  def set_theme(cls, theme_name: str):
77
77
  """Set the current theme."""
78
78
 
79
- # For 'auto', we try to grab a theme from the Parameter Store
80
- # if we can't find one, we'll set the theme to the default
79
+ # For 'auto', we check multiple sources in priority order:
80
+ # 1. Browser cookie (from localStorage, for per-user preference)
81
+ # 2. Parameter Store (for org-wide default)
82
+ # 3. Default theme
81
83
  if theme_name == "auto":
82
- theme_name = cls.ps.get("/workbench/dashboard/theme", warn=False) or cls.default_theme
84
+ theme_name = None
85
+
86
+ # 1. Check Flask request cookie (set from localStorage)
87
+ try:
88
+ from flask import request, has_request_context
89
+
90
+ if has_request_context():
91
+ theme_name = request.cookies.get("wb_theme")
92
+ except Exception:
93
+ pass
94
+
95
+ # 2. Fall back to ParameterStore
96
+ if not theme_name:
97
+ theme_name = cls.ps.get("/workbench/dashboard/theme", warn=False)
98
+
99
+ # 3. Fall back to default
100
+ theme_name = theme_name or cls.default_theme
83
101
 
84
102
  # Check if the theme is in our available themes
85
103
  if theme_name not in cls.available_themes:
@@ -104,9 +122,27 @@ class ThemeManager:
104
122
  cls.current_theme_name = theme_name
105
123
  cls.log.info(f"Theme set to '{theme_name}'")
106
124
 
125
+ # Bootstrap themes that are dark mode (from Bootswatch)
126
+ _dark_bootstrap_themes = {"DARKLY", "CYBORG", "SLATE", "SOLAR", "SUPERHERO", "VAPOR"}
127
+
107
128
  @classmethod
108
129
  def dark_mode(cls) -> bool:
109
- """Check if the current theme is a dark mode theme."""
130
+ """Check if the current theme is a dark mode theme.
131
+
132
+ Determines dark mode by checking if the Bootstrap base theme is a known dark theme.
133
+ Falls back to checking if 'dark' is in the theme name.
134
+ """
135
+ theme = cls.available_themes.get(cls.current_theme_name, {})
136
+ base_css = theme.get("base_css", "")
137
+
138
+ # Check if the base CSS URL contains a known dark Bootstrap theme
139
+ if base_css:
140
+ base_css_upper = base_css.upper()
141
+ for dark_theme in cls._dark_bootstrap_themes:
142
+ if dark_theme in base_css_upper:
143
+ return True
144
+
145
+ # Fallback: check if 'dark' is in the theme name
110
146
  return "dark" in cls.current_theme().lower()
111
147
 
112
148
  @classmethod
@@ -184,30 +220,57 @@ class ThemeManager:
184
220
 
185
221
  @classmethod
186
222
  def css_files(cls) -> list[str]:
187
- """Get the list of CSS files for the current theme."""
188
- theme = cls.available_themes[cls.current_theme_name]
223
+ """Get the list of CSS files for the current theme.
224
+
225
+ Note: Uses /base.css route for dynamic theme switching instead of CDN URLs.
226
+ """
189
227
  css_files = []
190
228
 
191
- # Add base.css or its CDN URL
192
- if theme["base_css"]:
193
- css_files.append(theme["base_css"])
229
+ # Use Flask route for base CSS (allows dynamic theme switching)
230
+ css_files.append("/base.css")
194
231
 
195
232
  # Add the DBC template CSS
196
233
  css_files.append(cls.dbc_css)
197
234
 
198
235
  # Add custom.css if it exists
199
- if theme["custom_css"]:
200
- css_files.append("/custom.css")
236
+ css_files.append("/custom.css")
201
237
 
202
238
  return css_files
203
239
 
240
+ @classmethod
241
+ def _get_theme_from_cookie(cls):
242
+ """Get the theme dict based on the wb_theme cookie, falling back to current theme."""
243
+ from flask import request
244
+
245
+ theme_name = request.cookies.get("wb_theme")
246
+ if theme_name and theme_name in cls.available_themes:
247
+ return cls.available_themes[theme_name], theme_name
248
+ return cls.available_themes[cls.current_theme_name], cls.current_theme_name
249
+
204
250
  @classmethod
205
251
  def register_css_route(cls, app):
206
- """Register Flask route for custom.css."""
252
+ """Register Flask routes for CSS and before_request hook for theme switching."""
253
+ from flask import redirect
254
+
255
+ @app.server.before_request
256
+ def check_theme_cookie():
257
+ """Check for theme cookie on each request and update theme if needed."""
258
+ _, theme_name = cls._get_theme_from_cookie()
259
+ if theme_name != cls.current_theme_name:
260
+ cls.set_theme(theme_name)
261
+
262
+ @app.server.route("/base.css")
263
+ def serve_base_css():
264
+ """Redirect to the appropriate Bootstrap theme CSS based on cookie."""
265
+ theme, _ = cls._get_theme_from_cookie()
266
+ if theme["base_css"]:
267
+ return redirect(theme["base_css"])
268
+ return "", 404
207
269
 
208
270
  @app.server.route("/custom.css")
209
271
  def serve_custom_css():
210
- theme = cls.available_themes[cls.current_theme_name]
272
+ """Serve the custom.css file based on cookie."""
273
+ theme, _ = cls._get_theme_from_cookie()
211
274
  if theme["custom_css"]:
212
275
  return send_from_directory(theme["custom_css"].parent, theme["custom_css"].name)
213
276
  return "", 404
@@ -250,23 +313,25 @@ class ThemeManager:
250
313
  # Loop over each path in the theme path
251
314
  for theme_path in cls.theme_path_list:
252
315
  for theme_dir in theme_path.iterdir():
253
- if theme_dir.is_dir():
254
- theme_name = theme_dir.name
255
-
256
- # Grab the base.css URL
257
- base_css_url = cls._get_base_css_url(theme_dir)
258
-
259
- # Grab the plotly template json, custom.css, and branding json
260
- plotly_template = theme_dir / "plotly.json"
261
- custom_css = theme_dir / "custom.css"
262
- branding = theme_dir / "branding.json"
263
-
264
- cls.available_themes[theme_name] = {
265
- "base_css": base_css_url,
266
- "plotly_template": plotly_template,
267
- "custom_css": custom_css if custom_css.exists() else None,
268
- "branding": branding if branding.exists() else None,
269
- }
316
+ # Skip hidden directories (e.g., .idea, .git)
317
+ if not theme_dir.is_dir() or theme_dir.name.startswith("."):
318
+ continue
319
+ theme_name = theme_dir.name
320
+
321
+ # Grab the base.css URL
322
+ base_css_url = cls._get_base_css_url(theme_dir)
323
+
324
+ # Grab the plotly template json, custom.css, and branding json
325
+ plotly_template = theme_dir / "plotly.json"
326
+ custom_css = theme_dir / "custom.css"
327
+ branding = theme_dir / "branding.json"
328
+
329
+ cls.available_themes[theme_name] = {
330
+ "base_css": base_css_url,
331
+ "plotly_template": plotly_template,
332
+ "custom_css": custom_css if custom_css.exists() else None,
333
+ "branding": branding if branding.exists() else None,
334
+ }
270
335
 
271
336
  if not cls.available_themes:
272
337
  cls.log.warning(f"No themes found in '{cls.theme_path_list}'...")