workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +2 -0
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  5. workbench/algorithms/dataframe/projection_2d.py +8 -2
  6. workbench/algorithms/dataframe/proximity.py +3 -0
  7. workbench/algorithms/dataframe/smart_aggregator.py +161 -0
  8. workbench/algorithms/sql/column_stats.py +0 -1
  9. workbench/algorithms/sql/correlations.py +0 -1
  10. workbench/algorithms/sql/descriptive_stats.py +0 -1
  11. workbench/api/feature_set.py +0 -1
  12. workbench/api/meta.py +0 -1
  13. workbench/cached/cached_meta.py +0 -1
  14. workbench/cached/cached_model.py +37 -7
  15. workbench/core/artifacts/endpoint_core.py +12 -2
  16. workbench/core/artifacts/feature_set_core.py +238 -225
  17. workbench/core/cloud_platform/cloud_meta.py +0 -1
  18. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  19. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  20. workbench/model_script_utils/model_script_utils.py +30 -0
  21. workbench/model_script_utils/uq_harness.py +0 -1
  22. workbench/model_scripts/chemprop/chemprop.template +196 -68
  23. workbench/model_scripts/chemprop/generated_model_script.py +197 -72
  24. workbench/model_scripts/chemprop/model_script_utils.py +30 -0
  25. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  26. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  27. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
  29. workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
  30. workbench/model_scripts/pytorch_model/pytorch.template +47 -29
  31. workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
  32. workbench/model_scripts/script_generation.py +0 -1
  33. workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
  34. workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
  35. workbench/model_scripts/xgb_model/uq_harness.py +0 -1
  36. workbench/scripts/ml_pipeline_sqs.py +71 -2
  37. workbench/themes/dark/custom.css +85 -8
  38. workbench/themes/dark/plotly.json +6 -6
  39. workbench/themes/light/custom.css +172 -64
  40. workbench/themes/light/plotly.json +9 -9
  41. workbench/themes/midnight_blue/custom.css +82 -29
  42. workbench/themes/midnight_blue/plotly.json +1 -1
  43. workbench/utils/aws_utils.py +0 -1
  44. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  45. workbench/utils/chem_utils/projections.py +16 -6
  46. workbench/utils/chem_utils/vis.py +137 -27
  47. workbench/utils/clientside_callbacks.py +41 -0
  48. workbench/utils/markdown_utils.py +57 -0
  49. workbench/utils/model_utils.py +0 -1
  50. workbench/utils/pipeline_utils.py +0 -1
  51. workbench/utils/plot_utils.py +52 -36
  52. workbench/utils/theme_manager.py +95 -30
  53. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  54. workbench/web_interface/components/model_plot.py +2 -0
  55. workbench/web_interface/components/plugin_unit_test.py +0 -1
  56. workbench/web_interface/components/plugins/ag_table.py +2 -4
  57. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  58. workbench/web_interface/components/plugins/model_details.py +10 -6
  59. workbench/web_interface/components/plugins/scatter_plot.py +184 -85
  60. workbench/web_interface/components/settings_menu.py +185 -0
  61. workbench/web_interface/page_views/main_page.py +0 -1
  62. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
  63. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
  64. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
  65. workbench/themes/quartz/base_css.url +0 -1
  66. workbench/themes/quartz/custom.css +0 -117
  67. workbench/themes/quartz/plotly.json +0 -642
  68. workbench/themes/quartz_dark/base_css.url +0 -1
  69. workbench/themes/quartz_dark/custom.css +0 -131
  70. workbench/themes/quartz_dark/plotly.json +0 -642
  71. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
  72. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
  73. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import base64
2
+ import numpy as np
2
3
  import pandas as pd
3
- from dash import dcc, html, callback, Input, Output, no_update
4
+ from dash import dcc, html, callback, clientside_callback, Input, Output, no_update
4
5
  import plotly.graph_objects as go
5
6
  import plotly.express as px
6
7
  from dash.exceptions import PreventUpdate
@@ -9,6 +10,8 @@ from dash.exceptions import PreventUpdate
9
10
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
10
11
  from workbench.utils.theme_manager import ThemeManager
11
12
  from workbench.utils.plot_utils import prediction_intervals
