workbench 0.8.234__py3-none-any.whl → 0.8.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. workbench/algorithms/dataframe/smart_aggregator.py +17 -12
  2. workbench/api/endpoint.py +13 -4
  3. workbench/api/model.py +2 -2
  4. workbench/cached/cached_model.py +2 -2
  5. workbench/core/artifacts/athena_source.py +5 -3
  6. workbench/core/artifacts/endpoint_core.py +30 -5
  7. workbench/core/cloud_platform/aws/aws_meta.py +2 -1
  8. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +27 -14
  9. workbench/model_script_utils/model_script_utils.py +225 -0
  10. workbench/model_script_utils/uq_harness.py +39 -21
  11. workbench/model_scripts/chemprop/chemprop.template +30 -15
  12. workbench/model_scripts/chemprop/generated_model_script.py +35 -18
  13. workbench/model_scripts/chemprop/model_script_utils.py +225 -0
  14. workbench/model_scripts/pytorch_model/generated_model_script.py +29 -15
  15. workbench/model_scripts/pytorch_model/model_script_utils.py +225 -0
  16. workbench/model_scripts/pytorch_model/pytorch.template +28 -14
  17. workbench/model_scripts/pytorch_model/uq_harness.py +39 -21
  18. workbench/model_scripts/xgb_model/generated_model_script.py +35 -22
  19. workbench/model_scripts/xgb_model/model_script_utils.py +225 -0
  20. workbench/model_scripts/xgb_model/uq_harness.py +39 -21
  21. workbench/model_scripts/xgb_model/xgb_model.template +29 -18
  22. workbench/scripts/ml_pipeline_batch.py +47 -2
  23. workbench/scripts/ml_pipeline_launcher.py +410 -0
  24. workbench/scripts/ml_pipeline_sqs.py +22 -2
  25. workbench/themes/dark/custom.css +29 -0
  26. workbench/themes/light/custom.css +29 -0
  27. workbench/themes/midnight_blue/custom.css +28 -0
  28. workbench/utils/model_utils.py +9 -0
  29. workbench/utils/theme_manager.py +95 -0
  30. workbench/web_interface/components/component_interface.py +3 -0
  31. workbench/web_interface/components/plugin_interface.py +26 -0
  32. workbench/web_interface/components/plugins/ag_table.py +4 -11
  33. workbench/web_interface/components/plugins/confusion_matrix.py +14 -8
  34. workbench/web_interface/components/plugins/model_plot.py +156 -0
  35. workbench/web_interface/components/plugins/scatter_plot.py +9 -2
  36. workbench/web_interface/components/plugins/shap_summary_plot.py +12 -4
  37. workbench/web_interface/components/settings_menu.py +10 -49
  38. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/METADATA +2 -2
  39. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/RECORD +43 -42
  40. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/WHEEL +1 -1
  41. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/entry_points.txt +1 -0
  42. workbench/web_interface/components/model_plot.py +0 -75
  43. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/licenses/LICENSE +0 -0
  44. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/top_level.txt +0 -0
@@ -155,6 +155,101 @@ class ThemeManager:
155
155
  """Get the name of the current theme."""
156
156
  return cls.current_theme_name
157
157
 
