workbench 0.8.177__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +5 -5
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +319 -204
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -82
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +260 -76
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/RECORD +121 -106
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -494
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -494
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
workbench/utils/plot_utils.py
CHANGED
|
@@ -1,14 +1,20 @@
|
|
|
1
1
|
"""Plot Utilities for Workbench"""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
import numpy as np
|
|
4
5
|
import pandas as pd
|
|
5
6
|
import plotly.graph_objects as go
|
|
7
|
+
from dash import html
|
|
8
|
+
|
|
9
|
+
from workbench.utils.color_utils import is_dark
|
|
10
|
+
|
|
11
|
+
log = logging.getLogger("workbench")
|
|
6
12
|
|
|
7
13
|
|
|
8
14
|
# For approximating beeswarm effect
|
|
9
15
|
def beeswarm_offsets(values, point_size=0.05, precision=2, max_offset=0.3):
|
|
10
16
|
"""
|
|
11
|
-
Generate
|
|
17
|
+
Generate beeswarm offsets using random jitter with collision avoidance.
|
|
12
18
|
|
|
13
19
|
Args:
|
|
14
20
|
values: Array of positions to be adjusted
|
|
@@ -22,42 +28,55 @@ def beeswarm_offsets(values, point_size=0.05, precision=2, max_offset=0.3):
|
|
|
22
28
|
values = np.asarray(values)
|
|
23
29
|
rounded = np.round(values, precision)
|
|
24
30
|
offsets = np.zeros_like(values, dtype=float)
|
|
25
|
-
|
|
26
|
-
# Sort indices by original values
|
|
27
|
-
sorted_idx = np.argsort(values)
|
|
31
|
+
rng = np.random.default_rng(42) # Fixed seed for reproducibility
|
|
28
32
|
|
|
29
33
|
for val in np.unique(rounded):
|
|
30
34
|
# Get indices belonging to this group
|
|
31
|
-
|
|
35
|
+
group_mask = rounded == val
|
|
36
|
+
group_idx = np.where(group_mask)[0]
|
|
32
37
|
|
|
33
38
|
if len(group_idx) > 1:
|
|
34
39
|
# Track occupied positions for collision detection
|
|
35
40
|
occupied = []
|
|
36
41
|
|
|
37
42
|
for idx in group_idx:
|
|
38
|
-
#
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
#
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
43
|
+
# Try random positions, starting near center and expanding outward
|
|
44
|
+
best_offset = 0
|
|
45
|
+
found = False
|
|
46
|
+
|
|
47
|
+
# First point goes to center
|
|
48
|
+
if not occupied:
|
|
49
|
+
found = True
|
|
50
|
+
else:
|
|
51
|
+
# Try random positions with increasing spread
|
|
52
|
+
for attempt in range(50):
|
|
53
|
+
# Gradually increase the range of random offsets
|
|
54
|
+
spread = min(max_offset, point_size * (1 + attempt * 0.5))
|
|
55
|
+
offset = rng.uniform(-spread, spread)
|
|
56
|
+
|
|
57
|
+
# Check for collision with occupied positions
|
|
58
|
+
if not any(abs(offset - pos) < point_size for pos in occupied):
|
|
59
|
+
best_offset = offset
|
|
60
|
+
found = True
|
|
61
|
+
break
|
|
62
|
+
|
|
63
|
+
# If no free position found after attempts, find the least crowded spot
|
|
64
|
+
if not found:
|
|
65
|
+
# Try a grid of positions and pick one with most space
|
|
66
|
+
candidates = np.linspace(-max_offset, max_offset, 20)
|
|
67
|
+
rng.shuffle(candidates)
|
|
68
|
+
for candidate in candidates:
|
|
69
|
+
if not any(abs(candidate - pos) < point_size * 0.8 for pos in occupied):
|
|
70
|
+
best_offset = candidate
|
|
71
|
+
found = True
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
# Last resort: just use a random position within bounds
|
|
75
|
+
if not found:
|
|
76
|
+
best_offset = rng.uniform(-max_offset, max_offset)
|
|
77
|
+
|
|
78
|
+
offsets[idx] = best_offset
|
|
79
|
+
occupied.append(best_offset)
|
|
61
80
|
|
|
62
81
|
return offsets
|
|
63
82
|
|
|
@@ -132,7 +151,7 @@ def prediction_intervals(df, figure, x_col):
|
|
|
132
151
|
x=sorted_df[x_col],
|
|
133
152
|
y=sorted_df["q_025"],
|
|
134
153
|
mode="lines",
|
|
135
|
-
line=dict(width=1, color="rgba(99, 110, 250, 0.
|
|
154
|
+
line=dict(width=1, color="rgba(99, 110, 250, 0.5)"),
|
|
136
155
|
name="2.5 Percentile",
|
|
137
156
|
hoverinfo="skip",
|
|
138
157
|
showlegend=False,
|
|
@@ -143,12 +162,12 @@ def prediction_intervals(df, figure, x_col):
|
|
|
143
162
|
x=sorted_df[x_col],
|
|
144
163
|
y=sorted_df["q_975"],
|
|
145
164
|
mode="lines",
|
|
146
|
-
line=dict(width=1, color="rgba(99, 110, 250, 0.
|
|
165
|
+
line=dict(width=1, color="rgba(99, 110, 250, 0.5)"),
|
|
147
166
|
name="97.5 Percentile",
|
|
148
167
|
hoverinfo="skip",
|
|
149
168
|
showlegend=False,
|
|
150
169
|
fill="tonexty",
|
|
151
|
-
fillcolor="rgba(99, 110, 250, 0.
|
|
170
|
+
fillcolor="rgba(99, 110, 250, 0.35)",
|
|
152
171
|
)
|
|
153
172
|
)
|
|
154
173
|
# Add inner band (q_25 to q_75) - less transparent
|
|
@@ -157,7 +176,7 @@ def prediction_intervals(df, figure, x_col):
|
|
|
157
176
|
x=sorted_df[x_col],
|
|
158
177
|
y=sorted_df["q_25"],
|
|
159
178
|
mode="lines",
|
|
160
|
-
line=dict(width=1, color="rgba(99, 250, 110, 0.
|
|
179
|
+
line=dict(width=1, color="rgba(99, 250, 110, 0.5)"),
|
|
161
180
|
name="25 Percentile",
|
|
162
181
|
hoverinfo="skip",
|
|
163
182
|
showlegend=False,
|
|
@@ -168,17 +187,123 @@ def prediction_intervals(df, figure, x_col):
|
|
|
168
187
|
x=sorted_df[x_col],
|
|
169
188
|
y=sorted_df["q_75"],
|
|
170
189
|
mode="lines",
|
|
171
|
-
line=dict(width=1, color="rgba(99, 250,
|
|
190
|
+
line=dict(width=1, color="rgba(99, 250, 110, 0.5)"),
|
|
172
191
|
name="75 Percentile",
|
|
173
192
|
hoverinfo="skip",
|
|
174
193
|
showlegend=False,
|
|
175
194
|
fill="tonexty",
|
|
176
|
-
fillcolor="rgba(99, 250, 110, 0.
|
|
195
|
+
fillcolor="rgba(99, 250, 110, 0.35)",
|
|
177
196
|
)
|
|
178
197
|
)
|
|
179
198
|
return figure
|
|
180
199
|
|
|
181
200
|
|
|
201
|
+
def molecule_hover_tooltip(
|
|
202
|
+
smiles: str, mol_id: str = None, width: int = 300, height: int = 200, background: str = None
|
|
203
|
+
) -> list:
|
|
204
|
+
"""Generate a molecule hover tooltip from a SMILES string.
|
|
205
|
+
|
|
206
|
+
This function creates a visually appealing tooltip with a dark background
|
|
207
|
+
that displays the molecule ID at the top and structure below when hovering
|
|
208
|
+
over scatter plot points.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
smiles: SMILES string representing the molecule
|
|
212
|
+
mol_id: Optional molecule ID to display at the top of the tooltip
|
|
213
|
+
width: Width of the molecule image in pixels (default: 300)
|
|
214
|
+
height: Height of the molecule image in pixels (default: 200)
|
|
215
|
+
background: Optional background color (if None, uses dark gray)
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
list: A list containing an html.Div with the ID header and molecule SVG,
|
|
219
|
+
or an html.Div with an error message if rendering fails
|
|
220
|
+
"""
|
|
221
|
+
try:
|
|
222
|
+
from workbench.utils.chem_utils.vis import svg_from_smiles
|
|
223
|
+
|
|
224
|
+
# Use provided background or default to dark gray
|
|
225
|
+
if background is None:
|
|
226
|
+
background = "rgba(64, 64, 64, 1)"
|
|
227
|
+
|
|
228
|
+
# Generate the SVG image from SMILES (base64 encoded data URI)
|
|
229
|
+
img = svg_from_smiles(smiles, width, height, background=background)
|
|
230
|
+
|
|
231
|
+
if img is None:
|
|
232
|
+
log.warning(f"Could not render molecule for SMILES: {smiles}")
|
|
233
|
+
return [
|
|
234
|
+
html.Div(
|
|
235
|
+
"Invalid SMILES",
|
|
236
|
+
className="custom-tooltip",
|
|
237
|
+
style={
|
|
238
|
+
"padding": "10px",
|
|
239
|
+
"color": "rgb(255, 140, 140)",
|
|
240
|
+
"width": f"{width}px",
|
|
241
|
+
"height": f"{height}px",
|
|
242
|
+
"display": "flex",
|
|
243
|
+
"alignItems": "center",
|
|
244
|
+
"justifyContent": "center",
|
|
245
|
+
},
|
|
246
|
+
)
|
|
247
|
+
]
|
|
248
|
+
|
|
249
|
+
# Build the tooltip with ID header and molecule image
|
|
250
|
+
children = []
|
|
251
|
+
|
|
252
|
+
# Add ID header if provided
|
|
253
|
+
if mol_id is not None:
|
|
254
|
+
# Set text color based on background brightness
|
|
255
|
+
text_color = "rgb(200, 200, 200)" if is_dark(background) else "rgb(60, 60, 60)"
|
|
256
|
+
children.append(
|
|
257
|
+
html.Div(
|
|
258
|
+
str(mol_id),
|
|
259
|
+
style={
|
|
260
|
+
"textAlign": "center",
|
|
261
|
+
"padding": "8px",
|
|
262
|
+
"color": text_color,
|
|
263
|
+
"fontSize": "14px",
|
|
264
|
+
"fontWeight": "bold",
|
|
265
|
+
"borderBottom": "1px solid rgba(128, 128, 128, 0.5)",
|
|
266
|
+
},
|
|
267
|
+
)
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Add molecule image
|
|
271
|
+
children.append(
|
|
272
|
+
html.Img(
|
|
273
|
+
src=img,
|
|
274
|
+
style={"padding": "0px", "margin": "0px", "display": "block"},
|
|
275
|
+
width=str(width),
|
|
276
|
+
height=str(height),
|
|
277
|
+
)
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
return [
|
|
281
|
+
html.Div(
|
|
282
|
+
children,
|
|
283
|
+
className="custom-tooltip",
|
|
284
|
+
style={"padding": "0px", "margin": "0px"},
|
|
285
|
+
)
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
except ImportError as e:
|
|
289
|
+
log.error(f"RDKit not available for molecule rendering: {e}")
|
|
290
|
+
return [
|
|
291
|
+
html.Div(
|
|
292
|
+
"RDKit not installed",
|
|
293
|
+
className="custom-tooltip",
|
|
294
|
+
style={
|
|
295
|
+
"padding": "10px",
|
|
296
|
+
"color": "rgb(255, 195, 140)",
|
|
297
|
+
"width": f"{width}px",
|
|
298
|
+
"height": f"{height}px",
|
|
299
|
+
"display": "flex",
|
|
300
|
+
"alignItems": "center",
|
|
301
|
+
"justifyContent": "center",
|
|
302
|
+
},
|
|
303
|
+
)
|
|
304
|
+
]
|
|
305
|
+
|
|
306
|
+
|
|
182
307
|
if __name__ == "__main__":
|
|
183
308
|
"""Exercise the Plot Utilities"""
|
|
184
309
|
import plotly.express as px
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""PyTorch Tabular utilities for Workbench models."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import tarfile
|
|
6
|
+
import tempfile
|
|
7
|
+
from typing import Any, Tuple
|
|
8
|
+
|
|
9
|
+
import awswrangler as wr
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
from workbench.utils.aws_utils import pull_s3_data
|
|
13
|
+
from workbench.utils.metrics_utils import compute_metrics_from_predictions
|
|
14
|
+
|
|
15
|
+
log = logging.getLogger("workbench")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def download_and_extract_model(s3_uri: str, model_dir: str) -> None:
|
|
19
|
+
"""Download and extract a PyTorch model artifact from S3.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
s3_uri: S3 URI of the model.tar.gz artifact
|
|
23
|
+
model_dir: Local directory to extract the model to
|
|
24
|
+
"""
|
|
25
|
+
with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
|
|
26
|
+
tmp_path = tmp.name
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
wr.s3.download(path=s3_uri, local_file=tmp_path)
|
|
30
|
+
with tarfile.open(tmp_path, "r:gz") as tar:
|
|
31
|
+
tar.extractall(model_dir)
|
|
32
|
+
log.info(f"Extracted model to {model_dir}")
|
|
33
|
+
finally:
|
|
34
|
+
if os.path.exists(tmp_path):
|
|
35
|
+
os.remove(tmp_path)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
39
|
+
"""Pull cross-validation results from AWS training artifacts.
|
|
40
|
+
|
|
41
|
+
This retrieves the validation predictions saved during model training and
|
|
42
|
+
computes metrics directly from them. For PyTorch models trained with
|
|
43
|
+
n_folds > 1, these are out-of-fold predictions from k-fold cross-validation.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
workbench_model: Workbench model object
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Tuple of:
|
|
50
|
+
- DataFrame with computed metrics
|
|
51
|
+
- DataFrame with validation predictions
|
|
52
|
+
"""
|
|
53
|
+
# Get the validation predictions from S3
|
|
54
|
+
s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
|
|
55
|
+
predictions_df = pull_s3_data(s3_path)
|
|
56
|
+
|
|
57
|
+
if predictions_df is None:
|
|
58
|
+
raise ValueError(f"No validation predictions found at {s3_path}")
|
|
59
|
+
|
|
60
|
+
log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
|
|
61
|
+
|
|
62
|
+
# Compute metrics from predictions
|
|
63
|
+
target = workbench_model.target()
|
|
64
|
+
class_labels = workbench_model.class_labels()
|
|
65
|
+
|
|
66
|
+
if target in predictions_df.columns and "prediction" in predictions_df.columns:
|
|
67
|
+
metrics_df = compute_metrics_from_predictions(predictions_df, target, class_labels)
|
|
68
|
+
else:
|
|
69
|
+
metrics_df = pd.DataFrame()
|
|
70
|
+
|
|
71
|
+
return metrics_df, predictions_df
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
if __name__ == "__main__":
|
|
75
|
+
from workbench.api import Model
|
|
76
|
+
|
|
77
|
+
# Test pulling CV results
|
|
78
|
+
model_name = "aqsol-reg-pytorch"
|
|
79
|
+
print(f"Loading Workbench model: {model_name}")
|
|
80
|
+
model = Model(model_name)
|
|
81
|
+
print(f"Model Framework: {model.model_framework}")
|
|
82
|
+
|
|
83
|
+
# Pull CV results from training artifacts
|
|
84
|
+
metrics_df, predictions_df = pull_cv_results(model)
|
|
85
|
+
print(f"\nMetrics:\n{metrics_df}")
|
|
86
|
+
print(f"\nPredictions shape: {predictions_df.shape}")
|
|
87
|
+
print(f"Predictions columns: {predictions_df.columns.tolist()}")
|
workbench/utils/shap_utils.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Optional, List, Tuple, Dict, Union
|
|
|
9
9
|
from workbench.utils.xgboost_model_utils import xgboost_model_from_s3
|
|
10
10
|
from workbench.utils.model_utils import load_category_mappings_from_s3
|
|
11
11
|
from workbench.utils.pandas_utils import convert_categorical_types
|
|
12
|
+
from workbench.model_script_utils.model_script_utils import decompress_features
|
|
12
13
|
|
|
13
14
|
# Set up the log
|
|
14
15
|
log = logging.getLogger("workbench")
|
|
@@ -111,61 +112,6 @@ def shap_values_data(
|
|
|
111
112
|
return result_df, feature_df
|
|
112
113
|
|
|
113
114
|
|
|
114
|
-
def decompress_features(
|
|
115
|
-
df: pd.DataFrame, features: List[str], compressed_features: List[str]
|
|
116
|
-
) -> Tuple[pd.DataFrame, List[str]]:
|
|
117
|
-
"""Prepare features for the XGBoost model
|
|
118
|
-
|
|
119
|
-
Args:
|
|
120
|
-
df (pd.DataFrame): The features DataFrame
|
|
121
|
-
features (List[str]): Full list of feature names
|
|
122
|
-
compressed_features (List[str]): List of feature names to decompress (bitstrings)
|
|
123
|
-
|
|
124
|
-
Returns:
|
|
125
|
-
pd.DataFrame: DataFrame with the decompressed features
|
|
126
|
-
List[str]: Updated list of feature names after decompression
|
|
127
|
-
|
|
128
|
-
Raises:
|
|
129
|
-
ValueError: If any missing values are found in the specified features
|
|
130
|
-
"""
|
|
131
|
-
|
|
132
|
-
# Check for any missing values in the required features
|
|
133
|
-
missing_counts = df[features].isna().sum()
|
|
134
|
-
if missing_counts.any():
|
|
135
|
-
missing_features = missing_counts[missing_counts > 0]
|
|
136
|
-
print(
|
|
137
|
-
f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
|
|
138
|
-
"WARNING: You might want to remove/replace all NaN values before processing."
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
# Decompress the specified compressed features
|
|
142
|
-
decompressed_features = features
|
|
143
|
-
for feature in compressed_features:
|
|
144
|
-
if (feature not in df.columns) or (feature not in features):
|
|
145
|
-
print(f"Feature '{feature}' not in the features list, skipping decompression.")
|
|
146
|
-
continue
|
|
147
|
-
|
|
148
|
-
# Remove the feature from the list of features to avoid duplication
|
|
149
|
-
decompressed_features.remove(feature)
|
|
150
|
-
|
|
151
|
-
# Handle all compressed features as bitstrings
|
|
152
|
-
bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
|
|
153
|
-
prefix = feature[:3]
|
|
154
|
-
|
|
155
|
-
# Create all new columns at once - avoids fragmentation
|
|
156
|
-
new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
|
|
157
|
-
new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
|
|
158
|
-
|
|
159
|
-
# Add to features list
|
|
160
|
-
decompressed_features.extend(new_col_names)
|
|
161
|
-
|
|
162
|
-
# Drop original column and concatenate new ones
|
|
163
|
-
df = df.drop(columns=[feature])
|
|
164
|
-
df = pd.concat([df, new_df], axis=1)
|
|
165
|
-
|
|
166
|
-
return df, decompressed_features
|
|
167
|
-
|
|
168
|
-
|
|
169
115
|
def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
|
|
170
116
|
"""
|
|
171
117
|
Internal function to calculate SHAP values for Workbench Models.
|
|
@@ -212,6 +158,14 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
|
|
|
212
158
|
log.error("No XGBoost model found in the artifact.")
|
|
213
159
|
return None, None, None, None
|
|
214
160
|
|
|
161
|
+
# Get the booster (SHAP requires the booster, not the sklearn wrapper)
|
|
162
|
+
if hasattr(xgb_model, "get_booster"):
|
|
163
|
+
# Full sklearn model - extract the booster
|
|
164
|
+
booster = xgb_model.get_booster()
|
|
165
|
+
else:
|
|
166
|
+
# Already a booster
|
|
167
|
+
booster = xgb_model
|
|
168
|
+
|
|
215
169
|
# Load category mappings if available
|
|
216
170
|
category_mappings = load_category_mappings_from_s3(model_artifact_uri)
|
|
217
171
|
|
|
@@ -229,8 +183,8 @@ def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
|
|
|
229
183
|
# Create a DMatrix with categorical support
|
|
230
184
|
dmatrix = xgb.DMatrix(X, enable_categorical=True)
|
|
231
185
|
|
|
232
|
-
# Use XGBoost's built-in SHAP calculation
|
|
233
|
-
shap_values =
|
|
186
|
+
# Use XGBoost's built-in SHAP calculation (booster method, not sklearn)
|
|
187
|
+
shap_values = booster.predict(dmatrix, pred_contribs=True, strict_shape=True)
|
|
234
188
|
features_with_bias = features + ["bias"]
|
|
235
189
|
|
|
236
190
|
# Now we need to subset the columns based on top 10 SHAP values
|
workbench/utils/theme_manager.py
CHANGED
|
@@ -76,10 +76,28 @@ class ThemeManager:
|
|
|
76
76
|
def set_theme(cls, theme_name: str):
|
|
77
77
|
"""Set the current theme."""
|
|
78
78
|
|
|
79
|
-
# For 'auto', we
|
|
80
|
-
#
|
|
79
|
+
# For 'auto', we check multiple sources in priority order:
|
|
80
|
+
# 1. Browser cookie (from localStorage, for per-user preference)
|
|
81
|
+
# 2. Parameter Store (for org-wide default)
|
|
82
|
+
# 3. Default theme
|
|
81
83
|
if theme_name == "auto":
|
|
82
|
-
theme_name =
|
|
84
|
+
theme_name = None
|
|
85
|
+
|
|
86
|
+
# 1. Check Flask request cookie (set from localStorage)
|
|
87
|
+
try:
|
|
88
|
+
from flask import request, has_request_context
|
|
89
|
+
|
|
90
|
+
if has_request_context():
|
|
91
|
+
theme_name = request.cookies.get("wb_theme")
|
|
92
|
+
except Exception:
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
# 2. Fall back to ParameterStore
|
|
96
|
+
if not theme_name:
|
|
97
|
+
theme_name = cls.ps.get("/workbench/dashboard/theme", warn=False)
|
|
98
|
+
|
|
99
|
+
# 3. Fall back to default
|
|
100
|
+
theme_name = theme_name or cls.default_theme
|
|
83
101
|
|
|
84
102
|
# Check if the theme is in our available themes
|
|
85
103
|
if theme_name not in cls.available_themes:
|
|
@@ -104,9 +122,27 @@ class ThemeManager:
|
|
|
104
122
|
cls.current_theme_name = theme_name
|
|
105
123
|
cls.log.info(f"Theme set to '{theme_name}'")
|
|
106
124
|
|
|
125
|
+
# Bootstrap themes that are dark mode (from Bootswatch)
|
|
126
|
+
_dark_bootstrap_themes = {"DARKLY", "CYBORG", "SLATE", "SOLAR", "SUPERHERO", "VAPOR"}
|
|
127
|
+
|
|
107
128
|
@classmethod
|
|
108
129
|
def dark_mode(cls) -> bool:
|
|
109
|
-
"""Check if the current theme is a dark mode theme.
|
|
130
|
+
"""Check if the current theme is a dark mode theme.
|
|
131
|
+
|
|
132
|
+
Determines dark mode by checking if the Bootstrap base theme is a known dark theme.
|
|
133
|
+
Falls back to checking if 'dark' is in the theme name.
|
|
134
|
+
"""
|
|
135
|
+
theme = cls.available_themes.get(cls.current_theme_name, {})
|
|
136
|
+
base_css = theme.get("base_css", "")
|
|
137
|
+
|
|
138
|
+
# Check if the base CSS URL contains a known dark Bootstrap theme
|
|
139
|
+
if base_css:
|
|
140
|
+
base_css_upper = base_css.upper()
|
|
141
|
+
for dark_theme in cls._dark_bootstrap_themes:
|
|
142
|
+
if dark_theme in base_css_upper:
|
|
143
|
+
return True
|
|
144
|
+
|
|
145
|
+
# Fallback: check if 'dark' is in the theme name
|
|
110
146
|
return "dark" in cls.current_theme().lower()
|
|
111
147
|
|
|
112
148
|
@classmethod
|
|
@@ -184,30 +220,57 @@ class ThemeManager:
|
|
|
184
220
|
|
|
185
221
|
@classmethod
|
|
186
222
|
def css_files(cls) -> list[str]:
|
|
187
|
-
"""Get the list of CSS files for the current theme.
|
|
188
|
-
|
|
223
|
+
"""Get the list of CSS files for the current theme.
|
|
224
|
+
|
|
225
|
+
Note: Uses /base.css route for dynamic theme switching instead of CDN URLs.
|
|
226
|
+
"""
|
|
189
227
|
css_files = []
|
|
190
228
|
|
|
191
|
-
#
|
|
192
|
-
|
|
193
|
-
css_files.append(theme["base_css"])
|
|
229
|
+
# Use Flask route for base CSS (allows dynamic theme switching)
|
|
230
|
+
css_files.append("/base.css")
|
|
194
231
|
|
|
195
232
|
# Add the DBC template CSS
|
|
196
233
|
css_files.append(cls.dbc_css)
|
|
197
234
|
|
|
198
235
|
# Add custom.css if it exists
|
|
199
|
-
|
|
200
|
-
css_files.append("/custom.css")
|
|
236
|
+
css_files.append("/custom.css")
|
|
201
237
|
|
|
202
238
|
return css_files
|
|
203
239
|
|
|
240
|
+
@classmethod
|
|
241
|
+
def _get_theme_from_cookie(cls):
|
|
242
|
+
"""Get the theme dict based on the wb_theme cookie, falling back to current theme."""
|
|
243
|
+
from flask import request
|
|
244
|
+
|
|
245
|
+
theme_name = request.cookies.get("wb_theme")
|
|
246
|
+
if theme_name and theme_name in cls.available_themes:
|
|
247
|
+
return cls.available_themes[theme_name], theme_name
|
|
248
|
+
return cls.available_themes[cls.current_theme_name], cls.current_theme_name
|
|
249
|
+
|
|
204
250
|
@classmethod
|
|
205
251
|
def register_css_route(cls, app):
|
|
206
|
-
"""Register Flask
|
|
252
|
+
"""Register Flask routes for CSS and before_request hook for theme switching."""
|
|
253
|
+
from flask import redirect
|
|
254
|
+
|
|
255
|
+
@app.server.before_request
|
|
256
|
+
def check_theme_cookie():
|
|
257
|
+
"""Check for theme cookie on each request and update theme if needed."""
|
|
258
|
+
_, theme_name = cls._get_theme_from_cookie()
|
|
259
|
+
if theme_name != cls.current_theme_name:
|
|
260
|
+
cls.set_theme(theme_name)
|
|
261
|
+
|
|
262
|
+
@app.server.route("/base.css")
|
|
263
|
+
def serve_base_css():
|
|
264
|
+
"""Redirect to the appropriate Bootstrap theme CSS based on cookie."""
|
|
265
|
+
theme, _ = cls._get_theme_from_cookie()
|
|
266
|
+
if theme["base_css"]:
|
|
267
|
+
return redirect(theme["base_css"])
|
|
268
|
+
return "", 404
|
|
207
269
|
|
|
208
270
|
@app.server.route("/custom.css")
|
|
209
271
|
def serve_custom_css():
|
|
210
|
-
|
|
272
|
+
"""Serve the custom.css file based on cookie."""
|
|
273
|
+
theme, _ = cls._get_theme_from_cookie()
|
|
211
274
|
if theme["custom_css"]:
|
|
212
275
|
return send_from_directory(theme["custom_css"].parent, theme["custom_css"].name)
|
|
213
276
|
return "", 404
|
|
@@ -250,23 +313,25 @@ class ThemeManager:
|
|
|
250
313
|
# Loop over each path in the theme path
|
|
251
314
|
for theme_path in cls.theme_path_list:
|
|
252
315
|
for theme_dir in theme_path.iterdir():
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
316
|
+
# Skip hidden directories (e.g., .idea, .git)
|
|
317
|
+
if not theme_dir.is_dir() or theme_dir.name.startswith("."):
|
|
318
|
+
continue
|
|
319
|
+
theme_name = theme_dir.name
|
|
320
|
+
|
|
321
|
+
# Grab the base.css URL
|
|
322
|
+
base_css_url = cls._get_base_css_url(theme_dir)
|
|
323
|
+
|
|
324
|
+
# Grab the plotly template json, custom.css, and branding json
|
|
325
|
+
plotly_template = theme_dir / "plotly.json"
|
|
326
|
+
custom_css = theme_dir / "custom.css"
|
|
327
|
+
branding = theme_dir / "branding.json"
|
|
328
|
+
|
|
329
|
+
cls.available_themes[theme_name] = {
|
|
330
|
+
"base_css": base_css_url,
|
|
331
|
+
"plotly_template": plotly_template,
|
|
332
|
+
"custom_css": custom_css if custom_css.exists() else None,
|
|
333
|
+
"branding": branding if branding.exists() else None,
|
|
334
|
+
}
|
|
270
335
|
|
|
271
336
|
if not cls.available_themes:
|
|
272
337
|
cls.log.warning(f"No themes found in '{cls.theme_path_list}'...")
|