workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +2 -0
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  5. workbench/algorithms/dataframe/projection_2d.py +8 -2
  6. workbench/algorithms/dataframe/proximity.py +3 -0
  7. workbench/algorithms/dataframe/smart_aggregator.py +161 -0
  8. workbench/algorithms/sql/column_stats.py +0 -1
  9. workbench/algorithms/sql/correlations.py +0 -1
  10. workbench/algorithms/sql/descriptive_stats.py +0 -1
  11. workbench/api/feature_set.py +0 -1
  12. workbench/api/meta.py +0 -1
  13. workbench/cached/cached_meta.py +0 -1
  14. workbench/cached/cached_model.py +37 -7
  15. workbench/core/artifacts/endpoint_core.py +12 -2
  16. workbench/core/artifacts/feature_set_core.py +238 -225
  17. workbench/core/cloud_platform/cloud_meta.py +0 -1
  18. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  19. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  20. workbench/model_script_utils/model_script_utils.py +30 -0
  21. workbench/model_script_utils/uq_harness.py +0 -1
  22. workbench/model_scripts/chemprop/chemprop.template +196 -68
  23. workbench/model_scripts/chemprop/generated_model_script.py +197 -72
  24. workbench/model_scripts/chemprop/model_script_utils.py +30 -0
  25. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  26. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  27. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
  29. workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
  30. workbench/model_scripts/pytorch_model/pytorch.template +47 -29
  31. workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
  32. workbench/model_scripts/script_generation.py +0 -1
  33. workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
  34. workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
  35. workbench/model_scripts/xgb_model/uq_harness.py +0 -1
  36. workbench/scripts/ml_pipeline_sqs.py +71 -2
  37. workbench/themes/dark/custom.css +85 -8
  38. workbench/themes/dark/plotly.json +6 -6
  39. workbench/themes/light/custom.css +172 -64
  40. workbench/themes/light/plotly.json +9 -9
  41. workbench/themes/midnight_blue/custom.css +82 -29
  42. workbench/themes/midnight_blue/plotly.json +1 -1
  43. workbench/utils/aws_utils.py +0 -1
  44. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  45. workbench/utils/chem_utils/projections.py +16 -6
  46. workbench/utils/chem_utils/vis.py +137 -27
  47. workbench/utils/clientside_callbacks.py +41 -0
  48. workbench/utils/markdown_utils.py +57 -0
  49. workbench/utils/model_utils.py +0 -1
  50. workbench/utils/pipeline_utils.py +0 -1
  51. workbench/utils/plot_utils.py +52 -36
  52. workbench/utils/theme_manager.py +95 -30
  53. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  54. workbench/web_interface/components/model_plot.py +2 -0
  55. workbench/web_interface/components/plugin_unit_test.py +0 -1
  56. workbench/web_interface/components/plugins/ag_table.py +2 -4
  57. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  58. workbench/web_interface/components/plugins/model_details.py +10 -6
  59. workbench/web_interface/components/plugins/scatter_plot.py +184 -85
  60. workbench/web_interface/components/settings_menu.py +185 -0
  61. workbench/web_interface/page_views/main_page.py +0 -1
  62. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
  63. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
  64. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
  65. workbench/themes/quartz/base_css.url +0 -1
  66. workbench/themes/quartz/custom.css +0 -117
  67. workbench/themes/quartz/plotly.json +0 -642
  68. workbench/themes/quartz_dark/base_css.url +0 -1
  69. workbench/themes/quartz_dark/custom.css +0 -131
  70. workbench/themes/quartz_dark/plotly.json +0 -642
  71. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
  72. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
  73. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
@@ -76,10 +76,28 @@ class ThemeManager:
76
76
  def set_theme(cls, theme_name: str):
77
77
  """Set the current theme."""
78
78
 
79
- # For 'auto', we try to grab a theme from the Parameter Store
80
- # if we can't find one, we'll set the theme to the default
79
+ # For 'auto', we check multiple sources in priority order:
80
+ # 1. Browser cookie (from localStorage, for per-user preference)
81
+ # 2. Parameter Store (for org-wide default)
82
+ # 3. Default theme
81
83
  if theme_name == "auto":