158
+ @classmethod
159
+ def get_theme_urls(cls) -> dict:
160
+ """Get a mapping of theme names to their Bootstrap CSS URLs."""
161
+ return {name: theme["base_css"] for name, theme in cls.available_themes.items() if theme["base_css"]}
162
+
163
+ @classmethod
164
+ def get_dark_themes(cls) -> list:
165
+ """Get list of theme names that are dark mode themes."""
166
+ dark_themes = []
167
+ for name, theme in cls.available_themes.items():
168
+ base_css = theme.get("base_css", "")
169
+ if base_css:
170
+ base_css_upper = base_css.upper()
171
+ for dark_theme in cls._dark_bootstrap_themes:
172
+ if dark_theme in base_css_upper:
173
+ dark_themes.append(name)
174
+ break
175
+ else:
176
+ # Fallback: check if 'dark' is in the theme name
177
+ if "dark" in name.lower():
178
+ dark_themes.append(name)
179
+ return dark_themes
180
+
181
+ @classmethod
182
+ def get_theme_switch_js(cls) -> str:
183
+ """Get the JavaScript code for dynamic theme switching.
184
+
185
+ This code should be used in a clientside callback that updates the Bootstrap
186
+ stylesheet and data-bs-theme attribute when the user selects a new theme.
187
+ """
188
+ theme_urls = cls.get_theme_urls()
189
+ dark_themes = cls.get_dark_themes()
190
+
191
+ return f"""
192
+ function(n_clicks_list, ids) {{
193
+ // Use callback context to find which input triggered this callback
194
+ const ctx = window.dash_clientside.callback_context;
195
+ if (!ctx || !ctx.triggered || ctx.triggered.length === 0) {{
196
+ return window.dash_clientside.no_update;
197
+ }}
198
+
199
+ // Get the triggered input's id (it's a JSON string for pattern-matching callbacks)
200
+ const triggeredId = ctx.triggered[0].prop_id.split('.')[0];
201
+ let clickedTheme = null;
202
+
203
+ try {{
204
+ const parsedId = JSON.parse(triggeredId);
205
+ clickedTheme = parsedId.theme;
206
+ }} catch (e) {{
207
+ // Not a pattern-matching callback ID
208
+ return window.dash_clientside.no_update;
209
+ }}
210
+
211
+ if (!clickedTheme) {{
212
+ return window.dash_clientside.no_update;
213
+ }}
214
+
215
+ // Store in localStorage and cookie
216
+ localStorage.setItem('wb_theme', clickedTheme);
217
+ document.cookie = `wb_theme=${{clickedTheme}}; path=/; max-age=31536000`;
218
+
219
+ // Theme URL mapping (generated from ThemeManager)
220
+ const themeUrls = {json.dumps(theme_urls)};
221
+ const darkThemes = {json.dumps(dark_themes)};
222
+
223
+ // Update Bootstrap's data-bs-theme attribute
224
+ const bsTheme = darkThemes.includes(clickedTheme) ? 'dark' : 'light';
225
+ document.documentElement.setAttribute('data-bs-theme', bsTheme);
226
+
227
+ // Find and update the Bootstrap stylesheet
228
+ const newUrl = themeUrls[clickedTheme];
229
+ if (newUrl) {{
230
+ let stylesheet = document.querySelector('link[href*="bootswatch"]') ||
231
+ document.querySelector('link[href*="bootstrap"]') ||
232
+ document.querySelector('link[href="/base.css"]');
233
+ if (stylesheet) {{
234
+ stylesheet.setAttribute('href', newUrl);
235
+ }} else {{
236
+ stylesheet = document.createElement('link');
237
+ stylesheet.rel = 'stylesheet';
238
+ stylesheet.href = newUrl;
239
+ document.head.appendChild(stylesheet);
240
+ }}
241
+ }}
242
+
243
+ // Force reload custom.css with cache-busting query param
244
+ let customCss = document.querySelector('link[href^="/custom.css"]');
245
+ if (customCss) {{
246
+ customCss.setAttribute('href', `/custom.css?t=${{Date.now()}}`);
247
+ }}
248
+
249
+ return clickedTheme;
250
+ }}
251
+ """
252
+
158
253
  @classmethod
159
254
  def background(cls) -> list[list[float | str]]:
160
255
  """Get the plot background for the current theme."""
@@ -136,6 +136,7 @@ class ComponentInterface(ABC):
136
136
  xref="paper",
137
137
  yref="paper",
138
138
  text=text_message,
139
+ bgcolor="rgba(0,0,0,0)",
139
140
  showarrow=False,
140
141
  font=dict(size=font_size, color="#9999cc"),
141
142
  )
@@ -144,6 +145,8 @@ class ComponentInterface(ABC):
144
145
  xaxis=dict(showticklabels=False, zeroline=False, showgrid=False),
145
146
  yaxis=dict(showticklabels=False, zeroline=False, showgrid=False),
146
147
  margin=dict(l=0, r=0, b=0, t=0),
