workbench 0.8.231__py3-none-any.whl → 0.8.236__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/smart_aggregator.py +17 -12
- workbench/api/endpoint.py +13 -4
- workbench/cached/cached_model.py +2 -2
- workbench/core/artifacts/endpoint_core.py +30 -5
- workbench/model_script_utils/model_script_utils.py +225 -0
- workbench/model_script_utils/uq_harness.py +39 -21
- workbench/model_scripts/chemprop/chemprop.template +29 -14
- workbench/model_scripts/chemprop/generated_model_script.py +35 -18
- workbench/model_scripts/chemprop/model_script_utils.py +225 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +34 -20
- workbench/model_scripts/pytorch_model/model_script_utils.py +225 -0
- workbench/model_scripts/pytorch_model/pytorch.template +28 -14
- workbench/model_scripts/pytorch_model/uq_harness.py +39 -21
- workbench/model_scripts/xgb_model/generated_model_script.py +35 -22
- workbench/model_scripts/xgb_model/model_script_utils.py +225 -0
- workbench/model_scripts/xgb_model/uq_harness.py +39 -21
- workbench/model_scripts/xgb_model/xgb_model.template +29 -18
- workbench/themes/dark/custom.css +29 -0
- workbench/themes/light/custom.css +29 -0
- workbench/themes/midnight_blue/custom.css +28 -0
- workbench/utils/markdown_utils.py +5 -1
- workbench/utils/model_utils.py +9 -0
- workbench/utils/theme_manager.py +95 -0
- workbench/web_interface/components/component_interface.py +3 -0
- workbench/web_interface/components/plugin_interface.py +26 -0
- workbench/web_interface/components/plugins/confusion_matrix.py +14 -8
- workbench/web_interface/components/plugins/model_details.py +18 -5
- workbench/web_interface/components/plugins/model_plot.py +156 -0
- workbench/web_interface/components/plugins/scatter_plot.py +9 -2
- workbench/web_interface/components/plugins/shap_summary_plot.py +12 -4
- workbench/web_interface/components/settings_menu.py +10 -49
- {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/METADATA +1 -1
- {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/RECORD +37 -37
- workbench/web_interface/components/model_plot.py +0 -75
- {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/WHEEL +0 -0
- {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.231.dist-info → workbench-0.8.236.dist-info}/top_level.txt +0 -0
workbench/themes/dark/custom.css
CHANGED
|
@@ -135,6 +135,35 @@ div:has(> [class*="ag-theme-"]) {
|
|
|
135
135
|
--bs-border-color: rgb(60, 60, 60);
|
|
136
136
|
}
|
|
137
137
|
|
|
138
|
+
/* React-select dropdown styling - target actual rendered elements */
|
|
139
|
+
.Select-control {
|
|
140
|
+
background-color: rgb(35, 35, 35) !important;
|
|
141
|
+
border-color: rgb(60, 60, 60) !important;
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
.Select-value-label, .Select-input input {
|
|
145
|
+
color: rgb(210, 210, 210) !important;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
.Select-placeholder {
|
|
149
|
+
color: rgb(150, 150, 150) !important;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
.Select-menu-outer {
|
|
153
|
+
background-color: rgb(35, 35, 35) !important;
|
|
154
|
+
border-color: rgb(60, 60, 60) !important;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
.VirtualizedSelectOption {
|
|
158
|
+
background-color: rgb(35, 35, 35) !important;
|
|
159
|
+
color: rgb(210, 210, 210) !important;
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
.VirtualizedSelectFocusedOption {
|
|
163
|
+
background-color: rgb(60, 60, 60) !important;
|
|
164
|
+
color: rgb(230, 230, 230) !important;
|
|
165
|
+
}
|
|
166
|
+
|
|
138
167
|
/* Bootstrap form controls (dbc components) */
|
|
139
168
|
.form-select, .form-control {
|
|
140
169
|
background-color: rgb(35, 35, 35) !important;
|
|
@@ -180,6 +180,35 @@ div:has(> [class*="ag-theme-"]) {
|
|
|
180
180
|
--bs-border-color: var(--wb-accent);
|
|
181
181
|
}
|
|
182
182
|
|
|
183
|
+
/* React-select dropdown styling - target actual rendered elements */
|
|
184
|
+
.Select-control {
|
|
185
|
+
background-color: var(--wb-dropdown-bg) !important;
|
|
186
|
+
border-color: var(--wb-accent) !important;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
.Select-value-label, .Select-input input {
|
|
190
|
+
color: var(--wb-text-primary) !important;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
.Select-placeholder {
|
|
194
|
+
color: var(--wb-text-muted) !important;
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
.Select-menu-outer {
|
|
198
|
+
background-color: var(--wb-dropdown-bg) !important;
|
|
199
|
+
border-color: var(--wb-accent) !important;
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
.VirtualizedSelectOption {
|
|
203
|
+
background-color: var(--wb-dropdown-bg) !important;
|
|
204
|
+
color: var(--wb-text-primary) !important;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
.VirtualizedSelectFocusedOption {
|
|
208
|
+
background-color: var(--wb-dropdown-hover) !important;
|
|
209
|
+
color: var(--wb-text-primary) !important;
|
|
210
|
+
}
|
|
211
|
+
|
|
183
212
|
/* Bootstrap form controls (dbc components) */
|
|
184
213
|
.form-select, .form-control {
|
|
185
214
|
background-color: var(--wb-dropdown-bg) !important;
|
|
@@ -133,6 +133,34 @@ div:has(> [class*="ag-theme-"]) {
|
|
|
133
133
|
--bs-border-color: rgb(80, 85, 115);
|
|
134
134
|
}
|
|
135
135
|
|
|
136
|
+
/* React-select dropdown styling - target actual rendered elements */
|
|
137
|
+
.Select-control {
|
|
138
|
+
background-color: rgb(55, 60, 90) !important;
|
|
139
|
+
border-color: rgb(80, 85, 115) !important;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
.Select-value-label, .Select-input input {
|
|
143
|
+
color: rgb(210, 210, 210) !important;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
.Select-placeholder {
|
|
147
|
+
color: rgb(150, 150, 170) !important;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
.Select-menu-outer {
|
|
151
|
+
background-color: rgb(55, 60, 90) !important;
|
|
152
|
+
border-color: rgb(80, 85, 115) !important;
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
.VirtualizedSelectOption {
|
|
156
|
+
background-color: rgb(55, 60, 90) !important;
|
|
157
|
+
color: rgb(210, 210, 210) !important;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
.VirtualizedSelectFocusedOption {
|
|
161
|
+
background-color: rgb(70, 75, 110) !important;
|
|
162
|
+
color: rgb(230, 230, 230) !important;
|
|
163
|
+
}
|
|
136
164
|
|
|
137
165
|
/* Bootstrap form controls (dbc components) */
|
|
138
166
|
.form-select, .form-control {
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Markdown Utility/helper methods"""
|
|
2
2
|
|
|
3
|
+
import math
|
|
4
|
+
|
|
3
5
|
from workbench.utils.symbols import health_icons
|
|
4
6
|
|
|
5
7
|
|
|
@@ -229,7 +231,9 @@ def df_to_html_table(df, round_digits: int = 2, margin_bottom: int = 30) -> str:
|
|
|
229
231
|
for val in row:
|
|
230
232
|
# Format value: integers without decimal, floats rounded
|
|
231
233
|
if isinstance(val, float):
|
|
232
|
-
if
|
|
234
|
+
if math.isnan(val):
|
|
235
|
+
formatted_val = "NaN"
|
|
236
|
+
elif val == int(val):
|
|
233
237
|
formatted_val = int(val)
|
|
234
238
|
else:
|
|
235
239
|
formatted_val = round(val, round_digits)
|
workbench/utils/model_utils.py
CHANGED
|
@@ -459,6 +459,12 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
459
459
|
# Spearman correlation for robustness
|
|
460
460
|
interval_to_error_corr = spearmanr(width_68, abs_residuals)[0]
|
|
461
461
|
|
|
462
|
+
# --- Confidence to Error Correlation ---
|
|
463
|
+
# If confidence column exists, compute correlation (should be negative: high confidence = low error)
|
|
464
|
+
confidence_to_error_corr = None
|
|
465
|
+
if "confidence" in df.columns:
|
|
466
|
+
confidence_to_error_corr = spearmanr(df["confidence"], abs_residuals)[0]
|
|
467
|
+
|
|
462
468
|
# Collect results
|
|
463
469
|
results = {
|
|
464
470
|
"coverage_68": coverage_68,
|
|
@@ -472,6 +478,7 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
472
478
|
"median_width_90": median_width_90,
|
|
473
479
|
"median_width_95": median_width_95,
|
|
474
480
|
"interval_to_error_corr": interval_to_error_corr,
|
|
481
|
+
"confidence_to_error_corr": confidence_to_error_corr,
|
|
475
482
|
"n_samples": len(df),
|
|
476
483
|
}
|
|
477
484
|
|
|
@@ -489,6 +496,8 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
489
496
|
print(f"CRPS: {mean_crps:.3f} (lower is better)")
|
|
490
497
|
print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
|
|
491
498
|
print(f"Interval/Error Corr: {interval_to_error_corr:.3f} (higher is better, target: >0.5)")
|
|
499
|
+
if confidence_to_error_corr is not None:
|
|
500
|
+
print(f"Confidence/Error Corr: {confidence_to_error_corr:.3f} (lower is better, target: <-0.5)")
|
|
492
501
|
print(f"Samples: {len(df)}")
|
|
493
502
|
return results
|
|
494
503
|
|
workbench/utils/theme_manager.py
CHANGED
|
@@ -155,6 +155,101 @@ class ThemeManager:
|
|
|
155
155
|
"""Get the name of the current theme."""
|
|
156
156
|
return cls.current_theme_name
|
|
157
157
|
|
|
158
|
+
@classmethod
|
|
159
|
+
def get_theme_urls(cls) -> dict:
|
|
160
|
+
"""Get a mapping of theme names to their Bootstrap CSS URLs."""
|
|
161
|
+
return {name: theme["base_css"] for name, theme in cls.available_themes.items() if theme["base_css"]}
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
def get_dark_themes(cls) -> list:
|
|
165
|
+
"""Get list of theme names that are dark mode themes."""
|
|
166
|
+
dark_themes = []
|
|
167
|
+
for name, theme in cls.available_themes.items():
|
|
168
|
+
base_css = theme.get("base_css", "")
|
|
169
|
+
if base_css:
|
|
170
|
+
base_css_upper = base_css.upper()
|
|
171
|
+
for dark_theme in cls._dark_bootstrap_themes:
|
|
172
|
+
if dark_theme in base_css_upper:
|
|
173
|
+
dark_themes.append(name)
|
|
174
|
+
break
|
|
175
|
+
else:
|
|
176
|
+
# Fallback: check if 'dark' is in the theme name
|
|
177
|
+
if "dark" in name.lower():
|
|
178
|
+
dark_themes.append(name)
|
|
179
|
+
return dark_themes
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def get_theme_switch_js(cls) -> str:
|
|
183
|
+
"""Get the JavaScript code for dynamic theme switching.
|
|
184
|
+
|
|
185
|
+
This code should be used in a clientside callback that updates the Bootstrap
|
|
186
|
+
stylesheet and data-bs-theme attribute when the user selects a new theme.
|
|
187
|
+
"""
|
|
188
|
+
theme_urls = cls.get_theme_urls()
|
|
189
|
+
dark_themes = cls.get_dark_themes()
|
|
190
|
+
|
|
191
|
+
return f"""
|
|
192
|
+
function(n_clicks_list, ids) {{
|
|
193
|
+
// Use callback context to find which input triggered this callback
|
|
194
|
+
const ctx = window.dash_clientside.callback_context;
|
|
195
|
+
if (!ctx || !ctx.triggered || ctx.triggered.length === 0) {{
|
|
196
|
+
return window.dash_clientside.no_update;
|
|
197
|
+
}}
|
|
198
|
+
|
|
199
|
+
// Get the triggered input's id (it's a JSON string for pattern-matching callbacks)
|
|
200
|
+
const triggeredId = ctx.triggered[0].prop_id.split('.')[0];
|
|
201
|
+
let clickedTheme = null;
|
|
202
|
+
|
|
203
|
+
try {{
|
|
204
|
+
const parsedId = JSON.parse(triggeredId);
|
|
205
|
+
clickedTheme = parsedId.theme;
|
|
206
|
+
}} catch (e) {{
|
|
207
|
+
// Not a pattern-matching callback ID
|
|
208
|
+
return window.dash_clientside.no_update;
|
|
209
|
+
}}
|
|
210
|
+
|
|
211
|
+
if (!clickedTheme) {{
|
|
212
|
+
return window.dash_clientside.no_update;
|
|
213
|
+
}}
|
|
214
|
+
|
|
215
|
+
// Store in localStorage and cookie
|
|
216
|
+
localStorage.setItem('wb_theme', clickedTheme);
|
|
217
|
+
document.cookie = `wb_theme=${{clickedTheme}}; path=/; max-age=31536000`;
|
|
218
|
+
|
|
219
|
+
// Theme URL mapping (generated from ThemeManager)
|
|
220
|
+
const themeUrls = {json.dumps(theme_urls)};
|
|
221
|
+
const darkThemes = {json.dumps(dark_themes)};
|
|
222
|
+
|
|
223
|
+
// Update Bootstrap's data-bs-theme attribute
|
|
224
|
+
const bsTheme = darkThemes.includes(clickedTheme) ? 'dark' : 'light';
|
|
225
|
+
document.documentElement.setAttribute('data-bs-theme', bsTheme);
|
|
226
|
+
|
|
227
|
+
// Find and update the Bootstrap stylesheet
|
|
228
|
+
const newUrl = themeUrls[clickedTheme];
|
|
229
|
+
if (newUrl) {{
|
|
230
|
+
let stylesheet = document.querySelector('link[href*="bootswatch"]') ||
|
|
231
|
+
document.querySelector('link[href*="bootstrap"]') ||
|
|
232
|
+
document.querySelector('link[href="/base.css"]');
|
|
233
|
+
if (stylesheet) {{
|
|
234
|
+
stylesheet.setAttribute('href', newUrl);
|
|
235
|
+
}} else {{
|
|
236
|
+
stylesheet = document.createElement('link');
|
|
237
|
+
stylesheet.rel = 'stylesheet';
|
|
238
|
+
stylesheet.href = newUrl;
|
|
239
|
+
document.head.appendChild(stylesheet);
|
|
240
|
+
}}
|
|
241
|
+
}}
|
|
242
|
+
|
|
243
|
+
// Force reload custom.css with cache-busting query param
|
|
244
|
+
let customCss = document.querySelector('link[href^="/custom.css"]');
|
|
245
|
+
if (customCss) {{
|
|
246
|
+
customCss.setAttribute('href', `/custom.css?t=${{Date.now()}}`);
|
|
247
|
+
}}
|
|
248
|
+
|
|
249
|
+
return clickedTheme;
|
|
250
|
+
}}
|
|
251
|
+
"""
|
|
252
|
+
|
|
158
253
|
@classmethod
|
|
159
254
|
def background(cls) -> list[list[float | str]]:
|
|
160
255
|
"""Get the plot background for the current theme."""
|
|
@@ -136,6 +136,7 @@ class ComponentInterface(ABC):
|
|
|
136
136
|
xref="paper",
|
|
137
137
|
yref="paper",
|
|
138
138
|
text=text_message,
|
|
139
|
+
bgcolor="rgba(0,0,0,0)",
|
|
139
140
|
showarrow=False,
|
|
140
141
|
font=dict(size=font_size, color="#9999cc"),
|
|
141
142
|
)
|
|
@@ -144,6 +145,8 @@ class ComponentInterface(ABC):
|
|
|
144
145
|
xaxis=dict(showticklabels=False, zeroline=False, showgrid=False),
|
|
145
146
|
yaxis=dict(showticklabels=False, zeroline=False, showgrid=False),
|
|
146
147
|
margin=dict(l=0, r=0, b=0, t=0),
|
|
148
|
+
paper_bgcolor="rgba(0,0,0,0)",
|
|
149
|
+
plot_bgcolor="rgba(0,0,0,0)",
|
|
147
150
|
)
|
|
148
151
|
|
|
149
152
|
if figure_height is not None:
|
|
@@ -3,10 +3,12 @@ import inspect
|
|
|
3
3
|
from typing import Union, Tuple, get_args, get_origin
|
|
4
4
|
from enum import Enum
|
|
5
5
|
import logging
|
|
6
|
+
from dash import no_update
|
|
6
7
|
from dash.development.base_component import Component
|
|
7
8
|
|
|
8
9
|
# Local Imports
|
|
9
10
|
from workbench.web_interface.components.component_interface import ComponentInterface
|
|
11
|
+
from workbench.utils.theme_manager import ThemeManager
|
|
10
12
|
|
|
11
13
|
log = logging.getLogger("workbench")
|
|
12
14
|
|
|
@@ -48,6 +50,9 @@ class PluginInterface(ComponentInterface):
|
|
|
48
50
|
- The 'register_internal_callbacks' method is optional
|
|
49
51
|
"""
|
|
50
52
|
|
|
53
|
+
# Shared ThemeManager instance for all plugins
|
|
54
|
+
theme_manager = ThemeManager()
|
|
55
|
+
|
|
51
56
|
@abstractmethod
|
|
52
57
|
def create_component(self, component_id: str) -> Component:
|
|
53
58
|
"""Create a Dash Component without any data.
|
|
@@ -83,6 +88,27 @@ class PluginInterface(ComponentInterface):
|
|
|
83
88
|
"""
|
|
84
89
|
pass
|
|
85
90
|
|
|
91
|
+
def set_theme(self, theme: str) -> list:
|
|
92
|
+
"""Called when the application theme changes. Override to re-render with new theme colors.
|
|
93
|
+
|
|
94
|
+
The default implementation returns no_update for all properties. Plugins that have
|
|
95
|
+
theme-dependent rendering (e.g., figures with colorscales) should override this method
|
|
96
|
+
to re-render their components.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
theme (str): The name of the new theme (e.g., "light", "dark", "midnight_blue").
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
list: Updated property values (same format as update_properties), or no_update list.
|
|
103
|
+
|
|
104
|
+
Example:
|
|
105
|
+
def set_theme(self, theme: str) -> list:
|
|
106
|
+
if self.model is None:
|
|
107
|
+
return [no_update] * len(self.properties)
|
|
108
|
+
return self.update_properties(self.model)
|
|
109
|
+
"""
|
|
110
|
+
return [no_update] * len(self.properties)
|
|
111
|
+
|
|
86
112
|
#
|
|
87
113
|
# Internal Methods: These methods are used to validate the plugin interface at runtime
|
|
88
114
|
#
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
"""A confusion matrix plugin component"""
|
|
2
2
|
|
|
3
|
-
from dash import dcc, callback, Output, Input, State
|
|
3
|
+
from dash import dcc, callback, Output, Input, State, no_update
|
|
4
4
|
import plotly.graph_objects as go
|
|
5
5
|
|
|
6
6
|
# Workbench Imports
|
|
7
7
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
8
|
-
from workbench.utils.theme_manager import ThemeManager
|
|
9
8
|
from workbench.cached.cached_model import CachedModel
|
|
10
9
|
from workbench.utils.color_utils import add_alpha_to_first_color
|
|
11
10
|
|
|
@@ -19,10 +18,8 @@ class ConfusionMatrix(PluginInterface):
|
|
|
19
18
|
def __init__(self):
|
|
20
19
|
"""Initialize the ConfusionMatrix plugin class"""
|
|
21
20
|
self.component_id = None
|
|
22
|
-
self.
|
|
23
|
-
self.
|
|
24
|
-
|
|
25
|
-
# Call the parent class constructor
|
|
21
|
+
self.model = None # Store the model for re-rendering on theme change
|
|
22
|
+
self.inference_run = None # Store the inference run for re-rendering
|
|
26
23
|
super().__init__()
|
|
27
24
|
|
|
28
25
|
def create_component(self, component_id: str) -> dcc.Graph:
|
|
@@ -57,9 +54,12 @@ class ConfusionMatrix(PluginInterface):
|
|
|
57
54
|
Returns:
|
|
58
55
|
list: A list containing the updated Plotly figure.
|
|
59
56
|
"""
|
|
57
|
+
# Store for re-rendering on theme change
|
|
58
|
+
self.model = model
|
|
59
|
+
self.inference_run = kwargs.get("inference_run", "auto_inference")
|
|
60
|
+
|
|
60
61
|
# Retrieve the confusion matrix data
|
|
61
|
-
|
|
62
|
-
df = model.confusion_matrix(inference_run)
|
|
62
|
+
df = model.confusion_matrix(self.inference_run)
|
|
63
63
|
if df is None:
|
|
64
64
|
return [self.display_text("No Data")]
|
|
65
65
|
|
|
@@ -137,6 +137,12 @@ class ConfusionMatrix(PluginInterface):
|
|
|
137
137
|
# Return the updated figure wrapped in a list
|
|
138
138
|
return [fig]
|
|
139
139
|
|
|
140
|
+
def set_theme(self, theme: str) -> list:
|
|
141
|
+
"""Re-render the confusion matrix when the theme changes."""
|
|
142
|
+
if self.model is None:
|
|
143
|
+
return [no_update] * len(self.properties)
|
|
144
|
+
return self.update_properties(self.model, inference_run=self.inference_run)
|
|
145
|
+
|
|
140
146
|
def register_internal_callbacks(self):
|
|
141
147
|
"""Register internal callbacks for the plugin."""
|
|
142
148
|
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from typing import Union
|
|
4
4
|
|
|
5
5
|
# Dash Imports
|
|
6
|
-
from dash import html, callback, dcc, Input, Output, State
|
|
6
|
+
from dash import html, callback, dcc, no_update, Input, Output, State
|
|
7
7
|
|
|
8
8
|
# Workbench Imports
|
|
9
9
|
from workbench.api import ModelType, ParameterStore
|
|
@@ -71,7 +71,9 @@ class ModelDetails(PluginInterface):
|
|
|
71
71
|
|
|
72
72
|
Args:
|
|
73
73
|
model (CachedModel): An instantiated CachedModel object
|
|
74
|
-
**kwargs: Additional keyword arguments
|
|
74
|
+
**kwargs: Additional keyword arguments
|
|
75
|
+
- inference_run: Current inference run selection (to preserve user's choice)
|
|
76
|
+
- previous_model_name: Name of the previously selected model
|
|
75
77
|
|
|
76
78
|
Returns:
|
|
77
79
|
list: A list of the updated property values for the plugin
|
|
@@ -85,10 +87,21 @@ class ModelDetails(PluginInterface):
|
|
|
85
87
|
|
|
86
88
|
# Populate the inference runs dropdown
|
|
87
89
|
inference_runs, default_run = self.get_inference_runs()
|
|
88
|
-
metrics = self.inference_metrics(default_run)
|
|
89
90
|
|
|
90
|
-
#
|
|
91
|
-
|
|
91
|
+
# Check if the model changed
|
|
92
|
+
previous_model_name = kwargs.get("previous_model_name")
|
|
93
|
+
current_inference_run = kwargs.get("inference_run")
|
|
94
|
+
model_changed = previous_model_name != model.name
|
|
95
|
+
|
|
96
|
+
# Only preserve the inference run if the model hasn't changed AND the selection is valid
|
|
97
|
+
if not model_changed and current_inference_run and current_inference_run in inference_runs:
|
|
98
|
+
# Same model, preserve the user's selection - use no_update for dropdown value
|
|
99
|
+
metrics = self.inference_metrics(current_inference_run)
|
|
100
|
+
return [header, details, inference_runs, no_update, metrics]
|
|
101
|
+
else:
|
|
102
|
+
# New model or invalid selection - use default
|
|
103
|
+
metrics = self.inference_metrics(default_run)
|
|
104
|
+
return [header, details, inference_runs, default_run, metrics]
|
|
92
105
|
|
|
93
106
|
def register_internal_callbacks(self):
|
|
94
107
|
@callback(
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""A Model Plot plugin that displays the appropriate visualization based on model type.
|
|
2
|
+
|
|
3
|
+
For classifiers: Shows a Confusion Matrix
|
|
4
|
+
For regressors: Shows a Scatter Plot of predictions vs actuals
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dash import html, no_update
|
|
8
|
+
|
|
9
|
+
# Workbench Imports
|
|
10
|
+
from workbench.api import ModelType
|
|
11
|
+
from workbench.cached.cached_model import CachedModel
|
|
12
|
+
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
13
|
+
from workbench.web_interface.components.plugins.confusion_matrix import ConfusionMatrix
|
|
14
|
+
from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ModelPlot(PluginInterface):
|
|
18
|
+
"""Model Plot Plugin - switches between ConfusionMatrix and ScatterPlot based on model type."""
|
|
19
|
+
|
|
20
|
+
auto_load_page = PluginPage.NONE
|
|
21
|
+
plugin_input_type = PluginInputType.MODEL
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
"""Initialize the ModelPlot plugin class"""
|
|
25
|
+
self.component_id = None
|
|
26
|
+
self.model = None
|
|
27
|
+
self.inference_run = None
|
|
28
|
+
|
|
29
|
+
# Internal plugins
|
|
30
|
+
self.scatter_plot = ScatterPlot()
|
|
31
|
+
self.confusion_matrix = ConfusionMatrix()
|
|
32
|
+
|
|
33
|
+
# Call the parent class constructor
|
|
34
|
+
super().__init__()
|
|
35
|
+
|
|
36
|
+
def create_component(self, component_id: str) -> html.Div:
|
|
37
|
+
"""Create a container with both ScatterPlot and ConfusionMatrix components.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
component_id (str): The ID of the web component
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
html.Div: Container with both plot types (one hidden based on model type)
|
|
44
|
+
"""
|
|
45
|
+
self.component_id = component_id
|
|
46
|
+
|
|
47
|
+
# Create internal components
|
|
48
|
+
scatter_component = self.scatter_plot.create_component(f"{component_id}-scatter")
|
|
49
|
+
confusion_component = self.confusion_matrix.create_component(f"{component_id}-confusion")
|
|
50
|
+
|
|
51
|
+
# Build properties list: visibility styles + scatter props + confusion props
|
|
52
|
+
self.properties = [
|
|
53
|
+
(f"{component_id}-scatter-container", "style"),
|
|
54
|
+
(f"{component_id}-confusion-container", "style"),
|
|
55
|
+
]
|
|
56
|
+
self.properties.extend(self.scatter_plot.properties)
|
|
57
|
+
self.properties.extend(self.confusion_matrix.properties)
|
|
58
|
+
|
|
59
|
+
# Aggregate signals from both plugins
|
|
60
|
+
self.signals = self.scatter_plot.signals + self.confusion_matrix.signals
|
|
61
|
+
|
|
62
|
+
# Create container with both components
|
|
63
|
+
# Show scatter plot by default (will display "Waiting for Data..." until model loads)
|
|
64
|
+
return html.Div(
|
|
65
|
+
id=component_id,
|
|
66
|
+
children=[
|
|
67
|
+
html.Div(
|
|
68
|
+
scatter_component,
|
|
69
|
+
id=f"{component_id}-scatter-container",
|
|
70
|
+
style={"display": "block"},
|
|
71
|
+
),
|
|
72
|
+
html.Div(
|
|
73
|
+
confusion_component,
|
|
74
|
+
id=f"{component_id}-confusion-container",
|
|
75
|
+
style={"display": "none"},
|
|
76
|
+
),
|
|
77
|
+
],
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def update_properties(self, model: CachedModel, **kwargs) -> list:
|
|
81
|
+
"""Update the plot based on model type.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
model (CachedModel): The model to visualize
|
|
85
|
+
**kwargs:
|
|
86
|
+
- inference_run (str): Inference capture name (default: "auto_inference")
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
list: Property values [scatter_style, confusion_style, ...scatter_props, ...confusion_props]
|
|
90
|
+
"""
|
|
91
|
+
# Cache for theme re-rendering
|
|
92
|
+
self.model = model
|
|
93
|
+
self.inference_run = kwargs.get("inference_run", "full_cross_fold")
|
|
94
|
+
|
|
95
|
+
# Determine model type and set visibility
|
|
96
|
+
is_classifier = model.model_type == ModelType.CLASSIFIER
|
|
97
|
+
scatter_style = {"display": "none"} if is_classifier else {"display": "block"}
|
|
98
|
+
confusion_style = {"display": "block"} if is_classifier else {"display": "none"}
|
|
99
|
+
|
|
100
|
+
if is_classifier:
|
|
101
|
+
# Update ConfusionMatrix, no_update for ScatterPlot
|
|
102
|
+
cm_props = self.confusion_matrix.update_properties(model, inference_run=self.inference_run)
|
|
103
|
+
scatter_props = [no_update] * len(self.scatter_plot.properties)
|
|
104
|
+
else:
|
|
105
|
+
# Update ScatterPlot with regression data
|
|
106
|
+
df = model.get_inference_predictions(self.inference_run)
|
|
107
|
+
if df is None:
|
|
108
|
+
# Still update visibility styles, but no_update for plugin properties
|
|
109
|
+
scatter_props = [no_update] * len(self.scatter_plot.properties)
|
|
110
|
+
cm_props = [no_update] * len(self.confusion_matrix.properties)
|
|
111
|
+
return [scatter_style, confusion_style] + scatter_props + cm_props
|
|
112
|
+
|
|
113
|
+
# Get target column for the x-axis
|
|
114
|
+
target = model.target()
|
|
115
|
+
if isinstance(target, list):
|
|
116
|
+
target = next((t for t in target if t in self.inference_run), target[0])
|
|
117
|
+
|
|
118
|
+
# Check if "confidence" column exists for coloring
|
|
119
|
+
color_col = "confidence" if "confidence" in df.columns else "prediction"
|
|
120
|
+
|
|
121
|
+
scatter_props = self.scatter_plot.update_properties(
|
|
122
|
+
df,
|
|
123
|
+
x=target,
|
|
124
|
+
y="prediction",
|
|
125
|
+
color=color_col,
|
|
126
|
+
regression_line=True,
|
|
127
|
+
)
|
|
128
|
+
cm_props = [no_update] * len(self.confusion_matrix.properties)
|
|
129
|
+
|
|
130
|
+
return [scatter_style, confusion_style] + scatter_props + cm_props
|
|
131
|
+
|
|
132
|
+
def set_theme(self, theme: str) -> list:
|
|
133
|
+
"""Re-render the appropriate plot when the theme changes."""
|
|
134
|
+
if self.model is None:
|
|
135
|
+
return [no_update] * len(self.properties)
|
|
136
|
+
|
|
137
|
+
# Just call update_properties which will re-render the right plot
|
|
138
|
+
return self.update_properties(self.model, inference_run=self.inference_run)
|
|
139
|
+
|
|
140
|
+
def register_internal_callbacks(self):
|
|
141
|
+
"""Register internal callbacks for both sub-plugins."""
|
|
142
|
+
self.scatter_plot.register_internal_callbacks()
|
|
143
|
+
self.confusion_matrix.register_internal_callbacks()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
if __name__ == "__main__":
|
|
147
|
+
"""Run the Unit Test for the Plugin."""
|
|
148
|
+
from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
|
|
149
|
+
|
|
150
|
+
# Test with a classifier (shows Confusion Matrix)
|
|
151
|
+
classifier_model = CachedModel("wine-classification")
|
|
152
|
+
PluginUnitTest(ModelPlot, input_data=classifier_model, theme="dark").run()
|
|
153
|
+
|
|
154
|
+
# Test with a regressor (shows Scatter Plot)
|
|
155
|
+
regressor_model = CachedModel("abalone-regression")
|
|
156
|
+
PluginUnitTest(ModelPlot, input_data=regressor_model, theme="dark").run()
|
|
@@ -8,7 +8,6 @@ from dash.exceptions import PreventUpdate
|
|
|
8
8
|
|
|
9
9
|
# Workbench Imports
|
|
10
10
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
11
|
-
from workbench.utils.theme_manager import ThemeManager
|
|
12
11
|
from workbench.utils.plot_utils import prediction_intervals
|
|
13
12
|
from workbench.utils.chem_utils.vis import molecule_hover_tooltip
|
|
14
13
|
from workbench.utils.clientside_callbacks import circle_overlay_callback
|
|
@@ -37,7 +36,6 @@ class ScatterPlot(PluginInterface):
|
|
|
37
36
|
self.hover_columns = []
|
|
38
37
|
self.df = None
|
|
39
38
|
self.show_axes = show_axes
|
|
40
|
-
self.theme_manager = ThemeManager()
|
|
41
39
|
self.has_smiles = False # Track if dataframe has smiles column for molecule hover
|
|
42
40
|
self.smiles_column = None
|
|
43
41
|
self.id_column = None
|
|
@@ -242,6 +240,15 @@ class ScatterPlot(PluginInterface):
|
|
|
242
240
|
|
|
243
241
|
return [figure, x_options, y_options, color_options, x_default, y_default, color_default, regression_line_value]
|
|
244
242
|
|
|
243
|
+
def set_theme(self, theme: str) -> list:
|
|
244
|
+
"""Re-render the scatter plot when the theme changes."""
|
|
245
|
+
# If no data yet, return no_update for all properties
|
|
246
|
+
if self.df is None or self.df.empty:
|
|
247
|
+
return [no_update] * len(self.properties)
|
|
248
|
+
|
|
249
|
+
# Re-render with defaults (user dropdown selections reset, but theme changes are rare)
|
|
250
|
+
return self.update_properties(self.df)
|
|
251
|
+
|
|
245
252
|
def create_scatter_plot(
|
|
246
253
|
self,
|
|
247
254
|
df: pd.DataFrame,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""SHAP Summary Plot visualization component for XGBoost models"""
|
|
2
2
|
|
|
3
|
-
from dash import dcc
|
|
3
|
+
from dash import dcc, no_update
|
|
4
4
|
import plotly.graph_objects as go
|
|
5
5
|
import pandas as pd
|
|
6
6
|
import numpy as np
|
|
@@ -9,7 +9,6 @@ from typing import Dict, List
|
|
|
9
9
|
# Workbench Imports
|
|
10
10
|
from workbench.cached.cached_model import CachedModel
|
|
11
11
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
12
|
-
from workbench.utils.theme_manager import ThemeManager
|
|
13
12
|
from workbench.utils.plot_utils import beeswarm_offsets
|
|
14
13
|
|
|
15
14
|
|
|
@@ -22,7 +21,7 @@ class ShapSummaryPlot(PluginInterface):
|
|
|
22
21
|
def __init__(self):
|
|
23
22
|
"""Initialize the ShapSummaryPlot plugin class"""
|
|
24
23
|
self.component_id = None
|
|
25
|
-
self.
|
|
24
|
+
self.model = None # Store the model for re-rendering on theme change
|
|
26
25
|
super().__init__()
|
|
27
26
|
|
|
28
27
|
def create_component(self, component_id: str) -> dcc.Graph:
|
|
@@ -39,6 +38,9 @@ class ShapSummaryPlot(PluginInterface):
|
|
|
39
38
|
|
|
40
39
|
def update_properties(self, model: CachedModel, **kwargs) -> list:
|
|
41
40
|
"""Create a SHAP Summary Plot for feature importance visualization."""
|
|
41
|
+
# Store for re-rendering on theme change
|
|
42
|
+
self.model = model
|
|
43
|
+
|
|
42
44
|
# Basic validation
|
|
43
45
|
shap_data = model.shap_data()
|
|
44
46
|
shap_sample_rows = model.shap_sample()
|
|
@@ -221,9 +223,15 @@ class ShapSummaryPlot(PluginInterface):
|
|
|
221
223
|
)
|
|
222
224
|
return main_fig
|
|
223
225
|
|
|
226
|
+
def set_theme(self, theme: str) -> list:
|
|
227
|
+
"""Re-render the SHAP summary plot when the theme changes."""
|
|
228
|
+
if self.model is None:
|
|
229
|
+
return [no_update] * len(self.properties)
|
|
230
|
+
return self.update_properties(self.model)
|
|
231
|
+
|
|
224
232
|
def register_internal_callbacks(self):
|
|
225
233
|
"""Register internal callbacks for the plugin."""
|
|
226
|
-
pass #
|
|
234
|
+
pass # No internal callbacks needed
|
|
227
235
|
|
|
228
236
|
|
|
229
237
|
if __name__ == "__main__":
|