13
+ from workbench.utils.chem_utils.vis import molecule_hover_tooltip
14
+ from workbench.utils.clientside_callbacks import circle_overlay_callback
12
15
 
13
16
 
14
17
  class ScatterPlot(PluginInterface):
@@ -18,6 +21,12 @@ class ScatterPlot(PluginInterface):
18
21
  auto_load_page = PluginPage.NONE
19
22
  plugin_input_type = PluginInputType.DATAFRAME
20
23
 
24
+ # Pre-computed circle overlay SVG
25
+ _circle_svg = """<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" style="overflow: visible;">
26
+ <circle cx="50" cy="50" r="10" stroke="rgba(255, 255, 255, 1)" stroke-width="3" fill="none" />
27
+ </svg>"""
28
+ _circle_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(_circle_svg.encode('utf-8')).decode('utf-8')}"
29
+
21
30
  def __init__(self, show_axes: bool = True):
22
31
  """Initialize the Scatter Plot Plugin
23
32
 
@@ -29,7 +38,10 @@ class ScatterPlot(PluginInterface):
29
38
  self.df = None
30
39
  self.show_axes = show_axes
31
40
  self.theme_manager = ThemeManager()
32
- self.colorscale = self.theme_manager.colorscale()
41
+ self.has_smiles = False # Track if dataframe has smiles column for molecule hover
42
+ self.smiles_column = None
43
+ self.id_column = None
44
+ self.hover_background = None # Cached background color for molecule hover tooltip
33
45
 
34
46
  # Call the parent class constructor
35
47
  super().__init__()
@@ -51,10 +63,10 @@ class ScatterPlot(PluginInterface):
51
63
  (f"{component_id}-x-dropdown", "options"),
52
64
  (f"{component_id}-y-dropdown", "options"),
53
65
  (f"{component_id}-color-dropdown", "options"),
54
- (f"{component_id}-label-dropdown", "options"),
55
66
  (f"{component_id}-x-dropdown", "value"),
56
67
  (f"{component_id}-y-dropdown", "value"),
57
68
  (f"{component_id}-color-dropdown", "value"),
69
+ (f"{component_id}-regression-line", "value"),
58
70
  ]
59
71
  self.signals = [(f"{component_id}-graph", "hoverData"), (f"{component_id}-graph", "clickData")]
60
72
 
@@ -69,51 +81,67 @@ class ScatterPlot(PluginInterface):
69
81
  id=f"{component_id}-graph",
70
82
  figure=self.display_text("Waiting for Data..."),
71
83
  config={"scrollZoom": True},
72
- style={"height": "100%"},
84
+ style={"height": "500px", "width": "100%"},
73
85
  clear_on_unhover=True,
74
86
  ),
75
87
  # Controls: X, Y, Color, Label Dropdowns, and Regression Line Checkbox
76
88
  html.Div(
77
89
  [
78
- html.Label("X", style={"marginLeft": "40px", "marginRight": "5px", "fontWeight": "bold"}),
90
+ html.Label(
91
+ "X",
92
+ style={
93
+ "marginLeft": "20px",
94
+ "marginRight": "5px",
95
+ "fontWeight": "bold",
96
+ "display": "flex",
97
+ "alignItems": "center",
98
+ },
99
+ ),
79
100
  dcc.Dropdown(
80
101
  id=f"{component_id}-x-dropdown",
81
- className="dropdown",
82
- style={"min-width": "50px", "flex": 1}, # Responsive width
102
+ style={"minWidth": "150px", "flex": 1},
83
103
  clearable=False,
84
104
  ),
85
- html.Label("Y", style={"marginLeft": "30px", "marginRight": "5px", "fontWeight": "bold"}),
105
+ html.Label(
106
+ "Y",
107
+ style={
108
+ "marginLeft": "20px",
109
+ "marginRight": "5px",
110
+ "fontWeight": "bold",
111
+ "display": "flex",
112
+ "alignItems": "center",
113
+ },
114
+ ),
86
115
  dcc.Dropdown(
87
116
  id=f"{component_id}-y-dropdown",
88
- className="dropdown",
89
- style={"min-width": "50px", "flex": 1}, # Responsive width
117
+ style={"minWidth": "150px", "flex": 1},
90
118
  clearable=False,
91
119
  ),
92
- html.Label("Color", style={"marginLeft": "30px", "marginRight": "5px", "fontWeight": "bold"}),
93
- dcc.Dropdown(
94
- id=f"{component_id}-color-dropdown",
95
- className="dropdown",
96
- style={"min-width": "50px", "flex": 1}, # Responsive width
97
- clearable=False,
120
+ html.Label(
121
+ "Color",
122
+ style={
123
+ "marginLeft": "20px",
124
+ "marginRight": "5px",
125
+ "fontWeight": "bold",
126
+ "display": "flex",
127
+ "alignItems": "center",
128
+ },
98
129
  ),
99
- html.Label("Label", style={"marginLeft": "30px", "marginRight": "5px", "fontWeight": "bold"}),
100
130
  dcc.Dropdown(
101
- id=f"{component_id}-label-dropdown",
102
- className="dropdown",
103
- style={"min-width": "50px", "flex": 1},
104
- options=[{"label": "None", "value": "none"}],
105
- value="none",
131
+ id=f"{component_id}-color-dropdown",
132
+ style={"minWidth": "150px", "flex": 1},
106
133
  clearable=False,
107
134
  ),
108
135
  dcc.Checklist(
109
136
  id=f"{component_id}-regression-line",
110
137
  options=[{"label": " Diagonal", "value": "show"}],
111
138
  value=[],
112
- style={"margin": "10px"},
139
+ style={"marginLeft": "20px", "display": "flex", "alignItems": "center"},
113
140
  ),
114
141
  ],
115
- style={"padding": "0px 0px 10px 0px", "display": "flex", "gap": "10px"},
142
+ style={"padding": "0px 0px 10px 0px", "display": "flex", "alignItems": "center", "gap": "5px"},
116
143
  ),
144
+ # Circle overlay tooltip (centered on hovered point)
117
145
  dcc.Tooltip(
118
146
  id=f"{component_id}-overlay",
119
147
  background_color="rgba(0,0,0,0)",
@@ -121,6 +149,14 @@ class ScatterPlot(PluginInterface):
121
149
  direction="bottom",
122
150
  loading_text="",
123
151
  ),
152
+ # Molecule tooltip (offset from hovered point) - only used when smiles column exists
153
+ dcc.Tooltip(
154
+ id=f"{component_id}-molecule-tooltip",
155
+ background_color="rgba(0,0,0,0)",
156
+ border_color="rgba(0,0,0,0)",
157
+ direction="bottom",
158
+ loading_text="",
159
+ ),
124
160
  ],
125
161
  style={"height": "100%", "display": "flex", "flexDirection": "column"}, # Full viewport height
126
162
  )
@@ -139,12 +175,16 @@ class ScatterPlot(PluginInterface):
139
175
  - hover_columns: The columns to show when hovering over a point
140
176
  - suppress_hover_display: Suppress hover display (default: False)
141
177
  - custom_data: Custom data that get passed to hoverData callbacks
178
+ - id_column: Column to use for molecule tooltip header (auto-detects "id" if not specified)
142
179
 
143
180
  Returns:
144
181
  list: A list of updated property values (figure, x options, y options, color options,
145
- label options, x default, y default,
146
- color default).
182
+ x default, y default, color default).
147
183
  """
