workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import base64
2
2
  import pandas as pd
3
- from dash import dcc, html, callback, Input, Output, no_update
3
+ from dash import dcc, html, callback, clientside_callback, Input, Output, no_update
4
4
  import plotly.graph_objects as go
5
5
  import plotly.express as px
6
6
  from dash.exceptions import PreventUpdate
@@ -8,7 +8,7 @@ from dash.exceptions import PreventUpdate
8
8
  # Workbench Imports
9
9
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
10
10
  from workbench.utils.theme_manager import ThemeManager
11
- from workbench.utils.plot_utils import prediction_intervals
11
+ from workbench.utils.plot_utils import prediction_intervals, molecule_hover_tooltip
12
12
 
13
13
 
14
14
  class ScatterPlot(PluginInterface):
@@ -18,6 +18,12 @@ class ScatterPlot(PluginInterface):
18
18
  auto_load_page = PluginPage.NONE
19
19
  plugin_input_type = PluginInputType.DATAFRAME
20
20
 
21
+ # Pre-computed circle overlay SVG
22
+ _circle_svg = """<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" style="overflow: visible;">
23
+ <circle cx="50" cy="50" r="10" stroke="rgba(255, 255, 255, 1)" stroke-width="3" fill="none" />
24
+ </svg>"""
25
+ _circle_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(_circle_svg.encode('utf-8')).decode('utf-8')}"
26
+
21
27
  def __init__(self, show_axes: bool = True):
