workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,256 @@
1
+ """Metrics utilities for computing model performance from predictions."""
2
+
3
+ import logging
4
+ from typing import List, Optional
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from scipy.stats import spearmanr
9
+ from sklearn.metrics import (
10
+ mean_absolute_error,
11
+ median_absolute_error,
12
+ precision_recall_fscore_support,
13
+ r2_score,
14
+ roc_auc_score,
15
+ root_mean_squared_error,
16
+ )
17
+
18
+ log = logging.getLogger("workbench")
19
+
20
+
21
+ def validate_proba_columns(predictions_df: pd.DataFrame, class_labels: List[str], guessing: bool = False) -> bool:
22
+ """Validate that probability columns match class labels.
23
+
24
+ Args:
25
+ predictions_df: DataFrame with prediction results
26
+ class_labels: List of class labels
27
+ guessing: Whether class labels were guessed from data
28
+
29
+ Returns:
30
+ True if validation passes
31
+
32
+ Raises:
33
+ ValueError: If probability columns don't match class labels
34
+ """
35
+ proba_columns = [col.replace("_proba", "") for col in predictions_df.columns if col.endswith("_proba")]
36
+
37
+ if sorted(class_labels) != sorted(proba_columns):
38
+ label_type = "GUESSED class_labels" if guessing else "class_labels"
39
+ raise ValueError(f"_proba columns {proba_columns} != {label_type} {class_labels}!")
40
+ return True
41
+
42
+
43
+ def compute_classification_metrics(
44
+ predictions_df: pd.DataFrame,
45
+ target_col: str,
46
+ class_labels: Optional[List[str]] = None,
47
+ prediction_col: str = "prediction",
48
+ ) -> pd.DataFrame:
49
+ """Compute classification metrics from a predictions DataFrame.
50
+
51
+ Args:
52
+ predictions_df: DataFrame with target and prediction columns
53
+ target_col: Name of the target column
54
+ class_labels: List of class labels in order (if None, inferred from target column)
55
+ prediction_col: Name of the prediction column (default: "prediction")
56
+
57
+ Returns:
58
+ DataFrame with per-class metrics (precision, recall, f1, roc_auc, support)
59
+ plus a weighted 'all' row. Returns empty DataFrame if validation fails.
60
+ """
61
+ # Validate inputs
62
+ if predictions_df.empty:
63
+ log.warning("Empty DataFrame provided. Returning empty metrics.")
64
+ return pd.DataFrame()
65
+
66
+ if prediction_col not in predictions_df.columns:
67
+ log.warning(f"Prediction column '{prediction_col}' not found in DataFrame. Returning empty metrics.")
68
+ return pd.DataFrame()
69
+
70
+ if target_col not in predictions_df.columns:
71
+ log.warning(f"Target column '{target_col}' not found in DataFrame. Returning empty metrics.")
72
+ return pd.DataFrame()
73
+
74
+ # Handle NaN predictions
75
+ df = predictions_df.copy()
76
+ nan_pred = df[prediction_col].isnull().sum()
77
+ if nan_pred > 0:
78
+ log.warning(f"Dropping {nan_pred} rows with NaN predictions.")
79
+ df = df[~df[prediction_col].isnull()]
80
+
81
+ if df.empty:
82
+ log.warning("No valid rows after dropping NaNs. Returning empty metrics.")
83
+ return pd.DataFrame()
84
+
85
+ # Handle class labels
86
+ guessing = False
87
+ if class_labels is None:
88
+ log.warning("Class labels not provided. Inferring from target column.")
89
+ class_labels = df[target_col].unique().tolist()
90
+ guessing = True
91
+
92
+ # Validate probability columns if present
93
+ proba_cols = [col for col in df.columns if col.endswith("_proba")]
94
+ if proba_cols:
95
+ validate_proba_columns(df, class_labels, guessing=guessing)
96
+
97
+ y_true = df[target_col]
98
+ y_pred = df[prediction_col]
99
+
100
+ # Precision, recall, f1, support per class
101
+ prec, rec, f1, support = precision_recall_fscore_support(y_true, y_pred, labels=class_labels, zero_division=0)
102
+
103
+ # ROC AUC per class (requires probability columns and sorted labels)
104
+ proba_col_names = [f"{label}_proba" for label in class_labels]
105
+ if all(col in df.columns for col in proba_col_names):
106
+ # roc_auc_score requires labels to be sorted, so we sort and reorder results back
107
+ sorted_labels = sorted(class_labels)
108
+ sorted_proba_cols = [f"{label}_proba" for label in sorted_labels]
109
+ y_score_sorted = df[sorted_proba_cols].values
110
+ roc_auc_sorted = roc_auc_score(y_true, y_score_sorted, labels=sorted_labels, multi_class="ovr", average=None)
111
+ # Map back to original class_labels order
112
+ label_to_auc = dict(zip(sorted_labels, roc_auc_sorted))
113
+ roc_auc = np.array([label_to_auc[label] for label in class_labels])
114
+ else:
115
+ roc_auc = np.array([None] * len(class_labels))
116
+
117
+ # Build per-class metrics
118
+ metrics_df = pd.DataFrame(
119
+ {
120
+ target_col: class_labels,
121
+ "precision": prec,
122
+ "recall": rec,
123
+ "f1": f1,
124
+ "roc_auc": roc_auc,
125
+ "support": support.astype(int),
126
+ }
127
+ )
128
+
129
+ # Add weighted 'all' row
130
+ total = support.sum()
131
+ all_row = {
132
+ target_col: "all",
133
+ "precision": (prec * support).sum() / total,
134
+ "recall": (rec * support).sum() / total,
135
+ "f1": (f1 * support).sum() / total,
136
+ "roc_auc": (roc_auc * support).sum() / total if roc_auc[0] is not None else None,
137
+ "support": int(total),
138
+ }
139
+ metrics_df = pd.concat([metrics_df, pd.DataFrame([all_row])], ignore_index=True)
140
+
141
+ return metrics_df
142
+
143
+
144
+ def compute_regression_metrics(
145
+ predictions_df: pd.DataFrame,
146
+ target_col: str,
147
+ prediction_col: str = "prediction",
148
+ ) -> pd.DataFrame:
149
+ """Compute regression metrics from a predictions DataFrame.
150
+
151
+ Args:
152
+ predictions_df: DataFrame with target and prediction columns
153
+ target_col: Name of the target column
154
+ prediction_col: Name of the prediction column (default: "prediction")
155
+
156
+ Returns:
157
+ DataFrame with regression metrics (rmse, mae, medae, r2, spearmanr, support)
158
+ Returns empty DataFrame if validation fails or no valid data.
159
+ """
160
+ # Validate inputs
161
+ if predictions_df.empty:
162
+ log.warning("Empty DataFrame provided. Returning empty metrics.")
163
+ return pd.DataFrame()
164
+
165
+ if prediction_col not in predictions_df.columns:
166
+ log.warning(f"Prediction column '{prediction_col}' not found in DataFrame. Returning empty metrics.")
167
+ return pd.DataFrame()
168
+
169
+ if target_col not in predictions_df.columns:
170
+ log.warning(f"Target column '{target_col}' not found in DataFrame. Returning empty metrics.")
171
+ return pd.DataFrame()
172
+
173
+ # Handle NaN values
174
+ df = predictions_df[[target_col, prediction_col]].copy()
175
+ nan_target = df[target_col].isnull().sum()
176
+ nan_pred = df[prediction_col].isnull().sum()
177
+ if nan_target > 0 or nan_pred > 0:
178
+ log.warning(f"NaNs found: {target_col}={nan_target}, {prediction_col}={nan_pred}. Dropping NaN rows.")
179
+ df = df.dropna()
180
+
181
+ if df.empty:
182
+ log.warning("No valid rows after dropping NaNs. Returning empty metrics.")
183
+ return pd.DataFrame()
184
+
185
+ y_true = df[target_col].values
186
+ y_pred = df[prediction_col].values
187
+
188
+ return pd.DataFrame(
189
+ [
190
+ {
191
+ "rmse": root_mean_squared_error(y_true, y_pred),
192
+ "mae": mean_absolute_error(y_true, y_pred),
193
+ "medae": median_absolute_error(y_true, y_pred),
194
+ "r2": r2_score(y_true, y_pred),
195
+ "spearmanr": spearmanr(y_true, y_pred).correlation,
196
+ "support": len(y_true),
197
+ }
198
+ ]
199
+ )
200
+
201
+
202
+ def compute_metrics_from_predictions(
203
+ predictions_df: pd.DataFrame,
204
+ target_col: str,
205
+ class_labels: Optional[List[str]] = None,
206
+ prediction_col: str = "prediction",
207
+ ) -> pd.DataFrame:
208
+ """Compute metrics from a predictions DataFrame.
209
+
210
+ Automatically determines if this is classification or regression based on
211
+ whether class_labels is provided.
212
+
213
+ Args:
214
+ predictions_df: DataFrame with target and prediction columns
215
+ target_col: Name of the target column
216
+ class_labels: List of class labels for classification (None for regression)
217
+ prediction_col: Name of the prediction column (default: "prediction")
218
+
219
+ Returns:
220
+ DataFrame with computed metrics
221
+ """
222
+ if target_col not in predictions_df.columns:
223
+ raise ValueError(f"Target column '{target_col}' not found in predictions DataFrame")
224
+ if prediction_col not in predictions_df.columns:
225
+ raise ValueError(f"Prediction column '{prediction_col}' not found in predictions DataFrame")
226
+
227
+ if class_labels:
228
+ return compute_classification_metrics(predictions_df, target_col, class_labels, prediction_col)
229
+ else:
230
+ return compute_regression_metrics(predictions_df, target_col, prediction_col)
231
+
232
+
233
+ if __name__ == "__main__":
234
+ # Test with sample data
235
+ print("Testing classification metrics...")
236
+ class_df = pd.DataFrame(
237
+ {
238
+ "target": ["a", "b", "c", "a", "b", "c", "a", "b", "c", "a"],
239
+ "prediction": ["a", "b", "c", "a", "b", "a", "a", "b", "c", "b"],
240
+ "a_proba": [0.8, 0.1, 0.1, 0.7, 0.2, 0.4, 0.9, 0.1, 0.1, 0.3],
241
+ "b_proba": [0.1, 0.8, 0.1, 0.2, 0.7, 0.3, 0.05, 0.8, 0.2, 0.6],
242
+ "c_proba": [0.1, 0.1, 0.8, 0.1, 0.1, 0.3, 0.05, 0.1, 0.7, 0.1],
243
+ }
244
+ )
245
+ metrics = compute_metrics_from_predictions(class_df, "target", ["a", "b", "c"])
246
+ print(metrics.to_string(index=False))
247
+
248
+ print("\nTesting regression metrics...")
249
+ reg_df = pd.DataFrame(
250
+ {
251
+ "target": [1.0, 2.0, 3.0, 4.0, 5.0],
252
+ "prediction": [1.1, 2.2, 2.9, 4.1, 4.8],
253
+ }
254
+ )
255
+ metrics = compute_metrics_from_predictions(reg_df, "target")
256
+ print(metrics.to_string(index=False))