workbench 0.8.234__py3-none-any.whl → 0.8.236__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. workbench/algorithms/dataframe/smart_aggregator.py +17 -12
  2. workbench/api/endpoint.py +13 -4
  3. workbench/cached/cached_model.py +2 -2
  4. workbench/core/artifacts/endpoint_core.py +30 -5
  5. workbench/model_script_utils/model_script_utils.py +225 -0
  6. workbench/model_script_utils/uq_harness.py +39 -21
  7. workbench/model_scripts/chemprop/chemprop.template +29 -14
  8. workbench/model_scripts/chemprop/generated_model_script.py +35 -18
  9. workbench/model_scripts/chemprop/model_script_utils.py +225 -0
  10. workbench/model_scripts/pytorch_model/generated_model_script.py +34 -20
  11. workbench/model_scripts/pytorch_model/model_script_utils.py +225 -0
  12. workbench/model_scripts/pytorch_model/pytorch.template +28 -14
  13. workbench/model_scripts/pytorch_model/uq_harness.py +39 -21
  14. workbench/model_scripts/xgb_model/generated_model_script.py +35 -22
  15. workbench/model_scripts/xgb_model/model_script_utils.py +225 -0
  16. workbench/model_scripts/xgb_model/uq_harness.py +39 -21
  17. workbench/model_scripts/xgb_model/xgb_model.template +29 -18
  18. workbench/themes/dark/custom.css +29 -0
  19. workbench/themes/light/custom.css +29 -0
  20. workbench/themes/midnight_blue/custom.css +28 -0
  21. workbench/utils/model_utils.py +9 -0
  22. workbench/utils/theme_manager.py +95 -0
  23. workbench/web_interface/components/component_interface.py +3 -0
  24. workbench/web_interface/components/plugin_interface.py +26 -0
  25. workbench/web_interface/components/plugins/confusion_matrix.py +14 -8
  26. workbench/web_interface/components/plugins/model_plot.py +156 -0
  27. workbench/web_interface/components/plugins/scatter_plot.py +9 -2
  28. workbench/web_interface/components/plugins/shap_summary_plot.py +12 -4
  29. workbench/web_interface/components/settings_menu.py +10 -49
  30. {workbench-0.8.234.dist-info → workbench-0.8.236.dist-info}/METADATA +1 -1
  31. {workbench-0.8.234.dist-info → workbench-0.8.236.dist-info}/RECORD +35 -35
  32. workbench/web_interface/components/model_plot.py +0 -75
  33. {workbench-0.8.234.dist-info → workbench-0.8.236.dist-info}/WHEEL +0 -0
  34. {workbench-0.8.234.dist-info → workbench-0.8.236.dist-info}/entry_points.txt +0 -0
  35. {workbench-0.8.234.dist-info → workbench-0.8.236.dist-info}/licenses/LICENSE +0 -0
  36. {workbench-0.8.234.dist-info → workbench-0.8.236.dist-info}/top_level.txt +0 -0
@@ -135,6 +135,35 @@ div:has(> [class*="ag-theme-"]) {
135
135
  --bs-border-color: rgb(60, 60, 60);
136
136
  }
137
137
 
138
+ /* React-select dropdown styling - target actual rendered elements */
139
+ .Select-control {
140
+ background-color: rgb(35, 35, 35) !important;
141
+ border-color: rgb(60, 60, 60) !important;
142
+ }
143
+
144
+ .Select-value-label, .Select-input input {
145
+ color: rgb(210, 210, 210) !important;
146
+ }
147
+
148
+ .Select-placeholder {
149
+ color: rgb(150, 150, 150) !important;
150
+ }
151
+
152
+ .Select-menu-outer {
153
+ background-color: rgb(35, 35, 35) !important;
154
+ border-color: rgb(60, 60, 60) !important;
155
+ }
156
+
157
+ .VirtualizedSelectOption {
158
+ background-color: rgb(35, 35, 35) !important;
159
+ color: rgb(210, 210, 210) !important;
160
+ }
161
+
162
+ .VirtualizedSelectFocusedOption {
163
+ background-color: rgb(60, 60, 60) !important;
164
+ color: rgb(230, 230, 230) !important;
165
+ }
166
+
138
167
  /* Bootstrap form controls (dbc components) */