22
28
  """Initialize the Scatter Plot Plugin
23
29
 
@@ -29,7 +35,10 @@ class ScatterPlot(PluginInterface):
29
35
  self.df = None
30
36
  self.show_axes = show_axes
31
37
  self.theme_manager = ThemeManager()
32
- self.colorscale = self.theme_manager.colorscale()
38
+ self.has_smiles = False # Track if dataframe has smiles column for molecule hover
39
+ self.smiles_column = None
40
+ self.id_column = None
41
+ self.hover_background = None # Cached background color for molecule hover tooltip
33
42
 
34
43
  # Call the parent class constructor
35
44
  super().__init__()
@@ -51,10 +60,10 @@ class ScatterPlot(PluginInterface):
51
60
  (f"{component_id}-x-dropdown", "options"),
52
61
  (f"{component_id}-y-dropdown", "options"),
53
62
  (f"{component_id}-color-dropdown", "options"),
54
- (f"{component_id}-label-dropdown", "options"),
55
63
  (f"{component_id}-x-dropdown", "value"),
56
64
  (f"{component_id}-y-dropdown", "value"),
57
65
  (f"{component_id}-color-dropdown", "value"),
66
+ (f"{component_id}-regression-line", "value"),
58
67
  ]
59
68
  self.signals = [(f"{component_id}-graph", "hoverData"), (f"{component_id}-graph", "clickData")]
60
69
 
@@ -69,51 +78,67 @@ class ScatterPlot(PluginInterface):
69
78
  id=f"{component_id}-graph",
70
79
  figure=self.display_text("Waiting for Data..."),
71
80
  config={"scrollZoom": True},
72
- style={"height": "100%"},
81
+ style={"height": "600px", "width": "100%"},
73
82
  clear_on_unhover=True,
74
83
  ),
75
84
  # Controls: X, Y, Color, Label Dropdowns, and Regression Line Checkbox
76
85
  html.Div(
77
86
  [
78
- html.Label("X", style={"marginLeft": "40px", "marginRight": "5px", "fontWeight": "bold"}),
87
+ html.Label(
88
+ "X",
89
+ style={
90
+ "marginLeft": "20px",
91
+ "marginRight": "5px",
92
+ "fontWeight": "bold",
93
+ "display": "flex",
94
+ "alignItems": "center",
95
+ },
96
+ ),
79
97
  dcc.Dropdown(
80
98
  id=f"{component_id}-x-dropdown",
81
- className="dropdown",
82
- style={"min-width": "50px", "flex": 1}, # Responsive width
99
+ style={"minWidth": "150px", "flex": 1},
83
100
  clearable=False,
84
101
  ),
85
- html.Label("Y", style={"marginLeft": "30px", "marginRight": "5px", "fontWeight": "bold"}),
102
+ html.Label(
103
+ "Y",
104
+ style={
105
+ "marginLeft": "20px",
106
+ "marginRight": "5px",
107
+ "fontWeight": "bold",
108
+ "display": "flex",
109
+ "alignItems": "center",
110
+ },
111
+ ),
86
112
  dcc.Dropdown(
87
113
  id=f"{component_id}-y-dropdown",
88
- className="dropdown",
89
- style={"min-width": "50px", "flex": 1}, # Responsive width
114
+ style={"minWidth": "150px", "flex": 1},
90
115
  clearable=False,
91
116
  ),
92
- html.Label("Color", style={"marginLeft": "30px", "marginRight": "5px", "fontWeight": "bold"}),
93
- dcc.Dropdown(
94
- id=f"{component_id}-color-dropdown",
95
- className="dropdown",
96
- style={"min-width": "50px", "flex": 1}, # Responsive width
97
- clearable=False,
117
+ html.Label(
118
+ "Color",
119
+ style={
120
+ "marginLeft": "20px",
121
+ "marginRight": "5px",
122
+ "fontWeight": "bold",
123
+ "display": "flex",
124
+ "alignItems": "center",
125
+ },
98
126
  ),
99
- html.Label("Label", style={"marginLeft": "30px", "marginRight": "5px", "fontWeight": "bold"}),
100
127
  dcc.Dropdown(
101
- id=f"{component_id}-label-dropdown",
102
- className="dropdown",
103
- style={"min-width": "50px", "flex": 1},
104
- options=[{"label": "None", "value": "none"}],
105
- value="none",
128
+ id=f"{component_id}-color-dropdown",
129
+ style={"minWidth": "150px", "flex": 1},
106
130
  clearable=False,
107
131
  ),
108
132
  dcc.Checklist(
109
133
  id=f"{component_id}-regression-line",
110
134
  options=[{"label": " Diagonal", "value": "show"}],
111
135
  value=[],
112
- style={"margin": "10px"},
136
+ style={"marginLeft": "20px", "display": "flex", "alignItems": "center"},
113
137
  ),
114
138
  ],
115
- style={"padding": "0px 0px 10px 0px", "display": "flex", "gap": "10px"},
139
+ style={"padding": "0px 0px 10px 0px", "display": "flex", "alignItems": "center", "gap": "5px"},
116
140
  ),
141
+ # Circle overlay tooltip (centered on hovered point)
117
142
  dcc.Tooltip(
118
143
  id=f"{component_id}-overlay",
119
144
  background_color="rgba(0,0,0,0)",
@@ -121,6 +146,14 @@ class ScatterPlot(PluginInterface):
121
146
  direction="bottom",
122
147
  loading_text="",
123
148
  ),
149
+ # Molecule tooltip (offset from hovered point) - only used when smiles column exists
150
+ dcc.Tooltip(
151
+ id=f"{component_id}-molecule-tooltip",
152
+ background_color="rgba(0,0,0,0)",
153
+ border_color="rgba(0,0,0,0)",
154
+ direction="bottom",
155
+ loading_text="",
156
+ ),
124
157
  ],
125
158
  style={"height": "100%", "display": "flex", "flexDirection": "column"}, # Full viewport height
126
159
  )
@@ -139,12 +172,16 @@ class ScatterPlot(PluginInterface):
139
172
  - hover_columns: The columns to show when hovering over a point
140
173
  - suppress_hover_display: Suppress hover display (default: False)
141
174
  - custom_data: Custom data that get passed to hoverData callbacks
175
+ - id_column: Column to use for molecule tooltip header (auto-detects "id" if not specified)
142
176
 
143
177
  Returns:
144
178
  list: A list of updated property values (figure, x options, y options, color options,
145
- label options, x default, y default,
146
- color default).
179
+ x default, y default, color default).
147
180
  """
