workbench 0.8.181__py3-none-any.whl → 0.8.182__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

@@ -420,8 +420,12 @@ class EndpointCore(Artifact):
420
420
 
421
421
  # Capture the inference results and metrics
422
422
  if capture_name is not None:
423
+
424
+ # If we don't have an id_column, we'll pull it from the model's FeatureSet
425
+ if id_column is None:
426
+ fs = FeatureSetCore(model.get_input())
427
+ id_column = fs.id_column
423
428
  description = capture_name.replace("_", " ").title()
424
- features = model.features()
425
429
  self._capture_inference_results(
426
430
  capture_name, prediction_df, target_column, model_type, metrics, description, features, id_column
427
431
  )
@@ -764,7 +768,7 @@ class EndpointCore(Artifact):
764
768
 
765
769
  # Add the ID column
766
770
  if id_column and id_column in pred_results_df.columns:
767
- output_columns.append(id_column)
771
+ output_columns.insert(0, id_column)
768
772
 
769
773
  # Write the predictions to our S3 Model Inference Folder
770
774
  self.log.info(f"Writing predictions to {inference_capture_path}/inference_predictions.csv")
@@ -18,8 +18,8 @@ from typing import List, Tuple
18
18
 
19
19
  # Template Placeholders
20
20
  TEMPLATE_PARAMS = {
21
- "target": "solubility",
22
- "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
21
+ "target": "udm_asy_res_free_percent",
22
+ "features": ['vsa_estate6', 'naromatom', 'mollogp', 'fr_nh2', 'mp', 'c2sp2', 'xch_3d', 'axp_6d', 'bcut2d_mrhi', 'fr_benzene', 'mz', 'slogp_vsa6', 'fr_halogen', 'bcut2d_mwhi', 'vsa_estate4', 'slogp_vsa3', 'estate_vsa5', 'minestateindex', 'axp_3dv', 'estate_vsa3', 'vsa_estate9', 'molwt', 'hallkieralpha', 'fpdensitymorgan1', 'peoe_vsa13', 'xpc_5d', 'bcut2d_chghi', 'peoe_vsa8', 'axp_0dv', 'axp_2d', 'chi2v', 'bcut2d_logphi', 'axp_5d', 'peoe_vsa2', 'estate_vsa6', 'qed', 'numrotatablebonds', 'xc_3dv', 'peoe_vsa3', 'balabanj', 'slogp_vsa5', 'mv', 'vsa_estate2', 'bcut2d_mwlow', 'xch_7d', 'chi3n', 'vsa_estate8', 'estate_vsa4', 'xp_7dv', 'fr_nh1', 'vsa_estate3', 'fr_ketone_topliss', 'minpartialcharge', 'phi', 'peoe_vsa10', 'vsa_estate7', 'estate_vsa7', 'tpsa', 'kappa3', 'kappa2', 'bcut2d_logplow', 'xch_6d', 'maxpartialcharge', 'vsa_estate1', 'peoe_vsa9', 'axp_1d', 'fr_ar_n', 'chi2n', 'vsa_estate5', 'xp_4dv', 'slogp_vsa10', 'num_stereobonds', 'peoe_vsa11', 'bcut2d_chglo', 'chi1v', 'peoe_vsa7', 'bertzct', 'axp_2dv', 'estate_vsa2', 'smr_vsa9', 'peoe_vsa6', 'num_s_centers', 'num_r_centers', 'xch_7dv', 'xc_5d', 'axp_4dv', 'xc_5dv', 'mi', 'xc_3d', 'fpdensitymorgan2', 'xp_0dv', 'nhohcount', 'numatomstereocenters', 'mse', 'smr_vsa3', 'peoe_vsa12', 'nocount', 'fpdensitymorgan3', 'minabsestateindex', 'bcut2d_mrlow', 'axp_5dv', 'sz', 'vsa_estate10', 'axp_3d', 'xch_6dv', 'xch_4d', 'xc_6d', 'estate_vsa8', 'mpe', 'smr_vsa7', 'numhdonors', 'smr_vsa1', 'xp_5d', 'fr_para_hydroxylation', 'chi3v', 'xpc_6dv', 'nbase', 'heavyatommolwt', 'avgipc', 'maxestateindex', 'smr_vsa6', 'fr_bicyclic', 'xc_4dv', 'xp_7d', 'smr_vsa5', 'xpc_4d', 'smr_vsa4', 'peoe_vsa4', 'numheteroatoms', 'fr_nhpyrrole', 'axp_4d', 'smr_vsa10', 'xp_6d', 'sps', 'mare', 'slogp_vsa2', 'axp_0d', 'slogp_vsa4', 'fr_al_oh', 'numheterocycles', 'labuteasa', 'xp_3d', 'chi4n', 'fractioncsp3', 'maxabspartialcharge', 'fr_al_oh_notert', 'peoe_vsa1', 'axp_7dv', 'slogp_vsa11', 'peoe_vsa5', 'xpc_5dv', 'xpc_6d', 'xp_2d', 'xp_3dv', 'fr_ndealkylation1', 'axp_7d', 'estate_vsa9', 'molmr', 'num_stereocenters', 'si', 'estate_vsa1', 'xc_6dv', 'chi0v', 'fr_oxazole', 'axp_6dv', 'xp_6dv', 'xp_4d', 'numaliphaticheterocycles', 'fr_imine', 'fr_imidazole', 'xp_5dv', 'fr_piperdine', 'slogp_vsa7', 'chi1', 'c1sp2', 'numaromaticheterocycles', 'xpc_4dv', 'c3sp2', 'fr_aniline', 'fr_piperzine', 'axp_1dv', 'xch_4dv', 'chi4v', 'chi1n', 'minabspartialcharge', 'slogp_vsa1', 'fr_nh0', 'chi0n', 'c2sp3', 'xc_4d', 'xch_5dv', 'peoe_vsa14', 'xch_5d', 'numsaturatedrings', 'fr_pyridine', 'kappa1', 'slogp_vsa8', 'xp_2dv', 'fr_ar_coo', 'numvalenceelectrons'],
23
23
  "compressed_features": [],
24
24
  "train_all_data": True,
25
25
  "hyperparameters": {},
@@ -117,8 +117,8 @@ def generate_model_script(template_params: dict) -> str:
117
117
  template_name = "xgb_model.template"
118
118
  model_script_dir = "xgb_model"
119
119
  elif template_params["model_type"] == ModelType.UQ_REGRESSOR:
120
- template_name = "quant_regression.template"
121
- model_script_dir = "quant_regression"
120
+ template_name = "mapie.template"
121
+ model_script_dir = "uq_models"
122
122
  elif template_params["model_type"] == ModelType.ENSEMBLE_REGRESSOR:
123
123
  template_name = "ensemble_xgb.template"
124
124
  model_script_dir = "ensemble_xgb"
@@ -0,0 +1,492 @@
1
+ # Model: XGBoost for point predictions + LightGBM with MAPIE for conformalized intervals
2
+ from mapie.regression import ConformalizedQuantileRegressor
3
+ from lightgbm import LGBMRegressor
4
+ from xgboost import XGBRegressor
5
+ from sklearn.model_selection import train_test_split
6
+
7
+ # Model Performance Scores
8
+ from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error
9
+
10
+ from io import StringIO
11
+ import json
12
+ import argparse
13
+ import joblib
14
+ import os
15
+ import numpy as np
16
+ import pandas as pd
17
+ from typing import List, Tuple
18
+
19
+ # Template Placeholders
20
+ TEMPLATE_PARAMS = {
21
+ "target": "udm_asy_res_free_percent",
22
+ "features": ['vsa_estate6', 'naromatom', 'mollogp', 'fr_nh2', 'mp', 'c2sp2', 'xch_3d', 'axp_6d', 'bcut2d_mrhi', 'fr_benzene', 'mz', 'slogp_vsa6', 'fr_halogen', 'bcut2d_mwhi', 'vsa_estate4', 'slogp_vsa3', 'estate_vsa5', 'minestateindex', 'axp_3dv', 'estate_vsa3', 'vsa_estate9', 'molwt', 'hallkieralpha', 'fpdensitymorgan1', 'peoe_vsa13', 'xpc_5d', 'bcut2d_chghi', 'peoe_vsa8', 'axp_0dv', 'axp_2d', 'chi2v', 'bcut2d_logphi', 'axp_5d', 'peoe_vsa2', 'estate_vsa6', 'qed', 'numrotatablebonds', 'xc_3dv', 'peoe_vsa3', 'balabanj', 'slogp_vsa5', 'mv', 'vsa_estate2', 'bcut2d_mwlow', 'xch_7d', 'chi3n', 'vsa_estate8', 'estate_vsa4', 'xp_7dv', 'fr_nh1', 'vsa_estate3', 'fr_ketone_topliss', 'minpartialcharge', 'phi', 'peoe_vsa10', 'vsa_estate7', 'estate_vsa7', 'tpsa', 'kappa3', 'kappa2', 'bcut2d_logplow', 'xch_6d', 'maxpartialcharge', 'vsa_estate1', 'peoe_vsa9', 'axp_1d', 'fr_ar_n', 'chi2n', 'vsa_estate5', 'xp_4dv', 'slogp_vsa10', 'num_stereobonds', 'peoe_vsa11', 'bcut2d_chglo', 'chi1v', 'peoe_vsa7', 'bertzct', 'axp_2dv', 'estate_vsa2', 'smr_vsa9', 'peoe_vsa6', 'num_s_centers', 'num_r_centers', 'xch_7dv', 'xc_5d', 'axp_4dv', 'xc_5dv', 'mi', 'xc_3d', 'fpdensitymorgan2', 'xp_0dv', 'nhohcount', 'numatomstereocenters', 'mse', 'smr_vsa3', 'peoe_vsa12', 'nocount', 'fpdensitymorgan3', 'minabsestateindex', 'bcut2d_mrlow', 'axp_5dv', 'sz', 'vsa_estate10', 'axp_3d', 'xch_6dv', 'xch_4d', 'xc_6d', 'estate_vsa8', 'mpe', 'smr_vsa7', 'numhdonors', 'smr_vsa1', 'xp_5d', 'fr_para_hydroxylation', 'chi3v', 'xpc_6dv', 'nbase', 'heavyatommolwt', 'avgipc', 'maxestateindex', 'smr_vsa6', 'fr_bicyclic', 'xc_4dv', 'xp_7d', 'smr_vsa5', 'xpc_4d', 'smr_vsa4', 'peoe_vsa4', 'numheteroatoms', 'fr_nhpyrrole', 'axp_4d', 'smr_vsa10', 'xp_6d', 'sps', 'mare', 'slogp_vsa2', 'axp_0d', 'slogp_vsa4', 'fr_al_oh', 'numheterocycles', 'labuteasa', 'xp_3d', 'chi4n', 'fractioncsp3', 'maxabspartialcharge', 'fr_al_oh_notert', 'peoe_vsa1', 'axp_7dv', 'slogp_vsa11', 'peoe_vsa5', 'xpc_5dv', 'xpc_6d', 'xp_2d', 'xp_3dv', 'fr_ndealkylation1', 'axp_7d', 'estate_vsa9', 'molmr', 'num_stereocenters', 'si', 'estate_vsa1', 'xc_6dv', 'chi0v', 'fr_oxazole', 'axp_6dv', 'xp_6dv', 'xp_4d', 'numaliphaticheterocycles', 'fr_imine', 'fr_imidazole', 'xp_5dv', 'fr_piperdine', 'slogp_vsa7', 'chi1', 'c1sp2', 'numaromaticheterocycles', 'xpc_4dv', 'c3sp2', 'fr_aniline', 'fr_piperzine', 'axp_1dv', 'xch_4dv', 'chi4v', 'chi1n', 'minabspartialcharge', 'slogp_vsa1', 'fr_nh0', 'chi0n', 'c2sp3', 'xc_4d', 'xch_5dv', 'peoe_vsa14', 'xch_5d', 'numsaturatedrings', 'fr_pyridine', 'kappa1', 'slogp_vsa8', 'xp_2dv', 'fr_ar_coo', 'numvalenceelectrons'],
23
+ "compressed_features": [],
24
+ "train_all_data": True,
25
+ "hyperparameters": {},
26
+ }
27
+
28
+
29
+ # Function to check if dataframe is empty
30
+ def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
31
+ """
32
+ Check if the provided dataframe is empty and raise an exception if it is.
33
+
34
+ Args:
35
+ df (pd.DataFrame): DataFrame to check
36
+ df_name (str): Name of the DataFrame
37
+ """
38
+ if df.empty:
39
+ msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
40
+ print(msg)
41
+ raise ValueError(msg)
42
+
43
+
44
+ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
45
+ """
46
+ Matches and renames DataFrame columns to match model feature names (case-insensitive).
47
+ Prioritizes exact matches, then case-insensitive matches.
48
+
49
+ Raises ValueError if any model features cannot be matched.
50
+ """
51
+ df_columns_lower = {col.lower(): col for col in df.columns}
52
+ rename_dict = {}
53
+ missing = []
54
+ for feature in model_features:
55
+ if feature in df.columns:
56
+ continue # Exact match
57
+ elif feature.lower() in df_columns_lower:
58
+ rename_dict[df_columns_lower[feature.lower()]] = feature
59
+ else:
60
+ missing.append(feature)
61
+
62
+ if missing:
63
+ raise ValueError(f"Features not found: {missing}")
64
+
65
+ # Rename the DataFrame columns to match the model features
66
+ return df.rename(columns=rename_dict)
67
+
68
+
69
+ def convert_categorical_types(df: pd.DataFrame, features: list, category_mappings={}) -> tuple:
70
+ """
71
+ Converts appropriate columns to categorical type with consistent mappings.
72
+
73
+ Args:
74
+ df (pd.DataFrame): The DataFrame to process.
75
+ features (list): List of feature names to consider for conversion.
76
+ category_mappings (dict, optional): Existing category mappings. If empty dict, we're in
77
+ training mode. If populated, we're in inference mode.
78
+
79
+ Returns:
80
+ tuple: (processed DataFrame, category mappings dictionary)
81
+ """
82
+ # Training mode
83
+ if category_mappings == {}:
84
+ for col in df.select_dtypes(include=["object", "string"]):
85
+ if col in features and df[col].nunique() < 20:
86
+ print(f"Training mode: Converting {col} to category")
87
+ df[col] = df[col].astype("category")
88
+ category_mappings[col] = df[col].cat.categories.tolist() # Store category mappings
89
+
90
+ # Inference mode
91
+ else:
92
+ for col, categories in category_mappings.items():
93
+ if col in df.columns:
94
+ print(f"Inference mode: Applying categorical mapping for {col}")
95
+ df[col] = pd.Categorical(df[col], categories=categories) # Apply consistent categorical mapping
96
+
97
+ return df, category_mappings
98
+
99
+
100
+ def decompress_features(
101
+ df: pd.DataFrame, features: List[str], compressed_features: List[str]
102
+ ) -> Tuple[pd.DataFrame, List[str]]:
103
+ """Prepare features for the model by decompressing bitstring features
104
+
105
+ Args:
106
+ df (pd.DataFrame): The features DataFrame
107
+ features (List[str]): Full list of feature names
108
+ compressed_features (List[str]): List of feature names to decompress (bitstrings)
109
+
110
+ Returns:
111
+ pd.DataFrame: DataFrame with the decompressed features
112
+ List[str]: Updated list of feature names after decompression
113
+
114
+ Raises:
115
+ ValueError: If any missing values are found in the specified features
116
+ """
117
+
118
+ # Check for any missing values in the required features
119
+ missing_counts = df[features].isna().sum()
120
+ if missing_counts.any():
121
+ missing_features = missing_counts[missing_counts > 0]
122
+ print(
123
+ f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
124
+ "WARNING: You might want to remove/replace all NaN values before processing."
125
+ )
126
+
127
+ # Decompress the specified compressed features
128
+ decompressed_features = features.copy()
129
+ for feature in compressed_features:
130
+ if (feature not in df.columns) or (feature not in features):
131
+ print(f"Feature '{feature}' not in the features list, skipping decompression.")
132
+ continue
133
+
134
+ # Remove the feature from the list of features to avoid duplication
135
+ decompressed_features.remove(feature)
136
+
137
+ # Handle all compressed features as bitstrings
138
+ bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
139
+ prefix = feature[:3]
140
+
141
+ # Create all new columns at once - avoids fragmentation
142
+ new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
143
+ new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
144
+
145
+ # Add to features list
146
+ decompressed_features.extend(new_col_names)
147
+
148
+ # Drop original column and concatenate new ones
149
+ df = df.drop(columns=[feature])
150
+ df = pd.concat([df, new_df], axis=1)
151
+
152
+ return df, decompressed_features
153
+
154
+
155
+ if __name__ == "__main__":
156
+ # Template Parameters
157
+ target = TEMPLATE_PARAMS["target"]
158
+ features = TEMPLATE_PARAMS["features"]
159
+ orig_features = features.copy()
160
+ compressed_features = TEMPLATE_PARAMS["compressed_features"]
161
+ train_all_data = TEMPLATE_PARAMS["train_all_data"]
162
+ hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
163
+ validation_split = 0.2
164
+
165
+ # Script arguments for input/output directories
166
+ parser = argparse.ArgumentParser()
167
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
168
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
169
+ parser.add_argument(
170
+ "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
171
+ )
172
+ args = parser.parse_args()
173
+
174
+ # Read the training data into DataFrames
175
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
176
+ print(f"Training Files: {training_files}")
177
+
178
+ # Combine files and read them all into a single pandas dataframe
179
+ all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
180
+
181
+ # Check if the dataframe is empty
182
+ check_dataframe(all_df, "training_df")
183
+
184
+ # Features/Target output
185
+ print(f"Target: {target}")
186
+ print(f"Features: {str(features)}")
187
+
188
+ # Convert any features that might be categorical to 'category' type
189
+ all_df, category_mappings = convert_categorical_types(all_df, features)
190
+
191
+ # If we have compressed features, decompress them
192
+ if compressed_features:
193
+ print(f"Decompressing features {compressed_features}...")
194
+ all_df, features = decompress_features(all_df, features, compressed_features)
195
+
196
+ # Do we want to train on all the data?
197
+ if train_all_data:
198
+ print("Training on ALL of the data")
199
+ df_train = all_df.copy()
200
+ df_val = all_df.copy()
201
+
202
+ # Does the dataframe have a training column?
203
+ elif "training" in all_df.columns:
204
+ print("Found training column, splitting data based on training column")
205
+ df_train = all_df[all_df["training"]]
206
+ df_val = all_df[~all_df["training"]]
207
+ else:
208
+ # Just do a random training Split
209
+ print("WARNING: No training column found, splitting data with random state=42")
210
+ df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
211
+ print(f"FIT/TRAIN: {df_train.shape}")
212
+ print(f"VALIDATION: {df_val.shape}")
213
+
214
+ # Prepare features and targets for training
215
+ X_train = df_train[features]
216
+ X_validate = df_val[features]
217
+ y_train = df_train[target]
218
+ y_validate = df_val[target]
219
+
220
+ # Train XGBoost for point predictions
221
+ print("\nTraining XGBoost for point predictions...")
222
+ print(f" Hyperparameters: {hyperparameters}")
223
+ xgb_model = XGBRegressor(enable_categorical=True, **hyperparameters)
224
+ xgb_model.fit(X_train, y_train)
225
+
226
+ # Evaluate XGBoost performance
227
+ y_pred_xgb = xgb_model.predict(X_validate)
228
+ xgb_rmse = root_mean_squared_error(y_validate, y_pred_xgb)
229
+ xgb_mae = mean_absolute_error(y_validate, y_pred_xgb)
230
+ xgb_r2 = r2_score(y_validate, y_pred_xgb)
231
+
232
+ print(f"\nXGBoost Point Prediction Performance:")
233
+ print(f"RMSE: {xgb_rmse:.3f}")
234
+ print(f"MAE: {xgb_mae:.3f}")
235
+ print(f"R2: {xgb_r2:.3f}")
236
+
237
+ # Define confidence levels we want to model
238
+ confidence_levels = [0.50, 0.68, 0.80, 0.90, 0.95] # 50%, 68%, 80%, 90%, 95% confidence intervals
239
+
240
+ # Store MAPIE models for each confidence level
241
+ mapie_models = {}
242
+
243
+ # Train models for each confidence level
244
+ for confidence_level in confidence_levels:
245
+ alpha = 1 - confidence_level
246
+ lower_q = alpha / 2
247
+ upper_q = 1 - alpha / 2
248
+
249
+ print(f"\nTraining quantile models for {confidence_level * 100:.0f}% confidence interval...")
250
+ print(f" Quantiles: {lower_q:.3f}, {upper_q:.3f}, 0.500")
251
+
252
+ # Train three models for this confidence level
253
+ quantile_estimators = []
254
+ for q in [lower_q, upper_q, 0.5]:
255
+ print(f" Training model for quantile {q:.3f}...")
256
+ est = LGBMRegressor(
257
+ objective="quantile",
258
+ alpha=q,
259
+ n_estimators=1000,
260
+ max_depth=6,
261
+ learning_rate=0.01,
262
+ num_leaves=31,
263
+ min_child_samples=20,
264
+ subsample=0.8,
265
+ colsample_bytree=0.8,
266
+ random_state=42,
267
+ verbose=-1,
268
+ force_col_wise=True,
269
+ )
270
+ est.fit(X_train, y_train)
271
+ quantile_estimators.append(est)
272
+
273
+ # Create MAPIE CQR model for this confidence level
274
+ print(f" Setting up MAPIE CQR for {confidence_level * 100:.0f}% confidence...")
275
+ mapie_model = ConformalizedQuantileRegressor(
276
+ quantile_estimators, confidence_level=confidence_level, prefit=True
277
+ )
278
+
279
+ # Conformalize the model
280
+ print(f" Conformalizing with validation data...")
281
+ mapie_model.conformalize(X_validate, y_validate)
282
+
283
+ # Store the model
284
+ mapie_models[f"mapie_{confidence_level:.2f}"] = mapie_model
285
+
286
+ # Validate coverage for this confidence level
287
+ y_pred, y_pis = mapie_model.predict_interval(X_validate)
288
+ coverage = np.mean((y_validate >= y_pis[:, 0, 0]) & (y_validate <= y_pis[:, 1, 0]))
289
+ print(f" Coverage: Target={confidence_level * 100:.0f}%, Empirical={coverage * 100:.1f}%")
290
+
291
+ print(f"\nOverall Model Performance Summary:")
292
+ print(f"XGBoost RMSE: {xgb_rmse:.3f}")
293
+ print(f"XGBoost MAE: {xgb_mae:.3f}")
294
+ print(f"XGBoost R2: {xgb_r2:.3f}")
295
+ print(f"NumRows: {len(df_val)}")
296
+
297
+ # Analyze interval widths across confidence levels
298
+ print(f"\nInterval Width Analysis:")
299
+ for conf_level in confidence_levels:
300
+ model = mapie_models[f"mapie_{conf_level:.2f}"]
301
+ _, y_pis = model.predict_interval(X_validate)
302
+ widths = y_pis[:, 1, 0] - y_pis[:, 0, 0]
303
+ print(f" {conf_level * 100:.0f}% CI: Mean width={np.mean(widths):.3f}, Std={np.std(widths):.3f}")
304
+
305
+ # Save the trained XGBoost model
306
+ joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
307
+
308
+ # Save all MAPIE models
309
+ for model_name, model in mapie_models.items():
310
+ joblib.dump(model, os.path.join(args.model_dir, f"{model_name}.joblib"))
311
+
312
+ # Save the feature list
313
+ with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
314
+ json.dump(features, fp)
315
+
316
+ # Save category mappings if any
317
+ if category_mappings:
318
+ with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as fp:
319
+ json.dump(category_mappings, fp)
320
+
321
+ # Save model configuration
322
+ model_config = {
323
+ "model_type": "XGBoost_MAPIE_CQR_LightGBM",
324
+ "confidence_levels": confidence_levels,
325
+ "n_features": len(features),
326
+ "target": target,
327
+ "validation_metrics": {
328
+ "xgb_rmse": float(xgb_rmse),
329
+ "xgb_mae": float(xgb_mae),
330
+ "xgb_r2": float(xgb_r2),
331
+ "n_validation": len(df_val),
332
+ },
333
+ }
334
+ with open(os.path.join(args.model_dir, "model_config.json"), "w") as fp:
335
+ json.dump(model_config, fp, indent=2)
336
+
337
+ print(f"\nModel training complete!")
338
+ print(f"Saved 1 XGBoost model and {len(mapie_models)} MAPIE models to {args.model_dir}")
339
+
340
+
341
+ #
342
+ # Inference Section
343
+ #
344
+ def model_fn(model_dir) -> dict:
345
+ """Load XGBoost and all MAPIE models from the specified directory."""
346
+
347
+ # Load model configuration to know which models to load
348
+ with open(os.path.join(model_dir, "model_config.json")) as fp:
349
+ config = json.load(fp)
350
+
351
+ # Load XGBoost regressor
352
+ xgb_path = os.path.join(model_dir, "xgb_model.joblib")
353
+ xgb_model = joblib.load(xgb_path)
354
+
355
+ # Load all MAPIE models
356
+ mapie_models = {}
357
+ for conf_level in config["confidence_levels"]:
358
+ model_name = f"mapie_{conf_level:.2f}"
359
+ mapie_models[model_name] = joblib.load(os.path.join(model_dir, f"{model_name}.joblib"))
360
+
361
+ # Load category mappings if they exist
362
+ category_mappings = {}
363
+ category_path = os.path.join(model_dir, "category_mappings.json")
364
+ if os.path.exists(category_path):
365
+ with open(category_path) as fp:
366
+ category_mappings = json.load(fp)
367
+
368
+ return {
369
+ "xgb_model": xgb_model,
370
+ "mapie_models": mapie_models,
371
+ "confidence_levels": config["confidence_levels"],
372
+ "category_mappings": category_mappings,
373
+ }
374
+
375
+
376
+ def input_fn(input_data, content_type):
377
+ """Parse input data and return a DataFrame."""
378
+ if not input_data:
379
+ raise ValueError("Empty input data is not supported!")
380
+
381
+ # Decode bytes to string if necessary
382
+ if isinstance(input_data, bytes):
383
+ input_data = input_data.decode("utf-8")
384
+
385
+ if "text/csv" in content_type:
386
+ return pd.read_csv(StringIO(input_data))
387
+ elif "application/json" in content_type:
388
+ return pd.DataFrame(json.loads(input_data))
389
+ else:
390
+ raise ValueError(f"{content_type} not supported!")
391
+
392
+
393
+ def output_fn(output_df, accept_type):
394
+ """Supports both CSV and JSON output formats."""
395
+ if "text/csv" in accept_type:
396
+ # Convert categorical columns to string to avoid fillna issues
397
+ for col in output_df.select_dtypes(include=["category"]).columns:
398
+ output_df[col] = output_df[col].astype(str)
399
+ csv_output = output_df.fillna("N/A").to_csv(index=False)
400
+ return csv_output, "text/csv"
401
+ elif "application/json" in accept_type:
402
+ return output_df.to_json(orient="records"), "application/json"
403
+ else:
404
+ raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
405
+
406
+
407
+ def predict_fn(df, models) -> pd.DataFrame:
408
+ """Make predictions using XGBoost for point estimates and MAPIE for conformalized intervals
409
+
410
+ Args:
411
+ df (pd.DataFrame): The input DataFrame
412
+ models (dict): Dictionary containing XGBoost and MAPIE models
413
+
414
+ Returns:
415
+ pd.DataFrame: DataFrame with XGBoost predictions and conformalized intervals
416
+ """
417
+
418
+ # Flag for outlier stretch adjustment for the prediction intervals
419
+ # if the predicted values are outside the intervals
420
+ outlier_stretch = False
421
+
422
+ # Grab our feature columns (from training)
423
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
424
+ with open(os.path.join(model_dir, "feature_columns.json")) as fp:
425
+ model_features = json.load(fp)
426
+
427
+ # Match features in a case-insensitive manner
428
+ matched_df = match_features_case_insensitive(df, model_features)
429
+
430
+ # Apply categorical mappings if they exist
431
+ if models.get("category_mappings"):
432
+ matched_df, _ = convert_categorical_types(matched_df, model_features, models["category_mappings"])
433
+
434
+ # Get features for prediction
435
+ X = matched_df[model_features]
436
+
437
+ # Get XGBoost point predictions
438
+ df["prediction"] = models["xgb_model"].predict(X)
439
+
440
+ # Get predictions from each MAPIE model for conformalized intervals
441
+ for conf_level in models["confidence_levels"]:
442
+ model_name = f"mapie_{conf_level:.2f}"
443
+ model = models["mapie_models"][model_name]
444
+
445
+ # Get conformalized predictions
446
+ y_pred, y_pis = model.predict_interval(X)
447
+
448
+ # Map confidence levels to quantile names
449
+ if conf_level == 0.50: # 50% CI
450
+ df["q_25"] = y_pis[:, 0, 0]
451
+ df["q_75"] = y_pis[:, 1, 0]
452
+ elif conf_level == 0.68: # 68% CI
453
+ df["q_16"] = y_pis[:, 0, 0]
454
+ df["q_84"] = y_pis[:, 1, 0]
455
+ elif conf_level == 0.80: # 80% CI
456
+ df["q_10"] = y_pis[:, 0, 0]
457
+ df["q_90"] = y_pis[:, 1, 0]
458
+ elif conf_level == 0.90: # 90% CI
459
+ df["q_05"] = y_pis[:, 0, 0]
460
+ df["q_95"] = y_pis[:, 1, 0]
461
+ elif conf_level == 0.95: # 95% CI
462
+ df["q_025"] = y_pis[:, 0, 0]
463
+ df["q_975"] = y_pis[:, 1, 0]
464
+
465
+ # Add median (q_50) from XGBoost prediction
466
+ df["q_50"] = df["prediction"]
467
+
468
+ # Calculate a pseudo-standard deviation from the 68% interval width
469
+ df["prediction_std"] = (df["q_84"] - df["q_16"]) / 2.0
470
+
471
+ # Reorder the quantile columns for easier reading
472
+ quantile_cols = ["q_025", "q_05", "q_10", "q_16", "q_25", "q_75", "q_84", "q_90", "q_95", "q_975"]
473
+ other_cols = [col for col in df.columns if col not in quantile_cols]
474
+ df = df[other_cols + quantile_cols]
475
+
476
+ # Adjust the outer quantiles to ensure they encompass the prediction
477
+ if outlier_stretch:
478
+ # Lower intervals adjustments
479
+ df["q_025"] = np.minimum(df["q_025"], df["prediction"])
480
+ df["q_05"] = np.minimum(df["q_05"], df["prediction"])
481
+ df["q_10"] = np.minimum(df["q_10"], df["prediction"])
482
+ df["q_16"] = np.minimum(df["q_16"], df["prediction"])
483
+ df["q_25"] = np.minimum(df["q_25"], df["prediction"])
484
+
485
+ # Upper intervals adjustments
486
+ df["q_75"] = np.maximum(df["q_75"], df["prediction"])
487
+ df["q_84"] = np.maximum(df["q_84"], df["prediction"])
488
+ df["q_90"] = np.maximum(df["q_90"], df["prediction"])
489
+ df["q_95"] = np.maximum(df["q_95"], df["prediction"])
490
+ df["q_975"] = np.maximum(df["q_975"], df["prediction"])
491
+
492
+ return df
@@ -0,0 +1 @@
1
+ # Note: Most libs are already in the training/inference images, ONLY specify additional libs here
@@ -560,7 +560,7 @@ class WorkbenchShell:
560
560
  from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
561
561
 
562
562
  # Get kwargs
563
- theme = kwargs.get("theme", "dark")
563
+ theme = kwargs.get("theme", "midnight_blue")
564
564
 
565
565
  plugin_test = PluginUnitTest(plugin_class, theme=theme, input_data=data, **kwargs)
566
566
 
@@ -159,7 +159,7 @@ class ScatterPlot(PluginInterface):
159
159
  self.df = self.df.drop(columns=aws_cols, errors="ignore")
160
160
 
161
161
  # Set hover columns and custom data
162
- self.hover_columns = kwargs.get("hover_columns", self.df.columns.tolist()[:10])
162
+ self.hover_columns = kwargs.get("hover_columns", sorted(self.df.columns.tolist()[:15]))
163
163
  self.suppress_hover_display = kwargs.get("suppress_hover_display", False)
164
164
  self.custom_data = kwargs.get("custom_data", [])
165
165
 
@@ -427,7 +427,7 @@ if __name__ == "__main__":
427
427
 
428
428
  from workbench.api import DFStore
429
429
 
430
- df = DFStore().get("/workbench/models/aqsol-uq/auto_inference")
430
+ df = DFStore().get("/workbench/models/aqsol-uq-100/full_cross_fold_inference")
431
431
 
432
432
  # Run the Unit Test on the Plugin
433
433
  PluginUnitTest(
@@ -436,6 +436,6 @@ if __name__ == "__main__":
436
436
  theme="midnight_blue",
437
437
  x="solubility",
438
438
  y="prediction",
439
- color="residuals_abs",
439
+ color="prediction_std",
440
440
  suppress_hover_display=True,
441
441
  ).run()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: workbench
3
- Version: 0.8.181
3
+ Version: 0.8.182
4
4
  Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
5
5
  Author-email: SuperCowPowers LLC <support@supercowpowers.com>
6
6
  License-Expression: MIT
@@ -54,7 +54,7 @@ workbench/core/artifacts/cached_artifact_mixin.py,sha256=ngqFLZ4cQx_TFouXZgXZQsv
54
54
  workbench/core/artifacts/data_capture_core.py,sha256=q8f79rRTYiZ7T4IQRWXl8ZvPpcvZyNxYERwvo8o0OQc,14858
55
55
  workbench/core/artifacts/data_source_abstract.py,sha256=5IRCzFVK-17cd4NXPMRfx99vQAmQ0WHE5jcm5RfsVTg,10619
56
56
  workbench/core/artifacts/data_source_factory.py,sha256=YL_tA5fsgubbB3dPF6T4tO0rGgz-6oo3ge4i_YXVC-M,2380
57
- workbench/core/artifacts/endpoint_core.py,sha256=iOBKnlfG3xVj9-Z9MX_IxxnSs6jMNXJXLgCsnWgyUqM,51657
57
+ workbench/core/artifacts/endpoint_core.py,sha256=b3cNj1UnlHmQdG1C8bmD2jWpD4h-O6F-75fWSm01uGU,51850
58
58
  workbench/core/artifacts/feature_set_core.py,sha256=7b1o_PzxtwaYC-W2zxlkltiO0fYULA8CVGWwHNmqgtI,31457
59
59
  workbench/core/artifacts/model_core.py,sha256=ECDwQ0qM5qb1yGJ07U70BVdfkrW9m7p9e6YJWib3uR0,50855
60
60
  workbench/core/artifacts/monitor_core.py,sha256=M307yz7tEzOEHgv-LmtVy9jKjSbM98fHW3ckmNYrwlU,27897
@@ -122,7 +122,7 @@ workbench/core/views/training_view.py,sha256=UWW8Asxtm_kV7Z8NooitMA4xC5vTc7lSWwT
122
122
  workbench/core/views/view.py,sha256=Ujzw6zLROP9oKfKm3zJwaOyfpyjh5uM9fAu1i3kUOig,11764
123
123
  workbench/core/views/view_utils.py,sha256=y0YuPW-90nAfgAD1UW_49-j7Mvncfm7-5rV8I_97CK8,12274
124
124
  workbench/core/views/storage/mdq_view.py,sha256=qf_ep1KwaXOIfO930laEwNIiCYP7VNOqjE3VdHfopRE,5195
125
- workbench/model_scripts/script_generation.py,sha256=dL23XYwEsHIStc7i53DtF_47FqOrI9gq0kQAT6sNpZ8,7923
125
+ workbench/model_scripts/script_generation.py,sha256=dLxVRrvrrI_HQatJRAXta6UEbFFbkgITNvDJllQZyCM,7905
126
126
  workbench/model_scripts/custom_models/chem_info/Readme.md,sha256=mH1lxJ4Pb7F5nBnVXaiuxpi8zS_yjUw_LBJepVKXhlA,574
127
127
  workbench/model_scripts/custom_models/chem_info/mol_descriptors.py,sha256=c8gkHZ-8s3HJaW9zN9pnYGK7YVW8Y0xFqQ1G_ysrF2Y,18789
128
128
  workbench/model_scripts/custom_models/chem_info/mol_standardize.py,sha256=qPLCdVMSXMOWN-01O1isg2zq7eQyFAI0SNatHkRq1uw,17524
@@ -140,8 +140,7 @@ workbench/model_scripts/custom_models/uq_models/Readme.md,sha256=UVpL-lvtTrLqwBe
140
140
  workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template,sha256=ca3CaAk6HVuNv1HnPgABTzRY3oDrRxomjgD4V1ZDwoc,6448
141
141
  workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template,sha256=xlKLHeLQkScONnrlbAGIsrCm2wwsvcfv4Vdrw4nlc_8,13457
142
142
  workbench/model_scripts/custom_models/uq_models/gaussian_process.template,sha256=3nMlCi8nEbc4N-MQTzjfIcljfDQkUmWeLBfmd18m5fg,6632
143
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=PCCDF3DuiH13wMltuCzorVb79uLjKuX_9-ryuooQK5o,19131
144
- workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=8VzoP-Wp3ECVIDqXVkiTS6bwmn3cd3dDZ2WjYPzXTi8,18955
143
+ workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=Y89qD3gJ8wx9klXXDUQNfoLTImVFcdYLfRz-SA8mppE,21461
145
144
  workbench/model_scripts/custom_models/uq_models/meta_uq.template,sha256=XTfhODRaHlI1jZGo9pSe-TqNsk2_nuSw0xMO2fKzDv8,14011
146
145
  workbench/model_scripts/custom_models/uq_models/ngboost.template,sha256=v1rviYTJGJnQRGgAyveXhOQlS-WFCTlc2vdnWq6HIXk,8241
147
146
  workbench/model_scripts/custom_models/uq_models/proximity.py,sha256=zqmNlX70LnWXr5fdtFFQppSNTLjlOciQVrjGr-g9jRE,13716
@@ -154,16 +153,17 @@ workbench/model_scripts/ensemble_xgb/requirements.txt,sha256=jWlGc7HH7vqyukTm38L
154
153
  workbench/model_scripts/pytorch_model/generated_model_script.py,sha256=Mr1IMQJE_ML899qjzhjkrP521IjvcAvqU0pk--FB7KY,22356
155
154
  workbench/model_scripts/pytorch_model/pytorch.template,sha256=_gRp6DH294FLxF21UpSTq7s9RFfrLjViKvjXQ4yDfBQ,21999
156
155
  workbench/model_scripts/pytorch_model/requirements.txt,sha256=ICS5nW0wix44EJO2tJszJSaUrSvhSfdedn6FcRInGx4,181
157
- workbench/model_scripts/quant_regression/quant_regression.template,sha256=2F25lZ7m_VafHvuGrC__R3uB1NzKgZu94eWJS9sWpYg,9783
158
- workbench/model_scripts/quant_regression/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
159
156
  workbench/model_scripts/scikit_learn/generated_model_script.py,sha256=c73ZpJBlU5k13Nx-ZDkLXu7da40CYyhwjwwmuPq6uLg,12870
160
157
  workbench/model_scripts/scikit_learn/requirements.txt,sha256=aVvwiJ3LgBUhM_PyFlb2gHXu_kpGPho3ANBzlOkfcvs,107
161
158
  workbench/model_scripts/scikit_learn/scikit_learn.template,sha256=QQvqx-eX9ZTbYmyupq6R6vIQwosmsmY_MRBPaHyfjdk,12586
159
+ workbench/model_scripts/uq_models/generated_model_script.py,sha256=Y89qD3gJ8wx9klXXDUQNfoLTImVFcdYLfRz-SA8mppE,21461
160
+ workbench/model_scripts/uq_models/mapie.template,sha256=8VzoP-Wp3ECVIDqXVkiTS6bwmn3cd3dDZ2WjYPzXTi8,18955
161
+ workbench/model_scripts/uq_models/requirements.txt,sha256=fw7T7t_YJAXK3T6Ysbesxh_Agx_tv0oYx72cEBTqRDY,98
162
162
  workbench/model_scripts/xgb_model/generated_model_script.py,sha256=Tbn7EMXxZZO8rDdKQ5fYCbpltACsMXNvuusLL9p-U5c,22319
163
163
  workbench/model_scripts/xgb_model/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
164
164
  workbench/model_scripts/xgb_model/xgb_model.template,sha256=0uXknIEqgUaIFUfu2gfkxa3WHUr8HBBqBepGUTDvrhQ,17917
165
165
  workbench/repl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
166
- workbench/repl/workbench_shell.py,sha256=Vhg4BQr2r4D4ymekrVtOFi0MaRvaH4V2UgcWRgvN_3U,22122
166
+ workbench/repl/workbench_shell.py,sha256=Gw7-tWCc6k7nwWrEimg-YuQOnWCsOzc2QMiX3Rt-Wpk,22131
167
167
  workbench/resources/open_source_api.key,sha256=3S0OTblsmC0msUPdE_dbBmI83xJNmYscuwLJ57JmuOc,433
168
168
  workbench/resources/signature_verify_pub.pem,sha256=V3-u-3_z2PH-805ybkKvzDOBwAbvHxcKn0jLBImEtzM,272
169
169
  workbench/scripts/check_double_bond_stereo.py,sha256=p5hnL54Weq77ES0HCELq9JeoM-PyUGkvVSeWYF2dKyo,7776
@@ -278,7 +278,7 @@ workbench/web_interface/components/plugins/molecule_panel.py,sha256=xGCEI5af8F5l
278
278
  workbench/web_interface/components/plugins/molecule_viewer.py,sha256=xavixcu4RNzh6Nj_-3-XlK09DgpNx5jGmo3wEPNftiE,4529
279
279
  workbench/web_interface/components/plugins/pipeline_details.py,sha256=caiFIakHk-1dGGNW7wlio2X7iAm2_tCNbSjDzoRWGEk,5534
280
280
  workbench/web_interface/components/plugins/proximity_mini_graph.py,sha256=b_YYnvLczJUhaDbrrXnyjUDYF7C4R4ufCZXtJiyRnJ0,7233
281
- workbench/web_interface/components/plugins/scatter_plot.py,sha256=j8J1-m_xZjG0hgaMevbRvKaTAze0GglpMMDlP3WA_6U,19106
281
+ workbench/web_interface/components/plugins/scatter_plot.py,sha256=8tYnHlgi2UnuKLoxi9-89QF8ZHFrPhpRkzBY5gzlVdo,19130
282
282
  workbench/web_interface/components/plugins/shap_summary_plot.py,sha256=_V-xxVehU-60IpYWvAqTW5x_6u6pbjz9mI8r0ppIXKg,9454
283
283
  workbench/web_interface/page_views/data_sources_page_view.py,sha256=SXNUG6n_eP9i4anddEXd5E9rMRt-R2EyNR-bbe8OQK4,4673
284
284
  workbench/web_interface/page_views/endpoints_page_view.py,sha256=EI3hA18pEn-mAPEzGAw0W-wM8qJR2j_8pQEJlbJCENk,2770
@@ -287,9 +287,9 @@ workbench/web_interface/page_views/main_page.py,sha256=X4-KyGTKLAdxR-Zk2niuLJB2Y
287
287
  workbench/web_interface/page_views/models_page_view.py,sha256=M0bdC7bAzLyIaE2jviY12FF4abdMFZmg6sFuOY_LaGI,2650
288
288
  workbench/web_interface/page_views/page_view.py,sha256=Gh6YnpOGlUejx-bHZAf5pzqoQ1H1R0OSwOpGhOBO06w,455
289
289
  workbench/web_interface/page_views/pipelines_page_view.py,sha256=v2pxrIbsHBcYiblfius3JK766NZ7ciD2yPx0t3E5IJo,2656
290
- workbench-0.8.181.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
291
- workbench-0.8.181.dist-info/METADATA,sha256=DZJcZg7gfOSERy7Y-Qia0fY9l9jcloqrkwk9OGaNAc4,9210
292
- workbench-0.8.181.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
293
- workbench-0.8.181.dist-info/entry_points.txt,sha256=zPFPruY9uayk8-wsKrhfnIyIB6jvZOW_ibyllEIsLWo,356
294
- workbench-0.8.181.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
295
- workbench-0.8.181.dist-info/RECORD,,
290
+ workbench-0.8.182.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
291
+ workbench-0.8.182.dist-info/METADATA,sha256=mNhjFYFQS3MP33Z2CCAI2ydB_BPSpTxlvSAQiwTQdy8,9210
292
+ workbench-0.8.182.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
293
+ workbench-0.8.182.dist-info/entry_points.txt,sha256=zPFPruY9uayk8-wsKrhfnIyIB6jvZOW_ibyllEIsLWo,356
294
+ workbench-0.8.182.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
295
+ workbench-0.8.182.dist-info/RECORD,,
@@ -1,274 +0,0 @@
1
- # Imports for XGB Model
2
- import xgboost as xgb
3
- import awswrangler as wr
4
- from sklearn.model_selection import train_test_split
5
-
6
- # Model Performance Scores
7
- from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error
8
-
9
- from io import StringIO
10
- import json
11
- import argparse
12
- import os
13
- import pandas as pd
14
-
15
- # Template Placeholders
16
- TEMPLATE_PARAMS = {
17
- "model_type": "{{model_type}}",
18
- "target_column": "{{target_column}}",
19
- "features": "{{feature_list}}",
20
- "model_metrics_s3_path": "{{model_metrics_s3_path}}",
21
- "train_all_data": "{{train_all_data}}",
22
- }
23
-
24
-
25
- # Function to check if dataframe is empty
26
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
27
- """
28
- Check if the provided dataframe is empty and raise an exception if it is.
29
-
30
- Args:
31
- df (pd.DataFrame): DataFrame to check
32
- df_name (str): Name of the DataFrame
33
- """
34
- if df.empty:
35
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
36
- print(msg)
37
- raise ValueError(msg)
38
-
39
-
40
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
41
- """
42
- Matches and renames DataFrame columns to match model feature names (case-insensitive).
43
- Prioritizes exact matches, then case-insensitive matches.
44
-
45
- Raises ValueError if any model features cannot be matched.
46
- """
47
- df_columns_lower = {col.lower(): col for col in df.columns}
48
- rename_dict = {}
49
- missing = []
50
- for feature in model_features:
51
- if feature in df.columns:
52
- continue # Exact match
53
- elif feature.lower() in df_columns_lower:
54
- rename_dict[df_columns_lower[feature.lower()]] = feature
55
- else:
56
- missing.append(feature)
57
-
58
- if missing:
59
- raise ValueError(f"Features not found: {missing}")
60
-
61
- # Rename the DataFrame columns to match the model features
62
- return df.rename(columns=rename_dict)
63
-
64
-
65
- if __name__ == "__main__":
66
- """The main function is for training the XGBoost Quantile Regression models"""
67
-
68
- # Harness Template Parameters
69
- target = TEMPLATE_PARAMS["target_column"]
70
- features = TEMPLATE_PARAMS["features"]
71
- model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
72
- train_all_data = TEMPLATE_PARAMS["train_all_data"]
73
- validation_split = 0.2
74
- quantiles = [0.025, 0.25, 0.50, 0.75, 0.975]
75
- q_models = {}
76
-
77
- # Script arguments for input/output directories
78
- parser = argparse.ArgumentParser()
79
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
80
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
81
- parser.add_argument(
82
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
83
- )
84
- args = parser.parse_args()
85
-
86
- # Load training data from the specified directory
87
- training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
88
- print(f"Training Files: {training_files}")
89
-
90
- # Combine files and read them all into a single pandas dataframe
91
- df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
92
-
93
- # Check if the DataFrame is empty
94
- check_dataframe(df, "training_df")
95
-
96
- # Training data split logic
97
- if train_all_data:
98
- # Use all data for both training and validation
99
- print("Training on all data...")
100
- df_train = df.copy()
101
- df_val = df.copy()
102
- elif "training" in df.columns:
103
- # Split data based on a 'training' column if it exists
104
- print("Splitting data based on 'training' column...")
105
- df_train = df[df["training"]].copy()
106
- df_val = df[~df["training"]].copy()
107
- else:
108
- # Perform a random split if no 'training' column is found
109
- print("Splitting data randomly...")
110
- df_train, df_val = train_test_split(df, test_size=validation_split, random_state=42)
111
-
112
- # Features/Target output
113
- print(f"Target: {target}")
114
- print(f"Features: {str(features)}")
115
- print(f"Data Shape: {df.shape}")
116
-
117
- # Prepare features and targets for training
118
- X_train = df_train[features]
119
- X_val = df_val[features]
120
- y_train = df_train[target]
121
- y_val = df_val[target]
122
-
123
- # Train models for each of the quantiles
124
- for q in quantiles:
125
- params = {
126
- "objective": "reg:quantileerror",
127
- "quantile_alpha": q,
128
- }
129
- model = xgb.XGBRegressor(**params)
130
- model.fit(X_train, y_train)
131
-
132
- # Convert quantile to string
133
- q_str = f"q_{int(q * 100)}" if (q * 100) == int(q * 100) else f"q_{int(q * 1000):03d}"
134
-
135
- # Store the model
136
- q_models[q_str] = model
137
-
138
- # Run predictions for each quantile
139
- quantile_predictions = {q: model.predict(X_val) for q, model in q_models.items()}
140
-
141
- # Create a copy of the validation DataFrame and add the new columns
142
- result_df = df_val[[target]].copy()
143
-
144
- # Add the quantile predictions to the DataFrame
145
- for name, preds in quantile_predictions.items():
146
- result_df[name] = preds
147
-
148
- # Add the median as the main prediction
149
- result_df["prediction"] = result_df["q_50"]
150
-
151
- # Now compute residuals on the prediction
152
- result_df["residual"] = result_df[target] - result_df["prediction"]
153
- result_df["residual_abs"] = result_df["residual"].abs()
154
-
155
- # Save the results dataframe to S3
156
- wr.s3.to_csv(
157
- result_df,
158
- path=f"{model_metrics_s3_path}/validation_predictions.csv",
159
- index=False,
160
- )
161
-
162
- # Report Performance Metrics
163
- rmse = root_mean_squared_error(result_df[target], result_df["prediction"])
164
- mae = mean_absolute_error(result_df[target], result_df["prediction"])
165
- r2 = r2_score(result_df[target], result_df["prediction"])
166
- print(f"RMSE: {rmse:.3f}")
167
- print(f"MAE: {mae:.3f}")
168
- print(f"R2: {r2:.3f}")
169
- print(f"NumRows: {len(result_df)}")
170
-
171
- # Now save the quantile models
172
- for name, model in q_models.items():
173
- model_path = os.path.join(args.model_dir, f"{name}.json")
174
- print(f"Saving model: {model_path}")
175
- model.save_model(model_path)
176
-
177
- # Also save the features (this will validate input during predictions)
178
- with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
179
- json.dump(features, fp)
180
-
181
-
182
- def model_fn(model_dir) -> dict:
183
- """Deserialized and return all the fitted models from the model directory.
184
-
185
- Args:
186
- model_dir (str): The directory where the models are stored.
187
-
188
- Returns:
189
- dict: A dictionary of the models.
190
- """
191
-
192
- # Load ALL the Quantile models from the model directory
193
- models = {}
194
- for file in os.listdir(model_dir):
195
- if file.startswith("q") and file.endswith(".json"): # The Quantile models
196
- # Load the model
197
- model_path = os.path.join(model_dir, file)
198
- print(f"Loading model: {model_path}")
199
- model = xgb.XGBRegressor()
200
- model.load_model(model_path)
201
-
202
- # Store the quantile model
203
- q_name = os.path.splitext(file)[0]
204
- models[q_name] = model
205
-
206
- # Return all the models
207
- return models
208
-
209
-
210
- def input_fn(input_data, content_type):
211
- """Parse input data and return a DataFrame."""
212
- if not input_data:
213
- raise ValueError("Empty input data is not supported!")
214
-
215
- # Decode bytes to string if necessary
216
- if isinstance(input_data, bytes):
217
- input_data = input_data.decode("utf-8")
218
-
219
- if "text/csv" in content_type:
220
- return pd.read_csv(StringIO(input_data))
221
- elif "application/json" in content_type:
222
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
223
- else:
224
- raise ValueError(f"{content_type} not supported!")
225
-
226
-
227
- def output_fn(output_df, accept_type):
228
- """Supports both CSV and JSON output formats."""
229
- if "text/csv" in accept_type:
230
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
231
- return csv_output, "text/csv"
232
- elif "application/json" in accept_type:
233
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
234
- else:
235
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
236
-
237
-
238
- def predict_fn(df, models) -> pd.DataFrame:
239
- """Make Predictions with our XGB Quantile Regression Model
240
-
241
- Args:
242
- df (pd.DataFrame): The input DataFrame
243
- models (dict): The dictionary of models to use for predictions
244
-
245
- Returns:
246
- pd.DataFrame: The DataFrame with the predictions added
247
- """
248
-
249
- # Grab our feature columns (from training)
250
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
251
- with open(os.path.join(model_dir, "feature_columns.json")) as fp:
252
- model_features = json.load(fp)
253
- print(f"Model Features: {model_features}")
254
-
255
- # We're going match features in a case-insensitive manner, accounting for all the permutations
256
- # - Model has a feature list that's any case ("Id", "taCos", "cOunT", "likes_tacos")
257
- # - Incoming data has columns that are mixed case ("ID", "Tacos", "Count", "Likes_Tacos")
258
- matched_df = match_features_case_insensitive(df, model_features)
259
-
260
- # Predict the features against all the models
261
- for name, model in models.items():
262
- df[name] = model.predict(matched_df[model_features])
263
-
264
- # Use the median prediction as the main prediction
265
- df["prediction"] = df["q_50"]
266
-
267
- # Estimate the standard deviation of the predictions using the interquartile range
268
- df["prediction_std"] = (df["q_75"] - df["q_25"]) / 1.35
269
-
270
- # Reorganize the columns so they are in alphabetical order
271
- df = df.reindex(sorted(df.columns), axis=1)
272
-
273
- # All done, return the DataFrame
274
- return df
@@ -1 +0,0 @@
1
- # Note: In general this file should be empty (as the default inference image has all required libraries)