workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,22 @@
1
1
  """XGBoost Model Utilities"""
2
2
 
3
+ import glob
4
+ import hashlib
3
5
  import logging
4
6
  import os
5
- import tempfile
6
- import tarfile
7
7
  import pickle
8
- import glob
8
+ import tempfile
9
+ from typing import Any, List, Optional, Tuple
10
+
9
11
  import awswrangler as wr
10
- from typing import Optional, List, Tuple
11
- import hashlib
12
+ import joblib
12
13
  import pandas as pd
13
- import numpy as np
14
14
  import xgboost as xgb
15
- from typing import Dict, Any
16
- from sklearn.model_selection import KFold, StratifiedKFold
17
- from sklearn.metrics import (
18
- precision_recall_fscore_support,
19
- confusion_matrix,
20
- mean_squared_error,
21
- mean_absolute_error,
22
- r2_score,
23
- )
24
- from sklearn.preprocessing import LabelEncoder
25
15
 
26
16
  # Workbench Imports
27
- from workbench.utils.model_utils import load_category_mappings_from_s3
17
+ from workbench.utils.aws_utils import pull_s3_data
18
+ from workbench.utils.metrics_utils import compute_metrics_from_predictions
19
+ from workbench.utils.model_utils import load_category_mappings_from_s3, safe_extract_tarfile
28
20
  from workbench.utils.pandas_utils import convert_categorical_types
29
21
 
30
22
  # Set up the log
@@ -34,14 +26,12 @@ log = logging.getLogger("workbench")
34
26
  def xgboost_model_from_s3(model_artifact_uri: str):
35
27
  """
36
28
  Download and extract XGBoost model artifact from S3, then load the model into memory.
37
- Handles both direct XGBoost model files and pickled models.
38
- Ensures categorical feature support is enabled.
39
29
 
40
30
  Args:
41
31
  model_artifact_uri (str): S3 URI of the model artifact.
42
32
 
43
33
  Returns:
44
- Loaded XGBoost model or None if unavailable.
34
+ Loaded XGBoost model (XGBClassifier, XGBRegressor, or Booster) or None if unavailable.
45
35
  """
46
36
 
47
37
  with tempfile.TemporaryDirectory() as tmpdir:
@@ -50,69 +40,90 @@ def xgboost_model_from_s3(model_artifact_uri: str):
50
40
  wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
51
41
 
52
42
  # Extract tarball
53
- with tarfile.open(local_tar_path, "r:gz") as tar:
54
- tar.extractall(path=tmpdir, filter="data")
43
+ safe_extract_tarfile(local_tar_path, tmpdir)
55
44
 
56
45
  # Define model file patterns to search for (in order of preference)
57
46
  patterns = [
58
- # Direct XGBoost model files
59
- os.path.join(tmpdir, "xgboost-model"),
60
- os.path.join(tmpdir, "xgb_model*.json"),
61
- os.path.join(tmpdir, "model"),
62
- os.path.join(tmpdir, "*.bin"),
47
+ # Joblib models (preferred - preserves everything)
48
+ os.path.join(tmpdir, "*model*.joblib"),
49
+ os.path.join(tmpdir, "xgb*.joblib"),
50
+ os.path.join(tmpdir, "**", "*model*.joblib"),
51
+ os.path.join(tmpdir, "**", "xgb*.joblib"),
52
+ # Pickle models (also preserves everything)
53
+ os.path.join(tmpdir, "*model*.pkl"),
54
+ os.path.join(tmpdir, "xgb*.pkl"),
55
+ os.path.join(tmpdir, "**", "*model*.pkl"),
56
+ os.path.join(tmpdir, "**", "xgb*.pkl"),
57
+ # JSON models (fallback - requires reconstruction)
58
+ os.path.join(tmpdir, "*model*.json"),
59
+ os.path.join(tmpdir, "xgb*.json"),
63
60
  os.path.join(tmpdir, "**", "*model*.json"),
64
- os.path.join(tmpdir, "**", "rmse.json"),
65
- # Pickled models
66
- os.path.join(tmpdir, "*.pkl"),
67
- os.path.join(tmpdir, "**", "*.pkl"),
68
- os.path.join(tmpdir, "*.pickle"),
69
- os.path.join(tmpdir, "**", "*.pickle"),
61
+ os.path.join(tmpdir, "**", "xgb*.json"),
70
62
  ]
