workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +7 -7
- workbench/core/artifacts/data_capture_core.py +8 -1
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +323 -205
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +133 -101
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +18 -7
- workbench/utils/chem_utils/mol_standardize.py +80 -58
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +274 -87
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""Metrics utilities for computing model performance from predictions."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from scipy.stats import spearmanr
|
|
9
|
+
from sklearn.metrics import (
|
|
10
|
+
mean_absolute_error,
|
|
11
|
+
median_absolute_error,
|
|
12
|
+
precision_recall_fscore_support,
|
|
13
|
+
r2_score,
|
|
14
|
+
roc_auc_score,
|
|
15
|
+
root_mean_squared_error,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
log = logging.getLogger("workbench")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def validate_proba_columns(predictions_df: pd.DataFrame, class_labels: List[str], guessing: bool = False) -> bool:
|
|
22
|
+
"""Validate that probability columns match class labels.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
predictions_df: DataFrame with prediction results
|
|
26
|
+
class_labels: List of class labels
|
|
27
|
+
guessing: Whether class labels were guessed from data
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
True if validation passes
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If probability columns don't match class labels
|
|
34
|
+
"""
|
|
35
|
+
proba_columns = [col.replace("_proba", "") for col in predictions_df.columns if col.endswith("_proba")]
|
|
36
|
+
|
|
37
|
+
if sorted(class_labels) != sorted(proba_columns):
|
|
38
|
+
label_type = "GUESSED class_labels" if guessing else "class_labels"
|
|
39
|
+
raise ValueError(f"_proba columns {proba_columns} != {label_type} {class_labels}!")
|
|
40
|
+
return True
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def compute_classification_metrics(
|
|
44
|
+
predictions_df: pd.DataFrame,
|
|
45
|
+
target_col: str,
|
|
46
|
+
class_labels: Optional[List[str]] = None,
|
|
47
|
+
prediction_col: str = "prediction",
|
|
48
|
+
) -> pd.DataFrame:
|
|
49
|
+
"""Compute classification metrics from a predictions DataFrame.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
predictions_df: DataFrame with target and prediction columns
|
|
53
|
+
target_col: Name of the target column
|
|
54
|
+
class_labels: List of class labels in order (if None, inferred from target column)
|
|
55
|
+
prediction_col: Name of the prediction column (default: "prediction")
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
DataFrame with per-class metrics (precision, recall, f1, roc_auc, support)
|
|
59
|
+
plus a weighted 'all' row. Returns empty DataFrame if validation fails.
|
|
60
|
+
"""
|
|
61
|
+
# Validate inputs
|
|
62
|
+
if predictions_df.empty:
|
|
63
|
+
log.warning("Empty DataFrame provided. Returning empty metrics.")
|
|
64
|
+
return pd.DataFrame()
|
|
65
|
+
|
|
66
|
+
if prediction_col not in predictions_df.columns:
|
|
67
|
+
log.warning(f"Prediction column '{prediction_col}' not found in DataFrame. Returning empty metrics.")
|
|
68
|
+
return pd.DataFrame()
|
|
69
|
+
|
|
70
|
+
if target_col not in predictions_df.columns:
|
|
71
|
+
log.warning(f"Target column '{target_col}' not found in DataFrame. Returning empty metrics.")
|
|
72
|
+
return pd.DataFrame()
|
|
73
|
+
|
|
74
|
+
# Handle NaN predictions
|
|
75
|
+
df = predictions_df.copy()
|
|
76
|
+
nan_pred = df[prediction_col].isnull().sum()
|
|
77
|
+
if nan_pred > 0:
|
|
78
|
+
log.warning(f"Dropping {nan_pred} rows with NaN predictions.")
|
|
79
|
+
df = df[~df[prediction_col].isnull()]
|
|
80
|
+
|
|
81
|
+
if df.empty:
|
|
82
|
+
log.warning("No valid rows after dropping NaNs. Returning empty metrics.")
|
|
83
|
+
return pd.DataFrame()
|
|
84
|
+
|
|
85
|
+
# Handle class labels
|
|
86
|
+
guessing = False
|
|
87
|
+
if class_labels is None:
|
|
88
|
+
log.warning("Class labels not provided. Inferring from target column.")
|
|
89
|
+
class_labels = df[target_col].unique().tolist()
|
|
90
|
+
guessing = True
|
|
91
|
+
|
|
92
|
+
# Validate probability columns if present
|
|
93
|
+
proba_cols = [col for col in df.columns if col.endswith("_proba")]
|
|
94
|
+
if proba_cols:
|
|
95
|
+
validate_proba_columns(df, class_labels, guessing=guessing)
|
|
96
|
+
|
|
97
|
+
y_true = df[target_col]
|
|
98
|
+
y_pred = df[prediction_col]
|
|
99
|
+
|
|
100
|
+
# Precision, recall, f1, support per class
|
|
101
|
+
prec, rec, f1, support = precision_recall_fscore_support(y_true, y_pred, labels=class_labels, zero_division=0)
|
|
102
|
+
|
|
103
|
+
# ROC AUC per class (requires probability columns and sorted labels)
|
|
104
|
+
proba_col_names = [f"{label}_proba" for label in class_labels]
|
|
105
|
+
if all(col in df.columns for col in proba_col_names):
|
|
106
|
+
# roc_auc_score requires labels to be sorted, so we sort and reorder results back
|
|
107
|
+
sorted_labels = sorted(class_labels)
|
|
108
|
+
sorted_proba_cols = [f"{label}_proba" for label in sorted_labels]
|
|
109
|
+
y_score_sorted = df[sorted_proba_cols].values
|
|
110
|
+
roc_auc_sorted = roc_auc_score(y_true, y_score_sorted, labels=sorted_labels, multi_class="ovr", average=None)
|
|
111
|
+
# Map back to original class_labels order
|
|
112
|
+
label_to_auc = dict(zip(sorted_labels, roc_auc_sorted))
|
|
113
|
+
roc_auc = np.array([label_to_auc[label] for label in class_labels])
|
|
114
|
+
else:
|
|
115
|
+
roc_auc = np.array([None] * len(class_labels))
|
|
116
|
+
|
|
117
|
+
# Build per-class metrics
|
|
118
|
+
metrics_df = pd.DataFrame(
|
|
119
|
+
{
|
|
120
|
+
target_col: class_labels,
|
|
121
|
+
"precision": prec,
|
|
122
|
+
"recall": rec,
|
|
123
|
+
"f1": f1,
|
|
124
|
+
"roc_auc": roc_auc,
|
|
125
|
+
"support": support.astype(int),
|
|
126
|
+
}
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Add weighted 'all' row
|
|
130
|
+
total = support.sum()
|
|
131
|
+
all_row = {
|
|
132
|
+
target_col: "all",
|
|
133
|
+
"precision": (prec * support).sum() / total,
|
|
134
|
+
"recall": (rec * support).sum() / total,
|
|
135
|
+
"f1": (f1 * support).sum() / total,
|
|
136
|
+
"roc_auc": (roc_auc * support).sum() / total if roc_auc[0] is not None else None,
|
|
137
|
+
"support": int(total),
|
|
138
|
+
}
|
|
139
|
+
metrics_df = pd.concat([metrics_df, pd.DataFrame([all_row])], ignore_index=True)
|
|
140
|
+
|
|
141
|
+
return metrics_df
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def compute_regression_metrics(
|
|
145
|
+
predictions_df: pd.DataFrame,
|
|
146
|
+
target_col: str,
|
|
147
|
+
prediction_col: str = "prediction",
|
|
148
|
+
) -> pd.DataFrame:
|
|
149
|
+
"""Compute regression metrics from a predictions DataFrame.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
predictions_df: DataFrame with target and prediction columns
|
|
153
|
+
target_col: Name of the target column
|
|
154
|
+
prediction_col: Name of the prediction column (default: "prediction")
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
DataFrame with regression metrics (rmse, mae, medae, r2, spearmanr, support)
|
|
158
|
+
Returns empty DataFrame if validation fails or no valid data.
|
|
159
|
+
"""
|
|
160
|
+
# Validate inputs
|
|
161
|
+
if predictions_df.empty:
|
|
162
|
+
log.warning("Empty DataFrame provided. Returning empty metrics.")
|
|
163
|
+
return pd.DataFrame()
|
|
164
|
+
|
|
165
|
+
if prediction_col not in predictions_df.columns:
|
|
166
|
+
log.warning(f"Prediction column '{prediction_col}' not found in DataFrame. Returning empty metrics.")
|
|
167
|
+
return pd.DataFrame()
|
|
168
|
+
|
|
169
|
+
if target_col not in predictions_df.columns:
|
|
170
|
+
log.warning(f"Target column '{target_col}' not found in DataFrame. Returning empty metrics.")
|
|
171
|
+
return pd.DataFrame()
|
|
172
|
+
|
|
173
|
+
# Handle NaN values
|
|
174
|
+
df = predictions_df[[target_col, prediction_col]].copy()
|
|
175
|
+
nan_target = df[target_col].isnull().sum()
|
|
176
|
+
nan_pred = df[prediction_col].isnull().sum()
|
|
177
|
+
if nan_target > 0 or nan_pred > 0:
|
|
178
|
+
log.warning(f"NaNs found: {target_col}={nan_target}, {prediction_col}={nan_pred}. Dropping NaN rows.")
|
|
179
|
+
df = df.dropna()
|
|
180
|
+
|
|
181
|
+
if df.empty:
|
|
182
|
+
log.warning("No valid rows after dropping NaNs. Returning empty metrics.")
|
|
183
|
+
return pd.DataFrame()
|
|
184
|
+
|
|
185
|
+
y_true = df[target_col].values
|
|
186
|
+
y_pred = df[prediction_col].values
|
|
187
|
+
|
|
188
|
+
return pd.DataFrame(
|
|
189
|
+
[
|
|
190
|
+
{
|
|
191
|
+
"rmse": root_mean_squared_error(y_true, y_pred),
|
|
192
|
+
"mae": mean_absolute_error(y_true, y_pred),
|
|
193
|
+
"medae": median_absolute_error(y_true, y_pred),
|
|
194
|
+
"r2": r2_score(y_true, y_pred),
|
|
195
|
+
"spearmanr": spearmanr(y_true, y_pred).correlation,
|
|
196
|
+
"support": len(y_true),
|
|
197
|
+
}
|
|
198
|
+
]
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def compute_metrics_from_predictions(
|
|
203
|
+
predictions_df: pd.DataFrame,
|
|
204
|
+
target_col: str,
|
|
205
|
+
class_labels: Optional[List[str]] = None,
|
|
206
|
+
prediction_col: str = "prediction",
|
|
207
|
+
) -> pd.DataFrame:
|
|
208
|
+
"""Compute metrics from a predictions DataFrame.
|
|
209
|
+
|
|
210
|
+
Automatically determines if this is classification or regression based on
|
|
211
|
+
whether class_labels is provided.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
predictions_df: DataFrame with target and prediction columns
|
|
215
|
+
target_col: Name of the target column
|
|
216
|
+
class_labels: List of class labels for classification (None for regression)
|
|
217
|
+
prediction_col: Name of the prediction column (default: "prediction")
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
DataFrame with computed metrics
|
|
221
|
+
"""
|
|
222
|
+
if target_col not in predictions_df.columns:
|
|
223
|
+
raise ValueError(f"Target column '{target_col}' not found in predictions DataFrame")
|
|
224
|
+
if prediction_col not in predictions_df.columns:
|
|
225
|
+
raise ValueError(f"Prediction column '{prediction_col}' not found in predictions DataFrame")
|
|
226
|
+
|
|
227
|
+
if class_labels:
|
|
228
|
+
return compute_classification_metrics(predictions_df, target_col, class_labels, prediction_col)
|
|
229
|
+
else:
|
|
230
|
+
return compute_regression_metrics(predictions_df, target_col, prediction_col)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
if __name__ == "__main__":
|
|
234
|
+
# Test with sample data
|
|
235
|
+
print("Testing classification metrics...")
|
|
236
|
+
class_df = pd.DataFrame(
|
|
237
|
+
{
|
|
238
|
+
"target": ["a", "b", "c", "a", "b", "c", "a", "b", "c", "a"],
|
|
239
|
+
"prediction": ["a", "b", "c", "a", "b", "a", "a", "b", "c", "b"],
|
|
240
|
+
"a_proba": [0.8, 0.1, 0.1, 0.7, 0.2, 0.4, 0.9, 0.1, 0.1, 0.3],
|
|
241
|
+
"b_proba": [0.1, 0.8, 0.1, 0.2, 0.7, 0.3, 0.05, 0.8, 0.2, 0.6],
|
|
242
|
+
"c_proba": [0.1, 0.1, 0.8, 0.1, 0.1, 0.3, 0.05, 0.1, 0.7, 0.1],
|
|
243
|
+
}
|
|
244
|
+
)
|
|
245
|
+
metrics = compute_metrics_from_predictions(class_df, "target", ["a", "b", "c"])
|
|
246
|
+
print(metrics.to_string(index=False))
|
|
247
|
+
|
|
248
|
+
print("\nTesting regression metrics...")
|
|
249
|
+
reg_df = pd.DataFrame(
|
|
250
|
+
{
|
|
251
|
+
"target": [1.0, 2.0, 3.0, 4.0, 5.0],
|
|
252
|
+
"prediction": [1.1, 2.2, 2.9, 4.1, 4.8],
|
|
253
|
+
}
|
|
254
|
+
)
|
|
255
|
+
metrics = compute_metrics_from_predictions(reg_df, "target")
|
|
256
|
+
print(metrics.to_string(index=False))
|