184
+ # Get the colorscale and background color from the current theme
185
+ self.colorscale = self.theme_manager.colorscale()
186
+ self.hover_background = self.theme_manager.background()
187
+
148
188
  # Get the limit for the number of rows to plot
149
189
  limit = kwargs.get("limit", 20000)
150
190
 
@@ -163,19 +203,28 @@ class ScatterPlot(PluginInterface):
163
203
  self.suppress_hover_display = kwargs.get("suppress_hover_display", False)
164
204
  self.custom_data = kwargs.get("custom_data", [])
165
205
 
206
+ # Check if the dataframe has smiles/id columns for molecule hover rendering
207
+ self.smiles_column = next((col for col in self.df.columns if col.lower() == "smiles"), None)
208
+ # Use provided id_column, or auto-detect "id" column, or fall back to first column
209
+ self.id_column = kwargs.get("id_column") or next(
210
+ (col for col in self.df.columns if col.lower() == "id"), self.df.columns[0]
211
+ )
212
+ self.has_smiles = self.smiles_column is not None
213
+
166
214
  # Identify numeric columns
167
215
  numeric_columns = self.df.select_dtypes(include="number").columns.tolist()
168
216
  if len(numeric_columns) < 3:
169
217
  raise ValueError("At least three numeric columns are required for x, y, and color.")
