workbench 0.8.177__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +5 -5
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +319 -204
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -82
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +260 -76
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/RECORD +121 -106
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -494
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -494
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import pandas as pd
|
|
3
|
-
from dash import dcc, html, callback, Input, Output, no_update
|
|
3
|
+
from dash import dcc, html, callback, clientside_callback, Input, Output, no_update
|
|
4
4
|
import plotly.graph_objects as go
|
|
5
5
|
import plotly.express as px
|
|
6
6
|
from dash.exceptions import PreventUpdate
|
|
@@ -8,7 +8,7 @@ from dash.exceptions import PreventUpdate
|
|
|
8
8
|
# Workbench Imports
|
|
9
9
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
10
10
|
from workbench.utils.theme_manager import ThemeManager
|
|
11
|
-
from workbench.utils.plot_utils import prediction_intervals
|
|
11
|
+
from workbench.utils.plot_utils import prediction_intervals, molecule_hover_tooltip
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class ScatterPlot(PluginInterface):
|
|
@@ -18,6 +18,12 @@ class ScatterPlot(PluginInterface):
|
|
|
18
18
|
auto_load_page = PluginPage.NONE
|
|
19
19
|
plugin_input_type = PluginInputType.DATAFRAME
|
|
20
20
|
|
|
21
|
+
# Pre-computed circle overlay SVG
|
|
22
|
+
_circle_svg = """<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" style="overflow: visible;">
|
|
23
|
+
<circle cx="50" cy="50" r="10" stroke="rgba(255, 255, 255, 1)" stroke-width="3" fill="none" />
|
|
24
|
+
</svg>"""
|
|
25
|
+
_circle_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(_circle_svg.encode('utf-8')).decode('utf-8')}"
|
|
26
|
+
|
|
21
27
|
def __init__(self, show_axes: bool = True):
|
|
22
28
|
"""Initialize the Scatter Plot Plugin
|
|
23
29
|
|
|
@@ -29,7 +35,10 @@ class ScatterPlot(PluginInterface):
|
|
|
29
35
|
self.df = None
|
|
30
36
|
self.show_axes = show_axes
|
|
31
37
|
self.theme_manager = ThemeManager()
|
|
32
|
-
self.
|
|
38
|
+
self.has_smiles = False # Track if dataframe has smiles column for molecule hover
|
|
39
|
+
self.smiles_column = None
|
|
40
|
+
self.id_column = None
|
|
41
|
+
self.hover_background = None # Cached background color for molecule hover tooltip
|
|
33
42
|
|
|
34
43
|
# Call the parent class constructor
|
|
35
44
|
super().__init__()
|
|
@@ -51,10 +60,10 @@ class ScatterPlot(PluginInterface):
|
|
|
51
60
|
(f"{component_id}-x-dropdown", "options"),
|
|
52
61
|
(f"{component_id}-y-dropdown", "options"),
|
|
53
62
|
(f"{component_id}-color-dropdown", "options"),
|
|
54
|
-
(f"{component_id}-label-dropdown", "options"),
|
|
55
63
|
(f"{component_id}-x-dropdown", "value"),
|
|
56
64
|
(f"{component_id}-y-dropdown", "value"),
|
|
57
65
|
(f"{component_id}-color-dropdown", "value"),
|
|
66
|
+
(f"{component_id}-regression-line", "value"),
|
|
58
67
|
]
|
|
59
68
|
self.signals = [(f"{component_id}-graph", "hoverData"), (f"{component_id}-graph", "clickData")]
|
|
60
69
|
|
|
@@ -69,51 +78,67 @@ class ScatterPlot(PluginInterface):
|
|
|
69
78
|
id=f"{component_id}-graph",
|
|
70
79
|
figure=self.display_text("Waiting for Data..."),
|
|
71
80
|
config={"scrollZoom": True},
|
|
72
|
-
style={"height": "100%"},
|
|
81
|
+
style={"height": "600px", "width": "100%"},
|
|
73
82
|
clear_on_unhover=True,
|
|
74
83
|
),
|
|
75
84
|
# Controls: X, Y, Color, Label Dropdowns, and Regression Line Checkbox
|
|
76
85
|
html.Div(
|
|
77
86
|
[
|
|
78
|
-
html.Label(
|
|
87
|
+
html.Label(
|
|
88
|
+
"X",
|
|
89
|
+
style={
|
|
90
|
+
"marginLeft": "20px",
|
|
91
|
+
"marginRight": "5px",
|
|
92
|
+
"fontWeight": "bold",
|
|
93
|
+
"display": "flex",
|
|
94
|
+
"alignItems": "center",
|
|
95
|
+
},
|
|
96
|
+
),
|
|
79
97
|
dcc.Dropdown(
|
|
80
98
|
id=f"{component_id}-x-dropdown",
|
|
81
|
-
|
|
82
|
-
style={"min-width": "50px", "flex": 1}, # Responsive width
|
|
99
|
+
style={"minWidth": "150px", "flex": 1},
|
|
83
100
|
clearable=False,
|
|
84
101
|
),
|
|
85
|
-
html.Label(
|
|
102
|
+
html.Label(
|
|
103
|
+
"Y",
|
|
104
|
+
style={
|
|
105
|
+
"marginLeft": "20px",
|
|
106
|
+
"marginRight": "5px",
|
|
107
|
+
"fontWeight": "bold",
|
|
108
|
+
"display": "flex",
|
|
109
|
+
"alignItems": "center",
|
|
110
|
+
},
|
|
111
|
+
),
|
|
86
112
|
dcc.Dropdown(
|
|
87
113
|
id=f"{component_id}-y-dropdown",
|
|
88
|
-
|
|
89
|
-
style={"min-width": "50px", "flex": 1}, # Responsive width
|
|
114
|
+
style={"minWidth": "150px", "flex": 1},
|
|
90
115
|
clearable=False,
|
|
91
116
|
),
|
|
92
|
-
html.Label(
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
117
|
+
html.Label(
|
|
118
|
+
"Color",
|
|
119
|
+
style={
|
|
120
|
+
"marginLeft": "20px",
|
|
121
|
+
"marginRight": "5px",
|
|
122
|
+
"fontWeight": "bold",
|
|
123
|
+
"display": "flex",
|
|
124
|
+
"alignItems": "center",
|
|
125
|
+
},
|
|
98
126
|
),
|
|
99
|
-
html.Label("Label", style={"marginLeft": "30px", "marginRight": "5px", "fontWeight": "bold"}),
|
|
100
127
|
dcc.Dropdown(
|
|
101
|
-
id=f"{component_id}-
|
|
102
|
-
|
|
103
|
-
style={"min-width": "50px", "flex": 1},
|
|
104
|
-
options=[{"label": "None", "value": "none"}],
|
|
105
|
-
value="none",
|
|
128
|
+
id=f"{component_id}-color-dropdown",
|
|
129
|
+
style={"minWidth": "150px", "flex": 1},
|
|
106
130
|
clearable=False,
|
|
107
131
|
),
|
|
108
132
|
dcc.Checklist(
|
|
109
133
|
id=f"{component_id}-regression-line",
|
|
110
134
|
options=[{"label": " Diagonal", "value": "show"}],
|
|
111
135
|
value=[],
|
|
112
|
-
style={"
|
|
136
|
+
style={"marginLeft": "20px", "display": "flex", "alignItems": "center"},
|
|
113
137
|
),
|
|
114
138
|
],
|
|
115
|
-
style={"padding": "0px 0px 10px 0px", "display": "flex", "gap": "
|
|
139
|
+
style={"padding": "0px 0px 10px 0px", "display": "flex", "alignItems": "center", "gap": "5px"},
|
|
116
140
|
),
|
|
141
|
+
# Circle overlay tooltip (centered on hovered point)
|
|
117
142
|
dcc.Tooltip(
|
|
118
143
|
id=f"{component_id}-overlay",
|
|
119
144
|
background_color="rgba(0,0,0,0)",
|
|
@@ -121,6 +146,14 @@ class ScatterPlot(PluginInterface):
|
|
|
121
146
|
direction="bottom",
|
|
122
147
|
loading_text="",
|
|
123
148
|
),
|
|
149
|
+
# Molecule tooltip (offset from hovered point) - only used when smiles column exists
|
|
150
|
+
dcc.Tooltip(
|
|
151
|
+
id=f"{component_id}-molecule-tooltip",
|
|
152
|
+
background_color="rgba(0,0,0,0)",
|
|
153
|
+
border_color="rgba(0,0,0,0)",
|
|
154
|
+
direction="bottom",
|
|
155
|
+
loading_text="",
|
|
156
|
+
),
|
|
124
157
|
],
|
|
125
158
|
style={"height": "100%", "display": "flex", "flexDirection": "column"}, # Full viewport height
|
|
126
159
|
)
|
|
@@ -139,12 +172,16 @@ class ScatterPlot(PluginInterface):
|
|
|
139
172
|
- hover_columns: The columns to show when hovering over a point
|
|
140
173
|
- suppress_hover_display: Suppress hover display (default: False)
|
|
141
174
|
- custom_data: Custom data that get passed to hoverData callbacks
|
|
175
|
+
- id_column: Column to use for molecule tooltip header (auto-detects "id" if not specified)
|
|
142
176
|
|
|
143
177
|
Returns:
|
|
144
178
|
list: A list of updated property values (figure, x options, y options, color options,
|
|
145
|
-
|
|
146
|
-
color default).
|
|
179
|
+
x default, y default, color default).
|
|
147
180
|
"""
|
|
181
|
+
# Get the colorscale and background color from the current theme
|
|
182
|
+
self.colorscale = self.theme_manager.colorscale()
|
|
183
|
+
self.hover_background = self.theme_manager.background()
|
|
184
|
+
|
|
148
185
|
# Get the limit for the number of rows to plot
|
|
149
186
|
limit = kwargs.get("limit", 20000)
|
|
150
187
|
|
|
@@ -159,23 +196,32 @@ class ScatterPlot(PluginInterface):
|
|
|
159
196
|
self.df = self.df.drop(columns=aws_cols, errors="ignore")
|
|
160
197
|
|
|
161
198
|
# Set hover columns and custom data
|
|
162
|
-
self.hover_columns = kwargs.get("hover_columns", self.df.columns.tolist()[:
|
|
199
|
+
self.hover_columns = kwargs.get("hover_columns", sorted(self.df.columns.tolist()[:15]))
|
|
163
200
|
self.suppress_hover_display = kwargs.get("suppress_hover_display", False)
|
|
164
201
|
self.custom_data = kwargs.get("custom_data", [])
|
|
165
202
|
|
|
203
|
+
# Check if the dataframe has smiles/id columns for molecule hover rendering
|
|
204
|
+
self.smiles_column = next((col for col in self.df.columns if col.lower() == "smiles"), None)
|
|
205
|
+
# Use provided id_column, or auto-detect "id" column, or fall back to first column
|
|
206
|
+
self.id_column = kwargs.get("id_column") or next(
|
|
207
|
+
(col for col in self.df.columns if col.lower() == "id"), self.df.columns[0]
|
|
208
|
+
)
|
|
209
|
+
self.has_smiles = self.smiles_column is not None
|
|
210
|
+
|
|
166
211
|
# Identify numeric columns
|
|
167
212
|
numeric_columns = self.df.select_dtypes(include="number").columns.tolist()
|
|
168
213
|
if len(numeric_columns) < 3:
|
|
169
214
|
raise ValueError("At least three numeric columns are required for x, y, and color.")
|
|
170
215
|
|
|
171
|
-
# Default x, y, and color (for color,
|
|
216
|
+
# Default x, y, and color (for color, prefer 'confidence' if it exists)
|
|
172
217
|
x_default = kwargs.get("x", numeric_columns[0])
|
|
173
218
|
y_default = kwargs.get("y", numeric_columns[1])
|
|
174
|
-
|
|
219
|
+
default_color = "confidence" if "confidence" in self.df.columns else numeric_columns[2]
|
|
220
|
+
color_default = kwargs.get("color", default_color)
|
|
175
221
|
regression_line = kwargs.get("regression_line", False)
|
|
176
222
|
|
|
177
223
|
# Create the default scatter plot
|
|
178
|
-
figure = self.create_scatter_plot(self.df, x_default, y_default, color_default,
|
|
224
|
+
figure = self.create_scatter_plot(self.df, x_default, y_default, color_default, regression_line)
|
|
179
225
|
|
|
180
226
|
# Dropdown options for x and y: use provided dropdown_columns or fallback to numeric columns
|
|
181
227
|
dropdown_columns = kwargs.get("dropdown_columns", numeric_columns)
|
|
@@ -188,11 +234,10 @@ class ScatterPlot(PluginInterface):
|
|
|
188
234
|
color_columns = numeric_columns + cat_columns
|
|
189
235
|
color_options = [{"label": col, "value": col} for col in color_columns]
|
|
190
236
|
|
|
191
|
-
#
|
|
192
|
-
|
|
193
|
-
label_options.extend([{"label": col, "value": col} for col in self.df.columns])
|
|
237
|
+
# Regression line checklist value (list with "show" if enabled, empty list if disabled)
|
|
238
|
+
regression_line_value = ["show"] if regression_line else []
|
|
194
239
|
|
|
195
|
-
return [figure, x_options, y_options, color_options,
|
|
240
|
+
return [figure, x_options, y_options, color_options, x_default, y_default, color_default, regression_line_value]
|
|
196
241
|
|
|
197
242
|
def create_scatter_plot(
|
|
198
243
|
self,
|
|
@@ -200,7 +245,6 @@ class ScatterPlot(PluginInterface):
|
|
|
200
245
|
x_col: str,
|
|
201
246
|
y_col: str,
|
|
202
247
|
color_col: str,
|
|
203
|
-
label_col: str,
|
|
204
248
|
regression_line: bool = False,
|
|
205
249
|
marker_size: int = 15,
|
|
206
250
|
) -> go.Figure:
|
|
@@ -211,24 +255,38 @@ class ScatterPlot(PluginInterface):
|
|
|
211
255
|
x_col (str): The column to use for the x-axis.
|
|
212
256
|
y_col (str): The column to use for the y-axis.
|
|
213
257
|
color_col (str): The column to use for the color scale.
|
|
214
|
-
label_col (str): The column to use for point labels.
|
|
215
258
|
regression_line (bool): Whether to include a regression line.
|
|
216
259
|
marker_size (int): Size of the markers. Default is 15.
|
|
217
260
|
|
|
218
261
|
Returns:
|
|
219
262
|
go.Figure: A Plotly Figure object.
|
|
220
263
|
"""
|
|
221
|
-
# Check if we need to show labels
|
|
222
|
-
show_labels = label_col != "none" and len(df) < 1000
|
|
223
264
|
|
|
224
265
|
# Helper to generate hover text for each point.
|
|
225
266
|
def generate_hover_text(row):
|
|
226
267
|
return "<br>".join([f"{col}: {row[col]}" for col in self.hover_columns])
|
|
227
268
|
|
|
228
|
-
# Generate hover text for all points
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
269
|
+
# Generate hover text for all points (unless suppressed or using molecule hover)
|
|
270
|
+
suppress_hover = self.suppress_hover_display or self.has_smiles
|
|
271
|
+
if suppress_hover:
|
|
272
|
+
# Use "none" to hide the default hover display but still fire hoverData callbacks
|
|
273
|
+
# Don't set hovertemplate when suppressing - it would override hoverinfo
|
|
274
|
+
hovertext = None
|
|
275
|
+
hovertemplate = None
|
|
276
|
+
hoverinfo = "none"
|
|
277
|
+
else:
|
|
278
|
+
hovertext = df.apply(generate_hover_text, axis=1)
|
|
279
|
+
hovertemplate = "%{hovertext}<extra></extra>"
|
|
280
|
+
hoverinfo = None
|
|
281
|
+
|
|
282
|
+
# Build customdata columns - include smiles and id if available for molecule hover
|
|
283
|
+
custom_data_cols = list(self.custom_data) if self.custom_data else []
|
|
284
|
+
if self.has_smiles:
|
|
285
|
+
# Add smiles as first column, id as second (if available)
|
|
286
|
+
if self.smiles_column not in custom_data_cols:
|
|
287
|
+
custom_data_cols = [self.smiles_column] + custom_data_cols
|
|
288
|
+
if self.id_column and self.id_column not in custom_data_cols:
|
|
289
|
+
custom_data_cols.insert(1, self.id_column)
|
|
232
290
|
|
|
233
291
|
# Determine marker settings based on the type of the color column.
|
|
234
292
|
if pd.api.types.is_numeric_dtype(df[color_col]):
|
|
@@ -240,18 +298,16 @@ class ScatterPlot(PluginInterface):
|
|
|
240
298
|
x=df[x_col],
|
|
241
299
|
y=df[y_col],
|
|
242
300
|
mode="markers",
|
|
243
|
-
text=df[label_col].astype(str) if show_labels else None,
|
|
244
|
-
textposition="top center",
|
|
245
301
|
hoverinfo=hoverinfo,
|
|
246
302
|
hovertext=hovertext,
|
|
247
303
|
hovertemplate=hovertemplate,
|
|
248
|
-
customdata=df[
|
|
304
|
+
customdata=df[custom_data_cols] if custom_data_cols else None,
|
|
249
305
|
marker=dict(
|
|
250
306
|
size=marker_size,
|
|
251
307
|
color=marker_color,
|
|
252
308
|
colorscale=self.colorscale,
|
|
253
309
|
colorbar=colorbar,
|
|
254
|
-
opacity=0.
|
|
310
|
+
opacity=0.9,
|
|
255
311
|
line=dict(color="rgba(0,0,0,0.25)", width=1),
|
|
256
312
|
),
|
|
257
313
|
)
|
|
@@ -266,18 +322,16 @@ class ScatterPlot(PluginInterface):
|
|
|
266
322
|
data = []
|
|
267
323
|
for i, cat in enumerate(categories):
|
|
268
324
|
sub_df = df[df[color_col] == cat]
|
|
269
|
-
sub_hovertext = hovertext.loc[sub_df.index]
|
|
325
|
+
sub_hovertext = hovertext.loc[sub_df.index] if hovertext is not None else None
|
|
270
326
|
trace = go.Scattergl(
|
|
271
327
|
x=sub_df[x_col],
|
|
272
328
|
y=sub_df[y_col],
|
|
273
329
|
mode="markers",
|
|
274
|
-
text=sub_df[label_col] if show_labels else None, # Add text if labels enabled
|
|
275
|
-
textposition="top center", # Position labels above points
|
|
276
330
|
name=cat,
|
|
277
331
|
hoverinfo=hoverinfo,
|
|
278
332
|
hovertext=sub_hovertext,
|
|
279
333
|
hovertemplate=hovertemplate,
|
|
280
|
-
customdata=sub_df[
|
|
334
|
+
customdata=sub_df[custom_data_cols] if custom_data_cols else None,
|
|
281
335
|
marker=dict(
|
|
282
336
|
size=marker_size,
|
|
283
337
|
color=discrete_colors[i % len(discrete_colors)],
|
|
@@ -345,64 +399,97 @@ class ScatterPlot(PluginInterface):
|
|
|
345
399
|
Input(f"{self.component_id}-x-dropdown", "value"),
|
|
346
400
|
Input(f"{self.component_id}-y-dropdown", "value"),
|
|
347
401
|
Input(f"{self.component_id}-color-dropdown", "value"),
|
|
348
|
-
Input(f"{self.component_id}-label-dropdown", "value"),
|
|
349
402
|
Input(f"{self.component_id}-regression-line", "value"),
|
|
350
403
|
],
|
|
351
404
|
prevent_initial_call=True,
|
|
352
405
|
)
|
|
353
|
-
def _update_scatter_plot(x_value, y_value, color_value,
|
|
406
|
+
def _update_scatter_plot(x_value, y_value, color_value, regression_line):
|
|
354
407
|
"""Update the Scatter Plot Graph based on the dropdown values."""
|
|
355
408
|
|
|
356
409
|
# Check if the dataframe is not empty and the values are not None
|
|
357
410
|
if not self.df.empty and x_value and y_value and color_value:
|
|
358
|
-
|
|
359
|
-
figure = self.create_scatter_plot(self.df, x_value, y_value, color_value, label_value, regression_line)
|
|
411
|
+
figure = self.create_scatter_plot(self.df, x_value, y_value, color_value, regression_line)
|
|
360
412
|
return figure
|
|
361
413
|
|
|
362
414
|
raise PreventUpdate
|
|
363
415
|
|
|
364
|
-
|
|
416
|
+
# Clientside callback for circle overlay - runs in browser, no server round trip
|
|
417
|
+
clientside_callback(
|
|
418
|
+
f"""
|
|
419
|
+
function(hoverData) {{
|
|
420
|
+
if (!hoverData) {{
|
|
421
|
+
return [false, window.dash_clientside.no_update, window.dash_clientside.no_update];
|
|
422
|
+
}}
|
|
423
|
+
var bbox = hoverData.points[0].bbox;
|
|
424
|
+
var centerX = (bbox.x0 + bbox.x1) / 2;
|
|
425
|
+
var centerY = (bbox.y0 + bbox.y1) / 2;
|
|
426
|
+
var adjustedBbox = {{
|
|
427
|
+
x0: centerX - 50,
|
|
428
|
+
x1: centerX + 50,
|
|
429
|
+
y0: centerY - 162,
|
|
430
|
+
y1: centerY - 62
|
|
431
|
+
}};
|
|
432
|
+
var imgElement = {{
|
|
433
|
+
type: 'Img',
|
|
434
|
+
namespace: 'dash_html_components',
|
|
435
|
+
props: {{
|
|
436
|
+
src: '{self._circle_data_uri}',
|
|
437
|
+
style: {{width: '100px', height: '100px'}}
|
|
438
|
+
}}
|
|
439
|
+
}};
|
|
440
|
+
return [true, adjustedBbox, [imgElement]];
|
|
441
|
+
}}
|
|
442
|
+
""",
|
|
365
443
|
Output(f"{self.component_id}-overlay", "show"),
|
|
366
444
|
Output(f"{self.component_id}-overlay", "bbox"),
|
|
367
445
|
Output(f"{self.component_id}-overlay", "children"),
|
|
368
446
|
Input(f"{self.component_id}-graph", "hoverData"),
|
|
369
447
|
)
|
|
370
|
-
def _scatter_overlay(hover_data):
|
|
371
|
-
if hover_data is None:
|
|
372
|
-
# Hide the overlay if no hover data
|
|
373
|
-
return False, no_update, no_update
|
|
374
448
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
449
|
+
@callback(
|
|
450
|
+
Output(f"{self.component_id}-molecule-tooltip", "show"),
|
|
451
|
+
Output(f"{self.component_id}-molecule-tooltip", "bbox"),
|
|
452
|
+
Output(f"{self.component_id}-molecule-tooltip", "children"),
|
|
453
|
+
Input(f"{self.component_id}-graph", "hoverData"),
|
|
454
|
+
)
|
|
455
|
+
def _scatter_molecule_overlay(hover_data):
|
|
456
|
+
"""Show molecule tooltip when smiles data is available."""
|
|
457
|
+
if hover_data is None or not self.has_smiles:
|
|
458
|
+
return False, no_update, no_update
|
|
385
459
|
|
|
386
|
-
#
|
|
387
|
-
|
|
388
|
-
|
|
460
|
+
# Extract customdata (contains smiles and id)
|
|
461
|
+
customdata = hover_data["points"][0].get("customdata")
|
|
462
|
+
if customdata is None:
|
|
463
|
+
return False, no_update, no_update
|
|
389
464
|
|
|
390
|
-
#
|
|
391
|
-
|
|
465
|
+
# SMILES is the first element, ID is second (if available)
|
|
466
|
+
if isinstance(customdata, (list, tuple)):
|
|
467
|
+
smiles = customdata[0]
|
|
468
|
+
mol_id = customdata[1] if len(customdata) > 1 and self.id_column else None
|
|
469
|
+
else:
|
|
470
|
+
smiles = customdata
|
|
471
|
+
mol_id = None
|
|
472
|
+
|
|
473
|
+
# Generate molecule tooltip with ID header (use cached background color)
|
|
474
|
+
mol_width, mol_height = 300, 200
|
|
475
|
+
children = molecule_hover_tooltip(
|
|
476
|
+
smiles, mol_id=mol_id, width=mol_width, height=mol_height, background=self.hover_background
|
|
477
|
+
)
|
|
392
478
|
|
|
393
|
-
#
|
|
479
|
+
# Position molecule tooltip above and slightly right of the point
|
|
480
|
+
bbox = hover_data["points"][0]["bbox"]
|
|
394
481
|
center_x = (bbox["x0"] + bbox["x1"]) / 2
|
|
395
482
|
center_y = (bbox["y0"] + bbox["y1"]) / 2
|
|
483
|
+
x_offset = 5 # Slight offset to the right
|
|
484
|
+
y_offset = mol_height + 50 # Above the point
|
|
396
485
|
|
|
397
|
-
# The tooltip should be centered on the point (note: 'bottom' tooltip, so we adjust y position)
|
|
398
486
|
adjusted_bbox = {
|
|
399
|
-
"x0": center_x
|
|
400
|
-
"x1": center_x +
|
|
401
|
-
"y0": center_y -
|
|
402
|
-
"y1": center_y -
|
|
487
|
+
"x0": center_x + x_offset,
|
|
488
|
+
"x1": center_x + x_offset + mol_width,
|
|
489
|
+
"y0": center_y - mol_height - y_offset,
|
|
490
|
+
"y1": center_y - y_offset,
|
|
403
491
|
}
|
|
404
|
-
|
|
405
|
-
return True, adjusted_bbox, [svg_image]
|
|
492
|
+
return True, adjusted_bbox, children
|
|
406
493
|
|
|
407
494
|
|
|
408
495
|
if __name__ == "__main__":
|
|
@@ -420,22 +507,35 @@ if __name__ == "__main__":
|
|
|
420
507
|
df = pd.DataFrame(data)
|
|
421
508
|
|
|
422
509
|
# Get a UQ regressor model
|
|
423
|
-
|
|
424
|
-
# end = Endpoint("aqsol-uq")
|
|
425
|
-
# df = end.auto_inference()
|
|
426
|
-
# DFStore().upsert("/workbench/models/aqsol-uq/auto_inference", df)
|
|
427
|
-
|
|
428
|
-
from workbench.api import DFStore
|
|
510
|
+
from workbench.api import Model
|
|
429
511
|
|
|
430
|
-
|
|
512
|
+
model = Model("logd-reg-xgb")
|
|
513
|
+
df = model.get_inference_predictions("full_cross_fold")
|
|
431
514
|
|
|
432
515
|
# Run the Unit Test on the Plugin
|
|
516
|
+
# Test currently commented out
|
|
517
|
+
"""
|
|
433
518
|
PluginUnitTest(
|
|
434
519
|
ScatterPlot,
|
|
435
520
|
input_data=df,
|
|
436
521
|
theme="midnight_blue",
|
|
437
|
-
x="
|
|
522
|
+
x="logd",
|
|
438
523
|
y="prediction",
|
|
439
|
-
color="
|
|
524
|
+
color="prediction_std",
|
|
525
|
+
suppress_hover_display=True,
|
|
526
|
+
).run()
|
|
527
|
+
"""
|
|
528
|
+
|
|
529
|
+
# Test with molecule hover (smiles column)
|
|
530
|
+
from workbench.api import FeatureSet
|
|
531
|
+
|
|
532
|
+
fs = FeatureSet("aqsol_features")
|
|
533
|
+
mol_df = fs.pull_dataframe()[:1000] # Limit to 1000 rows for testing
|
|
534
|
+
|
|
535
|
+
# Run the Unit Test with molecule data (hover over points to see molecule structures)
|
|
536
|
+
PluginUnitTest(
|
|
537
|
+
ScatterPlot,
|
|
538
|
+
input_data=mol_df,
|
|
539
|
+
theme="midnight_blue",
|
|
440
540
|
suppress_hover_display=True,
|
|
441
541
|
).run()
|