workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +7 -7
- workbench/core/artifacts/data_capture_core.py +8 -1
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +323 -205
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +133 -101
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +18 -7
- workbench/utils/chem_utils/mol_standardize.py +80 -58
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +274 -87
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""XGBoost Local Cross-Fold Validation Utilities
|
|
2
|
+
|
|
3
|
+
This module contains functions for running cross-validation locally on XGBoost models.
|
|
4
|
+
For most use cases, prefer using pull_cv_results() from xgboost_model_utils.py which
|
|
5
|
+
retrieves the CV results that were saved during training on SageMaker.
|
|
6
|
+
|
|
7
|
+
These local cross-fold functions are useful for:
|
|
8
|
+
- Re-running CV with different fold counts
|
|
9
|
+
- Leave-one-out cross-validation
|
|
10
|
+
- Custom CV experiments
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
from typing import Any, Tuple
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import xgboost as xgb
|
|
19
|
+
from sklearn.model_selection import KFold, StratifiedKFold
|
|
20
|
+
from sklearn.preprocessing import LabelEncoder
|
|
21
|
+
|
|
22
|
+
from workbench.utils.metrics_utils import compute_metrics_from_predictions
|
|
23
|
+
from workbench.utils.pandas_utils import expand_proba_column
|
|
24
|
+
from workbench.utils.xgboost_model_utils import xgboost_model_from_s3
|
|
25
|
+
|
|
26
|
+
log = logging.getLogger("workbench")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
30
|
+
"""
|
|
31
|
+
Performs K-fold cross-validation locally with detailed metrics.
|
|
32
|
+
|
|
33
|
+
Note: For most use cases, prefer using pull_cv_results() from xgboost_model_utils.py
|
|
34
|
+
which retrieves the CV results that were saved during training.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
workbench_model: Workbench model object
|
|
38
|
+
nfolds: Number of folds for cross-validation (default is 5)
|
|
39
|
+
Returns:
|
|
40
|
+
Tuple of:
|
|
41
|
+
- DataFrame with per-class metrics (and 'all' row for overall metrics)
|
|
42
|
+
- DataFrame with columns: id, target, prediction, and *_proba columns (for classifiers)
|
|
43
|
+
"""
|
|
44
|
+
from workbench.api import FeatureSet
|
|
45
|
+
|
|
46
|
+
# Load model
|
|
47
|
+
model_artifact_uri = workbench_model.model_data_url()
|
|
48
|
+
loaded_model = xgboost_model_from_s3(model_artifact_uri)
|
|
49
|
+
if loaded_model is None:
|
|
50
|
+
log.error("No XGBoost model found in the artifact.")
|
|
51
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
52
|
+
|
|
53
|
+
# Check if we got a full sklearn model or need to create one
|
|
54
|
+
if isinstance(loaded_model, (xgb.XGBClassifier, xgb.XGBRegressor)):
|
|
55
|
+
is_classifier = isinstance(loaded_model, xgb.XGBClassifier)
|
|
56
|
+
|
|
57
|
+
# Get the model's hyperparameters and ensure enable_categorical=True
|
|
58
|
+
params = loaded_model.get_params()
|
|
59
|
+
params["enable_categorical"] = True
|
|
60
|
+
|
|
61
|
+
# Create new model with same params but enable_categorical=True
|
|
62
|
+
if is_classifier:
|
|
63
|
+
xgb_model = xgb.XGBClassifier(**params)
|
|
64
|
+
else:
|
|
65
|
+
xgb_model = xgb.XGBRegressor(**params)
|
|
66
|
+
|
|
67
|
+
elif isinstance(loaded_model, xgb.Booster):
|
|
68
|
+
# Legacy: got a booster, need to wrap it
|
|
69
|
+
log.warning("Deprecated: Loaded model is a Booster, wrapping in sklearn model.")
|
|
70
|
+
is_classifier = workbench_model.model_type.value == "classifier"
|
|
71
|
+
xgb_model = (
|
|
72
|
+
xgb.XGBClassifier(enable_categorical=True) if is_classifier else xgb.XGBRegressor(enable_categorical=True)
|
|
73
|
+
)
|
|
74
|
+
xgb_model._Booster = loaded_model
|
|
75
|
+
else:
|
|
76
|
+
log.error(f"Unexpected model type: {type(loaded_model)}")
|
|
77
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
78
|
+
|
|
79
|
+
# Prepare data
|
|
80
|
+
fs = FeatureSet(workbench_model.get_input())
|
|
81
|
+
df = workbench_model.training_view().pull_dataframe()
|
|
82
|
+
|
|
83
|
+
# Extract sample weights if present
|
|
84
|
+
sample_weights = df.get("sample_weight")
|
|
85
|
+
if sample_weights is not None:
|
|
86
|
+
log.info(f"Using sample weights: min={sample_weights.min():.2f}, max={sample_weights.max():.2f}")
|
|
87
|
+
|
|
88
|
+
# Get columns
|
|
89
|
+
id_col = fs.id_column
|
|
90
|
+
target_col = workbench_model.target()
|
|
91
|
+
feature_cols = workbench_model.features()
|
|
92
|
+
print(f"Target column: {target_col}")
|
|
93
|
+
print(f"Feature columns: {len(feature_cols)} features")
|
|
94
|
+
|
|
95
|
+
# Convert string[python] to object, then to category for XGBoost compatibility
|
|
96
|
+
for col in feature_cols:
|
|
97
|
+
if pd.api.types.is_string_dtype(df[col]):
|
|
98
|
+
df[col] = df[col].astype("object").astype("category")
|
|
99
|
+
|
|
100
|
+
X = df[feature_cols]
|
|
101
|
+
y = df[target_col]
|
|
102
|
+
ids = df[id_col]
|
|
103
|
+
|
|
104
|
+
# Encode target if classifier
|
|
105
|
+
label_encoder = LabelEncoder() if is_classifier else None
|
|
106
|
+
if label_encoder:
|
|
107
|
+
y_encoded = label_encoder.fit_transform(y)
|
|
108
|
+
y_for_cv = pd.Series(y_encoded, index=y.index, name=target_col)
|
|
109
|
+
else:
|
|
110
|
+
y_for_cv = y
|
|
111
|
+
|
|
112
|
+
# Prepare KFold
|
|
113
|
+
kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
|
|
114
|
+
|
|
115
|
+
# Initialize predictions DataFrame
|
|
116
|
+
predictions_df = pd.DataFrame({id_col: ids, target_col: y})
|
|
117
|
+
|
|
118
|
+
# Perform cross-validation
|
|
119
|
+
for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y_for_cv), 1):
|
|
120
|
+
X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
|
|
121
|
+
y_train = y_for_cv.iloc[train_idx]
|
|
122
|
+
|
|
123
|
+
# Get sample weights for training fold
|
|
124
|
+
weights_train = sample_weights.iloc[train_idx] if sample_weights is not None else None
|
|
125
|
+
|
|
126
|
+
# Train and predict
|
|
127
|
+
xgb_model.fit(X_train, y_train, sample_weight=weights_train)
|
|
128
|
+
preds = xgb_model.predict(X_val)
|
|
129
|
+
|
|
130
|
+
# Store predictions (decode if classifier)
|
|
131
|
+
val_indices = X_val.index
|
|
132
|
+
if is_classifier:
|
|
133
|
+
predictions_df.loc[val_indices, "prediction"] = label_encoder.inverse_transform(preds.astype(int))
|
|
134
|
+
y_proba = xgb_model.predict_proba(X_val)
|
|
135
|
+
predictions_df.loc[val_indices, "pred_proba"] = pd.Series(y_proba.tolist(), index=val_indices)
|
|
136
|
+
else:
|
|
137
|
+
predictions_df.loc[val_indices, "prediction"] = preds
|
|
138
|
+
|
|
139
|
+
# Expand proba columns for classifiers
|
|
140
|
+
if is_classifier:
|
|
141
|
+
predictions_df = expand_proba_column(predictions_df, label_encoder.classes_)
|
|
142
|
+
|
|
143
|
+
# Compute metrics from the complete out-of-fold predictions
|
|
144
|
+
class_labels = list(label_encoder.classes_) if is_classifier else None
|
|
145
|
+
metrics_df = compute_metrics_from_predictions(predictions_df, target_col, class_labels)
|
|
146
|
+
|
|
147
|
+
return metrics_df, predictions_df
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def leave_one_out_inference(workbench_model: Any) -> pd.DataFrame:
|
|
151
|
+
"""
|
|
152
|
+
Performs leave-one-out cross-validation (parallelized).
|
|
153
|
+
For datasets > 1000 rows, first identifies top 100 worst predictions via 10-fold CV,
|
|
154
|
+
then performs true leave-one-out on those 100 samples.
|
|
155
|
+
Each model trains on ALL data except one sample.
|
|
156
|
+
"""
|
|
157
|
+
from workbench.api import FeatureSet
|
|
158
|
+
from joblib import Parallel, delayed
|
|
159
|
+
from tqdm import tqdm
|
|
160
|
+
|
|
161
|
+
def train_and_predict_one(model_params, is_classifier, X, y, train_idx, val_idx):
|
|
162
|
+
"""Train on train_idx, predict on val_idx."""
|
|
163
|
+
model = xgb.XGBClassifier(**model_params) if is_classifier else xgb.XGBRegressor(**model_params)
|
|
164
|
+
model.fit(X[train_idx], y[train_idx])
|
|
165
|
+
return model.predict(X[val_idx])[0]
|
|
166
|
+
|
|
167
|
+
# Load model and get params
|
|
168
|
+
model_artifact_uri = workbench_model.model_data_url()
|
|
169
|
+
loaded_model = xgboost_model_from_s3(model_artifact_uri)
|
|
170
|
+
if loaded_model is None:
|
|
171
|
+
log.error("No XGBoost model found in the artifact.")
|
|
172
|
+
return pd.DataFrame()
|
|
173
|
+
|
|
174
|
+
if isinstance(loaded_model, (xgb.XGBClassifier, xgb.XGBRegressor)):
|
|
175
|
+
is_classifier = isinstance(loaded_model, xgb.XGBClassifier)
|
|
176
|
+
model_params = loaded_model.get_params()
|
|
177
|
+
elif isinstance(loaded_model, xgb.Booster):
|
|
178
|
+
log.warning("Deprecated: Loaded model is a Booster, wrapping in sklearn model.")
|
|
179
|
+
is_classifier = workbench_model.model_type.value == "classifier"
|
|
180
|
+
model_params = {"enable_categorical": True}
|
|
181
|
+
else:
|
|
182
|
+
log.error(f"Unexpected model type: {type(loaded_model)}")
|
|
183
|
+
return pd.DataFrame()
|
|
184
|
+
|
|
185
|
+
# Load and prepare data
|
|
186
|
+
fs = FeatureSet(workbench_model.get_input())
|
|
187
|
+
df = workbench_model.training_view().pull_dataframe()
|
|
188
|
+
id_col = fs.id_column
|
|
189
|
+
target_col = workbench_model.target()
|
|
190
|
+
feature_cols = workbench_model.features()
|
|
191
|
+
|
|
192
|
+
# Convert string[python] to object, then to category for XGBoost compatibility
|
|
193
|
+
# This avoids XGBoost's issue with pandas 2.x string[python] dtype in categorical categories
|
|
194
|
+
for col in feature_cols:
|
|
195
|
+
if pd.api.types.is_string_dtype(df[col]):
|
|
196
|
+
# Double conversion: string[python] -> object -> category
|
|
197
|
+
df[col] = df[col].astype("object").astype("category")
|
|
198
|
+
|
|
199
|
+
# Determine which samples to run LOO on
|
|
200
|
+
if len(df) > 1000:
|
|
201
|
+
log.important(f"Dataset has {len(df)} rows. Running 10-fold CV to identify top 1000 worst predictions...")
|
|
202
|
+
_, predictions_df = cross_fold_inference(workbench_model, nfolds=10)
|
|
203
|
+
predictions_df["residual_abs"] = np.abs(predictions_df[target_col] - predictions_df["prediction"])
|
|
204
|
+
worst_samples = predictions_df.nlargest(1000, "residual_abs")
|
|
205
|
+
worst_ids = worst_samples[id_col].values
|
|
206
|
+
loo_indices = df[df[id_col].isin(worst_ids)].index.values
|
|
207
|
+
log.important(f"Running leave-one-out CV on 1000 worst samples. Each model trains on {len(df)-1} rows...")
|
|
208
|
+
else:
|
|
209
|
+
log.important(f"Running leave-one-out CV on all {len(df)} samples...")
|
|
210
|
+
loo_indices = df.index.values
|
|
211
|
+
|
|
212
|
+
# Prepare full dataset for training
|
|
213
|
+
X_full = df[feature_cols].values
|
|
214
|
+
y_full = df[target_col].values
|
|
215
|
+
|
|
216
|
+
# Encode target if classifier
|
|
217
|
+
label_encoder = LabelEncoder() if is_classifier else None
|
|
218
|
+
if label_encoder:
|
|
219
|
+
y_full = label_encoder.fit_transform(y_full)
|
|
220
|
+
|
|
221
|
+
# Generate LOO splits
|
|
222
|
+
splits = []
|
|
223
|
+
for loo_idx in loo_indices:
|
|
224
|
+
train_idx = np.delete(np.arange(len(X_full)), loo_idx)
|
|
225
|
+
val_idx = np.array([loo_idx])
|
|
226
|
+
splits.append((train_idx, val_idx))
|
|
227
|
+
|
|
228
|
+
# Parallel execution
|
|
229
|
+
predictions = Parallel(n_jobs=4)(
|
|
230
|
+
delayed(train_and_predict_one)(model_params, is_classifier, X_full, y_full, train_idx, val_idx)
|
|
231
|
+
for train_idx, val_idx in tqdm(splits, desc="LOO CV")
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Build results dataframe
|
|
235
|
+
predictions_array = np.array(predictions)
|
|
236
|
+
if label_encoder:
|
|
237
|
+
predictions_array = label_encoder.inverse_transform(predictions_array.astype(int))
|
|
238
|
+
|
|
239
|
+
predictions_df = pd.DataFrame(
|
|
240
|
+
{
|
|
241
|
+
id_col: df.loc[loo_indices, id_col].values,
|
|
242
|
+
target_col: df.loc[loo_indices, target_col].values,
|
|
243
|
+
"prediction": predictions_array,
|
|
244
|
+
}
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
predictions_df["residual_abs"] = np.abs(predictions_df[target_col] - predictions_df["prediction"])
|
|
248
|
+
|
|
249
|
+
return predictions_df
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
if __name__ == "__main__":
|
|
253
|
+
"""Exercise the Local Cross-Fold Utilities"""
|
|
254
|
+
from workbench.api import Model
|
|
255
|
+
from pprint import pprint
|
|
256
|
+
|
|
257
|
+
print("\n=== LOCAL CROSS FOLD REGRESSION EXAMPLE ===")
|
|
258
|
+
model = Model("abalone-regression")
|
|
259
|
+
results, df = cross_fold_inference(model)
|
|
260
|
+
pprint(results)
|
|
261
|
+
print(df.head())
|
|
262
|
+
|
|
263
|
+
print("\n=== LOCAL CROSS FOLD CLASSIFICATION EXAMPLE ===")
|
|
264
|
+
model = Model("wine-classification")
|
|
265
|
+
results, df = cross_fold_inference(model)
|
|
266
|
+
pprint(results)
|
|
267
|
+
print(df.head())
|