workbench 0.8.193__py3-none-any.whl → 0.8.197__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +212 -234
  4. workbench/algorithms/graph/light/proximity_graph.py +8 -7
  5. workbench/api/endpoint.py +2 -3
  6. workbench/api/model.py +2 -5
  7. workbench/core/artifacts/endpoint_core.py +25 -16
  8. workbench/core/artifacts/feature_set_core.py +126 -4
  9. workbench/core/artifacts/model_core.py +9 -14
  10. workbench/core/transforms/features_to_model/features_to_model.py +3 -3
  11. workbench/core/views/training_view.py +75 -0
  12. workbench/core/views/view.py +1 -1
  13. workbench/model_scripts/custom_models/proximity/proximity.py +212 -234
  14. workbench/model_scripts/custom_models/uq_models/proximity.py +212 -234
  15. workbench/model_scripts/pytorch_model/generated_model_script.py +567 -0
  16. workbench/model_scripts/uq_models/generated_model_script.py +589 -0
  17. workbench/model_scripts/uq_models/mapie.template +103 -6
  18. workbench/model_scripts/xgb_model/generated_model_script.py +4 -4
  19. workbench/repl/workbench_shell.py +3 -3
  20. workbench/utils/model_utils.py +10 -7
  21. workbench/utils/xgboost_model_utils.py +93 -34
  22. workbench/web_interface/components/plugin_unit_test.py +5 -2
  23. workbench/web_interface/components/plugins/model_details.py +2 -5
  24. {workbench-0.8.193.dist-info → workbench-0.8.197.dist-info}/METADATA +1 -1
  25. {workbench-0.8.193.dist-info → workbench-0.8.197.dist-info}/RECORD +29 -27
  26. {workbench-0.8.193.dist-info → workbench-0.8.197.dist-info}/WHEEL +0 -0
  27. {workbench-0.8.193.dist-info → workbench-0.8.197.dist-info}/entry_points.txt +0 -0
  28. {workbench-0.8.193.dist-info → workbench-0.8.197.dist-info}/licenses/LICENSE +0 -0
  29. {workbench-0.8.193.dist-info → workbench-0.8.197.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,7 @@ import joblib
14
14
  import os
15
15
  import numpy as np
16
16
  import pandas as pd
17
- from typing import List, Tuple
17
+ from typing import List, Tuple, Optional, Dict
18
18
 
19
19
  # Template Placeholders
20
20
  TEMPLATE_PARAMS = {
@@ -26,6 +26,46 @@ TEMPLATE_PARAMS = {
26
26
  }
27
27
 
28
28
 
29
+ def compute_confidence(
30
+ df: pd.DataFrame,
31
+ median_interval_width: float,
32
+ lower_q: str = "q_10",
33
+ upper_q: str = "q_90",
34
+ alpha: float = 1.0,
35
+ beta: float = 1.0,
36
+ ) -> pd.DataFrame:
37
+ """
38
+ Compute confidence scores (0.0 to 1.0) based on prediction interval width
39
+ and distance from median using exponential decay.
40
+
41
+ Args:
42
+ df: DataFrame with 'prediction', 'q_50', and quantile columns
43
+ median_interval_width: Pre-computed median interval width from training data
44
+ lower_q: Lower quantile column name (default: 'q_10')
45
+ upper_q: Upper quantile column name (default: 'q_90')
46
+ alpha: Weight for interval width term (default: 1.0)
47
+ beta: Weight for distance from median term (default: 1.0)
48
+
49
+ Returns:
50
+ DataFrame with added 'confidence' column
51
+ """
52
+ # Interval width
53
+ interval_width = (df[upper_q] - df[lower_q]).abs()
54
+
55
+ # Distance from median, normalized by interval width
56
+ distance_from_median = (df['prediction'] - df['q_50']).abs()
57
+ normalized_distance = distance_from_median / (interval_width + 1e-6)
58
+
59
+ # Cap the distance penalty at 1.0
60
+ normalized_distance = np.minimum(normalized_distance, 1.0)
61
+
62
+ # Confidence using exponential decay
63
+ interval_term = interval_width / median_interval_width
64
+ df['confidence'] = np.exp(-(alpha * interval_term + beta * normalized_distance))
65
+
66
+ return df
67
+
68
+
29
69
  # Function to check if dataframe is empty
30
70
  def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
31
71
  """
@@ -98,7 +138,7 @@ def convert_categorical_types(df: pd.DataFrame, features: list, category_mapping
98
138
 
99
139
 
100
140
  def decompress_features(
101
- df: pd.DataFrame, features: List[str], compressed_features: List[str]
141
+ df: pd.DataFrame, features: List[str], compressed_features: List[str]
102
142
  ) -> Tuple[pd.DataFrame, List[str]]:
103
143
  """Prepare features for the model by decompressing bitstring features
104
144
 
@@ -302,6 +342,46 @@ if __name__ == "__main__":
302
342
  widths = y_pis[:, 1, 0] - y_pis[:, 0, 0]
303
343
  print(f" {conf_level * 100:.0f}% CI: Mean width={np.mean(widths):.3f}, Std={np.std(widths):.3f}")
304
344
 
345
+ # Compute normalization statistics for confidence calculation
346
+ print(f"\nComputing normalization statistics for confidence scores...")
347
+
348
+ # Create a temporary validation dataframe with predictions
349
+ temp_val_df = df_val.copy()
350
+ temp_val_df["prediction"] = xgb_model.predict(X_validate)
351
+
352
+ # Add all quantile predictions
353
+ for conf_level in confidence_levels:
354
+ model_name = f"mapie_{conf_level:.2f}"
355
+ model = mapie_models[model_name]
356
+ y_pred, y_pis = model.predict_interval(X_validate)
357
+
358
+ if conf_level == 0.50:
359
+ temp_val_df["q_25"] = y_pis[:, 0, 0]
360
+ temp_val_df["q_75"] = y_pis[:, 1, 0]
361
+ # y_pred is the median prediction
362
+ temp_val_df["q_50"] = y_pred
363
+ elif conf_level == 0.68:
364
+ temp_val_df["q_16"] = y_pis[:, 0, 0]
365
+ temp_val_df["q_84"] = y_pis[:, 1, 0]
366
+ elif conf_level == 0.80:
367
+ temp_val_df["q_10"] = y_pis[:, 0, 0]
368
+ temp_val_df["q_90"] = y_pis[:, 1, 0]
369
+ elif conf_level == 0.90:
370
+ temp_val_df["q_05"] = y_pis[:, 0, 0]
371
+ temp_val_df["q_95"] = y_pis[:, 1, 0]
372
+ elif conf_level == 0.95:
373
+ temp_val_df["q_025"] = y_pis[:, 0, 0]
374
+ temp_val_df["q_975"] = y_pis[:, 1, 0]
375
+
376
+ # Compute normalization stats using q_10 and q_90 (default range)
377
+ interval_width = (temp_val_df["q_90"] - temp_val_df["q_10"]).abs()
378
+ median_interval_width = float(interval_width.median())
379
+ print(f" Median interval width (q_10-q_90): {median_interval_width:.6f}")
380
+
381
+ # Save median interval width for confidence calculation
382
+ with open(os.path.join(args.model_dir, "median_interval_width.json"), "w") as fp:
383
+ json.dump(median_interval_width, fp)
384
+
305
385
  # Save the trained XGBoost model
306
386
  joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
307
387
 
@@ -365,11 +445,19 @@ def model_fn(model_dir) -> dict:
365
445
  with open(category_path) as fp:
366
446
  category_mappings = json.load(fp)
367
447
 
448
+ # Load median interval width for confidence calculation
449
+ median_interval_width = None
450
+ median_width_path = os.path.join(model_dir, "median_interval_width.json")
451
+ if os.path.exists(median_width_path):
452
+ with open(median_width_path) as fp:
453
+ median_interval_width = json.load(fp)
454
+
368
455
  return {
369
456
  "xgb_model": xgb_model,
370
457
  "mapie_models": mapie_models,
371
458
  "confidence_levels": config["confidence_levels"],
372
459
  "category_mappings": category_mappings,
460
+ "median_interval_width": median_interval_width,
373
461
  }
374
462
 
375
463
 
@@ -449,6 +537,8 @@ def predict_fn(df, models) -> pd.DataFrame:
449
537
  if conf_level == 0.50: # 50% CI
450
538
  df["q_25"] = y_pis[:, 0, 0]
451
539
  df["q_75"] = y_pis[:, 1, 0]
540
+ # y_pred is the median prediction
541
+ df["q_50"] = y_pred
452
542
  elif conf_level == 0.68: # 68% CI
453
543
  df["q_16"] = y_pis[:, 0, 0]
454
544
  df["q_84"] = y_pis[:, 1, 0]
@@ -462,14 +552,11 @@ def predict_fn(df, models) -> pd.DataFrame:
462
552
  df["q_025"] = y_pis[:, 0, 0]
463
553
  df["q_975"] = y_pis[:, 1, 0]
464
554
 
465
- # Add median (q_50) from XGBoost prediction
466
- df["q_50"] = df["prediction"]
467
-
468
555
  # Calculate a pseudo-standard deviation from the 68% interval width
469
556
  df["prediction_std"] = (df["q_84"] - df["q_16"]).abs() / 2.0
470
557
 
471
558
  # Reorder the quantile columns for easier reading
472
- quantile_cols = ["q_025", "q_05", "q_10", "q_16", "q_25", "q_75", "q_84", "q_90", "q_95", "q_975"]
559
+ quantile_cols = ["q_025", "q_05", "q_10", "q_16", "q_25", "q_50", "q_75", "q_84", "q_90", "q_95", "q_975"]
473
560
  other_cols = [col for col in df.columns if col not in quantile_cols]
474
561
  df = df[other_cols + quantile_cols]
475
562
 
@@ -489,4 +576,14 @@ def predict_fn(df, models) -> pd.DataFrame:
489
576
  df["q_95"] = np.maximum(df["q_95"], df["prediction"])
490
577
  df["q_975"] = np.maximum(df["q_975"], df["prediction"])
491
578
 
579
+ # Compute confidence scores using pre-computed normalization stats
580
+ df = compute_confidence(
581
+ df,
582
+ lower_q="q_10",
583
+ upper_q="q_90",
584
+ alpha=1.0,
585
+ beta=1.0,
586
+ median_interval_width=models["median_interval_width"],
587
+ )
588
+
492
589
  return df
@@ -28,11 +28,11 @@ from typing import List, Tuple
28
28
 
29
29
  # Template Parameters
30
30
  TEMPLATE_PARAMS = {
31
- "model_type": "regressor",
32
- "target": "solubility",
33
- "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
31
+ "model_type": "classifier",
32
+ "target": "wine_class",
33
+ "features": ['alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'magnesium', 'total_phenols', 'flavanoids', 'nonflavanoid_phenols', 'proanthocyanins', 'color_intensity', 'hue', 'od280_od315_of_diluted_wines', 'proline'],
34
34
  "compressed_features": [],
35
- "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-regression/training",
35
+ "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/wine-classification/training",
36
36
  "train_all_data": False,
37
37
  "hyperparameters": {},
38
38
  }
@@ -525,7 +525,7 @@ class WorkbenchShell:
525
525
  def get_meta(self):
526
526
  return self.meta
527
527
 
528
- def plot_manager(self, data, plot_type: str = "table", **kwargs):
528
+ def plot_manager(self, data, plot_type: str = "scatter", **kwargs):
529
529
  """Plot Manager for Workbench"""
530
530
  from workbench.web_interface.components.plugins import ag_table, graph_plot, scatter_plot
531
531
 
@@ -564,10 +564,10 @@ class WorkbenchShell:
564
564
 
565
565
  plugin_test = PluginUnitTest(plugin_class, theme=theme, input_data=data, **kwargs)
566
566
 
567
- # Run the server and open in the browser
568
- plugin_test.run()
567
+ # Open the browser and run the dash server
569
568
  url = f"http://127.0.0.1:{plugin_test.port}"
570
569
  webbrowser.open(url)
570
+ plugin_test.run()
571
571
 
572
572
 
573
573
  # Launch Shell Entry Point
@@ -113,9 +113,16 @@ def proximity_model_local(model: "Model"):
113
113
  fs = FeatureSet(model.get_input())
114
114
  id_column = fs.id_column
115
115
 
116
- # Create the Proximity Model from our Training Data
117
- df = model.training_view().pull_dataframe()
118
- return Proximity(df, id_column, features, target, track_columns=features)
116
+ # Create the Proximity Model from both the full FeatureSet and the Model training data
117
+ full_df = fs.pull_dataframe()
118
+ model_df = model.training_view().pull_dataframe()
119
+
120
+ # Mark rows that are in the model
121
+ model_ids = set(model_df[id_column])
122
+ full_df["in_model"] = full_df[id_column].isin(model_ids)
123
+
124
+ # Create and return the Proximity Model
125
+ return Proximity(full_df, id_column, features, target, track_columns=features)
119
126
 
120
127
 
121
128
  def proximity_model(model: "Model", prox_model_name: str, track_columns: list = None) -> "Model":
@@ -165,9 +172,6 @@ def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -
165
172
  """
166
173
  from workbench.api import Model, ModelType, FeatureSet # noqa: F401 (avoid circular import)
167
174
 
168
- # Get the custom script path for the UQ model
169
- script_path = get_custom_script_path("uq_models", "mapie.template")
170
-
171
175
  # Get Feature and Target Columns from the existing given Model
172
176
  features = model.features()
173
177
  target = model.target()
@@ -182,7 +186,6 @@ def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -
182
186
  description=f"UQ Model for {model.name}",
183
187
  tags=["uq", model.name],
184
188
  train_all_data=train_all_data,
185
- custom_script=script_path,
186
189
  custom_args={"id_column": fs.id_column, "track_columns": [target]},
187
190
  )
188
191
  return uq_model
@@ -7,12 +7,11 @@ import joblib
7
7
  import pickle
8
8
  import glob
9
9
  import awswrangler as wr
10
- from typing import Optional, List, Tuple
10
+ from typing import Optional, List, Tuple, Any
11
11
  import hashlib
12
12
  import pandas as pd
13
13
  import numpy as np
14
14
  import xgboost as xgb
15
- from typing import Dict, Any
16
15
  from sklearn.model_selection import KFold, StratifiedKFold
17
16
  from sklearn.metrics import (
18
17
  precision_recall_fscore_support,
@@ -20,13 +19,14 @@ from sklearn.metrics import (
20
19
  mean_absolute_error,
21
20
  r2_score,
22
21
  median_absolute_error,
22
+ roc_auc_score,
23
23
  )
24
24
  from scipy.stats import spearmanr
25
25
  from sklearn.preprocessing import LabelEncoder
26
26
 
27
27
  # Workbench Imports
28
28
  from workbench.utils.model_utils import load_category_mappings_from_s3, safe_extract_tarfile
29
- from workbench.utils.pandas_utils import convert_categorical_types
29
+ from workbench.utils.pandas_utils import convert_categorical_types, expand_proba_column
30
30
 
31
31
  # Set up the log
32
32
  log = logging.getLogger("workbench")
@@ -258,7 +258,7 @@ def leaf_stats(df: pd.DataFrame, target_col: str) -> pd.DataFrame:
258
258
  return result_df
259
259
 
260
260
 
261
- def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[str, Any], pd.DataFrame]:
261
+ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[pd.DataFrame, pd.DataFrame]:
262
262
  """
263
263
  Performs K-fold cross-validation with detailed metrics.
264
264
  Args:
@@ -266,10 +266,8 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
266
266
  nfolds: Number of folds for cross-validation (default is 5)
267
267
  Returns:
268
268
  Tuple of:
269
- - Dictionary containing:
270
- - folds: Dictionary of formatted strings for each fold
271
- - summary_metrics: Summary metrics across folds
272
- - DataFrame with columns: id, target, prediction (out-of-fold predictions for all samples)
269
+ - DataFrame with per-class metrics (and 'all' row for overall metrics)
270
+ - DataFrame with columns: id, target, prediction, and *_proba columns (for classifiers)
273
271
  """
274
272
  from workbench.api import FeatureSet
275
273
 
@@ -278,7 +276,7 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
278
276
  loaded_model = xgboost_model_from_s3(model_artifact_uri)
279
277
  if loaded_model is None:
280
278
  log.error("No XGBoost model found in the artifact.")
281
- return {}, pd.DataFrame()
279
+ return pd.DataFrame(), pd.DataFrame()
282
280
 
283
281
  # Check if we got a full sklearn model or need to create one
284
282
  if isinstance(loaded_model, (xgb.XGBClassifier, xgb.XGBRegressor)):
@@ -304,7 +302,7 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
304
302
  xgb_model._Booster = loaded_model
305
303
  else:
306
304
  log.error(f"Unexpected model type: {type(loaded_model)}")
307
- return {}, pd.DataFrame()
305
+ return pd.DataFrame(), pd.DataFrame()
308
306
 
309
307
  # Prepare data
310
308
  fs = FeatureSet(workbench_model.get_input())
@@ -335,12 +333,12 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
335
333
  y_for_cv = y
336
334
 
337
335
  # Prepare KFold
336
+ # Note: random_state=42 seems to not actually give us reproducible results
338
337
  kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
339
338
 
340
339
  # Initialize results collection
341
340
  fold_metrics = []
342
- predictions_df = pd.DataFrame({id_col: ids, target_col: y}) # Keep original values
343
- # Note: 'prediction' column will be created automatically with correct dtype
341
+ predictions_df = pd.DataFrame({id_col: ids, target_col: y})
344
342
 
345
343
  # Perform cross-validation
346
344
  for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y_for_cv), 1):
@@ -355,6 +353,8 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
355
353
  val_indices = X_val.index
356
354
  if is_classifier:
357
355
  predictions_df.loc[val_indices, "prediction"] = label_encoder.inverse_transform(preds.astype(int))
356
+ y_proba = xgb_model.predict_proba(X_val)
357
+ predictions_df.loc[val_indices, "pred_proba"] = pd.Series(y_proba.tolist(), index=val_indices)
358
358
  else:
359
359
  predictions_df.loc[val_indices, "prediction"] = preds
360
360
 
@@ -362,10 +362,34 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
362
362
  if is_classifier:
363
363
  y_val_orig = label_encoder.inverse_transform(y_val)
364
364
  preds_orig = label_encoder.inverse_transform(preds.astype(int))
365
+
366
+ # Overall weighted metrics
365
367
  prec, rec, f1, _ = precision_recall_fscore_support(
366
368
  y_val_orig, preds_orig, average="weighted", zero_division=0
367
369
  )
368
- fold_metrics.append({"fold": fold_idx, "precision": prec, "recall": rec, "fscore": f1})
370
+
371
+ # Per-class F1
372
+ prec_per_class, rec_per_class, f1_per_class, _ = precision_recall_fscore_support(
373
+ y_val_orig, preds_orig, average=None, zero_division=0, labels=label_encoder.classes_
374
+ )
375
+
376
+ # ROC-AUC (overall and per-class)
377
+ roc_auc_overall = roc_auc_score(y_val, y_proba, multi_class="ovr", average="macro")
378
+ roc_auc_per_class = roc_auc_score(y_val, y_proba, multi_class="ovr", average=None)
379
+
380
+ fold_metrics.append(
381
+ {
382
+ "fold": fold_idx,
383
+ "precision": prec,
384
+ "recall": rec,
385
+ "f1": f1,
386
+ "roc_auc": roc_auc_overall,
387
+ "precision_per_class": prec_per_class,
388
+ "recall_per_class": rec_per_class,
389
+ "f1_per_class": f1_per_class,
390
+ "roc_auc_per_class": roc_auc_per_class,
391
+ }
392
+ )
369
393
  else:
370
394
  spearman_corr, _ = spearmanr(y_val, preds)
371
395
  fold_metrics.append(
@@ -379,32 +403,67 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[st
379
403
  }
380
404
  )
381
405
 
382
- # Calculate summary metrics (mean ± std)
406
+ # Calculate summary metrics
383
407
  fold_df = pd.DataFrame(fold_metrics)
384
- metric_names = ["precision", "recall", "fscore"] if is_classifier else ["rmse", "mae", "medae", "r2", "spearmanr"]
385
- summary_metrics = {metric: f"{fold_df[metric].mean():.3f} ±{fold_df[metric].std():.3f}" for metric in metric_names}
386
408
 
387
- # Format fold results for display
388
- formatted_folds = {}
389
- for _, row in fold_df.iterrows():
390
- fold_key = f"Fold {int(row['fold'])}"
391
- if is_classifier:
392
- formatted_folds[fold_key] = (
393
- f"precision: {row['precision']:.3f} " f"recall: {row['recall']:.3f} " f"fscore: {row['fscore']:.3f}"
394
- )
395
- else:
396
- formatted_folds[fold_key] = (
397
- f"rmse: {row['rmse']:.3f} "
398
- f"mae: {row['mae']:.3f} "
399
- f"medae: {row['medae']:.3f} "
400
- f"r2: {row['r2']:.3f} "
401
- f"spearmanr: {row['spearmanr']:.3f}"
409
+ if is_classifier:
410
+ # Expand the *_proba columns into separate columns for easier handling
411
+ predictions_df = expand_proba_column(predictions_df, label_encoder.classes_)
412
+
413
+ # Build per-class metrics DataFrame
414
+ metric_rows = []
415
+
416
+ # Per-class rows
417
+ for idx, class_name in enumerate(label_encoder.classes_):
418
+ prec_scores = np.array([fold["precision_per_class"][idx] for fold in fold_metrics])
419
+ rec_scores = np.array([fold["recall_per_class"][idx] for fold in fold_metrics])
420
+ f1_scores = np.array([fold["f1_per_class"][idx] for fold in fold_metrics])
421
+ roc_auc_scores = np.array([fold["roc_auc_per_class"][idx] for fold in fold_metrics])
422
+
423
+ y_orig = label_encoder.inverse_transform(y_for_cv)
424
+ support = int((y_orig == class_name).sum())
425
+
426
+ metric_rows.append(
427
+ {
428
+ "class": class_name,
429
+ "precision": prec_scores.mean(),
430
+ "recall": rec_scores.mean(),
431
+ "f1": f1_scores.mean(),
432
+ "roc_auc": roc_auc_scores.mean(),
433
+ "support": support,
434
+ }
402
435
  )
403
436
 
404
- # Build return dictionary
405
- metrics_dict = {"summary_metrics": summary_metrics, "folds": formatted_folds}
437
+ # Overall 'all' row
438
+ metric_rows.append(
439
+ {
440
+ "class": "all",
441
+ "precision": fold_df["precision"].mean(),
442
+ "recall": fold_df["recall"].mean(),
443
+ "f1": fold_df["f1"].mean(),
444
+ "roc_auc": fold_df["roc_auc"].mean(),
445
+ "support": len(y_for_cv),
446
+ }
447
+ )
448
+
449
+ metrics_df = pd.DataFrame(metric_rows)
450
+
451
+ else:
452
+ # Regression metrics
453
+ metrics_df = pd.DataFrame(
454
+ [
455
+ {
456
+ "rmse": fold_df["rmse"].mean(),
457
+ "mae": fold_df["mae"].mean(),
458
+ "medae": fold_df["medae"].mean(),
459
+ "r2": fold_df["r2"].mean(),
460
+ "spearmanr": fold_df["spearmanr"].mean(),
461
+ "support": len(y_for_cv),
462
+ }
463
+ ]
464
+ )
406
465
 
407
- return metrics_dict, predictions_df
466
+ return metrics_df, predictions_df
408
467
 
409
468
 
410
469
  def leave_one_out_inference(workbench_model: Any) -> pd.DataFrame:
@@ -156,10 +156,13 @@ class PluginUnitTest:
156
156
  """Run the Dash server for the plugin, handling common errors gracefully."""
157
157
  while self.is_port_in_use(self.port):
158
158
  log.info(f"Port {self.port} is in use. Trying the next one...")
159
- self.port += 1 # Increment the port number until an available one is found
159
+ self.port += 1
160
160
 
161
161
  log.info(f"Starting Dash server on port {self.port}...")
162
- self.app.run(debug=True, use_reloader=False, port=self.port)
162
+ try:
163
+ self.app.run(debug=True, use_reloader=False, port=self.port)
164
+ except KeyboardInterrupt:
165
+ log.info("Shutting down Dash server...")
163
166
 
164
167
  @staticmethod
165
168
  def is_port_in_use(port):
@@ -45,8 +45,6 @@ class ModelDetails(PluginInterface):
45
45
  html.H5(children="Inference Metrics", style={"marginTop": "20px"}),
46
46
  dcc.Dropdown(id=f"{self.component_id}-dropdown", className="dropdown"),
47
47
  dcc.Markdown(id=f"{self.component_id}-metrics"),
48
- html.H5(children="Cross Fold Metrics", style={"marginTop": "20px"}),
49
- dcc.Markdown(id=f"{self.component_id}-cross-metrics", dangerously_allow_html=True),
50
48
  ],
51
49
  )
52
50
 
@@ -57,7 +55,6 @@ class ModelDetails(PluginInterface):
57
55
  (f"{self.component_id}-dropdown", "options"),
58
56
  (f"{self.component_id}-dropdown", "value"),
59
57
  (f"{self.component_id}-metrics", "children"),
60
- (f"{self.component_id}-cross-metrics", "children"),
61
58
  ]
62
59
  self.signals = [(f"{self.component_id}-dropdown", "value")]
63
60
 
@@ -84,10 +81,9 @@ class ModelDetails(PluginInterface):
84
81
  # Populate the inference runs dropdown
85
82
  inference_runs, default_run = self.get_inference_runs()
86
83
  metrics = self.inference_metrics(default_run)
87
- cross_metrics = self.cross_metrics()
88
84
 
89
85
  # Return the updated property values for the plugin
90
- return [header, details, inference_runs, default_run, metrics, cross_metrics]
86
+ return [header, details, inference_runs, default_run, metrics]
91
87
 
92
88
  def register_internal_callbacks(self):
93
89
  @callback(
@@ -225,6 +221,7 @@ class ModelDetails(PluginInterface):
225
221
 
226
222
  def cross_metrics(self) -> str:
227
223
  # Get cross fold metrics if they exist
224
+ # Note: Currently not used since we show cross fold metrics in the dropdown
228
225
  model_name = self.current_model.name
229
226
  cross_fold_data = self.params.get(f"/workbench/models/{model_name}/inference/cross_fold", warn=False)
230
227
  if not cross_fold_data:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: workbench
3
- Version: 0.8.193
3
+ Version: 0.8.197
4
4
  Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
5
5
  Author-email: SuperCowPowers LLC <support@supercowpowers.com>
6
6
  License: MIT License