workbench 0.8.174__py3-none-any.whl → 0.8.176__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/core/artifacts/artifact.py +2 -2
- workbench/core/artifacts/data_capture_core.py +8 -1
- workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +197 -96
- workbench/model_scripts/custom_models/uq_models/mapie.template +1 -9
- {workbench-0.8.174.dist-info → workbench-0.8.176.dist-info}/METADATA +1 -1
- {workbench-0.8.174.dist-info → workbench-0.8.176.dist-info}/RECORD +11 -11
- {workbench-0.8.174.dist-info → workbench-0.8.176.dist-info}/WHEEL +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.176.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.176.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.176.dist-info}/top_level.txt +0 -0
|
@@ -238,8 +238,8 @@ class Artifact(ABC):
|
|
|
238
238
|
"""
|
|
239
239
|
|
|
240
240
|
# Check for ReadOnly Role
|
|
241
|
-
if self.aws_account_clamp.
|
|
242
|
-
self.log.info("Cannot add metadata with a ReadOnly
|
|
241
|
+
if self.aws_account_clamp.read_only:
|
|
242
|
+
self.log.info("Cannot add metadata with a ReadOnly Permissions...")
|
|
243
243
|
return
|
|
244
244
|
|
|
245
245
|
# Sanity check
|
|
@@ -231,6 +231,13 @@ class DataCaptureCore:
|
|
|
231
231
|
self.log.info(f"Processing {len(files)} files from {from_date} onwards.")
|
|
232
232
|
else:
|
|
233
233
|
self.log.info(f"Processing all {len(files)} files...")
|
|
234
|
+
|
|
235
|
+
# Check if any files remain after filtering
|
|
236
|
+
if not files:
|
|
237
|
+
self.log.info("No files to process after date filtering.")
|
|
238
|
+
return pd.DataFrame(), pd.DataFrame()
|
|
239
|
+
|
|
240
|
+
# Sort files by name (assumed to include timestamp)
|
|
234
241
|
files.sort()
|
|
235
242
|
|
|
236
243
|
# Get all timestamps in one batch if needed
|
|
@@ -337,7 +344,7 @@ if __name__ == "__main__":
|
|
|
337
344
|
# print(pred_df.head())
|
|
338
345
|
|
|
339
346
|
# Check that data capture is working
|
|
340
|
-
input_df, output_df = dc.get_captured_data()
|
|
347
|
+
input_df, output_df = dc.get_captured_data(from_date="2025-09-01")
|
|
341
348
|
if input_df.empty and output_df.empty:
|
|
342
349
|
print("No data capture files found, for a new endpoint it may take a few minutes to start capturing data")
|
|
343
350
|
else:
|
|
@@ -55,9 +55,10 @@ class AWSAccountClamp:
|
|
|
55
55
|
# Check our Assume Role
|
|
56
56
|
self.log.info("Checking Workbench Assumed Role...")
|
|
57
57
|
role_info = self.aws_session.assumed_role_info()
|
|
58
|
+
self.log.info(f"Assumed Role: {role_info}")
|
|
58
59
|
|
|
59
|
-
# Check if
|
|
60
|
-
self.
|
|
60
|
+
# Check if we have tag write permissions (if we don't, we are read-only)
|
|
61
|
+
self.read_only = not self.check_tag_permissions()
|
|
61
62
|
|
|
62
63
|
# Check our Workbench API Key and Load the License
|
|
63
64
|
self.log.info("Checking Workbench API License...")
|
|
@@ -141,6 +142,45 @@ class AWSAccountClamp:
|
|
|
141
142
|
"""
|
|
142
143
|
return self.boto3_session.client("sagemaker")
|
|
143
144
|
|
|
145
|
+
def check_tag_permissions(self):
|
|
146
|
+
"""Check if current role has permission to add tags to SageMaker endpoints.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
bool: True if AddTags is allowed, False otherwise
|
|
150
|
+
"""
|
|
151
|
+
try:
|
|
152
|
+
sagemaker = self.boto3_session.client("sagemaker")
|
|
153
|
+
|
|
154
|
+
# Use a non-existent endpoint name
|
|
155
|
+
fake_endpoint = "workbench-permission-check-dummy-endpoint"
|
|
156
|
+
|
|
157
|
+
# Try to add tags to the non-existent endpoint
|
|
158
|
+
sagemaker.add_tags(
|
|
159
|
+
ResourceArn=f"arn:aws:sagemaker:{self.region}:{self.account_id}:endpoint/{fake_endpoint}",
|
|
160
|
+
Tags=[{"Key": "PermissionCheck", "Value": "Test"}],
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# If we get here, we have permission (but endpoint doesn't exist)
|
|
164
|
+
return True
|
|
165
|
+
|
|
166
|
+
except ClientError as e:
|
|
167
|
+
error_code = e.response["Error"]["Code"]
|
|
168
|
+
|
|
169
|
+
# AccessDeniedException = no permission
|
|
170
|
+
if error_code == "AccessDeniedException":
|
|
171
|
+
self.log.debug("No AddTags permission (AccessDeniedException)")
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
# ResourceNotFound = we have permission, but endpoint doesn't exist
|
|
175
|
+
elif error_code in ["ResourceNotFound", "ValidationException"]:
|
|
176
|
+
self.log.debug("AddTags permission verified (resource not found)")
|
|
177
|
+
return True
|
|
178
|
+
|
|
179
|
+
# Unexpected error, assume no permission for safety
|
|
180
|
+
else:
|
|
181
|
+
self.log.debug(f"Unexpected error checking permissions: {error_code}")
|
|
182
|
+
return False
|
|
183
|
+
|
|
144
184
|
|
|
145
185
|
if __name__ == "__main__":
|
|
146
186
|
"""Exercise the AWS Account Clamp Class"""
|
|
@@ -165,3 +205,9 @@ if __name__ == "__main__":
|
|
|
165
205
|
print("\n\n*** AWS Sagemaker Session/Client Check ***")
|
|
166
206
|
sm_client = aws_account_clamp.sagemaker_client()
|
|
167
207
|
print(sm_client.list_feature_groups()["FeatureGroupSummaries"])
|
|
208
|
+
|
|
209
|
+
print("\n\n*** AWS Tag Permission Check ***")
|
|
210
|
+
if aws_account_clamp.check_tag_permissions():
|
|
211
|
+
print("Tag Permission Check Success...")
|
|
212
|
+
else:
|
|
213
|
+
print("Tag Permission Check Failed...")
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
# Model:
|
|
2
|
-
from
|
|
3
|
-
from
|
|
4
|
-
from xgboost import XGBRegressor
|
|
1
|
+
# Model: XGBoost for point predictions + LightGBM with MAPIE for conformalized intervals
|
|
2
|
+
from mapie.regression import ConformalizedQuantileRegressor
|
|
3
|
+
from lightgbm import LGBMRegressor
|
|
4
|
+
from xgboost import XGBRegressor
|
|
5
5
|
from sklearn.model_selection import train_test_split
|
|
6
6
|
|
|
7
7
|
# Model Performance Scores
|
|
@@ -20,19 +20,12 @@ import numpy as np
|
|
|
20
20
|
import pandas as pd
|
|
21
21
|
from typing import List, Tuple
|
|
22
22
|
|
|
23
|
-
# Local Imports
|
|
24
|
-
from proximity import Proximity
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
23
|
# Template Placeholders
|
|
29
24
|
TEMPLATE_PARAMS = {
|
|
30
|
-
"id_column": "udm_mol_id",
|
|
31
25
|
"target": "udm_asy_res_value",
|
|
32
|
-
"features": ['bcut2d_logplow', '
|
|
26
|
+
"features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
|
|
33
27
|
"compressed_features": [],
|
|
34
|
-
"train_all_data":
|
|
35
|
-
"track_columns": "udm_asy_res_value"
|
|
28
|
+
"train_all_data": True
|
|
36
29
|
}
|
|
37
30
|
|
|
38
31
|
|
|
@@ -108,7 +101,7 @@ def convert_categorical_types(df: pd.DataFrame, features: list, category_mapping
|
|
|
108
101
|
|
|
109
102
|
|
|
110
103
|
def decompress_features(
|
|
111
|
-
|
|
104
|
+
df: pd.DataFrame, features: List[str], compressed_features: List[str]
|
|
112
105
|
) -> Tuple[pd.DataFrame, List[str]]:
|
|
113
106
|
"""Prepare features for the model by decompressing bitstring features
|
|
114
107
|
|
|
@@ -164,13 +157,11 @@ def decompress_features(
|
|
|
164
157
|
|
|
165
158
|
if __name__ == "__main__":
|
|
166
159
|
# Template Parameters
|
|
167
|
-
id_column = TEMPLATE_PARAMS["id_column"]
|
|
168
160
|
target = TEMPLATE_PARAMS["target"]
|
|
169
161
|
features = TEMPLATE_PARAMS["features"]
|
|
170
162
|
orig_features = features.copy()
|
|
171
163
|
compressed_features = TEMPLATE_PARAMS["compressed_features"]
|
|
172
164
|
train_all_data = TEMPLATE_PARAMS["train_all_data"]
|
|
173
|
-
track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
|
|
174
165
|
validation_split = 0.2
|
|
175
166
|
|
|
176
167
|
# Script arguments for input/output directories
|
|
@@ -228,78 +219,167 @@ if __name__ == "__main__":
|
|
|
228
219
|
print(f"FIT/TRAIN: {df_train.shape}")
|
|
229
220
|
print(f"VALIDATION: {df_val.shape}")
|
|
230
221
|
|
|
231
|
-
# We're using XGBoost for point predictions and NGBoost for uncertainty quantification
|
|
232
|
-
xgb_model = XGBRegressor()
|
|
233
|
-
ngb_model = NGBRegressor() # Dist=Cauchy) Seems to give HUGE prediction intervals
|
|
234
|
-
ngb_model = NGBRegressor(
|
|
235
|
-
Dist=T,
|
|
236
|
-
learning_rate=0.005,
|
|
237
|
-
minibatch_frac=0.1, # Very small batches
|
|
238
|
-
col_sample=0.8 # This parameter DOES exist
|
|
239
|
-
) # Testing this out
|
|
240
|
-
print("NGBoost using T distribution for uncertainty quantification")
|
|
241
|
-
|
|
242
222
|
# Prepare features and targets for training
|
|
243
223
|
X_train = df_train[features]
|
|
244
224
|
X_validate = df_val[features]
|
|
245
225
|
y_train = df_train[target]
|
|
246
226
|
y_validate = df_val[target]
|
|
247
227
|
|
|
248
|
-
# Train
|
|
228
|
+
# Train XGBoost for point predictions
|
|
229
|
+
print("\nTraining XGBoost for point predictions...")
|
|
230
|
+
xgb_model = XGBRegressor(enable_categorical=True)
|
|
249
231
|
xgb_model.fit(X_train, y_train)
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
print(f"
|
|
261
|
-
|
|
262
|
-
|
|
232
|
+
|
|
233
|
+
# Evaluate XGBoost performance
|
|
234
|
+
y_pred_xgb = xgb_model.predict(X_validate)
|
|
235
|
+
xgb_rmse = root_mean_squared_error(y_validate, y_pred_xgb)
|
|
236
|
+
xgb_mae = mean_absolute_error(y_validate, y_pred_xgb)
|
|
237
|
+
xgb_r2 = r2_score(y_validate, y_pred_xgb)
|
|
238
|
+
|
|
239
|
+
print(f"\nXGBoost Point Prediction Performance:")
|
|
240
|
+
print(f"RMSE: {xgb_rmse:.3f}")
|
|
241
|
+
print(f"MAE: {xgb_mae:.3f}")
|
|
242
|
+
print(f"R2: {xgb_r2:.3f}")
|
|
243
|
+
|
|
244
|
+
# Define confidence levels we want to model
|
|
245
|
+
confidence_levels = [0.50, 0.80, 0.90, 0.95] # 50%, 80%, 90%, 95% confidence intervals
|
|
246
|
+
|
|
247
|
+
# Store MAPIE models for each confidence level
|
|
248
|
+
mapie_models = {}
|
|
249
|
+
|
|
250
|
+
# Train models for each confidence level
|
|
251
|
+
for confidence_level in confidence_levels:
|
|
252
|
+
alpha = 1 - confidence_level
|
|
253
|
+
lower_q = alpha / 2
|
|
254
|
+
upper_q = 1 - alpha / 2
|
|
255
|
+
|
|
256
|
+
print(f"\nTraining quantile models for {confidence_level * 100:.0f}% confidence interval...")
|
|
257
|
+
print(f" Quantiles: {lower_q:.3f}, {upper_q:.3f}, 0.500")
|
|
258
|
+
|
|
259
|
+
# Train three models for this confidence level
|
|
260
|
+
quantile_estimators = []
|
|
261
|
+
for q in [lower_q, upper_q, 0.5]:
|
|
262
|
+
print(f" Training model for quantile {q:.3f}...")
|
|
263
|
+
est = LGBMRegressor(
|
|
264
|
+
objective="quantile",
|
|
265
|
+
alpha=q,
|
|
266
|
+
n_estimators=1000,
|
|
267
|
+
max_depth=6,
|
|
268
|
+
learning_rate=0.01,
|
|
269
|
+
num_leaves=31,
|
|
270
|
+
min_child_samples=20,
|
|
271
|
+
subsample=0.8,
|
|
272
|
+
colsample_bytree=0.8,
|
|
273
|
+
random_state=42,
|
|
274
|
+
verbose=-1,
|
|
275
|
+
force_col_wise=True
|
|
276
|
+
)
|
|
277
|
+
est.fit(X_train, y_train)
|
|
278
|
+
quantile_estimators.append(est)
|
|
279
|
+
|
|
280
|
+
# Create MAPIE CQR model for this confidence level
|
|
281
|
+
print(f" Setting up MAPIE CQR for {confidence_level * 100:.0f}% confidence...")
|
|
282
|
+
mapie_model = ConformalizedQuantileRegressor(
|
|
283
|
+
quantile_estimators,
|
|
284
|
+
confidence_level=confidence_level,
|
|
285
|
+
prefit=True
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Conformalize the model
|
|
289
|
+
print(f" Conformalizing with validation data...")
|
|
290
|
+
mapie_model.conformalize(X_validate, y_validate)
|
|
291
|
+
|
|
292
|
+
# Store the model
|
|
293
|
+
mapie_models[f"mapie_{confidence_level:.2f}"] = mapie_model
|
|
294
|
+
|
|
295
|
+
# Validate coverage for this confidence level
|
|
296
|
+
y_pred, y_pis = mapie_model.predict_interval(X_validate)
|
|
297
|
+
coverage = np.mean((y_validate >= y_pis[:, 0, 0]) & (y_validate <= y_pis[:, 1, 0]))
|
|
298
|
+
print(f" Coverage: Target={confidence_level * 100:.0f}%, Empirical={coverage * 100:.1f}%")
|
|
299
|
+
|
|
300
|
+
print(f"\nOverall Model Performance Summary:")
|
|
301
|
+
print(f"XGBoost RMSE: {xgb_rmse:.3f}")
|
|
302
|
+
print(f"XGBoost MAE: {xgb_mae:.3f}")
|
|
303
|
+
print(f"XGBoost R2: {xgb_r2:.3f}")
|
|
263
304
|
print(f"NumRows: {len(df_val)}")
|
|
264
305
|
|
|
306
|
+
# Analyze interval widths across confidence levels
|
|
307
|
+
print(f"\nInterval Width Analysis:")
|
|
308
|
+
for conf_level in confidence_levels:
|
|
309
|
+
model = mapie_models[f"mapie_{conf_level:.2f}"]
|
|
310
|
+
_, y_pis = model.predict_interval(X_validate)
|
|
311
|
+
widths = y_pis[:, 1, 0] - y_pis[:, 0, 0]
|
|
312
|
+
print(f" {conf_level * 100:.0f}% CI: Mean width={np.mean(widths):.3f}, Std={np.std(widths):.3f}")
|
|
313
|
+
|
|
265
314
|
# Save the trained XGBoost model
|
|
266
315
|
xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
|
|
267
316
|
|
|
268
|
-
# Save
|
|
269
|
-
|
|
317
|
+
# Save all MAPIE models
|
|
318
|
+
for model_name, model in mapie_models.items():
|
|
319
|
+
joblib.dump(model, os.path.join(args.model_dir, f"{model_name}.joblib"))
|
|
270
320
|
|
|
271
|
-
# Save the
|
|
321
|
+
# Save the feature list
|
|
272
322
|
with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
|
|
273
|
-
json.dump(
|
|
274
|
-
|
|
275
|
-
#
|
|
276
|
-
|
|
323
|
+
json.dump(features, fp)
|
|
324
|
+
|
|
325
|
+
# Save category mappings if any
|
|
326
|
+
if category_mappings:
|
|
327
|
+
with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as fp:
|
|
328
|
+
json.dump(category_mappings, fp)
|
|
329
|
+
|
|
330
|
+
# Save model configuration
|
|
331
|
+
model_config = {
|
|
332
|
+
"model_type": "XGBoost_MAPIE_CQR_LightGBM",
|
|
333
|
+
"confidence_levels": confidence_levels,
|
|
334
|
+
"n_features": len(features),
|
|
335
|
+
"target": target,
|
|
336
|
+
"validation_metrics": {
|
|
337
|
+
"xgb_rmse": float(xgb_rmse),
|
|
338
|
+
"xgb_mae": float(xgb_mae),
|
|
339
|
+
"xgb_r2": float(xgb_r2),
|
|
340
|
+
"n_validation": len(df_val)
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
with open(os.path.join(args.model_dir, "model_config.json"), "w") as fp:
|
|
344
|
+
json.dump(model_config, fp, indent=2)
|
|
277
345
|
|
|
278
|
-
|
|
279
|
-
model
|
|
346
|
+
print(f"\nModel training complete!")
|
|
347
|
+
print(f"Saved 1 XGBoost model and {len(mapie_models)} MAPIE models to {args.model_dir}")
|
|
280
348
|
|
|
281
349
|
|
|
282
350
|
#
|
|
283
351
|
# Inference Section
|
|
284
352
|
#
|
|
285
353
|
def model_fn(model_dir) -> dict:
|
|
286
|
-
"""Load
|
|
354
|
+
"""Load XGBoost and all MAPIE models from the specified directory."""
|
|
355
|
+
|
|
356
|
+
# Load model configuration to know which models to load
|
|
357
|
+
with open(os.path.join(model_dir, "model_config.json")) as fp:
|
|
358
|
+
config = json.load(fp)
|
|
287
359
|
|
|
288
360
|
# Load XGBoost regressor
|
|
289
361
|
xgb_path = os.path.join(model_dir, "xgb_model.json")
|
|
290
362
|
xgb_model = XGBRegressor(enable_categorical=True)
|
|
291
363
|
xgb_model.load_model(xgb_path)
|
|
292
364
|
|
|
293
|
-
# Load
|
|
294
|
-
|
|
365
|
+
# Load all MAPIE models
|
|
366
|
+
mapie_models = {}
|
|
367
|
+
for conf_level in config["confidence_levels"]:
|
|
368
|
+
model_name = f"mapie_{conf_level:.2f}"
|
|
369
|
+
mapie_models[model_name] = joblib.load(os.path.join(model_dir, f"{model_name}.joblib"))
|
|
295
370
|
|
|
296
|
-
#
|
|
297
|
-
|
|
371
|
+
# Load category mappings if they exist
|
|
372
|
+
category_mappings = {}
|
|
373
|
+
category_path = os.path.join(model_dir, "category_mappings.json")
|
|
374
|
+
if os.path.exists(category_path):
|
|
375
|
+
with open(category_path) as fp:
|
|
376
|
+
category_mappings = json.load(fp)
|
|
298
377
|
|
|
299
378
|
return {
|
|
300
|
-
"
|
|
301
|
-
"
|
|
302
|
-
"
|
|
379
|
+
"xgb_model": xgb_model,
|
|
380
|
+
"mapie_models": mapie_models,
|
|
381
|
+
"confidence_levels": config["confidence_levels"],
|
|
382
|
+
"category_mappings": category_mappings
|
|
303
383
|
}
|
|
304
384
|
|
|
305
385
|
|
|
@@ -315,7 +395,7 @@ def input_fn(input_data, content_type):
|
|
|
315
395
|
if "text/csv" in content_type:
|
|
316
396
|
return pd.read_csv(StringIO(input_data))
|
|
317
397
|
elif "application/json" in content_type:
|
|
318
|
-
return pd.DataFrame(json.loads(input_data))
|
|
398
|
+
return pd.DataFrame(json.loads(input_data))
|
|
319
399
|
else:
|
|
320
400
|
raise ValueError(f"{content_type} not supported!")
|
|
321
401
|
|
|
@@ -323,23 +403,26 @@ def input_fn(input_data, content_type):
|
|
|
323
403
|
def output_fn(output_df, accept_type):
|
|
324
404
|
"""Supports both CSV and JSON output formats."""
|
|
325
405
|
if "text/csv" in accept_type:
|
|
326
|
-
|
|
406
|
+
# Convert categorical columns to string to avoid fillna issues
|
|
407
|
+
for col in output_df.select_dtypes(include=['category']).columns:
|
|
408
|
+
output_df[col] = output_df[col].astype(str)
|
|
409
|
+
csv_output = output_df.fillna("N/A").to_csv(index=False)
|
|
327
410
|
return csv_output, "text/csv"
|
|
328
411
|
elif "application/json" in accept_type:
|
|
329
|
-
return output_df.to_json(orient="records"), "application/json"
|
|
412
|
+
return output_df.to_json(orient="records"), "application/json"
|
|
330
413
|
else:
|
|
331
414
|
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
332
415
|
|
|
333
416
|
|
|
334
417
|
def predict_fn(df, models) -> pd.DataFrame:
|
|
335
|
-
"""Make
|
|
418
|
+
"""Make predictions using XGBoost for point estimates and MAPIE for conformalized intervals
|
|
336
419
|
|
|
337
420
|
Args:
|
|
338
421
|
df (pd.DataFrame): The input DataFrame
|
|
339
|
-
models (dict):
|
|
422
|
+
models (dict): Dictionary containing XGBoost and MAPIE models
|
|
340
423
|
|
|
341
424
|
Returns:
|
|
342
|
-
pd.DataFrame:
|
|
425
|
+
pd.DataFrame: DataFrame with XGBoost predictions and conformalized intervals
|
|
343
426
|
"""
|
|
344
427
|
|
|
345
428
|
# Grab our feature columns (from training)
|
|
@@ -350,44 +433,62 @@ def predict_fn(df, models) -> pd.DataFrame:
|
|
|
350
433
|
# Match features in a case-insensitive manner
|
|
351
434
|
matched_df = match_features_case_insensitive(df, model_features)
|
|
352
435
|
|
|
353
|
-
#
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
dist_params = y_dists.params
|
|
361
|
-
|
|
362
|
-
# Extract mean and std from distribution parameters
|
|
363
|
-
df["prediction_uq"] = dist_params['loc'] # mean
|
|
364
|
-
df["prediction_std"] = dist_params['scale'] # standard deviation
|
|
365
|
-
|
|
366
|
-
# Add 95% prediction intervals using ppf (percent point function)
|
|
367
|
-
# Note: Our hybrid model uses XGB point prediction and NGBoost UQ
|
|
368
|
-
# so we need to adjust the bounds to include the point prediction
|
|
369
|
-
df["q_025"] = np.minimum(y_dists.ppf(0.025), df["prediction"])
|
|
370
|
-
df["q_975"] = np.maximum(y_dists.ppf(0.975), df["prediction"])
|
|
371
|
-
|
|
372
|
-
# Add 90% prediction intervals
|
|
373
|
-
df["q_05"] = y_dists.ppf(0.05) # 5th percentile
|
|
374
|
-
df["q_95"] = y_dists.ppf(0.95) # 95th percentile
|
|
375
|
-
|
|
376
|
-
# Add 80% prediction intervals
|
|
377
|
-
df["q_10"] = y_dists.ppf(0.10) # 10th percentile
|
|
378
|
-
df["q_90"] = y_dists.ppf(0.90) # 90th percentile
|
|
436
|
+
# Apply categorical mappings if they exist
|
|
437
|
+
if models.get("category_mappings"):
|
|
438
|
+
matched_df, _ = convert_categorical_types(
|
|
439
|
+
matched_df,
|
|
440
|
+
model_features,
|
|
441
|
+
models["category_mappings"]
|
|
442
|
+
)
|
|
379
443
|
|
|
380
|
-
#
|
|
381
|
-
|
|
382
|
-
|
|
444
|
+
# Get features for prediction
|
|
445
|
+
X = matched_df[model_features]
|
|
446
|
+
|
|
447
|
+
# Get XGBoost point predictions
|
|
448
|
+
df["prediction"] = models["xgb_model"].predict(X)
|
|
449
|
+
|
|
450
|
+
# Get predictions from each MAPIE model for conformalized intervals
|
|
451
|
+
for conf_level in models["confidence_levels"]:
|
|
452
|
+
model_name = f"mapie_{conf_level:.2f}"
|
|
453
|
+
model = models["mapie_models"][model_name]
|
|
454
|
+
|
|
455
|
+
# Get conformalized predictions
|
|
456
|
+
y_pred, y_pis = model.predict_interval(X)
|
|
457
|
+
|
|
458
|
+
# Map confidence levels to quantile names
|
|
459
|
+
if conf_level == 0.50: # 50% CI
|
|
460
|
+
df["q_25"] = y_pis[:, 0, 0]
|
|
461
|
+
df["q_75"] = y_pis[:, 1, 0]
|
|
462
|
+
elif conf_level == 0.80: # 80% CI
|
|
463
|
+
df["q_10"] = y_pis[:, 0, 0]
|
|
464
|
+
df["q_90"] = y_pis[:, 1, 0]
|
|
465
|
+
elif conf_level == 0.90: # 90% CI
|
|
466
|
+
df["q_05"] = y_pis[:, 0, 0]
|
|
467
|
+
df["q_95"] = y_pis[:, 1, 0]
|
|
468
|
+
elif conf_level == 0.95: # 95% CI
|
|
469
|
+
df["q_025"] = y_pis[:, 0, 0]
|
|
470
|
+
df["q_975"] = y_pis[:, 1, 0]
|
|
471
|
+
|
|
472
|
+
# Add median (q_50) from XGBoost prediction
|
|
473
|
+
df["q_50"] = df["prediction"]
|
|
474
|
+
|
|
475
|
+
# Calculate uncertainty metrics based on 95% interval
|
|
476
|
+
interval_width = df["q_975"] - df["q_025"]
|
|
477
|
+
df["prediction_std"] = interval_width / 3.92
|
|
383
478
|
|
|
384
479
|
# Reorder the quantile columns for easier reading
|
|
385
480
|
quantile_cols = ["q_025", "q_05", "q_10", "q_25", "q_75", "q_90", "q_95", "q_975"]
|
|
386
481
|
other_cols = [col for col in df.columns if col not in quantile_cols]
|
|
387
482
|
df = df[other_cols + quantile_cols]
|
|
388
483
|
|
|
389
|
-
#
|
|
390
|
-
|
|
484
|
+
# Uncertainty score
|
|
485
|
+
df["uncertainty_score"] = interval_width / (np.abs(df["prediction"]) + 1e-6)
|
|
486
|
+
|
|
487
|
+
# Confidence bands
|
|
488
|
+
df["confidence_band"] = pd.cut(
|
|
489
|
+
df["uncertainty_score"],
|
|
490
|
+
bins=[0, 0.5, 1.0, 2.0, np.inf],
|
|
491
|
+
labels=["high", "medium", "low", "very_low"]
|
|
492
|
+
)
|
|
391
493
|
|
|
392
|
-
# Return the modified DataFrame
|
|
393
494
|
return df
|
|
@@ -227,15 +227,7 @@ if __name__ == "__main__":
|
|
|
227
227
|
|
|
228
228
|
# Train XGBoost for point predictions
|
|
229
229
|
print("\nTraining XGBoost for point predictions...")
|
|
230
|
-
xgb_model = XGBRegressor(
|
|
231
|
-
n_estimators=1000,
|
|
232
|
-
max_depth=6,
|
|
233
|
-
learning_rate=0.01,
|
|
234
|
-
subsample=0.8,
|
|
235
|
-
colsample_bytree=0.8,
|
|
236
|
-
random_state=42,
|
|
237
|
-
verbosity=0
|
|
238
|
-
)
|
|
230
|
+
xgb_model = XGBRegressor(enable_categorical=True)
|
|
239
231
|
xgb_model.fit(X_train, y_train)
|
|
240
232
|
|
|
241
233
|
# Evaluate XGBoost performance
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: workbench
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.176
|
|
4
4
|
Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
|
|
5
5
|
Author-email: SuperCowPowers LLC <support@supercowpowers.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -48,10 +48,10 @@ workbench/cached/cached_model.py,sha256=iMc_fySUE5qau3feduVXMNb24JY0sBjt1g6WeLLc
|
|
|
48
48
|
workbench/cached/cached_pipeline.py,sha256=QOVnEKu5RbIdlNpJUi-0Ebh0_-C68RigSPwKh4dvZTM,1948
|
|
49
49
|
workbench/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
50
50
|
workbench/core/artifacts/__init__.py,sha256=ps7rA_rbWnDbvWbg4kvu--IKMY8WmbPRyv4Si0xub1Q,965
|
|
51
|
-
workbench/core/artifacts/artifact.py,sha256=
|
|
51
|
+
workbench/core/artifacts/artifact.py,sha256=WFGC1F61d7uFSRB7UTWYOF8O_wk8F9rn__THJL2veLM,17752
|
|
52
52
|
workbench/core/artifacts/athena_source.py,sha256=RNmCe7s6uH4gVHpcdJcL84aSbF5Q1ahJBLLGwHYRXEU,26081
|
|
53
53
|
workbench/core/artifacts/cached_artifact_mixin.py,sha256=ngqFLZ4cQx_TFouXZgXZQsv_7W6XCvxVGXXSfzzaft8,3775
|
|
54
|
-
workbench/core/artifacts/data_capture_core.py,sha256=
|
|
54
|
+
workbench/core/artifacts/data_capture_core.py,sha256=q8f79rRTYiZ7T4IQRWXl8ZvPpcvZyNxYERwvo8o0OQc,14858
|
|
55
55
|
workbench/core/artifacts/data_source_abstract.py,sha256=5IRCzFVK-17cd4NXPMRfx99vQAmQ0WHE5jcm5RfsVTg,10619
|
|
56
56
|
workbench/core/artifacts/data_source_factory.py,sha256=YL_tA5fsgubbB3dPF6T4tO0rGgz-6oo3ge4i_YXVC-M,2380
|
|
57
57
|
workbench/core/artifacts/endpoint_core.py,sha256=lwgiz0jttW8C4YqcKaA8nf231WI3kol-nLnKcAbFJko,49049
|
|
@@ -60,7 +60,7 @@ workbench/core/artifacts/model_core.py,sha256=6d5dV4DGUBgD9E_Gpk0F5x7OEc4oiDKokv
|
|
|
60
60
|
workbench/core/artifacts/monitor_core.py,sha256=M307yz7tEzOEHgv-LmtVy9jKjSbM98fHW3ckmNYrwlU,27897
|
|
61
61
|
workbench/core/cloud_platform/cloud_meta.py,sha256=-g4-LTC3D0PXb3VfaXdLR1ERijKuHdffeMK_zhD-koQ,8809
|
|
62
62
|
workbench/core/cloud_platform/aws/README.md,sha256=QT5IQXoUHbIA0qQ2wO6_2P2lYjYQFVYuezc22mWY4i8,97
|
|
63
|
-
workbench/core/cloud_platform/aws/aws_account_clamp.py,sha256=
|
|
63
|
+
workbench/core/cloud_platform/aws/aws_account_clamp.py,sha256=V5iVsoGvSRilARtTdExnt27QptzAcJaW0s3nm2B8-ow,8286
|
|
64
64
|
workbench/core/cloud_platform/aws/aws_df_store.py,sha256=utRIlTCPwFneHHZ8_Z3Hw3rOJSeryiFA4wBtucxULRQ,15055
|
|
65
65
|
workbench/core/cloud_platform/aws/aws_graph_store.py,sha256=ytYxQTplUmeWbsPmxyZbf6mO9qyTl60ewlJG8MyfyEY,9414
|
|
66
66
|
workbench/core/cloud_platform/aws/aws_meta.py,sha256=eY9Pn6pl2yAyseACFb2nitR-0vLwG4i8CSEXe8Iaswc,34778
|
|
@@ -140,8 +140,8 @@ workbench/model_scripts/custom_models/uq_models/Readme.md,sha256=UVpL-lvtTrLqwBe
|
|
|
140
140
|
workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template,sha256=U4LIlpp8Rbu3apyzPR7-55lvlutpTsCro_PUvQ5pklY,6457
|
|
141
141
|
workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template,sha256=0IJnSBACQ556ldEiPqR7yPCOOLJs1hQhHmPBvB2d9tY,13491
|
|
142
142
|
workbench/model_scripts/custom_models/uq_models/gaussian_process.template,sha256=QbDUfkiPCwJ-c-4Twgu4utZuYZaAyeW_3T1IP-_tutw,6683
|
|
143
|
-
workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=
|
|
144
|
-
workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=
|
|
143
|
+
workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=AcLf-vXOmn_vpTeiKpNKCW_dRhR8Co1sMFC84EPT4IE,22392
|
|
144
|
+
workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=VkFM0eZM2d-hzDbngk9s08DD5vn2nQRD4coCUfj36Fk,18181
|
|
145
145
|
workbench/model_scripts/custom_models/uq_models/meta_uq.template,sha256=eawh0Fp3DhbdCXzWN6KloczT5ZS_ou4ayW65yUTTE4o,14109
|
|
146
146
|
workbench/model_scripts/custom_models/uq_models/ngboost.template,sha256=9-O6P-SW50ul5Wl6es2DMWXSbrwOg7HWsdc8Qdln0MM,8278
|
|
147
147
|
workbench/model_scripts/custom_models/uq_models/proximity.py,sha256=zqmNlX70LnWXr5fdtFFQppSNTLjlOciQVrjGr-g9jRE,13716
|
|
@@ -288,9 +288,9 @@ workbench/web_interface/page_views/main_page.py,sha256=X4-KyGTKLAdxR-Zk2niuLJB2Y
|
|
|
288
288
|
workbench/web_interface/page_views/models_page_view.py,sha256=M0bdC7bAzLyIaE2jviY12FF4abdMFZmg6sFuOY_LaGI,2650
|
|
289
289
|
workbench/web_interface/page_views/page_view.py,sha256=Gh6YnpOGlUejx-bHZAf5pzqoQ1H1R0OSwOpGhOBO06w,455
|
|
290
290
|
workbench/web_interface/page_views/pipelines_page_view.py,sha256=v2pxrIbsHBcYiblfius3JK766NZ7ciD2yPx0t3E5IJo,2656
|
|
291
|
-
workbench-0.8.
|
|
292
|
-
workbench-0.8.
|
|
293
|
-
workbench-0.8.
|
|
294
|
-
workbench-0.8.
|
|
295
|
-
workbench-0.8.
|
|
296
|
-
workbench-0.8.
|
|
291
|
+
workbench-0.8.176.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
|
|
292
|
+
workbench-0.8.176.dist-info/METADATA,sha256=4uDF0MKfrLJrqmAiwYsUlCOA8o5BlxNTLweZLFwtYS0,9210
|
|
293
|
+
workbench-0.8.176.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
294
|
+
workbench-0.8.176.dist-info/entry_points.txt,sha256=zPFPruY9uayk8-wsKrhfnIyIB6jvZOW_ibyllEIsLWo,356
|
|
295
|
+
workbench-0.8.176.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
|
|
296
|
+
workbench-0.8.176.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|