workbench 0.8.224__py3-none-any.whl → 0.8.234__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +2 -0
  3. workbench/algorithms/dataframe/smart_aggregator.py +161 -0
  4. workbench/algorithms/sql/column_stats.py +0 -1
  5. workbench/algorithms/sql/correlations.py +0 -1
  6. workbench/algorithms/sql/descriptive_stats.py +0 -1
  7. workbench/api/meta.py +0 -1
  8. workbench/cached/cached_meta.py +0 -1
  9. workbench/cached/cached_model.py +37 -7
  10. workbench/core/artifacts/endpoint_core.py +12 -2
  11. workbench/core/artifacts/feature_set_core.py +66 -8
  12. workbench/core/cloud_platform/cloud_meta.py +0 -1
  13. workbench/model_script_utils/model_script_utils.py +30 -0
  14. workbench/model_script_utils/uq_harness.py +0 -1
  15. workbench/model_scripts/chemprop/chemprop.template +3 -0
  16. workbench/model_scripts/chemprop/generated_model_script.py +3 -3
  17. workbench/model_scripts/chemprop/model_script_utils.py +30 -0
  18. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  19. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  20. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
  21. workbench/model_scripts/pytorch_model/generated_model_script.py +50 -32
  22. workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
  23. workbench/model_scripts/pytorch_model/pytorch.template +47 -29
  24. workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
  25. workbench/model_scripts/script_generation.py +0 -1
  26. workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
  27. workbench/model_scripts/xgb_model/uq_harness.py +0 -1
  28. workbench/themes/dark/custom.css +85 -8
  29. workbench/themes/dark/plotly.json +6 -6
  30. workbench/themes/light/custom.css +172 -70
  31. workbench/themes/light/plotly.json +9 -9
  32. workbench/themes/midnight_blue/custom.css +48 -29
  33. workbench/themes/midnight_blue/plotly.json +1 -1
  34. workbench/utils/aws_utils.py +0 -1
  35. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  36. workbench/utils/chem_utils/vis.py +137 -27
  37. workbench/utils/clientside_callbacks.py +41 -0
  38. workbench/utils/markdown_utils.py +61 -0
  39. workbench/utils/pipeline_utils.py +0 -1
  40. workbench/utils/plot_utils.py +8 -110
  41. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  42. workbench/web_interface/components/model_plot.py +2 -0
  43. workbench/web_interface/components/plugin_unit_test.py +0 -1
  44. workbench/web_interface/components/plugins/ag_table.py +2 -4
  45. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  46. workbench/web_interface/components/plugins/model_details.py +28 -11
  47. workbench/web_interface/components/plugins/scatter_plot.py +56 -43
  48. workbench/web_interface/components/settings_menu.py +2 -1
  49. workbench/web_interface/page_views/main_page.py +0 -1
  50. {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/METADATA +31 -29
  51. {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/RECORD +55 -59
  52. {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/WHEEL +1 -1
  53. workbench/themes/quartz/base_css.url +0 -1
  54. workbench/themes/quartz/custom.css +0 -117
  55. workbench/themes/quartz/plotly.json +0 -642
  56. workbench/themes/quartz_dark/base_css.url +0 -1
  57. workbench/themes/quartz_dark/custom.css +0 -131
  58. workbench/themes/quartz_dark/plotly.json +0 -642
  59. {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/entry_points.txt +0 -0
  60. {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/licenses/LICENSE +0 -0
  61. {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/top_level.txt +0 -0
@@ -2,32 +2,122 @@
2
2
 
3
3
  import logging
4
4
  import base64
5
- import re
6
5
  from typing import Optional, Tuple
7
6
  from rdkit import Chem
8
7
  from rdkit.Chem import AllChem, Draw
9
8
  from rdkit.Chem.Draw import rdMolDraw2D
9
+ from dash import html
10
+
11
+ # Workbench Imports
12
+ from workbench.utils.color_utils import is_dark
10
13
 
11
14
  # Set up the logger
12
15
  log = logging.getLogger("workbench")
13
16
 
14
17
 
15
- def _is_dark(color: str) -> bool:
16
- """Determine if an rgba color is dark based on RGB average.
18
+ def molecule_hover_tooltip(
19
+ smiles: str, mol_id: str = None, width: int = 300, height: int = 200, background: str = None
20
+ ) -> list:
21
+ """Generate a molecule hover tooltip from a SMILES string.
22
+
23
+ This function creates a visually appealing tooltip with a dark background
24
+ that displays the molecule ID at the top and structure below when hovering
25
+ over scatter plot points.
17
26
 
18
27
  Args:
19
- color: Color in rgba(...) format
28
+ smiles: SMILES string representing the molecule
29
+ mol_id: Optional molecule ID to display at the top of the tooltip
30
+ width: Width of the molecule image in pixels (default: 300)
31
+ height: Height of the molecule image in pixels (default: 200)
32
+ background: Optional background color (if None, uses dark gray)
20
33
 
21
34
  Returns:
22
- True if the color is dark, False otherwise
35
+ list: A list containing an html.Div with the ID header and molecule SVG,
36
+ or an html.Div with an error message if rendering fails
23
37
  """
24
- match = re.match(r"rgba?\((\d+),\s*(\d+),\s*(\d+)", color)
25
- if not match:
26
- log.warning(f"Invalid color format: {color}, defaulting to dark")
27
- return True # Default to dark mode on error
38
+ try:
28
39
 
29
- r, g, b = map(int, match.groups())
30
- return (r + g + b) / 3 < 128
40
+ # Use provided background or default to dark gray
41
+ if background is None:
42
+ background = "rgba(64, 64, 64, 1)"
43
+
44
+ # Generate the SVG image from SMILES (base64 encoded data URI)
45
+ img = svg_from_smiles(smiles, width, height, background=background)
46
+
47
+ if img is None:
48
+ log.warning(f"Could not render molecule for SMILES: {smiles}")
49
+ return [
50
+ html.Div(
51
+ "Invalid SMILES",
52
+ className="custom-tooltip",
53
+ style={
54
+ "padding": "10px",
55
+ "color": "rgb(255, 140, 140)",
56
+ "width": f"{width}px",
57
+ "height": f"{height}px",
58
+ "display": "flex",
59
+ "alignItems": "center",
60
+ "justifyContent": "center",
61
+ },
62
+ )
63
+ ]
64
+
65
+ # Build the tooltip with ID header and molecule image
66
+ children = []
67
+
68
+ # Add ID header if provided
69
+ if mol_id is not None:
70
+ # Set text color based on background brightness
71
+ text_color = "rgb(200, 200, 200)" if is_dark(background) else "rgb(60, 60, 60)"
72
+ children.append(
73
+ html.Div(
74
+ str(mol_id),
75
+ style={
76
+ "textAlign": "center",
77
+ "padding": "8px",
78
+ "color": text_color,
79
+ "fontSize": "14px",
80
+ "fontWeight": "bold",
81
+ "borderBottom": "1px solid rgba(128, 128, 128, 0.5)",
82
+ },
83
+ )
84
+ )
85
+
86
+ # Add molecule image
87
+ children.append(
88
+ html.Img(
89
+ src=img,
90
+ style={"padding": "0px", "margin": "0px", "display": "block"},
91
+ width=str(width),
92
+ height=str(height),
93
+ )
94
+ )
95
+
96
+ return [
97
+ html.Div(
98
+ children,
99
+ className="custom-tooltip",
100
+ style={"padding": "0px", "margin": "0px"},
101
+ )
102
+ ]
103
+
104
+ except ImportError as e:
105
+ log.error(f"RDKit not available for molecule rendering: {e}")
106
+ return [
107
+ html.Div(
108
+ "RDKit not installed",
109
+ className="custom-tooltip",
110
+ style={
111
+ "padding": "10px",
112
+ "color": "rgb(255, 195, 140)",
113
+ "width": f"{width}px",
114
+ "height": f"{height}px",
115
+ "display": "flex",
116
+ "alignItems": "center",
117
+ "justifyContent": "center",
118
+ },
119
+ )
120
+ ]
31
121
 
32
122
 
33
123
  def _rgba_to_tuple(rgba: str) -> Tuple[float, float, float, float]:
@@ -75,7 +165,13 @@ def _configure_draw_options(options: Draw.MolDrawOptions, background: str) -> No
75
165
  options: RDKit drawing options object
76
166
  background: Background color string
77
167
  """
78
- if _is_dark(background):
168
+ try:
169
+ if is_dark(background):
170
+ rdMolDraw2D.SetDarkMode(options)
171
+ # Light backgrounds use RDKit defaults (no action needed)
172
+ except ValueError:
173
+ # Default to dark mode if color format is invalid
174
+ log.warning(f"Invalid color format: {background}, defaulting to dark mode")
79
175
  rdMolDraw2D.SetDarkMode(options)
80
176
  options.setBackgroundColour(_rgba_to_tuple(background))
81
177
 
@@ -137,7 +233,7 @@ def svg_from_smiles(
137
233
  drawer.DrawMolecule(mol)
138
234
  drawer.FinishDrawing()
139
235
 
140
- # Encode SVG
236
+ # Encode SVG as base64 data URI
141
237
  svg = drawer.GetDrawingText()
142
238
  encoded_svg = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
143
239
  return f"data:image/svg+xml;base64,{encoded_svg}"
@@ -222,7 +318,7 @@ if __name__ == "__main__":
222
318
  # Test 6: Color parsing functions
223
319
  print("\n6. Testing color utility functions...")
224
320
  test_colors = [
225
- ("invalid_color", True, (0.25, 0.25, 0.25, 1.0)), # Should use defaults
321
+ ("invalid_color", None, (0.25, 0.25, 0.25, 1.0)), # Should raise ValueError
226
322
  ("rgba(255, 255, 255, 1)", False, (1.0, 1.0, 1.0, 1.0)),
227
323
  ("rgba(0, 0, 0, 1)", True, (0.0, 0.0, 0.0, 1.0)),
228
324
  ("rgba(64, 64, 64, 0.5)", True, (0.251, 0.251, 0.251, 0.5)),
@@ -230,24 +326,38 @@ if __name__ == "__main__":
230
326
  ]
231
327
 
232
328
  for color, expected_dark, expected_tuple in test_colors:
233
- is_dark_result = _is_dark(color)
234
- tuple_result = _rgba_to_tuple(color)
235
-
236
- dark_status = "✓" if is_dark_result == expected_dark else "✗"
237
- print(f" {dark_status} is_dark('{color[:20]}...'): {is_dark_result} == {expected_dark}")
329
+ try:
330
+ is_dark_result = is_dark(color)
331
+ if expected_dark is None:
332
+ print(f" is_dark('{color[:20]}...'): Expected ValueError but got {is_dark_result}")
333
+ else:
334
+ dark_status = "✓" if is_dark_result == expected_dark else "✗"
335
+ print(f" {dark_status} is_dark('{color[:20]}...'): {is_dark_result} == {expected_dark}")
336
+ except ValueError:
337
+ if expected_dark is None:
338
+ print(f" ✓ is_dark('{color[:20]}...'): Correctly raised ValueError")
339
+ else:
340
+ print(f" ✗ is_dark('{color[:20]}...'): Unexpected ValueError")
238
341
 
342
+ tuple_result = _rgba_to_tuple(color)
239
343
  # Check tuple values with tolerance for floating point
240
344
  tuple_match = all(abs(a - b) < 0.01 for a, b in zip(tuple_result, expected_tuple))
241
345
  tuple_status = "✓" if tuple_match else "✗"
242
346
  print(f" {tuple_status} rgba_to_tuple('{color[:20]}...'): matches expected")
243
347
 
244
- # Test the show function (will open image windows)
245
- print("\n7. Testing show function (will open image windows)...")
246
- try:
247
- show(test_molecules["aspirin"])
248
- show(test_molecules["aspirin"], background="rgba(220, 220, 220, 1)")
249
- print(" ✓ show() function executed (check for image window)")
250
- except Exception as e:
251
- print(f" ✗ show() function failed: {e}")
348
+ # Test the tooltip generation in a simple Dash app
349
+ from dash import Dash
350
+
351
+ app = Dash(__name__)
352
+ app.layout = html.Div(
353
+ [
354
+ html.Div("Tooltip Preview:", style={"color": "white", "marginBottom": "20px"}),
355
+ *molecule_hover_tooltip("CC(=O)OC1=CC=CC=C1C(=O)O", mol_id="Aspirin", background="rgba(200, 30, 30, 1)"),
356
+ ],
357
+ style={"background": "#1a1a1a", "padding": "50px"},
358
+ )
359
+
360
+ if __name__ == "__main__":
361
+ app.run(debug=True)
252
362
 
253
363
  print("\n✅ All tests completed!")
@@ -0,0 +1,41 @@
1
+ """Clientside JavaScript callbacks for Dash components.
2
+
3
+ These functions return JavaScript code strings for use with Dash's clientside_callback.
4
+ Using clientside callbacks avoids server round-trips for simple UI interactions.
5
+ """
6
+
7
+
8
+ def circle_overlay_callback(circle_data_uri: str) -> str:
9
+ """Returns JS function for circle overlay on scatter plot hover.
10
+
11
+ Args:
12
+ circle_data_uri: Base64-encoded SVG data URI for the circle overlay
13
+
14
+ Returns:
15
+ JavaScript function string for use with clientside_callback
16
+ """
17
+ return f"""
18
+ function(hoverData) {{
19
+ if (!hoverData) {{
20
+ return [false, window.dash_clientside.no_update, window.dash_clientside.no_update];
21
+ }}
22
+ var bbox = hoverData.points[0].bbox;
23
+ var centerX = (bbox.x0 + bbox.x1) / 2;
24
+ var centerY = (bbox.y0 + bbox.y1) / 2;
25
+ var adjustedBbox = {{
26
+ x0: centerX - 50,
27
+ x1: centerX + 50,
28
+ y0: centerY - 162,
29
+ y1: centerY - 62
30
+ }};
31
+ var imgElement = {{
32
+ type: 'Img',
33
+ namespace: 'dash_html_components',
34
+ props: {{
35
+ src: '{circle_data_uri}',
36
+ style: {{width: '100px', height: '100px'}}
37
+ }}
38
+ }};
39
+ return [true, adjustedBbox, [imgElement]];
40
+ }}
41
+ """
@@ -1,5 +1,7 @@
1
1
  """Markdown Utility/helper methods"""
2
2
 
3
+ import math
4
+
3
5
  from workbench.utils.symbols import health_icons
4
6
 
5
7
 
@@ -185,6 +187,65 @@ def dict_to_collapsible_html(data: dict, title: str = None, collapse_all: bool =
185
187
  return result
186
188
 
187
189
 
190
+ def df_to_html_table(df, round_digits: int = 2, margin_bottom: int = 30) -> str:
191
+ """Convert a DataFrame to a compact styled HTML table (horizontal layout).
192
+
193
+ Args:
194
+ df: DataFrame with metrics (can be single or multi-row)
195
+ round_digits: Number of decimal places to round to (default: 2)
196
+ margin_bottom: Bottom margin in pixels (default: 30)
197
+
198
+ Returns:
199
+ str: HTML table string
200
+ """
201
+ # Handle index: reset if named (keeps as column), otherwise drop
202
+ if df.index.name:
203
+ df = df.reset_index()
204
+ else:
205
+ df = df.reset_index(drop=True)
206
+
207
+ # Round numeric columns
208
+ df = df.round(round_digits)
209
+
210
+ # Table styles
211
+ container_style = f"display: flex; justify-content: center; margin-top: 10px; margin-bottom: {margin_bottom}px;"
212
+ table_style = "border-collapse: collapse; width: 100%; font-size: 15px;"
213
+ header_style = (
214
+ "background: linear-gradient(to bottom, #4a4a4a 0%, #2d2d2d 100%); "
215
+ "color: white; padding: 4px 8px; text-align: center;"
216
+ )
217
+ cell_style = "padding: 3px 8px; text-align: center; border-bottom: 1px solid #444;"
218
+
219
+ # Build the HTML table (wrapped in centered container)
220
+ html = f'<div style="{container_style}"><table style="{table_style}">'
221
+
222
+ # Header row
223
+ html += "<tr>"
224
+ for col in df.columns:
225
+ html += f'<th style="{header_style}">{col}</th>'
226
+ html += "</tr>"
227
+
228
+ # Data rows
229
+ for _, row in df.iterrows():
230
+ html += "<tr>"
231
+ for val in row:
232
+ # Format value: integers without decimal, floats rounded
233
+ if isinstance(val, float):
234
+ if math.isnan(val):
235
+ formatted_val = "NaN"
236
+ elif val == int(val):
237
+ formatted_val = int(val)
238
+ else:
239
+ formatted_val = round(val, round_digits)
240
+ else:
241
+ formatted_val = val
242
+ html += f'<td style="{cell_style}">{formatted_val}</td>'
243
+ html += "</tr>"
244
+
245
+ html += "</table></div>"
246
+ return html
247
+
248
+
188
249
  if __name__ == "__main__":
189
250
  """Exercise the Markdown Utilities"""
190
251
  from workbench.api.model import Model
@@ -6,7 +6,6 @@ import json
6
6
  # Workbench Imports
7
7
  from workbench.api import DataSource, FeatureSet, Model, Endpoint, ParameterStore
8
8
 
9
-
10
9
  # Set up the logging
11
10
  log = logging.getLogger("workbench")
12
11
 
@@ -4,7 +4,6 @@ import logging
4
4
  import numpy as np
5
5
  import pandas as pd
6
6
  import plotly.graph_objects as go
7
- from dash import html
8
7
 
9
8
  log = logging.getLogger("workbench")
10
9
 
@@ -143,13 +142,13 @@ def prediction_intervals(df, figure, x_col):
143
142
 
144
143
  # Sort dataframe by x_col for connected lines
145
144
  sorted_df = df.sort_values(by=x_col)
146
- # Add outer band (q_025 to q_975) - more transparent
145
+ # Add outer band (q_025 to q_975) - desaturated blue-gray, more transparent
147
146
  figure.add_trace(
148
147
  go.Scatter(
149
148
  x=sorted_df[x_col],
150
149
  y=sorted_df["q_025"],
151
150
  mode="lines",
152
- line=dict(width=1, color="rgba(99, 110, 250, 0.25)"),
151
+ line=dict(width=1, color="rgba(120, 130, 180, 0.3)"),
153
152
  name="2.5 Percentile",
154
153
  hoverinfo="skip",
155
154
  showlegend=False,
@@ -160,21 +159,21 @@ def prediction_intervals(df, figure, x_col):
160
159
  x=sorted_df[x_col],
161
160
  y=sorted_df["q_975"],
162
161
  mode="lines",
163
- line=dict(width=1, color="rgba(99, 110, 250, 0.25)"),
162
+ line=dict(width=1, color="rgba(120, 130, 180, 0.3)"),
164
163
  name="97.5 Percentile",
165
164
  hoverinfo="skip",
166
165
  showlegend=False,
167
166
  fill="tonexty",
168
- fillcolor="rgba(99, 110, 250, 0.2)",
167
+ fillcolor="rgba(120, 130, 180, 0.15)",
169
168
  )
170
169
  )
171
- # Add inner band (q_25 to q_75) - less transparent
170
+ # Add inner band (q_25 to q_75) - desaturated green-gray, slightly more visible
172
171
  figure.add_trace(
173
172
  go.Scatter(
174
173
  x=sorted_df[x_col],
175
174
  y=sorted_df["q_25"],
176
175
  mode="lines",
177
- line=dict(width=1, color="rgba(99, 250, 110, 0.25)"),
176
+ line=dict(width=1, color="rgba(130, 180, 140, 0.3)"),
178
177
  name="25 Percentile",
179
178
  hoverinfo="skip",
180
179
  showlegend=False,
@@ -185,118 +184,17 @@ def prediction_intervals(df, figure, x_col):
185
184
  x=sorted_df[x_col],
186
185
  y=sorted_df["q_75"],
187
186
  mode="lines",
188
- line=dict(width=1, color="rgba(99, 250, 100, 0.25)"),
187
+ line=dict(width=1, color="rgba(130, 180, 140, 0.3)"),
189
188
  name="75 Percentile",
190
189
  hoverinfo="skip",
191
190
  showlegend=False,
192
191
  fill="tonexty",
193
- fillcolor="rgba(99, 250, 110, 0.2)",
192
+ fillcolor="rgba(130, 180, 140, 0.18)",
194
193
  )
195
194
  )
196
195
  return figure
197
196
 
198
197
 
199
- def molecule_hover_tooltip(smiles: str, mol_id: str = None, width: int = 300, height: int = 200) -> list:
200
- """Generate a molecule hover tooltip from a SMILES string.
201
-
202
- This function creates a visually appealing tooltip with a dark background
203
- that displays the molecule ID at the top and structure below when hovering
204
- over scatter plot points.
205
-
206
- Args:
207
- smiles: SMILES string representing the molecule
208
- mol_id: Optional molecule ID to display at the top of the tooltip
209
- width: Width of the molecule image in pixels (default: 300)
210
- height: Height of the molecule image in pixels (default: 200)
211
-
212
- Returns:
213
- list: A list containing an html.Div with the ID header and molecule SVG,
214
- or an html.Div with an error message if rendering fails
215
- """
216
- try:
217
- from workbench.utils.chem_utils.vis import svg_from_smiles
218
- from workbench.utils.theme_manager import ThemeManager
219
-
220
- # Get the background color from the current theme
221
- background = ThemeManager().background()
222
-
223
- # Generate the SVG image from SMILES
224
- img = svg_from_smiles(smiles, width, height, background=background)
225
-
226
- if img is None:
227
- log.warning(f"Could not render molecule for SMILES: {smiles}")
228
- return [
229
- html.Div(
230
- "Invalid SMILES",
231
- className="custom-tooltip",
232
- style={
233
- "padding": "10px",
234
- "color": "rgb(255, 140, 140)",
235
- "width": f"{width}px",
236
- "height": f"{height}px",
237
- "display": "flex",
238
- "alignItems": "center",
239
- "justifyContent": "center",
240
- },
241
- )
242
- ]
243
-
244
- # Build the tooltip with ID header and molecule image
245
- children = []
246
-
247
- # Add ID header if provided
248
- if mol_id is not None:
249
- children.append(
250
- html.Div(
251
- str(mol_id),
252
- style={
253
- "textAlign": "center",
254
- "padding": "8px",
255
- "color": "rgb(200, 200, 200)",
256
- "fontSize": "14px",
257
- "fontWeight": "bold",
258
- "borderBottom": "1px solid rgba(128, 128, 128, 0.5)",
259
- },
260
- )
261
- )
262
-
263
- # Add molecule image
264
- children.append(
265
- html.Img(
266
- src=img,
267
- style={"padding": "0px", "margin": "0px", "display": "block"},
268
- width=str(width),
269
- height=str(height),
270
- )
271
- )
272
-
273
- return [
274
- html.Div(
275
- children,
276
- className="custom-tooltip",
277
- style={"padding": "0px", "margin": "0px"},
278
- )
279
- ]
280
-
281
- except ImportError as e:
282
- log.error(f"RDKit not available for molecule rendering: {e}")
283
- return [
284
- html.Div(
285
- "RDKit not installed",
286
- className="custom-tooltip",
287
- style={
288
- "padding": "10px",
289
- "color": "rgb(255, 195, 140)",
290
- "width": f"{width}px",
291
- "height": f"{height}px",
292
- "display": "flex",
293
- "alignItems": "center",
294
- "justifyContent": "center",
295
- },
296
- )
297
- ]
298
-
299
-
300
198
  if __name__ == "__main__":
301
199
  """Exercise the Plot Utilities"""
302
200
  import plotly.express as px
@@ -11,7 +11,6 @@ import logging
11
11
  from workbench.algorithms.dataframe.aggregation import aggregate
12
12
  from workbench.algorithms.dataframe.projection_2d import Projection2D
13
13
 
14
-
15
14
  # Workbench Logger
16
15
  log = logging.getLogger("workbench")
17
16
 
@@ -10,8 +10,10 @@ from workbench.api import Model, ModelType
10
10
  from workbench.web_interface.components.component_interface import ComponentInterface
11
11
  from workbench.web_interface.components.plugins.confusion_matrix import ConfusionMatrix
12
12
  from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
13
+ from workbench.utils.deprecated_utils import deprecated
13
14
 
14
15
 
16
+ @deprecated(version="0.9")
15
17
  class ModelPlot(ComponentInterface):
16
18
  """Model Metrics Components"""
17
19
 
@@ -3,7 +3,6 @@ import dash_bootstrap_components as dbc
3
3
  import logging
4
4
  import socket
5
5
 
6
-
7
6
  # Workbench Imports
8
7
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginInputType
9
8
  from workbench.api import DataSource, FeatureSet, Model, Endpoint, Meta
@@ -22,9 +22,7 @@ class AGTable(PluginInterface):
22
22
  header_height = 30
23
23
  row_height = 25
24
24
 
25
- def create_component(
26
- self, component_id: str, header_color: str = "rgb(120, 60, 60)", max_height: int = 500
27
- ) -> AgGrid:
25
+ def create_component(self, component_id: str, max_height: int = 500) -> AgGrid:
28
26
  """Create a Table Component without any data."""
29
27
  self.component_id = component_id
30
28
  self.max_height = max_height
@@ -112,4 +110,4 @@ if __name__ == "__main__":
112
110
  test_df = pd.DataFrame(data)
113
111
 
114
112
  # Run the Unit Test on the Plugin
115
- PluginUnitTest(AGTable, theme="quartz", input_data=test_df, max_height=500).run()
113
+ PluginUnitTest(AGTable, theme="dark", input_data=test_df, max_height=500).run()
@@ -3,7 +3,6 @@
3
3
  from dash import dcc, callback, Output, Input, State
4
4
  import plotly.graph_objects as go
5
5
 
6
-
7
6
  # Workbench Imports
8
7
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
9
8
  from workbench.utils.theme_manager import ThemeManager
@@ -22,7 +21,6 @@ class ConfusionMatrix(PluginInterface):
22
21
  self.component_id = None
23
22
  self.current_highlight = None # Store the currently highlighted cell
24
23
  self.theme_manager = ThemeManager()
25
- self.colorscale = add_alpha_to_first_color(self.theme_manager.colorscale("heatmap"))
26
24
 
27
25
  # Call the parent class constructor
28
26
  super().__init__()
@@ -65,9 +63,8 @@ class ConfusionMatrix(PluginInterface):
65
63
  if df is None:
66
64
  return [self.display_text("No Data")]
67
65
 
68
- # Use Plotly's default theme-friendly colorscale
69
- # from plotly.colors import sequential
70
- # color_scale = sequential.Plasma
66
+ # Get the colorscale from the current theme
67
+ colorscale = add_alpha_to_first_color(self.theme_manager.colorscale("heatmap"))
71
68
 
72
69
  # The confusion matrix is displayed in reverse order (flip the dataframe for correct orientation)
73
70
  df = df.iloc[::-1]
@@ -89,7 +86,7 @@ class ConfusionMatrix(PluginInterface):
89
86
  title="Count",
90
87
  outlinewidth=1,
91
88
  ),
92
- colorscale=self.colorscale,
89
+ colorscale=colorscale,
93
90
  )
94
91
  )
95
92