workbench 0.8.224__py3-none-any.whl → 0.8.231__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +2 -0
- workbench/algorithms/dataframe/smart_aggregator.py +161 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/api/meta.py +0 -1
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +37 -7
- workbench/core/artifacts/endpoint_core.py +12 -2
- workbench/core/artifacts/feature_set_core.py +66 -8
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/model_script_utils/model_script_utils.py +30 -0
- workbench/model_script_utils/uq_harness.py +0 -1
- workbench/model_scripts/chemprop/chemprop.template +3 -0
- workbench/model_scripts/chemprop/generated_model_script.py +3 -3
- workbench/model_scripts/chemprop/model_script_utils.py +30 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
- workbench/model_scripts/pytorch_model/generated_model_script.py +50 -32
- workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
- workbench/model_scripts/pytorch_model/pytorch.template +47 -29
- workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
- workbench/model_scripts/script_generation.py +0 -1
- workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
- workbench/model_scripts/xgb_model/uq_harness.py +0 -1
- workbench/themes/dark/custom.css +85 -8
- workbench/themes/dark/plotly.json +6 -6
- workbench/themes/light/custom.css +172 -70
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +48 -29
- workbench/themes/midnight_blue/plotly.json +1 -1
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/vis.py +137 -27
- workbench/utils/clientside_callbacks.py +41 -0
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +8 -110
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +2 -0
- workbench/web_interface/components/plugin_unit_test.py +0 -1
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +10 -6
- workbench/web_interface/components/plugins/scatter_plot.py +56 -43
- workbench/web_interface/components/settings_menu.py +2 -1
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/METADATA +31 -29
- {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/RECORD +55 -59
- {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
|
@@ -2,32 +2,122 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import base64
|
|
5
|
-
import re
|
|
6
5
|
from typing import Optional, Tuple
|
|
7
6
|
from rdkit import Chem
|
|
8
7
|
from rdkit.Chem import AllChem, Draw
|
|
9
8
|
from rdkit.Chem.Draw import rdMolDraw2D
|
|
9
|
+
from dash import html
|
|
10
|
+
|
|
11
|
+
# Workbench Imports
|
|
12
|
+
from workbench.utils.color_utils import is_dark
|
|
10
13
|
|
|
11
14
|
# Set up the logger
|
|
12
15
|
log = logging.getLogger("workbench")
|
|
13
16
|
|
|
14
17
|
|
|
15
|
-
def
|
|
16
|
-
|
|
18
|
+
def molecule_hover_tooltip(
|
|
19
|
+
smiles: str, mol_id: str = None, width: int = 300, height: int = 200, background: str = None
|
|
20
|
+
) -> list:
|
|
21
|
+
"""Generate a molecule hover tooltip from a SMILES string.
|
|
22
|
+
|
|
23
|
+
This function creates a visually appealing tooltip with a dark background
|
|
24
|
+
that displays the molecule ID at the top and structure below when hovering
|
|
25
|
+
over scatter plot points.
|
|
17
26
|
|
|
18
27
|
Args:
|
|
19
|
-
|
|
28
|
+
smiles: SMILES string representing the molecule
|
|
29
|
+
mol_id: Optional molecule ID to display at the top of the tooltip
|
|
30
|
+
width: Width of the molecule image in pixels (default: 300)
|
|
31
|
+
height: Height of the molecule image in pixels (default: 200)
|
|
32
|
+
background: Optional background color (if None, uses dark gray)
|
|
20
33
|
|
|
21
34
|
Returns:
|
|
22
|
-
|
|
35
|
+
list: A list containing an html.Div with the ID header and molecule SVG,
|
|
36
|
+
or an html.Div with an error message if rendering fails
|
|
23
37
|
"""
|
|
24
|
-
|
|
25
|
-
if not match:
|
|
26
|
-
log.warning(f"Invalid color format: {color}, defaulting to dark")
|
|
27
|
-
return True # Default to dark mode on error
|
|
38
|
+
try:
|
|
28
39
|
|
|
29
|
-
|
|
30
|
-
|
|
40
|
+
# Use provided background or default to dark gray
|
|
41
|
+
if background is None:
|
|
42
|
+
background = "rgba(64, 64, 64, 1)"
|
|
43
|
+
|
|
44
|
+
# Generate the SVG image from SMILES (base64 encoded data URI)
|
|
45
|
+
img = svg_from_smiles(smiles, width, height, background=background)
|
|
46
|
+
|
|
47
|
+
if img is None:
|
|
48
|
+
log.warning(f"Could not render molecule for SMILES: {smiles}")
|
|
49
|
+
return [
|
|
50
|
+
html.Div(
|
|
51
|
+
"Invalid SMILES",
|
|
52
|
+
className="custom-tooltip",
|
|
53
|
+
style={
|
|
54
|
+
"padding": "10px",
|
|
55
|
+
"color": "rgb(255, 140, 140)",
|
|
56
|
+
"width": f"{width}px",
|
|
57
|
+
"height": f"{height}px",
|
|
58
|
+
"display": "flex",
|
|
59
|
+
"alignItems": "center",
|
|
60
|
+
"justifyContent": "center",
|
|
61
|
+
},
|
|
62
|
+
)
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
# Build the tooltip with ID header and molecule image
|
|
66
|
+
children = []
|
|
67
|
+
|
|
68
|
+
# Add ID header if provided
|
|
69
|
+
if mol_id is not None:
|
|
70
|
+
# Set text color based on background brightness
|
|
71
|
+
text_color = "rgb(200, 200, 200)" if is_dark(background) else "rgb(60, 60, 60)"
|
|
72
|
+
children.append(
|
|
73
|
+
html.Div(
|
|
74
|
+
str(mol_id),
|
|
75
|
+
style={
|
|
76
|
+
"textAlign": "center",
|
|
77
|
+
"padding": "8px",
|
|
78
|
+
"color": text_color,
|
|
79
|
+
"fontSize": "14px",
|
|
80
|
+
"fontWeight": "bold",
|
|
81
|
+
"borderBottom": "1px solid rgba(128, 128, 128, 0.5)",
|
|
82
|
+
},
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Add molecule image
|
|
87
|
+
children.append(
|
|
88
|
+
html.Img(
|
|
89
|
+
src=img,
|
|
90
|
+
style={"padding": "0px", "margin": "0px", "display": "block"},
|
|
91
|
+
width=str(width),
|
|
92
|
+
height=str(height),
|
|
93
|
+
)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return [
|
|
97
|
+
html.Div(
|
|
98
|
+
children,
|
|
99
|
+
className="custom-tooltip",
|
|
100
|
+
style={"padding": "0px", "margin": "0px"},
|
|
101
|
+
)
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
except ImportError as e:
|
|
105
|
+
log.error(f"RDKit not available for molecule rendering: {e}")
|
|
106
|
+
return [
|
|
107
|
+
html.Div(
|
|
108
|
+
"RDKit not installed",
|
|
109
|
+
className="custom-tooltip",
|
|
110
|
+
style={
|
|
111
|
+
"padding": "10px",
|
|
112
|
+
"color": "rgb(255, 195, 140)",
|
|
113
|
+
"width": f"{width}px",
|
|
114
|
+
"height": f"{height}px",
|
|
115
|
+
"display": "flex",
|
|
116
|
+
"alignItems": "center",
|
|
117
|
+
"justifyContent": "center",
|
|
118
|
+
},
|
|
119
|
+
)
|
|
120
|
+
]
|
|
31
121
|
|
|
32
122
|
|
|
33
123
|
def _rgba_to_tuple(rgba: str) -> Tuple[float, float, float, float]:
|
|
@@ -75,7 +165,13 @@ def _configure_draw_options(options: Draw.MolDrawOptions, background: str) -> No
|
|
|
75
165
|
options: RDKit drawing options object
|
|
76
166
|
background: Background color string
|
|
77
167
|
"""
|
|
78
|
-
|
|
168
|
+
try:
|
|
169
|
+
if is_dark(background):
|
|
170
|
+
rdMolDraw2D.SetDarkMode(options)
|
|
171
|
+
# Light backgrounds use RDKit defaults (no action needed)
|
|
172
|
+
except ValueError:
|
|
173
|
+
# Default to dark mode if color format is invalid
|
|
174
|
+
log.warning(f"Invalid color format: {background}, defaulting to dark mode")
|
|
79
175
|
rdMolDraw2D.SetDarkMode(options)
|
|
80
176
|
options.setBackgroundColour(_rgba_to_tuple(background))
|
|
81
177
|
|
|
@@ -137,7 +233,7 @@ def svg_from_smiles(
|
|
|
137
233
|
drawer.DrawMolecule(mol)
|
|
138
234
|
drawer.FinishDrawing()
|
|
139
235
|
|
|
140
|
-
# Encode SVG
|
|
236
|
+
# Encode SVG as base64 data URI
|
|
141
237
|
svg = drawer.GetDrawingText()
|
|
142
238
|
encoded_svg = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
|
|
143
239
|
return f"data:image/svg+xml;base64,{encoded_svg}"
|
|
@@ -222,7 +318,7 @@ if __name__ == "__main__":
|
|
|
222
318
|
# Test 6: Color parsing functions
|
|
223
319
|
print("\n6. Testing color utility functions...")
|
|
224
320
|
test_colors = [
|
|
225
|
-
("invalid_color",
|
|
321
|
+
("invalid_color", None, (0.25, 0.25, 0.25, 1.0)), # Should raise ValueError
|
|
226
322
|
("rgba(255, 255, 255, 1)", False, (1.0, 1.0, 1.0, 1.0)),
|
|
227
323
|
("rgba(0, 0, 0, 1)", True, (0.0, 0.0, 0.0, 1.0)),
|
|
228
324
|
("rgba(64, 64, 64, 0.5)", True, (0.251, 0.251, 0.251, 0.5)),
|
|
@@ -230,24 +326,38 @@ if __name__ == "__main__":
|
|
|
230
326
|
]
|
|
231
327
|
|
|
232
328
|
for color, expected_dark, expected_tuple in test_colors:
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
329
|
+
try:
|
|
330
|
+
is_dark_result = is_dark(color)
|
|
331
|
+
if expected_dark is None:
|
|
332
|
+
print(f" ✗ is_dark('{color[:20]}...'): Expected ValueError but got {is_dark_result}")
|
|
333
|
+
else:
|
|
334
|
+
dark_status = "✓" if is_dark_result == expected_dark else "✗"
|
|
335
|
+
print(f" {dark_status} is_dark('{color[:20]}...'): {is_dark_result} == {expected_dark}")
|
|
336
|
+
except ValueError:
|
|
337
|
+
if expected_dark is None:
|
|
338
|
+
print(f" ✓ is_dark('{color[:20]}...'): Correctly raised ValueError")
|
|
339
|
+
else:
|
|
340
|
+
print(f" ✗ is_dark('{color[:20]}...'): Unexpected ValueError")
|
|
238
341
|
|
|
342
|
+
tuple_result = _rgba_to_tuple(color)
|
|
239
343
|
# Check tuple values with tolerance for floating point
|
|
240
344
|
tuple_match = all(abs(a - b) < 0.01 for a, b in zip(tuple_result, expected_tuple))
|
|
241
345
|
tuple_status = "✓" if tuple_match else "✗"
|
|
242
346
|
print(f" {tuple_status} rgba_to_tuple('{color[:20]}...'): matches expected")
|
|
243
347
|
|
|
244
|
-
# Test the
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
348
|
+
# Test the tooltip generation in a simple Dash app
|
|
349
|
+
from dash import Dash
|
|
350
|
+
|
|
351
|
+
app = Dash(__name__)
|
|
352
|
+
app.layout = html.Div(
|
|
353
|
+
[
|
|
354
|
+
html.Div("Tooltip Preview:", style={"color": "white", "marginBottom": "20px"}),
|
|
355
|
+
*molecule_hover_tooltip("CC(=O)OC1=CC=CC=C1C(=O)O", mol_id="Aspirin", background="rgba(200, 30, 30, 1)"),
|
|
356
|
+
],
|
|
357
|
+
style={"background": "#1a1a1a", "padding": "50px"},
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
if __name__ == "__main__":
|
|
361
|
+
app.run(debug=True)
|
|
252
362
|
|
|
253
363
|
print("\n✅ All tests completed!")
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Clientside JavaScript callbacks for Dash components.
|
|
2
|
+
|
|
3
|
+
These functions return JavaScript code strings for use with Dash's clientside_callback.
|
|
4
|
+
Using clientside callbacks avoids server round-trips for simple UI interactions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def circle_overlay_callback(circle_data_uri: str) -> str:
|
|
9
|
+
"""Returns JS function for circle overlay on scatter plot hover.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
circle_data_uri: Base64-encoded SVG data URI for the circle overlay
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
JavaScript function string for use with clientside_callback
|
|
16
|
+
"""
|
|
17
|
+
return f"""
|
|
18
|
+
function(hoverData) {{
|
|
19
|
+
if (!hoverData) {{
|
|
20
|
+
return [false, window.dash_clientside.no_update, window.dash_clientside.no_update];
|
|
21
|
+
}}
|
|
22
|
+
var bbox = hoverData.points[0].bbox;
|
|
23
|
+
var centerX = (bbox.x0 + bbox.x1) / 2;
|
|
24
|
+
var centerY = (bbox.y0 + bbox.y1) / 2;
|
|
25
|
+
var adjustedBbox = {{
|
|
26
|
+
x0: centerX - 50,
|
|
27
|
+
x1: centerX + 50,
|
|
28
|
+
y0: centerY - 162,
|
|
29
|
+
y1: centerY - 62
|
|
30
|
+
}};
|
|
31
|
+
var imgElement = {{
|
|
32
|
+
type: 'Img',
|
|
33
|
+
namespace: 'dash_html_components',
|
|
34
|
+
props: {{
|
|
35
|
+
src: '{circle_data_uri}',
|
|
36
|
+
style: {{width: '100px', height: '100px'}}
|
|
37
|
+
}}
|
|
38
|
+
}};
|
|
39
|
+
return [true, adjustedBbox, [imgElement]];
|
|
40
|
+
}}
|
|
41
|
+
"""
|
|
@@ -185,6 +185,63 @@ def dict_to_collapsible_html(data: dict, title: str = None, collapse_all: bool =
|
|
|
185
185
|
return result
|
|
186
186
|
|
|
187
187
|
|
|
188
|
+
def df_to_html_table(df, round_digits: int = 2, margin_bottom: int = 30) -> str:
|
|
189
|
+
"""Convert a DataFrame to a compact styled HTML table (horizontal layout).
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
df: DataFrame with metrics (can be single or multi-row)
|
|
193
|
+
round_digits: Number of decimal places to round to (default: 2)
|
|
194
|
+
margin_bottom: Bottom margin in pixels (default: 30)
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
str: HTML table string
|
|
198
|
+
"""
|
|
199
|
+
# Handle index: reset if named (keeps as column), otherwise drop
|
|
200
|
+
if df.index.name:
|
|
201
|
+
df = df.reset_index()
|
|
202
|
+
else:
|
|
203
|
+
df = df.reset_index(drop=True)
|
|
204
|
+
|
|
205
|
+
# Round numeric columns
|
|
206
|
+
df = df.round(round_digits)
|
|
207
|
+
|
|
208
|
+
# Table styles
|
|
209
|
+
container_style = f"display: flex; justify-content: center; margin-top: 10px; margin-bottom: {margin_bottom}px;"
|
|
210
|
+
table_style = "border-collapse: collapse; width: 100%; font-size: 15px;"
|
|
211
|
+
header_style = (
|
|
212
|
+
"background: linear-gradient(to bottom, #4a4a4a 0%, #2d2d2d 100%); "
|
|
213
|
+
"color: white; padding: 4px 8px; text-align: center;"
|
|
214
|
+
)
|
|
215
|
+
cell_style = "padding: 3px 8px; text-align: center; border-bottom: 1px solid #444;"
|
|
216
|
+
|
|
217
|
+
# Build the HTML table (wrapped in centered container)
|
|
218
|
+
html = f'<div style="{container_style}"><table style="{table_style}">'
|
|
219
|
+
|
|
220
|
+
# Header row
|
|
221
|
+
html += "<tr>"
|
|
222
|
+
for col in df.columns:
|
|
223
|
+
html += f'<th style="{header_style}">{col}</th>'
|
|
224
|
+
html += "</tr>"
|
|
225
|
+
|
|
226
|
+
# Data rows
|
|
227
|
+
for _, row in df.iterrows():
|
|
228
|
+
html += "<tr>"
|
|
229
|
+
for val in row:
|
|
230
|
+
# Format value: integers without decimal, floats rounded
|
|
231
|
+
if isinstance(val, float):
|
|
232
|
+
if val == int(val):
|
|
233
|
+
formatted_val = int(val)
|
|
234
|
+
else:
|
|
235
|
+
formatted_val = round(val, round_digits)
|
|
236
|
+
else:
|
|
237
|
+
formatted_val = val
|
|
238
|
+
html += f'<td style="{cell_style}">{formatted_val}</td>'
|
|
239
|
+
html += "</tr>"
|
|
240
|
+
|
|
241
|
+
html += "</table></div>"
|
|
242
|
+
return html
|
|
243
|
+
|
|
244
|
+
|
|
188
245
|
if __name__ == "__main__":
|
|
189
246
|
"""Exercise the Markdown Utilities"""
|
|
190
247
|
from workbench.api.model import Model
|
workbench/utils/plot_utils.py
CHANGED
|
@@ -4,7 +4,6 @@ import logging
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import pandas as pd
|
|
6
6
|
import plotly.graph_objects as go
|
|
7
|
-
from dash import html
|
|
8
7
|
|
|
9
8
|
log = logging.getLogger("workbench")
|
|
10
9
|
|
|
@@ -143,13 +142,13 @@ def prediction_intervals(df, figure, x_col):
|
|
|
143
142
|
|
|
144
143
|
# Sort dataframe by x_col for connected lines
|
|
145
144
|
sorted_df = df.sort_values(by=x_col)
|
|
146
|
-
# Add outer band (q_025 to q_975) - more transparent
|
|
145
|
+
# Add outer band (q_025 to q_975) - desaturated blue-gray, more transparent
|
|
147
146
|
figure.add_trace(
|
|
148
147
|
go.Scatter(
|
|
149
148
|
x=sorted_df[x_col],
|
|
150
149
|
y=sorted_df["q_025"],
|
|
151
150
|
mode="lines",
|
|
152
|
-
line=dict(width=1, color="rgba(
|
|
151
|
+
line=dict(width=1, color="rgba(120, 130, 180, 0.3)"),
|
|
153
152
|
name="2.5 Percentile",
|
|
154
153
|
hoverinfo="skip",
|
|
155
154
|
showlegend=False,
|
|
@@ -160,21 +159,21 @@ def prediction_intervals(df, figure, x_col):
|
|
|
160
159
|
x=sorted_df[x_col],
|
|
161
160
|
y=sorted_df["q_975"],
|
|
162
161
|
mode="lines",
|
|
163
|
-
line=dict(width=1, color="rgba(
|
|
162
|
+
line=dict(width=1, color="rgba(120, 130, 180, 0.3)"),
|
|
164
163
|
name="97.5 Percentile",
|
|
165
164
|
hoverinfo="skip",
|
|
166
165
|
showlegend=False,
|
|
167
166
|
fill="tonexty",
|
|
168
|
-
fillcolor="rgba(
|
|
167
|
+
fillcolor="rgba(120, 130, 180, 0.15)",
|
|
169
168
|
)
|
|
170
169
|
)
|
|
171
|
-
# Add inner band (q_25 to q_75) -
|
|
170
|
+
# Add inner band (q_25 to q_75) - desaturated green-gray, slightly more visible
|
|
172
171
|
figure.add_trace(
|
|
173
172
|
go.Scatter(
|
|
174
173
|
x=sorted_df[x_col],
|
|
175
174
|
y=sorted_df["q_25"],
|
|
176
175
|
mode="lines",
|
|
177
|
-
line=dict(width=1, color="rgba(
|
|
176
|
+
line=dict(width=1, color="rgba(130, 180, 140, 0.3)"),
|
|
178
177
|
name="25 Percentile",
|
|
179
178
|
hoverinfo="skip",
|
|
180
179
|
showlegend=False,
|
|
@@ -185,118 +184,17 @@ def prediction_intervals(df, figure, x_col):
|
|
|
185
184
|
x=sorted_df[x_col],
|
|
186
185
|
y=sorted_df["q_75"],
|
|
187
186
|
mode="lines",
|
|
188
|
-
line=dict(width=1, color="rgba(
|
|
187
|
+
line=dict(width=1, color="rgba(130, 180, 140, 0.3)"),
|
|
189
188
|
name="75 Percentile",
|
|
190
189
|
hoverinfo="skip",
|
|
191
190
|
showlegend=False,
|
|
192
191
|
fill="tonexty",
|
|
193
|
-
fillcolor="rgba(
|
|
192
|
+
fillcolor="rgba(130, 180, 140, 0.18)",
|
|
194
193
|
)
|
|
195
194
|
)
|
|
196
195
|
return figure
|
|
197
196
|
|
|
198
197
|
|
|
199
|
-
def molecule_hover_tooltip(smiles: str, mol_id: str = None, width: int = 300, height: int = 200) -> list:
|
|
200
|
-
"""Generate a molecule hover tooltip from a SMILES string.
|
|
201
|
-
|
|
202
|
-
This function creates a visually appealing tooltip with a dark background
|
|
203
|
-
that displays the molecule ID at the top and structure below when hovering
|
|
204
|
-
over scatter plot points.
|
|
205
|
-
|
|
206
|
-
Args:
|
|
207
|
-
smiles: SMILES string representing the molecule
|
|
208
|
-
mol_id: Optional molecule ID to display at the top of the tooltip
|
|
209
|
-
width: Width of the molecule image in pixels (default: 300)
|
|
210
|
-
height: Height of the molecule image in pixels (default: 200)
|
|
211
|
-
|
|
212
|
-
Returns:
|
|
213
|
-
list: A list containing an html.Div with the ID header and molecule SVG,
|
|
214
|
-
or an html.Div with an error message if rendering fails
|
|
215
|
-
"""
|
|
216
|
-
try:
|
|
217
|
-
from workbench.utils.chem_utils.vis import svg_from_smiles
|
|
218
|
-
from workbench.utils.theme_manager import ThemeManager
|
|
219
|
-
|
|
220
|
-
# Get the background color from the current theme
|
|
221
|
-
background = ThemeManager().background()
|
|
222
|
-
|
|
223
|
-
# Generate the SVG image from SMILES
|
|
224
|
-
img = svg_from_smiles(smiles, width, height, background=background)
|
|
225
|
-
|
|
226
|
-
if img is None:
|
|
227
|
-
log.warning(f"Could not render molecule for SMILES: {smiles}")
|
|
228
|
-
return [
|
|
229
|
-
html.Div(
|
|
230
|
-
"Invalid SMILES",
|
|
231
|
-
className="custom-tooltip",
|
|
232
|
-
style={
|
|
233
|
-
"padding": "10px",
|
|
234
|
-
"color": "rgb(255, 140, 140)",
|
|
235
|
-
"width": f"{width}px",
|
|
236
|
-
"height": f"{height}px",
|
|
237
|
-
"display": "flex",
|
|
238
|
-
"alignItems": "center",
|
|
239
|
-
"justifyContent": "center",
|
|
240
|
-
},
|
|
241
|
-
)
|
|
242
|
-
]
|
|
243
|
-
|
|
244
|
-
# Build the tooltip with ID header and molecule image
|
|
245
|
-
children = []
|
|
246
|
-
|
|
247
|
-
# Add ID header if provided
|
|
248
|
-
if mol_id is not None:
|
|
249
|
-
children.append(
|
|
250
|
-
html.Div(
|
|
251
|
-
str(mol_id),
|
|
252
|
-
style={
|
|
253
|
-
"textAlign": "center",
|
|
254
|
-
"padding": "8px",
|
|
255
|
-
"color": "rgb(200, 200, 200)",
|
|
256
|
-
"fontSize": "14px",
|
|
257
|
-
"fontWeight": "bold",
|
|
258
|
-
"borderBottom": "1px solid rgba(128, 128, 128, 0.5)",
|
|
259
|
-
},
|
|
260
|
-
)
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
# Add molecule image
|
|
264
|
-
children.append(
|
|
265
|
-
html.Img(
|
|
266
|
-
src=img,
|
|
267
|
-
style={"padding": "0px", "margin": "0px", "display": "block"},
|
|
268
|
-
width=str(width),
|
|
269
|
-
height=str(height),
|
|
270
|
-
)
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
return [
|
|
274
|
-
html.Div(
|
|
275
|
-
children,
|
|
276
|
-
className="custom-tooltip",
|
|
277
|
-
style={"padding": "0px", "margin": "0px"},
|
|
278
|
-
)
|
|
279
|
-
]
|
|
280
|
-
|
|
281
|
-
except ImportError as e:
|
|
282
|
-
log.error(f"RDKit not available for molecule rendering: {e}")
|
|
283
|
-
return [
|
|
284
|
-
html.Div(
|
|
285
|
-
"RDKit not installed",
|
|
286
|
-
className="custom-tooltip",
|
|
287
|
-
style={
|
|
288
|
-
"padding": "10px",
|
|
289
|
-
"color": "rgb(255, 195, 140)",
|
|
290
|
-
"width": f"{width}px",
|
|
291
|
-
"height": f"{height}px",
|
|
292
|
-
"display": "flex",
|
|
293
|
-
"alignItems": "center",
|
|
294
|
-
"justifyContent": "center",
|
|
295
|
-
},
|
|
296
|
-
)
|
|
297
|
-
]
|
|
298
|
-
|
|
299
|
-
|
|
300
198
|
if __name__ == "__main__":
|
|
301
199
|
"""Exercise the Plot Utilities"""
|
|
302
200
|
import plotly.express as px
|
|
@@ -10,8 +10,10 @@ from workbench.api import Model, ModelType
|
|
|
10
10
|
from workbench.web_interface.components.component_interface import ComponentInterface
|
|
11
11
|
from workbench.web_interface.components.plugins.confusion_matrix import ConfusionMatrix
|
|
12
12
|
from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
|
|
13
|
+
from workbench.utils.deprecated_utils import deprecated
|
|
13
14
|
|
|
14
15
|
|
|
16
|
+
@deprecated(version="0.9")
|
|
15
17
|
class ModelPlot(ComponentInterface):
|
|
16
18
|
"""Model Metrics Components"""
|
|
17
19
|
|
|
@@ -3,7 +3,6 @@ import dash_bootstrap_components as dbc
|
|
|
3
3
|
import logging
|
|
4
4
|
import socket
|
|
5
5
|
|
|
6
|
-
|
|
7
6
|
# Workbench Imports
|
|
8
7
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginInputType
|
|
9
8
|
from workbench.api import DataSource, FeatureSet, Model, Endpoint, Meta
|
|
@@ -22,9 +22,7 @@ class AGTable(PluginInterface):
|
|
|
22
22
|
header_height = 30
|
|
23
23
|
row_height = 25
|
|
24
24
|
|
|
25
|
-
def create_component(
|
|
26
|
-
self, component_id: str, header_color: str = "rgb(120, 60, 60)", max_height: int = 500
|
|
27
|
-
) -> AgGrid:
|
|
25
|
+
def create_component(self, component_id: str, max_height: int = 500) -> AgGrid:
|
|
28
26
|
"""Create a Table Component without any data."""
|
|
29
27
|
self.component_id = component_id
|
|
30
28
|
self.max_height = max_height
|
|
@@ -112,4 +110,4 @@ if __name__ == "__main__":
|
|
|
112
110
|
test_df = pd.DataFrame(data)
|
|
113
111
|
|
|
114
112
|
# Run the Unit Test on the Plugin
|
|
115
|
-
PluginUnitTest(AGTable, theme="
|
|
113
|
+
PluginUnitTest(AGTable, theme="dark", input_data=test_df, max_height=500).run()
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
from dash import dcc, callback, Output, Input, State
|
|
4
4
|
import plotly.graph_objects as go
|
|
5
5
|
|
|
6
|
-
|
|
7
6
|
# Workbench Imports
|
|
8
7
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
9
8
|
from workbench.utils.theme_manager import ThemeManager
|
|
@@ -22,7 +21,6 @@ class ConfusionMatrix(PluginInterface):
|
|
|
22
21
|
self.component_id = None
|
|
23
22
|
self.current_highlight = None # Store the currently highlighted cell
|
|
24
23
|
self.theme_manager = ThemeManager()
|
|
25
|
-
self.colorscale = add_alpha_to_first_color(self.theme_manager.colorscale("heatmap"))
|
|
26
24
|
|
|
27
25
|
# Call the parent class constructor
|
|
28
26
|
super().__init__()
|
|
@@ -65,9 +63,8 @@ class ConfusionMatrix(PluginInterface):
|
|
|
65
63
|
if df is None:
|
|
66
64
|
return [self.display_text("No Data")]
|
|
67
65
|
|
|
68
|
-
#
|
|
69
|
-
|
|
70
|
-
# color_scale = sequential.Plasma
|
|
66
|
+
# Get the colorscale from the current theme
|
|
67
|
+
colorscale = add_alpha_to_first_color(self.theme_manager.colorscale("heatmap"))
|
|
71
68
|
|
|
72
69
|
# The confusion matrix is displayed in reverse order (flip the dataframe for correct orientation)
|
|
73
70
|
df = df.iloc[::-1]
|
|
@@ -89,7 +86,7 @@ class ConfusionMatrix(PluginInterface):
|
|
|
89
86
|
title="Count",
|
|
90
87
|
outlinewidth=1,
|
|
91
88
|
),
|
|
92
|
-
colorscale=
|
|
89
|
+
colorscale=colorscale,
|
|
93
90
|
)
|
|
94
91
|
)
|
|
95
92
|
|
|
@@ -8,9 +8,14 @@ from dash import html, callback, dcc, Input, Output, State
|
|
|
8
8
|
# Workbench Imports
|
|
9
9
|
from workbench.api import ModelType, ParameterStore
|
|
10
10
|
from workbench.cached.cached_model import CachedModel
|
|
11
|
-
from workbench.utils.markdown_utils import
|
|
11
|
+
from workbench.utils.markdown_utils import (
|
|
12
|
+
health_tag_markdown,
|
|
13
|
+
tags_to_markdown,
|
|
14
|
+
dict_to_markdown,
|
|
15
|
+
dict_to_collapsible_html,
|
|
16
|
+
df_to_html_table,
|
|
17
|
+
)
|
|
12
18
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
13
|
-
from workbench.utils.markdown_utils import tags_to_markdown, dict_to_markdown, dict_to_collapsible_html
|
|
14
19
|
|
|
15
20
|
|
|
16
21
|
class ModelDetails(PluginInterface):
|
|
@@ -44,7 +49,7 @@ class ModelDetails(PluginInterface):
|
|
|
44
49
|
dcc.Markdown(id=f"{self.component_id}-summary", dangerously_allow_html=True),
|
|
45
50
|
html.H5(children="Inference Metrics", style={"marginTop": "20px"}),
|
|
46
51
|
dcc.Dropdown(id=f"{self.component_id}-dropdown", className="dropdown"),
|
|
47
|
-
dcc.Markdown(id=f"{self.component_id}-metrics"),
|
|
52
|
+
dcc.Markdown(id=f"{self.component_id}-metrics", dangerously_allow_html=True),
|
|
48
53
|
],
|
|
49
54
|
)
|
|
50
55
|
|
|
@@ -175,15 +180,14 @@ class ModelDetails(PluginInterface):
|
|
|
175
180
|
markdown += " \nNo Data \n"
|
|
176
181
|
else:
|
|
177
182
|
markdown += " \n"
|
|
178
|
-
metrics = metrics.round(3)
|
|
179
183
|
|
|
180
184
|
# If the model is a classification model, have the index sorting match the class labels
|
|
181
185
|
if self.current_model.model_type == ModelType.CLASSIFIER:
|
|
182
|
-
# Sort the metrics by the class labels (if they match)
|
|
183
186
|
class_labels = self.current_model.class_labels()
|
|
184
187
|
if set(metrics.index) == set(class_labels):
|
|
185
188
|
metrics = metrics.reindex(class_labels)
|
|
186
|
-
|
|
189
|
+
|
|
190
|
+
markdown += df_to_html_table(metrics)
|
|
187
191
|
|
|
188
192
|
# Get additional inference metrics if they exist
|
|
189
193
|
model_name = self.current_model.name
|