workbench 0.8.217__py3-none-any.whl → 0.8.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
- workbench/algorithms/dataframe/projection_2d.py +8 -2
- workbench/algorithms/dataframe/proximity.py +3 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/feature_set.py +0 -1
- workbench/core/artifacts/endpoint_core.py +2 -2
- workbench/core/artifacts/feature_set_core.py +185 -230
- workbench/core/transforms/features_to_model/features_to_model.py +2 -8
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
- workbench/model_script_utils/model_script_utils.py +15 -11
- workbench/model_scripts/chemprop/chemprop.template +195 -70
- workbench/model_scripts/chemprop/generated_model_script.py +198 -73
- workbench/model_scripts/chemprop/model_script_utils.py +15 -11
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +80 -43
- workbench/model_scripts/pytorch_model/generated_model_script.py +2 -2
- workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
- workbench/model_scripts/xgb_model/generated_model_script.py +7 -7
- workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +71 -2
- workbench/themes/light/custom.css +7 -1
- workbench/themes/midnight_blue/custom.css +34 -0
- workbench/utils/chem_utils/fingerprints.py +80 -43
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/meta_model_simulator.py +41 -13
- workbench/utils/model_utils.py +0 -1
- workbench/utils/plot_utils.py +146 -28
- workbench/utils/shap_utils.py +1 -55
- workbench/utils/theme_manager.py +95 -30
- workbench/web_interface/components/plugins/scatter_plot.py +152 -66
- workbench/web_interface/components/settings_menu.py +184 -0
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/METADATA +4 -13
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/RECORD +38 -37
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/entry_points.txt +1 -0
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/WHEEL +0 -0
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/top_level.txt +0 -0
|
@@ -8,7 +8,7 @@ from dash.exceptions import PreventUpdate
|
|
|
8
8
|
# Workbench Imports
|
|
9
9
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
10
10
|
from workbench.utils.theme_manager import ThemeManager
|
|
11
|
-
from workbench.utils.plot_utils import prediction_intervals
|
|
11
|
+
from workbench.utils.plot_utils import prediction_intervals, molecule_hover_tooltip
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class ScatterPlot(PluginInterface):
|
|
@@ -18,6 +18,12 @@ class ScatterPlot(PluginInterface):
|
|
|
18
18
|
auto_load_page = PluginPage.NONE
|
|
19
19
|
plugin_input_type = PluginInputType.DATAFRAME
|
|
20
20
|
|
|
21
|
+
# Pre-computed circle overlay SVG
|
|
22
|
+
_circle_svg = """<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" style="overflow: visible;">
|
|
23
|
+
<circle cx="50" cy="50" r="10" stroke="rgba(255, 255, 255, 1)" stroke-width="3" fill="none" />
|
|
24
|
+
</svg>"""
|
|
25
|
+
_circle_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(_circle_svg.encode('utf-8')).decode('utf-8')}"
|
|
26
|
+
|
|
21
27
|
def __init__(self, show_axes: bool = True):
|
|
22
28
|
"""Initialize the Scatter Plot Plugin
|
|
23
29
|
|
|
@@ -30,6 +36,9 @@ class ScatterPlot(PluginInterface):
|
|
|
30
36
|
self.show_axes = show_axes
|
|
31
37
|
self.theme_manager = ThemeManager()
|
|
32
38
|
self.colorscale = self.theme_manager.colorscale()
|
|
39
|
+
self.has_smiles = False # Track if dataframe has smiles column for molecule hover
|
|
40
|
+
self.smiles_column = None
|
|
41
|
+
self.id_column = None
|
|
33
42
|
|
|
34
43
|
# Call the parent class constructor
|
|
35
44
|
super().__init__()
|
|
@@ -51,10 +60,10 @@ class ScatterPlot(PluginInterface):
|
|
|
51
60
|
(f"{component_id}-x-dropdown", "options"),
|
|
52
61
|
(f"{component_id}-y-dropdown", "options"),
|
|
53
62
|
(f"{component_id}-color-dropdown", "options"),
|
|
54
|
-
(f"{component_id}-label-dropdown", "options"),
|
|
55
63
|
(f"{component_id}-x-dropdown", "value"),
|
|
56
64
|
(f"{component_id}-y-dropdown", "value"),
|
|
57
65
|
(f"{component_id}-color-dropdown", "value"),
|
|
66
|
+
(f"{component_id}-regression-line", "value"),
|
|
58
67
|
]
|
|
59
68
|
self.signals = [(f"{component_id}-graph", "hoverData"), (f"{component_id}-graph", "clickData")]
|
|
60
69
|
|
|
@@ -75,45 +84,61 @@ class ScatterPlot(PluginInterface):
|
|
|
75
84
|
# Controls: X, Y, Color, Label Dropdowns, and Regression Line Checkbox
|
|
76
85
|
html.Div(
|
|
77
86
|
[
|
|
78
|
-
html.Label(
|
|
87
|
+
html.Label(
|
|
88
|
+
"X",
|
|
89
|
+
style={
|
|
90
|
+
"marginLeft": "20px",
|
|
91
|
+
"marginRight": "5px",
|
|
92
|
+
"fontWeight": "bold",
|
|
93
|
+
"display": "flex",
|
|
94
|
+
"alignItems": "center",
|
|
95
|
+
},
|
|
96
|
+
),
|
|
79
97
|
dcc.Dropdown(
|
|
80
98
|
id=f"{component_id}-x-dropdown",
|
|
81
|
-
|
|
82
|
-
style={"min-width": "50px", "flex": 1}, # Responsive width
|
|
99
|
+
style={"minWidth": "150px", "flex": 1},
|
|
83
100
|
clearable=False,
|
|
84
101
|
),
|
|
85
|
-
html.Label(
|
|
102
|
+
html.Label(
|
|
103
|
+
"Y",
|
|
104
|
+
style={
|
|
105
|
+
"marginLeft": "20px",
|
|
106
|
+
"marginRight": "5px",
|
|
107
|
+
"fontWeight": "bold",
|
|
108
|
+
"display": "flex",
|
|
109
|
+
"alignItems": "center",
|
|
110
|
+
},
|
|
111
|
+
),
|
|
86
112
|
dcc.Dropdown(
|
|
87
113
|
id=f"{component_id}-y-dropdown",
|
|
88
|
-
|
|
89
|
-
style={"min-width": "50px", "flex": 1}, # Responsive width
|
|
114
|
+
style={"minWidth": "150px", "flex": 1},
|
|
90
115
|
clearable=False,
|
|
91
116
|
),
|
|
92
|
-
html.Label(
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
117
|
+
html.Label(
|
|
118
|
+
"Color",
|
|
119
|
+
style={
|
|
120
|
+
"marginLeft": "20px",
|
|
121
|
+
"marginRight": "5px",
|
|
122
|
+
"fontWeight": "bold",
|
|
123
|
+
"display": "flex",
|
|
124
|
+
"alignItems": "center",
|
|
125
|
+
},
|
|
98
126
|
),
|
|
99
|
-
html.Label("Label", style={"marginLeft": "30px", "marginRight": "5px", "fontWeight": "bold"}),
|
|
100
127
|
dcc.Dropdown(
|
|
101
|
-
id=f"{component_id}-
|
|
102
|
-
|
|
103
|
-
style={"min-width": "50px", "flex": 1},
|
|
104
|
-
options=[{"label": "None", "value": "none"}],
|
|
105
|
-
value="none",
|
|
128
|
+
id=f"{component_id}-color-dropdown",
|
|
129
|
+
style={"minWidth": "150px", "flex": 1},
|
|
106
130
|
clearable=False,
|
|
107
131
|
),
|
|
108
132
|
dcc.Checklist(
|
|
109
133
|
id=f"{component_id}-regression-line",
|
|
110
134
|
options=[{"label": " Diagonal", "value": "show"}],
|
|
111
135
|
value=[],
|
|
112
|
-
style={"
|
|
136
|
+
style={"marginLeft": "20px", "display": "flex", "alignItems": "center"},
|
|
113
137
|
),
|
|
114
138
|
],
|
|
115
|
-
style={"padding": "0px 0px 10px 0px", "display": "flex", "gap": "
|
|
139
|
+
style={"padding": "0px 0px 10px 0px", "display": "flex", "alignItems": "center", "gap": "5px"},
|
|
116
140
|
),
|
|
141
|
+
# Circle overlay tooltip (centered on hovered point)
|
|
117
142
|
dcc.Tooltip(
|
|
118
143
|
id=f"{component_id}-overlay",
|
|
119
144
|
background_color="rgba(0,0,0,0)",
|
|
@@ -121,6 +146,14 @@ class ScatterPlot(PluginInterface):
|
|
|
121
146
|
direction="bottom",
|
|
122
147
|
loading_text="",
|
|
123
148
|
),
|
|
149
|
+
# Molecule tooltip (offset from hovered point) - only used when smiles column exists
|
|
150
|
+
dcc.Tooltip(
|
|
151
|
+
id=f"{component_id}-molecule-tooltip",
|
|
152
|
+
background_color="rgba(0,0,0,0)",
|
|
153
|
+
border_color="rgba(0,0,0,0)",
|
|
154
|
+
direction="bottom",
|
|
155
|
+
loading_text="",
|
|
156
|
+
),
|
|
124
157
|
],
|
|
125
158
|
style={"height": "100%", "display": "flex", "flexDirection": "column"}, # Full viewport height
|
|
126
159
|
)
|
|
@@ -139,11 +172,11 @@ class ScatterPlot(PluginInterface):
|
|
|
139
172
|
- hover_columns: The columns to show when hovering over a point
|
|
140
173
|
- suppress_hover_display: Suppress hover display (default: False)
|
|
141
174
|
- custom_data: Custom data that get passed to hoverData callbacks
|
|
175
|
+
- id_column: Column to use for molecule tooltip header (auto-detects "id" if not specified)
|
|
142
176
|
|
|
143
177
|
Returns:
|
|
144
178
|
list: A list of updated property values (figure, x options, y options, color options,
|
|
145
|
-
|
|
146
|
-
color default).
|
|
179
|
+
x default, y default, color default).
|
|
147
180
|
"""
|
|
148
181
|
# Get the limit for the number of rows to plot
|
|
149
182
|
limit = kwargs.get("limit", 20000)
|
|
@@ -163,6 +196,11 @@ class ScatterPlot(PluginInterface):
|
|
|
163
196
|
self.suppress_hover_display = kwargs.get("suppress_hover_display", False)
|
|
164
197
|
self.custom_data = kwargs.get("custom_data", [])
|
|
165
198
|
|
|
199
|
+
# Check if the dataframe has smiles/id columns for molecule hover rendering
|
|
200
|
+
self.smiles_column = next((col for col in self.df.columns if col.lower() == "smiles"), None)
|
|
201
|
+
self.id_column = kwargs.get("id_column") or next((col for col in self.df.columns if col.lower() == "id"), None)
|
|
202
|
+
self.has_smiles = self.smiles_column is not None
|
|
203
|
+
|
|
166
204
|
# Identify numeric columns
|
|
167
205
|
numeric_columns = self.df.select_dtypes(include="number").columns.tolist()
|
|
168
206
|
if len(numeric_columns) < 3:
|
|
@@ -175,7 +213,7 @@ class ScatterPlot(PluginInterface):
|
|
|
175
213
|
regression_line = kwargs.get("regression_line", False)
|
|
176
214
|
|
|
177
215
|
# Create the default scatter plot
|
|
178
|
-
figure = self.create_scatter_plot(self.df, x_default, y_default, color_default,
|
|
216
|
+
figure = self.create_scatter_plot(self.df, x_default, y_default, color_default, regression_line)
|
|
179
217
|
|
|
180
218
|
# Dropdown options for x and y: use provided dropdown_columns or fallback to numeric columns
|
|
181
219
|
dropdown_columns = kwargs.get("dropdown_columns", numeric_columns)
|
|
@@ -188,11 +226,10 @@ class ScatterPlot(PluginInterface):
|
|
|
188
226
|
color_columns = numeric_columns + cat_columns
|
|
189
227
|
color_options = [{"label": col, "value": col} for col in color_columns]
|
|
190
228
|
|
|
191
|
-
#
|
|
192
|
-
|
|
193
|
-
label_options.extend([{"label": col, "value": col} for col in self.df.columns])
|
|
229
|
+
# Regression line checklist value (list with "show" if enabled, empty list if disabled)
|
|
230
|
+
regression_line_value = ["show"] if regression_line else []
|
|
194
231
|
|
|
195
|
-
return [figure, x_options, y_options, color_options,
|
|
232
|
+
return [figure, x_options, y_options, color_options, x_default, y_default, color_default, regression_line_value]
|
|
196
233
|
|
|
197
234
|
def create_scatter_plot(
|
|
198
235
|
self,
|
|
@@ -200,7 +237,6 @@ class ScatterPlot(PluginInterface):
|
|
|
200
237
|
x_col: str,
|
|
201
238
|
y_col: str,
|
|
202
239
|
color_col: str,
|
|
203
|
-
label_col: str,
|
|
204
240
|
regression_line: bool = False,
|
|
205
241
|
marker_size: int = 15,
|
|
206
242
|
) -> go.Figure:
|
|
@@ -211,24 +247,38 @@ class ScatterPlot(PluginInterface):
|
|
|
211
247
|
x_col (str): The column to use for the x-axis.
|
|
212
248
|
y_col (str): The column to use for the y-axis.
|
|
213
249
|
color_col (str): The column to use for the color scale.
|
|
214
|
-
label_col (str): The column to use for point labels.
|
|
215
250
|
regression_line (bool): Whether to include a regression line.
|
|
216
251
|
marker_size (int): Size of the markers. Default is 15.
|
|
217
252
|
|
|
218
253
|
Returns:
|
|
219
254
|
go.Figure: A Plotly Figure object.
|
|
220
255
|
"""
|
|
221
|
-
# Check if we need to show labels
|
|
222
|
-
show_labels = label_col != "none" and len(df) < 1000
|
|
223
256
|
|
|
224
257
|
# Helper to generate hover text for each point.
|
|
225
258
|
def generate_hover_text(row):
|
|
226
259
|
return "<br>".join([f"{col}: {row[col]}" for col in self.hover_columns])
|
|
227
260
|
|
|
228
|
-
# Generate hover text for all points
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
261
|
+
# Generate hover text for all points (unless suppressed or using molecule hover)
|
|
262
|
+
suppress_hover = self.suppress_hover_display or self.has_smiles
|
|
263
|
+
if suppress_hover:
|
|
264
|
+
# Use "none" to hide the default hover display but still fire hoverData callbacks
|
|
265
|
+
# Don't set hovertemplate when suppressing - it would override hoverinfo
|
|
266
|
+
hovertext = None
|
|
267
|
+
hovertemplate = None
|
|
268
|
+
hoverinfo = "none"
|
|
269
|
+
else:
|
|
270
|
+
hovertext = df.apply(generate_hover_text, axis=1)
|
|
271
|
+
hovertemplate = "%{hovertext}<extra></extra>"
|
|
272
|
+
hoverinfo = None
|
|
273
|
+
|
|
274
|
+
# Build customdata columns - include smiles and id if available for molecule hover
|
|
275
|
+
custom_data_cols = list(self.custom_data) if self.custom_data else []
|
|
276
|
+
if self.has_smiles:
|
|
277
|
+
# Add smiles as first column, id as second (if available)
|
|
278
|
+
if self.smiles_column not in custom_data_cols:
|
|
279
|
+
custom_data_cols = [self.smiles_column] + custom_data_cols
|
|
280
|
+
if self.id_column and self.id_column not in custom_data_cols:
|
|
281
|
+
custom_data_cols.insert(1, self.id_column)
|
|
232
282
|
|
|
233
283
|
# Determine marker settings based on the type of the color column.
|
|
234
284
|
if pd.api.types.is_numeric_dtype(df[color_col]):
|
|
@@ -240,12 +290,10 @@ class ScatterPlot(PluginInterface):
|
|
|
240
290
|
x=df[x_col],
|
|
241
291
|
y=df[y_col],
|
|
242
292
|
mode="markers",
|
|
243
|
-
text=df[label_col].astype(str) if show_labels else None,
|
|
244
|
-
textposition="top center",
|
|
245
293
|
hoverinfo=hoverinfo,
|
|
246
294
|
hovertext=hovertext,
|
|
247
295
|
hovertemplate=hovertemplate,
|
|
248
|
-
customdata=df[
|
|
296
|
+
customdata=df[custom_data_cols] if custom_data_cols else None,
|
|
249
297
|
marker=dict(
|
|
250
298
|
size=marker_size,
|
|
251
299
|
color=marker_color,
|
|
@@ -266,18 +314,16 @@ class ScatterPlot(PluginInterface):
|
|
|
266
314
|
data = []
|
|
267
315
|
for i, cat in enumerate(categories):
|
|
268
316
|
sub_df = df[df[color_col] == cat]
|
|
269
|
-
sub_hovertext = hovertext.loc[sub_df.index]
|
|
317
|
+
sub_hovertext = hovertext.loc[sub_df.index] if hovertext is not None else None
|
|
270
318
|
trace = go.Scattergl(
|
|
271
319
|
x=sub_df[x_col],
|
|
272
320
|
y=sub_df[y_col],
|
|
273
321
|
mode="markers",
|
|
274
|
-
text=sub_df[label_col] if show_labels else None, # Add text if labels enabled
|
|
275
|
-
textposition="top center", # Position labels above points
|
|
276
322
|
name=cat,
|
|
277
323
|
hoverinfo=hoverinfo,
|
|
278
324
|
hovertext=sub_hovertext,
|
|
279
325
|
hovertemplate=hovertemplate,
|
|
280
|
-
customdata=sub_df[
|
|
326
|
+
customdata=sub_df[custom_data_cols] if custom_data_cols else None,
|
|
281
327
|
marker=dict(
|
|
282
328
|
size=marker_size,
|
|
283
329
|
color=discrete_colors[i % len(discrete_colors)],
|
|
@@ -345,18 +391,16 @@ class ScatterPlot(PluginInterface):
|
|
|
345
391
|
Input(f"{self.component_id}-x-dropdown", "value"),
|
|
346
392
|
Input(f"{self.component_id}-y-dropdown", "value"),
|
|
347
393
|
Input(f"{self.component_id}-color-dropdown", "value"),
|
|
348
|
-
Input(f"{self.component_id}-label-dropdown", "value"),
|
|
349
394
|
Input(f"{self.component_id}-regression-line", "value"),
|
|
350
395
|
],
|
|
351
396
|
prevent_initial_call=True,
|
|
352
397
|
)
|
|
353
|
-
def _update_scatter_plot(x_value, y_value, color_value,
|
|
398
|
+
def _update_scatter_plot(x_value, y_value, color_value, regression_line):
|
|
354
399
|
"""Update the Scatter Plot Graph based on the dropdown values."""
|
|
355
400
|
|
|
356
401
|
# Check if the dataframe is not empty and the values are not None
|
|
357
402
|
if not self.df.empty and x_value and y_value and color_value:
|
|
358
|
-
|
|
359
|
-
figure = self.create_scatter_plot(self.df, x_value, y_value, color_value, label_value, regression_line)
|
|
403
|
+
figure = self.create_scatter_plot(self.df, x_value, y_value, color_value, regression_line)
|
|
360
404
|
return figure
|
|
361
405
|
|
|
362
406
|
raise PreventUpdate
|
|
@@ -367,43 +411,68 @@ class ScatterPlot(PluginInterface):
|
|
|
367
411
|
Output(f"{self.component_id}-overlay", "children"),
|
|
368
412
|
Input(f"{self.component_id}-graph", "hoverData"),
|
|
369
413
|
)
|
|
370
|
-
def
|
|
414
|
+
def _scatter_circle_overlay(hover_data):
|
|
415
|
+
"""Show white circle overlay centered on the hovered point."""
|
|
371
416
|
if hover_data is None:
|
|
372
|
-
# Hide the overlay if no hover data
|
|
373
417
|
return False, no_update, no_update
|
|
374
418
|
|
|
375
419
|
# Extract bounding box from hoverData
|
|
376
420
|
bbox = hover_data["points"][0]["bbox"]
|
|
377
421
|
|
|
378
|
-
#
|
|
379
|
-
|
|
380
|
-
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" style="overflow: visible;">
|
|
381
|
-
<!-- Circle for the node -->
|
|
382
|
-
<circle cx="50" cy="50" r="10" stroke="rgba(255, 255, 255, 1)" stroke-width="3" fill="none" />
|
|
383
|
-
</svg>
|
|
384
|
-
"""
|
|
385
|
-
|
|
386
|
-
# Encode the SVG as Base64
|
|
387
|
-
encoded_svg = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
|
|
388
|
-
data_uri = f"data:image/svg+xml;base64,{encoded_svg}"
|
|
389
|
-
|
|
390
|
-
# Use an img tag for the overlay
|
|
391
|
-
svg_image = html.Img(src=data_uri, style={"width": "100px", "height": "100px"})
|
|
422
|
+
# Use pre-computed circle SVG
|
|
423
|
+
svg_image = html.Img(src=self._circle_data_uri, style={"width": "100px", "height": "100px"})
|
|
392
424
|
|
|
393
425
|
# Get the center of the bounding box
|
|
394
426
|
center_x = (bbox["x0"] + bbox["x1"]) / 2
|
|
395
427
|
center_y = (bbox["y0"] + bbox["y1"]) / 2
|
|
396
428
|
|
|
397
|
-
# The tooltip should be centered on the point
|
|
429
|
+
# The tooltip should be centered on the point
|
|
398
430
|
adjusted_bbox = {
|
|
399
431
|
"x0": center_x - 50,
|
|
400
432
|
"x1": center_x + 50,
|
|
401
433
|
"y0": center_y - 162,
|
|
402
434
|
"y1": center_y - 62,
|
|
403
435
|
}
|
|
404
|
-
# Return the updated values for the overlay
|
|
405
436
|
return True, adjusted_bbox, [svg_image]
|
|
406
437
|
|
|
438
|
+
@callback(
|
|
439
|
+
Output(f"{self.component_id}-molecule-tooltip", "show"),
|
|
440
|
+
Output(f"{self.component_id}-molecule-tooltip", "bbox"),
|
|
441
|
+
Output(f"{self.component_id}-molecule-tooltip", "children"),
|
|
442
|
+
Input(f"{self.component_id}-graph", "hoverData"),
|
|
443
|
+
)
|
|
444
|
+
def _scatter_molecule_overlay(hover_data):
|
|
445
|
+
"""Show molecule tooltip when smiles data is available."""
|
|
446
|
+
if hover_data is None or not self.has_smiles:
|
|
447
|
+
return False, no_update, no_update
|
|
448
|
+
|
|
449
|
+
# Extract customdata (contains smiles and id)
|
|
450
|
+
customdata = hover_data["points"][0].get("customdata")
|
|
451
|
+
if customdata is None:
|
|
452
|
+
return False, no_update, no_update
|
|
453
|
+
|
|
454
|
+
# SMILES is the first element, ID is second (if available)
|
|
455
|
+
if isinstance(customdata, (list, tuple)):
|
|
456
|
+
smiles = customdata[0]
|
|
457
|
+
mol_id = customdata[1] if len(customdata) > 1 and self.id_column else None
|
|
458
|
+
else:
|
|
459
|
+
smiles = customdata
|
|
460
|
+
mol_id = None
|
|
461
|
+
|
|
462
|
+
# Generate molecule tooltip with ID header
|
|
463
|
+
mol_width, mol_height = 300, 200
|
|
464
|
+
children = molecule_hover_tooltip(smiles, mol_id=mol_id, width=mol_width, height=mol_height)
|
|
465
|
+
|
|
466
|
+
# Extract bounding box and offset the molecule tooltip to the right of the point
|
|
467
|
+
bbox = hover_data["points"][0]["bbox"]
|
|
468
|
+
adjusted_bbox = {
|
|
469
|
+
"x0": bbox["x0"] + 15,
|
|
470
|
+
"x1": bbox["x1"] + mol_width + 15,
|
|
471
|
+
"y0": bbox["y0"] - (2 * mol_height + 60),
|
|
472
|
+
"y1": bbox["y1"] - (mol_height + 60),
|
|
473
|
+
}
|
|
474
|
+
return True, adjusted_bbox, children
|
|
475
|
+
|
|
407
476
|
|
|
408
477
|
if __name__ == "__main__":
|
|
409
478
|
"""Run the Unit Test for the Plugin."""
|
|
@@ -426,6 +495,8 @@ if __name__ == "__main__":
|
|
|
426
495
|
df = model.get_inference_predictions("full_cross_fold")
|
|
427
496
|
|
|
428
497
|
# Run the Unit Test on the Plugin
|
|
498
|
+
# Test currently commented out
|
|
499
|
+
"""
|
|
429
500
|
PluginUnitTest(
|
|
430
501
|
ScatterPlot,
|
|
431
502
|
input_data=df,
|
|
@@ -435,3 +506,18 @@ if __name__ == "__main__":
|
|
|
435
506
|
color="prediction_std",
|
|
436
507
|
suppress_hover_display=True,
|
|
437
508
|
).run()
|
|
509
|
+
"""
|
|
510
|
+
|
|
511
|
+
# Test with molecule hover (smiles column)
|
|
512
|
+
from workbench.api import FeatureSet
|
|
513
|
+
|
|
514
|
+
fs = FeatureSet("aqsol_features")
|
|
515
|
+
mol_df = fs.pull_dataframe()[:1000] # Limit to 1000 rows for testing
|
|
516
|
+
|
|
517
|
+
# Run the Unit Test with molecule data (hover over points to see molecule structures)
|
|
518
|
+
PluginUnitTest(
|
|
519
|
+
ScatterPlot,
|
|
520
|
+
input_data=mol_df,
|
|
521
|
+
theme="midnight_blue",
|
|
522
|
+
suppress_hover_display=True,
|
|
523
|
+
).run()
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""SettingsMenu: A settings menu component for the Workbench Dashboard."""
|
|
2
|
+
|
|
3
|
+
from dash import html, dcc
|
|
4
|
+
import dash_bootstrap_components as dbc
|
|
5
|
+
|
|
6
|
+
# Workbench Imports
|
|
7
|
+
from workbench.utils.theme_manager import ThemeManager
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SettingsMenu:
|
|
11
|
+
"""A settings menu with admin links and theme selection."""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
"""Initialize the SettingsMenu."""
|
|
15
|
+
self.tm = ThemeManager()
|
|
16
|
+
|
|
17
|
+
def create_component(self, component_id: str) -> html.Div:
|
|
18
|
+
"""Create a settings menu dropdown component.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
component_id (str): The ID prefix for the component.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
html.Div: A Div containing the settings menu dropdown.
|
|
25
|
+
"""
|
|
26
|
+
themes = self.tm.list_themes()
|
|
27
|
+
|
|
28
|
+
# Create theme submenu items
|
|
29
|
+
theme_items = []
|
|
30
|
+
for theme in sorted(themes):
|
|
31
|
+
theme_items.append(
|
|
32
|
+
dbc.DropdownMenuItem(
|
|
33
|
+
[
|
|
34
|
+
html.Span(
|
|
35
|
+
"",
|
|
36
|
+
id={"type": f"{component_id}-checkmark", "theme": theme},
|
|
37
|
+
style={
|
|
38
|
+
"fontFamily": "monospace",
|
|
39
|
+
"marginRight": "5px",
|
|
40
|
+
"width": "20px",
|
|
41
|
+
"display": "inline-block",
|
|
42
|
+
},
|
|
43
|
+
),
|
|
44
|
+
theme.replace("_", " ").title(),
|
|
45
|
+
],
|
|
46
|
+
id={"type": f"{component_id}-theme-item", "theme": theme},
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Hamburger icon (3 rounded lines)
|
|
51
|
+
hamburger_icon = html.Div(
|
|
52
|
+
[
|
|
53
|
+
html.Div(
|
|
54
|
+
style={
|
|
55
|
+
"width": "20px",
|
|
56
|
+
"height": "3px",
|
|
57
|
+
"backgroundColor": "currentColor",
|
|
58
|
+
"borderRadius": "2px",
|
|
59
|
+
"marginBottom": "4px",
|
|
60
|
+
}
|
|
61
|
+
),
|
|
62
|
+
html.Div(
|
|
63
|
+
style={
|
|
64
|
+
"width": "20px",
|
|
65
|
+
"height": "3px",
|
|
66
|
+
"backgroundColor": "currentColor",
|
|
67
|
+
"borderRadius": "2px",
|
|
68
|
+
"marginBottom": "4px",
|
|
69
|
+
}
|
|
70
|
+
),
|
|
71
|
+
html.Div(
|
|
72
|
+
style={"width": "20px", "height": "3px", "backgroundColor": "currentColor", "borderRadius": "2px"}
|
|
73
|
+
),
|
|
74
|
+
],
|
|
75
|
+
style={"display": "flex", "flexDirection": "column", "alignItems": "center", "justifyContent": "center"},
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Build menu items: Status, License, divider, Themes submenu
|
|
79
|
+
menu_items = [
|
|
80
|
+
dbc.DropdownMenuItem("Status", href="/status", external_link=True, target="_blank"),
|
|
81
|
+
dbc.DropdownMenuItem("License", href="/license", external_link=True, target="_blank"),
|
|
82
|
+
dbc.DropdownMenuItem(divider=True),
|
|
83
|
+
dbc.DropdownMenuItem("Themes", header=True),
|
|
84
|
+
*theme_items,
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
return html.Div(
|
|
88
|
+
[
|
|
89
|
+
dbc.DropdownMenu(
|
|
90
|
+
label=hamburger_icon,
|
|
91
|
+
children=menu_items,
|
|
92
|
+
id=f"{component_id}-dropdown",
|
|
93
|
+
toggle_style={
|
|
94
|
+
"background": "transparent",
|
|
95
|
+
"border": "none",
|
|
96
|
+
"boxShadow": "none",
|
|
97
|
+
"padding": "5px 10px",
|
|
98
|
+
},
|
|
99
|
+
caret=False,
|
|
100
|
+
align_end=True,
|
|
101
|
+
),
|
|
102
|
+
# Dummy store for the clientside callback output
|
|
103
|
+
dcc.Store(id=f"{component_id}-dummy", data=None),
|
|
104
|
+
# Store to trigger checkmark update on load
|
|
105
|
+
dcc.Store(id=f"{component_id}-init", data=True),
|
|
106
|
+
],
|
|
107
|
+
id=component_id,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def get_clientside_callback_code(component_id: str) -> str:
|
|
112
|
+
"""Get the JavaScript code for the theme selection clientside callback.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
component_id (str): The ID prefix used in create_component.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
str: JavaScript code for the clientside callback.
|
|
119
|
+
"""
|
|
120
|
+
return """
|
|
121
|
+
function(n_clicks_list, ids) {
|
|
122
|
+
// Find which button was clicked
|
|
123
|
+
if (!n_clicks_list || n_clicks_list.every(n => !n)) {
|
|
124
|
+
return window.dash_clientside.no_update;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
// Find the clicked theme
|
|
128
|
+
let clickedTheme = null;
|
|
129
|
+
for (let i = 0; i < n_clicks_list.length; i++) {
|
|
130
|
+
if (n_clicks_list[i]) {
|
|
131
|
+
clickedTheme = ids[i].theme;
|
|
132
|
+
break;
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
if (clickedTheme) {
|
|
137
|
+
// Store in localStorage
|
|
138
|
+
localStorage.setItem('wb_theme', clickedTheme);
|
|
139
|
+
// Set cookie for Flask to read on reload
|
|
140
|
+
document.cookie = `wb_theme=${clickedTheme}; path=/; max-age=31536000`;
|
|
141
|
+
// Reload the page to apply the new theme
|
|
142
|
+
window.location.reload();
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
return window.dash_clientside.no_update;
|
|
146
|
+
}
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
@staticmethod
|
|
150
|
+
def get_checkmark_callback_code() -> str:
|
|
151
|
+
"""Get the JavaScript code to update checkmarks based on localStorage.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
str: JavaScript code for the checkmark update callback.
|
|
155
|
+
"""
|
|
156
|
+
return """
|
|
157
|
+
function(init, ids) {
|
|
158
|
+
// Get current theme from localStorage (or cookie as fallback)
|
|
159
|
+
let currentTheme = localStorage.getItem('wb_theme');
|
|
160
|
+
if (!currentTheme) {
|
|
161
|
+
// Try to read from cookie
|
|
162
|
+
const cookies = document.cookie.split(';');
|
|
163
|
+
for (let cookie of cookies) {
|
|
164
|
+
const [name, value] = cookie.trim().split('=');
|
|
165
|
+
if (name === 'wb_theme') {
|
|
166
|
+
currentTheme = value;
|
|
167
|
+
break;
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
// Return checkmarks for each theme
|
|
173
|
+
return ids.map(id => id.theme === currentTheme ? '\u2713' : '');
|
|
174
|
+
}
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
if __name__ == "__main__":
|
|
179
|
+
# Quick test to verify component creation
|
|
180
|
+
menu = SettingsMenu()
|
|
181
|
+
component = menu.create_component("test-settings-menu")
|
|
182
|
+
print("SettingsMenu component created successfully")
|
|
183
|
+
print(f"Available themes: {menu.tm.list_themes()}")
|
|
184
|
+
print(f"Current theme: {menu.tm.current_theme()}")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: workbench
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.224
|
|
4
4
|
Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
|
|
5
5
|
Author-email: SuperCowPowers LLC <support@supercowpowers.com>
|
|
6
6
|
License: MIT License
|
|
@@ -54,6 +54,7 @@ Requires-Dist: requests>=2.26.0
|
|
|
54
54
|
Requires-Dist: rdkit>=2024.9.5
|
|
55
55
|
Requires-Dist: mordredcommunity>=2.0.6
|
|
56
56
|
Requires-Dist: workbench-bridges>=0.1.16
|
|
57
|
+
Requires-Dist: cleanlab>=2.9.0
|
|
57
58
|
Provides-Extra: ui
|
|
58
59
|
Requires-Dist: plotly>=6.0.0; extra == "ui"
|
|
59
60
|
Requires-Dist: dash>=3.0.0; extra == "ui"
|
|
@@ -71,18 +72,8 @@ Requires-Dist: flake8; extra == "dev"
|
|
|
71
72
|
Requires-Dist: black; extra == "dev"
|
|
72
73
|
Provides-Extra: all
|
|
73
74
|
Requires-Dist: networkx>=3.2; extra == "all"
|
|
74
|
-
Requires-Dist:
|
|
75
|
-
Requires-Dist:
|
|
76
|
-
Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "all"
|
|
77
|
-
Requires-Dist: dash-bootstrap-templates>=1.3.0; extra == "all"
|
|
78
|
-
Requires-Dist: dash_ag_grid; extra == "all"
|
|
79
|
-
Requires-Dist: tabulate>=0.9.0; extra == "all"
|
|
80
|
-
Requires-Dist: pytest; extra == "all"
|
|
81
|
-
Requires-Dist: pytest-sugar; extra == "all"
|
|
82
|
-
Requires-Dist: coverage; extra == "all"
|
|
83
|
-
Requires-Dist: pytest-cov; extra == "all"
|
|
84
|
-
Requires-Dist: flake8; extra == "all"
|
|
85
|
-
Requires-Dist: black; extra == "all"
|
|
75
|
+
Requires-Dist: workbench[ui]; extra == "all"
|
|
76
|
+
Requires-Dist: workbench[dev]; extra == "all"
|
|
86
77
|
Dynamic: license-file
|
|
87
78
|
|
|
88
79
|
|