82
- theme_name = cls.ps.get("/workbench/dashboard/theme", warn=False) or cls.default_theme
84
+ theme_name = None
85
+
86
+ # 1. Check Flask request cookie (set from localStorage)
87
+ try:
88
+ from flask import request, has_request_context
89
+
90
+ if has_request_context():
91
+ theme_name = request.cookies.get("wb_theme")
92
+ except Exception:
93
+ pass
94
+
95
+ # 2. Fall back to ParameterStore
96
+ if not theme_name:
97
+ theme_name = cls.ps.get("/workbench/dashboard/theme", warn=False)
98
+
99
+ # 3. Fall back to default
100
+ theme_name = theme_name or cls.default_theme
83
101
 
84
102
  # Check if the theme is in our available themes
85
103
  if theme_name not in cls.available_themes:
@@ -104,9 +122,27 @@ class ThemeManager:
104
122
  cls.current_theme_name = theme_name
105
123
  cls.log.info(f"Theme set to '{theme_name}'")
106
124
 
125
+ # Bootstrap themes that are dark mode (from Bootswatch)
126
+ _dark_bootstrap_themes = {"DARKLY", "CYBORG", "SLATE", "SOLAR", "SUPERHERO", "VAPOR"}
127
+
107
128
  @classmethod
108
129
  def dark_mode(cls) -> bool:
109
- """Check if the current theme is a dark mode theme."""
130
+ """Check if the current theme is a dark mode theme.
131
+
132
+ Determines dark mode by checking if the Bootstrap base theme is a known dark theme.
133
+ Falls back to checking if 'dark' is in the theme name.
134
+ """
135
+ theme = cls.available_themes.get(cls.current_theme_name, {})
136
+ base_css = theme.get("base_css", "")
137
+
138
+ # Check if the base CSS URL contains a known dark Bootstrap theme
139
+ if base_css:
140
+ base_css_upper = base_css.upper()
141
+ for dark_theme in cls._dark_bootstrap_themes:
142
+ if dark_theme in base_css_upper:
143
+ return True
144
+
145
+ # Fallback: check if 'dark' is in the theme name
110
146
  return "dark" in cls.current_theme().lower()
111
147
 
112
148
  @classmethod
@@ -184,30 +220,57 @@ class ThemeManager:
184
220
 
185
221
  @classmethod
186
222
  def css_files(cls) -> list[str]:
187
- """Get the list of CSS files for the current theme."""
188
- theme = cls.available_themes[cls.current_theme_name]
223
+ """Get the list of CSS files for the current theme.
224
+
225
+ Note: Uses /base.css route for dynamic theme switching instead of CDN URLs.
226
+ """
189
227
  css_files = []
190
228
 
191
- # Add base.css or its CDN URL
192
- if theme["base_css"]:
193
- css_files.append(theme["base_css"])
229
+ # Use Flask route for base CSS (allows dynamic theme switching)
230
+ css_files.append("/base.css")
194
231
 
195
232
  # Add the DBC template CSS
196
233
  css_files.append(cls.dbc_css)
197
234
 
198
235
  # Add custom.css if it exists
199
- if theme["custom_css"]:
200
- css_files.append("/custom.css")
236
+ css_files.append("/custom.css")
201
237
 
202
238
  return css_files
203
239
 
240
+ @classmethod
241
+ def _get_theme_from_cookie(cls):
242
+ """Get the theme dict based on the wb_theme cookie, falling back to current theme."""
243
+ from flask import request
244
+
245
+ theme_name = request.cookies.get("wb_theme")
246
+ if theme_name and theme_name in cls.available_themes:
247
+ return cls.available_themes[theme_name], theme_name
248
+ return cls.available_themes[cls.current_theme_name], cls.current_theme_name
249
+
204
250
  @classmethod
205
251
  def register_css_route(cls, app):
206
- """Register Flask route for custom.css."""
252
+ """Register Flask routes for CSS and before_request hook for theme switching."""
253
+ from flask import redirect
254
+
255
+ @app.server.before_request
256
+ def check_theme_cookie():
257
+ """Check for theme cookie on each request and update theme if needed."""
258
+ _, theme_name = cls._get_theme_from_cookie()
259
+ if theme_name != cls.current_theme_name:
260
+ cls.set_theme(theme_name)
261
+
262
+ @app.server.route("/base.css")
263
+ def serve_base_css():
264
+ """Redirect to the appropriate Bootstrap theme CSS based on cookie."""
265
+ theme, _ = cls._get_theme_from_cookie()
266
+ if theme["base_css"]:
267
+ return redirect(theme["base_css"])
268
+ return "", 404
207
269
 
