workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +14 -12
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/compound.py +1 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +18 -5
- workbench/api/feature_set.py +121 -15
- workbench/api/meta.py +5 -2
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +55 -21
- workbench/api/monitor.py +1 -16
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +16 -8
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +382 -253
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -80
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +62 -40
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +649 -0
- workbench/model_scripts/chemprop/generated_model_script.py +649 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +20 -11
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +369 -401
- workbench/repl/workbench_shell.py +28 -19
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/scripts/training_test.py +85 -0
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +175 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +219 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +278 -79
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -219
- workbench/web_interface/components/model_plot.py +14 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +38 -74
- workbench/web_interface/components/plugins/scatter_plot.py +6 -10
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
- workbench-0.8.220.dist-info/entry_points.txt +11 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- workbench-0.8.162.dist-info/entry_points.txt +0 -5
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
workbench/utils/model_utils.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import numpy as np
|
|
6
|
+
from scipy.stats import spearmanr
|
|
6
7
|
import importlib.resources
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
import os
|
|
@@ -92,13 +93,158 @@ def get_custom_script_path(package: str, script_name: str) -> Path:
|
|
|
92
93
|
return script_path
|
|
93
94
|
|
|
94
95
|
|
|
95
|
-
def
|
|
96
|
-
"""Create a
|
|
96
|
+
def proximity_model_local(model: "Model", include_all_columns: bool = False):
|
|
97
|
+
"""Create a FeatureSpaceProximity Model for this Model
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
model (Model): The Model/FeatureSet used to create the proximity model
|
|
101
|
+
include_all_columns (bool): Include all DataFrame columns in neighbor results (default: False)
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
FeatureSpaceProximity: The proximity model
|
|
105
|
+
"""
|
|
106
|
+
from workbench.algorithms.dataframe.feature_space_proximity import FeatureSpaceProximity # noqa: F401
|
|
107
|
+
from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
|
|
108
|
+
|
|
109
|
+
# Get Feature and Target Columns from the existing given Model
|
|
110
|
+
features = model.features()
|
|
111
|
+
target = model.target()
|
|
112
|
+
|
|
113
|
+
# Backtrack our FeatureSet to get the ID column
|
|
114
|
+
fs = FeatureSet(model.get_input())
|
|
115
|
+
id_column = fs.id_column
|
|
116
|
+
|
|
117
|
+
# Create the Proximity Model from both the full FeatureSet and the Model training data
|
|
118
|
+
full_df = fs.pull_dataframe()
|
|
119
|
+
model_df = model.training_view().pull_dataframe()
|
|
120
|
+
|
|
121
|
+
# Mark rows that are in the model
|
|
122
|
+
model_ids = set(model_df[id_column])
|
|
123
|
+
full_df["in_model"] = full_df[id_column].isin(model_ids)
|
|
124
|
+
|
|
125
|
+
# Create and return the FeatureSpaceProximity Model
|
|
126
|
+
return FeatureSpaceProximity(
|
|
127
|
+
full_df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def fingerprint_prox_model_local(
|
|
132
|
+
model: "Model",
|
|
133
|
+
include_all_columns: bool = False,
|
|
134
|
+
radius: int = 2,
|
|
135
|
+
n_bits: int = 1024,
|
|
136
|
+
counts: bool = False,
|
|
137
|
+
):
|
|
138
|
+
"""Create a FingerprintProximity Model for this Model
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
model (Model): The Model used to create the fingerprint proximity model
|
|
142
|
+
include_all_columns (bool): Include all DataFrame columns in neighbor results (default: False)
|
|
143
|
+
radius (int): Morgan fingerprint radius (default: 2)
|
|
144
|
+
n_bits (int): Number of bits for the fingerprint (default: 1024)
|
|
145
|
+
counts (bool): Use count fingerprints instead of binary (default: False)
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
FingerprintProximity: The fingerprint proximity model
|
|
149
|
+
"""
|
|
150
|
+
from workbench.algorithms.dataframe.fingerprint_proximity import FingerprintProximity # noqa: F401
|
|
151
|
+
from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
|
|
152
|
+
|
|
153
|
+
# Get Target Column from the existing given Model
|
|
154
|
+
target = model.target()
|
|
155
|
+
|
|
156
|
+
# Backtrack our FeatureSet to get the ID column
|
|
157
|
+
fs = FeatureSet(model.get_input())
|
|
158
|
+
id_column = fs.id_column
|
|
159
|
+
|
|
160
|
+
# Create the Proximity Model from both the full FeatureSet and the Model training data
|
|
161
|
+
full_df = fs.pull_dataframe()
|
|
162
|
+
model_df = model.training_view().pull_dataframe()
|
|
163
|
+
|
|
164
|
+
# Mark rows that are in the model
|
|
165
|
+
model_ids = set(model_df[id_column])
|
|
166
|
+
full_df["in_model"] = full_df[id_column].isin(model_ids)
|
|
167
|
+
|
|
168
|
+
# Create and return the FingerprintProximity Model
|
|
169
|
+
return FingerprintProximity(
|
|
170
|
+
full_df,
|
|
171
|
+
id_column=id_column,
|
|
172
|
+
target=target,
|
|
173
|
+
include_all_columns=include_all_columns,
|
|
174
|
+
radius=radius,
|
|
175
|
+
n_bits=n_bits,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def noise_model_local(model: "Model"):
|
|
180
|
+
"""Create a NoiseModel for detecting noisy/problematic samples in a Model's training data.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
model (Model): The Model used to create the noise model
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
NoiseModel: The noise model with precomputed noise scores for all samples
|
|
187
|
+
"""
|
|
188
|
+
from workbench.algorithms.models.noise_model import NoiseModel # noqa: F401 (avoid circular import)
|
|
189
|
+
from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
|
|
190
|
+
|
|
191
|
+
# Get Feature and Target Columns from the existing given Model
|
|
192
|
+
features = model.features()
|
|
193
|
+
target = model.target()
|
|
194
|
+
|
|
195
|
+
# Backtrack our FeatureSet to get the ID column
|
|
196
|
+
fs = FeatureSet(model.get_input())
|
|
197
|
+
id_column = fs.id_column
|
|
198
|
+
|
|
199
|
+
# Create the NoiseModel from both the full FeatureSet and the Model training data
|
|
200
|
+
full_df = fs.pull_dataframe()
|
|
201
|
+
model_df = model.training_view().pull_dataframe()
|
|
202
|
+
|
|
203
|
+
# Mark rows that are in the model
|
|
204
|
+
model_ids = set(model_df[id_column])
|
|
205
|
+
full_df["in_model"] = full_df[id_column].isin(model_ids)
|
|
206
|
+
|
|
207
|
+
# Create and return the NoiseModel
|
|
208
|
+
return NoiseModel(full_df, id_column, features, target)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def cleanlab_model_local(model: "Model"):
|
|
212
|
+
"""Create a CleanlabModels instance for detecting data quality issues in a Model's training data.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
model (Model): The Model used to create the cleanlab models
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
CleanlabModels: Factory providing access to CleanLearning and Datalab models.
|
|
219
|
+
- clean_learning(): CleanLearning model with enhanced get_label_issues()
|
|
220
|
+
- datalab(): Datalab instance with report(), get_issues()
|
|
221
|
+
"""
|
|
222
|
+
from workbench.algorithms.models.cleanlab_model import create_cleanlab_model # noqa: F401 (avoid circular import)
|
|
223
|
+
from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
|
|
224
|
+
|
|
225
|
+
# Get Feature and Target Columns from the existing given Model
|
|
226
|
+
features = model.features()
|
|
227
|
+
target = model.target()
|
|
228
|
+
model_type = model.model_type
|
|
229
|
+
|
|
230
|
+
# Backtrack our FeatureSet to get the ID column
|
|
231
|
+
fs = FeatureSet(model.get_input())
|
|
232
|
+
id_column = fs.id_column
|
|
233
|
+
|
|
234
|
+
# Get the full FeatureSet data
|
|
235
|
+
full_df = fs.pull_dataframe()
|
|
236
|
+
|
|
237
|
+
# Create and return the CleanLearning model
|
|
238
|
+
return create_cleanlab_model(full_df, id_column, features, target, model_type=model_type)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def published_proximity_model(model: "Model", prox_model_name: str, include_all_columns: bool = False) -> "Model":
|
|
242
|
+
"""Create a published proximity model based on the given model
|
|
97
243
|
|
|
98
244
|
Args:
|
|
99
245
|
model (Model): The model to create the proximity model from
|
|
100
246
|
prox_model_name (str): The name of the proximity model to create
|
|
101
|
-
|
|
247
|
+
include_all_columns (bool): Include all DataFrame columns in results (default: False)
|
|
102
248
|
Returns:
|
|
103
249
|
Model: The proximity model
|
|
104
250
|
"""
|
|
@@ -121,45 +267,23 @@ def proximity_model(model: "Model", prox_model_name: str, track_columns: list =
|
|
|
121
267
|
description=f"Proximity Model for {model.name}",
|
|
122
268
|
tags=["proximity", model.name],
|
|
123
269
|
custom_script=script_path,
|
|
124
|
-
custom_args={"
|
|
270
|
+
custom_args={"include_all_columns": include_all_columns},
|
|
125
271
|
)
|
|
126
272
|
return prox_model
|
|
127
273
|
|
|
128
274
|
|
|
129
|
-
def
|
|
130
|
-
"""Create a Uncertainty Quantification (UQ) model based on the given model
|
|
131
|
-
|
|
132
|
-
Args:
|
|
133
|
-
model (Model): The model to create the UQ model from
|
|
134
|
-
uq_model_name (str): The name of the UQ model to create
|
|
135
|
-
train_all_data (bool, optional): Whether to train the UQ model on all data (default: False)
|
|
136
|
-
|
|
137
|
-
Returns:
|
|
138
|
-
Model: The UQ model
|
|
275
|
+
def safe_extract_tarfile(tar_path: str, extract_path: str) -> None:
|
|
139
276
|
"""
|
|
140
|
-
|
|
277
|
+
Extract a tarball safely, using data filter if available.
|
|
141
278
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
fs = FeatureSet(model.get_input())
|
|
151
|
-
uq_model = fs.to_model(
|
|
152
|
-
name=uq_model_name,
|
|
153
|
-
model_type=ModelType.UQ_REGRESSOR,
|
|
154
|
-
feature_list=features,
|
|
155
|
-
target_column=target,
|
|
156
|
-
description=f"UQ Model for {model.name}",
|
|
157
|
-
tags=["uq", model.name],
|
|
158
|
-
train_all_data=train_all_data,
|
|
159
|
-
custom_script=script_path,
|
|
160
|
-
custom_args={"id_column": fs.id_column, "track_columns": [target]},
|
|
161
|
-
)
|
|
162
|
-
return uq_model
|
|
279
|
+
The filter parameter was backported to Python 3.8+, 3.9+, 3.10.13+, 3.11+
|
|
280
|
+
as a security patch, but may not be present in older patch versions.
|
|
281
|
+
"""
|
|
282
|
+
with tarfile.open(tar_path, "r:gz") as tar:
|
|
283
|
+
if hasattr(tarfile, "data_filter"):
|
|
284
|
+
tar.extractall(path=extract_path, filter="data")
|
|
285
|
+
else:
|
|
286
|
+
tar.extractall(path=extract_path)
|
|
163
287
|
|
|
164
288
|
|
|
165
289
|
def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
|
|
@@ -180,8 +304,7 @@ def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
|
|
|
180
304
|
wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
|
|
181
305
|
|
|
182
306
|
# Extract tarball
|
|
183
|
-
|
|
184
|
-
tar.extractall(path=tmpdir, filter="data")
|
|
307
|
+
safe_extract_tarfile(local_tar_path, tmpdir)
|
|
185
308
|
|
|
186
309
|
# Look for category mappings in base directory only
|
|
187
310
|
mappings_path = os.path.join(tmpdir, "category_mappings.json")
|
|
@@ -197,6 +320,63 @@ def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
|
|
|
197
320
|
return category_mappings
|
|
198
321
|
|
|
199
322
|
|
|
323
|
+
def load_hyperparameters_from_s3(model_artifact_uri: str) -> Optional[dict]:
|
|
324
|
+
"""
|
|
325
|
+
Download and extract hyperparameters from a model artifact in S3.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
model_artifact_uri (str): S3 URI of the model artifact (model.tar.gz).
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
dict: The loaded hyperparameters or None if not found.
|
|
332
|
+
"""
|
|
333
|
+
hyperparameters = None
|
|
334
|
+
|
|
335
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
336
|
+
# Download model artifact
|
|
337
|
+
local_tar_path = os.path.join(tmpdir, "model.tar.gz")
|
|
338
|
+
wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
|
|
339
|
+
|
|
340
|
+
# Extract tarball
|
|
341
|
+
safe_extract_tarfile(local_tar_path, tmpdir)
|
|
342
|
+
|
|
343
|
+
# Look for hyperparameters in base directory only
|
|
344
|
+
hyperparameters_path = os.path.join(tmpdir, "hyperparameters.json")
|
|
345
|
+
|
|
346
|
+
if os.path.exists(hyperparameters_path):
|
|
347
|
+
try:
|
|
348
|
+
with open(hyperparameters_path, "r") as f:
|
|
349
|
+
hyperparameters = json.load(f)
|
|
350
|
+
log.info(f"Loaded hyperparameters from {hyperparameters_path}")
|
|
351
|
+
except Exception as e:
|
|
352
|
+
log.warning(f"Failed to load hyperparameters from {hyperparameters_path}: {e}")
|
|
353
|
+
|
|
354
|
+
return hyperparameters
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def get_model_hyperparameters(workbench_model: Any) -> Optional[dict]:
|
|
358
|
+
"""Get the hyperparameters used to train a Workbench model.
|
|
359
|
+
|
|
360
|
+
This retrieves the hyperparameters.json file from the model artifacts
|
|
361
|
+
that was saved during model training.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
workbench_model: Workbench model object
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
dict: The hyperparameters used during training, or None if not found
|
|
368
|
+
"""
|
|
369
|
+
# Get the model artifact URI
|
|
370
|
+
model_artifact_uri = workbench_model.model_data_url()
|
|
371
|
+
|
|
372
|
+
if model_artifact_uri is None:
|
|
373
|
+
log.warning(f"No model artifact found for {workbench_model.uuid}")
|
|
374
|
+
return None
|
|
375
|
+
|
|
376
|
+
log.info(f"Loading hyperparameters from {model_artifact_uri}")
|
|
377
|
+
return load_hyperparameters_from_s3(model_artifact_uri)
|
|
378
|
+
|
|
379
|
+
|
|
200
380
|
def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
201
381
|
"""
|
|
202
382
|
Evaluate uncertainty quantification model with essential metrics.
|
|
@@ -217,31 +397,51 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
217
397
|
if "prediction" not in df.columns:
|
|
218
398
|
raise ValueError("Prediction column 'prediction' not found in DataFrame.")
|
|
219
399
|
|
|
400
|
+
# Drop rows with NaN predictions (e.g., from models that can't handle missing features)
|
|
401
|
+
n_total = len(df)
|
|
402
|
+
df = df.dropna(subset=["prediction", target_col])
|
|
403
|
+
n_valid = len(df)
|
|
404
|
+
if n_valid < n_total:
|
|
405
|
+
log.info(f"UQ metrics: dropped {n_total - n_valid} rows with NaN predictions")
|
|
406
|
+
|
|
220
407
|
# --- Coverage and Interval Width ---
|
|
221
408
|
if "q_025" in df.columns and "q_975" in df.columns:
|
|
222
409
|
lower_95, upper_95 = df["q_025"], df["q_975"]
|
|
410
|
+
lower_90, upper_90 = df["q_05"], df["q_95"]
|
|
411
|
+
lower_80, upper_80 = df["q_10"], df["q_90"]
|
|
412
|
+
lower_68 = df.get("q_16", df["q_10"]) # fallback to 80% interval
|
|
413
|
+
upper_68 = df.get("q_84", df["q_90"]) # fallback to 80% interval
|
|
223
414
|
lower_50, upper_50 = df["q_25"], df["q_75"]
|
|
224
415
|
elif "prediction_std" in df.columns:
|
|
225
416
|
lower_95 = df["prediction"] - 1.96 * df["prediction_std"]
|
|
226
417
|
upper_95 = df["prediction"] + 1.96 * df["prediction_std"]
|
|
418
|
+
lower_90 = df["prediction"] - 1.645 * df["prediction_std"]
|
|
419
|
+
upper_90 = df["prediction"] + 1.645 * df["prediction_std"]
|
|
420
|
+
lower_80 = df["prediction"] - 1.282 * df["prediction_std"]
|
|
421
|
+
upper_80 = df["prediction"] + 1.282 * df["prediction_std"]
|
|
422
|
+
lower_68 = df["prediction"] - 1.0 * df["prediction_std"]
|
|
423
|
+
upper_68 = df["prediction"] + 1.0 * df["prediction_std"]
|
|
227
424
|
lower_50 = df["prediction"] - 0.674 * df["prediction_std"]
|
|
228
425
|
upper_50 = df["prediction"] + 0.674 * df["prediction_std"]
|
|
229
426
|
else:
|
|
230
427
|
raise ValueError(
|
|
231
428
|
"Either quantile columns (q_025, q_975, q_25, q_75) or 'prediction_std' column must be present."
|
|
232
429
|
)
|
|
430
|
+
median_std = df["prediction_std"].median()
|
|
233
431
|
coverage_95 = np.mean((df[target_col] >= lower_95) & (df[target_col] <= upper_95))
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
432
|
+
coverage_90 = np.mean((df[target_col] >= lower_90) & (df[target_col] <= upper_90))
|
|
433
|
+
coverage_80 = np.mean((df[target_col] >= lower_80) & (df[target_col] <= upper_80))
|
|
434
|
+
coverage_68 = np.mean((df[target_col] >= lower_68) & (df[target_col] <= upper_68))
|
|
435
|
+
median_width_95 = np.median(upper_95 - lower_95)
|
|
436
|
+
median_width_90 = np.median(upper_90 - lower_90)
|
|
437
|
+
median_width_80 = np.median(upper_80 - lower_80)
|
|
438
|
+
median_width_50 = np.median(upper_50 - lower_50)
|
|
439
|
+
median_width_68 = np.median(upper_68 - lower_68)
|
|
237
440
|
|
|
238
441
|
# --- CRPS (measures calibration + sharpness) ---
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
mean_crps = np.mean(crps)
|
|
243
|
-
else:
|
|
244
|
-
mean_crps = np.nan
|
|
442
|
+
z = (df[target_col] - df["prediction"]) / df["prediction_std"]
|
|
443
|
+
crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
|
|
444
|
+
mean_crps = np.mean(crps)
|
|
245
445
|
|
|
246
446
|
# --- Interval Score @ 95% (penalizes miscoverage) ---
|
|
247
447
|
alpha_95 = 0.05
|
|
@@ -252,38 +452,50 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
252
452
|
)
|
|
253
453
|
mean_is_95 = np.mean(is_95)
|
|
254
454
|
|
|
255
|
-
# ---
|
|
455
|
+
# --- Interval to Error Correlation ---
|
|
256
456
|
abs_residuals = np.abs(df[target_col] - df["prediction"])
|
|
257
|
-
|
|
258
|
-
|
|
457
|
+
width_68 = upper_68 - lower_68
|
|
458
|
+
|
|
459
|
+
# Spearman correlation for robustness
|
|
460
|
+
interval_to_error_corr = spearmanr(width_68, abs_residuals)[0]
|
|
259
461
|
|
|
260
462
|
# Collect results
|
|
261
463
|
results = {
|
|
464
|
+
"coverage_68": coverage_68,
|
|
465
|
+
"coverage_80": coverage_80,
|
|
466
|
+
"coverage_90": coverage_90,
|
|
262
467
|
"coverage_95": coverage_95,
|
|
263
|
-
"
|
|
264
|
-
"
|
|
265
|
-
"
|
|
266
|
-
"
|
|
267
|
-
"
|
|
268
|
-
"
|
|
468
|
+
"median_std": median_std,
|
|
469
|
+
"median_width_50": median_width_50,
|
|
470
|
+
"median_width_68": median_width_68,
|
|
471
|
+
"median_width_80": median_width_80,
|
|
472
|
+
"median_width_90": median_width_90,
|
|
473
|
+
"median_width_95": median_width_95,
|
|
474
|
+
"interval_to_error_corr": interval_to_error_corr,
|
|
269
475
|
"n_samples": len(df),
|
|
270
476
|
}
|
|
271
477
|
|
|
272
478
|
print("\n=== UQ Metrics ===")
|
|
479
|
+
print(f"Coverage @ 68%: {coverage_68:.3f} (target: 0.68)")
|
|
480
|
+
print(f"Coverage @ 80%: {coverage_80:.3f} (target: 0.80)")
|
|
481
|
+
print(f"Coverage @ 90%: {coverage_90:.3f} (target: 0.90)")
|
|
273
482
|
print(f"Coverage @ 95%: {coverage_95:.3f} (target: 0.95)")
|
|
274
|
-
print(f"
|
|
275
|
-
print(f"
|
|
276
|
-
print(f"
|
|
483
|
+
print(f"Median Prediction StdDev: {median_std:.3f}")
|
|
484
|
+
print(f"Median 50% Width: {median_width_50:.3f}")
|
|
485
|
+
print(f"Median 68% Width: {median_width_68:.3f}")
|
|
486
|
+
print(f"Median 80% Width: {median_width_80:.3f}")
|
|
487
|
+
print(f"Median 90% Width: {median_width_90:.3f}")
|
|
488
|
+
print(f"Median 95% Width: {median_width_95:.3f}")
|
|
277
489
|
print(f"CRPS: {mean_crps:.3f} (lower is better)")
|
|
278
490
|
print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
|
|
279
|
-
print(f"
|
|
491
|
+
print(f"Interval/Error Corr: {interval_to_error_corr:.3f} (higher is better, target: >0.5)")
|
|
280
492
|
print(f"Samples: {len(df)}")
|
|
281
493
|
return results
|
|
282
494
|
|
|
283
495
|
|
|
284
496
|
if __name__ == "__main__":
|
|
285
497
|
"""Exercise the Model Utilities"""
|
|
286
|
-
from workbench.api import Model
|
|
498
|
+
from workbench.api import Model
|
|
287
499
|
|
|
288
500
|
# Get the instance information
|
|
289
501
|
print(model_instance_info())
|
|
@@ -298,24 +510,11 @@ if __name__ == "__main__":
|
|
|
298
510
|
# Get the custom script path
|
|
299
511
|
print(get_custom_script_path("chem_info", "molecular_descriptors.py"))
|
|
300
512
|
|
|
301
|
-
# Test
|
|
513
|
+
# Test loading hyperparameters
|
|
302
514
|
m = Model("aqsol-regression")
|
|
515
|
+
hyperparams = get_model_hyperparameters(m)
|
|
516
|
+
print(hyperparams)
|
|
517
|
+
|
|
518
|
+
# Test the proximity model
|
|
303
519
|
# prox_model = proximity_model(m, "aqsol-prox")
|
|
304
520
|
# print(prox_model)#
|
|
305
|
-
|
|
306
|
-
# Test the UQ model
|
|
307
|
-
# uq_model_instance = uq_model(m, "aqsol-uq")
|
|
308
|
-
# print(uq_model_instance)
|
|
309
|
-
# uq_model_instance.to_endpoint()
|
|
310
|
-
|
|
311
|
-
# Test the uq_metrics function
|
|
312
|
-
end = Endpoint("aqsol-uq")
|
|
313
|
-
df = end.auto_inference(capture=True)
|
|
314
|
-
results = uq_metrics(df, target_col="solubility")
|
|
315
|
-
print(results)
|
|
316
|
-
|
|
317
|
-
# Test the uq_metrics function
|
|
318
|
-
end = Endpoint("aqsol-uq-100")
|
|
319
|
-
df = end.auto_inference(capture=True)
|
|
320
|
-
results = uq_metrics(df, target_col="solubility")
|
|
321
|
-
print(results)
|
workbench/utils/monitor_utils.py
CHANGED
|
@@ -14,7 +14,7 @@ from workbench.utils.s3_utils import read_content_from_s3
|
|
|
14
14
|
log = logging.getLogger("workbench")
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def
|
|
17
|
+
def pull_data_capture_for_testing(data_capture_path, max_files=1) -> Union[pd.DataFrame, None]:
|
|
18
18
|
"""
|
|
19
19
|
Read and process captured data from S3.
|
|
20
20
|
|
|
@@ -26,7 +26,12 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
|
|
|
26
26
|
|
|
27
27
|
Returns:
|
|
28
28
|
Union[pd.DataFrame, None]: A dataframe of the captured data (or None if no data is found).
|
|
29
|
+
|
|
30
|
+
Notes:
|
|
31
|
+
This method is really only for testing and debugging.
|
|
29
32
|
"""
|
|
33
|
+
log.important("This method is for testing and debugging only.")
|
|
34
|
+
|
|
30
35
|
# List files in the specified S3 path
|
|
31
36
|
files = wr.s3.list_objects(data_capture_path)
|
|
32
37
|
if not files:
|
|
@@ -64,59 +69,53 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
|
|
|
64
69
|
def process_data_capture(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
65
70
|
"""
|
|
66
71
|
Process the captured data DataFrame to extract input and output data.
|
|
67
|
-
|
|
72
|
+
Handles cases where input or output might not be captured.
|
|
73
|
+
|
|
68
74
|
Args:
|
|
69
75
|
df (DataFrame): DataFrame with captured data.
|
|
70
76
|
Returns:
|
|
71
77
|
tuple[DataFrame, DataFrame]: Input and output DataFrames.
|
|
72
78
|
"""
|
|
79
|
+
|
|
80
|
+
def parse_endpoint_data(data: dict) -> pd.DataFrame:
|
|
81
|
+
"""Parse endpoint data based on encoding type."""
|
|
82
|
+
encoding = data["encoding"].upper()
|
|
83
|
+
|
|
84
|
+
if encoding == "CSV":
|
|
85
|
+
return pd.read_csv(StringIO(data["data"]))
|
|
86
|
+
elif encoding == "JSON":
|
|
87
|
+
json_data = json.loads(data["data"])
|
|
88
|
+
if isinstance(json_data, dict):
|
|
89
|
+
return pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
|
|
90
|
+
else:
|
|
91
|
+
return pd.DataFrame(json_data)
|
|
92
|
+
else:
|
|
93
|
+
return None # Unknown encoding
|
|
94
|
+
|
|
73
95
|
input_dfs = []
|
|
74
96
|
output_dfs = []
|
|
75
97
|
|
|
76
|
-
|
|
98
|
+
# Use itertuples() instead of iterrows() for better performance
|
|
99
|
+
for row in df.itertuples(index=True):
|
|
77
100
|
try:
|
|
78
|
-
capture_data = row
|
|
79
|
-
|
|
80
|
-
#
|
|
81
|
-
if "endpointInput"
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
# Process input data
|
|
92
|
-
input_data = capture_data["endpointInput"]
|
|
93
|
-
if input_data["encoding"].upper() == "CSV":
|
|
94
|
-
input_df = pd.read_csv(StringIO(input_data["data"]))
|
|
95
|
-
elif input_data["encoding"].upper() == "JSON":
|
|
96
|
-
json_data = json.loads(input_data["data"])
|
|
97
|
-
if isinstance(json_data, dict):
|
|
98
|
-
input_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
|
|
99
|
-
else:
|
|
100
|
-
input_df = pd.DataFrame(json_data)
|
|
101
|
-
|
|
102
|
-
# Process output data
|
|
103
|
-
output_data = capture_data["endpointOutput"]
|
|
104
|
-
if output_data["encoding"].upper() == "CSV":
|
|
105
|
-
output_df = pd.read_csv(StringIO(output_data["data"]))
|
|
106
|
-
elif output_data["encoding"].upper() == "JSON":
|
|
107
|
-
json_data = json.loads(output_data["data"])
|
|
108
|
-
if isinstance(json_data, dict):
|
|
109
|
-
output_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
|
|
110
|
-
else:
|
|
111
|
-
output_df = pd.DataFrame(json_data)
|
|
112
|
-
|
|
113
|
-
# If we get here, both processed successfully
|
|
114
|
-
input_dfs.append(input_df)
|
|
115
|
-
output_dfs.append(output_df)
|
|
101
|
+
capture_data = row.captureData
|
|
102
|
+
|
|
103
|
+
# Process input data if present
|
|
104
|
+
if "endpointInput" in capture_data:
|
|
105
|
+
input_df = parse_endpoint_data(capture_data["endpointInput"])
|
|
106
|
+
if input_df is not None:
|
|
107
|
+
input_dfs.append(input_df)
|
|
108
|
+
|
|
109
|
+
# Process output data if present
|
|
110
|
+
if "endpointOutput" in capture_data:
|
|
111
|
+
output_df = parse_endpoint_data(capture_data["endpointOutput"])
|
|
112
|
+
if output_df is not None:
|
|
113
|
+
output_dfs.append(output_df)
|
|
116
114
|
|
|
117
115
|
except Exception as e:
|
|
118
|
-
log.
|
|
116
|
+
log.debug(f"Row {row.Index}: Failed to process row: {e}")
|
|
119
117
|
continue
|
|
118
|
+
|
|
120
119
|
# Combine and return results
|
|
121
120
|
return (
|
|
122
121
|
pd.concat(input_dfs, ignore_index=True) if input_dfs else pd.DataFrame(),
|
|
@@ -178,23 +177,6 @@ def parse_monitoring_results(results_json: str) -> Dict[str, Any]:
|
|
|
178
177
|
return {"error": str(e)}
|
|
179
178
|
|
|
180
179
|
|
|
181
|
-
"""TEMP
|
|
182
|
-
# If the status is "CompletedWithViolations", we grab the lastest
|
|
183
|
-
# violation file and add it to the result
|
|
184
|
-
if status == "CompletedWithViolations":
|
|
185
|
-
violation_file = f"{self.monitoring_path}/
|
|
186
|
-
{last_run['CreationTime'].strftime('%Y/%m/%d')}/constraint_violations.json"
|
|
187
|
-
if wr.s3.does_object_exist(violation_file):
|
|
188
|
-
violations_json = read_content_from_s3(violation_file)
|
|
189
|
-
violations = parse_monitoring_results(violations_json)
|
|
190
|
-
result["violations"] = violations.get("constraint_violations", [])
|
|
191
|
-
result["violation_count"] = len(result["violations"])
|
|
192
|
-
else:
|
|
193
|
-
result["violations"] = []
|
|
194
|
-
result["violation_count"] = 0
|
|
195
|
-
"""
|
|
196
|
-
|
|
197
|
-
|
|
198
180
|
def preprocessing_script(feature_list: list[str]) -> str:
|
|
199
181
|
"""
|
|
200
182
|
A preprocessing script for monitoring jobs.
|
|
@@ -245,8 +227,8 @@ if __name__ == "__main__":
|
|
|
245
227
|
from workbench.api.monitor import Monitor
|
|
246
228
|
|
|
247
229
|
# Test pulling data capture
|
|
248
|
-
mon = Monitor("
|
|
249
|
-
df =
|
|
230
|
+
mon = Monitor("abalone-regression-rt")
|
|
231
|
+
df = pull_data_capture_for_testing(mon.data_capture_path)
|
|
250
232
|
print("Data Capture:")
|
|
251
233
|
print(df.head())
|
|
252
234
|
|
|
@@ -262,4 +244,4 @@ if __name__ == "__main__":
|
|
|
262
244
|
# Test preprocessing script
|
|
263
245
|
script = preprocessing_script(["feature1", "feature2", "feature3"])
|
|
264
246
|
print("\nPreprocessing Script:")
|
|
265
|
-
print(script)
|
|
247
|
+
# print(script)
|
workbench/utils/pandas_utils.py
CHANGED
|
@@ -152,7 +152,7 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
|
|
|
152
152
|
|
|
153
153
|
# Check for differences in common columns
|
|
154
154
|
for column in common_columns:
|
|
155
|
-
if pd.api.types.is_string_dtype(df1[column])
|
|
155
|
+
if pd.api.types.is_string_dtype(df1[column]) and pd.api.types.is_string_dtype(df2[column]):
|
|
156
156
|
# String comparison with NaNs treated as equal
|
|
157
157
|
differences = ~(df1[column].fillna("") == df2[column].fillna(""))
|
|
158
158
|
elif pd.api.types.is_float_dtype(df1[column]) or pd.api.types.is_float_dtype(df2[column]):
|
|
@@ -161,8 +161,8 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
|
|
|
161
161
|
pd.isna(df1[column]) & pd.isna(df2[column])
|
|
162
162
|
)
|
|
163
163
|
else:
|
|
164
|
-
# Other types (
|
|
165
|
-
differences =
|
|
164
|
+
# Other types (int, Int64, etc.) - compare with NaNs treated as equal
|
|
165
|
+
differences = (df1[column] != df2[column]) & ~(pd.isna(df1[column]) & pd.isna(df2[column]))
|
|
166
166
|
|
|
167
167
|
# If differences exist, display them
|
|
168
168
|
if differences.any():
|