workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +2 -0
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
- workbench/algorithms/dataframe/projection_2d.py +8 -2
- workbench/algorithms/dataframe/proximity.py +3 -0
- workbench/algorithms/dataframe/smart_aggregator.py +161 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/api/feature_set.py +0 -1
- workbench/api/meta.py +0 -1
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +37 -7
- workbench/core/artifacts/endpoint_core.py +12 -2
- workbench/core/artifacts/feature_set_core.py +238 -225
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/transforms/features_to_model/features_to_model.py +2 -8
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
- workbench/model_script_utils/model_script_utils.py +30 -0
- workbench/model_script_utils/uq_harness.py +0 -1
- workbench/model_scripts/chemprop/chemprop.template +196 -68
- workbench/model_scripts/chemprop/generated_model_script.py +197 -72
- workbench/model_scripts/chemprop/model_script_utils.py +30 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
- workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
- workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
- workbench/model_scripts/pytorch_model/pytorch.template +47 -29
- workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
- workbench/model_scripts/script_generation.py +0 -1
- workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
- workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
- workbench/model_scripts/xgb_model/uq_harness.py +0 -1
- workbench/scripts/ml_pipeline_sqs.py +71 -2
- workbench/themes/dark/custom.css +85 -8
- workbench/themes/dark/plotly.json +6 -6
- workbench/themes/light/custom.css +172 -64
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +82 -29
- workbench/themes/midnight_blue/plotly.json +1 -1
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +137 -27
- workbench/utils/clientside_callbacks.py +41 -0
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/model_utils.py +0 -1
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +52 -36
- workbench/utils/theme_manager.py +95 -30
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +2 -0
- workbench/web_interface/components/plugin_unit_test.py +0 -1
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +10 -6
- workbench/web_interface/components/plugins/scatter_plot.py +184 -85
- workbench/web_interface/components/settings_menu.py +185 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import base64
|
|
2
|
+
import numpy as np
|
|
2
3
|
import pandas as pd
|
|
3
|
-
from dash import dcc, html, callback, Input, Output, no_update
|
|
4
|
+
from dash import dcc, html, callback, clientside_callback, Input, Output, no_update
|
|
4
5
|
import plotly.graph_objects as go
|
|
5
6
|
import plotly.express as px
|
|
6
7
|
from dash.exceptions import PreventUpdate
|
|
@@ -9,6 +10,8 @@ from dash.exceptions import PreventUpdate
|
|
|
9
10
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
10
11
|
from workbench.utils.theme_manager import ThemeManager
|
|
11
12
|
from workbench.utils.plot_utils import prediction_intervals
|
|
13
|
+
from workbench.utils.chem_utils.vis import molecule_hover_tooltip
|
|
14
|
+
from workbench.utils.clientside_callbacks import circle_overlay_callback
|
|
12
15
|
|
|
13
16
|
|
|
14
17
|
class ScatterPlot(PluginInterface):
|
|
@@ -18,6 +21,12 @@ class ScatterPlot(PluginInterface):
|
|
|
18
21
|
auto_load_page = PluginPage.NONE
|
|
19
22
|
plugin_input_type = PluginInputType.DATAFRAME
|
|
20
23
|
|
|
24
|
+
# Pre-computed circle overlay SVG
|
|
25
|
+
_circle_svg = """<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" style="overflow: visible;">
|
|
26
|
+
<circle cx="50" cy="50" r="10" stroke="rgba(255, 255, 255, 1)" stroke-width="3" fill="none" />
|
|
27
|
+
</svg>"""
|
|
28
|
+
_circle_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(_circle_svg.encode('utf-8')).decode('utf-8')}"
|
|
29
|
+
|
|
21
30
|
def __init__(self, show_axes: bool = True):
|
|
22
31
|
"""Initialize the Scatter Plot Plugin
|
|
23
32
|
|
|
@@ -29,7 +38,10 @@ class ScatterPlot(PluginInterface):
|
|
|
29
38
|
self.df = None
|
|
30
39
|
self.show_axes = show_axes
|
|
31
40
|
self.theme_manager = ThemeManager()
|
|
32
|
-
self.
|
|
41
|
+
self.has_smiles = False # Track if dataframe has smiles column for molecule hover
|
|
42
|
+
self.smiles_column = None
|
|
43
|
+
self.id_column = None
|
|
44
|
+
self.hover_background = None # Cached background color for molecule hover tooltip
|
|
33
45
|
|
|
34
46
|
# Call the parent class constructor
|
|
35
47
|
super().__init__()
|
|
@@ -51,10 +63,10 @@ class ScatterPlot(PluginInterface):
|
|
|
51
63
|
(f"{component_id}-x-dropdown", "options"),
|
|
52
64
|
(f"{component_id}-y-dropdown", "options"),
|
|
53
65
|
(f"{component_id}-color-dropdown", "options"),
|
|
54
|
-
(f"{component_id}-label-dropdown", "options"),
|
|
55
66
|
(f"{component_id}-x-dropdown", "value"),
|
|
56
67
|
(f"{component_id}-y-dropdown", "value"),
|
|
57
68
|
(f"{component_id}-color-dropdown", "value"),
|
|
69
|
+
(f"{component_id}-regression-line", "value"),
|
|
58
70
|
]
|
|
59
71
|
self.signals = [(f"{component_id}-graph", "hoverData"), (f"{component_id}-graph", "clickData")]
|
|
60
72
|
|
|
@@ -69,51 +81,67 @@ class ScatterPlot(PluginInterface):
|
|
|
69
81
|
id=f"{component_id}-graph",
|
|
70
82
|
figure=self.display_text("Waiting for Data..."),
|
|
71
83
|
config={"scrollZoom": True},
|
|
72
|
-
style={"height": "100%"},
|
|
84
|
+
style={"height": "500px", "width": "100%"},
|
|
73
85
|
clear_on_unhover=True,
|
|
74
86
|
),
|
|
75
87
|
# Controls: X, Y, Color, Label Dropdowns, and Regression Line Checkbox
|
|
76
88
|
html.Div(
|
|
77
89
|
[
|
|
78
|
-
html.Label(
|
|
90
|
+
html.Label(
|
|
91
|
+
"X",
|
|
92
|
+
style={
|
|
93
|
+
"marginLeft": "20px",
|
|
94
|
+
"marginRight": "5px",
|
|
95
|
+
"fontWeight": "bold",
|
|
96
|
+
"display": "flex",
|
|
97
|
+
"alignItems": "center",
|
|
98
|
+
},
|
|
99
|
+
),
|
|
79
100
|
dcc.Dropdown(
|
|
80
101
|
id=f"{component_id}-x-dropdown",
|
|
81
|
-
|
|
82
|
-
style={"min-width": "50px", "flex": 1}, # Responsive width
|
|
102
|
+
style={"minWidth": "150px", "flex": 1},
|
|
83
103
|
clearable=False,
|
|
84
104
|
),
|
|
85
|
-
html.Label(
|
|
105
|
+
html.Label(
|
|
106
|
+
"Y",
|
|
107
|
+
style={
|
|
108
|
+
"marginLeft": "20px",
|
|
109
|
+
"marginRight": "5px",
|
|
110
|
+
"fontWeight": "bold",
|
|
111
|
+
"display": "flex",
|
|
112
|
+
"alignItems": "center",
|
|
113
|
+
},
|
|
114
|
+
),
|
|
86
115
|
dcc.Dropdown(
|
|
87
116
|
id=f"{component_id}-y-dropdown",
|
|
88
|
-
|
|
89
|
-
style={"min-width": "50px", "flex": 1}, # Responsive width
|
|
117
|
+
style={"minWidth": "150px", "flex": 1},
|
|
90
118
|
clearable=False,
|
|
91
119
|
),
|
|
92
|
-
html.Label(
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
120
|
+
html.Label(
|
|
121
|
+
"Color",
|
|
122
|
+
style={
|
|
123
|
+
"marginLeft": "20px",
|
|
124
|
+
"marginRight": "5px",
|
|
125
|
+
"fontWeight": "bold",
|
|
126
|
+
"display": "flex",
|
|
127
|
+
"alignItems": "center",
|
|
128
|
+
},
|
|
98
129
|
),
|
|
99
|
-
html.Label("Label", style={"marginLeft": "30px", "marginRight": "5px", "fontWeight": "bold"}),
|
|
100
130
|
dcc.Dropdown(
|
|
101
|
-
id=f"{component_id}-
|
|
102
|
-
|
|
103
|
-
style={"min-width": "50px", "flex": 1},
|
|
104
|
-
options=[{"label": "None", "value": "none"}],
|
|
105
|
-
value="none",
|
|
131
|
+
id=f"{component_id}-color-dropdown",
|
|
132
|
+
style={"minWidth": "150px", "flex": 1},
|
|
106
133
|
clearable=False,
|
|
107
134
|
),
|
|
108
135
|
dcc.Checklist(
|
|
109
136
|
id=f"{component_id}-regression-line",
|
|
110
137
|
options=[{"label": " Diagonal", "value": "show"}],
|
|
111
138
|
value=[],
|
|
112
|
-
style={"
|
|
139
|
+
style={"marginLeft": "20px", "display": "flex", "alignItems": "center"},
|
|
113
140
|
),
|
|
114
141
|
],
|
|
115
|
-
style={"padding": "0px 0px 10px 0px", "display": "flex", "gap": "
|
|
142
|
+
style={"padding": "0px 0px 10px 0px", "display": "flex", "alignItems": "center", "gap": "5px"},
|
|
116
143
|
),
|
|
144
|
+
# Circle overlay tooltip (centered on hovered point)
|
|
117
145
|
dcc.Tooltip(
|
|
118
146
|
id=f"{component_id}-overlay",
|
|
119
147
|
background_color="rgba(0,0,0,0)",
|
|
@@ -121,6 +149,14 @@ class ScatterPlot(PluginInterface):
|
|
|
121
149
|
direction="bottom",
|
|
122
150
|
loading_text="",
|
|
123
151
|
),
|
|
152
|
+
# Molecule tooltip (offset from hovered point) - only used when smiles column exists
|
|
153
|
+
dcc.Tooltip(
|
|
154
|
+
id=f"{component_id}-molecule-tooltip",
|
|
155
|
+
background_color="rgba(0,0,0,0)",
|
|
156
|
+
border_color="rgba(0,0,0,0)",
|
|
157
|
+
direction="bottom",
|
|
158
|
+
loading_text="",
|
|
159
|
+
),
|
|
124
160
|
],
|
|
125
161
|
style={"height": "100%", "display": "flex", "flexDirection": "column"}, # Full viewport height
|
|
126
162
|
)
|
|
@@ -139,12 +175,16 @@ class ScatterPlot(PluginInterface):
|
|
|
139
175
|
- hover_columns: The columns to show when hovering over a point
|
|
140
176
|
- suppress_hover_display: Suppress hover display (default: False)
|
|
141
177
|
- custom_data: Custom data that get passed to hoverData callbacks
|
|
178
|
+
- id_column: Column to use for molecule tooltip header (auto-detects "id" if not specified)
|
|
142
179
|
|
|
143
180
|
Returns:
|
|
144
181
|
list: A list of updated property values (figure, x options, y options, color options,
|
|
145
|
-
|
|
146
|
-
color default).
|
|
182
|
+
x default, y default, color default).
|
|
147
183
|
"""
|
|
184
|
+
# Get the colorscale and background color from the current theme
|
|
185
|
+
self.colorscale = self.theme_manager.colorscale()
|
|
186
|
+
self.hover_background = self.theme_manager.background()
|
|
187
|
+
|
|
148
188
|
# Get the limit for the number of rows to plot
|
|
149
189
|
limit = kwargs.get("limit", 20000)
|
|
150
190
|
|
|
@@ -163,19 +203,28 @@ class ScatterPlot(PluginInterface):
|
|
|
163
203
|
self.suppress_hover_display = kwargs.get("suppress_hover_display", False)
|
|
164
204
|
self.custom_data = kwargs.get("custom_data", [])
|
|
165
205
|
|
|
206
|
+
# Check if the dataframe has smiles/id columns for molecule hover rendering
|
|
207
|
+
self.smiles_column = next((col for col in self.df.columns if col.lower() == "smiles"), None)
|
|
208
|
+
# Use provided id_column, or auto-detect "id" column, or fall back to first column
|
|
209
|
+
self.id_column = kwargs.get("id_column") or next(
|
|
210
|
+
(col for col in self.df.columns if col.lower() == "id"), self.df.columns[0]
|
|
211
|
+
)
|
|
212
|
+
self.has_smiles = self.smiles_column is not None
|
|
213
|
+
|
|
166
214
|
# Identify numeric columns
|
|
167
215
|
numeric_columns = self.df.select_dtypes(include="number").columns.tolist()
|
|
168
216
|
if len(numeric_columns) < 3:
|
|
169
217
|
raise ValueError("At least three numeric columns are required for x, y, and color.")
|
|
170
218
|
|
|
171
|
-
# Default x, y, and color (for color,
|
|
219
|
+
# Default x, y, and color (for color, prefer 'confidence' if it exists)
|
|
172
220
|
x_default = kwargs.get("x", numeric_columns[0])
|
|
173
221
|
y_default = kwargs.get("y", numeric_columns[1])
|
|
174
|
-
|
|
222
|
+
default_color = "confidence" if "confidence" in self.df.columns else numeric_columns[2]
|
|
223
|
+
color_default = kwargs.get("color", default_color)
|
|
175
224
|
regression_line = kwargs.get("regression_line", False)
|
|
176
225
|
|
|
177
226
|
# Create the default scatter plot
|
|
178
|
-
figure = self.create_scatter_plot(self.df, x_default, y_default, color_default,
|
|
227
|
+
figure = self.create_scatter_plot(self.df, x_default, y_default, color_default, regression_line)
|
|
179
228
|
|
|
180
229
|
# Dropdown options for x and y: use provided dropdown_columns or fallback to numeric columns
|
|
181
230
|
dropdown_columns = kwargs.get("dropdown_columns", numeric_columns)
|
|
@@ -188,11 +237,10 @@ class ScatterPlot(PluginInterface):
|
|
|
188
237
|
color_columns = numeric_columns + cat_columns
|
|
189
238
|
color_options = [{"label": col, "value": col} for col in color_columns]
|
|
190
239
|
|
|
191
|
-
#
|
|
192
|
-
|
|
193
|
-
label_options.extend([{"label": col, "value": col} for col in self.df.columns])
|
|
240
|
+
# Regression line checklist value (list with "show" if enabled, empty list if disabled)
|
|
241
|
+
regression_line_value = ["show"] if regression_line else []
|
|
194
242
|
|
|
195
|
-
return [figure, x_options, y_options, color_options,
|
|
243
|
+
return [figure, x_options, y_options, color_options, x_default, y_default, color_default, regression_line_value]
|
|
196
244
|
|
|
197
245
|
def create_scatter_plot(
|
|
198
246
|
self,
|
|
@@ -200,9 +248,7 @@ class ScatterPlot(PluginInterface):
|
|
|
200
248
|
x_col: str,
|
|
201
249
|
y_col: str,
|
|
202
250
|
color_col: str,
|
|
203
|
-
label_col: str,
|
|
204
251
|
regression_line: bool = False,
|
|
205
|
-
marker_size: int = 15,
|
|
206
252
|
) -> go.Figure:
|
|
207
253
|
"""Create a Plotly Scatter Plot figure.
|
|
208
254
|
|
|
@@ -211,24 +257,46 @@ class ScatterPlot(PluginInterface):
|
|
|
211
257
|
x_col (str): The column to use for the x-axis.
|
|
212
258
|
y_col (str): The column to use for the y-axis.
|
|
213
259
|
color_col (str): The column to use for the color scale.
|
|
214
|
-
label_col (str): The column to use for point labels.
|
|
215
260
|
regression_line (bool): Whether to include a regression line.
|
|
216
|
-
marker_size (int): Size of the markers. Default is 15.
|
|
217
261
|
|
|
218
262
|
Returns:
|
|
219
263
|
go.Figure: A Plotly Figure object.
|
|
220
264
|
"""
|
|
221
|
-
|
|
222
|
-
|
|
265
|
+
|
|
266
|
+
# If aggregation_count is present, sort so largest counts are drawn first (underneath)
|
|
267
|
+
# and compute marker sizes using square root (between log and linear)
|
|
268
|
+
if "aggregation_count" in df.columns:
|
|
269
|
+
df = df.sort_values("aggregation_count", ascending=False).reset_index(drop=True)
|
|
270
|
+
# Scale: base_size (15) + (sqrt(count) - 1) * factor, so count=1 stays at base_size
|
|
271
|
+
marker_sizes = 15 + (np.sqrt(df["aggregation_count"]) - 1) * 3
|
|
272
|
+
else:
|
|
273
|
+
marker_sizes = 15
|
|
223
274
|
|
|
224
275
|
# Helper to generate hover text for each point.
|
|
225
276
|
def generate_hover_text(row):
|
|
226
277
|
return "<br>".join([f"{col}: {row[col]}" for col in self.hover_columns])
|
|
227
278
|
|
|
228
|
-
# Generate hover text for all points
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
279
|
+
# Generate hover text for all points (unless suppressed or using molecule hover)
|
|
280
|
+
suppress_hover = self.suppress_hover_display or self.has_smiles
|
|
281
|
+
if suppress_hover:
|
|
282
|
+
# Use "none" to hide the default hover display but still fire hoverData callbacks
|
|
283
|
+
# Don't set hovertemplate when suppressing - it would override hoverinfo
|
|
284
|
+
hovertext = None
|
|
285
|
+
hovertemplate = None
|
|
286
|
+
hoverinfo = "none"
|
|
287
|
+
else:
|
|
288
|
+
hovertext = df.apply(generate_hover_text, axis=1)
|
|
289
|
+
hovertemplate = "%{hovertext}<extra></extra>"
|
|
290
|
+
hoverinfo = None
|
|
291
|
+
|
|
292
|
+
# Build customdata columns - include smiles and id if available for molecule hover
|
|
293
|
+
custom_data_cols = list(self.custom_data) if self.custom_data else []
|
|
294
|
+
if self.has_smiles:
|
|
295
|
+
# Add smiles as first column, id as second (if available)
|
|
296
|
+
if self.smiles_column not in custom_data_cols:
|
|
297
|
+
custom_data_cols = [self.smiles_column] + custom_data_cols
|
|
298
|
+
if self.id_column and self.id_column not in custom_data_cols:
|
|
299
|
+
custom_data_cols.insert(1, self.id_column)
|
|
232
300
|
|
|
233
301
|
# Determine marker settings based on the type of the color column.
|
|
234
302
|
if pd.api.types.is_numeric_dtype(df[color_col]):
|
|
@@ -240,18 +308,16 @@ class ScatterPlot(PluginInterface):
|
|
|
240
308
|
x=df[x_col],
|
|
241
309
|
y=df[y_col],
|
|
242
310
|
mode="markers",
|
|
243
|
-
text=df[label_col].astype(str) if show_labels else None,
|
|
244
|
-
textposition="top center",
|
|
245
311
|
hoverinfo=hoverinfo,
|
|
246
312
|
hovertext=hovertext,
|
|
247
313
|
hovertemplate=hovertemplate,
|
|
248
|
-
customdata=df[
|
|
314
|
+
customdata=df[custom_data_cols] if custom_data_cols else None,
|
|
249
315
|
marker=dict(
|
|
250
|
-
size=
|
|
316
|
+
size=marker_sizes,
|
|
251
317
|
color=marker_color,
|
|
252
318
|
colorscale=self.colorscale,
|
|
253
319
|
colorbar=colorbar,
|
|
254
|
-
opacity=0.
|
|
320
|
+
opacity=0.9,
|
|
255
321
|
line=dict(color="rgba(0,0,0,0.25)", width=1),
|
|
256
322
|
),
|
|
257
323
|
)
|
|
@@ -266,20 +332,27 @@ class ScatterPlot(PluginInterface):
|
|
|
266
332
|
data = []
|
|
267
333
|
for i, cat in enumerate(categories):
|
|
268
334
|
sub_df = df[df[color_col] == cat]
|
|
269
|
-
sub_hovertext = hovertext.loc[sub_df.index]
|
|
335
|
+
sub_hovertext = hovertext.loc[sub_df.index] if hovertext is not None else None
|
|
336
|
+
# Get marker sizes for this subset (handles both array and scalar)
|
|
337
|
+
if isinstance(marker_sizes, (pd.Series, np.ndarray)):
|
|
338
|
+
sub_marker_sizes = (
|
|
339
|
+
marker_sizes.loc[sub_df.index]
|
|
340
|
+
if isinstance(marker_sizes, pd.Series)
|
|
341
|
+
else marker_sizes[sub_df.index]
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
sub_marker_sizes = marker_sizes
|
|
270
345
|
trace = go.Scattergl(
|
|
271
346
|
x=sub_df[x_col],
|
|
272
347
|
y=sub_df[y_col],
|
|
273
348
|
mode="markers",
|
|
274
|
-
text=sub_df[label_col] if show_labels else None, # Add text if labels enabled
|
|
275
|
-
textposition="top center", # Position labels above points
|
|
276
349
|
name=cat,
|
|
277
350
|
hoverinfo=hoverinfo,
|
|
278
351
|
hovertext=sub_hovertext,
|
|
279
352
|
hovertemplate=hovertemplate,
|
|
280
|
-
customdata=sub_df[
|
|
353
|
+
customdata=sub_df[custom_data_cols] if custom_data_cols else None,
|
|
281
354
|
marker=dict(
|
|
282
|
-
size=
|
|
355
|
+
size=sub_marker_sizes,
|
|
283
356
|
color=discrete_colors[i % len(discrete_colors)],
|
|
284
357
|
opacity=0.8,
|
|
285
358
|
line=dict(color="rgba(0,0,0,0.25)", width=1),
|
|
@@ -345,64 +418,73 @@ class ScatterPlot(PluginInterface):
|
|
|
345
418
|
Input(f"{self.component_id}-x-dropdown", "value"),
|
|
346
419
|
Input(f"{self.component_id}-y-dropdown", "value"),
|
|
347
420
|
Input(f"{self.component_id}-color-dropdown", "value"),
|
|
348
|
-
Input(f"{self.component_id}-label-dropdown", "value"),
|
|
349
421
|
Input(f"{self.component_id}-regression-line", "value"),
|
|
350
422
|
],
|
|
351
423
|
prevent_initial_call=True,
|
|
352
424
|
)
|
|
353
|
-
def _update_scatter_plot(x_value, y_value, color_value,
|
|
425
|
+
def _update_scatter_plot(x_value, y_value, color_value, regression_line):
|
|
354
426
|
"""Update the Scatter Plot Graph based on the dropdown values."""
|
|
355
427
|
|
|
356
428
|
# Check if the dataframe is not empty and the values are not None
|
|
357
429
|
if not self.df.empty and x_value and y_value and color_value:
|
|
358
|
-
|
|
359
|
-
figure = self.create_scatter_plot(self.df, x_value, y_value, color_value, label_value, regression_line)
|
|
430
|
+
figure = self.create_scatter_plot(self.df, x_value, y_value, color_value, regression_line)
|
|
360
431
|
return figure
|
|
361
432
|
|
|
362
433
|
raise PreventUpdate
|
|
363
434
|
|
|
364
|
-
|
|
435
|
+
# Clientside callback for circle overlay - runs in browser, no server round trip
|
|
436
|
+
clientside_callback(
|
|
437
|
+
circle_overlay_callback(self._circle_data_uri),
|
|
365
438
|
Output(f"{self.component_id}-overlay", "show"),
|
|
366
439
|
Output(f"{self.component_id}-overlay", "bbox"),
|
|
367
440
|
Output(f"{self.component_id}-overlay", "children"),
|
|
368
441
|
Input(f"{self.component_id}-graph", "hoverData"),
|
|
369
442
|
)
|
|
370
|
-
def _scatter_overlay(hover_data):
|
|
371
|
-
if hover_data is None:
|
|
372
|
-
# Hide the overlay if no hover data
|
|
373
|
-
return False, no_update, no_update
|
|
374
|
-
|
|
375
|
-
# Extract bounding box from hoverData
|
|
376
|
-
bbox = hover_data["points"][0]["bbox"]
|
|
377
443
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
444
|
+
@callback(
|
|
445
|
+
Output(f"{self.component_id}-molecule-tooltip", "show"),
|
|
446
|
+
Output(f"{self.component_id}-molecule-tooltip", "bbox"),
|
|
447
|
+
Output(f"{self.component_id}-molecule-tooltip", "children"),
|
|
448
|
+
Input(f"{self.component_id}-graph", "hoverData"),
|
|
449
|
+
)
|
|
450
|
+
def _scatter_molecule_overlay(hover_data):
|
|
451
|
+
"""Show molecule tooltip when smiles data is available."""
|
|
452
|
+
if hover_data is None or not self.has_smiles:
|
|
453
|
+
return False, no_update, no_update
|
|
385
454
|
|
|
386
|
-
#
|
|
387
|
-
|
|
388
|
-
|
|
455
|
+
# Extract customdata (contains smiles and id)
|
|
456
|
+
customdata = hover_data["points"][0].get("customdata")
|
|
457
|
+
if customdata is None:
|
|
458
|
+
return False, no_update, no_update
|
|
389
459
|
|
|
390
|
-
#
|
|
391
|
-
|
|
460
|
+
# SMILES is the first element, ID is second (if available)
|
|
461
|
+
if isinstance(customdata, (list, tuple)):
|
|
462
|
+
smiles = customdata[0]
|
|
463
|
+
mol_id = customdata[1] if len(customdata) > 1 and self.id_column else None
|
|
464
|
+
else:
|
|
465
|
+
smiles = customdata
|
|
466
|
+
mol_id = None
|
|
467
|
+
|
|
468
|
+
# Generate molecule tooltip with ID header (use cached background color)
|
|
469
|
+
mol_width, mol_height = 300, 200
|
|
470
|
+
children = molecule_hover_tooltip(
|
|
471
|
+
smiles, mol_id=mol_id, width=mol_width, height=mol_height, background=self.hover_background
|
|
472
|
+
)
|
|
392
473
|
|
|
393
|
-
#
|
|
474
|
+
# Position molecule tooltip above and slightly right of the point
|
|
475
|
+
bbox = hover_data["points"][0]["bbox"]
|
|
394
476
|
center_x = (bbox["x0"] + bbox["x1"]) / 2
|
|
395
477
|
center_y = (bbox["y0"] + bbox["y1"]) / 2
|
|
478
|
+
x_offset = 5 # Slight offset to the right
|
|
479
|
+
y_offset = mol_height + 50 # Above the point
|
|
396
480
|
|
|
397
|
-
# The tooltip should be centered on the point (note: 'bottom' tooltip, so we adjust y position)
|
|
398
481
|
adjusted_bbox = {
|
|
399
|
-
"x0": center_x
|
|
400
|
-
"x1": center_x +
|
|
401
|
-
"y0": center_y -
|
|
402
|
-
"y1": center_y -
|
|
482
|
+
"x0": center_x + x_offset,
|
|
483
|
+
"x1": center_x + x_offset + mol_width,
|
|
484
|
+
"y0": center_y - mol_height - y_offset,
|
|
485
|
+
"y1": center_y - y_offset,
|
|
403
486
|
}
|
|
404
|
-
|
|
405
|
-
return True, adjusted_bbox, [svg_image]
|
|
487
|
+
return True, adjusted_bbox, children
|
|
406
488
|
|
|
407
489
|
|
|
408
490
|
if __name__ == "__main__":
|
|
@@ -426,6 +508,8 @@ if __name__ == "__main__":
|
|
|
426
508
|
df = model.get_inference_predictions("full_cross_fold")
|
|
427
509
|
|
|
428
510
|
# Run the Unit Test on the Plugin
|
|
511
|
+
# Test currently commented out
|
|
512
|
+
"""
|
|
429
513
|
PluginUnitTest(
|
|
430
514
|
ScatterPlot,
|
|
431
515
|
input_data=df,
|
|
@@ -435,3 +519,18 @@ if __name__ == "__main__":
|
|
|
435
519
|
color="prediction_std",
|
|
436
520
|
suppress_hover_display=True,
|
|
437
521
|
).run()
|
|
522
|
+
"""
|
|
523
|
+
|
|
524
|
+
# Test with molecule hover (smiles column)
|
|
525
|
+
from workbench.api import FeatureSet
|
|
526
|
+
|
|
527
|
+
fs = FeatureSet("aqsol_features")
|
|
528
|
+
mol_df = fs.pull_dataframe()[:1000] # Limit to 1000 rows for testing
|
|
529
|
+
|
|
530
|
+
# Run the Unit Test with molecule data (hover over points to see molecule structures)
|
|
531
|
+
PluginUnitTest(
|
|
532
|
+
ScatterPlot,
|
|
533
|
+
input_data=mol_df,
|
|
534
|
+
theme="midnight_blue",
|
|
535
|
+
suppress_hover_display=True,
|
|
536
|
+
).run()
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""SettingsMenu: A settings menu component for the Workbench Dashboard."""
|
|
2
|
+
|
|
3
|
+
from dash import html, dcc
|
|
4
|
+
import dash_bootstrap_components as dbc
|
|
5
|
+
|
|
6
|
+
# Workbench Imports
|
|
7
|
+
from workbench.utils.theme_manager import ThemeManager
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SettingsMenu:
|
|
11
|
+
"""A settings menu with admin links and theme selection."""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
"""Initialize the SettingsMenu."""
|
|
15
|
+
self.tm = ThemeManager()
|
|
16
|
+
|
|
17
|
+
def create_component(self, component_id: str) -> html.Div:
|
|
18
|
+
"""Create a settings menu dropdown component.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
component_id (str): The ID prefix for the component.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
html.Div: A Div containing the settings menu dropdown.
|
|
25
|
+
"""
|
|
26
|
+
themes = self.tm.list_themes()
|
|
27
|
+
|
|
28
|
+
# Create theme submenu items
|
|
29
|
+
theme_items = []
|
|
30
|
+
for theme in sorted(themes):
|
|
31
|
+
theme_items.append(
|
|
32
|
+
dbc.DropdownMenuItem(
|
|
33
|
+
[
|
|
34
|
+
html.Span(
|
|
35
|
+
"",
|
|
36
|
+
id={"type": f"{component_id}-checkmark", "theme": theme},
|
|
37
|
+
style={
|
|
38
|
+
"fontFamily": "monospace",
|
|
39
|
+
"marginRight": "5px",
|
|
40
|
+
"width": "20px",
|
|
41
|
+
"display": "inline-block",
|
|
42
|
+
},
|
|
43
|
+
),
|
|
44
|
+
theme.replace("_", " ").title(),
|
|
45
|
+
],
|
|
46
|
+
id={"type": f"{component_id}-theme-item", "theme": theme},
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Hamburger icon (3 rounded lines)
|
|
51
|
+
hamburger_icon = html.Div(
|
|
52
|
+
[
|
|
53
|
+
html.Div(
|
|
54
|
+
style={
|
|
55
|
+
"width": "20px",
|
|
56
|
+
"height": "3px",
|
|
57
|
+
"backgroundColor": "currentColor",
|
|
58
|
+
"borderRadius": "2px",
|
|
59
|
+
"marginBottom": "4px",
|
|
60
|
+
}
|
|
61
|
+
),
|
|
62
|
+
html.Div(
|
|
63
|
+
style={
|
|
64
|
+
"width": "20px",
|
|
65
|
+
"height": "3px",
|
|
66
|
+
"backgroundColor": "currentColor",
|
|
67
|
+
"borderRadius": "2px",
|
|
68
|
+
"marginBottom": "4px",
|
|
69
|
+
}
|
|
70
|
+
),
|
|
71
|
+
html.Div(
|
|
72
|
+
style={"width": "20px", "height": "3px", "backgroundColor": "currentColor", "borderRadius": "2px"}
|
|
73
|
+
),
|
|
74
|
+
],
|
|
75
|
+
style={"display": "flex", "flexDirection": "column", "alignItems": "center", "justifyContent": "center"},
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Build menu items: Home, Status, License, divider, Themes submenu
|
|
79
|
+
menu_items = [
|
|
80
|
+
dbc.DropdownMenuItem("Home", href="/"),
|
|
81
|
+
dbc.DropdownMenuItem("Status", href="/status", external_link=True, target="_blank"),
|
|
82
|
+
dbc.DropdownMenuItem("License", href="/license", external_link=True, target="_blank"),
|
|
83
|
+
dbc.DropdownMenuItem(divider=True),
|
|
84
|
+
dbc.DropdownMenuItem("Themes", header=True),
|
|
85
|
+
*theme_items,
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
return html.Div(
|
|
89
|
+
[
|
|
90
|
+
dbc.DropdownMenu(
|
|
91
|
+
label=hamburger_icon,
|
|
92
|
+
children=menu_items,
|
|
93
|
+
id=f"{component_id}-dropdown",
|
|
94
|
+
toggle_style={
|
|
95
|
+
"background": "transparent",
|
|
96
|
+
"border": "none",
|
|
97
|
+
"boxShadow": "none",
|
|
98
|
+
"padding": "5px 10px",
|
|
99
|
+
},
|
|
100
|
+
caret=False,
|
|
101
|
+
align_end=True,
|
|
102
|
+
),
|
|
103
|
+
# Dummy store for the clientside callback output
|
|
104
|
+
dcc.Store(id=f"{component_id}-dummy", data=None),
|
|
105
|
+
# Store to trigger checkmark update on load
|
|
106
|
+
dcc.Store(id=f"{component_id}-init", data=True),
|
|
107
|
+
],
|
|
108
|
+
id=component_id,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def get_clientside_callback_code(component_id: str) -> str:
|
|
113
|
+
"""Get the JavaScript code for the theme selection clientside callback.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
component_id (str): The ID prefix used in create_component.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
str: JavaScript code for the clientside callback.
|
|
120
|
+
"""
|
|
121
|
+
return """
|
|
122
|
+
function(n_clicks_list, ids) {
|
|
123
|
+
// Find which button was clicked
|
|
124
|
+
if (!n_clicks_list || n_clicks_list.every(n => !n)) {
|
|
125
|
+
return window.dash_clientside.no_update;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
// Find the clicked theme
|
|
129
|
+
let clickedTheme = null;
|
|
130
|
+
for (let i = 0; i < n_clicks_list.length; i++) {
|
|
131
|
+
if (n_clicks_list[i]) {
|
|
132
|
+
clickedTheme = ids[i].theme;
|
|
133
|
+
break;
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
if (clickedTheme) {
|
|
138
|
+
// Store in localStorage
|
|
139
|
+
localStorage.setItem('wb_theme', clickedTheme);
|
|
140
|
+
// Set cookie for Flask to read on reload
|
|
141
|
+
document.cookie = `wb_theme=${clickedTheme}; path=/; max-age=31536000`;
|
|
142
|
+
// Reload the page to apply the new theme
|
|
143
|
+
window.location.reload();
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
return window.dash_clientside.no_update;
|
|
147
|
+
}
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
@staticmethod
|
|
151
|
+
def get_checkmark_callback_code() -> str:
|
|
152
|
+
"""Get the JavaScript code to update checkmarks based on localStorage.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
str: JavaScript code for the checkmark update callback.
|
|
156
|
+
"""
|
|
157
|
+
return """
|
|
158
|
+
function(init, ids) {
|
|
159
|
+
// Get current theme from localStorage (or cookie as fallback)
|
|
160
|
+
let currentTheme = localStorage.getItem('wb_theme');
|
|
161
|
+
if (!currentTheme) {
|
|
162
|
+
// Try to read from cookie
|
|
163
|
+
const cookies = document.cookie.split(';');
|
|
164
|
+
for (let cookie of cookies) {
|
|
165
|
+
const [name, value] = cookie.trim().split('=');
|
|
166
|
+
if (name === 'wb_theme') {
|
|
167
|
+
currentTheme = value;
|
|
168
|
+
break;
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
// Return checkmarks for each theme
|
|
174
|
+
return ids.map(id => id.theme === currentTheme ? '\u2713' : '');
|
|
175
|
+
}
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
if __name__ == "__main__":
|
|
180
|
+
# Quick test to verify component creation
|
|
181
|
+
menu = SettingsMenu()
|
|
182
|
+
component = menu.create_component("test-settings-menu")
|
|
183
|
+
print("SettingsMenu component created successfully")
|
|
184
|
+
print(f"Available themes: {menu.tm.list_themes()}")
|
|
185
|
+
print(f"Current theme: {menu.tm.current_theme()}")
|