71
63
 
72
64
  # Try each pattern
73
65
  for pattern in patterns:
74
- # Use glob to find all matching files
75
66
  for model_path in glob.glob(pattern, recursive=True):
76
- # Determine file type by extension
67
+ # Skip files that are clearly not XGBoost models
68
+ filename = os.path.basename(model_path).lower()
69
+ if any(skip in filename for skip in ["label_encoder", "scaler", "preprocessor", "transformer"]):
70
+ log.debug(f"Skipping non-model file: {model_path}")
71
+ continue
72
+
77
73
  _, ext = os.path.splitext(model_path)
78
74
 
79
75
  try:
80
- if ext.lower() in [".pkl", ".pickle"]:
81
- # Handle pickled models
76
+ if ext == ".joblib":
77
+ model = joblib.load(model_path)
78
+ # Verify it's actually an XGBoost model
79
+ if isinstance(model, (xgb.XGBClassifier, xgb.XGBRegressor, xgb.Booster)):
80
+ log.important(f"Loaded XGBoost model from joblib: {model_path}")
81
+ return model
82
+ else:
83
+ log.debug(f"Skipping non-XGBoost object from {model_path}: {type(model)}")
84
+
85
+ elif ext in [".pkl", ".pickle"]:
82
86
  with open(model_path, "rb") as f:
83
87
  model = pickle.load(f)
84
-
85
- # Handle different model types
86
- if isinstance(model, xgb.Booster):
87
- log.important(f"Loaded XGBoost Booster from pickle: {model_path}")
88
+ # Verify it's actually an XGBoost model
89
+ if isinstance(model, (xgb.XGBClassifier, xgb.XGBRegressor, xgb.Booster)):
90
+ log.important(f"Loaded XGBoost model from pickle: {model_path}")
88
91
  return model
89
- elif hasattr(model, "get_booster"):
90
- log.important(f"Loaded XGBoost model from pipeline: {model_path}")
91
- booster = model.get_booster()
92
- return booster
93
- else:
94
- # Handle direct XGBoost model files
92
+ else:
93
+ log.debug(f"Skipping non-XGBoost object from {model_path}: {type(model)}")
94
+
95
+ elif ext == ".json":
96
+ # JSON files should be XGBoost models by definition
95
97
  booster = xgb.Booster()
96
98
  booster.load_model(model_path)
97
- log.important(f"Loaded XGBoost model directly: {model_path}")
99
+ log.important(f"Loaded XGBoost booster from JSON: {model_path}")
98
100
  return booster
101
+
99
102
  except Exception as e:
100
- log.info(f"Failed to load model from {model_path}: {e}")
101
- continue # Try the next file
103
+ log.debug(f"Failed to load {model_path}: {e}")
104
+ continue
102
105
 
103
- # If no model found
104
106
  log.error("No XGBoost model found in the artifact.")
105
107
  return None
106
108
 
107
109
 
108
- def feature_importance(workbench_model, importance_type: str = "weight") -> Optional[List[Tuple[str, float]]]:
110
+ def feature_importance(workbench_model, importance_type: str = "gain") -> Optional[List[Tuple[str, float]]]:
109
111
  """
110
112
  Get sorted feature importances from a Workbench Model object.
111
113
 
112
114
  Args:
113
115
  workbench_model: Workbench model object
114
- importance_type: Type of feature importance.
115
- Options: 'weight', 'gain', 'cover', 'total_gain', 'total_cover'
116
+ importance_type: Type of feature importance. Options:
117
+ - 'gain' (default): Average improvement in loss/objective when feature is used.
118
+ Best for understanding predictive power of features.
119
+ - 'weight': Number of times a feature appears in trees (split count).
120
+ Useful for understanding model complexity and feature usage frequency.
121
+ - 'cover': Average number of samples affected when feature is used.
122
+ Shows the relative quantity of observations related to this feature.
123
+ - 'total_gain': Total improvement in loss/objective across all splits.
124
+ Similar to 'gain' but not averaged (can be biased toward frequent features).
125
+ - 'total_cover': Total number of samples affected across all splits.
126
+ Similar to 'cover' but not averaged.
116
127
 
117
128
  Returns:
118
129
  List of tuples (feature, importance) sorted by importance value (descending).
@@ -121,7 +132,8 @@ def feature_importance(workbench_model, importance_type: str = "weight") -> Opti
121
132
 
122
133
  Note:
123
134
  XGBoost's get_score() only returns features with non-zero importance.
124
- This function ensures all model features are included in the output.
135
+ This function ensures all model features are included in the output,
136
+ adding zero values for features that weren't used in any tree splits.
125
137
  """
