workbench 0.8.193__py3-none-any.whl → 0.8.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
- workbench/algorithms/dataframe/proximity.py +212 -234
- workbench/algorithms/graph/light/proximity_graph.py +8 -7
- workbench/api/endpoint.py +2 -3
- workbench/api/model.py +2 -5
- workbench/core/artifacts/endpoint_core.py +25 -16
- workbench/core/artifacts/feature_set_core.py +126 -4
- workbench/core/artifacts/model_core.py +9 -14
- workbench/core/transforms/features_to_model/features_to_model.py +3 -3
- workbench/core/views/training_view.py +75 -0
- workbench/core/views/view.py +1 -1
- workbench/model_scripts/custom_models/proximity/proximity.py +212 -234
- workbench/model_scripts/custom_models/uq_models/proximity.py +212 -234
- workbench/model_scripts/pytorch_model/generated_model_script.py +567 -0
- workbench/model_scripts/uq_models/generated_model_script.py +589 -0
- workbench/model_scripts/uq_models/mapie.template +103 -6
- workbench/model_scripts/xgb_model/generated_model_script.py +4 -4
- workbench/repl/workbench_shell.py +3 -3
- workbench/utils/model_utils.py +10 -7
- workbench/utils/xgboost_model_utils.py +93 -34
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/model_details.py +2 -5
- {workbench-0.8.193.dist-info → workbench-0.8.197.dist-info}/METADATA +1 -1
- {workbench-0.8.193.dist-info → workbench-0.8.197.dist-info}/RECORD +29 -27
- {workbench-0.8.193.dist-info → workbench-0.8.197.dist-info}/WHEEL +0 -0
- {workbench-0.8.193.dist-info → workbench-0.8.197.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.193.dist-info → workbench-0.8.197.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.193.dist-info → workbench-0.8.197.dist-info}/top_level.txt +0 -0
|
@@ -14,7 +14,7 @@ import joblib
|
|
|
14
14
|
import os
|
|
15
15
|
import numpy as np
|
|
16
16
|
import pandas as pd
|
|
17
|
-
from typing import List, Tuple
|
|
17
|
+
from typing import List, Tuple, Optional, Dict
|
|
18
18
|
|
|
19
19
|
# Template Placeholders
|
|
20
20
|
TEMPLATE_PARAMS = {
|
|
@@ -26,6 +26,46 @@ TEMPLATE_PARAMS = {
|
|
|
26
26
|
}
|
|
27
27
|
|
|
28
28
|
|
|
29
|
+
def compute_confidence(
|
|
30
|
+
df: pd.DataFrame,
|
|
31
|
+
median_interval_width: float,
|
|
32
|
+
lower_q: str = "q_10",
|
|
33
|
+
upper_q: str = "q_90",
|
|
34
|
+
alpha: float = 1.0,
|
|
35
|
+
beta: float = 1.0,
|
|
36
|
+
) -> pd.DataFrame:
|
|
37
|
+
"""
|
|
38
|
+
Compute confidence scores (0.0 to 1.0) based on prediction interval width
|
|
39
|
+
and distance from median using exponential decay.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
df: DataFrame with 'prediction', 'q_50', and quantile columns
|
|
43
|
+
median_interval_width: Pre-computed median interval width from training data
|
|
44
|
+
lower_q: Lower quantile column name (default: 'q_10')
|
|
45
|
+
upper_q: Upper quantile column name (default: 'q_90')
|
|
46
|
+
alpha: Weight for interval width term (default: 1.0)
|
|
47
|
+
beta: Weight for distance from median term (default: 1.0)
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
DataFrame with added 'confidence' column
|
|
51
|
+
"""
|
|
52
|
+
# Interval width
|
|
53
|
+
interval_width = (df[upper_q] - df[lower_q]).abs()
|
|
54
|
+
|
|
55
|
+
# Distance from median, normalized by interval width
|
|
56
|
+
distance_from_median = (df['prediction'] - df['q_50']).abs()
|
|
57
|
+
normalized_distance = distance_from_median / (interval_width + 1e-6)
|
|
58
|
+
|
|
59
|
+
# Cap the distance penalty at 1.0
|
|
60
|
+
normalized_distance = np.minimum(normalized_distance, 1.0)
|
|
61
|
+
|
|
62
|
+
# Confidence using exponential decay
|
|
63
|
+
interval_term = interval_width / median_interval_width
|
|
64
|
+
df['confidence'] = np.exp(-(alpha * interval_term + beta * normalized_distance))
|
|
65
|
+
|
|
66
|
+
return df
|
|
67
|
+
|
|
68
|
+
|
|
29
69
|
# Function to check if dataframe is empty
|
|
30
70
|
def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
|
|
31
71
|
"""
|
|
@@ -98,7 +138,7 @@ def convert_categorical_types(df: pd.DataFrame, features: list, category_mapping
|
|
|
98
138
|
|
|
99
139
|
|
|
100
140
|
def decompress_features(
|
|
101
|
-
|
|
141
|
+
df: pd.DataFrame, features: List[str], compressed_features: List[str]
|
|
102
142
|
) -> Tuple[pd.DataFrame, List[str]]:
|
|
103
143
|
"""Prepare features for the model by decompressing bitstring features
|
|
104
144
|
|
|
@@ -302,6 +342,46 @@ if __name__ == "__main__":
|
|
|
302
342
|
widths = y_pis[:, 1, 0] - y_pis[:, 0, 0]
|
|
303
343
|
print(f" {conf_level * 100:.0f}% CI: Mean width={np.mean(widths):.3f}, Std={np.std(widths):.3f}")
|
|
304
344
|
|
|
345
|
+
# Compute normalization statistics for confidence calculation
|
|
346
|
+
print(f"\nComputing normalization statistics for confidence scores...")
|
|
347
|
+
|
|
348
|
+
# Create a temporary validation dataframe with predictions
|
|
349
|
+
temp_val_df = df_val.copy()
|
|
350
|
+
temp_val_df["prediction"] = xgb_model.predict(X_validate)
|
|
351
|
+
|
|
352
|
+
# Add all quantile predictions
|
|
353
|
+
for conf_level in confidence_levels:
|
|
354
|
+
model_name = f"mapie_{conf_level:.2f}"
|
|
355
|
+
model = mapie_models[model_name]
|
|
356
|
+
y_pred, y_pis = model.predict_interval(X_validate)
|
|
357
|
+
|
|
358
|
+
if conf_level == 0.50:
|
|
359
|
+
temp_val_df["q_25"] = y_pis[:, 0, 0]
|
|
360
|
+
temp_val_df["q_75"] = y_pis[:, 1, 0]
|
|
361
|
+
# y_pred is the median prediction
|
|
362
|
+
temp_val_df["q_50"] = y_pred
|
|
363
|
+
elif conf_level == 0.68:
|
|
364
|
+
temp_val_df["q_16"] = y_pis[:, 0, 0]
|
|
365
|
+
temp_val_df["q_84"] = y_pis[:, 1, 0]
|
|
366
|
+
elif conf_level == 0.80:
|
|
367
|
+
temp_val_df["q_10"] = y_pis[:, 0, 0]
|
|
368
|
+
temp_val_df["q_90"] = y_pis[:, 1, 0]
|
|
369
|
+
elif conf_level == 0.90:
|
|
370
|
+
temp_val_df["q_05"] = y_pis[:, 0, 0]
|
|
371
|
+
temp_val_df["q_95"] = y_pis[:, 1, 0]
|
|
372
|
+
elif conf_level == 0.95:
|
|
373
|
+
temp_val_df["q_025"] = y_pis[:, 0, 0]
|
|
374
|
+
temp_val_df["q_975"] = y_pis[:, 1, 0]
|
|
375
|
+
|
|
376
|
+
# Compute normalization stats using q_10 and q_90 (default range)
|
|
377
|
+
interval_width = (temp_val_df["q_90"] - temp_val_df["q_10"]).abs()
|
|
378
|
+
median_interval_width = float(interval_width.median())
|
|
379
|
+
print(f" Median interval width (q_10-q_90): {median_interval_width:.6f}")
|
|
380
|
+
|
|
381
|
+
# Save median interval width for confidence calculation
|
|
382
|
+
with open(os.path.join(args.model_dir, "median_interval_width.json"), "w") as fp:
|
|
383
|
+
json.dump(median_interval_width, fp)
|
|
384
|
+
|
|
305
385
|
# Save the trained XGBoost model
|
|
306
386
|
joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
|
|
307
387
|
|
|
@@ -365,11 +445,19 @@ def model_fn(model_dir) -> dict:
|
|
|
365
445
|
with open(category_path) as fp:
|
|
366
446
|
category_mappings = json.load(fp)
|
|
367
447
|
|
|
448
|
+
# Load median interval width for confidence calculation
|
|
449
|
+
median_interval_width = None
|
|
450
|
+
median_width_path = os.path.join(model_dir, "median_interval_width.json")
|
|
451
|
+
if os.path.exists(median_width_path):
|
|
452
|
+
with open(median_width_path) as fp:
|
|
453
|
+
median_interval_width = json.load(fp)
|
|
454
|
+
|
|
368
455
|
return {
|
|
369
456
|
"xgb_model": xgb_model,
|
|
370
457
|
"mapie_models": mapie_models,
|
|
371
458
|
"confidence_levels": config["confidence_levels"],
|
|
372
459
|
"category_mappings": category_mappings,
|
|
460
|
+
"median_interval_width": median_interval_width,
|
|
373
461
|
}
|
|
374
462
|
|
|
375
463
|
|
|
@@ -449,6 +537,8 @@ def predict_fn(df, models) -> pd.DataFrame:
|
|
|
449
537
|
if conf_level == 0.50: # 50% CI
|
|
450
538
|
df["q_25"] = y_pis[:, 0, 0]
|
|
451
539
|
df["q_75"] = y_pis[:, 1, 0]
|
|
540
|
+
# y_pred is the median prediction
|
|
541
|
+
df["q_50"] = y_pred
|
|
452
542
|
elif conf_level == 0.68: # 68% CI
|
|
453
543
|
df["q_16"] = y_pis[:, 0, 0]
|
|
454
544
|
df["q_84"] = y_pis[:, 1, 0]
|
|
@@ -462,14 +552,11 @@ def predict_fn(df, models) -> pd.DataFrame:
|
|
|
462
552
|
df["q_025"] = y_pis[:, 0, 0]
|
|
463
553
|
df["q_975"] = y_pis[:, 1, 0]
|
|
464
554
|
|
|
465
|
-
# Add median (q_50) from XGBoost prediction
|
|
466
|
-
df["q_50"] = df["prediction"]
|
|
467
|
-
|
|
468
555
|
# Calculate a pseudo-standard deviation from the 68% interval width
|
|
469
556
|
df["prediction_std"] = (df["q_84"] - df["q_16"]).abs() / 2.0
|
|
470
557
|
|
|
471
558
|
# Reorder the quantile columns for easier reading
|
|
472
|
-
quantile_cols = ["q_025", "q_05", "q_10", "q_16", "q_25", "q_75", "q_84", "q_90", "q_95", "q_975"]
|
|
559
|
+
quantile_cols = ["q_025", "q_05", "q_10", "q_16", "q_25", "q_50", "q_75", "q_84", "q_90", "q_95", "q_975"]
|
|
473
560
|
other_cols = [col for col in df.columns if col not in quantile_cols]
|
|
474
561
|
df = df[other_cols + quantile_cols]
|
|
475
562
|
|
|
@@ -489,4 +576,14 @@ def predict_fn(df, models) -> pd.DataFrame:
|
|
|
489
576
|
df["q_95"] = np.maximum(df["q_95"], df["prediction"])
|
|
490
577
|
df["q_975"] = np.maximum(df["q_975"], df["prediction"])
|
|
491
578
|
|
|
579
|
+
# Compute confidence scores using pre-computed normalization stats
|
|
580
|
+
df = compute_confidence(
|
|
581
|
+
df,
|
|
582
|
+
lower_q="q_10",
|
|
583
|
+
upper_q="q_90",
|
|
584
|
+
alpha=1.0,
|
|
585
|
+
beta=1.0,
|
|
586
|
+
median_interval_width=models["median_interval_width"],
|
|
587
|
+
)
|
|
588
|
+
|
|
492
589
|
return df
|
|
@@ -28,11 +28,11 @@ from typing import List, Tuple
|
|
|
28
28
|
|
|
29
29
|
# Template Parameters
|
|
30
30
|
TEMPLATE_PARAMS = {
|
|
31
|
-
"model_type": "
|
|
32
|
-
"target": "
|
|
33
|
-
"features": ['
|
|
31
|
+
"model_type": "classifier",
|
|
32
|
+
"target": "wine_class",
|
|
33
|
+
"features": ['alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'magnesium', 'total_phenols', 'flavanoids', 'nonflavanoid_phenols', 'proanthocyanins', 'color_intensity', 'hue', 'od280_od315_of_diluted_wines', 'proline'],
|
|
34
34
|
"compressed_features": [],
|
|
35
|
-
"model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/
|
|
35
|
+
"model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/wine-classification/training",
|
|
36
36
|
"train_all_data": False,
|
|
37
37
|
"hyperparameters": {},
|
|
38
38
|
}
|
|
@@ -525,7 +525,7 @@ class WorkbenchShell:
|
|
|
525
525
|
def get_meta(self):
|
|
526
526
|
return self.meta
|
|
527
527
|
|
|
528
|
-
def plot_manager(self, data, plot_type: str = "
|
|
528
|
+
def plot_manager(self, data, plot_type: str = "scatter", **kwargs):
|
|
529
529
|
"""Plot Manager for Workbench"""
|
|
530
530
|
from workbench.web_interface.components.plugins import ag_table, graph_plot, scatter_plot
|
|
531
531
|
|
|
@@ -564,10 +564,10 @@ class WorkbenchShell:
|
|
|
564
564
|
|
|
565
565
|
plugin_test = PluginUnitTest(plugin_class, theme=theme, input_data=data, **kwargs)
|
|
566
566
|
|
|
567
|
-
#
|
|
568
|
-
plugin_test.run()
|
|
567
|
+
# Open the browser and run the dash server
|
|
569
568
|
url = f"http://127.0.0.1:{plugin_test.port}"
|
|
570
569
|
webbrowser.open(url)
|
|
570
|
+
plugin_test.run()
|
|
571
571
|
|
|
572
572
|
|
|
573
573
|
# Launch Shell Entry Point
|
workbench/utils/model_utils.py
CHANGED
|
@@ -113,9 +113,16 @@ def proximity_model_local(model: "Model"):
|
|
|
113
113
|
fs = FeatureSet(model.get_input())
|
|
114
114
|
id_column = fs.id_column
|
|
115
115
|
|
|
116
|
-
# Create the Proximity Model from
|
|
117
|
-
|
|
118
|
-
|
|
116
|
+
# Create the Proximity Model from both the full FeatureSet and the Model training data
|
|
117
|
+
full_df = fs.pull_dataframe()
|
|
118
|
+
model_df = model.training_view().pull_dataframe()
|
|
119
|
+
|
|
120
|
+
# Mark rows that are in the model
|
|
121
|
+
model_ids = set(model_df[id_column])
|
|
122
|
+
full_df["in_model"] = full_df[id_column].isin(model_ids)
|
|
123
|
+
|
|
124
|
+
# Create and return the Proximity Model
|
|
125
|
+
return Proximity(full_df, id_column, features, target, track_columns=features)
|
|
119
126
|
|
|
120
127
|
|
|
121
128
|
def proximity_model(model: "Model", prox_model_name: str, track_columns: list = None) -> "Model":
|
|
@@ -165,9 +172,6 @@ def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -
|
|
|
165
172
|
"""
|
|
166
173
|
from workbench.api import Model, ModelType, FeatureSet # noqa: F401 (avoid circular import)
|
|
167
174
|
|
|
168
|
-
# Get the custom script path for the UQ model
|
|
169
|
-
script_path = get_custom_script_path("uq_models", "mapie.template")
|
|
170
|
-
|
|
171
175
|
# Get Feature and Target Columns from the existing given Model
|
|
172
176
|
features = model.features()
|
|
173
177
|
target = model.target()
|
|
@@ -182,7 +186,6 @@ def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -
|
|
|
182
186
|
description=f"UQ Model for {model.name}",
|
|
183
187
|
tags=["uq", model.name],
|
|
184
188
|
train_all_data=train_all_data,
|
|
185
|
-
custom_script=script_path,
|
|
186
189
|
custom_args={"id_column": fs.id_column, "track_columns": [target]},
|
|
187
190
|
)
|
|
188
191
|
return uq_model
|
|
@@ -7,12 +7,11 @@ import joblib
|
|
|
7
7
|
import pickle
|
|
8
8
|
import glob
|
|
9
9
|
import awswrangler as wr
|
|
10
|
-
from typing import Optional, List, Tuple
|
|
10
|
+
from typing import Optional, List, Tuple, Any
|
|
11
11
|
import hashlib
|
|
12
12
|
import pandas as pd
|
|
13
13
|
import numpy as np
|
|
14
14
|
import xgboost as xgb
|
|
15
|
-
from typing import Dict, Any
|
|
16
15
|
from sklearn.model_selection import KFold, StratifiedKFold
|
|
17
16
|
from sklearn.metrics import (
|
|
18
17
|
precision_recall_fscore_support,
|
|
@@ -20,13 +19,14 @@ from sklearn.metrics import (
|
|
|
20
19
|
mean_absolute_error,
|
|
21
20
|
r2_score,
|
|
22
21
|
median_absolute_error,
|
|
22
|
+
roc_auc_score,
|
|
23
23
|
)
|
|
24
24
|
from scipy.stats import spearmanr
|
|
25
25
|
from sklearn.preprocessing import LabelEncoder
|
|
26
26
|
|
|
27
27
|
# Workbench Imports
|
|
28
28
|
from workbench.utils.model_utils import load_category_mappings_from_s3, safe_extract_tarfile
|
|
29
|
-
from workbench.utils.pandas_utils import convert_categorical_types
|
|
29
|
+
from workbench.utils.pandas_utils import convert_categorical_types, expand_proba_column
|
|
30
30
|
|
|
31
31
|
# Set up the log
|
|
32
32
|
log = logging.getLogger("workbench")
|
|
@@ -258,7 +258,7 @@ def leaf_stats(df: pd.DataFrame, target_col: str) -> pd.DataFrame:
|
|
|
258
258
|
return result_df
|
|
259
259
|
|
|
260
260
|
|
|
261
|
-
def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[
|
|
261
|
+
def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
262
262
|
"""
|
|
263
263
|
Performs K-fold cross-validation with detailed metrics.
|
|
264
264
|
Args:
|
|
@@ -266,10 +266,8 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
|
|
|
266
266
|
nfolds: Number of folds for cross-validation (default is 5)
|
|
267
267
|
Returns:
|
|
268
268
|
Tuple of:
|
|
269
|
-
-
|
|
270
|
-
|
|
271
|
-
- summary_metrics: Summary metrics across folds
|
|
272
|
-
- DataFrame with columns: id, target, prediction (out-of-fold predictions for all samples)
|
|
269
|
+
- DataFrame with per-class metrics (and 'all' row for overall metrics)
|
|
270
|
+
- DataFrame with columns: id, target, prediction, and *_proba columns (for classifiers)
|
|
273
271
|
"""
|
|
274
272
|
from workbench.api import FeatureSet
|
|
275
273
|
|
|
@@ -278,7 +276,7 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
|
|
|
278
276
|
loaded_model = xgboost_model_from_s3(model_artifact_uri)
|
|
279
277
|
if loaded_model is None:
|
|
280
278
|
log.error("No XGBoost model found in the artifact.")
|
|
281
|
-
return
|
|
279
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
282
280
|
|
|
283
281
|
# Check if we got a full sklearn model or need to create one
|
|
284
282
|
if isinstance(loaded_model, (xgb.XGBClassifier, xgb.XGBRegressor)):
|
|
@@ -304,7 +302,7 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
|
|
|
304
302
|
xgb_model._Booster = loaded_model
|
|
305
303
|
else:
|
|
306
304
|
log.error(f"Unexpected model type: {type(loaded_model)}")
|
|
307
|
-
return
|
|
305
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
308
306
|
|
|
309
307
|
# Prepare data
|
|
310
308
|
fs = FeatureSet(workbench_model.get_input())
|
|
@@ -335,12 +333,12 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
|
|
|
335
333
|
y_for_cv = y
|
|
336
334
|
|
|
337
335
|
# Prepare KFold
|
|
336
|
+
# Note: random_state=42 seems to not actually give us reproducible results
|
|
338
337
|
kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
|
|
339
338
|
|
|
340
339
|
# Initialize results collection
|
|
341
340
|
fold_metrics = []
|
|
342
|
-
predictions_df = pd.DataFrame({id_col: ids, target_col: y})
|
|
343
|
-
# Note: 'prediction' column will be created automatically with correct dtype
|
|
341
|
+
predictions_df = pd.DataFrame({id_col: ids, target_col: y})
|
|
344
342
|
|
|
345
343
|
# Perform cross-validation
|
|
346
344
|
for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y_for_cv), 1):
|
|
@@ -355,6 +353,8 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
|
|
|
355
353
|
val_indices = X_val.index
|
|
356
354
|
if is_classifier:
|
|
357
355
|
predictions_df.loc[val_indices, "prediction"] = label_encoder.inverse_transform(preds.astype(int))
|
|
356
|
+
y_proba = xgb_model.predict_proba(X_val)
|
|
357
|
+
predictions_df.loc[val_indices, "pred_proba"] = pd.Series(y_proba.tolist(), index=val_indices)
|
|
358
358
|
else:
|
|
359
359
|
predictions_df.loc[val_indices, "prediction"] = preds
|
|
360
360
|
|
|
@@ -362,10 +362,34 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
|
|
|
362
362
|
if is_classifier:
|
|
363
363
|
y_val_orig = label_encoder.inverse_transform(y_val)
|
|
364
364
|
preds_orig = label_encoder.inverse_transform(preds.astype(int))
|
|
365
|
+
|
|
366
|
+
# Overall weighted metrics
|
|
365
367
|
prec, rec, f1, _ = precision_recall_fscore_support(
|
|
366
368
|
y_val_orig, preds_orig, average="weighted", zero_division=0
|
|
367
369
|
)
|
|
368
|
-
|
|
370
|
+
|
|
371
|
+
# Per-class F1
|
|
372
|
+
prec_per_class, rec_per_class, f1_per_class, _ = precision_recall_fscore_support(
|
|
373
|
+
y_val_orig, preds_orig, average=None, zero_division=0, labels=label_encoder.classes_
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# ROC-AUC (overall and per-class)
|
|
377
|
+
roc_auc_overall = roc_auc_score(y_val, y_proba, multi_class="ovr", average="macro")
|
|
378
|
+
roc_auc_per_class = roc_auc_score(y_val, y_proba, multi_class="ovr", average=None)
|
|
379
|
+
|
|
380
|
+
fold_metrics.append(
|
|
381
|
+
{
|
|
382
|
+
"fold": fold_idx,
|
|
383
|
+
"precision": prec,
|
|
384
|
+
"recall": rec,
|
|
385
|
+
"f1": f1,
|
|
386
|
+
"roc_auc": roc_auc_overall,
|
|
387
|
+
"precision_per_class": prec_per_class,
|
|
388
|
+
"recall_per_class": rec_per_class,
|
|
389
|
+
"f1_per_class": f1_per_class,
|
|
390
|
+
"roc_auc_per_class": roc_auc_per_class,
|
|
391
|
+
}
|
|
392
|
+
)
|
|
369
393
|
else:
|
|
370
394
|
spearman_corr, _ = spearmanr(y_val, preds)
|
|
371
395
|
fold_metrics.append(
|
|
@@ -379,32 +403,67 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
|
|
|
379
403
|
}
|
|
380
404
|
)
|
|
381
405
|
|
|
382
|
-
# Calculate summary metrics
|
|
406
|
+
# Calculate summary metrics
|
|
383
407
|
fold_df = pd.DataFrame(fold_metrics)
|
|
384
|
-
metric_names = ["precision", "recall", "fscore"] if is_classifier else ["rmse", "mae", "medae", "r2", "spearmanr"]
|
|
385
|
-
summary_metrics = {metric: f"{fold_df[metric].mean():.3f} ±{fold_df[metric].std():.3f}" for metric in metric_names}
|
|
386
408
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
409
|
+
if is_classifier:
|
|
410
|
+
# Expand the *_proba columns into separate columns for easier handling
|
|
411
|
+
predictions_df = expand_proba_column(predictions_df, label_encoder.classes_)
|
|
412
|
+
|
|
413
|
+
# Build per-class metrics DataFrame
|
|
414
|
+
metric_rows = []
|
|
415
|
+
|
|
416
|
+
# Per-class rows
|
|
417
|
+
for idx, class_name in enumerate(label_encoder.classes_):
|
|
418
|
+
prec_scores = np.array([fold["precision_per_class"][idx] for fold in fold_metrics])
|
|
419
|
+
rec_scores = np.array([fold["recall_per_class"][idx] for fold in fold_metrics])
|
|
420
|
+
f1_scores = np.array([fold["f1_per_class"][idx] for fold in fold_metrics])
|
|
421
|
+
roc_auc_scores = np.array([fold["roc_auc_per_class"][idx] for fold in fold_metrics])
|
|
422
|
+
|
|
423
|
+
y_orig = label_encoder.inverse_transform(y_for_cv)
|
|
424
|
+
support = int((y_orig == class_name).sum())
|
|
425
|
+
|
|
426
|
+
metric_rows.append(
|
|
427
|
+
{
|
|
428
|
+
"class": class_name,
|
|
429
|
+
"precision": prec_scores.mean(),
|
|
430
|
+
"recall": rec_scores.mean(),
|
|
431
|
+
"f1": f1_scores.mean(),
|
|
432
|
+
"roc_auc": roc_auc_scores.mean(),
|
|
433
|
+
"support": support,
|
|
434
|
+
}
|
|
402
435
|
)
|
|
403
436
|
|
|
404
|
-
|
|
405
|
-
|
|
437
|
+
# Overall 'all' row
|
|
438
|
+
metric_rows.append(
|
|
439
|
+
{
|
|
440
|
+
"class": "all",
|
|
441
|
+
"precision": fold_df["precision"].mean(),
|
|
442
|
+
"recall": fold_df["recall"].mean(),
|
|
443
|
+
"f1": fold_df["f1"].mean(),
|
|
444
|
+
"roc_auc": fold_df["roc_auc"].mean(),
|
|
445
|
+
"support": len(y_for_cv),
|
|
446
|
+
}
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
metrics_df = pd.DataFrame(metric_rows)
|
|
450
|
+
|
|
451
|
+
else:
|
|
452
|
+
# Regression metrics
|
|
453
|
+
metrics_df = pd.DataFrame(
|
|
454
|
+
[
|
|
455
|
+
{
|
|
456
|
+
"rmse": fold_df["rmse"].mean(),
|
|
457
|
+
"mae": fold_df["mae"].mean(),
|
|
458
|
+
"medae": fold_df["medae"].mean(),
|
|
459
|
+
"r2": fold_df["r2"].mean(),
|
|
460
|
+
"spearmanr": fold_df["spearmanr"].mean(),
|
|
461
|
+
"support": len(y_for_cv),
|
|
462
|
+
}
|
|
463
|
+
]
|
|
464
|
+
)
|
|
406
465
|
|
|
407
|
-
return
|
|
466
|
+
return metrics_df, predictions_df
|
|
408
467
|
|
|
409
468
|
|
|
410
469
|
def leave_one_out_inference(workbench_model: Any) -> pd.DataFrame:
|
|
@@ -156,10 +156,13 @@ class PluginUnitTest:
|
|
|
156
156
|
"""Run the Dash server for the plugin, handling common errors gracefully."""
|
|
157
157
|
while self.is_port_in_use(self.port):
|
|
158
158
|
log.info(f"Port {self.port} is in use. Trying the next one...")
|
|
159
|
-
self.port += 1
|
|
159
|
+
self.port += 1
|
|
160
160
|
|
|
161
161
|
log.info(f"Starting Dash server on port {self.port}...")
|
|
162
|
-
|
|
162
|
+
try:
|
|
163
|
+
self.app.run(debug=True, use_reloader=False, port=self.port)
|
|
164
|
+
except KeyboardInterrupt:
|
|
165
|
+
log.info("Shutting down Dash server...")
|
|
163
166
|
|
|
164
167
|
@staticmethod
|
|
165
168
|
def is_port_in_use(port):
|
|
@@ -45,8 +45,6 @@ class ModelDetails(PluginInterface):
|
|
|
45
45
|
html.H5(children="Inference Metrics", style={"marginTop": "20px"}),
|
|
46
46
|
dcc.Dropdown(id=f"{self.component_id}-dropdown", className="dropdown"),
|
|
47
47
|
dcc.Markdown(id=f"{self.component_id}-metrics"),
|
|
48
|
-
html.H5(children="Cross Fold Metrics", style={"marginTop": "20px"}),
|
|
49
|
-
dcc.Markdown(id=f"{self.component_id}-cross-metrics", dangerously_allow_html=True),
|
|
50
48
|
],
|
|
51
49
|
)
|
|
52
50
|
|
|
@@ -57,7 +55,6 @@ class ModelDetails(PluginInterface):
|
|
|
57
55
|
(f"{self.component_id}-dropdown", "options"),
|
|
58
56
|
(f"{self.component_id}-dropdown", "value"),
|
|
59
57
|
(f"{self.component_id}-metrics", "children"),
|
|
60
|
-
(f"{self.component_id}-cross-metrics", "children"),
|
|
61
58
|
]
|
|
62
59
|
self.signals = [(f"{self.component_id}-dropdown", "value")]
|
|
63
60
|
|
|
@@ -84,10 +81,9 @@ class ModelDetails(PluginInterface):
|
|
|
84
81
|
# Populate the inference runs dropdown
|
|
85
82
|
inference_runs, default_run = self.get_inference_runs()
|
|
86
83
|
metrics = self.inference_metrics(default_run)
|
|
87
|
-
cross_metrics = self.cross_metrics()
|
|
88
84
|
|
|
89
85
|
# Return the updated property values for the plugin
|
|
90
|
-
return [header, details, inference_runs, default_run, metrics
|
|
86
|
+
return [header, details, inference_runs, default_run, metrics]
|
|
91
87
|
|
|
92
88
|
def register_internal_callbacks(self):
|
|
93
89
|
@callback(
|
|
@@ -225,6 +221,7 @@ class ModelDetails(PluginInterface):
|
|
|
225
221
|
|
|
226
222
|
def cross_metrics(self) -> str:
|
|
227
223
|
# Get cross fold metrics if they exist
|
|
224
|
+
# Note: Currently not used since we show cross fold metrics in the dropdown
|
|
228
225
|
model_name = self.current_model.name
|
|
229
226
|
cross_fold_data = self.params.get(f"/workbench/models/{model_name}/inference/cross_fold", warn=False)
|
|
230
227
|
if not cross_fold_data:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: workbench
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.197
|
|
4
4
|
Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
|
|
5
5
|
Author-email: SuperCowPowers LLC <support@supercowpowers.com>
|
|
6
6
|
License: MIT License
|