workbench 0.8.205__py3-none-any.whl → 0.8.212__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. workbench/algorithms/models/noise_model.py +388 -0
  2. workbench/api/endpoint.py +3 -6
  3. workbench/api/feature_set.py +1 -1
  4. workbench/api/model.py +5 -11
  5. workbench/cached/cached_model.py +4 -4
  6. workbench/core/artifacts/endpoint_core.py +57 -145
  7. workbench/core/artifacts/model_core.py +21 -19
  8. workbench/core/transforms/features_to_model/features_to_model.py +2 -2
  9. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +1 -1
  10. workbench/model_script_utils/model_script_utils.py +335 -0
  11. workbench/model_script_utils/pytorch_utils.py +395 -0
  12. workbench/model_script_utils/uq_harness.py +278 -0
  13. workbench/model_scripts/chemprop/chemprop.template +289 -666
  14. workbench/model_scripts/chemprop/generated_model_script.py +292 -669
  15. workbench/model_scripts/chemprop/model_script_utils.py +335 -0
  16. workbench/model_scripts/chemprop/requirements.txt +2 -10
  17. workbench/model_scripts/pytorch_model/generated_model_script.py +355 -612
  18. workbench/model_scripts/pytorch_model/model_script_utils.py +335 -0
  19. workbench/model_scripts/pytorch_model/pytorch.template +350 -607
  20. workbench/model_scripts/pytorch_model/pytorch_utils.py +395 -0
  21. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  22. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  23. workbench/model_scripts/script_generation.py +2 -5
  24. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  25. workbench/model_scripts/xgb_model/generated_model_script.py +349 -412
  26. workbench/model_scripts/xgb_model/model_script_utils.py +335 -0
  27. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  28. workbench/model_scripts/xgb_model/xgb_model.template +344 -407
  29. workbench/scripts/training_test.py +85 -0
  30. workbench/utils/chemprop_utils.py +18 -656
  31. workbench/utils/metrics_utils.py +172 -0
  32. workbench/utils/model_utils.py +104 -47
  33. workbench/utils/pytorch_utils.py +32 -472
  34. workbench/utils/xgboost_local_crossfold.py +267 -0
  35. workbench/utils/xgboost_model_utils.py +49 -356
  36. workbench/web_interface/components/plugins/model_details.py +30 -68
  37. {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/METADATA +5 -5
  38. {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/RECORD +42 -31
  39. {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/entry_points.txt +1 -0
  40. workbench/model_scripts/uq_models/mapie.template +0 -605
  41. workbench/model_scripts/uq_models/requirements.txt +0 -1
  42. {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/WHEEL +0 -0
  43. {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/licenses/LICENSE +0 -0
  44. {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/top_level.txt +0 -0
@@ -1,32 +1,23 @@
1
1
  """XGBoost Model Utilities"""
2
2
 
3
+ import glob
4
+ import hashlib
3
5
  import logging
4
6
  import os
5
- import tempfile
6
- import joblib
7
7
  import pickle
8
- import glob
8
+ import tempfile
9
+ from typing import Any, List, Optional, Tuple
10
+
9
11
  import awswrangler as wr
10
- from typing import Optional, List, Tuple, Any
11
- import hashlib
12
+ import joblib
12
13
  import pandas as pd
13
- import numpy as np
14
14
  import xgboost as xgb
15
- from sklearn.model_selection import KFold, StratifiedKFold
16
- from sklearn.metrics import (
17
- precision_recall_fscore_support,
18
- mean_squared_error,
19
- mean_absolute_error,
20
- r2_score,
21
- median_absolute_error,
22
- roc_auc_score,
23
- )
24
- from scipy.stats import spearmanr
25
- from sklearn.preprocessing import LabelEncoder
26
15
 
27
16
  # Workbench Imports
17
+ from workbench.utils.aws_utils import pull_s3_data
18
+ from workbench.utils.metrics_utils import compute_metrics_from_predictions
28
19
  from workbench.utils.model_utils import load_category_mappings_from_s3, safe_extract_tarfile
29
- from workbench.utils.pandas_utils import convert_categorical_types, expand_proba_column
20
+ from workbench.utils.pandas_utils import convert_categorical_types
30
21
 
31
22
  # Set up the log
32
23
  log = logging.getLogger("workbench")
@@ -258,327 +249,45 @@ def leaf_stats(df: pd.DataFrame, target_col: str) -> pd.DataFrame:
258
249
  return result_df
259
250
 
260
251
 
261
- def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[pd.DataFrame, pd.DataFrame]:
262
- """
263
- Performs K-fold cross-validation with detailed metrics.
252
+ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
253
+ """Pull cross-validation results from AWS training artifacts.
254
+
255
+ This retrieves the validation predictions saved during model training and
256
+ computes metrics directly from them. For XGBoost models trained with
257
+ n_folds > 1, these are out-of-fold predictions from k-fold cross-validation.
258
+
264
259
  Args:
265
260
  workbench_model: Workbench model object
266
- nfolds: Number of folds for cross-validation (default is 5)
261
+
267
262
  Returns:
268
263
  Tuple of:
269
- - DataFrame with per-class metrics (and 'all' row for overall metrics)
270
- - DataFrame with columns: id, target, prediction, and *_proba columns (for classifiers)
264
+ - DataFrame with computed metrics
265
+ - DataFrame with validation predictions
271
266
  """
272
- from workbench.api import FeatureSet
273
-
274
- # Load model
275
- model_artifact_uri = workbench_model.model_data_url()
276
- loaded_model = xgboost_model_from_s3(model_artifact_uri)
277
- if loaded_model is None:
278
- log.error("No XGBoost model found in the artifact.")
279
- return pd.DataFrame(), pd.DataFrame()
280
-
281
- # Check if we got a full sklearn model or need to create one
282
- if isinstance(loaded_model, (xgb.XGBClassifier, xgb.XGBRegressor)):
283
- is_classifier = isinstance(loaded_model, xgb.XGBClassifier)
284
-
285
- # Get the model's hyperparameters and ensure enable_categorical=True
286
- params = loaded_model.get_params()
287
- params["enable_categorical"] = True
288
-
289
- # Create new model with same params but enable_categorical=True
290
- if is_classifier:
291
- xgb_model = xgb.XGBClassifier(**params)
292
- else:
293
- xgb_model = xgb.XGBRegressor(**params)
294
-
295
- elif isinstance(loaded_model, xgb.Booster):
296
- # Legacy: got a booster, need to wrap it
297
- log.warning("Deprecated: Loaded model is a Booster, wrapping in sklearn model.")
298
- is_classifier = workbench_model.model_type.value == "classifier"
299
- xgb_model = (
300
- xgb.XGBClassifier(enable_categorical=True) if is_classifier else xgb.XGBRegressor(enable_categorical=True)
301
- )
302
- xgb_model._Booster = loaded_model
303
- else:
304
- log.error(f"Unexpected model type: {type(loaded_model)}")
305
- return pd.DataFrame(), pd.DataFrame()
306
-
307
- # Prepare data
308
- fs = FeatureSet(workbench_model.get_input())
309
- df = workbench_model.training_view().pull_dataframe()
310
-
311
- # Extract sample weights if present
312
- sample_weights = df.get("sample_weight")
313
- if sample_weights is not None:
314
- log.info(f"Using sample weights: min={sample_weights.min():.2f}, max={sample_weights.max():.2f}")
315
-
316
- # Get columns
317
- id_col = fs.id_column
318
- target_col = workbench_model.target()
319
- feature_cols = workbench_model.features()
320
- print(f"Target column: {target_col}")
321
- print(f"Feature columns: {len(feature_cols)} features")
322
-
323
- # Convert string[python] to object, then to category for XGBoost compatibility
324
- for col in feature_cols:
325
- if pd.api.types.is_string_dtype(df[col]):
326
- df[col] = df[col].astype("object").astype("category")
327
-
328
- X = df[feature_cols]
329
- y = df[target_col]
330
- ids = df[id_col]
331
-
332
- # Encode target if classifier
333
- label_encoder = LabelEncoder() if is_classifier else None
334
- if label_encoder:
335
- y_encoded = label_encoder.fit_transform(y)
336
- y_for_cv = pd.Series(y_encoded, index=y.index, name=target_col)
337
- else:
338
- y_for_cv = y
339
-
340
- # Prepare KFold
341
- kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
342
-
343
- # Initialize results collection
344
- fold_metrics = []
345
- predictions_df = pd.DataFrame({id_col: ids, target_col: y})
346
-
347
- # Perform cross-validation
348
- for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y_for_cv), 1):
349
- X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
350
- y_train, y_val = y_for_cv.iloc[train_idx], y_for_cv.iloc[val_idx]
351
-
352
- # Get sample weights for training fold
353
- weights_train = sample_weights.iloc[train_idx] if sample_weights is not None else None
354
-
355
- # Train and predict
356
- xgb_model.fit(X_train, y_train, sample_weight=weights_train)
357
- preds = xgb_model.predict(X_val)
358
-
359
- # Store predictions (decode if classifier)
360
- val_indices = X_val.index
361
- if is_classifier:
362
- predictions_df.loc[val_indices, "prediction"] = label_encoder.inverse_transform(preds.astype(int))
363
- y_proba = xgb_model.predict_proba(X_val)
364
- predictions_df.loc[val_indices, "pred_proba"] = pd.Series(y_proba.tolist(), index=val_indices)
365
- else:
366
- predictions_df.loc[val_indices, "prediction"] = preds
367
-
368
- # Calculate fold metrics
369
- if is_classifier:
370
- y_val_orig = label_encoder.inverse_transform(y_val)
371
- preds_orig = label_encoder.inverse_transform(preds.astype(int))
372
-
373
- # Overall weighted metrics
374
- prec, rec, f1, _ = precision_recall_fscore_support(
375
- y_val_orig, preds_orig, average="weighted", zero_division=0
376
- )
377
-
378
- # Per-class F1
379
- prec_per_class, rec_per_class, f1_per_class, _ = precision_recall_fscore_support(
380
- y_val_orig, preds_orig, average=None, zero_division=0, labels=label_encoder.classes_
381
- )
382
-
383
- # ROC-AUC (overall and per-class)
384
- roc_auc_overall = roc_auc_score(y_val, y_proba, multi_class="ovr", average="macro")
385
- roc_auc_per_class = roc_auc_score(y_val, y_proba, multi_class="ovr", average=None)
386
-
387
- fold_metrics.append(
388
- {
389
- "fold": fold_idx,
390
- "precision": prec,
391
- "recall": rec,
392
- "f1": f1,
393
- "roc_auc": roc_auc_overall,
394
- "precision_per_class": prec_per_class,
395
- "recall_per_class": rec_per_class,
396
- "f1_per_class": f1_per_class,
397
- "roc_auc_per_class": roc_auc_per_class,
398
- }
399
- )
400
- else:
401
- spearman_corr, _ = spearmanr(y_val, preds)
402
- fold_metrics.append(
403
- {
404
- "fold": fold_idx,
405
- "rmse": np.sqrt(mean_squared_error(y_val, preds)),
406
- "mae": mean_absolute_error(y_val, preds),
407
- "medae": median_absolute_error(y_val, preds),
408
- "r2": r2_score(y_val, preds),
409
- "spearmanr": spearman_corr,
410
- }
411
- )
412
-
413
- # Calculate summary metrics
414
- fold_df = pd.DataFrame(fold_metrics)
415
-
416
- if is_classifier:
417
- # Expand the *_proba columns into separate columns for easier handling
418
- predictions_df = expand_proba_column(predictions_df, label_encoder.classes_)
419
-
420
- # Build per-class metrics DataFrame
421
- metric_rows = []
422
-
423
- # Per-class rows
424
- for idx, class_name in enumerate(label_encoder.classes_):
425
- prec_scores = np.array([fold["precision_per_class"][idx] for fold in fold_metrics])
426
- rec_scores = np.array([fold["recall_per_class"][idx] for fold in fold_metrics])
427
- f1_scores = np.array([fold["f1_per_class"][idx] for fold in fold_metrics])
428
- roc_auc_scores = np.array([fold["roc_auc_per_class"][idx] for fold in fold_metrics])
429
-
430
- y_orig = label_encoder.inverse_transform(y_for_cv)
431
- support = int((y_orig == class_name).sum())
432
-
433
- metric_rows.append(
434
- {
435
- "class": class_name,
436
- "precision": prec_scores.mean(),
437
- "recall": rec_scores.mean(),
438
- "f1": f1_scores.mean(),
439
- "roc_auc": roc_auc_scores.mean(),
440
- "support": support,
441
- }
442
- )
443
-
444
- # Overall 'all' row
445
- metric_rows.append(
446
- {
447
- "class": "all",
448
- "precision": fold_df["precision"].mean(),
449
- "recall": fold_df["recall"].mean(),
450
- "f1": fold_df["f1"].mean(),
451
- "roc_auc": fold_df["roc_auc"].mean(),
452
- "support": len(y_for_cv),
453
- }
454
- )
455
-
456
- metrics_df = pd.DataFrame(metric_rows)
457
-
458
- else:
459
- # Regression metrics
460
- metrics_df = pd.DataFrame(
461
- [
462
- {
463
- "rmse": fold_df["rmse"].mean(),
464
- "mae": fold_df["mae"].mean(),
465
- "medae": fold_df["medae"].mean(),
466
- "r2": fold_df["r2"].mean(),
467
- "spearmanr": fold_df["spearmanr"].mean(),
468
- "support": len(y_for_cv),
469
- }
470
- ]
471
- )
267
+ # Get the validation predictions from S3
268
+ s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
269
+ predictions_df = pull_s3_data(s3_path)
472
270
 
473
- return metrics_df, predictions_df
271
+ if predictions_df is None:
272
+ raise ValueError(f"No validation predictions found at {s3_path}")
474
273
 
274
+ log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
475
275
 
476
- def leave_one_out_inference(workbench_model: Any) -> pd.DataFrame:
477
- """
478
- Performs leave-one-out cross-validation (parallelized).
479
- For datasets > 1000 rows, first identifies top 100 worst predictions via 10-fold CV,
480
- then performs true leave-one-out on those 100 samples.
481
- Each model trains on ALL data except one sample.
482
- """
483
- from workbench.api import FeatureSet
484
- from joblib import Parallel, delayed
485
- from tqdm import tqdm
276
+ # Compute metrics from predictions
277
+ target = workbench_model.target()
278
+ class_labels = workbench_model.class_labels()
486
279
 
487
- def train_and_predict_one(model_params, is_classifier, X, y, train_idx, val_idx):
488
- """Train on train_idx, predict on val_idx."""
489
- model = xgb.XGBClassifier(**model_params) if is_classifier else xgb.XGBRegressor(**model_params)
490
- model.fit(X[train_idx], y[train_idx])
491
- return model.predict(X[val_idx])[0]
492
-
493
- # Load model and get params
494
- model_artifact_uri = workbench_model.model_data_url()
495
- loaded_model = xgboost_model_from_s3(model_artifact_uri)
496
- if loaded_model is None:
497
- log.error("No XGBoost model found in the artifact.")
498
- return pd.DataFrame()
499
-
500
- if isinstance(loaded_model, (xgb.XGBClassifier, xgb.XGBRegressor)):
501
- is_classifier = isinstance(loaded_model, xgb.XGBClassifier)
502
- model_params = loaded_model.get_params()
503
- elif isinstance(loaded_model, xgb.Booster):
504
- log.warning("Deprecated: Loaded model is a Booster, wrapping in sklearn model.")
505
- is_classifier = workbench_model.model_type.value == "classifier"
506
- model_params = {"enable_categorical": True}
507
- else:
508
- log.error(f"Unexpected model type: {type(loaded_model)}")
509
- return pd.DataFrame()
510
-
511
- # Load and prepare data
512
- fs = FeatureSet(workbench_model.get_input())
513
- df = workbench_model.training_view().pull_dataframe()
514
- id_col = fs.id_column
515
- target_col = workbench_model.target()
516
- feature_cols = workbench_model.features()
517
-
518
- # Convert string[python] to object, then to category for XGBoost compatibility
519
- # This avoids XGBoost's issue with pandas 2.x string[python] dtype in categorical categories
520
- for col in feature_cols:
521
- if pd.api.types.is_string_dtype(df[col]):
522
- # Double conversion: string[python] -> object -> category
523
- df[col] = df[col].astype("object").astype("category")
524
-
525
- # Determine which samples to run LOO on
526
- if len(df) > 1000:
527
- log.important(f"Dataset has {len(df)} rows. Running 10-fold CV to identify top 1000 worst predictions...")
528
- _, predictions_df = cross_fold_inference(workbench_model, nfolds=10)
529
- predictions_df["residual_abs"] = np.abs(predictions_df[target_col] - predictions_df["prediction"])
530
- worst_samples = predictions_df.nlargest(1000, "residual_abs")
531
- worst_ids = worst_samples[id_col].values
532
- loo_indices = df[df[id_col].isin(worst_ids)].index.values
533
- log.important(f"Running leave-one-out CV on 1000 worst samples. Each model trains on {len(df)-1} rows...")
280
+ if target in predictions_df.columns and "prediction" in predictions_df.columns:
281
+ metrics_df = compute_metrics_from_predictions(predictions_df, target, class_labels)
534
282
  else:
535
- log.important(f"Running leave-one-out CV on all {len(df)} samples...")
536
- loo_indices = df.index.values
537
-
538
- # Prepare full dataset for training
539
- X_full = df[feature_cols].values
540
- y_full = df[target_col].values
541
-
542
- # Encode target if classifier
543
- label_encoder = LabelEncoder() if is_classifier else None
544
- if label_encoder:
545
- y_full = label_encoder.fit_transform(y_full)
546
-
547
- # Generate LOO splits
548
- splits = []
549
- for loo_idx in loo_indices:
550
- train_idx = np.delete(np.arange(len(X_full)), loo_idx)
551
- val_idx = np.array([loo_idx])
552
- splits.append((train_idx, val_idx))
553
-
554
- # Parallel execution
555
- predictions = Parallel(n_jobs=4)(
556
- delayed(train_and_predict_one)(model_params, is_classifier, X_full, y_full, train_idx, val_idx)
557
- for train_idx, val_idx in tqdm(splits, desc="LOO CV")
558
- )
559
-
560
- # Build results dataframe
561
- predictions_array = np.array(predictions)
562
- if label_encoder:
563
- predictions_array = label_encoder.inverse_transform(predictions_array.astype(int))
564
-
565
- predictions_df = pd.DataFrame(
566
- {
567
- id_col: df.loc[loo_indices, id_col].values,
568
- target_col: df.loc[loo_indices, target_col].values,
569
- "prediction": predictions_array,
570
- }
571
- )
283
+ metrics_df = pd.DataFrame()
572
284
 
573
- predictions_df["residual_abs"] = np.abs(predictions_df[target_col] - predictions_df["prediction"])
574
-
575
- return predictions_df
285
+ return metrics_df, predictions_df
576
286
 
577
287
 
578
288
  if __name__ == "__main__":
579
289
  """Exercise the Model Utilities"""
580
290
  from workbench.api import Model
581
- from pprint import pprint
582
291
 
583
292
  # Test the XGBoost model loading and feature importance
584
293
  model = Model("abalone-regression")
@@ -594,38 +303,22 @@ if __name__ == "__main__":
594
303
  print(f"Model parameters: {xgb_model.get_params()}")
595
304
  print(f"enable_categorical: {xgb_model.enable_categorical}")
596
305
 
597
- # Test with UQ Model
598
- uq_model = Model("aqsol-uq")
599
- _xgb_model = xgboost_model_from_s3(uq_model.model_data_url())
600
-
601
- print("\n=== CROSS FOLD REGRESSION EXAMPLE ===")
306
+ print("\n=== PULL CV RESULTS EXAMPLE ===")
602
307
  model = Model("abalone-regression")
603
- results, df = cross_fold_inference(model)
604
- pprint(results)
605
- print(df.head())
606
-
607
- print("\n=== CROSS FOLD CLASSIFICATION EXAMPLE ===")
308
+ metrics_df, predictions_df = pull_cv_results(model)
309
+ print(f"\nMetrics:\n{metrics_df}")
310
+ print(f"\nPredictions shape: {predictions_df.shape}")
311
+ print(f"Predictions columns: {predictions_df.columns.tolist()}")
312
+ print(predictions_df.head())
313
+
314
+ # Test on a Classifier model
315
+ print("\n=== CLASSIFIER MODEL TEST ===")
608
316
  model = Model("wine-classification")
609
- results, df = cross_fold_inference(model)
610
- pprint(results)
611
- print(df.head())
612
-
613
- # Test XGBoost add_leaf_hash
614
- """
615
- input_df = FeatureSet(model.get_input()).pull_dataframe()
616
- leaf_df = add_leaf_hash(model, input_df)
617
- print("DataFrame with Leaf Hash:")
618
- print(leaf_df)
619
-
620
- # Okay, we're going to copy row 3 and insert it into row 7 to make sure the leaf_hash is the same
621
- input_df.iloc[7] = input_df.iloc[3]
622
- print("DataFrame with Leaf Hash (3 and 7 should match):")
623
- leaf_df = add_leaf_hash(model, input_df)
624
- print(leaf_df)
625
-
626
- # Test leaf_stats
627
- target_col = "class_number_of_rings"
628
- stats_df = leaf_stats(leaf_df, target_col)
629
- print("DataFrame with Leaf Statistics:")
630
- print(stats_df)
631
- """
317
+ features = feature_importance(model)
318
+ print("Feature Importance:")
319
+ print(features)
320
+ metrics_df, predictions_df = pull_cv_results(model)
321
+ print(f"\nMetrics:\n{metrics_df}")
322
+ print(f"\nPredictions shape: {predictions_df.shape}")
323
+ print(f"Predictions columns: {predictions_df.columns.tolist()}")
324
+ print(predictions_df.head())
@@ -41,7 +41,7 @@ class ModelDetails(PluginInterface):
41
41
  id=self.component_id,
42
42
  children=[
43
43
  html.H4(id=f"{self.component_id}-header", children="Model: Loading..."),
44
- dcc.Markdown(id=f"{self.component_id}-summary"),
44
+ dcc.Markdown(id=f"{self.component_id}-summary", dangerously_allow_html=True),
45
45
  html.H5(children="Inference Metrics", style={"marginTop": "20px"}),
46
46
  dcc.Dropdown(id=f"{self.component_id}-dropdown", className="dropdown"),
47
47
  dcc.Markdown(id=f"{self.component_id}-metrics"),
@@ -106,63 +106,37 @@ class ModelDetails(PluginInterface):
106
106
  Returns:
107
107
  str: A markdown string
108
108
  """
109
-
110
- # Get these fields from the model
111
- show_fields = [
112
- "health_tags",
113
- "input",
114
- "workbench_registered_endpoints",
115
- "workbench_model_type",
116
- "workbench_model_target",
117
- "workbench_model_features",
118
- "param_meta",
119
- "workbench_tags",
120
- ]
121
-
122
- # Construct the markdown string
123
109
  summary = self.current_model.summary()
124
110
  markdown = ""
125
- for key in show_fields:
126
-
127
- # Special case for the health tags
128
- if key == "health_tags":
129
- markdown += health_tag_markdown(summary.get(key, []))
130
- continue
131
-
132
- # Special case for the features
133
- if key == "workbench_model_features":
134
- value = summary.get(key, [])
135
- key = "features"
136
- value = f"({len(value)}) {', '.join(value)[:100]}..."
137
- markdown += f"**{key}:** {value} \n"
138
- continue
139
-
140
- # Special case for Parameter Store Metadata
141
- if key == "param_meta":
142
- model_name = summary["name"]
143
- meta_data = self.params.get(f"/workbench/models/{model_name}/meta", warn=False)
144
- if meta_data:
145
- markdown += dict_to_markdown(meta_data, title="Additional Metadata")
146
- continue
147
-
148
- # Special case for tags
149
- if key == "workbench_tags":
150
- tags = summary.get(key, "")
151
- markdown += tags_to_markdown(tags)
152
- continue
153
-
154
- # Get the value
155
- value = summary.get(key, "-")
156
-
157
- # If the value is a list, convert it to a comma-separated string
158
- if isinstance(value, list):
159
- value = ", ".join(value)
160
-
161
- # Chop off the "workbench_" prefix
162
- key = key.replace("workbench_", "")
163
-
164
- # Add to markdown string
165
- markdown += f"**{key}:** {value} \n"
111
+
112
+ # Health tags
113
+ markdown += health_tag_markdown(summary.get("health_tags", []))
114
+
115
+ # Simple fields
116
+ markdown += f"**input:** {summary.get('input', '-')} \n"
117
+ endpoints = ", ".join(summary.get("workbench_registered_endpoints", []))
118
+ markdown += f"**registered_endpoints:** {endpoints or '-'} \n"
119
+ markdown += f"**model_type:** {summary.get('workbench_model_type', '-')} \n"
120
+ markdown += f"**model_target:** {summary.get('workbench_model_target', '-')} \n"
121
+
122
+ # Features (truncated)
123
+ features = summary.get("workbench_model_features", [])
124
+ features_str = f"({len(features)}) {', '.join(features)[:100]}..."
125
+ markdown += f"**features:** {features_str} \n"
126
+
127
+ # Parameter Store metadata
128
+ model_name = summary["name"]
129
+ meta_data = self.params.get(f"/workbench/models/{model_name}/meta", warn=False)
130
+ if meta_data:
131
+ markdown += dict_to_markdown(meta_data, title="Additional Metadata")
132
+
133
+ # Tags
134
+ markdown += tags_to_markdown(summary.get("workbench_tags", "")) + " \n"
135
+
136
+ # Hyperparameters
137
+ hyperparams = summary.get("hyperparameters")
138
+ if hyperparams and isinstance(hyperparams, dict):
139
+ markdown += dict_to_collapsible_html(hyperparams, title="Hyperparameters", collapse_all=True)
166
140
 
167
141
  return markdown
168
142
 
@@ -219,18 +193,6 @@ class ModelDetails(PluginInterface):
219
193
  markdown += dict_to_markdown(inference_data, title="Additional Inference Metrics")
220
194
  return markdown
221
195
 
222
- def cross_metrics(self) -> str:
223
- # Get cross fold metrics if they exist
224
- # Note: Currently not used since we show cross fold metrics in the dropdown
225
- model_name = self.current_model.name
226
- cross_fold_data = self.params.get(f"/workbench/models/{model_name}/inference/cross_fold", warn=False)
227
- if not cross_fold_data:
228
- return "**No Cross Fold Data**"
229
-
230
- # Convert the cross fold data to a markdown string
231
- html = dict_to_collapsible_html(cross_fold_data)
232
- return html
233
-
234
196
  def get_inference_runs(self):
235
197
  """Get the inference runs for the model
236
198
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: workbench
3
- Version: 0.8.205
3
+ Version: 0.8.212
4
4
  Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
5
5
  Author-email: SuperCowPowers LLC <support@supercowpowers.com>
6
6
  License: MIT License
@@ -52,10 +52,10 @@ Requires-Dist: joblib>=1.3.2
52
52
  Requires-Dist: requests>=2.26.0
53
53
  Requires-Dist: rdkit>=2024.9.5
54
54
  Requires-Dist: mordredcommunity>=2.0.6
55
- Requires-Dist: workbench-bridges>=0.1.15
55
+ Requires-Dist: workbench-bridges>=0.1.16
56
56
  Provides-Extra: ui
57
57
  Requires-Dist: plotly>=6.0.0; extra == "ui"
58
- Requires-Dist: dash>3.0.0; extra == "ui"
58
+ Requires-Dist: dash>=3.0.0; extra == "ui"
59
59
  Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "ui"
60
60
  Requires-Dist: dash-bootstrap-templates>=1.3.0; extra == "ui"
61
61
  Requires-Dist: dash_ag_grid; extra == "ui"
@@ -70,8 +70,8 @@ Requires-Dist: flake8; extra == "dev"
70
70
  Requires-Dist: black; extra == "dev"
71
71
  Provides-Extra: all
72
72
  Requires-Dist: networkx>=3.2; extra == "all"
73
- Requires-Dist: plotly>=5.18.0; extra == "all"
74
- Requires-Dist: dash<3.0.0,>=2.16.1; extra == "all"
73
+ Requires-Dist: plotly>=6.0.0; extra == "all"
74
+ Requires-Dist: dash>=3.0.0; extra == "all"
75
75
  Requires-Dist: dash-bootstrap-components>=1.6.0; extra == "all"
76
76
  Requires-Dist: dash-bootstrap-templates>=1.3.0; extra == "all"
77
77
  Requires-Dist: dash_ag_grid; extra == "all"