170
218
 
171
- # Default x, y, and color (for color, default to a numeric column)
219
+ # Default x, y, and color (for color, prefer 'confidence' if it exists)
172
220
  x_default = kwargs.get("x", numeric_columns[0])
173
221
  y_default = kwargs.get("y", numeric_columns[1])
174
- color_default = kwargs.get("color", numeric_columns[2])
222
+ default_color = "confidence" if "confidence" in self.df.columns else numeric_columns[2]
223
+ color_default = kwargs.get("color", default_color)
175
224
  regression_line = kwargs.get("regression_line", False)
176
225
 
177
226
  # Create the default scatter plot
178
- figure = self.create_scatter_plot(self.df, x_default, y_default, color_default, "none", regression_line)
227
+ figure = self.create_scatter_plot(self.df, x_default, y_default, color_default, regression_line)
179
228
 
180
229
  # Dropdown options for x and y: use provided dropdown_columns or fallback to numeric columns
181
230
  dropdown_columns = kwargs.get("dropdown_columns", numeric_columns)
@@ -188,11 +237,10 @@ class ScatterPlot(PluginInterface):
188
237
  color_columns = numeric_columns + cat_columns
189
238
  color_options = [{"label": col, "value": col} for col in color_columns]
190
239
 
191
- # For label dropdown, include None option and all columns
192
- label_options = [{"label": "None", "value": "none"}]
193
- label_options.extend([{"label": col, "value": col} for col in self.df.columns])
240
+ # Regression line checklist value (list with "show" if enabled, empty list if disabled)
241
+ regression_line_value = ["show"] if regression_line else []
194
242
 
195
- return [figure, x_options, y_options, color_options, label_options, x_default, y_default, color_default]
243
+ return [figure, x_options, y_options, color_options, x_default, y_default, color_default, regression_line_value]
196
244
 
197
245
  def create_scatter_plot(
198
246
  self,
@@ -200,9 +248,7 @@ class ScatterPlot(PluginInterface):
200
248
  x_col: str,
201
249
  y_col: str,
202
250
  color_col: str,
203
- label_col: str,
204
251
  regression_line: bool = False,
205
- marker_size: int = 15,
206
252
  ) -> go.Figure:
207
253
  """Create a Plotly Scatter Plot figure.
208
254
 
@@ -211,24 +257,46 @@ class ScatterPlot(PluginInterface):
211
257
  x_col (str): The column to use for the x-axis.
212
258
  y_col (str): The column to use for the y-axis.
213
259
  color_col (str): The column to use for the color scale.
214
- label_col (str): The column to use for point labels.
215
260
  regression_line (bool): Whether to include a regression line.
216
- marker_size (int): Size of the markers. Default is 15.
217
261
 
218
262
  Returns:
219
263
  go.Figure: A Plotly Figure object.
220
264
  """
221
- # Check if we need to show labels
222
- show_labels = label_col != "none" and len(df) < 1000
265
+
266
+ # If aggregation_count is present, sort so largest counts are drawn first (underneath)
267
+ # and compute marker sizes using square root (between log and linear)
268
+ if "aggregation_count" in df.columns:
269
+ df = df.sort_values("aggregation_count", ascending=False).reset_index(drop=True)
270
+ # Scale: base_size (15) + (sqrt(count) - 1) * factor, so count=1 stays at base_size
271
+ marker_sizes = 15 + (np.sqrt(df["aggregation_count"]) - 1) * 3
272
+ else:
273
+ marker_sizes = 15
223
274
 
224
275
  # Helper to generate hover text for each point.
225
276
  def generate_hover_text(row):
226
277
  return "<br>".join([f"{col}: {row[col]}" for col in self.hover_columns])
227
278
 