126
138
  model_artifact_uri = workbench_model.model_data_url()
127
139
  xgb_model = xgboost_model_from_s3(model_artifact_uri)
@@ -129,11 +141,18 @@ def feature_importance(workbench_model, importance_type: str = "weight") -> Opti
129
141
  log.error("No XGBoost model found in the artifact.")
130
142
  return None
131
143
 
132
- # Get feature importances (only non-zero features)
133
- importances = xgb_model.get_score(importance_type=importance_type)
144
+ # Check if we got a full sklearn model or just a booster (for backwards compatibility)
145
+ if hasattr(xgb_model, "get_booster"):
146
+ # Full sklearn model - get the booster for feature importance
147
+ booster = xgb_model.get_booster()
148
+ all_features = booster.feature_names
149
+ else:
150
+ # Already a booster (legacy JSON load)
151
+ booster = xgb_model
152
+ all_features = xgb_model.feature_names
134
153
 
135
- # Get all feature names from the model
136
- all_features = xgb_model.feature_names
154
+ # Get feature importances (only non-zero features)
155
+ importances = booster.get_score(importance_type=importance_type)
137
156
 
138
157
  # Create complete importance dict with zeros for missing features
139
158
  complete_importances = {feat: importances.get(feat, 0.0) for feat in all_features}
@@ -230,149 +249,45 @@ def leaf_stats(df: pd.DataFrame, target_col: str) -> pd.DataFrame:
230
249
  return result_df
231
250
 
232
251
 
233
- def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Dict[str, Any]:
234
- """
235
- Performs K-fold cross-validation with detailed metrics.
252
+ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
253
+ """Pull cross-validation results from AWS training artifacts.
254
+
255
+ This retrieves the validation predictions saved during model training and
256
+ computes metrics directly from them. For XGBoost models trained with
257
+ n_folds > 1, these are out-of-fold predictions from k-fold cross-validation.
258
+
236
259
  Args:
237
260
  workbench_model: Workbench model object
238
- nfolds: Number of folds for cross-validation (default is 5)
261
+
239
262
  Returns:
