workbench 0.8.234__py3-none-any.whl → 0.8.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/smart_aggregator.py +17 -12
- workbench/api/endpoint.py +13 -4
- workbench/api/model.py +2 -2
- workbench/cached/cached_model.py +2 -2
- workbench/core/artifacts/athena_source.py +5 -3
- workbench/core/artifacts/endpoint_core.py +30 -5
- workbench/core/cloud_platform/aws/aws_meta.py +2 -1
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +27 -14
- workbench/model_script_utils/model_script_utils.py +225 -0
- workbench/model_script_utils/uq_harness.py +39 -21
- workbench/model_scripts/chemprop/chemprop.template +30 -15
- workbench/model_scripts/chemprop/generated_model_script.py +35 -18
- workbench/model_scripts/chemprop/model_script_utils.py +225 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +29 -15
- workbench/model_scripts/pytorch_model/model_script_utils.py +225 -0
- workbench/model_scripts/pytorch_model/pytorch.template +28 -14
- workbench/model_scripts/pytorch_model/uq_harness.py +39 -21
- workbench/model_scripts/xgb_model/generated_model_script.py +35 -22
- workbench/model_scripts/xgb_model/model_script_utils.py +225 -0
- workbench/model_scripts/xgb_model/uq_harness.py +39 -21
- workbench/model_scripts/xgb_model/xgb_model.template +29 -18
- workbench/scripts/ml_pipeline_batch.py +47 -2
- workbench/scripts/ml_pipeline_launcher.py +410 -0
- workbench/scripts/ml_pipeline_sqs.py +22 -2
- workbench/themes/dark/custom.css +29 -0
- workbench/themes/light/custom.css +29 -0
- workbench/themes/midnight_blue/custom.css +28 -0
- workbench/utils/model_utils.py +9 -0
- workbench/utils/theme_manager.py +95 -0
- workbench/web_interface/components/component_interface.py +3 -0
- workbench/web_interface/components/plugin_interface.py +26 -0
- workbench/web_interface/components/plugins/ag_table.py +4 -11
- workbench/web_interface/components/plugins/confusion_matrix.py +14 -8
- workbench/web_interface/components/plugins/model_plot.py +156 -0
- workbench/web_interface/components/plugins/scatter_plot.py +9 -2
- workbench/web_interface/components/plugins/shap_summary_plot.py +12 -4
- workbench/web_interface/components/settings_menu.py +10 -49
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/METADATA +2 -2
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/RECORD +43 -42
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/WHEEL +1 -1
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/entry_points.txt +1 -0
- workbench/web_interface/components/model_plot.py +0 -75
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/top_level.txt +0 -0
workbench/utils/theme_manager.py
CHANGED
|
@@ -155,6 +155,101 @@ class ThemeManager:
|
|
|
155
155
|
"""Get the name of the current theme."""
|
|
156
156
|
return cls.current_theme_name
|
|
157
157
|
|
|
158
|
+
@classmethod
|
|
159
|
+
def get_theme_urls(cls) -> dict:
|
|
160
|
+
"""Get a mapping of theme names to their Bootstrap CSS URLs."""
|
|
161
|
+
return {name: theme["base_css"] for name, theme in cls.available_themes.items() if theme["base_css"]}
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
def get_dark_themes(cls) -> list:
|
|
165
|
+
"""Get list of theme names that are dark mode themes."""
|
|
166
|
+
dark_themes = []
|
|
167
|
+
for name, theme in cls.available_themes.items():
|
|
168
|
+
base_css = theme.get("base_css", "")
|
|
169
|
+
if base_css:
|
|
170
|
+
base_css_upper = base_css.upper()
|
|
171
|
+
for dark_theme in cls._dark_bootstrap_themes:
|
|
172
|
+
if dark_theme in base_css_upper:
|
|
173
|
+
dark_themes.append(name)
|
|
174
|
+
break
|
|
175
|
+
else:
|
|
176
|
+
# Fallback: check if 'dark' is in the theme name
|
|
177
|
+
if "dark" in name.lower():
|
|
178
|
+
dark_themes.append(name)
|
|
179
|
+
return dark_themes
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def get_theme_switch_js(cls) -> str:
|
|
183
|
+
"""Get the JavaScript code for dynamic theme switching.
|
|
184
|
+
|
|
185
|
+
This code should be used in a clientside callback that updates the Bootstrap
|
|
186
|
+
stylesheet and data-bs-theme attribute when the user selects a new theme.
|
|
187
|
+
"""
|
|
188
|
+
theme_urls = cls.get_theme_urls()
|
|
189
|
+
dark_themes = cls.get_dark_themes()
|
|
190
|
+
|
|
191
|
+
return f"""
|
|
192
|
+
function(n_clicks_list, ids) {{
|
|
193
|
+
// Use callback context to find which input triggered this callback
|
|
194
|
+
const ctx = window.dash_clientside.callback_context;
|
|
195
|
+
if (!ctx || !ctx.triggered || ctx.triggered.length === 0) {{
|
|
196
|
+
return window.dash_clientside.no_update;
|
|
197
|
+
}}
|
|
198
|
+
|
|
199
|
+
// Get the triggered input's id (it's a JSON string for pattern-matching callbacks)
|
|
200
|
+
const triggeredId = ctx.triggered[0].prop_id.split('.')[0];
|
|
201
|
+
let clickedTheme = null;
|
|
202
|
+
|
|
203
|
+
try {{
|
|
204
|
+
const parsedId = JSON.parse(triggeredId);
|
|
205
|
+
clickedTheme = parsedId.theme;
|
|
206
|
+
}} catch (e) {{
|
|
207
|
+
// Not a pattern-matching callback ID
|
|
208
|
+
return window.dash_clientside.no_update;
|
|
209
|
+
}}
|
|
210
|
+
|
|
211
|
+
if (!clickedTheme) {{
|
|
212
|
+
return window.dash_clientside.no_update;
|
|
213
|
+
}}
|
|
214
|
+
|
|
215
|
+
// Store in localStorage and cookie
|
|
216
|
+
localStorage.setItem('wb_theme', clickedTheme);
|
|
217
|
+
document.cookie = `wb_theme=${{clickedTheme}}; path=/; max-age=31536000`;
|
|
218
|
+
|
|
219
|
+
// Theme URL mapping (generated from ThemeManager)
|
|
220
|
+
const themeUrls = {json.dumps(theme_urls)};
|
|
221
|
+
const darkThemes = {json.dumps(dark_themes)};
|
|
222
|
+
|
|
223
|
+
// Update Bootstrap's data-bs-theme attribute
|
|
224
|
+
const bsTheme = darkThemes.includes(clickedTheme) ? 'dark' : 'light';
|
|
225
|
+
document.documentElement.setAttribute('data-bs-theme', bsTheme);
|
|
226
|
+
|
|
227
|
+
// Find and update the Bootstrap stylesheet
|
|
228
|
+
const newUrl = themeUrls[clickedTheme];
|
|
229
|
+
if (newUrl) {{
|
|
230
|
+
let stylesheet = document.querySelector('link[href*="bootswatch"]') ||
|
|
231
|
+
document.querySelector('link[href*="bootstrap"]') ||
|
|
232
|
+
document.querySelector('link[href="/base.css"]');
|
|
233
|
+
if (stylesheet) {{
|
|
234
|
+
stylesheet.setAttribute('href', newUrl);
|
|
235
|
+
}} else {{
|
|
236
|
+
stylesheet = document.createElement('link');
|
|
237
|
+
stylesheet.rel = 'stylesheet';
|
|
238
|
+
stylesheet.href = newUrl;
|
|
239
|
+
document.head.appendChild(stylesheet);
|
|
240
|
+
}}
|
|
241
|
+
}}
|
|
242
|
+
|
|
243
|
+
// Force reload custom.css with cache-busting query param
|
|
244
|
+
let customCss = document.querySelector('link[href^="/custom.css"]');
|
|
245
|
+
if (customCss) {{
|
|
246
|
+
customCss.setAttribute('href', `/custom.css?t=${{Date.now()}}`);
|
|
247
|
+
}}
|
|
248
|
+
|
|
249
|
+
return clickedTheme;
|
|
250
|
+
}}
|
|
251
|
+
"""
|
|
252
|
+
|
|
158
253
|
@classmethod
|
|
159
254
|
def background(cls) -> list[list[float | str]]:
|
|
160
255
|
"""Get the plot background for the current theme."""
|
|
@@ -136,6 +136,7 @@ class ComponentInterface(ABC):
|
|
|
136
136
|
xref="paper",
|
|
137
137
|
yref="paper",
|
|
138
138
|
text=text_message,
|
|
139
|
+
bgcolor="rgba(0,0,0,0)",
|
|
139
140
|
showarrow=False,
|
|
140
141
|
font=dict(size=font_size, color="#9999cc"),
|
|
141
142
|
)
|
|
@@ -144,6 +145,8 @@ class ComponentInterface(ABC):
|
|
|
144
145
|
xaxis=dict(showticklabels=False, zeroline=False, showgrid=False),
|
|
145
146
|
yaxis=dict(showticklabels=False, zeroline=False, showgrid=False),
|
|
146
147
|
margin=dict(l=0, r=0, b=0, t=0),
|
|
148
|
+
paper_bgcolor="rgba(0,0,0,0)",
|
|
149
|
+
plot_bgcolor="rgba(0,0,0,0)",
|
|
147
150
|
)
|
|
148
151
|
|
|
149
152
|
if figure_height is not None:
|
|
@@ -3,10 +3,12 @@ import inspect
|
|
|
3
3
|
from typing import Union, Tuple, get_args, get_origin
|
|
4
4
|
from enum import Enum
|
|
5
5
|
import logging
|
|
6
|
+
from dash import no_update
|
|
6
7
|
from dash.development.base_component import Component
|
|
7
8
|
|
|
8
9
|
# Local Imports
|
|
9
10
|
from workbench.web_interface.components.component_interface import ComponentInterface
|
|
11
|
+
from workbench.utils.theme_manager import ThemeManager
|
|
10
12
|
|
|
11
13
|
log = logging.getLogger("workbench")
|
|
12
14
|
|
|
@@ -48,6 +50,9 @@ class PluginInterface(ComponentInterface):
|
|
|
48
50
|
- The 'register_internal_callbacks' method is optional
|
|
49
51
|
"""
|
|
50
52
|
|
|
53
|
+
# Shared ThemeManager instance for all plugins
|
|
54
|
+
theme_manager = ThemeManager()
|
|
55
|
+
|
|
51
56
|
@abstractmethod
|
|
52
57
|
def create_component(self, component_id: str) -> Component:
|
|
53
58
|
"""Create a Dash Component without any data.
|
|
@@ -83,6 +88,27 @@ class PluginInterface(ComponentInterface):
|
|
|
83
88
|
"""
|
|
84
89
|
pass
|
|
85
90
|
|
|
91
|
+
def set_theme(self, theme: str) -> list:
|
|
92
|
+
"""Called when the application theme changes. Override to re-render with new theme colors.
|
|
93
|
+
|
|
94
|
+
The default implementation returns no_update for all properties. Plugins that have
|
|
95
|
+
theme-dependent rendering (e.g., figures with colorscales) should override this method
|
|
96
|
+
to re-render their components.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
theme (str): The name of the new theme (e.g., "light", "dark", "midnight_blue").
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
list: Updated property values (same format as update_properties), or no_update list.
|
|
103
|
+
|
|
104
|
+
Example:
|
|
105
|
+
def set_theme(self, theme: str) -> list:
|
|
106
|
+
if self.model is None:
|
|
107
|
+
return [no_update] * len(self.properties)
|
|
108
|
+
return self.update_properties(self.model)
|
|
109
|
+
"""
|
|
110
|
+
return [no_update] * len(self.properties)
|
|
111
|
+
|
|
86
112
|
#
|
|
87
113
|
# Internal Methods: These methods are used to validate the plugin interface at runtime
|
|
88
114
|
#
|
|
@@ -96,18 +96,11 @@ class AGTable(PluginInterface):
|
|
|
96
96
|
|
|
97
97
|
if __name__ == "__main__":
|
|
98
98
|
# Run the Unit Test for the Plugin
|
|
99
|
+
from workbench.api import Meta
|
|
99
100
|
from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
|
|
100
101
|
|
|
101
|
-
# Test data
|
|
102
|
-
|
|
103
|
-
"ID": [f"id_{i}" for i in range(10)],
|
|
104
|
-
"feat1": [1.0, 1.0, 1.1, 3.0, 4.0, 1.0, 1.0, 1.1, 3.0, 4.0],
|
|
105
|
-
"feat2": [1.0, 1.0, 1.1, 3.0, 4.0, 1.0, 1.0, 1.1, 3.0, 4.0],
|
|
106
|
-
"feat3": [0.1, 0.15, 0.2, 0.9, 2.8, 0.25, 0.35, 0.4, 1.6, 2.5],
|
|
107
|
-
"price": [31, 60, 62, 40, 20, 31, 61, 60, 40, 20],
|
|
108
|
-
"name": ["A", "B", "C", "D", "E", "F", "G", "H", "I", "Z" * 55],
|
|
109
|
-
}
|
|
110
|
-
test_df = pd.DataFrame(data)
|
|
102
|
+
# Test on model data
|
|
103
|
+
models_df = Meta().models(details=True)
|
|
111
104
|
|
|
112
105
|
# Run the Unit Test on the Plugin
|
|
113
|
-
PluginUnitTest(AGTable, theme="dark", input_data=
|
|
106
|
+
PluginUnitTest(AGTable, theme="dark", input_data=models_df, max_height=500).run()
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
"""A confusion matrix plugin component"""
|
|
2
2
|
|
|
3
|
-
from dash import dcc, callback, Output, Input, State
|
|
3
|
+
from dash import dcc, callback, Output, Input, State, no_update
|
|
4
4
|
import plotly.graph_objects as go
|
|
5
5
|
|
|
6
6
|
# Workbench Imports
|
|
7
7
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
8
|
-
from workbench.utils.theme_manager import ThemeManager
|
|
9
8
|
from workbench.cached.cached_model import CachedModel
|
|
10
9
|
from workbench.utils.color_utils import add_alpha_to_first_color
|
|
11
10
|
|
|
@@ -19,10 +18,8 @@ class ConfusionMatrix(PluginInterface):
|
|
|
19
18
|
def __init__(self):
|
|
20
19
|
"""Initialize the ConfusionMatrix plugin class"""
|
|
21
20
|
self.component_id = None
|
|
22
|
-
self.
|
|
23
|
-
self.
|
|
24
|
-
|
|
25
|
-
# Call the parent class constructor
|
|
21
|
+
self.model = None # Store the model for re-rendering on theme change
|
|
22
|
+
self.inference_run = None # Store the inference run for re-rendering
|
|
26
23
|
super().__init__()
|
|
27
24
|
|
|
28
25
|
def create_component(self, component_id: str) -> dcc.Graph:
|
|
@@ -57,9 +54,12 @@ class ConfusionMatrix(PluginInterface):
|
|
|
57
54
|
Returns:
|
|
58
55
|
list: A list containing the updated Plotly figure.
|
|
59
56
|
"""
|
|
57
|
+
# Store for re-rendering on theme change
|
|
58
|
+
self.model = model
|
|
59
|
+
self.inference_run = kwargs.get("inference_run", "auto_inference")
|
|
60
|
+
|
|
60
61
|
# Retrieve the confusion matrix data
|
|
61
|
-
|
|
62
|
-
df = model.confusion_matrix(inference_run)
|
|
62
|
+
df = model.confusion_matrix(self.inference_run)
|
|
63
63
|
if df is None:
|
|
64
64
|
return [self.display_text("No Data")]
|
|
65
65
|
|
|
@@ -137,6 +137,12 @@ class ConfusionMatrix(PluginInterface):
|
|
|
137
137
|
# Return the updated figure wrapped in a list
|
|
138
138
|
return [fig]
|
|
139
139
|
|
|
140
|
+
def set_theme(self, theme: str) -> list:
|
|
141
|
+
"""Re-render the confusion matrix when the theme changes."""
|
|
142
|
+
if self.model is None:
|
|
143
|
+
return [no_update] * len(self.properties)
|
|
144
|
+
return self.update_properties(self.model, inference_run=self.inference_run)
|
|
145
|
+
|
|
140
146
|
def register_internal_callbacks(self):
|
|
141
147
|
"""Register internal callbacks for the plugin."""
|
|
142
148
|
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""A Model Plot plugin that displays the appropriate visualization based on model type.
|
|
2
|
+
|
|
3
|
+
For classifiers: Shows a Confusion Matrix
|
|
4
|
+
For regressors: Shows a Scatter Plot of predictions vs actuals
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dash import html, no_update
|
|
8
|
+
|
|
9
|
+
# Workbench Imports
|
|
10
|
+
from workbench.api import ModelType
|
|
11
|
+
from workbench.cached.cached_model import CachedModel
|
|
12
|
+
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
13
|
+
from workbench.web_interface.components.plugins.confusion_matrix import ConfusionMatrix
|
|
14
|
+
from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ModelPlot(PluginInterface):
|
|
18
|
+
"""Model Plot Plugin - switches between ConfusionMatrix and ScatterPlot based on model type."""
|
|
19
|
+
|
|
20
|
+
auto_load_page = PluginPage.NONE
|
|
21
|
+
plugin_input_type = PluginInputType.MODEL
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
"""Initialize the ModelPlot plugin class"""
|
|
25
|
+
self.component_id = None
|
|
26
|
+
self.model = None
|
|
27
|
+
self.inference_run = None
|
|
28
|
+
|
|
29
|
+
# Internal plugins
|
|
30
|
+
self.scatter_plot = ScatterPlot()
|
|
31
|
+
self.confusion_matrix = ConfusionMatrix()
|
|
32
|
+
|
|
33
|
+
# Call the parent class constructor
|
|
34
|
+
super().__init__()
|
|
35
|
+
|
|
36
|
+
def create_component(self, component_id: str) -> html.Div:
|
|
37
|
+
"""Create a container with both ScatterPlot and ConfusionMatrix components.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
component_id (str): The ID of the web component
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
html.Div: Container with both plot types (one hidden based on model type)
|
|
44
|
+
"""
|
|
45
|
+
self.component_id = component_id
|
|
46
|
+
|
|
47
|
+
# Create internal components
|
|
48
|
+
scatter_component = self.scatter_plot.create_component(f"{component_id}-scatter")
|
|
49
|
+
confusion_component = self.confusion_matrix.create_component(f"{component_id}-confusion")
|
|
50
|
+
|
|
51
|
+
# Build properties list: visibility styles + scatter props + confusion props
|
|
52
|
+
self.properties = [
|
|
53
|
+
(f"{component_id}-scatter-container", "style"),
|
|
54
|
+
(f"{component_id}-confusion-container", "style"),
|
|
55
|
+
]
|
|
56
|
+
self.properties.extend(self.scatter_plot.properties)
|
|
57
|
+
self.properties.extend(self.confusion_matrix.properties)
|
|
58
|
+
|
|
59
|
+
# Aggregate signals from both plugins
|
|
60
|
+
self.signals = self.scatter_plot.signals + self.confusion_matrix.signals
|
|
61
|
+
|
|
62
|
+
# Create container with both components
|
|
63
|
+
# Show scatter plot by default (will display "Waiting for Data..." until model loads)
|
|
64
|
+
return html.Div(
|
|
65
|
+
id=component_id,
|
|
66
|
+
children=[
|
|
67
|
+
html.Div(
|
|
68
|
+
scatter_component,
|
|
69
|
+
id=f"{component_id}-scatter-container",
|
|
70
|
+
style={"display": "block"},
|
|
71
|
+
),
|
|
72
|
+
html.Div(
|
|
73
|
+
confusion_component,
|
|
74
|
+
id=f"{component_id}-confusion-container",
|
|
75
|
+
style={"display": "none"},
|
|
76
|
+
),
|
|
77
|
+
],
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def update_properties(self, model: CachedModel, **kwargs) -> list:
|
|
81
|
+
"""Update the plot based on model type.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
model (CachedModel): The model to visualize
|
|
85
|
+
**kwargs:
|
|
86
|
+
- inference_run (str): Inference capture name (default: "auto_inference")
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
list: Property values [scatter_style, confusion_style, ...scatter_props, ...confusion_props]
|
|
90
|
+
"""
|
|
91
|
+
# Cache for theme re-rendering
|
|
92
|
+
self.model = model
|
|
93
|
+
self.inference_run = kwargs.get("inference_run", "full_cross_fold")
|
|
94
|
+
|
|
95
|
+
# Determine model type and set visibility
|
|
96
|
+
is_classifier = model.model_type == ModelType.CLASSIFIER
|
|
97
|
+
scatter_style = {"display": "none"} if is_classifier else {"display": "block"}
|
|
98
|
+
confusion_style = {"display": "block"} if is_classifier else {"display": "none"}
|
|
99
|
+
|
|
100
|
+
if is_classifier:
|
|
101
|
+
# Update ConfusionMatrix, no_update for ScatterPlot
|
|
102
|
+
cm_props = self.confusion_matrix.update_properties(model, inference_run=self.inference_run)
|
|
103
|
+
scatter_props = [no_update] * len(self.scatter_plot.properties)
|
|
104
|
+
else:
|
|
105
|
+
# Update ScatterPlot with regression data
|
|
106
|
+
df = model.get_inference_predictions(self.inference_run)
|
|
107
|
+
if df is None:
|
|
108
|
+
# Still update visibility styles, but no_update for plugin properties
|
|
109
|
+
scatter_props = [no_update] * len(self.scatter_plot.properties)
|
|
110
|
+
cm_props = [no_update] * len(self.confusion_matrix.properties)
|
|
111
|
+
return [scatter_style, confusion_style] + scatter_props + cm_props
|
|
112
|
+
|
|
113
|
+
# Get target column for the x-axis
|
|
114
|
+
target = model.target()
|
|
115
|
+
if isinstance(target, list):
|
|
116
|
+
target = next((t for t in target if t in self.inference_run), target[0])
|
|
117
|
+
|
|
118
|
+
# Check if "confidence" column exists for coloring
|
|
119
|
+
color_col = "confidence" if "confidence" in df.columns else "prediction"
|
|
120
|
+
|
|
121
|
+
scatter_props = self.scatter_plot.update_properties(
|
|
122
|
+
df,
|
|
123
|
+
x=target,
|
|
124
|
+
y="prediction",
|
|
125
|
+
color=color_col,
|
|
126
|
+
regression_line=True,
|
|
127
|
+
)
|
|
128
|
+
cm_props = [no_update] * len(self.confusion_matrix.properties)
|
|
129
|
+
|
|
130
|
+
return [scatter_style, confusion_style] + scatter_props + cm_props
|
|
131
|
+
|
|
132
|
+
def set_theme(self, theme: str) -> list:
|
|
133
|
+
"""Re-render the appropriate plot when the theme changes."""
|
|
134
|
+
if self.model is None:
|
|
135
|
+
return [no_update] * len(self.properties)
|
|
136
|
+
|
|
137
|
+
# Just call update_properties which will re-render the right plot
|
|
138
|
+
return self.update_properties(self.model, inference_run=self.inference_run)
|
|
139
|
+
|
|
140
|
+
def register_internal_callbacks(self):
|
|
141
|
+
"""Register internal callbacks for both sub-plugins."""
|
|
142
|
+
self.scatter_plot.register_internal_callbacks()
|
|
143
|
+
self.confusion_matrix.register_internal_callbacks()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
if __name__ == "__main__":
|
|
147
|
+
"""Run the Unit Test for the Plugin."""
|
|
148
|
+
from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
|
|
149
|
+
|
|
150
|
+
# Test with a classifier (shows Confusion Matrix)
|
|
151
|
+
classifier_model = CachedModel("wine-classification")
|
|
152
|
+
PluginUnitTest(ModelPlot, input_data=classifier_model, theme="dark").run()
|
|
153
|
+
|
|
154
|
+
# Test with a regressor (shows Scatter Plot)
|
|
155
|
+
regressor_model = CachedModel("abalone-regression")
|
|
156
|
+
PluginUnitTest(ModelPlot, input_data=regressor_model, theme="dark").run()
|
|
@@ -8,7 +8,6 @@ from dash.exceptions import PreventUpdate
|
|
|
8
8
|
|
|
9
9
|
# Workbench Imports
|
|
10
10
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
11
|
-
from workbench.utils.theme_manager import ThemeManager
|
|
12
11
|
from workbench.utils.plot_utils import prediction_intervals
|
|
13
12
|
from workbench.utils.chem_utils.vis import molecule_hover_tooltip
|
|
14
13
|
from workbench.utils.clientside_callbacks import circle_overlay_callback
|
|
@@ -37,7 +36,6 @@ class ScatterPlot(PluginInterface):
|
|
|
37
36
|
self.hover_columns = []
|
|
38
37
|
self.df = None
|
|
39
38
|
self.show_axes = show_axes
|
|
40
|
-
self.theme_manager = ThemeManager()
|
|
41
39
|
self.has_smiles = False # Track if dataframe has smiles column for molecule hover
|
|
42
40
|
self.smiles_column = None
|
|
43
41
|
self.id_column = None
|
|
@@ -242,6 +240,15 @@ class ScatterPlot(PluginInterface):
|
|
|
242
240
|
|
|
243
241
|
return [figure, x_options, y_options, color_options, x_default, y_default, color_default, regression_line_value]
|
|
244
242
|
|
|
243
|
+
def set_theme(self, theme: str) -> list:
|
|
244
|
+
"""Re-render the scatter plot when the theme changes."""
|
|
245
|
+
# If no data yet, return no_update for all properties
|
|
246
|
+
if self.df is None or self.df.empty:
|
|
247
|
+
return [no_update] * len(self.properties)
|
|
248
|
+
|
|
249
|
+
# Re-render with defaults (user dropdown selections reset, but theme changes are rare)
|
|
250
|
+
return self.update_properties(self.df)
|
|
251
|
+
|
|
245
252
|
def create_scatter_plot(
|
|
246
253
|
self,
|
|
247
254
|
df: pd.DataFrame,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""SHAP Summary Plot visualization component for XGBoost models"""
|
|
2
2
|
|
|
3
|
-
from dash import dcc
|
|
3
|
+
from dash import dcc, no_update
|
|
4
4
|
import plotly.graph_objects as go
|
|
5
5
|
import pandas as pd
|
|
6
6
|
import numpy as np
|
|
@@ -9,7 +9,6 @@ from typing import Dict, List
|
|
|
9
9
|
# Workbench Imports
|
|
10
10
|
from workbench.cached.cached_model import CachedModel
|
|
11
11
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
12
|
-
from workbench.utils.theme_manager import ThemeManager
|
|
13
12
|
from workbench.utils.plot_utils import beeswarm_offsets
|
|
14
13
|
|
|
15
14
|
|
|
@@ -22,7 +21,7 @@ class ShapSummaryPlot(PluginInterface):
|
|
|
22
21
|
def __init__(self):
|
|
23
22
|
"""Initialize the ShapSummaryPlot plugin class"""
|
|
24
23
|
self.component_id = None
|
|
25
|
-
self.
|
|
24
|
+
self.model = None # Store the model for re-rendering on theme change
|
|
26
25
|
super().__init__()
|
|
27
26
|
|
|
28
27
|
def create_component(self, component_id: str) -> dcc.Graph:
|
|
@@ -39,6 +38,9 @@ class ShapSummaryPlot(PluginInterface):
|
|
|
39
38
|
|
|
40
39
|
def update_properties(self, model: CachedModel, **kwargs) -> list:
|
|
41
40
|
"""Create a SHAP Summary Plot for feature importance visualization."""
|
|
41
|
+
# Store for re-rendering on theme change
|
|
42
|
+
self.model = model
|
|
43
|
+
|
|
42
44
|
# Basic validation
|
|
43
45
|
shap_data = model.shap_data()
|
|
44
46
|
shap_sample_rows = model.shap_sample()
|
|
@@ -221,9 +223,15 @@ class ShapSummaryPlot(PluginInterface):
|
|
|
221
223
|
)
|
|
222
224
|
return main_fig
|
|
223
225
|
|
|
226
|
+
def set_theme(self, theme: str) -> list:
|
|
227
|
+
"""Re-render the SHAP summary plot when the theme changes."""
|
|
228
|
+
if self.model is None:
|
|
229
|
+
return [no_update] * len(self.properties)
|
|
230
|
+
return self.update_properties(self.model)
|
|
231
|
+
|
|
224
232
|
def register_internal_callbacks(self):
|
|
225
233
|
"""Register internal callbacks for the plugin."""
|
|
226
|
-
pass #
|
|
234
|
+
pass # No internal callbacks needed
|
|
227
235
|
|
|
228
236
|
|
|
229
237
|
if __name__ == "__main__":
|
|
@@ -100,74 +100,35 @@ class SettingsMenu:
|
|
|
100
100
|
caret=False,
|
|
101
101
|
align_end=True,
|
|
102
102
|
),
|
|
103
|
-
#
|
|
104
|
-
dcc.Store(id=f"{component_id}-dummy", data=None),
|
|
105
|
-
# Store to trigger checkmark update on load
|
|
103
|
+
# Store to trigger checkmark update on load and theme change
|
|
106
104
|
dcc.Store(id=f"{component_id}-init", data=True),
|
|
107
105
|
],
|
|
108
106
|
id=component_id,
|
|
109
107
|
)
|
|
110
108
|
|
|
111
|
-
|
|
112
|
-
def get_clientside_callback_code(component_id: str) -> str:
|
|
109
|
+
def get_clientside_callback_code(self) -> str:
|
|
113
110
|
"""Get the JavaScript code for the theme selection clientside callback.
|
|
114
111
|
|
|
115
|
-
Args:
|
|
116
|
-
component_id (str): The ID prefix used in create_component.
|
|
117
|
-
|
|
118
112
|
Returns:
|
|
119
113
|
str: JavaScript code for the clientside callback.
|
|
120
114
|
"""
|
|
121
|
-
return
|
|
122
|
-
function(n_clicks_list, ids) {
|
|
123
|
-
// Find which button was clicked
|
|
124
|
-
if (!n_clicks_list || n_clicks_list.every(n => !n)) {
|
|
125
|
-
return window.dash_clientside.no_update;
|
|
126
|
-
}
|
|
127
|
-
|
|
128
|
-
// Find the clicked theme
|
|
129
|
-
let clickedTheme = null;
|
|
130
|
-
for (let i = 0; i < n_clicks_list.length; i++) {
|
|
131
|
-
if (n_clicks_list[i]) {
|
|
132
|
-
clickedTheme = ids[i].theme;
|
|
133
|
-
break;
|
|
134
|
-
}
|
|
135
|
-
}
|
|
136
|
-
|
|
137
|
-
if (clickedTheme) {
|
|
138
|
-
// Store in localStorage
|
|
139
|
-
localStorage.setItem('wb_theme', clickedTheme);
|
|
140
|
-
// Set cookie for Flask to read on reload
|
|
141
|
-
document.cookie = `wb_theme=${clickedTheme}; path=/; max-age=31536000`;
|
|
142
|
-
// Reload the page to apply the new theme
|
|
143
|
-
window.location.reload();
|
|
144
|
-
}
|
|
145
|
-
|
|
146
|
-
return window.dash_clientside.no_update;
|
|
147
|
-
}
|
|
148
|
-
"""
|
|
115
|
+
return self.tm.get_theme_switch_js()
|
|
149
116
|
|
|
150
117
|
@staticmethod
|
|
151
118
|
def get_checkmark_callback_code() -> str:
|
|
152
|
-
"""Get the JavaScript code to update checkmarks based on
|
|
119
|
+
"""Get the JavaScript code to update checkmarks based on current theme.
|
|
153
120
|
|
|
154
121
|
Returns:
|
|
155
122
|
str: JavaScript code for the checkmark update callback.
|
|
156
123
|
"""
|
|
157
124
|
return """
|
|
158
|
-
function(
|
|
159
|
-
//
|
|
160
|
-
let currentTheme =
|
|
125
|
+
function(theme, ids) {
|
|
126
|
+
// If theme is a string (from theme switch), use it directly
|
|
127
|
+
let currentTheme = (typeof theme === 'string') ? theme : null;
|
|
128
|
+
|
|
129
|
+
// Otherwise, get from localStorage
|
|
161
130
|
if (!currentTheme) {
|
|
162
|
-
|
|
163
|
-
const cookies = document.cookie.split(';');
|
|
164
|
-
for (let cookie of cookies) {
|
|
165
|
-
const [name, value] = cookie.trim().split('=');
|
|
166
|
-
if (name === 'wb_theme') {
|
|
167
|
-
currentTheme = value;
|
|
168
|
-
break;
|
|
169
|
-
}
|
|
170
|
-
}
|
|
131
|
+
currentTheme = localStorage.getItem('wb_theme');
|
|
171
132
|
}
|
|
172
133
|
|
|
173
134
|
// Return checkmarks for each theme
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: workbench
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.239
|
|
4
4
|
Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
|
|
5
5
|
Author-email: SuperCowPowers LLC <support@supercowpowers.com>
|
|
6
6
|
License: MIT License
|
|
@@ -40,7 +40,7 @@ Requires-Dist: boto3>=1.31.76
|
|
|
40
40
|
Requires-Dist: botocore>=1.31.76
|
|
41
41
|
Requires-Dist: redis>=5.0.1
|
|
42
42
|
Requires-Dist: numpy>=1.26.4
|
|
43
|
-
Requires-Dist: pandas
|
|
43
|
+
Requires-Dist: pandas<3.0,>=2.2.1
|
|
44
44
|
Requires-Dist: awswrangler>=3.4.0
|
|
45
45
|
Requires-Dist: sagemaker<3.0,>=2.143
|
|
46
46
|
Requires-Dist: cryptography>=44.0.2
|