181
+ # Get the colorscale and background color from the current theme
182
+ self.colorscale = self.theme_manager.colorscale()
183
+ self.hover_background = self.theme_manager.background()
184
+
148
185
  # Get the limit for the number of rows to plot
149
186
  limit = kwargs.get("limit", 20000)
150
187
 
@@ -159,23 +196,32 @@ class ScatterPlot(PluginInterface):
159
196
  self.df = self.df.drop(columns=aws_cols, errors="ignore")
160
197
 
161
198
  # Set hover columns and custom data
162
- self.hover_columns = kwargs.get("hover_columns", self.df.columns.tolist()[:10])
199
+ self.hover_columns = kwargs.get("hover_columns", sorted(self.df.columns.tolist()[:15]))
163
200
  self.suppress_hover_display = kwargs.get("suppress_hover_display", False)
164
201
  self.custom_data = kwargs.get("custom_data", [])
165
202
 
203
+ # Check if the dataframe has smiles/id columns for molecule hover rendering
204
+ self.smiles_column = next((col for col in self.df.columns if col.lower() == "smiles"), None)
205
+ # Use provided id_column, or auto-detect "id" column, or fall back to first column
206
+ self.id_column = kwargs.get("id_column") or next(
207
+ (col for col in self.df.columns if col.lower() == "id"), self.df.columns[0]
208
+ )
209
+ self.has_smiles = self.smiles_column is not None
210
+
166
211
  # Identify numeric columns
167
212
  numeric_columns = self.df.select_dtypes(include="number").columns.tolist()
168
213
  if len(numeric_columns) < 3:
169
214
  raise ValueError("At least three numeric columns are required for x, y, and color.")
170
215
 
171
- # Default x, y, and color (for color, default to a numeric column)
216
+ # Default x, y, and color (for color, prefer 'confidence' if it exists)
172
217
  x_default = kwargs.get("x", numeric_columns[0])
173
218
  y_default = kwargs.get("y", numeric_columns[1])
174
- color_default = kwargs.get("color", numeric_columns[2])
219
+ default_color = "confidence" if "confidence" in self.df.columns else numeric_columns[2]
220
+ color_default = kwargs.get("color", default_color)
175
221
  regression_line = kwargs.get("regression_line", False)
176
222
 
177
223
  # Create the default scatter plot
178
- figure = self.create_scatter_plot(self.df, x_default, y_default, color_default, "none", regression_line)
224
+ figure = self.create_scatter_plot(self.df, x_default, y_default, color_default, regression_line)
179
225
 
180
226
  # Dropdown options for x and y: use provided dropdown_columns or fallback to numeric columns
181
227
  dropdown_columns = kwargs.get("dropdown_columns", numeric_columns)
@@ -188,11 +234,10 @@ class ScatterPlot(PluginInterface):
188
234
  color_columns = numeric_columns + cat_columns
189
235
  color_options = [{"label": col, "value": col} for col in color_columns]
190
236
 
191
- # For label dropdown, include None option and all columns
192
- label_options = [{"label": "None", "value": "none"}]
193
- label_options.extend([{"label": col, "value": col} for col in self.df.columns])
237
+ # Regression line checklist value (list with "show" if enabled, empty list if disabled)
238
+ regression_line_value = ["show"] if regression_line else []
194
239
 
195
- return [figure, x_options, y_options, color_options, label_options, x_default, y_default, color_default]
240
+ return [figure, x_options, y_options, color_options, x_default, y_default, color_default, regression_line_value]
196
241
 
197
242
  def create_scatter_plot(
198
243
  self,
@@ -200,7 +245,6 @@ class ScatterPlot(PluginInterface):
200
245
  x_col: str,
201
246
  y_col: str,
202
247
  color_col: str,
203
- label_col: str,
204
248
  regression_line: bool = False,
205
249
  marker_size: int = 15,
206
250
  ) -> go.Figure:
@@ -211,24 +255,38 @@ class ScatterPlot(PluginInterface):
211
255
  x_col (str): The column to use for the x-axis.
212
256
  y_col (str): The column to use for the y-axis.