240
- Dictionary containing:
241
- - folds: Dictionary of formatted strings for each fold
242
- - summary_metrics: Summary metrics across folds
243
- - overall_metrics: Overall metrics for all folds
263
+ Tuple of:
264
+ - DataFrame with computed metrics
265
+ - DataFrame with validation predictions
244
266
  """
245
- from workbench.api import FeatureSet
267
+ # Get the validation predictions from S3
268
+ s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
269
+ predictions_df = pull_s3_data(s3_path)
246
270
 
247
- # Load model
248
- model_type = workbench_model.model_type.value
249
- model_artifact_uri = workbench_model.model_data_url()
250
- loaded_booster = xgboost_model_from_s3(model_artifact_uri)
251
- if loaded_booster is None:
252
- log.error("No XGBoost model found in the artifact.")
253
- return {}
254
- # Create the model wrapper
255
- is_classifier = model_type == "classifier"
256
- xgb_model = (
257
- xgb.XGBClassifier(enable_categorical=True) if is_classifier else xgb.XGBRegressor(enable_categorical=True)
258
- )
259
- xgb_model._Booster = loaded_booster
260
- # Prepare data
261
- fs = FeatureSet(workbench_model.get_input())
262
- df = fs.pull_dataframe()
263
- feature_cols = workbench_model.features()
264
- # Convert string features to categorical
265
- for col in feature_cols:
266
- if df[col].dtype in ["object", "string"]:
267
- df[col] = df[col].astype("category")
268
- # Split X and y
269
- X = df[workbench_model.features()]
270
- y = df[workbench_model.target()]
271
-
272
- # Encode target if it's a classification problem
273
- label_encoder = LabelEncoder() if is_classifier else None
274
- if label_encoder:
275
- y = pd.Series(label_encoder.fit_transform(y), name=workbench_model.target())
276
- # Prepare KFold
277
- kfold = (
278
- StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=42)
279
- if is_classifier
280
- else KFold(n_splits=nfolds, shuffle=True, random_state=42)
281
- )
271
+ if predictions_df is None:
272
+ raise ValueError(f"No validation predictions found at {s3_path}")
282
273
 
283
- fold_results = []
284
- all_predictions = []
285
- all_actuals = []
286
- for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y)):
287
- X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
288
- y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
289
-
290
- # Train the model
291
- xgb_model.fit(X_train, y_train)
292
- preds = xgb_model.predict(X_val)
293
- all_predictions.extend(preds)
294
- all_actuals.extend(y_val)
295
-
296
- # Calculate metrics for this fold
297
- fold_metrics = {"fold": fold_idx + 1}
298
-
299
- if is_classifier:
300
- y_val_original = label_encoder.inverse_transform(y_val)
301
- preds_original = label_encoder.inverse_transform(preds.astype(int))
302
- scores = precision_recall_fscore_support(
303
- y_val_original, preds_original, average="weighted", zero_division=0
304
- )
305
- fold_metrics.update({"precision": float(scores[0]), "recall": float(scores[1]), "fscore": float(scores[2])})
306
- else:
307
- fold_metrics.update(
308
- {
309
- "rmse": float(np.sqrt(mean_squared_error(y_val, preds))),
310
- "mae": float(mean_absolute_error(y_val, preds)),
311
- "r2": float(r2_score(y_val, preds)),
312
- }
313
- )
314
-
315
- fold_results.append(fold_metrics)
316
- # Calculate overall metrics
317
- overall_metrics = {}
318
- if is_classifier:
319
- all_actuals_original = label_encoder.inverse_transform(all_actuals)
320
- all_predictions_original = label_encoder.inverse_transform(all_predictions)
321
- scores = precision_recall_fscore_support(
322
- all_actuals_original, all_predictions_original, average="weighted", zero_division=0
323
- )
324
- overall_metrics.update(
325
- {
326
- "precision": float(scores[0]),
327
- "recall": float(scores[1]),
328
- "fscore": float(scores[2]),
329
- "confusion_matrix": confusion_matrix(
330
- all_actuals_original, all_predictions_original, labels=label_encoder.classes_
331
- ).tolist(),
332
- "label_names": list(label_encoder.classes_),
333
- }
334
- )
274
+ log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
275
+
276
+ # Compute metrics from predictions
277
+ target = workbench_model.target()
278
+ class_labels = workbench_model.class_labels()
279
+
280
+ if target in predictions_df.columns and "prediction" in predictions_df.columns:
281
+ metrics_df = compute_metrics_from_predictions(predictions_df, target, class_labels)
335
282
  else:
336
- overall_metrics.update(
337
- {
338
- "rmse": float(np.sqrt(mean_squared_error(all_actuals, all_predictions))),
339
- "mae": float(mean_absolute_error(all_actuals, all_predictions)),
340
- "r2": float(r2_score(all_actuals, all_predictions)),
341
- }
342
- )
343
- # Calculate summary metrics across folds
344
- summary_metrics = {}
345
- metrics_to_aggregate = ["precision", "recall", "fscore"] if is_classifier else ["rmse", "mae", "r2"]
346
-
347
- for metric in metrics_to_aggregate:
348
- values = [fold[metric] for fold in fold_results]
349
- summary_metrics[metric] = f"{float(np.mean(values)):.3f} ±{float(np.std(values)):.3f}"
350
- # Format fold results as strings (TBD section)
351
- formatted_folds = {}
352
- for fold_data in fold_results:
353
- fold_key = f"Fold {fold_data['fold']}"
354
- if is_classifier:
355
- formatted_folds[fold_key] = (
356
- f"precision: {fold_data['precision']:.3f} "
357
- f"recall: {fold_data['recall']:.3f} "
358
- f"fscore: {fold_data['fscore']:.3f}"
359
- )
360
- else:
361
- formatted_folds[fold_key] = (
362
- f"rmse: {fold_data['rmse']:.3f} mae: {fold_data['mae']:.3f} r2: {fold_data['r2']:.3f}"
363
- )
364
- # Return the results
365
- return {
366
- "summary_metrics": summary_metrics,
367
- # "overall_metrics": overall_metrics,
368
- "folds": formatted_folds,
369
- }
283
+ metrics_df = pd.DataFrame()
284
+
285
+ return metrics_df, predictions_df
370
286
 
371
287
 
372
288
  if __name__ == "__main__":
373
289
  """Exercise the Model Utilities"""
374
- from workbench.api import Model, FeatureSet
375
- from pprint import pprint
290
+ from workbench.api import Model
376
291
 
377
292
  # Test the XGBoost model loading and feature importance
378
293
  model = Model("abalone-regression")
@@ -384,34 +299,26 @@ if __name__ == "__main__":
384
299
  model_artifact_uri = model.model_data_url()
385
300
  xgb_model = xgboost_model_from_s3(model_artifact_uri)
386
301
 
387
- # Test with UQ Model
388
- uq_model = Model("aqsol-uq")
389
- _xgb_model = xgboost_model_from_s3(uq_model.model_data_url())
390
-
391
- # Test XGBoost add_leaf_hash
392
- input_df = FeatureSet(model.get_input()).pull_dataframe()
393
- leaf_df = add_leaf_hash(model, input_df)
394
- print("DataFrame with Leaf Hash:")
395
- print(leaf_df)
396
-
397
- # Okay, we're going to copy row 3 and insert it into row 7 to make sure the leaf_hash is the same
398
- input_df.iloc[7] = input_df.iloc[3]
399
- print("DataFrame with Leaf Hash (3 and 7 should match):")
400
- leaf_df = add_leaf_hash(model, input_df)
401
- print(leaf_df)
402
-
403
- # Test leaf_stats
404
- target_col = "class_number_of_rings"
405
- stats_df = leaf_stats(leaf_df, target_col)
406
- print("DataFrame with Leaf Statistics:")
407
- print(stats_df)
408
-
409
- print("\n=== CROSS FOLD REGRESSION EXAMPLE ===")
410
- model = Model("abalone-regression")
411
- results = cross_fold_inference(model)
412
- pprint(results)
302
+ # Verify enable_categorical is preserved (for debugging/confidence)
303
+ print(f"Model parameters: {xgb_model.get_params()}")
304
+ print(f"enable_categorical: {xgb_model.enable_categorical}")
413
305
 
414
- print("\n=== CROSS FOLD CLASSIFICATION EXAMPLE ===")
306
+ print("\n=== PULL CV RESULTS EXAMPLE ===")
307
+ model = Model("abalone-regression")
308
+ metrics_df, predictions_df = pull_cv_results(model)
309
+ print(f"\nMetrics:\n{metrics_df}")
310
+ print(f"\nPredictions shape: {predictions_df.shape}")
311
+ print(f"Predictions columns: {predictions_df.columns.tolist()}")
312
+ print(predictions_df.head())
313
+
314
+ # Test on a Classifier model
315
+ print("\n=== CLASSIFIER MODEL TEST ===")
415
316
  model = Model("wine-classification")
416
- results = cross_fold_inference(model)
417
- pprint(results)
317
+ features = feature_importance(model)
318
+ print("Feature Importance:")
319
+ print(features)
320
+ metrics_df, predictions_df = pull_cv_results(model)
321
+ print(f"\nMetrics:\n{metrics_df}")
322
+ print(f"\nPredictions shape: {predictions_df.shape}")
323
+ print(f"Predictions columns: {predictions_df.columns.tolist()}")
324
+ print(predictions_df.head())
@@ -11,7 +11,6 @@ import logging
11
11
  from workbench.algorithms.dataframe.aggregation import aggregate
12
12
  from workbench.algorithms.dataframe.projection_2d import Projection2D
13
13
 
14
-
15
14
  # Workbench Logger
16
15
  log = logging.getLogger("workbench")
17
16
 
@@ -10,8 +10,10 @@ from workbench.api import Model, ModelType
10
10
  from workbench.web_interface.components.component_interface import ComponentInterface
11
11
  from workbench.web_interface.components.plugins.confusion_matrix import ConfusionMatrix
12
12
  from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
13
+ from workbench.utils.deprecated_utils import deprecated
13
14
 
14
15
 
16
+ @deprecated(version="0.9")
15
17
  class ModelPlot(ComponentInterface):
16
18
  """Model Metrics Components"""
17
19
 
@@ -36,10 +38,22 @@ class ModelPlot(ComponentInterface):
36
38
  if df is None:
37
39
  return self.display_text("No Data")
38
40
 
39
- # Calculate the distance from the diagonal for each point
41
+ # Grab the target(s) for this model
40
42
  target = model.target()
43
+
44
+ # For multi-task models, match target to inference_run name or default to first
45
+ if isinstance(target, list):
46
+ target = next((t for t in target if t in inference_run), target[0])
47
+
48
+ # Compute error for coloring
41
49
  df["error"] = abs(df["prediction"] - df[target])
42
- return ScatterPlot().update_properties(df, color="error", regression_line=True)[0]
50
+ return ScatterPlot().update_properties(
51
+ df,
52
+ color="error",
53
+ regression_line=True,
54
+ x=target,
55
+ y="prediction",
56
+ )[0]
43
57
  else:
44
58
  return self.display_text(f"Model Type: {model.model_type}\n\n Awesome Plot Coming Soon!")
45
59
 
@@ -3,7 +3,6 @@ import dash_bootstrap_components as dbc
3
3
  import logging
4
4
  import socket
5
5
 
6
-
7
6
  # Workbench Imports
8
7
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginInputType
9
8
  from workbench.api import DataSource, FeatureSet, Model, Endpoint, Meta
@@ -156,10 +155,13 @@ class PluginUnitTest:
156
155
  """Run the Dash server for the plugin, handling common errors gracefully."""
157
156
  while self.is_port_in_use(self.port):
158
157
  log.info(f"Port {self.port} is in use. Trying the next one...")
159
- self.port += 1 # Increment the port number until an available one is found
158
+ self.port += 1
160
159
 
161
160
  log.info(f"Starting Dash server on port {self.port}...")
162
- self.app.run(debug=True, use_reloader=False, port=self.port)
161
+ try:
162
+ self.app.run(debug=True, use_reloader=False, port=self.port)
163
+ except KeyboardInterrupt:
164
+ log.info("Shutting down Dash server...")
163
165
 
164
166
  @staticmethod
165
167
  def is_port_in_use(port):
@@ -22,9 +22,7 @@ class AGTable(PluginInterface):
22
22
  header_height = 30
23
23
  row_height = 25
24
24
 
25
- def create_component(
26
- self, component_id: str, header_color: str = "rgb(120, 60, 60)", max_height: int = 500
27
- ) -> AgGrid:
25
+ def create_component(self, component_id: str, max_height: int = 500) -> AgGrid:
28
26
  """Create a Table Component without any data."""
29
27
  self.component_id = component_id
30
28
  self.max_height = max_height
@@ -112,4 +110,4 @@ if __name__ == "__main__":
112
110
  test_df = pd.DataFrame(data)
113
111
 
114
112
  # Run the Unit Test on the Plugin
115
- PluginUnitTest(AGTable, theme="quartz", input_data=test_df, max_height=500).run()
113
+ PluginUnitTest(AGTable, theme="dark", input_data=test_df, max_height=500).run()
@@ -3,7 +3,6 @@
3
3
  from dash import dcc, callback, Output, Input, State
4
4
  import plotly.graph_objects as go
5
5
 
6
-
7
6
  # Workbench Imports
8
7
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
9
8
  from workbench.utils.theme_manager import ThemeManager
@@ -22,7 +21,6 @@ class ConfusionMatrix(PluginInterface):
22
21
  self.component_id = None
23
22
  self.current_highlight = None # Store the currently highlighted cell
24
23
  self.theme_manager = ThemeManager()
25
- self.colorscale = add_alpha_to_first_color(self.theme_manager.colorscale("heatmap"))
26
24
 
27
25
  # Call the parent class constructor
28
26
  super().__init__()
@@ -65,9 +63,8 @@ class ConfusionMatrix(PluginInterface):
65
63
  if df is None:
66
64
  return [self.display_text("No Data")]
67
65
 
68
- # Use Plotly's default theme-friendly colorscale
69
- # from plotly.colors import sequential
70
- # color_scale = sequential.Plasma
66
+ # Get the colorscale from the current theme
67
+ colorscale = add_alpha_to_first_color(self.theme_manager.colorscale("heatmap"))
71
68
 
72
69
  # The confusion matrix is displayed in reverse order (flip the dataframe for correct orientation)
73
70
  df = df.iloc[::-1]
@@ -89,7 +86,7 @@ class ConfusionMatrix(PluginInterface):
89
86
  title="Count",
90
87
  outlinewidth=1,
91
88
  ),
92
- colorscale=self.colorscale,
89
+ colorscale=colorscale,
93
90
  )
94
91
  )
95
92