148
+ paper_bgcolor="rgba(0,0,0,0)",
149
+ plot_bgcolor="rgba(0,0,0,0)",
147
150
  )
148
151
 
149
152
  if figure_height is not None:
@@ -3,10 +3,12 @@ import inspect
3
3
  from typing import Union, Tuple, get_args, get_origin
4
4
  from enum import Enum
5
5
  import logging
6
+ from dash import no_update
6
7
  from dash.development.base_component import Component
7
8
 
8
9
  # Local Imports
9
10
  from workbench.web_interface.components.component_interface import ComponentInterface
11
+ from workbench.utils.theme_manager import ThemeManager
10
12
 
11
13
  log = logging.getLogger("workbench")
12
14
 
@@ -48,6 +50,9 @@ class PluginInterface(ComponentInterface):
48
50
  - The 'register_internal_callbacks' method is optional
49
51
  """
50
52
 
53
+ # Shared ThemeManager instance for all plugins
54
+ theme_manager = ThemeManager()
55
+
51
56
  @abstractmethod
52
57
  def create_component(self, component_id: str) -> Component:
53
58
  """Create a Dash Component without any data.
@@ -83,6 +88,27 @@ class PluginInterface(ComponentInterface):
83
88
  """
84
89
  pass
85
90
 
91
+ def set_theme(self, theme: str) -> list:
92
+ """Called when the application theme changes. Override to re-render with new theme colors.
93
+
94
+ The default implementation returns no_update for all properties. Plugins that have
95
+ theme-dependent rendering (e.g., figures with colorscales) should override this method
96
+ to re-render their components.
97
+
98
+ Args:
99
+ theme (str): The name of the new theme (e.g., "light", "dark", "midnight_blue").
100
+
101
+ Returns:
102
+ list: Updated property values (same format as update_properties), or no_update list.
103
+
104
+ Example:
105
+ def set_theme(self, theme: str) -> list:
106
+ if self.model is None:
107
+ return [no_update] * len(self.properties)
108
+ return self.update_properties(self.model)
109
+ """
110
+ return [no_update] * len(self.properties)
111
+
86
112
  #
87
113
  # Internal Methods: These methods are used to validate the plugin interface at runtime
88
114
  #
@@ -96,18 +96,11 @@ class AGTable(PluginInterface):
96
96
 
97
97
  if __name__ == "__main__":
98
98
  # Run the Unit Test for the Plugin
99
+ from workbench.api import Meta
99
100
  from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
100
101
 
101
- # Test data
102
- data = {
103
- "ID": [f"id_{i}" for i in range(10)],
104
- "feat1": [1.0, 1.0, 1.1, 3.0, 4.0, 1.0, 1.0, 1.1, 3.0, 4.0],
105
- "feat2": [1.0, 1.0, 1.1, 3.0, 4.0, 1.0, 1.0, 1.1, 3.0, 4.0],
106
- "feat3": [0.1, 0.15, 0.2, 0.9, 2.8, 0.25, 0.35, 0.4, 1.6, 2.5],
107
- "price": [31, 60, 62, 40, 20, 31, 61, 60, 40, 20],
108
- "name": ["A", "B", "C", "D", "E", "F", "G", "H", "I", "Z" * 55],
109
- }
110
- test_df = pd.DataFrame(data)
102
+ # Test on model data
103
+ models_df = Meta().models(details=True)
111
104
 
112
105
  # Run the Unit Test on the Plugin
113
- PluginUnitTest(AGTable, theme="dark", input_data=test_df, max_height=500).run()
106
+ PluginUnitTest(AGTable, theme="dark", input_data=models_df, max_height=500).run()
@@ -1,11 +1,10 @@
1
1
  """A confusion matrix plugin component"""
2
2
 
3
- from dash import dcc, callback, Output, Input, State
3
+ from dash import dcc, callback, Output, Input, State, no_update
4
4
  import plotly.graph_objects as go
5
5
 
6
6
  # Workbench Imports
7
7
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
8
- from workbench.utils.theme_manager import ThemeManager
9
8
  from workbench.cached.cached_model import CachedModel
10
9
  from workbench.utils.color_utils import add_alpha_to_first_color
11
10
 
@@ -19,10 +18,8 @@ class ConfusionMatrix(PluginInterface):
19
18
  def __init__(self):
20
19
  """Initialize the ConfusionMatrix plugin class"""
21
20
  self.component_id = None
22
- self.current_highlight = None # Store the currently highlighted cell
23
- self.theme_manager = ThemeManager()
24
-
25
- # Call the parent class constructor
21
+ self.model = None # Store the model for re-rendering on theme change
22
+ self.inference_run = None # Store the inference run for re-rendering
26
23
  super().__init__()
27
24
 
28
25
  def create_component(self, component_id: str) -> dcc.Graph:
@@ -57,9 +54,12 @@ class ConfusionMatrix(PluginInterface):
57
54
  Returns:
58
55
  list: A list containing the updated Plotly figure.
59
56
  """
