workbench 0.8.178__py3-none-any.whl → 0.8.179__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/api/endpoint.py +3 -2
- workbench/core/artifacts/endpoint_core.py +5 -5
- workbench/core/artifacts/feature_set_core.py +32 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +37 -34
- workbench/model_scripts/custom_models/uq_models/mapie.template +35 -32
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +7 -22
- workbench/model_scripts/custom_models/uq_models/ngboost.template +5 -12
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
- workbench/model_scripts/pytorch_model/pytorch.template +9 -18
- workbench/model_scripts/quant_regression/quant_regression.template +5 -10
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/xgb_model/generated_model_script.py +24 -33
- workbench/model_scripts/xgb_model/xgb_model.template +23 -32
- workbench/utils/model_utils.py +2 -1
- workbench/utils/xgboost_model_utils.py +160 -137
- {workbench-0.8.178.dist-info → workbench-0.8.179.dist-info}/METADATA +1 -1
- {workbench-0.8.178.dist-info → workbench-0.8.179.dist-info}/RECORD +25 -25
- {workbench-0.8.178.dist-info → workbench-0.8.179.dist-info}/WHEEL +0 -0
- {workbench-0.8.178.dist-info → workbench-0.8.179.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.178.dist-info → workbench-0.8.179.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.178.dist-info → workbench-0.8.179.dist-info}/top_level.txt +0 -0
|
@@ -4,6 +4,7 @@ import logging
|
|
|
4
4
|
import os
|
|
5
5
|
import tempfile
|
|
6
6
|
import tarfile
|
|
7
|
+
import joblib
|
|
7
8
|
import pickle
|
|
8
9
|
import glob
|
|
9
10
|
import awswrangler as wr
|
|
@@ -16,7 +17,6 @@ from typing import Dict, Any
|
|
|
16
17
|
from sklearn.model_selection import KFold, StratifiedKFold
|
|
17
18
|
from sklearn.metrics import (
|
|
18
19
|
precision_recall_fscore_support,
|
|
19
|
-
confusion_matrix,
|
|
20
20
|
mean_squared_error,
|
|
21
21
|
mean_absolute_error,
|
|
22
22
|
r2_score,
|
|
@@ -34,14 +34,12 @@ log = logging.getLogger("workbench")
|
|
|
34
34
|
def xgboost_model_from_s3(model_artifact_uri: str):
|
|
35
35
|
"""
|
|
36
36
|
Download and extract XGBoost model artifact from S3, then load the model into memory.
|
|
37
|
-
Handles both direct XGBoost model files and pickled models.
|
|
38
|
-
Ensures categorical feature support is enabled.
|
|
39
37
|
|
|
40
38
|
Args:
|
|
41
39
|
model_artifact_uri (str): S3 URI of the model artifact.
|
|
42
40
|
|
|
43
41
|
Returns:
|
|
44
|
-
Loaded XGBoost model or None if unavailable.
|
|
42
|
+
Loaded XGBoost model (XGBClassifier, XGBRegressor, or Booster) or None if unavailable.
|
|
45
43
|
"""
|
|
46
44
|
|
|
47
45
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
@@ -55,64 +53,86 @@ def xgboost_model_from_s3(model_artifact_uri: str):
|
|
|
55
53
|
|
|
56
54
|
# Define model file patterns to search for (in order of preference)
|
|
57
55
|
patterns = [
|
|
58
|
-
#
|
|
59
|
-
os.path.join(tmpdir, "
|
|
60
|
-
os.path.join(tmpdir, "
|
|
61
|
-
os.path.join(tmpdir, "model"),
|
|
62
|
-
os.path.join(tmpdir, "*.
|
|
56
|
+
# Joblib models (preferred - preserves everything)
|
|
57
|
+
os.path.join(tmpdir, "*model*.joblib"),
|
|
58
|
+
os.path.join(tmpdir, "xgb*.joblib"),
|
|
59
|
+
os.path.join(tmpdir, "**", "*model*.joblib"),
|
|
60
|
+
os.path.join(tmpdir, "**", "xgb*.joblib"),
|
|
61
|
+
# Pickle models (also preserves everything)
|
|
62
|
+
os.path.join(tmpdir, "*model*.pkl"),
|
|
63
|
+
os.path.join(tmpdir, "xgb*.pkl"),
|
|
64
|
+
os.path.join(tmpdir, "**", "*model*.pkl"),
|
|
65
|
+
os.path.join(tmpdir, "**", "xgb*.pkl"),
|
|
66
|
+
# JSON models (fallback - requires reconstruction)
|
|
67
|
+
os.path.join(tmpdir, "*model*.json"),
|
|
68
|
+
os.path.join(tmpdir, "xgb*.json"),
|
|
63
69
|
os.path.join(tmpdir, "**", "*model*.json"),
|
|
64
|
-
os.path.join(tmpdir, "**", "
|
|
65
|
-
# Pickled models
|
|
66
|
-
os.path.join(tmpdir, "*.pkl"),
|
|
67
|
-
os.path.join(tmpdir, "**", "*.pkl"),
|
|
68
|
-
os.path.join(tmpdir, "*.pickle"),
|
|
69
|
-
os.path.join(tmpdir, "**", "*.pickle"),
|
|
70
|
+
os.path.join(tmpdir, "**", "xgb*.json"),
|
|
70
71
|
]
|
|
71
72
|
|
|
72
73
|
# Try each pattern
|
|
73
74
|
for pattern in patterns:
|
|
74
|
-
# Use glob to find all matching files
|
|
75
75
|
for model_path in glob.glob(pattern, recursive=True):
|
|
76
|
-
#
|
|
76
|
+
# Skip files that are clearly not XGBoost models
|
|
77
|
+
filename = os.path.basename(model_path).lower()
|
|
78
|
+
if any(skip in filename for skip in ["label_encoder", "scaler", "preprocessor", "transformer"]):
|
|
79
|
+
log.debug(f"Skipping non-model file: {model_path}")
|
|
80
|
+
continue
|
|
81
|
+
|
|
77
82
|
_, ext = os.path.splitext(model_path)
|
|
78
83
|
|
|
79
84
|
try:
|
|
80
|
-
if ext
|
|
81
|
-
|
|
85
|
+
if ext == ".joblib":
|
|
86
|
+
model = joblib.load(model_path)
|
|
87
|
+
# Verify it's actually an XGBoost model
|
|
88
|
+
if isinstance(model, (xgb.XGBClassifier, xgb.XGBRegressor, xgb.Booster)):
|
|
89
|
+
log.important(f"Loaded XGBoost model from joblib: {model_path}")
|
|
90
|
+
return model
|
|
91
|
+
else:
|
|
92
|
+
log.debug(f"Skipping non-XGBoost object from {model_path}: {type(model)}")
|
|
93
|
+
|
|
94
|
+
elif ext in [".pkl", ".pickle"]:
|
|
82
95
|
with open(model_path, "rb") as f:
|
|
83
96
|
model = pickle.load(f)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
log.important(f"Loaded XGBoost Booster from pickle: {model_path}")
|
|
97
|
+
# Verify it's actually an XGBoost model
|
|
98
|
+
if isinstance(model, (xgb.XGBClassifier, xgb.XGBRegressor, xgb.Booster)):
|
|
99
|
+
log.important(f"Loaded XGBoost model from pickle: {model_path}")
|
|
88
100
|
return model
|
|
89
|
-
|
|
90
|
-
log.
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
# Handle direct XGBoost model files
|
|
101
|
+
else:
|
|
102
|
+
log.debug(f"Skipping non-XGBoost object from {model_path}: {type(model)}")
|
|
103
|
+
|
|
104
|
+
elif ext == ".json":
|
|
105
|
+
# JSON files should be XGBoost models by definition
|
|
95
106
|
booster = xgb.Booster()
|
|
96
107
|
booster.load_model(model_path)
|
|
97
|
-
log.important(f"Loaded XGBoost
|
|
108
|
+
log.important(f"Loaded XGBoost booster from JSON: {model_path}")
|
|
98
109
|
return booster
|
|
110
|
+
|
|
99
111
|
except Exception as e:
|
|
100
|
-
log.
|
|
101
|
-
continue
|
|
112
|
+
log.debug(f"Failed to load {model_path}: {e}")
|
|
113
|
+
continue
|
|
102
114
|
|
|
103
|
-
# If no model found
|
|
104
115
|
log.error("No XGBoost model found in the artifact.")
|
|
105
116
|
return None
|
|
106
117
|
|
|
107
118
|
|
|
108
|
-
def feature_importance(workbench_model, importance_type: str = "
|
|
119
|
+
def feature_importance(workbench_model, importance_type: str = "gain") -> Optional[List[Tuple[str, float]]]:
|
|
109
120
|
"""
|
|
110
121
|
Get sorted feature importances from a Workbench Model object.
|
|
111
122
|
|
|
112
123
|
Args:
|
|
113
124
|
workbench_model: Workbench model object
|
|
114
|
-
importance_type: Type of feature importance.
|
|
115
|
-
|
|
125
|
+
importance_type: Type of feature importance. Options:
|
|
126
|
+
- 'gain' (default): Average improvement in loss/objective when feature is used.
|
|
127
|
+
Best for understanding predictive power of features.
|
|
128
|
+
- 'weight': Number of times a feature appears in trees (split count).
|
|
129
|
+
Useful for understanding model complexity and feature usage frequency.
|
|
130
|
+
- 'cover': Average number of samples affected when feature is used.
|
|
131
|
+
Shows the relative quantity of observations related to this feature.
|
|
132
|
+
- 'total_gain': Total improvement in loss/objective across all splits.
|
|
133
|
+
Similar to 'gain' but not averaged (can be biased toward frequent features).
|
|
134
|
+
- 'total_cover': Total number of samples affected across all splits.
|
|
135
|
+
Similar to 'cover' but not averaged.
|
|
116
136
|
|
|
117
137
|
Returns:
|
|
118
138
|
List of tuples (feature, importance) sorted by importance value (descending).
|
|
@@ -121,7 +141,8 @@ def feature_importance(workbench_model, importance_type: str = "weight") -> Opti
|
|
|
121
141
|
|
|
122
142
|
Note:
|
|
123
143
|
XGBoost's get_score() only returns features with non-zero importance.
|
|
124
|
-
This function ensures all model features are included in the output
|
|
144
|
+
This function ensures all model features are included in the output,
|
|
145
|
+
adding zero values for features that weren't used in any tree splits.
|
|
125
146
|
"""
|
|
126
147
|
model_artifact_uri = workbench_model.model_data_url()
|
|
127
148
|
xgb_model = xgboost_model_from_s3(model_artifact_uri)
|
|
@@ -129,11 +150,18 @@ def feature_importance(workbench_model, importance_type: str = "weight") -> Opti
|
|
|
129
150
|
log.error("No XGBoost model found in the artifact.")
|
|
130
151
|
return None
|
|
131
152
|
|
|
132
|
-
#
|
|
133
|
-
|
|
153
|
+
# Check if we got a full sklearn model or just a booster (for backwards compatibility)
|
|
154
|
+
if hasattr(xgb_model, "get_booster"):
|
|
155
|
+
# Full sklearn model - get the booster for feature importance
|
|
156
|
+
booster = xgb_model.get_booster()
|
|
157
|
+
all_features = booster.feature_names
|
|
158
|
+
else:
|
|
159
|
+
# Already a booster (legacy JSON load)
|
|
160
|
+
booster = xgb_model
|
|
161
|
+
all_features = xgb_model.feature_names
|
|
134
162
|
|
|
135
|
-
# Get
|
|
136
|
-
|
|
163
|
+
# Get feature importances (only non-zero features)
|
|
164
|
+
importances = booster.get_score(importance_type=importance_type)
|
|
137
165
|
|
|
138
166
|
# Create complete importance dict with zeros for missing features
|
|
139
167
|
complete_importances = {feat: importances.get(feat, 0.0) for feat in all_features}
|
|
@@ -230,143 +258,132 @@ def leaf_stats(df: pd.DataFrame, target_col: str) -> pd.DataFrame:
|
|
|
230
258
|
return result_df
|
|
231
259
|
|
|
232
260
|
|
|
233
|
-
def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Dict[str, Any]:
|
|
261
|
+
def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Tuple[Dict[str, Any], pd.DataFrame]:
|
|
234
262
|
"""
|
|
235
263
|
Performs K-fold cross-validation with detailed metrics.
|
|
236
264
|
Args:
|
|
237
265
|
workbench_model: Workbench model object
|
|
238
266
|
nfolds: Number of folds for cross-validation (default is 5)
|
|
239
267
|
Returns:
|
|
240
|
-
|
|
241
|
-
-
|
|
242
|
-
|
|
243
|
-
|
|
268
|
+
Tuple of:
|
|
269
|
+
- Dictionary containing:
|
|
270
|
+
- folds: Dictionary of formatted strings for each fold
|
|
271
|
+
- summary_metrics: Summary metrics across folds
|
|
272
|
+
- DataFrame with columns: id, target, prediction (out-of-fold predictions for all samples)
|
|
244
273
|
"""
|
|
245
274
|
from workbench.api import FeatureSet
|
|
246
275
|
|
|
247
276
|
# Load model
|
|
248
|
-
model_type = workbench_model.model_type.value
|
|
249
277
|
model_artifact_uri = workbench_model.model_data_url()
|
|
250
|
-
|
|
251
|
-
if
|
|
278
|
+
loaded_model = xgboost_model_from_s3(model_artifact_uri)
|
|
279
|
+
if loaded_model is None:
|
|
252
280
|
log.error("No XGBoost model found in the artifact.")
|
|
253
|
-
return {}
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
281
|
+
return {}, pd.DataFrame()
|
|
282
|
+
|
|
283
|
+
# Check if we got a full sklearn model or need to create one
|
|
284
|
+
if isinstance(loaded_model, (xgb.XGBClassifier, xgb.XGBRegressor)):
|
|
285
|
+
xgb_model = loaded_model
|
|
286
|
+
is_classifier = isinstance(xgb_model, xgb.XGBClassifier)
|
|
287
|
+
elif isinstance(loaded_model, xgb.Booster):
|
|
288
|
+
# Legacy: got a booster, need to wrap it
|
|
289
|
+
log.warning("Deprecated: Loaded model is a Booster, wrapping in sklearn model.")
|
|
290
|
+
is_classifier = workbench_model.model_type.value == "classifier"
|
|
291
|
+
xgb_model = (
|
|
292
|
+
xgb.XGBClassifier(enable_categorical=True) if is_classifier else xgb.XGBRegressor(enable_categorical=True)
|
|
293
|
+
)
|
|
294
|
+
xgb_model._Booster = loaded_model
|
|
295
|
+
else:
|
|
296
|
+
log.error(f"Unexpected model type: {type(loaded_model)}")
|
|
297
|
+
return {}, pd.DataFrame()
|
|
298
|
+
|
|
260
299
|
# Prepare data
|
|
261
300
|
fs = FeatureSet(workbench_model.get_input())
|
|
262
301
|
df = fs.view("training").pull_dataframe()
|
|
302
|
+
|
|
303
|
+
# Get id column - assuming FeatureSet has an id_column attribute or similar
|
|
304
|
+
id_col = fs.id_column
|
|
305
|
+
target_col = workbench_model.target()
|
|
263
306
|
feature_cols = workbench_model.features()
|
|
307
|
+
|
|
264
308
|
# Convert string features to categorical
|
|
265
309
|
for col in feature_cols:
|
|
266
310
|
if df[col].dtype in ["object", "string"]:
|
|
267
311
|
df[col] = df[col].astype("category")
|
|
268
|
-
# Split X and y
|
|
269
|
-
X = df[workbench_model.features()]
|
|
270
|
-
y = df[workbench_model.target()]
|
|
271
312
|
|
|
272
|
-
|
|
313
|
+
X = df[feature_cols]
|
|
314
|
+
y = df[target_col]
|
|
315
|
+
ids = df[id_col]
|
|
316
|
+
|
|
317
|
+
# Encode target if classifier
|
|
273
318
|
label_encoder = LabelEncoder() if is_classifier else None
|
|
274
319
|
if label_encoder:
|
|
275
|
-
|
|
320
|
+
y_encoded = label_encoder.fit_transform(y)
|
|
321
|
+
y_for_cv = pd.Series(y_encoded, index=y.index, name=target_col)
|
|
322
|
+
else:
|
|
323
|
+
y_for_cv = y
|
|
324
|
+
|
|
276
325
|
# Prepare KFold
|
|
277
|
-
kfold = (
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
)
|
|
326
|
+
kfold = (StratifiedKFold if is_classifier else KFold)(n_splits=nfolds, shuffle=True, random_state=42)
|
|
327
|
+
|
|
328
|
+
# Initialize results collection
|
|
329
|
+
fold_metrics = []
|
|
330
|
+
predictions_df = pd.DataFrame({id_col: ids, target_col: y}) # Keep original values
|
|
331
|
+
# Note: 'prediction' column will be created automatically with correct dtype
|
|
282
332
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
all_actuals = []
|
|
286
|
-
for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y)):
|
|
333
|
+
# Perform cross-validation
|
|
334
|
+
for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y_for_cv), 1):
|
|
287
335
|
X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
|
|
288
|
-
y_train, y_val =
|
|
336
|
+
y_train, y_val = y_for_cv.iloc[train_idx], y_for_cv.iloc[val_idx]
|
|
289
337
|
|
|
290
|
-
# Train
|
|
338
|
+
# Train and predict
|
|
291
339
|
xgb_model.fit(X_train, y_train)
|
|
292
340
|
preds = xgb_model.predict(X_val)
|
|
293
|
-
all_predictions.extend(preds)
|
|
294
|
-
all_actuals.extend(y_val)
|
|
295
341
|
|
|
296
|
-
#
|
|
297
|
-
|
|
342
|
+
# Store predictions (decode if classifier)
|
|
343
|
+
val_indices = X_val.index
|
|
344
|
+
if is_classifier:
|
|
345
|
+
predictions_df.loc[val_indices, "prediction"] = label_encoder.inverse_transform(preds.astype(int))
|
|
346
|
+
else:
|
|
347
|
+
predictions_df.loc[val_indices, "prediction"] = preds
|
|
298
348
|
|
|
349
|
+
# Calculate fold metrics
|
|
299
350
|
if is_classifier:
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
351
|
+
y_val_orig = label_encoder.inverse_transform(y_val)
|
|
352
|
+
preds_orig = label_encoder.inverse_transform(preds.astype(int))
|
|
353
|
+
prec, rec, f1, _ = precision_recall_fscore_support(
|
|
354
|
+
y_val_orig, preds_orig, average="weighted", zero_division=0
|
|
304
355
|
)
|
|
305
|
-
fold_metrics.
|
|
356
|
+
fold_metrics.append({"fold": fold_idx, "precision": prec, "recall": rec, "fscore": f1})
|
|
306
357
|
else:
|
|
307
|
-
fold_metrics.
|
|
358
|
+
fold_metrics.append(
|
|
308
359
|
{
|
|
309
|
-
"
|
|
310
|
-
"
|
|
311
|
-
"
|
|
360
|
+
"fold": fold_idx,
|
|
361
|
+
"rmse": np.sqrt(mean_squared_error(y_val, preds)),
|
|
362
|
+
"mae": mean_absolute_error(y_val, preds),
|
|
363
|
+
"r2": r2_score(y_val, preds),
|
|
312
364
|
}
|
|
313
365
|
)
|
|
314
366
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
scores = precision_recall_fscore_support(
|
|
322
|
-
all_actuals_original, all_predictions_original, average="weighted", zero_division=0
|
|
323
|
-
)
|
|
324
|
-
overall_metrics.update(
|
|
325
|
-
{
|
|
326
|
-
"precision": float(scores[0]),
|
|
327
|
-
"recall": float(scores[1]),
|
|
328
|
-
"fscore": float(scores[2]),
|
|
329
|
-
"confusion_matrix": confusion_matrix(
|
|
330
|
-
all_actuals_original, all_predictions_original, labels=label_encoder.classes_
|
|
331
|
-
).tolist(),
|
|
332
|
-
"label_names": list(label_encoder.classes_),
|
|
333
|
-
}
|
|
334
|
-
)
|
|
335
|
-
else:
|
|
336
|
-
overall_metrics.update(
|
|
337
|
-
{
|
|
338
|
-
"rmse": float(np.sqrt(mean_squared_error(all_actuals, all_predictions))),
|
|
339
|
-
"mae": float(mean_absolute_error(all_actuals, all_predictions)),
|
|
340
|
-
"r2": float(r2_score(all_actuals, all_predictions)),
|
|
341
|
-
}
|
|
342
|
-
)
|
|
343
|
-
# Calculate summary metrics across folds
|
|
344
|
-
summary_metrics = {}
|
|
345
|
-
metrics_to_aggregate = ["precision", "recall", "fscore"] if is_classifier else ["rmse", "mae", "r2"]
|
|
346
|
-
|
|
347
|
-
for metric in metrics_to_aggregate:
|
|
348
|
-
values = [fold[metric] for fold in fold_results]
|
|
349
|
-
summary_metrics[metric] = f"{float(np.mean(values)):.3f} ±{float(np.std(values)):.3f}"
|
|
350
|
-
# Format fold results as strings (TBD section)
|
|
367
|
+
# Calculate summary metrics (mean ± std)
|
|
368
|
+
fold_df = pd.DataFrame(fold_metrics)
|
|
369
|
+
metric_names = ["precision", "recall", "fscore"] if is_classifier else ["rmse", "mae", "r2"]
|
|
370
|
+
summary_metrics = {metric: f"{fold_df[metric].mean():.3f} ±{fold_df[metric].std():.3f}" for metric in metric_names}
|
|
371
|
+
|
|
372
|
+
# Format fold results for display
|
|
351
373
|
formatted_folds = {}
|
|
352
|
-
for
|
|
353
|
-
fold_key = f"Fold {
|
|
374
|
+
for _, row in fold_df.iterrows():
|
|
375
|
+
fold_key = f"Fold {int(row['fold'])}"
|
|
354
376
|
if is_classifier:
|
|
355
377
|
formatted_folds[fold_key] = (
|
|
356
|
-
f"precision: {
|
|
357
|
-
f"recall: {fold_data['recall']:.3f} "
|
|
358
|
-
f"fscore: {fold_data['fscore']:.3f}"
|
|
378
|
+
f"precision: {row['precision']:.3f} " f"recall: {row['recall']:.3f} " f"fscore: {row['fscore']:.3f}"
|
|
359
379
|
)
|
|
360
380
|
else:
|
|
361
|
-
formatted_folds[fold_key] =
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
# "overall_metrics": overall_metrics,
|
|
368
|
-
"folds": formatted_folds,
|
|
369
|
-
}
|
|
381
|
+
formatted_folds[fold_key] = f"rmse: {row['rmse']:.3f} " f"mae: {row['mae']:.3f} " f"r2: {row['r2']:.3f}"
|
|
382
|
+
|
|
383
|
+
# Build return dictionary
|
|
384
|
+
metrics_dict = {"summary_metrics": summary_metrics, "folds": formatted_folds}
|
|
385
|
+
|
|
386
|
+
return metrics_dict, predictions_df
|
|
370
387
|
|
|
371
388
|
|
|
372
389
|
if __name__ == "__main__":
|
|
@@ -384,6 +401,10 @@ if __name__ == "__main__":
|
|
|
384
401
|
model_artifact_uri = model.model_data_url()
|
|
385
402
|
xgb_model = xgboost_model_from_s3(model_artifact_uri)
|
|
386
403
|
|
|
404
|
+
# Verify enable_categorical is preserved (for debugging/confidence)
|
|
405
|
+
print(f"Model parameters: {xgb_model.get_params()}")
|
|
406
|
+
print(f"enable_categorical: {xgb_model.enable_categorical}")
|
|
407
|
+
|
|
387
408
|
# Test with UQ Model
|
|
388
409
|
uq_model = Model("aqsol-uq")
|
|
389
410
|
_xgb_model = xgboost_model_from_s3(uq_model.model_data_url())
|
|
@@ -408,10 +429,12 @@ if __name__ == "__main__":
|
|
|
408
429
|
|
|
409
430
|
print("\n=== CROSS FOLD REGRESSION EXAMPLE ===")
|
|
410
431
|
model = Model("abalone-regression")
|
|
411
|
-
results = cross_fold_inference(model)
|
|
432
|
+
results, df = cross_fold_inference(model)
|
|
412
433
|
pprint(results)
|
|
434
|
+
print(df.head())
|
|
413
435
|
|
|
414
436
|
print("\n=== CROSS FOLD CLASSIFICATION EXAMPLE ===")
|
|
415
437
|
model = Model("wine-classification")
|
|
416
|
-
results = cross_fold_inference(model)
|
|
438
|
+
results, df = cross_fold_inference(model)
|
|
417
439
|
pprint(results)
|
|
440
|
+
print(df.head())
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: workbench
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.179
|
|
4
4
|
Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
|
|
5
5
|
Author-email: SuperCowPowers LLC <support@supercowpowers.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -31,7 +31,7 @@ workbench/api/__init__.py,sha256=kvrP70ypDOMdPGj_Eeftdh8J0lu_1qQVne6GXMkD4_E,102
|
|
|
31
31
|
workbench/api/compound.py,sha256=kf5EaM5qjWwsZutcxqj9IC_MPnDV1uVHDMns9OA_GOo,2545
|
|
32
32
|
workbench/api/data_source.py,sha256=Ngz36YZWxFfpJbmURhM1LQPYjh5kdpZNGo6_fCRePbA,8321
|
|
33
33
|
workbench/api/df_store.py,sha256=Wybb3zO-jPpAi2Ns8Ks1-lagvXAaBlRpBZHhnnl3Lms,6131
|
|
34
|
-
workbench/api/endpoint.py,sha256=
|
|
34
|
+
workbench/api/endpoint.py,sha256=RDQs78OFRSp9GYGHHph0gLtAfF4p3We96bzdJ4-9bzw,3926
|
|
35
35
|
workbench/api/feature_set.py,sha256=Yxei3tvWR4gSLcdJnNndux07dNeKNu1HKgsChJtHxEM,6633
|
|
36
36
|
workbench/api/graph_store.py,sha256=LremJyPrQFgsHb7hxsctuCsoxx3p7TKtaY5qALHe6pc,4372
|
|
37
37
|
workbench/api/meta.py,sha256=1_9989cPvf3hd3tA-83hLijOGNnhwXAF8aZF45adeDQ,8596
|
|
@@ -54,8 +54,8 @@ workbench/core/artifacts/cached_artifact_mixin.py,sha256=ngqFLZ4cQx_TFouXZgXZQsv
|
|
|
54
54
|
workbench/core/artifacts/data_capture_core.py,sha256=q8f79rRTYiZ7T4IQRWXl8ZvPpcvZyNxYERwvo8o0OQc,14858
|
|
55
55
|
workbench/core/artifacts/data_source_abstract.py,sha256=5IRCzFVK-17cd4NXPMRfx99vQAmQ0WHE5jcm5RfsVTg,10619
|
|
56
56
|
workbench/core/artifacts/data_source_factory.py,sha256=YL_tA5fsgubbB3dPF6T4tO0rGgz-6oo3ge4i_YXVC-M,2380
|
|
57
|
-
workbench/core/artifacts/endpoint_core.py,sha256=
|
|
58
|
-
workbench/core/artifacts/feature_set_core.py,sha256=
|
|
57
|
+
workbench/core/artifacts/endpoint_core.py,sha256=S9QHtXrQXkCTY-coaYSsAGJi60nZj3Aq9ruyVE1PUQs,49224
|
|
58
|
+
workbench/core/artifacts/feature_set_core.py,sha256=7b1o_PzxtwaYC-W2zxlkltiO0fYULA8CVGWwHNmqgtI,31457
|
|
59
59
|
workbench/core/artifacts/model_core.py,sha256=ECDwQ0qM5qb1yGJ07U70BVdfkrW9m7p9e6YJWib3uR0,50855
|
|
60
60
|
workbench/core/artifacts/monitor_core.py,sha256=M307yz7tEzOEHgv-LmtVy9jKjSbM98fHW3ckmNYrwlU,27897
|
|
61
61
|
workbench/core/cloud_platform/cloud_meta.py,sha256=-g4-LTC3D0PXb3VfaXdLR1ERijKuHdffeMK_zhD-koQ,8809
|
|
@@ -132,36 +132,36 @@ workbench/model_scripts/custom_models/chem_info/requirements.txt,sha256=7HBUzvNi
|
|
|
132
132
|
workbench/model_scripts/custom_models/meta_endpoints/example.py,sha256=hzOAuLhIGB8vei-555ruNxpsE1GhuByHGjGB0zw8GSs,1726
|
|
133
133
|
workbench/model_scripts/custom_models/network_security/Readme.md,sha256=Z2gtiu0hLHvEJ1x-_oFq3qJZcsK81sceBAGAGltpqQ8,222
|
|
134
134
|
workbench/model_scripts/custom_models/proximity/Readme.md,sha256=RlMFAJZgAT2mCgDk-UwR_R0Y_NbCqeI5-8DUsxsbpWQ,289
|
|
135
|
-
workbench/model_scripts/custom_models/proximity/feature_space_proximity.template,sha256=
|
|
135
|
+
workbench/model_scripts/custom_models/proximity/feature_space_proximity.template,sha256=eOllmqB20BWtTiV53dgpIqXKtgSbPFDW_zf8PvM3oF0,4813
|
|
136
136
|
workbench/model_scripts/custom_models/proximity/generated_model_script.py,sha256=RdbKbXtrSNYQJvB-oLcRHpJ6w0TM7zbmMfuocHb7GM0,7967
|
|
137
137
|
workbench/model_scripts/custom_models/proximity/proximity.py,sha256=zqmNlX70LnWXr5fdtFFQppSNTLjlOciQVrjGr-g9jRE,13716
|
|
138
138
|
workbench/model_scripts/custom_models/proximity/requirements.txt,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
139
139
|
workbench/model_scripts/custom_models/uq_models/Readme.md,sha256=UVpL-lvtTrLqwBeQFinLhd_uNrEw4JUlggIdUSDrd-w,188
|
|
140
|
-
workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template,sha256=
|
|
141
|
-
workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template,sha256=
|
|
142
|
-
workbench/model_scripts/custom_models/uq_models/gaussian_process.template,sha256=
|
|
143
|
-
workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=
|
|
144
|
-
workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=
|
|
145
|
-
workbench/model_scripts/custom_models/uq_models/meta_uq.template,sha256=
|
|
146
|
-
workbench/model_scripts/custom_models/uq_models/ngboost.template,sha256=
|
|
140
|
+
workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template,sha256=ca3CaAk6HVuNv1HnPgABTzRY3oDrRxomjgD4V1ZDwoc,6448
|
|
141
|
+
workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template,sha256=xlKLHeLQkScONnrlbAGIsrCm2wwsvcfv4Vdrw4nlc_8,13457
|
|
142
|
+
workbench/model_scripts/custom_models/uq_models/gaussian_process.template,sha256=3nMlCi8nEbc4N-MQTzjfIcljfDQkUmWeLBfmd18m5fg,6632
|
|
143
|
+
workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=hE54OgIsR20-5RBjqqjUp5Z_aMzrg1Ykhwx4J3kG8Pw,19401
|
|
144
|
+
workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=FEBYxM-gVzwHWYga11DZf3ORUqmCrT0eQk8apixwx_E,19000
|
|
145
|
+
workbench/model_scripts/custom_models/uq_models/meta_uq.template,sha256=XTfhODRaHlI1jZGo9pSe-TqNsk2_nuSw0xMO2fKzDv8,14011
|
|
146
|
+
workbench/model_scripts/custom_models/uq_models/ngboost.template,sha256=v1rviYTJGJnQRGgAyveXhOQlS-WFCTlc2vdnWq6HIXk,8241
|
|
147
147
|
workbench/model_scripts/custom_models/uq_models/proximity.py,sha256=zqmNlX70LnWXr5fdtFFQppSNTLjlOciQVrjGr-g9jRE,13716
|
|
148
148
|
workbench/model_scripts/custom_models/uq_models/requirements.txt,sha256=fw7T7t_YJAXK3T6Ysbesxh_Agx_tv0oYx72cEBTqRDY,98
|
|
149
149
|
workbench/model_scripts/custom_script_example/custom_model_script.py,sha256=T8aydawgRVAdSlDimoWpXxG2YuWWQkbcjBVjAeSG2_0,6408
|
|
150
150
|
workbench/model_scripts/custom_script_example/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
|
|
151
|
-
workbench/model_scripts/ensemble_xgb/ensemble_xgb.template,sha256=
|
|
151
|
+
workbench/model_scripts/ensemble_xgb/ensemble_xgb.template,sha256=pWmuo-EVz0owvkRI-h9mUTYt1-ouyD-_yyQu6SQbYZ4,10350
|
|
152
152
|
workbench/model_scripts/ensemble_xgb/generated_model_script.py,sha256=dsjUGm22xI1ThGn97HPKtooyEPK-HOQnf5chnZ7-MXk,10675
|
|
153
153
|
workbench/model_scripts/ensemble_xgb/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
|
|
154
154
|
workbench/model_scripts/pytorch_model/generated_model_script.py,sha256=Mr1IMQJE_ML899qjzhjkrP521IjvcAvqU0pk--FB7KY,22356
|
|
155
|
-
workbench/model_scripts/pytorch_model/pytorch.template,sha256=
|
|
155
|
+
workbench/model_scripts/pytorch_model/pytorch.template,sha256=_gRp6DH294FLxF21UpSTq7s9RFfrLjViKvjXQ4yDfBQ,21999
|
|
156
156
|
workbench/model_scripts/pytorch_model/requirements.txt,sha256=ICS5nW0wix44EJO2tJszJSaUrSvhSfdedn6FcRInGx4,181
|
|
157
|
-
workbench/model_scripts/quant_regression/quant_regression.template,sha256=
|
|
157
|
+
workbench/model_scripts/quant_regression/quant_regression.template,sha256=2F25lZ7m_VafHvuGrC__R3uB1NzKgZu94eWJS9sWpYg,9783
|
|
158
158
|
workbench/model_scripts/quant_regression/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
|
|
159
159
|
workbench/model_scripts/scikit_learn/generated_model_script.py,sha256=c73ZpJBlU5k13Nx-ZDkLXu7da40CYyhwjwwmuPq6uLg,12870
|
|
160
160
|
workbench/model_scripts/scikit_learn/requirements.txt,sha256=aVvwiJ3LgBUhM_PyFlb2gHXu_kpGPho3ANBzlOkfcvs,107
|
|
161
|
-
workbench/model_scripts/scikit_learn/scikit_learn.template,sha256=
|
|
162
|
-
workbench/model_scripts/xgb_model/generated_model_script.py,sha256=
|
|
161
|
+
workbench/model_scripts/scikit_learn/scikit_learn.template,sha256=QQvqx-eX9ZTbYmyupq6R6vIQwosmsmY_MRBPaHyfjdk,12586
|
|
162
|
+
workbench/model_scripts/xgb_model/generated_model_script.py,sha256=Tbn7EMXxZZO8rDdKQ5fYCbpltACsMXNvuusLL9p-U5c,22319
|
|
163
163
|
workbench/model_scripts/xgb_model/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
|
|
164
|
-
workbench/model_scripts/xgb_model/xgb_model.template,sha256=
|
|
164
|
+
workbench/model_scripts/xgb_model/xgb_model.template,sha256=9wed0Ii6KZpEiVj6fHGDkwpxPkLIhcSIDls7Ij0HUF0,17874
|
|
165
165
|
workbench/repl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
166
166
|
workbench/repl/workbench_shell.py,sha256=Vhg4BQr2r4D4ymekrVtOFi0MaRvaH4V2UgcWRgvN_3U,22122
|
|
167
167
|
workbench/resources/open_source_api.key,sha256=3S0OTblsmC0msUPdE_dbBmI83xJNmYscuwLJ57JmuOc,433
|
|
@@ -219,7 +219,7 @@ workbench/utils/lambda_utils.py,sha256=7GhGRPyXn9o-toWb9HBGSnI8-DhK9YRkwhCSk_mNK
|
|
|
219
219
|
workbench/utils/license_manager.py,sha256=sDuhk1mZZqUbFmnuFXehyGnui_ALxrmYBg7gYwoo7ho,6975
|
|
220
220
|
workbench/utils/log_utils.py,sha256=7n1NJXO_jUX82e6LWAQug6oPo3wiPDBYsqk9gsYab_A,3167
|
|
221
221
|
workbench/utils/markdown_utils.py,sha256=4lEqzgG4EVmLcvvKKNUwNxVCySLQKJTJmWDiaDroI1w,8306
|
|
222
|
-
workbench/utils/model_utils.py,sha256=
|
|
222
|
+
workbench/utils/model_utils.py,sha256=b71URPIsbkmFy2XEglftO47lQGIl5CpahjCbb2CnTpM,12801
|
|
223
223
|
workbench/utils/monitor_utils.py,sha256=kVaJ7BgUXs3VPMFYfLC03wkIV4Dq-pEhoXS0wkJFxCc,7858
|
|
224
224
|
workbench/utils/pandas_utils.py,sha256=uTUx-d1KYfjbS9PMQp2_9FogCV7xVZR6XLzU5YAGmfs,39371
|
|
225
225
|
workbench/utils/performance_utils.py,sha256=WDNvz-bOdC99cDuXl0urAV4DJ7alk_V3yzKPwvqgST4,1329
|
|
@@ -242,7 +242,7 @@ workbench/utils/workbench_cache.py,sha256=IQchxB81iR4eVggHBxUJdXxUCRkqWz1jKe5gxN
|
|
|
242
242
|
workbench/utils/workbench_event_bridge.py,sha256=z1GmXOB-Qs7VOgC6Hjnp2DI9nSEWepaSXejACxTIR7o,4150
|
|
243
243
|
workbench/utils/workbench_logging.py,sha256=WCuMWhQwibrvcGAyj96h2wowh6dH7zNlDJ7sWUzdCeI,10263
|
|
244
244
|
workbench/utils/workbench_sqs.py,sha256=RwM80z7YWwdtMaCKh7KWF8v38f7eBRU7kyC7ZhTRuI0,2072
|
|
245
|
-
workbench/utils/xgboost_model_utils.py,sha256=
|
|
245
|
+
workbench/utils/xgboost_model_utils.py,sha256=wSUrs9VlftaTZ-cWZMEeHY6TmcLvxwrKk4S4lr7kWWw,17482
|
|
246
246
|
workbench/utils/chem_utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
247
247
|
workbench/utils/chem_utils/fingerprints.py,sha256=Qvs8jaUwguWUq3Q3j695MY0t0Wk3BvroW-oWBwalMUo,5255
|
|
248
248
|
workbench/utils/chem_utils/misc.py,sha256=Nevf8_opu-uIPrv_1_0ubuFVVo2_fGUkMoLAHB3XAeo,7372
|
|
@@ -287,9 +287,9 @@ workbench/web_interface/page_views/main_page.py,sha256=X4-KyGTKLAdxR-Zk2niuLJB2Y
|
|
|
287
287
|
workbench/web_interface/page_views/models_page_view.py,sha256=M0bdC7bAzLyIaE2jviY12FF4abdMFZmg6sFuOY_LaGI,2650
|
|
288
288
|
workbench/web_interface/page_views/page_view.py,sha256=Gh6YnpOGlUejx-bHZAf5pzqoQ1H1R0OSwOpGhOBO06w,455
|
|
289
289
|
workbench/web_interface/page_views/pipelines_page_view.py,sha256=v2pxrIbsHBcYiblfius3JK766NZ7ciD2yPx0t3E5IJo,2656
|
|
290
|
-
workbench-0.8.
|
|
291
|
-
workbench-0.8.
|
|
292
|
-
workbench-0.8.
|
|
293
|
-
workbench-0.8.
|
|
294
|
-
workbench-0.8.
|
|
295
|
-
workbench-0.8.
|
|
290
|
+
workbench-0.8.179.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
|
|
291
|
+
workbench-0.8.179.dist-info/METADATA,sha256=d1akwCm_MgH4MPba49bBYehDcMGFY5GhfXTmswPKJbM,9210
|
|
292
|
+
workbench-0.8.179.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
293
|
+
workbench-0.8.179.dist-info/entry_points.txt,sha256=zPFPruY9uayk8-wsKrhfnIyIB6jvZOW_ibyllEIsLWo,356
|
|
294
|
+
workbench-0.8.179.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
|
|
295
|
+
workbench-0.8.179.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|