139
168
  .form-select, .form-control {
140
169
  background-color: rgb(35, 35, 35) !important;
@@ -180,6 +180,35 @@ div:has(> [class*="ag-theme-"]) {
180
180
  --bs-border-color: var(--wb-accent);
181
181
  }
182
182
 
183
+ /* React-select dropdown styling - target actual rendered elements */
184
+ .Select-control {
185
+ background-color: var(--wb-dropdown-bg) !important;
186
+ border-color: var(--wb-accent) !important;
187
+ }
188
+
189
+ .Select-value-label, .Select-input input {
190
+ color: var(--wb-text-primary) !important;
191
+ }
192
+
193
+ .Select-placeholder {
194
+ color: var(--wb-text-muted) !important;
195
+ }
196
+
197
+ .Select-menu-outer {
198
+ background-color: var(--wb-dropdown-bg) !important;
199
+ border-color: var(--wb-accent) !important;
200
+ }
201
+
202
+ .VirtualizedSelectOption {
203
+ background-color: var(--wb-dropdown-bg) !important;
204
+ color: var(--wb-text-primary) !important;
205
+ }
206
+
207
+ .VirtualizedSelectFocusedOption {
208
+ background-color: var(--wb-dropdown-hover) !important;
209
+ color: var(--wb-text-primary) !important;
210
+ }
211
+
183
212
  /* Bootstrap form controls (dbc components) */
184
213
  .form-select, .form-control {
185
214
  background-color: var(--wb-dropdown-bg) !important;
@@ -133,6 +133,34 @@ div:has(> [class*="ag-theme-"]) {
133
133
  --bs-border-color: rgb(80, 85, 115);
134
134
  }
135
135
 
136
+ /* React-select dropdown styling - target actual rendered elements */
137
+ .Select-control {
138
+ background-color: rgb(55, 60, 90) !important;
139
+ border-color: rgb(80, 85, 115) !important;
140
+ }
141
+
142
+ .Select-value-label, .Select-input input {
143
+ color: rgb(210, 210, 210) !important;
144
+ }
145
+
146
+ .Select-placeholder {
147
+ color: rgb(150, 150, 170) !important;
148
+ }
149
+
150
+ .Select-menu-outer {
151
+ background-color: rgb(55, 60, 90) !important;
152
+ border-color: rgb(80, 85, 115) !important;
153
+ }
154
+
155
+ .VirtualizedSelectOption {
156
+ background-color: rgb(55, 60, 90) !important;
157
+ color: rgb(210, 210, 210) !important;
158
+ }
159
+
160
+ .VirtualizedSelectFocusedOption {
161
+ background-color: rgb(70, 75, 110) !important;
162
+ color: rgb(230, 230, 230) !important;
163
+ }
136
164
 
137
165
  /* Bootstrap form controls (dbc components) */
138
166
  .form-select, .form-control {
@@ -459,6 +459,12 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
459
459
  # Spearman correlation for robustness
460
460
  interval_to_error_corr = spearmanr(width_68, abs_residuals)[0]
461
461
 
462
+ # --- Confidence to Error Correlation ---
463
+ # If confidence column exists, compute correlation (should be negative: high confidence = low error)
464
+ confidence_to_error_corr = None
465
+ if "confidence" in df.columns:
466
+ confidence_to_error_corr = spearmanr(df["confidence"], abs_residuals)[0]
467
+
462
468
  # Collect results
463
469
  results = {
464
470
  "coverage_68": coverage_68,
@@ -472,6 +478,7 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
472
478
  "median_width_90": median_width_90,
473
479
  "median_width_95": median_width_95,
474
480
  "interval_to_error_corr": interval_to_error_corr,
481
+ "confidence_to_error_corr": confidence_to_error_corr,
475
482
  "n_samples": len(df),
476
483
  }
477
484
 
@@ -489,6 +496,8 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
489
496
  print(f"CRPS: {mean_crps:.3f} (lower is better)")
490
497
  print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
491
498
  print(f"Interval/Error Corr: {interval_to_error_corr:.3f} (higher is better, target: >0.5)")
499
+ if confidence_to_error_corr is not None:
500
+ print(f"Confidence/Error Corr: {confidence_to_error_corr:.3f} (lower is better, target: <-0.5)")
492
501
  print(f"Samples: {len(df)}")
493
502
  return results
494
503
 
@@ -155,6 +155,101 @@ class ThemeManager:
155
155
  """Get the name of the current theme."""
156
156
  return cls.current_theme_name
157
157
 
158
+ @classmethod
159
+ def get_theme_urls(cls) -> dict:
160
+ """Get a mapping of theme names to their Bootstrap CSS URLs."""
161
+ return {name: theme["base_css"] for name, theme in cls.available_themes.items() if theme["base_css"]}
162
+
163
+ @classmethod
164
+ def get_dark_themes(cls) -> list:
165
+ """Get list of theme names that are dark mode themes."""
166
+ dark_themes = []
167
+ for name, theme in cls.available_themes.items():
168
+ base_css = theme.get("base_css", "")
169
+ if base_css:
170
+ base_css_upper = base_css.upper()
171
+ for dark_theme in cls._dark_bootstrap_themes:
172
+ if dark_theme in base_css_upper:
173
+ dark_themes.append(name)
174
+ break
175
+ else:
176
+ # Fallback: check if 'dark' is in the theme name
177
+ if "dark" in name.lower():
178
+ dark_themes.append(name)
179
+ return dark_themes
180
+
181
+ @classmethod
182
+ def get_theme_switch_js(cls) -> str:
183
+ """Get the JavaScript code for dynamic theme switching.
184
+
185
+ This code should be used in a clientside callback that updates the Bootstrap
186
+ stylesheet and data-bs-theme attribute when the user selects a new theme.
187
+ """
188
+ theme_urls = cls.get_theme_urls()
189
+ dark_themes = cls.get_dark_themes()
190
+
191
+ return f"""
192
+ function(n_clicks_list, ids) {{
193
+ // Use callback context to find which input triggered this callback
194
+ const ctx = window.dash_clientside.callback_context;
195
+ if (!ctx || !ctx.triggered || ctx.triggered.length === 0) {{
196
+ return window.dash_clientside.no_update;
197
+ }}
198
+
199
+ // Get the triggered input's id (it's a JSON string for pattern-matching callbacks)
200
+ const triggeredId = ctx.triggered[0].prop_id.split('.')[0];
201
+ let clickedTheme = null;
202
+
203
+ try {{
204
+ const parsedId = JSON.parse(triggeredId);
205
+ clickedTheme = parsedId.theme;
206
+ }} catch (e) {{
207
+ // Not a pattern-matching callback ID
208
+ return window.dash_clientside.no_update;
209
+ }}
210
+
211
+ if (!clickedTheme) {{
212
+ return window.dash_clientside.no_update;
213
+ }}
214
+
215
+ // Store in localStorage and cookie
216
+ localStorage.setItem('wb_theme', clickedTheme);
217
+ document.cookie = `wb_theme=${{clickedTheme}}; path=/; max-age=31536000`;
218
+
219
+ // Theme URL mapping (generated from ThemeManager)
220
+ const themeUrls = {json.dumps(theme_urls)};
221
+ const darkThemes = {json.dumps(dark_themes)};
222
+
223
+ // Update Bootstrap's data-bs-theme attribute
224
+ const bsTheme = darkThemes.includes(clickedTheme) ? 'dark' : 'light';
225
+ document.documentElement.setAttribute('data-bs-theme', bsTheme);
226
+
227
+ // Find and update the Bootstrap stylesheet
228
+ const newUrl = themeUrls[clickedTheme];
229
+ if (newUrl) {{
230
+ let stylesheet = document.querySelector('link[href*="bootswatch"]') ||
231
+ document.querySelector('link[href*="bootstrap"]') ||
232
+ document.querySelector('link[href="/base.css"]');
233
+ if (stylesheet) {{
234
+ stylesheet.setAttribute('href', newUrl);
235
+ }} else {{
236
+ stylesheet = document.createElement('link');
237
+ stylesheet.rel = 'stylesheet';
238
+ stylesheet.href = newUrl;
239
+ document.head.appendChild(stylesheet);
240
+ }}
241
+ }}
242
+
243
+ // Force reload custom.css with cache-busting query param
244
+ let customCss = document.querySelector('link[href^="/custom.css"]');
245
+ if (customCss) {{
246
+ customCss.setAttribute('href', `/custom.css?t=${{Date.now()}}`);
247
+ }}
248
+
249
+ return clickedTheme;
250
+ }}
251
+ """
252
+
158
253
  @classmethod
159
254
  def background(cls) -> list[list[float | str]]:
160
255
  """Get the plot background for the current theme."""
@@ -136,6 +136,7 @@ class ComponentInterface(ABC):
136
136
  xref="paper",
137
137
  yref="paper",
138
138
  text=text_message,
139
+ bgcolor="rgba(0,0,0,0)",
139
140
  showarrow=False,
140
141
  font=dict(size=font_size, color="#9999cc"),
141
142
  )
@@ -144,6 +145,8 @@ class ComponentInterface(ABC):
144
145
  xaxis=dict(showticklabels=False, zeroline=False, showgrid=False),
145
146
  yaxis=dict(showticklabels=False, zeroline=False, showgrid=False),
146
147
  margin=dict(l=0, r=0, b=0, t=0),
148
+ paper_bgcolor="rgba(0,0,0,0)",
149
+ plot_bgcolor="rgba(0,0,0,0)",
147
150
  )
148
151
 
149
152
  if figure_height is not None:
@@ -3,10 +3,12 @@ import inspect
3
3
  from typing import Union, Tuple, get_args, get_origin
4
4
  from enum import Enum
5
5
  import logging
6
+ from dash import no_update
6
7
  from dash.development.base_component import Component
7
8
 
8
9
  # Local Imports
9
10
  from workbench.web_interface.components.component_interface import ComponentInterface
11
+ from workbench.utils.theme_manager import ThemeManager
10
12
 
11
13
  log = logging.getLogger("workbench")
12
14
 
@@ -48,6 +50,9 @@ class PluginInterface(ComponentInterface):
48
50
  - The 'register_internal_callbacks' method is optional
49
51
  """
50
52
 
53
+ # Shared ThemeManager instance for all plugins
54
+ theme_manager = ThemeManager()
55
+
51
56
  @abstractmethod
52
57
  def create_component(self, component_id: str) -> Component:
53
58
  """Create a Dash Component without any data.
@@ -83,6 +88,27 @@ class PluginInterface(ComponentInterface):
83
88
  """
84
89
  pass
85
90
 
91
+ def set_theme(self, theme: str) -> list:
92
+ """Called when the application theme changes. Override to re-render with new theme colors.
93
+
94
+ The default implementation returns no_update for all properties. Plugins that have
95
+ theme-dependent rendering (e.g., figures with colorscales) should override this method
96
+ to re-render their components.
97
+
98
+ Args:
99
+ theme (str): The name of the new theme (e.g., "light", "dark", "midnight_blue").
100
+
101
+ Returns:
102
+ list: Updated property values (same format as update_properties), or no_update list.
103
+
104
+ Example:
105
+ def set_theme(self, theme: str) -> list:
106
+ if self.model is None:
107
+ return [no_update] * len(self.properties)
108
+ return self.update_properties(self.model)
109
+ """
110
+ return [no_update] * len(self.properties)
111
+
86
112
  #
87
113
  # Internal Methods: These methods are used to validate the plugin interface at runtime
88
114
  #
@@ -1,11 +1,10 @@
1
1
  """A confusion matrix plugin component"""
2
2
 
3
- from dash import dcc, callback, Output, Input, State
3
+ from dash import dcc, callback, Output, Input, State, no_update
4
4
  import plotly.graph_objects as go
5
5
 
6
6
  # Workbench Imports
7
7
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
8
- from workbench.utils.theme_manager import ThemeManager
9
8
  from workbench.cached.cached_model import CachedModel
10
9
  from workbench.utils.color_utils import add_alpha_to_first_color
11
10
 
@@ -19,10 +18,8 @@ class ConfusionMatrix(PluginInterface):
19
18
  def __init__(self):
20
19
  """Initialize the ConfusionMatrix plugin class"""
21
20
  self.component_id = None
22
- self.current_highlight = None # Store the currently highlighted cell
23
- self.theme_manager = ThemeManager()
24
-
25
- # Call the parent class constructor
21
+ self.model = None # Store the model for re-rendering on theme change
22
+ self.inference_run = None # Store the inference run for re-rendering
26
23
  super().__init__()
27
24
 
28
25
  def create_component(self, component_id: str) -> dcc.Graph:
@@ -57,9 +54,12 @@ class ConfusionMatrix(PluginInterface):
57
54
  Returns:
58
55
  list: A list containing the updated Plotly figure.
59
56
  """
57
+ # Store for re-rendering on theme change
58
+ self.model = model
59
+ self.inference_run = kwargs.get("inference_run", "auto_inference")
60
+
60
61
  # Retrieve the confusion matrix data
61
- inference_run = kwargs.get("inference_run", "auto_inference")
62
- df = model.confusion_matrix(inference_run)
62
+ df = model.confusion_matrix(self.inference_run)
63
63
  if df is None:
64
64
  return [self.display_text("No Data")]
65
65
 
@@ -137,6 +137,12 @@ class ConfusionMatrix(PluginInterface):
137
137
  # Return the updated figure wrapped in a list
138
138
  return [fig]
139
139
 
140
+ def set_theme(self, theme: str) -> list:
141
+ """Re-render the confusion matrix when the theme changes."""
142
+ if self.model is None:
143
+ return [no_update] * len(self.properties)
144
+ return self.update_properties(self.model, inference_run=self.inference_run)
145
+
140
146
  def register_internal_callbacks(self):
141
147
  """Register internal callbacks for the plugin."""
142
148
 
@@ -0,0 +1,156 @@
1
+ """A Model Plot plugin that displays the appropriate visualization based on model type.
2
+
3
+ For classifiers: Shows a Confusion Matrix
4
+ For regressors: Shows a Scatter Plot of predictions vs actuals
5
+ """
6
+
7
+ from dash import html, no_update
8
+
9
+ # Workbench Imports
10
+ from workbench.api import ModelType
11
+ from workbench.cached.cached_model import CachedModel
12
+ from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
13
+ from workbench.web_interface.components.plugins.confusion_matrix import ConfusionMatrix
14
+ from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
15
+
16
+
17
+ class ModelPlot(PluginInterface):
18
+ """Model Plot Plugin - switches between ConfusionMatrix and ScatterPlot based on model type."""
19
+
20
+ auto_load_page = PluginPage.NONE
21
+ plugin_input_type = PluginInputType.MODEL
22
+
23
+ def __init__(self):
24
+ """Initialize the ModelPlot plugin class"""
25
+ self.component_id = None
26
+ self.model = None
27
+ self.inference_run = None
28
+
29
+ # Internal plugins
30
+ self.scatter_plot = ScatterPlot()
31
+ self.confusion_matrix = ConfusionMatrix()
32
+
33
+ # Call the parent class constructor
34
+ super().__init__()
35
+
36
+ def create_component(self, component_id: str) -> html.Div:
37
+ """Create a container with both ScatterPlot and ConfusionMatrix components.
38
+
39
+ Args:
40
+ component_id (str): The ID of the web component
41
+
42
+ Returns:
43
+ html.Div: Container with both plot types (one hidden based on model type)
44
+ """
45
+ self.component_id = component_id
46
+
47
+ # Create internal components
48
+ scatter_component = self.scatter_plot.create_component(f"{component_id}-scatter")
49
+ confusion_component = self.confusion_matrix.create_component(f"{component_id}-confusion")
50
+
51
+ # Build properties list: visibility styles + scatter props + confusion props
52
+ self.properties = [
53
+ (f"{component_id}-scatter-container", "style"),
54
+ (f"{component_id}-confusion-container", "style"),
55
+ ]
56
+ self.properties.extend(self.scatter_plot.properties)
57
+ self.properties.extend(self.confusion_matrix.properties)
58
+
59
+ # Aggregate signals from both plugins
60
+ self.signals = self.scatter_plot.signals + self.confusion_matrix.signals
61
+
62
+ # Create container with both components
63
+ # Show scatter plot by default (will display "Waiting for Data..." until model loads)
64
+ return html.Div(
65
+ id=component_id,
66
+ children=[
67
+ html.Div(
68
+ scatter_component,
69
+ id=f"{component_id}-scatter-container",
70
+ style={"display": "block"},
71
+ ),
72
+ html.Div(
73
+ confusion_component,
74
+ id=f"{component_id}-confusion-container",
75
+ style={"display": "none"},
76
+ ),
77
+ ],
78
+ )
79
+
80
+ def update_properties(self, model: CachedModel, **kwargs) -> list:
81
+ """Update the plot based on model type.
82
+
83
+ Args:
84
+ model (CachedModel): The model to visualize
85
+ **kwargs:
86
+ - inference_run (str): Inference capture name (default: "auto_inference")
87
+
88
+ Returns:
89
+ list: Property values [scatter_style, confusion_style, ...scatter_props, ...confusion_props]
90
+ """
91
+ # Cache for theme re-rendering
92
+ self.model = model
93
+ self.inference_run = kwargs.get("inference_run", "full_cross_fold")
94
+
95
+ # Determine model type and set visibility
96
+ is_classifier = model.model_type == ModelType.CLASSIFIER
97
+ scatter_style = {"display": "none"} if is_classifier else {"display": "block"}
98
+ confusion_style = {"display": "block"} if is_classifier else {"display": "none"}
99
+
100
+ if is_classifier:
101
+ # Update ConfusionMatrix, no_update for ScatterPlot
102
+ cm_props = self.confusion_matrix.update_properties(model, inference_run=self.inference_run)
103
+ scatter_props = [no_update] * len(self.scatter_plot.properties)
104
+ else:
105
+ # Update ScatterPlot with regression data
106
+ df = model.get_inference_predictions(self.inference_run)
107
+ if df is None:
108
+ # Still update visibility styles, but no_update for plugin properties
109
+ scatter_props = [no_update] * len(self.scatter_plot.properties)
110
+ cm_props = [no_update] * len(self.confusion_matrix.properties)
111
+ return [scatter_style, confusion_style] + scatter_props + cm_props
112
+
113
+ # Get target column for the x-axis
114
+ target = model.target()
115
+ if isinstance(target, list):
116
+ target = next((t for t in target if t in self.inference_run), target[0])
117
+
118
+ # Check if "confidence" column exists for coloring
119
+ color_col = "confidence" if "confidence" in df.columns else "prediction"
120
+
121
+ scatter_props = self.scatter_plot.update_properties(
122
+ df,
123
+ x=target,
124
+ y="prediction",
125
+ color=color_col,
126
+ regression_line=True,
127
+ )
128
+ cm_props = [no_update] * len(self.confusion_matrix.properties)
129
+
130
+ return [scatter_style, confusion_style] + scatter_props + cm_props
131
+
132
+ def set_theme(self, theme: str) -> list:
133
+ """Re-render the appropriate plot when the theme changes."""
134
+ if self.model is None:
135
+ return [no_update] * len(self.properties)
136
+
137
+ # Just call update_properties which will re-render the right plot
138
+ return self.update_properties(self.model, inference_run=self.inference_run)
139
+
140
+ def register_internal_callbacks(self):
141
+ """Register internal callbacks for both sub-plugins."""
142
+ self.scatter_plot.register_internal_callbacks()
143
+ self.confusion_matrix.register_internal_callbacks()
144
+
145
+
146
+ if __name__ == "__main__":
147
+ """Run the Unit Test for the Plugin."""
148
+ from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
149
+
150
+ # Test with a classifier (shows Confusion Matrix)
151
+ classifier_model = CachedModel("wine-classification")
152
+ PluginUnitTest(ModelPlot, input_data=classifier_model, theme="dark").run()
153
+
154
+ # Test with a regressor (shows Scatter Plot)
155
+ regressor_model = CachedModel("abalone-regression")
156
+ PluginUnitTest(ModelPlot, input_data=regressor_model, theme="dark").run()
@@ -8,7 +8,6 @@ from dash.exceptions import PreventUpdate
8
8
 
9
9
  # Workbench Imports
10
10
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
11
- from workbench.utils.theme_manager import ThemeManager
12
11
  from workbench.utils.plot_utils import prediction_intervals
13
12
  from workbench.utils.chem_utils.vis import molecule_hover_tooltip
14
13
  from workbench.utils.clientside_callbacks import circle_overlay_callback
@@ -37,7 +36,6 @@ class ScatterPlot(PluginInterface):
37
36
  self.hover_columns = []
38
37
  self.df = None
39
38
  self.show_axes = show_axes
40
- self.theme_manager = ThemeManager()
41
39
  self.has_smiles = False # Track if dataframe has smiles column for molecule hover
42
40
  self.smiles_column = None
43
41
  self.id_column = None
@@ -242,6 +240,15 @@ class ScatterPlot(PluginInterface):
242
240
 
243
241
  return [figure, x_options, y_options, color_options, x_default, y_default, color_default, regression_line_value]
244
242
 
243
+ def set_theme(self, theme: str) -> list:
244
+ """Re-render the scatter plot when the theme changes."""
245
+ # If no data yet, return no_update for all properties
246
+ if self.df is None or self.df.empty:
247
+ return [no_update] * len(self.properties)
248
+
249
+ # Re-render with defaults (user dropdown selections reset, but theme changes are rare)
250
+ return self.update_properties(self.df)
251
+
245
252
  def create_scatter_plot(
246
253
  self,
247
254
  df: pd.DataFrame,
@@ -1,6 +1,6 @@
1
1
  """SHAP Summary Plot visualization component for XGBoost models"""
2
2
 
3
- from dash import dcc
3
+ from dash import dcc, no_update
4
4
  import plotly.graph_objects as go
5
5
  import pandas as pd
6
6
  import numpy as np
@@ -9,7 +9,6 @@ from typing import Dict, List
9
9
  # Workbench Imports
10
10
  from workbench.cached.cached_model import CachedModel
11
11
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
12
- from workbench.utils.theme_manager import ThemeManager
13
12
  from workbench.utils.plot_utils import beeswarm_offsets
14
13
 
15
14
 
@@ -22,7 +21,7 @@ class ShapSummaryPlot(PluginInterface):
22
21
  def __init__(self):
23
22
  """Initialize the ShapSummaryPlot plugin class"""
24
23
  self.component_id = None
25
- self.theme_manager = ThemeManager()
24
+ self.model = None # Store the model for re-rendering on theme change
26
25
  super().__init__()
27
26
 
28
27
  def create_component(self, component_id: str) -> dcc.Graph:
@@ -39,6 +38,9 @@ class ShapSummaryPlot(PluginInterface):
39
38
 
40
39
  def update_properties(self, model: CachedModel, **kwargs) -> list:
41
40
  """Create a SHAP Summary Plot for feature importance visualization."""
41
+ # Store for re-rendering on theme change
42
+ self.model = model
43
+
42
44
  # Basic validation
43
45
  shap_data = model.shap_data()
44
46
  shap_sample_rows = model.shap_sample()
@@ -221,9 +223,15 @@ class ShapSummaryPlot(PluginInterface):
221
223
  )
222
224
  return main_fig
223
225
 
226
+ def set_theme(self, theme: str) -> list:
227
+ """Re-render the SHAP summary plot when the theme changes."""
228
+ if self.model is None:
229
+ return [no_update] * len(self.properties)
230
+ return self.update_properties(self.model)
231
+
224
232
  def register_internal_callbacks(self):
225
233
  """Register internal callbacks for the plugin."""
226
- pass # Implement if needed
234
+ pass # No internal callbacks needed
227
235
 
228
236
 
229
237
  if __name__ == "__main__":