228
- # Generate hover text for all points.
229
- hovertext = df.apply(generate_hover_text, axis=1)
230
- hovertemplate = "%{hovertext}<extra></extra>"
231
- hoverinfo = "none" if self.suppress_hover_display else None
279
+ # Generate hover text for all points (unless suppressed or using molecule hover)
280
+ suppress_hover = self.suppress_hover_display or self.has_smiles
281
+ if suppress_hover:
282
+ # Use "none" to hide the default hover display but still fire hoverData callbacks
283
+ # Don't set hovertemplate when suppressing - it would override hoverinfo
284
+ hovertext = None
285
+ hovertemplate = None
286
+ hoverinfo = "none"
287
+ else:
288
+ hovertext = df.apply(generate_hover_text, axis=1)
289
+ hovertemplate = "%{hovertext}<extra></extra>"
290
+ hoverinfo = None
291
+
292
+ # Build customdata columns - include smiles and id if available for molecule hover
293
+ custom_data_cols = list(self.custom_data) if self.custom_data else []
294
+ if self.has_smiles:
295
+ # Add smiles as first column, id as second (if available)
296
+ if self.smiles_column not in custom_data_cols:
297
+ custom_data_cols = [self.smiles_column] + custom_data_cols
298
+ if self.id_column and self.id_column not in custom_data_cols:
299
+ custom_data_cols.insert(1, self.id_column)
232
300
 
233
301
  # Determine marker settings based on the type of the color column.
234
302
  if pd.api.types.is_numeric_dtype(df[color_col]):
@@ -240,18 +308,16 @@ class ScatterPlot(PluginInterface):
240
308
  x=df[x_col],
241
309
  y=df[y_col],
242
310
  mode="markers",
243
- text=df[label_col].astype(str) if show_labels else None,
244
- textposition="top center",
245
311
  hoverinfo=hoverinfo,
246
312
  hovertext=hovertext,
247
313
  hovertemplate=hovertemplate,
248
- customdata=df[self.custom_data],
314
+ customdata=df[custom_data_cols] if custom_data_cols else None,
249
315
  marker=dict(
250
- size=marker_size,
316
+ size=marker_sizes,
251
317
  color=marker_color,
252
318
  colorscale=self.colorscale,
253
319
  colorbar=colorbar,
254
- opacity=0.8,
320
+ opacity=0.9,
255
321
  line=dict(color="rgba(0,0,0,0.25)", width=1),
256
322
  ),
257
323
  )
@@ -266,20 +332,27 @@ class ScatterPlot(PluginInterface):
266
332
  data = []
267
333
  for i, cat in enumerate(categories):
268
334
  sub_df = df[df[color_col] == cat]
269
- sub_hovertext = hovertext.loc[sub_df.index]
335
+ sub_hovertext = hovertext.loc[sub_df.index] if hovertext is not None else None
336
+ # Get marker sizes for this subset (handles both array and scalar)
337
+ if isinstance(marker_sizes, (pd.Series, np.ndarray)):
338
+ sub_marker_sizes = (
339
+ marker_sizes.loc[sub_df.index]
340
+ if isinstance(marker_sizes, pd.Series)
341
+ else marker_sizes[sub_df.index]
342
+ )
343
+ else:
344
+ sub_marker_sizes = marker_sizes
270
345
  trace = go.Scattergl(
271
346
  x=sub_df[x_col],
272
347
  y=sub_df[y_col],
273
348
  mode="markers",
274
- text=sub_df[label_col] if show_labels else None, # Add text if labels enabled
275
- textposition="top center", # Position labels above points
276
349
  name=cat,
277
350
  hoverinfo=hoverinfo,
278
351
  hovertext=sub_hovertext,
279
352
  hovertemplate=hovertemplate,
280
- customdata=sub_df[self.custom_data],
353
+ customdata=sub_df[custom_data_cols] if custom_data_cols else None,
281
354
  marker=dict(
282
- size=marker_size,
355
+ size=sub_marker_sizes,
283
356
  color=discrete_colors[i % len(discrete_colors)],
284
357
  opacity=0.8,
285
358
  line=dict(color="rgba(0,0,0,0.25)", width=1),
@@ -345,64 +418,73 @@ class ScatterPlot(PluginInterface):
345
418
  Input(f"{self.component_id}-x-dropdown", "value"),
346
419
  Input(f"{self.component_id}-y-dropdown", "value"),
347
420
  Input(f"{self.component_id}-color-dropdown", "value"),
348
- Input(f"{self.component_id}-label-dropdown", "value"),
349
421
  Input(f"{self.component_id}-regression-line", "value"),
350
422
  ],
351
423
  prevent_initial_call=True,
352
424
  )