57
+ # Store for re-rendering on theme change
58
+ self.model = model
59
+ self.inference_run = kwargs.get("inference_run", "auto_inference")
60
+
60
61
  # Retrieve the confusion matrix data
61
- inference_run = kwargs.get("inference_run", "auto_inference")
62
- df = model.confusion_matrix(inference_run)
62
+ df = model.confusion_matrix(self.inference_run)
63
63
  if df is None:
64
64
  return [self.display_text("No Data")]
65
65
 
@@ -137,6 +137,12 @@ class ConfusionMatrix(PluginInterface):
137
137
  # Return the updated figure wrapped in a list
138
138
  return [fig]
139
139
 
140
+ def set_theme(self, theme: str) -> list:
141
+ """Re-render the confusion matrix when the theme changes."""
142
+ if self.model is None:
143
+ return [no_update] * len(self.properties)
144
+ return self.update_properties(self.model, inference_run=self.inference_run)
145
+
140
146
  def register_internal_callbacks(self):
141
147
  """Register internal callbacks for the plugin."""
142
148
 
@@ -0,0 +1,156 @@
1
+ """A Model Plot plugin that displays the appropriate visualization based on model type.
2
+
3
+ For classifiers: Shows a Confusion Matrix
4
+ For regressors: Shows a Scatter Plot of predictions vs actuals
5
+ """
6
+
7
+ from dash import html, no_update
8
+
9
+ # Workbench Imports
10
+ from workbench.api import ModelType
11
+ from workbench.cached.cached_model import CachedModel
12
+ from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
13
+ from workbench.web_interface.components.plugins.confusion_matrix import ConfusionMatrix
14
+ from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
15
+
16
+
17
+ class ModelPlot(PluginInterface):
18
+ """Model Plot Plugin - switches between ConfusionMatrix and ScatterPlot based on model type."""
19
+
20
+ auto_load_page = PluginPage.NONE
21
+ plugin_input_type = PluginInputType.MODEL
22
+
23
+ def __init__(self):
24
+ """Initialize the ModelPlot plugin class"""
25
+ self.component_id = None
26
+ self.model = None
27
+ self.inference_run = None
28
+
29
+ # Internal plugins
30
+ self.scatter_plot = ScatterPlot()
31
+ self.confusion_matrix = ConfusionMatrix()
32
+
33
+ # Call the parent class constructor
34
+ super().__init__()
35
+
36
+ def create_component(self, component_id: str) -> html.Div:
37
+ """Create a container with both ScatterPlot and ConfusionMatrix components.
38
+
39
+ Args:
40
+ component_id (str): The ID of the web component
41
+
42
+ Returns:
43
+ html.Div: Container with both plot types (one hidden based on model type)
44
+ """
45
+ self.component_id = component_id
46
+
47
+ # Create internal components
48
+ scatter_component = self.scatter_plot.create_component(f"{component_id}-scatter")
49
+ confusion_component = self.confusion_matrix.create_component(f"{component_id}-confusion")
50
+
51
+ # Build properties list: visibility styles + scatter props + confusion props
52
+ self.properties = [
53
+ (f"{component_id}-scatter-container", "style"),
54
+ (f"{component_id}-confusion-container", "style"),
55
+ ]
56
+ self.properties.extend(self.scatter_plot.properties)
57
+ self.properties.extend(self.confusion_matrix.properties)
58
+
59
+ # Aggregate signals from both plugins
60
+ self.signals = self.scatter_plot.signals + self.confusion_matrix.signals
61
+
62
+ # Create container with both components
63
+ # Show scatter plot by default (will display "Waiting for Data..." until model loads)
64
+ return html.Div(
65
+ id=component_id,
66
+ children=[
67
+ html.Div(
68
+ scatter_component,
69
+ id=f"{component_id}-scatter-container",
70
+ style={"display": "block"},
71
+ ),
72
+ html.Div(
73
+ confusion_component,
74
+ id=f"{component_id}-confusion-container",
75
+ style={"display": "none"},
76
+ ),
77
+ ],
78
+ )
79
+
80
+ def update_properties(self, model: CachedModel, **kwargs) -> list:
81
+ """Update the plot based on model type.
82
+
83
+ Args:
84
+ model (CachedModel): The model to visualize
85
+ **kwargs:
86
+ - inference_run (str): Inference capture name (default: "auto_inference")
87
+
88
+ Returns:
89
+ list: Property values [scatter_style, confusion_style, ...scatter_props, ...confusion_props]
90
+ """
91
+ # Cache for theme re-rendering
92
+ self.model = model
93
+ self.inference_run = kwargs.get("inference_run", "full_cross_fold")
94
+
95
+ # Determine model type and set visibility
96
+ is_classifier = model.model_type == ModelType.CLASSIFIER
97
+ scatter_style = {"display": "none"} if is_classifier else {"display": "block"}
98
+ confusion_style = {"display": "block"} if is_classifier else {"display": "none"}
99
+
100
+ if is_classifier:
101
+ # Update ConfusionMatrix, no_update for ScatterPlot
102
+ cm_props = self.confusion_matrix.update_properties(model, inference_run=self.inference_run)
103
+ scatter_props = [no_update] * len(self.scatter_plot.properties)
104
+ else:
105
+ # Update ScatterPlot with regression data
106
+ df = model.get_inference_predictions(self.inference_run)
107
+ if df is None:
108
+ # Still update visibility styles, but no_update for plugin properties
109
+ scatter_props = [no_update] * len(self.scatter_plot.properties)
110
+ cm_props = [no_update] * len(self.confusion_matrix.properties)
111
+ return [scatter_style, confusion_style] + scatter_props + cm_props
112
+
113
+ # Get target column for the x-axis
114
+ target = model.target()
115
+ if isinstance(target, list):
116
+ target = next((t for t in target if t in self.inference_run), target[0])
117
+
118
+ # Check if "confidence" column exists for coloring
119
+ color_col = "confidence" if "confidence" in df.columns else "prediction"
120
+
121
+ scatter_props = self.scatter_plot.update_properties(
122
+ df,
123
+ x=target,
124
+ y="prediction",
125
+ color=color_col,
126
+ regression_line=True,
127
+ )
128
+ cm_props = [no_update] * len(self.confusion_matrix.properties)
129
+
130
+ return [scatter_style, confusion_style] + scatter_props + cm_props
131
+
132
+ def set_theme(self, theme: str) -> list:
133
+ """Re-render the appropriate plot when the theme changes."""
134
+ if self.model is None:
135
+ return [no_update] * len(self.properties)
136
+
137
+ # Just call update_properties which will re-render the right plot
138
+ return self.update_properties(self.model, inference_run=self.inference_run)
139
+
140
+ def register_internal_callbacks(self):
141
+ """Register internal callbacks for both sub-plugins."""
142
+ self.scatter_plot.register_internal_callbacks()
143
+ self.confusion_matrix.register_internal_callbacks()
144
+
145
+
146
+ if __name__ == "__main__":
147
+ """Run the Unit Test for the Plugin."""
148
+ from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
149
+
150
+ # Test with a classifier (shows Confusion Matrix)
151
+ classifier_model = CachedModel("wine-classification")
152
+ PluginUnitTest(ModelPlot, input_data=classifier_model, theme="dark").run()
153
+
154
+ # Test with a regressor (shows Scatter Plot)
155
+ regressor_model = CachedModel("abalone-regression")
156
+ PluginUnitTest(ModelPlot, input_data=regressor_model, theme="dark").run()
@@ -8,7 +8,6 @@ from dash.exceptions import PreventUpdate
8
8
 