208
270
  @app.server.route("/custom.css")
209
271
  def serve_custom_css():
210
- theme = cls.available_themes[cls.current_theme_name]
272
+ """Serve the custom.css file based on cookie."""
273
+ theme, _ = cls._get_theme_from_cookie()
211
274
  if theme["custom_css"]:
212
275
  return send_from_directory(theme["custom_css"].parent, theme["custom_css"].name)
213
276
  return "", 404
@@ -250,23 +313,25 @@ class ThemeManager:
250
313
  # Loop over each path in the theme path
251
314
  for theme_path in cls.theme_path_list:
252
315
  for theme_dir in theme_path.iterdir():
253
- if theme_dir.is_dir():
254
- theme_name = theme_dir.name
255
-
256
- # Grab the base.css URL
257
- base_css_url = cls._get_base_css_url(theme_dir)
258
-
259
- # Grab the plotly template json, custom.css, and branding json
260
- plotly_template = theme_dir / "plotly.json"
261
- custom_css = theme_dir / "custom.css"
262
- branding = theme_dir / "branding.json"
263
-
264
- cls.available_themes[theme_name] = {
265
- "base_css": base_css_url,
266
- "plotly_template": plotly_template,
267
- "custom_css": custom_css if custom_css.exists() else None,
268
- "branding": branding if branding.exists() else None,
269
- }
316
+ # Skip hidden directories (e.g., .idea, .git)
317
+ if not theme_dir.is_dir() or theme_dir.name.startswith("."):
318
+ continue
319
+ theme_name = theme_dir.name
320
+
321
+ # Grab the base.css URL
322
+ base_css_url = cls._get_base_css_url(theme_dir)
323
+
324
+ # Grab the plotly template json, custom.css, and branding json
325
+ plotly_template = theme_dir / "plotly.json"
326
+ custom_css = theme_dir / "custom.css"
327
+ branding = theme_dir / "branding.json"
328
+
329
+ cls.available_themes[theme_name] = {
330
+ "base_css": base_css_url,
331
+ "plotly_template": plotly_template,
332
+ "custom_css": custom_css if custom_css.exists() else None,
333
+ "branding": branding if branding.exists() else None,
334
+ }
270
335
 
271
336
  if not cls.available_themes:
272
337
  cls.log.warning(f"No themes found in '{cls.theme_path_list}'...")
@@ -11,7 +11,6 @@ import logging
11
11
  from workbench.algorithms.dataframe.aggregation import aggregate
12
12
  from workbench.algorithms.dataframe.projection_2d import Projection2D
13
13
 
14
-
15
14
  # Workbench Logger
16
15
  log = logging.getLogger("workbench")
17
16
 
@@ -10,8 +10,10 @@ from workbench.api import Model, ModelType
10
10
  from workbench.web_interface.components.component_interface import ComponentInterface
11
11
  from workbench.web_interface.components.plugins.confusion_matrix import ConfusionMatrix
12
12
  from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
13
+ from workbench.utils.deprecated_utils import deprecated
13
14
 
14
15
 
16
+ @deprecated(version="0.9")
15
17
  class ModelPlot(ComponentInterface):
16
18
  """Model Metrics Components"""
17
19
 
@@ -3,7 +3,6 @@ import dash_bootstrap_components as dbc
3
3
  import logging
4
4
  import socket
5
5
 
6
-
7
6
  # Workbench Imports
8
7
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginInputType
9
8
  from workbench.api import DataSource, FeatureSet, Model, Endpoint, Meta
@@ -22,9 +22,7 @@ class AGTable(PluginInterface):
22
22
  header_height = 30
23
23
  row_height = 25
24
24
 
25
- def create_component(
26
- self, component_id: str, header_color: str = "rgb(120, 60, 60)", max_height: int = 500
27
- ) -> AgGrid:
25
+ def create_component(self, component_id: str, max_height: int = 500) -> AgGrid:
28
26
  """Create a Table Component without any data."""
29
27
  self.component_id = component_id
30
28
  self.max_height = max_height