353
- def _update_scatter_plot(x_value, y_value, color_value, label_value, regression_line):
425
+ def _update_scatter_plot(x_value, y_value, color_value, regression_line):
354
426
  """Update the Scatter Plot Graph based on the dropdown values."""
355
427
 
356
428
  # Check if the dataframe is not empty and the values are not None
357
429
  if not self.df.empty and x_value and y_value and color_value:
358
- # Update Plotly Scatter Plot with the label value
359
- figure = self.create_scatter_plot(self.df, x_value, y_value, color_value, label_value, regression_line)
430
+ figure = self.create_scatter_plot(self.df, x_value, y_value, color_value, regression_line)
360
431
  return figure
361
432
 
362
433
  raise PreventUpdate
363
434
 
364
- @callback(
435
+ # Clientside callback for circle overlay - runs in browser, no server round trip
436
+ clientside_callback(
437
+ circle_overlay_callback(self._circle_data_uri),
365
438
  Output(f"{self.component_id}-overlay", "show"),
366
439
  Output(f"{self.component_id}-overlay", "bbox"),
367
440
  Output(f"{self.component_id}-overlay", "children"),
368
441
  Input(f"{self.component_id}-graph", "hoverData"),
369
442
  )
370
- def _scatter_overlay(hover_data):
371
- if hover_data is None:
372
- # Hide the overlay if no hover data
373
- return False, no_update, no_update
374
-
375
- # Extract bounding box from hoverData
376
- bbox = hover_data["points"][0]["bbox"]
377
443
 
378
- # Create an SVG with a circle at the center
379
- svg = """
380
- <svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" style="overflow: visible;">
381
- <!-- Circle for the node -->
382
- <circle cx="50" cy="50" r="10" stroke="rgba(255, 255, 255, 1)" stroke-width="3" fill="none" />
383
- </svg>
384
- """
444
+ @callback(
445
+ Output(f"{self.component_id}-molecule-tooltip", "show"),
446
+ Output(f"{self.component_id}-molecule-tooltip", "bbox"),
447
+ Output(f"{self.component_id}-molecule-tooltip", "children"),
448
+ Input(f"{self.component_id}-graph", "hoverData"),
449
+ )
450
+ def _scatter_molecule_overlay(hover_data):
451
+ """Show molecule tooltip when smiles data is available."""
452
+ if hover_data is None or not self.has_smiles:
453
+ return False, no_update, no_update
385
454
 
386
- # Encode the SVG as Base64
387
- encoded_svg = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
388
- data_uri = f"data:image/svg+xml;base64,{encoded_svg}"
455
+ # Extract customdata (contains smiles and id)
456
+ customdata = hover_data["points"][0].get("customdata")
457
+ if customdata is None:
458
+ return False, no_update, no_update
389
459
 
390
- # Use an img tag for the overlay
391
- svg_image = html.Img(src=data_uri, style={"width": "100px", "height": "100px"})
460
+ # SMILES is the first element, ID is second (if available)
461
+ if isinstance(customdata, (list, tuple)):
462
+ smiles = customdata[0]
463
+ mol_id = customdata[1] if len(customdata) > 1 and self.id_column else None
464
+ else:
465
+ smiles = customdata
466
+ mol_id = None
467
+
468
+ # Generate molecule tooltip with ID header (use cached background color)
469
+ mol_width, mol_height = 300, 200
470
+ children = molecule_hover_tooltip(
471
+ smiles, mol_id=mol_id, width=mol_width, height=mol_height, background=self.hover_background
472
+ )
392
473
 
393
- # Get the center of the bounding box
474
+ # Position molecule tooltip above and slightly right of the point
475
+ bbox = hover_data["points"][0]["bbox"]
394
476
  center_x = (bbox["x0"] + bbox["x1"]) / 2
395
477
  center_y = (bbox["y0"] + bbox["y1"]) / 2
478
+ x_offset = 5 # Slight offset to the right
479
+ y_offset = mol_height + 50 # Above the point
396
480
 
397
- # The tooltip should be centered on the point (note: 'bottom' tooltip, so we adjust y position)
398
481
  adjusted_bbox = {
399
- "x0": center_x - 50,
400
- "x1": center_x + 50,
401
- "y0": center_y - 162,
402
- "y1": center_y - 62,
482
+ "x0": center_x + x_offset,
483
+ "x1": center_x + x_offset + mol_width,
484
+ "y0": center_y - mol_height - y_offset,
485
+ "y1": center_y - y_offset,
403
486
  }
404
- # Return the updated values for the overlay
405
- return True, adjusted_bbox, [svg_image]
487
+ return True, adjusted_bbox, children
406
488
 
407
489
 
408
490
  if __name__ == "__main__":
@@ -426,6 +508,8 @@ if __name__ == "__main__":
426
508
  df = model.get_inference_predictions("full_cross_fold")
427
509
 
428
510
  # Run the Unit Test on the Plugin
511
+ # Test currently commented out
512
+ """
429
513
  PluginUnitTest(
430
514
  ScatterPlot,
431
515
  input_data=df,
@@ -435,3 +519,18 @@ if __name__ == "__main__":
435
519
  color="prediction_std",
436
520
  suppress_hover_display=True,
437
521
  ).run()
522
+ """
523
+
524
+ # Test with molecule hover (smiles column)
525
+ from workbench.api import FeatureSet
526
+
527
+ fs = FeatureSet("aqsol_features")
528
+ mol_df = fs.pull_dataframe()[:1000] # Limit to 1000 rows for testing
529
+
530
+ # Run the Unit Test with molecule data (hover over points to see molecule structures)
531
+ PluginUnitTest(
532
+ ScatterPlot,
533
+ input_data=mol_df,
534
+ theme="midnight_blue",
535
+ suppress_hover_display=True,
536
+ ).run()
@@ -0,0 +1,185 @@
1
+ """SettingsMenu: A settings menu component for the Workbench Dashboard."""
2
+
3
+ from dash import html, dcc
4
+ import dash_bootstrap_components as dbc
5
+
6
+ # Workbench Imports
7
+ from workbench.utils.theme_manager import ThemeManager
8
+
9
+
10
+ class SettingsMenu:
11
+ """A settings menu with admin links and theme selection."""
12
+
13
+ def __init__(self):
14
+ """Initialize the SettingsMenu."""
15
+ self.tm = ThemeManager()
16
+
17
+ def create_component(self, component_id: str) -> html.Div:
18
+ """Create a settings menu dropdown component.
19
+
20
+ Args:
21
+ component_id (str): The ID prefix for the component.
22
+
23
+ Returns:
24
+ html.Div: A Div containing the settings menu dropdown.
25
+ """
26
+ themes = self.tm.list_themes()
27
+
28
+ # Create theme submenu items
29
+ theme_items = []
30
+ for theme in sorted(themes):
31
+ theme_items.append(
32
+ dbc.DropdownMenuItem(
33
+ [
34
+ html.Span(
35
+ "",
36
+ id={"type": f"{component_id}-checkmark", "theme": theme},
37
+ style={
38
+ "fontFamily": "monospace",
39
+ "marginRight": "5px",
40
+ "width": "20px",
41
+ "display": "inline-block",
42
+ },
43
+ ),
44
+ theme.replace("_", " ").title(),
45
+ ],
46
+ id={"type": f"{component_id}-theme-item", "theme": theme},
47
+ )
48
+ )
49
+
50
+ # Hamburger icon (3 rounded lines)
51
+ hamburger_icon = html.Div(
52
+ [
53
+ html.Div(
54
+ style={
55
+ "width": "20px",
56
+ "height": "3px",
57
+ "backgroundColor": "currentColor",
58
+ "borderRadius": "2px",
59
+ "marginBottom": "4px",
60
+ }
61
+ ),
62
+ html.Div(
63
+ style={
64
+ "width": "20px",
65
+ "height": "3px",
66
+ "backgroundColor": "currentColor",
67
+ "borderRadius": "2px",
68
+ "marginBottom": "4px",
69
+ }
70
+ ),
71
+ html.Div(
72
+ style={"width": "20px", "height": "3px", "backgroundColor": "currentColor", "borderRadius": "2px"}
73
+ ),
74
+ ],
75
+ style={"display": "flex", "flexDirection": "column", "alignItems": "center", "justifyContent": "center"},
76
+ )
77
+
78
+ # Build menu items: Home, Status, License, divider, Themes submenu
79
+ menu_items = [
80
+ dbc.DropdownMenuItem("Home", href="/"),
81
+ dbc.DropdownMenuItem("Status", href="/status", external_link=True, target="_blank"),
82
+ dbc.DropdownMenuItem("License", href="/license", external_link=True, target="_blank"),
83
+ dbc.DropdownMenuItem(divider=True),
84
+ dbc.DropdownMenuItem("Themes", header=True),
85
+ *theme_items,
86
+ ]
87
+
88
+ return html.Div(
89
+ [
90
+ dbc.DropdownMenu(
91
+ label=hamburger_icon,
92
+ children=menu_items,
93
+ id=f"{component_id}-dropdown",
94
+ toggle_style={
95
+ "background": "transparent",
96
+ "border": "none",
97
+ "boxShadow": "none",
98
+ "padding": "5px 10px",
99
+ },
100
+ caret=False,
101
+ align_end=True,
102
+ ),
103
+ # Dummy store for the clientside callback output
104
+ dcc.Store(id=f"{component_id}-dummy", data=None),
105
+ # Store to trigger checkmark update on load
106
+ dcc.Store(id=f"{component_id}-init", data=True),
107
+ ],
108
+ id=component_id,
109
+ )
110
+
111
+ @staticmethod
112
+ def get_clientside_callback_code(component_id: str) -> str:
113
+ """Get the JavaScript code for the theme selection clientside callback.
114
+
115
+ Args:
116
+ component_id (str): The ID prefix used in create_component.
117
+
118
+ Returns:
119
+ str: JavaScript code for the clientside callback.
120
+ """
121
+ return """
122
+ function(n_clicks_list, ids) {
123
+ // Find which button was clicked
124
+ if (!n_clicks_list || n_clicks_list.every(n => !n)) {
125
+ return window.dash_clientside.no_update;
126
+ }
127
+
128
+ // Find the clicked theme
129
+ let clickedTheme = null;
130
+ for (let i = 0; i < n_clicks_list.length; i++) {
131
+ if (n_clicks_list[i]) {
132
+ clickedTheme = ids[i].theme;
133
+ break;
134
+ }
135
+ }
136
+
137
+ if (clickedTheme) {
138
+ // Store in localStorage
139
+ localStorage.setItem('wb_theme', clickedTheme);
140
+ // Set cookie for Flask to read on reload
141
+ document.cookie = `wb_theme=${clickedTheme}; path=/; max-age=31536000`;
142
+ // Reload the page to apply the new theme
143
+ window.location.reload();
144
+ }
145
+
146
+ return window.dash_clientside.no_update;
147
+ }
148
+ """
149
+
150
+ @staticmethod
151
+ def get_checkmark_callback_code() -> str:
152
+ """Get the JavaScript code to update checkmarks based on localStorage.
153
+
154
+ Returns:
155
+ str: JavaScript code for the checkmark update callback.
156
+ """
157
+ return """
158
+ function(init, ids) {
159
+ // Get current theme from localStorage (or cookie as fallback)
160
+ let currentTheme = localStorage.getItem('wb_theme');
161
+ if (!currentTheme) {
162
+ // Try to read from cookie
163
+ const cookies = document.cookie.split(';');
164
+ for (let cookie of cookies) {
165
+ const [name, value] = cookie.trim().split('=');
166
+ if (name === 'wb_theme') {
167
+ currentTheme = value;
168
+ break;
169
+ }
170
+ }
171
+ }
172
+
173
+ // Return checkmarks for each theme
174
+ return ids.map(id => id.theme === currentTheme ? '\u2713' : '');
175
+ }
176
+ """
177
+
178
+
179
+ if __name__ == "__main__":
180
+ # Quick test to verify component creation
181
+ menu = SettingsMenu()
182
+ component = menu.create_component("test-settings-menu")
183
+ print("SettingsMenu component created successfully")
184
+ print(f"Available themes: {menu.tm.list_themes()}")
185
+ print(f"Current theme: {menu.tm.current_theme()}")