9
9
  # Workbench Imports
10
10
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
11
- from workbench.utils.theme_manager import ThemeManager
12
11
  from workbench.utils.plot_utils import prediction_intervals
13
12
  from workbench.utils.chem_utils.vis import molecule_hover_tooltip
14
13
  from workbench.utils.clientside_callbacks import circle_overlay_callback
@@ -37,7 +36,6 @@ class ScatterPlot(PluginInterface):
37
36
  self.hover_columns = []
38
37
  self.df = None
39
38
  self.show_axes = show_axes
40
- self.theme_manager = ThemeManager()
41
39
  self.has_smiles = False # Track if dataframe has smiles column for molecule hover
42
40
  self.smiles_column = None
43
41
  self.id_column = None
@@ -242,6 +240,15 @@ class ScatterPlot(PluginInterface):
242
240
 
243
241
  return [figure, x_options, y_options, color_options, x_default, y_default, color_default, regression_line_value]
244
242
 
243
+ def set_theme(self, theme: str) -> list:
244
+ """Re-render the scatter plot when the theme changes."""
245
+ # If no data yet, return no_update for all properties
246
+ if self.df is None or self.df.empty:
247
+ return [no_update] * len(self.properties)
248
+
249
+ # Re-render with defaults (user dropdown selections reset, but theme changes are rare)
250
+ return self.update_properties(self.df)
251
+
245
252
  def create_scatter_plot(
246
253
  self,
247
254
  df: pd.DataFrame,
@@ -1,6 +1,6 @@
1
1
  """SHAP Summary Plot visualization component for XGBoost models"""
2
2
 
3
- from dash import dcc
3
+ from dash import dcc, no_update
4
4
  import plotly.graph_objects as go
5
5
  import pandas as pd
6
6
  import numpy as np
@@ -9,7 +9,6 @@ from typing import Dict, List
9
9
  # Workbench Imports
10
10
  from workbench.cached.cached_model import CachedModel
11
11
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
12
- from workbench.utils.theme_manager import ThemeManager
13
12
  from workbench.utils.plot_utils import beeswarm_offsets
14
13
 
15
14
 
@@ -22,7 +21,7 @@ class ShapSummaryPlot(PluginInterface):
22
21
  def __init__(self):
23
22
  """Initialize the ShapSummaryPlot plugin class"""
24
23
  self.component_id = None
25
- self.theme_manager = ThemeManager()
24
+ self.model = None # Store the model for re-rendering on theme change
26
25
  super().__init__()
27
26
 
28
27
  def create_component(self, component_id: str) -> dcc.Graph:
@@ -39,6 +38,9 @@ class ShapSummaryPlot(PluginInterface):
39
38
 
40
39
  def update_properties(self, model: CachedModel, **kwargs) -> list:
41
40
  """Create a SHAP Summary Plot for feature importance visualization."""
41
+ # Store for re-rendering on theme change
42
+ self.model = model
43
+
42
44
  # Basic validation
43
45
  shap_data = model.shap_data()
44
46
  shap_sample_rows = model.shap_sample()
@@ -221,9 +223,15 @@ class ShapSummaryPlot(PluginInterface):
221
223
  )
