workbench 0.8.177__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +5 -5
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +319 -204
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -82
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +260 -76
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/RECORD +121 -106
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -494
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -494
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
|
@@ -1,30 +1,22 @@
|
|
|
1
1
|
"""XGBoost Model Utilities"""
|
|
2
2
|
|
|
3
|
+
import glob
|
|
4
|
+
import hashlib
|
|
3
5
|
import logging
|
|
4
6
|
import os
|
|
5
|
-
import tempfile
|
|
6
|
-
import tarfile
|
|
7
7
|
import pickle
|
|
8
|
-
import
|
|
8
|
+
import tempfile
|
|
9
|
+
from typing import Any, List, Optional, Tuple
|
|
10
|
+
|
|
9
11
|
import awswrangler as wr
|
|
10
|
-
|
|
11
|
-
import hashlib
|
|
12
|
+
import joblib
|
|
12
13
|
import pandas as pd
|
|
13
|
-
import numpy as np
|
|
14
14
|
import xgboost as xgb
|
|
15
|
-
from typing import Dict, Any
|
|
16
|
-
from sklearn.model_selection import KFold, StratifiedKFold
|
|
17
|
-
from sklearn.metrics import (
|
|
18
|
-
precision_recall_fscore_support,
|
|
19
|
-
confusion_matrix,
|
|
20
|
-
mean_squared_error,
|
|
21
|
-
mean_absolute_error,
|
|
22
|
-
r2_score,
|
|
23
|
-
)
|
|
24
|
-
from sklearn.preprocessing import LabelEncoder
|
|
25
15
|
|
|
26
16
|
# Workbench Imports
|
|
27
|
-
from workbench.utils.
|
|
17
|
+
from workbench.utils.aws_utils import pull_s3_data
|
|
18
|
+
from workbench.utils.metrics_utils import compute_metrics_from_predictions
|
|
19
|
+
from workbench.utils.model_utils import load_category_mappings_from_s3, safe_extract_tarfile
|
|
28
20
|
from workbench.utils.pandas_utils import convert_categorical_types
|
|
29
21
|
|
|
30
22
|
# Set up the log
|
|
@@ -34,14 +26,12 @@ log = logging.getLogger("workbench")
|
|
|
34
26
|
def xgboost_model_from_s3(model_artifact_uri: str):
|
|
35
27
|
"""
|
|
36
28
|
Download and extract XGBoost model artifact from S3, then load the model into memory.
|
|
37
|
-
Handles both direct XGBoost model files and pickled models.
|
|
38
|
-
Ensures categorical feature support is enabled.
|
|
39
29
|
|
|
40
30
|
Args:
|
|
41
31
|
model_artifact_uri (str): S3 URI of the model artifact.
|
|
42
32
|
|
|
43
33
|
Returns:
|
|
44
|
-
Loaded XGBoost model or None if unavailable.
|
|
34
|
+
Loaded XGBoost model (XGBClassifier, XGBRegressor, or Booster) or None if unavailable.
|
|
45
35
|
"""
|
|
46
36
|
|
|
47
37
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
@@ -50,69 +40,90 @@ def xgboost_model_from_s3(model_artifact_uri: str):
|
|
|
50
40
|
wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
|
|
51
41
|
|
|
52
42
|
# Extract tarball
|
|
53
|
-
|
|
54
|
-
tar.extractall(path=tmpdir, filter="data")
|
|
43
|
+
safe_extract_tarfile(local_tar_path, tmpdir)
|
|
55
44
|
|
|
56
45
|
# Define model file patterns to search for (in order of preference)
|
|
57
46
|
patterns = [
|
|
58
|
-
#
|
|
59
|
-
os.path.join(tmpdir, "
|
|
60
|
-
os.path.join(tmpdir, "
|
|
61
|
-
os.path.join(tmpdir, "model"),
|
|
62
|
-
os.path.join(tmpdir, "*.
|
|
47
|
+
# Joblib models (preferred - preserves everything)
|
|
48
|
+
os.path.join(tmpdir, "*model*.joblib"),
|
|
49
|
+
os.path.join(tmpdir, "xgb*.joblib"),
|
|
50
|
+
os.path.join(tmpdir, "**", "*model*.joblib"),
|
|
51
|
+
os.path.join(tmpdir, "**", "xgb*.joblib"),
|
|
52
|
+
# Pickle models (also preserves everything)
|
|
53
|
+
os.path.join(tmpdir, "*model*.pkl"),
|
|
54
|
+
os.path.join(tmpdir, "xgb*.pkl"),
|
|
55
|
+
os.path.join(tmpdir, "**", "*model*.pkl"),
|
|
56
|
+
os.path.join(tmpdir, "**", "xgb*.pkl"),
|
|
57
|
+
# JSON models (fallback - requires reconstruction)
|
|
58
|
+
os.path.join(tmpdir, "*model*.json"),
|
|
59
|
+
os.path.join(tmpdir, "xgb*.json"),
|
|
63
60
|
os.path.join(tmpdir, "**", "*model*.json"),
|
|
64
|
-
os.path.join(tmpdir, "**", "
|
|
65
|
-
# Pickled models
|
|
66
|
-
os.path.join(tmpdir, "*.pkl"),
|
|
67
|
-
os.path.join(tmpdir, "**", "*.pkl"),
|
|
68
|
-
os.path.join(tmpdir, "*.pickle"),
|
|
69
|
-
os.path.join(tmpdir, "**", "*.pickle"),
|
|
61
|
+
os.path.join(tmpdir, "**", "xgb*.json"),
|
|
70
62
|
]
|
|
71
63
|
|
|
72
64
|
# Try each pattern
|
|
73
65
|
for pattern in patterns:
|
|
74
|
-
# Use glob to find all matching files
|
|
75
66
|
for model_path in glob.glob(pattern, recursive=True):
|
|
76
|
-
#
|
|
67
|
+
# Skip files that are clearly not XGBoost models
|
|
68
|
+
filename = os.path.basename(model_path).lower()
|
|
69
|
+
if any(skip in filename for skip in ["label_encoder", "scaler", "preprocessor", "transformer"]):
|
|
70
|
+
log.debug(f"Skipping non-model file: {model_path}")
|
|
71
|
+
continue
|
|
72
|
+
|
|
77
73
|
_, ext = os.path.splitext(model_path)
|
|
78
74
|
|
|
79
75
|
try:
|
|
80
|
-
if ext
|
|
81
|
-
|
|
76
|
+
if ext == ".joblib":
|
|
77
|
+
model = joblib.load(model_path)
|
|
78
|
+
# Verify it's actually an XGBoost model
|
|
79
|
+
if isinstance(model, (xgb.XGBClassifier, xgb.XGBRegressor, xgb.Booster)):
|
|
80
|
+
log.important(f"Loaded XGBoost model from joblib: {model_path}")
|
|
81
|
+
return model
|
|
82
|
+
else:
|
|
83
|
+
log.debug(f"Skipping non-XGBoost object from {model_path}: {type(model)}")
|
|
84
|
+
|
|
85
|
+
elif ext in [".pkl", ".pickle"]:
|
|
82
86
|
with open(model_path, "rb") as f:
|
|
83
87
|
model = pickle.load(f)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
log.important(f"Loaded XGBoost Booster from pickle: {model_path}")
|
|
88
|
+
# Verify it's actually an XGBoost model
|
|
89
|
+
if isinstance(model, (xgb.XGBClassifier, xgb.XGBRegressor, xgb.Booster)):
|
|
90
|
+
log.important(f"Loaded XGBoost model from pickle: {model_path}")
|
|
88
91
|
return model
|
|
89
|
-
|
|
90
|
-
log.
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
# Handle direct XGBoost model files
|
|
92
|
+
else:
|
|
93
|
+
log.debug(f"Skipping non-XGBoost object from {model_path}: {type(model)}")
|
|
94
|
+
|
|
95
|
+
elif ext == ".json":
|
|
96
|
+
# JSON files should be XGBoost models by definition
|
|
95
97
|
booster = xgb.Booster()
|
|
96
98
|
booster.load_model(model_path)
|
|
97
|
-
log.important(f"Loaded XGBoost
|
|
99
|
+
log.important(f"Loaded XGBoost booster from JSON: {model_path}")
|
|
98
100
|
return booster
|
|
101
|
+
|
|
99
102
|
except Exception as e:
|
|
100
|
-
log.
|
|
101
|
-
continue
|
|
103
|
+
log.debug(f"Failed to load {model_path}: {e}")
|
|
104
|
+
continue
|
|
102
105
|
|
|
103
|
-
# If no model found
|
|
104
106
|
log.error("No XGBoost model found in the artifact.")
|
|
105
107
|
return None
|
|
106
108
|
|
|
107
109
|
|
|
108
|
-
def feature_importance(workbench_model, importance_type: str = "
|
|
110
|
+
def feature_importance(workbench_model, importance_type: str = "gain") -> Optional[List[Tuple[str, float]]]:
|
|
109
111
|
"""
|
|
110
112
|
Get sorted feature importances from a Workbench Model object.
|
|
111
113
|
|
|
112
114
|
Args:
|
|
113
115
|
workbench_model: Workbench model object
|
|
114
|
-
importance_type: Type of feature importance.
|
|
115
|
-
|
|
116
|
+
importance_type: Type of feature importance. Options:
|
|
117
|
+
- 'gain' (default): Average improvement in loss/objective when feature is used.
|
|
118
|
+
Best for understanding predictive power of features.
|
|
119
|
+
- 'weight': Number of times a feature appears in trees (split count).
|
|
120
|
+
Useful for understanding model complexity and feature usage frequency.
|
|
121
|
+
- 'cover': Average number of samples affected when feature is used.
|
|
122
|
+
Shows the relative quantity of observations related to this feature.
|
|
123
|
+
- 'total_gain': Total improvement in loss/objective across all splits.
|
|
124
|
+
Similar to 'gain' but not averaged (can be biased toward frequent features).
|
|
125
|
+
- 'total_cover': Total number of samples affected across all splits.
|
|
126
|
+
Similar to 'cover' but not averaged.
|
|
116
127
|
|
|
117
128
|
Returns:
|
|
118
129
|
List of tuples (feature, importance) sorted by importance value (descending).
|
|
@@ -121,7 +132,8 @@ def feature_importance(workbench_model, importance_type: str = "weight") -> Opti
|
|
|
121
132
|
|
|
122
133
|
Note:
|
|
123
134
|
XGBoost's get_score() only returns features with non-zero importance.
|
|
124
|
-
This function ensures all model features are included in the output
|
|
135
|
+
This function ensures all model features are included in the output,
|
|
136
|
+
adding zero values for features that weren't used in any tree splits.
|
|
125
137
|
"""
|
|
126
138
|
model_artifact_uri = workbench_model.model_data_url()
|
|
127
139
|
xgb_model = xgboost_model_from_s3(model_artifact_uri)
|
|
@@ -129,11 +141,18 @@ def feature_importance(workbench_model, importance_type: str = "weight") -> Opti
|
|
|
129
141
|
log.error("No XGBoost model found in the artifact.")
|
|
130
142
|
return None
|
|
131
143
|
|
|
132
|
-
#
|
|
133
|
-
|
|
144
|
+
# Check if we got a full sklearn model or just a booster (for backwards compatibility)
|
|
145
|
+
if hasattr(xgb_model, "get_booster"):
|
|
146
|
+
# Full sklearn model - get the booster for feature importance
|
|
147
|
+
booster = xgb_model.get_booster()
|
|
148
|
+
all_features = booster.feature_names
|
|
149
|
+
else:
|
|
150
|
+
# Already a booster (legacy JSON load)
|
|
151
|
+
booster = xgb_model
|
|
152
|
+
all_features = xgb_model.feature_names
|
|
134
153
|
|
|
135
|
-
# Get
|
|
136
|
-
|
|
154
|
+
# Get feature importances (only non-zero features)
|
|
155
|
+
importances = booster.get_score(importance_type=importance_type)
|
|
137
156
|
|
|
138
157
|
# Create complete importance dict with zeros for missing features
|
|
139
158
|
complete_importances = {feat: importances.get(feat, 0.0) for feat in all_features}
|
|
@@ -230,149 +249,45 @@ def leaf_stats(df: pd.DataFrame, target_col: str) -> pd.DataFrame:
|
|
|
230
249
|
return result_df
|
|
231
250
|
|
|
232
251
|
|
|
233
|
-
def
|
|
234
|
-
"""
|
|
235
|
-
|
|
252
|
+
def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
253
|
+
"""Pull cross-validation results from AWS training artifacts.
|
|
254
|
+
|
|
255
|
+
This retrieves the validation predictions saved during model training and
|
|
256
|
+
computes metrics directly from them. For XGBoost models trained with
|
|
257
|
+
n_folds > 1, these are out-of-fold predictions from k-fold cross-validation.
|
|
258
|
+
|
|
236
259
|
Args:
|
|
237
260
|
workbench_model: Workbench model object
|
|
238
|
-
|
|
261
|
+
|
|
239
262
|
Returns:
|
|
240
|
-
|
|
241
|
-
-
|
|
242
|
-
-
|
|
243
|
-
- overall_metrics: Overall metrics for all folds
|
|
263
|
+
Tuple of:
|
|
264
|
+
- DataFrame with computed metrics
|
|
265
|
+
- DataFrame with validation predictions
|
|
244
266
|
"""
|
|
245
|
-
|
|
267
|
+
# Get the validation predictions from S3
|
|
268
|
+
s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
|
|
269
|
+
predictions_df = pull_s3_data(s3_path)
|
|
246
270
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
model_artifact_uri = workbench_model.model_data_url()
|
|
250
|
-
loaded_booster = xgboost_model_from_s3(model_artifact_uri)
|
|
251
|
-
if loaded_booster is None:
|
|
252
|
-
log.error("No XGBoost model found in the artifact.")
|
|
253
|
-
return {}
|
|
254
|
-
# Create the model wrapper
|
|
255
|
-
is_classifier = model_type == "classifier"
|
|
256
|
-
xgb_model = (
|
|
257
|
-
xgb.XGBClassifier(enable_categorical=True) if is_classifier else xgb.XGBRegressor(enable_categorical=True)
|
|
258
|
-
)
|
|
259
|
-
xgb_model._Booster = loaded_booster
|
|
260
|
-
# Prepare data
|
|
261
|
-
fs = FeatureSet(workbench_model.get_input())
|
|
262
|
-
df = fs.pull_dataframe()
|
|
263
|
-
feature_cols = workbench_model.features()
|
|
264
|
-
# Convert string features to categorical
|
|
265
|
-
for col in feature_cols:
|
|
266
|
-
if df[col].dtype in ["object", "string"]:
|
|
267
|
-
df[col] = df[col].astype("category")
|
|
268
|
-
# Split X and y
|
|
269
|
-
X = df[workbench_model.features()]
|
|
270
|
-
y = df[workbench_model.target()]
|
|
271
|
-
|
|
272
|
-
# Encode target if it's a classification problem
|
|
273
|
-
label_encoder = LabelEncoder() if is_classifier else None
|
|
274
|
-
if label_encoder:
|
|
275
|
-
y = pd.Series(label_encoder.fit_transform(y), name=workbench_model.target())
|
|
276
|
-
# Prepare KFold
|
|
277
|
-
kfold = (
|
|
278
|
-
StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=42)
|
|
279
|
-
if is_classifier
|
|
280
|
-
else KFold(n_splits=nfolds, shuffle=True, random_state=42)
|
|
281
|
-
)
|
|
271
|
+
if predictions_df is None:
|
|
272
|
+
raise ValueError(f"No validation predictions found at {s3_path}")
|
|
282
273
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
xgb_model.fit(X_train, y_train)
|
|
292
|
-
preds = xgb_model.predict(X_val)
|
|
293
|
-
all_predictions.extend(preds)
|
|
294
|
-
all_actuals.extend(y_val)
|
|
295
|
-
|
|
296
|
-
# Calculate metrics for this fold
|
|
297
|
-
fold_metrics = {"fold": fold_idx + 1}
|
|
298
|
-
|
|
299
|
-
if is_classifier:
|
|
300
|
-
y_val_original = label_encoder.inverse_transform(y_val)
|
|
301
|
-
preds_original = label_encoder.inverse_transform(preds.astype(int))
|
|
302
|
-
scores = precision_recall_fscore_support(
|
|
303
|
-
y_val_original, preds_original, average="weighted", zero_division=0
|
|
304
|
-
)
|
|
305
|
-
fold_metrics.update({"precision": float(scores[0]), "recall": float(scores[1]), "fscore": float(scores[2])})
|
|
306
|
-
else:
|
|
307
|
-
fold_metrics.update(
|
|
308
|
-
{
|
|
309
|
-
"rmse": float(np.sqrt(mean_squared_error(y_val, preds))),
|
|
310
|
-
"mae": float(mean_absolute_error(y_val, preds)),
|
|
311
|
-
"r2": float(r2_score(y_val, preds)),
|
|
312
|
-
}
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
fold_results.append(fold_metrics)
|
|
316
|
-
# Calculate overall metrics
|
|
317
|
-
overall_metrics = {}
|
|
318
|
-
if is_classifier:
|
|
319
|
-
all_actuals_original = label_encoder.inverse_transform(all_actuals)
|
|
320
|
-
all_predictions_original = label_encoder.inverse_transform(all_predictions)
|
|
321
|
-
scores = precision_recall_fscore_support(
|
|
322
|
-
all_actuals_original, all_predictions_original, average="weighted", zero_division=0
|
|
323
|
-
)
|
|
324
|
-
overall_metrics.update(
|
|
325
|
-
{
|
|
326
|
-
"precision": float(scores[0]),
|
|
327
|
-
"recall": float(scores[1]),
|
|
328
|
-
"fscore": float(scores[2]),
|
|
329
|
-
"confusion_matrix": confusion_matrix(
|
|
330
|
-
all_actuals_original, all_predictions_original, labels=label_encoder.classes_
|
|
331
|
-
).tolist(),
|
|
332
|
-
"label_names": list(label_encoder.classes_),
|
|
333
|
-
}
|
|
334
|
-
)
|
|
274
|
+
log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
|
|
275
|
+
|
|
276
|
+
# Compute metrics from predictions
|
|
277
|
+
target = workbench_model.target()
|
|
278
|
+
class_labels = workbench_model.class_labels()
|
|
279
|
+
|
|
280
|
+
if target in predictions_df.columns and "prediction" in predictions_df.columns:
|
|
281
|
+
metrics_df = compute_metrics_from_predictions(predictions_df, target, class_labels)
|
|
335
282
|
else:
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
"mae": float(mean_absolute_error(all_actuals, all_predictions)),
|
|
340
|
-
"r2": float(r2_score(all_actuals, all_predictions)),
|
|
341
|
-
}
|
|
342
|
-
)
|
|
343
|
-
# Calculate summary metrics across folds
|
|
344
|
-
summary_metrics = {}
|
|
345
|
-
metrics_to_aggregate = ["precision", "recall", "fscore"] if is_classifier else ["rmse", "mae", "r2"]
|
|
346
|
-
|
|
347
|
-
for metric in metrics_to_aggregate:
|
|
348
|
-
values = [fold[metric] for fold in fold_results]
|
|
349
|
-
summary_metrics[metric] = f"{float(np.mean(values)):.3f} ±{float(np.std(values)):.3f}"
|
|
350
|
-
# Format fold results as strings (TBD section)
|
|
351
|
-
formatted_folds = {}
|
|
352
|
-
for fold_data in fold_results:
|
|
353
|
-
fold_key = f"Fold {fold_data['fold']}"
|
|
354
|
-
if is_classifier:
|
|
355
|
-
formatted_folds[fold_key] = (
|
|
356
|
-
f"precision: {fold_data['precision']:.3f} "
|
|
357
|
-
f"recall: {fold_data['recall']:.3f} "
|
|
358
|
-
f"fscore: {fold_data['fscore']:.3f}"
|
|
359
|
-
)
|
|
360
|
-
else:
|
|
361
|
-
formatted_folds[fold_key] = (
|
|
362
|
-
f"rmse: {fold_data['rmse']:.3f} mae: {fold_data['mae']:.3f} r2: {fold_data['r2']:.3f}"
|
|
363
|
-
)
|
|
364
|
-
# Return the results
|
|
365
|
-
return {
|
|
366
|
-
"summary_metrics": summary_metrics,
|
|
367
|
-
# "overall_metrics": overall_metrics,
|
|
368
|
-
"folds": formatted_folds,
|
|
369
|
-
}
|
|
283
|
+
metrics_df = pd.DataFrame()
|
|
284
|
+
|
|
285
|
+
return metrics_df, predictions_df
|
|
370
286
|
|
|
371
287
|
|
|
372
288
|
if __name__ == "__main__":
|
|
373
289
|
"""Exercise the Model Utilities"""
|
|
374
|
-
from workbench.api import Model
|
|
375
|
-
from pprint import pprint
|
|
290
|
+
from workbench.api import Model
|
|
376
291
|
|
|
377
292
|
# Test the XGBoost model loading and feature importance
|
|
378
293
|
model = Model("abalone-regression")
|
|
@@ -384,34 +299,26 @@ if __name__ == "__main__":
|
|
|
384
299
|
model_artifact_uri = model.model_data_url()
|
|
385
300
|
xgb_model = xgboost_model_from_s3(model_artifact_uri)
|
|
386
301
|
|
|
387
|
-
#
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
# Test XGBoost add_leaf_hash
|
|
392
|
-
input_df = FeatureSet(model.get_input()).pull_dataframe()
|
|
393
|
-
leaf_df = add_leaf_hash(model, input_df)
|
|
394
|
-
print("DataFrame with Leaf Hash:")
|
|
395
|
-
print(leaf_df)
|
|
396
|
-
|
|
397
|
-
# Okay, we're going to copy row 3 and insert it into row 7 to make sure the leaf_hash is the same
|
|
398
|
-
input_df.iloc[7] = input_df.iloc[3]
|
|
399
|
-
print("DataFrame with Leaf Hash (3 and 7 should match):")
|
|
400
|
-
leaf_df = add_leaf_hash(model, input_df)
|
|
401
|
-
print(leaf_df)
|
|
402
|
-
|
|
403
|
-
# Test leaf_stats
|
|
404
|
-
target_col = "class_number_of_rings"
|
|
405
|
-
stats_df = leaf_stats(leaf_df, target_col)
|
|
406
|
-
print("DataFrame with Leaf Statistics:")
|
|
407
|
-
print(stats_df)
|
|
408
|
-
|
|
409
|
-
print("\n=== CROSS FOLD REGRESSION EXAMPLE ===")
|
|
410
|
-
model = Model("abalone-regression")
|
|
411
|
-
results = cross_fold_inference(model)
|
|
412
|
-
pprint(results)
|
|
302
|
+
# Verify enable_categorical is preserved (for debugging/confidence)
|
|
303
|
+
print(f"Model parameters: {xgb_model.get_params()}")
|
|
304
|
+
print(f"enable_categorical: {xgb_model.enable_categorical}")
|
|
413
305
|
|
|
414
|
-
print("\n===
|
|
306
|
+
print("\n=== PULL CV RESULTS EXAMPLE ===")
|
|
307
|
+
model = Model("abalone-regression")
|
|
308
|
+
metrics_df, predictions_df = pull_cv_results(model)
|
|
309
|
+
print(f"\nMetrics:\n{metrics_df}")
|
|
310
|
+
print(f"\nPredictions shape: {predictions_df.shape}")
|
|
311
|
+
print(f"Predictions columns: {predictions_df.columns.tolist()}")
|
|
312
|
+
print(predictions_df.head())
|
|
313
|
+
|
|
314
|
+
# Test on a Classifier model
|
|
315
|
+
print("\n=== CLASSIFIER MODEL TEST ===")
|
|
415
316
|
model = Model("wine-classification")
|
|
416
|
-
|
|
417
|
-
|
|
317
|
+
features = feature_importance(model)
|
|
318
|
+
print("Feature Importance:")
|
|
319
|
+
print(features)
|
|
320
|
+
metrics_df, predictions_df = pull_cv_results(model)
|
|
321
|
+
print(f"\nMetrics:\n{metrics_df}")
|
|
322
|
+
print(f"\nPredictions shape: {predictions_df.shape}")
|
|
323
|
+
print(f"Predictions columns: {predictions_df.columns.tolist()}")
|
|
324
|
+
print(predictions_df.head())
|
|
@@ -10,8 +10,10 @@ from workbench.api import Model, ModelType
|
|
|
10
10
|
from workbench.web_interface.components.component_interface import ComponentInterface
|
|
11
11
|
from workbench.web_interface.components.plugins.confusion_matrix import ConfusionMatrix
|
|
12
12
|
from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
|
|
13
|
+
from workbench.utils.deprecated_utils import deprecated
|
|
13
14
|
|
|
14
15
|
|
|
16
|
+
@deprecated(version="0.9")
|
|
15
17
|
class ModelPlot(ComponentInterface):
|
|
16
18
|
"""Model Metrics Components"""
|
|
17
19
|
|
|
@@ -36,10 +38,22 @@ class ModelPlot(ComponentInterface):
|
|
|
36
38
|
if df is None:
|
|
37
39
|
return self.display_text("No Data")
|
|
38
40
|
|
|
39
|
-
#
|
|
41
|
+
# Grab the target(s) for this model
|
|
40
42
|
target = model.target()
|
|
43
|
+
|
|
44
|
+
# For multi-task models, match target to inference_run name or default to first
|
|
45
|
+
if isinstance(target, list):
|
|
46
|
+
target = next((t for t in target if t in inference_run), target[0])
|
|
47
|
+
|
|
48
|
+
# Compute error for coloring
|
|
41
49
|
df["error"] = abs(df["prediction"] - df[target])
|
|
42
|
-
return ScatterPlot().update_properties(
|
|
50
|
+
return ScatterPlot().update_properties(
|
|
51
|
+
df,
|
|
52
|
+
color="error",
|
|
53
|
+
regression_line=True,
|
|
54
|
+
x=target,
|
|
55
|
+
y="prediction",
|
|
56
|
+
)[0]
|
|
43
57
|
else:
|
|
44
58
|
return self.display_text(f"Model Type: {model.model_type}\n\n Awesome Plot Coming Soon!")
|
|
45
59
|
|
|
@@ -3,7 +3,6 @@ import dash_bootstrap_components as dbc
|
|
|
3
3
|
import logging
|
|
4
4
|
import socket
|
|
5
5
|
|
|
6
|
-
|
|
7
6
|
# Workbench Imports
|
|
8
7
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginInputType
|
|
9
8
|
from workbench.api import DataSource, FeatureSet, Model, Endpoint, Meta
|
|
@@ -156,10 +155,13 @@ class PluginUnitTest:
|
|
|
156
155
|
"""Run the Dash server for the plugin, handling common errors gracefully."""
|
|
157
156
|
while self.is_port_in_use(self.port):
|
|
158
157
|
log.info(f"Port {self.port} is in use. Trying the next one...")
|
|
159
|
-
self.port += 1
|
|
158
|
+
self.port += 1
|
|
160
159
|
|
|
161
160
|
log.info(f"Starting Dash server on port {self.port}...")
|
|
162
|
-
|
|
161
|
+
try:
|
|
162
|
+
self.app.run(debug=True, use_reloader=False, port=self.port)
|
|
163
|
+
except KeyboardInterrupt:
|
|
164
|
+
log.info("Shutting down Dash server...")
|
|
163
165
|
|
|
164
166
|
@staticmethod
|
|
165
167
|
def is_port_in_use(port):
|
|
@@ -22,9 +22,7 @@ class AGTable(PluginInterface):
|
|
|
22
22
|
header_height = 30
|
|
23
23
|
row_height = 25
|
|
24
24
|
|
|
25
|
-
def create_component(
|
|
26
|
-
self, component_id: str, header_color: str = "rgb(120, 60, 60)", max_height: int = 500
|
|
27
|
-
) -> AgGrid:
|
|
25
|
+
def create_component(self, component_id: str, max_height: int = 500) -> AgGrid:
|
|
28
26
|
"""Create a Table Component without any data."""
|
|
29
27
|
self.component_id = component_id
|
|
30
28
|
self.max_height = max_height
|
|
@@ -112,4 +110,4 @@ if __name__ == "__main__":
|
|
|
112
110
|
test_df = pd.DataFrame(data)
|
|
113
111
|
|
|
114
112
|
# Run the Unit Test on the Plugin
|
|
115
|
-
PluginUnitTest(AGTable, theme="
|
|
113
|
+
PluginUnitTest(AGTable, theme="dark", input_data=test_df, max_height=500).run()
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
from dash import dcc, callback, Output, Input, State
|
|
4
4
|
import plotly.graph_objects as go
|
|
5
5
|
|
|
6
|
-
|
|
7
6
|
# Workbench Imports
|
|
8
7
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
9
8
|
from workbench.utils.theme_manager import ThemeManager
|
|
@@ -22,7 +21,6 @@ class ConfusionMatrix(PluginInterface):
|
|
|
22
21
|
self.component_id = None
|
|
23
22
|
self.current_highlight = None # Store the currently highlighted cell
|
|
24
23
|
self.theme_manager = ThemeManager()
|
|
25
|
-
self.colorscale = add_alpha_to_first_color(self.theme_manager.colorscale("heatmap"))
|
|
26
24
|
|
|
27
25
|
# Call the parent class constructor
|
|
28
26
|
super().__init__()
|
|
@@ -65,9 +63,8 @@ class ConfusionMatrix(PluginInterface):
|
|
|
65
63
|
if df is None:
|
|
66
64
|
return [self.display_text("No Data")]
|
|
67
65
|
|
|
68
|
-
#
|
|
69
|
-
|
|
70
|
-
# color_scale = sequential.Plasma
|
|
66
|
+
# Get the colorscale from the current theme
|
|
67
|
+
colorscale = add_alpha_to_first_color(self.theme_manager.colorscale("heatmap"))
|
|
71
68
|
|
|
72
69
|
# The confusion matrix is displayed in reverse order (flip the dataframe for correct orientation)
|
|
73
70
|
df = df.iloc[::-1]
|
|
@@ -89,7 +86,7 @@ class ConfusionMatrix(PluginInterface):
|
|
|
89
86
|
title="Count",
|
|
90
87
|
outlinewidth=1,
|
|
91
88
|
),
|
|
92
|
-
colorscale=
|
|
89
|
+
colorscale=colorscale,
|
|
93
90
|
)
|
|
94
91
|
)
|
|
95
92
|
|