workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +2 -0
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
- workbench/algorithms/dataframe/projection_2d.py +8 -2
- workbench/algorithms/dataframe/proximity.py +3 -0
- workbench/algorithms/dataframe/smart_aggregator.py +161 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/api/feature_set.py +0 -1
- workbench/api/meta.py +0 -1
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +37 -7
- workbench/core/artifacts/endpoint_core.py +12 -2
- workbench/core/artifacts/feature_set_core.py +238 -225
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/transforms/features_to_model/features_to_model.py +2 -8
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
- workbench/model_script_utils/model_script_utils.py +30 -0
- workbench/model_script_utils/uq_harness.py +0 -1
- workbench/model_scripts/chemprop/chemprop.template +196 -68
- workbench/model_scripts/chemprop/generated_model_script.py +197 -72
- workbench/model_scripts/chemprop/model_script_utils.py +30 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
- workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
- workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
- workbench/model_scripts/pytorch_model/pytorch.template +47 -29
- workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
- workbench/model_scripts/script_generation.py +0 -1
- workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
- workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
- workbench/model_scripts/xgb_model/uq_harness.py +0 -1
- workbench/scripts/ml_pipeline_sqs.py +71 -2
- workbench/themes/dark/custom.css +85 -8
- workbench/themes/dark/plotly.json +6 -6
- workbench/themes/light/custom.css +172 -64
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +82 -29
- workbench/themes/midnight_blue/plotly.json +1 -1
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +137 -27
- workbench/utils/clientside_callbacks.py +41 -0
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/model_utils.py +0 -1
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +52 -36
- workbench/utils/theme_manager.py +95 -30
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +2 -0
- workbench/web_interface/components/plugin_unit_test.py +0 -1
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +10 -6
- workbench/web_interface/components/plugins/scatter_plot.py +184 -85
- workbench/web_interface/components/settings_menu.py +185 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
workbench/utils/theme_manager.py
CHANGED
|
@@ -76,10 +76,28 @@ class ThemeManager:
|
|
|
76
76
|
def set_theme(cls, theme_name: str):
|
|
77
77
|
"""Set the current theme."""
|
|
78
78
|
|
|
79
|
-
# For 'auto', we
|
|
80
|
-
#
|
|
79
|
+
# For 'auto', we check multiple sources in priority order:
|
|
80
|
+
# 1. Browser cookie (from localStorage, for per-user preference)
|
|
81
|
+
# 2. Parameter Store (for org-wide default)
|
|
82
|
+
# 3. Default theme
|
|
81
83
|
if theme_name == "auto":
|
|
82
|
-
theme_name =
|
|
84
|
+
theme_name = None
|
|
85
|
+
|
|
86
|
+
# 1. Check Flask request cookie (set from localStorage)
|
|
87
|
+
try:
|
|
88
|
+
from flask import request, has_request_context
|
|
89
|
+
|
|
90
|
+
if has_request_context():
|
|
91
|
+
theme_name = request.cookies.get("wb_theme")
|
|
92
|
+
except Exception:
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
# 2. Fall back to ParameterStore
|
|
96
|
+
if not theme_name:
|
|
97
|
+
theme_name = cls.ps.get("/workbench/dashboard/theme", warn=False)
|
|
98
|
+
|
|
99
|
+
# 3. Fall back to default
|
|
100
|
+
theme_name = theme_name or cls.default_theme
|
|
83
101
|
|
|
84
102
|
# Check if the theme is in our available themes
|
|
85
103
|
if theme_name not in cls.available_themes:
|
|
@@ -104,9 +122,27 @@ class ThemeManager:
|
|
|
104
122
|
cls.current_theme_name = theme_name
|
|
105
123
|
cls.log.info(f"Theme set to '{theme_name}'")
|
|
106
124
|
|
|
125
|
+
# Bootstrap themes that are dark mode (from Bootswatch)
|
|
126
|
+
_dark_bootstrap_themes = {"DARKLY", "CYBORG", "SLATE", "SOLAR", "SUPERHERO", "VAPOR"}
|
|
127
|
+
|
|
107
128
|
@classmethod
|
|
108
129
|
def dark_mode(cls) -> bool:
|
|
109
|
-
"""Check if the current theme is a dark mode theme.
|
|
130
|
+
"""Check if the current theme is a dark mode theme.
|
|
131
|
+
|
|
132
|
+
Determines dark mode by checking if the Bootstrap base theme is a known dark theme.
|
|
133
|
+
Falls back to checking if 'dark' is in the theme name.
|
|
134
|
+
"""
|
|
135
|
+
theme = cls.available_themes.get(cls.current_theme_name, {})
|
|
136
|
+
base_css = theme.get("base_css", "")
|
|
137
|
+
|
|
138
|
+
# Check if the base CSS URL contains a known dark Bootstrap theme
|
|
139
|
+
if base_css:
|
|
140
|
+
base_css_upper = base_css.upper()
|
|
141
|
+
for dark_theme in cls._dark_bootstrap_themes:
|
|
142
|
+
if dark_theme in base_css_upper:
|
|
143
|
+
return True
|
|
144
|
+
|
|
145
|
+
# Fallback: check if 'dark' is in the theme name
|
|
110
146
|
return "dark" in cls.current_theme().lower()
|
|
111
147
|
|
|
112
148
|
@classmethod
|
|
@@ -184,30 +220,57 @@ class ThemeManager:
|
|
|
184
220
|
|
|
185
221
|
@classmethod
|
|
186
222
|
def css_files(cls) -> list[str]:
|
|
187
|
-
"""Get the list of CSS files for the current theme.
|
|
188
|
-
|
|
223
|
+
"""Get the list of CSS files for the current theme.
|
|
224
|
+
|
|
225
|
+
Note: Uses /base.css route for dynamic theme switching instead of CDN URLs.
|
|
226
|
+
"""
|
|
189
227
|
css_files = []
|
|
190
228
|
|
|
191
|
-
#
|
|
192
|
-
|
|
193
|
-
css_files.append(theme["base_css"])
|
|
229
|
+
# Use Flask route for base CSS (allows dynamic theme switching)
|
|
230
|
+
css_files.append("/base.css")
|
|
194
231
|
|
|
195
232
|
# Add the DBC template CSS
|
|
196
233
|
css_files.append(cls.dbc_css)
|
|
197
234
|
|
|
198
235
|
# Add custom.css if it exists
|
|
199
|
-
|
|
200
|
-
css_files.append("/custom.css")
|
|
236
|
+
css_files.append("/custom.css")
|
|
201
237
|
|
|
202
238
|
return css_files
|
|
203
239
|
|
|
240
|
+
@classmethod
|
|
241
|
+
def _get_theme_from_cookie(cls):
|
|
242
|
+
"""Get the theme dict based on the wb_theme cookie, falling back to current theme."""
|
|
243
|
+
from flask import request
|
|
244
|
+
|
|
245
|
+
theme_name = request.cookies.get("wb_theme")
|
|
246
|
+
if theme_name and theme_name in cls.available_themes:
|
|
247
|
+
return cls.available_themes[theme_name], theme_name
|
|
248
|
+
return cls.available_themes[cls.current_theme_name], cls.current_theme_name
|
|
249
|
+
|
|
204
250
|
@classmethod
|
|
205
251
|
def register_css_route(cls, app):
|
|
206
|
-
"""Register Flask
|
|
252
|
+
"""Register Flask routes for CSS and before_request hook for theme switching."""
|
|
253
|
+
from flask import redirect
|
|
254
|
+
|
|
255
|
+
@app.server.before_request
|
|
256
|
+
def check_theme_cookie():
|
|
257
|
+
"""Check for theme cookie on each request and update theme if needed."""
|
|
258
|
+
_, theme_name = cls._get_theme_from_cookie()
|
|
259
|
+
if theme_name != cls.current_theme_name:
|
|
260
|
+
cls.set_theme(theme_name)
|
|
261
|
+
|
|
262
|
+
@app.server.route("/base.css")
|
|
263
|
+
def serve_base_css():
|
|
264
|
+
"""Redirect to the appropriate Bootstrap theme CSS based on cookie."""
|
|
265
|
+
theme, _ = cls._get_theme_from_cookie()
|
|
266
|
+
if theme["base_css"]:
|
|
267
|
+
return redirect(theme["base_css"])
|
|
268
|
+
return "", 404
|
|
207
269
|
|
|
208
270
|
@app.server.route("/custom.css")
|
|
209
271
|
def serve_custom_css():
|
|
210
|
-
|
|
272
|
+
"""Serve the custom.css file based on cookie."""
|
|
273
|
+
theme, _ = cls._get_theme_from_cookie()
|
|
211
274
|
if theme["custom_css"]:
|
|
212
275
|
return send_from_directory(theme["custom_css"].parent, theme["custom_css"].name)
|
|
213
276
|
return "", 404
|
|
@@ -250,23 +313,25 @@ class ThemeManager:
|
|
|
250
313
|
# Loop over each path in the theme path
|
|
251
314
|
for theme_path in cls.theme_path_list:
|
|
252
315
|
for theme_dir in theme_path.iterdir():
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
316
|
+
# Skip hidden directories (e.g., .idea, .git)
|
|
317
|
+
if not theme_dir.is_dir() or theme_dir.name.startswith("."):
|
|
318
|
+
continue
|
|
319
|
+
theme_name = theme_dir.name
|
|
320
|
+
|
|
321
|
+
# Grab the base.css URL
|
|
322
|
+
base_css_url = cls._get_base_css_url(theme_dir)
|
|
323
|
+
|
|
324
|
+
# Grab the plotly template json, custom.css, and branding json
|
|
325
|
+
plotly_template = theme_dir / "plotly.json"
|
|
326
|
+
custom_css = theme_dir / "custom.css"
|
|
327
|
+
branding = theme_dir / "branding.json"
|
|
328
|
+
|
|
329
|
+
cls.available_themes[theme_name] = {
|
|
330
|
+
"base_css": base_css_url,
|
|
331
|
+
"plotly_template": plotly_template,
|
|
332
|
+
"custom_css": custom_css if custom_css.exists() else None,
|
|
333
|
+
"branding": branding if branding.exists() else None,
|
|
334
|
+
}
|
|
270
335
|
|
|
271
336
|
if not cls.available_themes:
|
|
272
337
|
cls.log.warning(f"No themes found in '{cls.theme_path_list}'...")
|
|
@@ -10,8 +10,10 @@ from workbench.api import Model, ModelType
|
|
|
10
10
|
from workbench.web_interface.components.component_interface import ComponentInterface
|
|
11
11
|
from workbench.web_interface.components.plugins.confusion_matrix import ConfusionMatrix
|
|
12
12
|
from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
|
|
13
|
+
from workbench.utils.deprecated_utils import deprecated
|
|
13
14
|
|
|
14
15
|
|
|
16
|
+
@deprecated(version="0.9")
|
|
15
17
|
class ModelPlot(ComponentInterface):
|
|
16
18
|
"""Model Metrics Components"""
|
|
17
19
|
|
|
@@ -3,7 +3,6 @@ import dash_bootstrap_components as dbc
|
|
|
3
3
|
import logging
|
|
4
4
|
import socket
|
|
5
5
|
|
|
6
|
-
|
|
7
6
|
# Workbench Imports
|
|
8
7
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginInputType
|
|
9
8
|
from workbench.api import DataSource, FeatureSet, Model, Endpoint, Meta
|
|
@@ -22,9 +22,7 @@ class AGTable(PluginInterface):
|
|
|
22
22
|
header_height = 30
|
|
23
23
|
row_height = 25
|
|
24
24
|
|
|
25
|
-
def create_component(
|
|
26
|
-
self, component_id: str, header_color: str = "rgb(120, 60, 60)", max_height: int = 500
|
|
27
|
-
) -> AgGrid:
|
|
25
|
+
def create_component(self, component_id: str, max_height: int = 500) -> AgGrid:
|
|
28
26
|
"""Create a Table Component without any data."""
|
|
29
27
|
self.component_id = component_id
|
|
30
28
|
self.max_height = max_height
|
|
@@ -112,4 +110,4 @@ if __name__ == "__main__":
|
|
|
112
110
|
test_df = pd.DataFrame(data)
|
|
113
111
|
|
|
114
112
|
# Run the Unit Test on the Plugin
|
|
115
|
-
PluginUnitTest(AGTable, theme="
|
|
113
|
+
PluginUnitTest(AGTable, theme="dark", input_data=test_df, max_height=500).run()
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
from dash import dcc, callback, Output, Input, State
|
|
4
4
|
import plotly.graph_objects as go
|
|
5
5
|
|
|
6
|
-
|
|
7
6
|
# Workbench Imports
|
|
8
7
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
9
8
|
from workbench.utils.theme_manager import ThemeManager
|
|
@@ -22,7 +21,6 @@ class ConfusionMatrix(PluginInterface):
|
|
|
22
21
|
self.component_id = None
|
|
23
22
|
self.current_highlight = None # Store the currently highlighted cell
|
|
24
23
|
self.theme_manager = ThemeManager()
|
|
25
|
-
self.colorscale = add_alpha_to_first_color(self.theme_manager.colorscale("heatmap"))
|
|
26
24
|
|
|
27
25
|
# Call the parent class constructor
|
|
28
26
|
super().__init__()
|
|
@@ -65,9 +63,8 @@ class ConfusionMatrix(PluginInterface):
|
|
|
65
63
|
if df is None:
|
|
66
64
|
return [self.display_text("No Data")]
|
|
67
65
|
|
|
68
|
-
#
|
|
69
|
-
|
|
70
|
-
# color_scale = sequential.Plasma
|
|
66
|
+
# Get the colorscale from the current theme
|
|
67
|
+
colorscale = add_alpha_to_first_color(self.theme_manager.colorscale("heatmap"))
|
|
71
68
|
|
|
72
69
|
# The confusion matrix is displayed in reverse order (flip the dataframe for correct orientation)
|
|
73
70
|
df = df.iloc[::-1]
|
|
@@ -89,7 +86,7 @@ class ConfusionMatrix(PluginInterface):
|
|
|
89
86
|
title="Count",
|
|
90
87
|
outlinewidth=1,
|
|
91
88
|
),
|
|
92
|
-
colorscale=
|
|
89
|
+
colorscale=colorscale,
|
|
93
90
|
)
|
|
94
91
|
)
|
|
95
92
|
|
|
@@ -8,9 +8,14 @@ from dash import html, callback, dcc, Input, Output, State
|
|
|
8
8
|
# Workbench Imports
|
|
9
9
|
from workbench.api import ModelType, ParameterStore
|
|
10
10
|
from workbench.cached.cached_model import CachedModel
|
|
11
|
-
from workbench.utils.markdown_utils import
|
|
11
|
+
from workbench.utils.markdown_utils import (
|
|
12
|
+
health_tag_markdown,
|
|
13
|
+
tags_to_markdown,
|
|
14
|
+
dict_to_markdown,
|
|
15
|
+
dict_to_collapsible_html,
|
|
16
|
+
df_to_html_table,
|
|
17
|
+
)
|
|
12
18
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
13
|
-
from workbench.utils.markdown_utils import tags_to_markdown, dict_to_markdown, dict_to_collapsible_html
|
|
14
19
|
|
|
15
20
|
|
|
16
21
|
class ModelDetails(PluginInterface):
|
|
@@ -44,7 +49,7 @@ class ModelDetails(PluginInterface):
|
|
|
44
49
|
dcc.Markdown(id=f"{self.component_id}-summary", dangerously_allow_html=True),
|
|
45
50
|
html.H5(children="Inference Metrics", style={"marginTop": "20px"}),
|
|
46
51
|
dcc.Dropdown(id=f"{self.component_id}-dropdown", className="dropdown"),
|
|
47
|
-
dcc.Markdown(id=f"{self.component_id}-metrics"),
|
|
52
|
+
dcc.Markdown(id=f"{self.component_id}-metrics", dangerously_allow_html=True),
|
|
48
53
|
],
|
|
49
54
|
)
|
|
50
55
|
|
|
@@ -175,15 +180,14 @@ class ModelDetails(PluginInterface):
|
|
|
175
180
|
markdown += " \nNo Data \n"
|
|
176
181
|
else:
|
|
177
182
|
markdown += " \n"
|
|
178
|
-
metrics = metrics.round(3)
|
|
179
183
|
|
|
180
184
|
# If the model is a classification model, have the index sorting match the class labels
|
|
181
185
|
if self.current_model.model_type == ModelType.CLASSIFIER:
|
|
182
|
-
# Sort the metrics by the class labels (if they match)
|
|
183
186
|
class_labels = self.current_model.class_labels()
|
|
184
187
|
if set(metrics.index) == set(class_labels):
|
|
185
188
|
metrics = metrics.reindex(class_labels)
|
|
186
|
-
|
|
189
|
+
|
|
190
|
+
markdown += df_to_html_table(metrics)
|
|
187
191
|
|
|
188
192
|
# Get additional inference metrics if they exist
|
|
189
193
|
model_name = self.current_model.name
|