222
224
  return main_fig
223
225
 
226
+ def set_theme(self, theme: str) -> list:
227
+ """Re-render the SHAP summary plot when the theme changes."""
228
+ if self.model is None:
229
+ return [no_update] * len(self.properties)
230
+ return self.update_properties(self.model)
231
+
224
232
  def register_internal_callbacks(self):
225
233
  """Register internal callbacks for the plugin."""
226
- pass # Implement if needed
234
+ pass # No internal callbacks needed
227
235
 
228
236
 
229
237
  if __name__ == "__main__":
@@ -100,74 +100,35 @@ class SettingsMenu:
100
100
  caret=False,
101
101
  align_end=True,
102
102
  ),
103
- # Dummy store for the clientside callback output
104
- dcc.Store(id=f"{component_id}-dummy", data=None),
105
- # Store to trigger checkmark update on load
103
+ # Store to trigger checkmark update on load and theme change
106
104
  dcc.Store(id=f"{component_id}-init", data=True),
107
105
  ],
108
106
  id=component_id,
109
107
  )
110
108
 
111
- @staticmethod
112
- def get_clientside_callback_code(component_id: str) -> str:
109
+ def get_clientside_callback_code(self) -> str:
113
110
  """Get the JavaScript code for the theme selection clientside callback.
114
111
 
115
- Args:
116
- component_id (str): The ID prefix used in create_component.
117
-
118
112
  Returns:
119
113
  str: JavaScript code for the clientside callback.
120
114
  """