213
257
  color_col (str): The column to use for the color scale.
214
- label_col (str): The column to use for point labels.
215
258
  regression_line (bool): Whether to include a regression line.
216
259
  marker_size (int): Size of the markers. Default is 15.
217
260
 
218
261
  Returns:
219
262
  go.Figure: A Plotly Figure object.
220
263
  """
221
- # Check if we need to show labels
222
- show_labels = label_col != "none" and len(df) < 1000
223
264
 
224
265
  # Helper to generate hover text for each point.
225
266
  def generate_hover_text(row):
226
267
  return "<br>".join([f"{col}: {row[col]}" for col in self.hover_columns])
227
268
 
228
- # Generate hover text for all points.
229
- hovertext = df.apply(generate_hover_text, axis=1)
230
- hovertemplate = "%{hovertext}<extra></extra>"
231
- hoverinfo = "none" if self.suppress_hover_display else None
269
+ # Generate hover text for all points (unless suppressed or using molecule hover)
270
+ suppress_hover = self.suppress_hover_display or self.has_smiles
271
+ if suppress_hover:
272
+ # Use "none" to hide the default hover display but still fire hoverData callbacks
273
+ # Don't set hovertemplate when suppressing - it would override hoverinfo
274
+ hovertext = None
275
+ hovertemplate = None
276
+ hoverinfo = "none"
277
+ else:
278
+ hovertext = df.apply(generate_hover_text, axis=1)
279
+ hovertemplate = "%{hovertext}<extra></extra>"
280
+ hoverinfo = None
281
+
282
+ # Build customdata columns - include smiles and id if available for molecule hover
283
+ custom_data_cols = list(self.custom_data) if self.custom_data else []
284
+ if self.has_smiles:
285
+ # Add smiles as first column, id as second (if available)
286
+ if self.smiles_column not in custom_data_cols:
287
+ custom_data_cols = [self.smiles_column] + custom_data_cols
288
+ if self.id_column and self.id_column not in custom_data_cols:
289
+ custom_data_cols.insert(1, self.id_column)
232
290
 
233
291
  # Determine marker settings based on the type of the color column.
234
292
  if pd.api.types.is_numeric_dtype(df[color_col]):
@@ -240,18 +298,16 @@ class ScatterPlot(PluginInterface):
240
298
  x=df[x_col],
241
299
  y=df[y_col],
242
300
  mode="markers",
243
- text=df[label_col].astype(str) if show_labels else None,
244
- textposition="top center",
245
301
  hoverinfo=hoverinfo,
246
302
  hovertext=hovertext,
247
303
  hovertemplate=hovertemplate,
248
- customdata=df[self.custom_data],
304
+ customdata=df[custom_data_cols] if custom_data_cols else None,
249
305
  marker=dict(
250
306
  size=marker_size,
251
307
  color=marker_color,
252
308
  colorscale=self.colorscale,
253
309
  colorbar=colorbar,
254
- opacity=0.8,
310
+ opacity=0.9,
255
311
  line=dict(color="rgba(0,0,0,0.25)", width=1),
256
312
  ),
257
313
  )
@@ -266,18 +322,16 @@ class ScatterPlot(PluginInterface):
266
322
  data = []
267
323
  for i, cat in enumerate(categories):
268
324
  sub_df = df[df[color_col] == cat]
269
- sub_hovertext = hovertext.loc[sub_df.index]
325
+ sub_hovertext = hovertext.loc[sub_df.index] if hovertext is not None else None
270
326
  trace = go.Scattergl(
271
327
  x=sub_df[x_col],
272
328
  y=sub_df[y_col],
273
329
  mode="markers",
274
- text=sub_df[label_col] if show_labels else None, # Add text if labels enabled
275
- textposition="top center", # Position labels above points
276
330
  name=cat,
277
331
  hoverinfo=hoverinfo,
278
332
  hovertext=sub_hovertext,
279
333
  hovertemplate=hovertemplate,
280
- customdata=sub_df[self.custom_data],
334
+ customdata=sub_df[custom_data_cols] if custom_data_cols else None,
281
335
  marker=dict(
282
336
  size=marker_size,
283
337
  color=discrete_colors[i % len(discrete_colors)],
@@ -345,64 +399,97 @@ class ScatterPlot(PluginInterface):
345
399
  Input(f"{self.component_id}-x-dropdown", "value"),
346
400
  Input(f"{self.component_id}-y-dropdown", "value"),
347
401
  Input(f"{self.component_id}-color-dropdown", "value"),
348
- Input(f"{self.component_id}-label-dropdown", "value"),
349
402
  Input(f"{self.component_id}-regression-line", "value"),
350
403
  ],
351
404
  prevent_initial_call=True,
352
405
  )
353
- def _update_scatter_plot(x_value, y_value, color_value, label_value, regression_line):
406
+ def _update_scatter_plot(x_value, y_value, color_value, regression_line):
354
407
  """Update the Scatter Plot Graph based on the dropdown values."""
355
408
 
356
409
  # Check if the dataframe is not empty and the values are not None
357
410
  if not self.df.empty and x_value and y_value and color_value:
358
- # Update Plotly Scatter Plot with the label value
359
- figure = self.create_scatter_plot(self.df, x_value, y_value, color_value, label_value, regression_line)
411
+ figure = self.create_scatter_plot(self.df, x_value, y_value, color_value, regression_line)
360
412
  return figure
361
413
 
362
414
  raise PreventUpdate
363
415
 
364
- @callback(
416
+ # Clientside callback for circle overlay - runs in browser, no server round trip
417
+ clientside_callback(
418
+ f"""
419
+ function(hoverData) {{
420
+ if (!hoverData) {{
421
+ return [false, window.dash_clientside.no_update, window.dash_clientside.no_update];
422
+ }}
423
+ var bbox = hoverData.points[0].bbox;
424
+ var centerX = (bbox.x0 + bbox.x1) / 2;
425
+ var centerY = (bbox.y0 + bbox.y1) / 2;
426
+ var adjustedBbox = {{
427
+ x0: centerX - 50,
428
+ x1: centerX + 50,
429
+ y0: centerY - 162,
430
+ y1: centerY - 62
431
+ }};
432
+ var imgElement = {{
433
+ type: 'Img',
434
+ namespace: 'dash_html_components',
435
+ props: {{
436
+ src: '{self._circle_data_uri}',
437
+ style: {{width: '100px', height: '100px'}}
438
+ }}
439
+ }};
440
+ return [true, adjustedBbox, [imgElement]];
441
+ }}
442
+ """,
365
443
  Output(f"{self.component_id}-overlay", "show"),
366
444
  Output(f"{self.component_id}-overlay", "bbox"),
367
445
  Output(f"{self.component_id}-overlay", "children"),
368
446
  Input(f"{self.component_id}-graph", "hoverData"),
369
447
  )
370
- def _scatter_overlay(hover_data):
371
- if hover_data is None:
372
- # Hide the overlay if no hover data
373
- return False, no_update, no_update
374
448
 
375
- # Extract bounding box from hoverData
376
- bbox = hover_data["points"][0]["bbox"]
377
-
378
- # Create an SVG with a circle at the center
379
- svg = """
380
- <svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" style="overflow: visible;">
381
- <!-- Circle for the node -->
382
- <circle cx="50" cy="50" r="10" stroke="rgba(255, 255, 255, 1)" stroke-width="3" fill="none" />
383
- </svg>
384
- """
449
+ @callback(
450
+ Output(f"{self.component_id}-molecule-tooltip", "show"),
451
+ Output(f"{self.component_id}-molecule-tooltip", "bbox"),
452
+ Output(f"{self.component_id}-molecule-tooltip", "children"),
453
+ Input(f"{self.component_id}-graph", "hoverData"),
454
+ )
455
+ def _scatter_molecule_overlay(hover_data):
456
+ """Show molecule tooltip when smiles data is available."""
457
+ if hover_data is None or not self.has_smiles:
458
+ return False, no_update, no_update
385
459
 
386
- # Encode the SVG as Base64
387
- encoded_svg = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
388
- data_uri = f"data:image/svg+xml;base64,{encoded_svg}"
460
+ # Extract customdata (contains smiles and id)
461
+ customdata = hover_data["points"][0].get("customdata")
462
+ if customdata is None:
463
+ return False, no_update, no_update
389
464
 
390
- # Use an img tag for the overlay
391
- svg_image = html.Img(src=data_uri, style={"width": "100px", "height": "100px"})
465
+ # SMILES is the first element, ID is second (if available)
466
+ if isinstance(customdata, (list, tuple)):
467
+ smiles = customdata[0]
468
+ mol_id = customdata[1] if len(customdata) > 1 and self.id_column else None
469
+ else:
470
+ smiles = customdata
471
+ mol_id = None
472
+
473
+ # Generate molecule tooltip with ID header (use cached background color)
474
+ mol_width, mol_height = 300, 200
475
+ children = molecule_hover_tooltip(
476
+ smiles, mol_id=mol_id, width=mol_width, height=mol_height, background=self.hover_background
477
+ )
392
478
 
393
- # Get the center of the bounding box
479
+ # Position molecule tooltip above and slightly right of the point
480
+ bbox = hover_data["points"][0]["bbox"]
394
481
  center_x = (bbox["x0"] + bbox["x1"]) / 2
395
482
  center_y = (bbox["y0"] + bbox["y1"]) / 2
483
+ x_offset = 5 # Slight offset to the right
484
+ y_offset = mol_height + 50 # Above the point
396
485
 
397
- # The tooltip should be centered on the point (note: 'bottom' tooltip, so we adjust y position)
398
486
  adjusted_bbox = {
399
- "x0": center_x - 50,
400
- "x1": center_x + 50,
401
- "y0": center_y - 162,
402
- "y1": center_y - 62,
487
+ "x0": center_x + x_offset,
488
+ "x1": center_x + x_offset + mol_width,
489
+ "y0": center_y - mol_height - y_offset,
490
+ "y1": center_y - y_offset,
403
491
  }
404
- # Return the updated values for the overlay
405
- return True, adjusted_bbox, [svg_image]
492
+ return True, adjusted_bbox, children
406
493
 
407
494
 
408
495
  if __name__ == "__main__":
@@ -420,22 +507,35 @@ if __name__ == "__main__":
420
507
  df = pd.DataFrame(data)
421
508
 
422
509
  # Get a UQ regressor model
423
- # from workbench.api import Endpoint, DFStore
424
- # end = Endpoint("aqsol-uq")
425
- # df = end.auto_inference()
426
- # DFStore().upsert("/workbench/models/aqsol-uq/auto_inference", df)
427
-
428
- from workbench.api import DFStore
510
+ from workbench.api import Model
429
511
 
430
- df = DFStore().get("/workbench/models/aqsol-uq/auto_inference")
512
+ model = Model("logd-reg-xgb")
513
+ df = model.get_inference_predictions("full_cross_fold")
431
514
 
432
515
  # Run the Unit Test on the Plugin
516
+ # Test currently commented out
517
+ """
433
518
  PluginUnitTest(
434
519
  ScatterPlot,
435
520
  input_data=df,
436
521
  theme="midnight_blue",
437
- x="solubility",
522
+ x="logd",
438
523
  y="prediction",
439
- color="residuals_abs",
524
+ color="prediction_std",
525
+ suppress_hover_display=True,
526
+ ).run()
527
+ """
528
+
529
+ # Test with molecule hover (smiles column)
530
+ from workbench.api import FeatureSet
531
+
532
+ fs = FeatureSet("aqsol_features")
533
+ mol_df = fs.pull_dataframe()[:1000] # Limit to 1000 rows for testing
534
+
535
+ # Run the Unit Test with molecule data (hover over points to see molecule structures)
536
+ PluginUnitTest(
537
+ ScatterPlot,
538
+ input_data=mol_df,
539
+ theme="midnight_blue",
440
540
  suppress_hover_display=True,
441
541
  ).run()