@@ -112,4 +110,4 @@ if __name__ == "__main__":
112
110
  test_df = pd.DataFrame(data)
113
111
 
114
112
  # Run the Unit Test on the Plugin
115
- PluginUnitTest(AGTable, theme="quartz", input_data=test_df, max_height=500).run()
113
+ PluginUnitTest(AGTable, theme="dark", input_data=test_df, max_height=500).run()
@@ -3,7 +3,6 @@
3
3
  from dash import dcc, callback, Output, Input, State
4
4
  import plotly.graph_objects as go
5
5
 
6
-
7
6
  # Workbench Imports
8
7
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
9
8
  from workbench.utils.theme_manager import ThemeManager
@@ -22,7 +21,6 @@ class ConfusionMatrix(PluginInterface):
22
21
  self.component_id = None
23
22
  self.current_highlight = None # Store the currently highlighted cell
24
23
  self.theme_manager = ThemeManager()
25
- self.colorscale = add_alpha_to_first_color(self.theme_manager.colorscale("heatmap"))
26
24
 
27
25
  # Call the parent class constructor
28
26
  super().__init__()
@@ -65,9 +63,8 @@ class ConfusionMatrix(PluginInterface):
65
63
  if df is None:
66
64
  return [self.display_text("No Data")]
67
65
 
68
- # Use Plotly's default theme-friendly colorscale
69
- # from plotly.colors import sequential
70
- # color_scale = sequential.Plasma
66
+ # Get the colorscale from the current theme
67
+ colorscale = add_alpha_to_first_color(self.theme_manager.colorscale("heatmap"))
71
68
 
72
69
  # The confusion matrix is displayed in reverse order (flip the dataframe for correct orientation)
73
70
  df = df.iloc[::-1]
@@ -89,7 +86,7 @@ class ConfusionMatrix(PluginInterface):
89
86
  title="Count",
90
87
  outlinewidth=1,
91
88
  ),
92
- colorscale=self.colorscale,
89
+ colorscale=colorscale,
93
90
  )
94
91
  )
95
92
 
@@ -8,9 +8,14 @@ from dash import html, callback, dcc, Input, Output, State
8
8
  # Workbench Imports
9
9
  from workbench.api import ModelType, ParameterStore
10
10
  from workbench.cached.cached_model import CachedModel
11
- from workbench.utils.markdown_utils import health_tag_markdown
11
+ from workbench.utils.markdown_utils import (
12
+ health_tag_markdown,
13
+ tags_to_markdown,
14
+ dict_to_markdown,
15
+ dict_to_collapsible_html,
16
+ df_to_html_table,
17
+ )
12
18
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
13
- from workbench.utils.markdown_utils import tags_to_markdown, dict_to_markdown, dict_to_collapsible_html
14
19
 
15
20
 
16
21
  class ModelDetails(PluginInterface):
@@ -44,7 +49,7 @@ class ModelDetails(PluginInterface):
44
49
  dcc.Markdown(id=f"{self.component_id}-summary", dangerously_allow_html=True),
45
50
  html.H5(children="Inference Metrics", style={"marginTop": "20px"}),
46
51
  dcc.Dropdown(id=f"{self.component_id}-dropdown", className="dropdown"),
47
- dcc.Markdown(id=f"{self.component_id}-metrics"),
52
+ dcc.Markdown(id=f"{self.component_id}-metrics", dangerously_allow_html=True),
48
53
  ],
49
54
  )
50
55
 
@@ -175,15 +180,14 @@ class ModelDetails(PluginInterface):
175
180
  markdown += " \nNo Data \n"
176
181
  else:
177
182
  markdown += " \n"
178
- metrics = metrics.round(3)
179
183
 
180
184
  # If the model is a classification model, have the index sorting match the class labels
181
185
  if self.current_model.model_type == ModelType.CLASSIFIER:
182
- # Sort the metrics by the class labels (if they match)
183
186
  class_labels = self.current_model.class_labels()
184
187
  if set(metrics.index) == set(class_labels):
185
188
  metrics = metrics.reindex(class_labels)
186
- markdown += metrics.to_markdown()
189
+
190
+ markdown += df_to_html_table(metrics)
187
191
 
188
192
  # Get additional inference metrics if they exist
189
193
  model_name = self.current_model.name