121
- return """
122
- function(n_clicks_list, ids) {
123
- // Find which button was clicked
124
- if (!n_clicks_list || n_clicks_list.every(n => !n)) {
125
- return window.dash_clientside.no_update;
126
- }
127
-
128
- // Find the clicked theme
129
- let clickedTheme = null;
130
- for (let i = 0; i < n_clicks_list.length; i++) {
131
- if (n_clicks_list[i]) {
132
- clickedTheme = ids[i].theme;
133
- break;
134
- }
135
- }
136
-
137
- if (clickedTheme) {
138
- // Store in localStorage
139
- localStorage.setItem('wb_theme', clickedTheme);
140
- // Set cookie for Flask to read on reload
141
- document.cookie = `wb_theme=${clickedTheme}; path=/; max-age=31536000`;
142
- // Reload the page to apply the new theme
143
- window.location.reload();
144
- }
145
-
146
- return window.dash_clientside.no_update;
147
- }
148
- """
115
+ return self.tm.get_theme_switch_js()
149
116
 
150
117
  @staticmethod
151
118
  def get_checkmark_callback_code() -> str:
152
- """Get the JavaScript code to update checkmarks based on localStorage.
119
+ """Get the JavaScript code to update checkmarks based on current theme.
153
120
 
154
121
  Returns:
155
122
  str: JavaScript code for the checkmark update callback.
156
123
  """
157
124
  return """
158
- function(init, ids) {
159
- // Get current theme from localStorage (or cookie as fallback)
160
- let currentTheme = localStorage.getItem('wb_theme');
125
+ function(theme, ids) {
126
+ // If theme is a string (from theme switch), use it directly
127
+ let currentTheme = (typeof theme === 'string') ? theme : null;
128
+
129
+ // Otherwise, get from localStorage
161
130
  if (!currentTheme) {
162
- // Try to read from cookie
163
- const cookies = document.cookie.split(';');
164
- for (let cookie of cookies) {
165
- const [name, value] = cookie.trim().split('=');
166
- if (name === 'wb_theme') {
167
- currentTheme = value;
168
- break;
169
- }
170
- }
131
+ currentTheme = localStorage.getItem('wb_theme');
171
132
  }
172
133
 
173
134
  // Return checkmarks for each theme
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: workbench
3
- Version: 0.8.234
3
+ Version: 0.8.239
4
4
  Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
5
5
  Author-email: SuperCowPowers LLC <support@supercowpowers.com>
6
6
  License: MIT License
@@ -40,7 +40,7 @@ Requires-Dist: boto3>=1.31.76
40
40
  Requires-Dist: botocore>=1.31.76
41
41
  Requires-Dist: redis>=5.0.1
42
42
  Requires-Dist: numpy>=1.26.4
43
- Requires-Dist: pandas>=2.2.1
43
+ Requires-Dist: pandas<3.0,>=2.2.1
44
44
  Requires-Dist: awswrangler>=3.4.0
45
45
  Requires-Dist: sagemaker<3.0,>=2.143
46
46
  Requires-Dist: cryptography>=44.0.2