workbench 0.8.198__py3-none-any.whl → 0.8.203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. workbench/algorithms/dataframe/proximity.py +11 -4
  2. workbench/api/__init__.py +2 -1
  3. workbench/api/df_store.py +17 -108
  4. workbench/api/feature_set.py +48 -11
  5. workbench/api/model.py +1 -1
  6. workbench/api/parameter_store.py +3 -52
  7. workbench/core/artifacts/__init__.py +11 -2
  8. workbench/core/artifacts/artifact.py +5 -5
  9. workbench/core/artifacts/df_store_core.py +114 -0
  10. workbench/core/artifacts/endpoint_core.py +261 -78
  11. workbench/core/artifacts/feature_set_core.py +69 -1
  12. workbench/core/artifacts/model_core.py +48 -14
  13. workbench/core/artifacts/parameter_store_core.py +98 -0
  14. workbench/core/transforms/features_to_model/features_to_model.py +50 -33
  15. workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
  16. workbench/core/views/view.py +2 -2
  17. workbench/model_scripts/chemprop/chemprop.template +933 -0
  18. workbench/model_scripts/chemprop/generated_model_script.py +933 -0
  19. workbench/model_scripts/chemprop/requirements.txt +11 -0
  20. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  21. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  22. workbench/model_scripts/custom_models/proximity/proximity.py +11 -4
  23. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +11 -5
  24. workbench/model_scripts/custom_models/uq_models/meta_uq.template +11 -5
  25. workbench/model_scripts/custom_models/uq_models/ngboost.template +11 -5
  26. workbench/model_scripts/custom_models/uq_models/proximity.py +11 -4
  27. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +11 -5
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +365 -173
  29. workbench/model_scripts/pytorch_model/pytorch.template +362 -170
  30. workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
  31. workbench/model_scripts/script_generation.py +10 -7
  32. workbench/model_scripts/uq_models/generated_model_script.py +43 -27
  33. workbench/model_scripts/uq_models/mapie.template +40 -24
  34. workbench/model_scripts/xgb_model/generated_model_script.py +36 -7
  35. workbench/model_scripts/xgb_model/xgb_model.template +36 -7
  36. workbench/repl/workbench_shell.py +14 -5
  37. workbench/resources/open_source_api.key +1 -1
  38. workbench/scripts/endpoint_test.py +162 -0
  39. workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
  40. workbench/utils/chemprop_utils.py +761 -0
  41. workbench/utils/pytorch_utils.py +527 -0
  42. workbench/utils/xgboost_model_utils.py +10 -5
  43. workbench/web_interface/components/model_plot.py +7 -1
  44. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/METADATA +3 -3
  45. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/RECORD +49 -43
  46. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/entry_points.txt +2 -1
  47. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  48. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  49. workbench/model_scripts/__pycache__/script_generation.cpython-312.pyc +0 -0
  50. workbench/model_scripts/__pycache__/script_generation.cpython-313.pyc +0 -0
  51. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/WHEEL +0 -0
  52. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/licenses/LICENSE +0 -0
  53. {workbench-0.8.198.dist-info → workbench-0.8.203.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,302 @@
1
+ # Model Imports (this will be replaced with the imports for the template)
2
+ None
3
+
4
+ # Template Placeholders
5
+ TEMPLATE_PARAMS = {
6
+ "model_type": "regressor",
7
+ "target_column": "udm_asy_res_efflux_ratio",
8
+ "feature_list": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo', 'tertiary_amine_count', 'type_i_pattern_count', 'type_ii_pattern_count', 'aromatic_interaction_score', 'molecular_axis_length', 'molecular_asymmetry', 'molecular_volume_3d', 'radius_of_gyration', 'asphericity', 'charge_centroid_distance', 'nitrogen_span', 'amide_count', 'hba_hbd_ratio', 'intramolecular_hbond_potential', 'amphiphilic_moment'],
9
+ "model_class": PyTorch,
10
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-pytorch-test/training",
11
+ "train_all_data": False,
12
+ }
13
+
14
+ import awswrangler as wr
15
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
16
+ from sklearn.model_selection import train_test_split
17
+ from sklearn.pipeline import Pipeline
18
+
19
+ from io import StringIO
20
+ import json
21
+ import argparse
22
+ import joblib
23
+ import os
24
+ import pandas as pd
25
+ from typing import List
26
+
27
+ # Global model_type for both training and inference
28
+ model_type = TEMPLATE_PARAMS["model_type"]
29
+
30
+
31
+ # Function to check if dataframe is empty
32
+ def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
33
+ """Check if the DataFrame is empty and raise an error if so."""
34
+ if df.empty:
35
+ msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
36
+ print(msg)
37
+ raise ValueError(msg)
38
+
39
+
40
+ # Function to expand probability column into individual class probability columns
41
+ def expand_proba_column(df: pd.DataFrame, class_labels: List[str]) -> pd.DataFrame:
42
+ """Expand 'pred_proba' column into separate columns for each class label."""
43
+ proba_column = "pred_proba"
44
+ if proba_column not in df.columns:
45
+ raise ValueError('DataFrame does not contain a "pred_proba" column')
46
+
47
+ # Create new columns for each class label's probability
48
+ new_col_names = [f"{label}_proba" for label in class_labels]
49
+ proba_df = pd.DataFrame(df[proba_column].tolist(), columns=new_col_names)
50
+
51
+ # Drop the original 'pred_proba' column and reset the index
52
+ df = df.drop(columns=[proba_column]).reset_index(drop=True)
53
+
54
+ # Concatenate the new probability columns with the original DataFrame
55
+ df = pd.concat([df, proba_df], axis=1)
56
+ return df
57
+
58
+
59
+ # Function to match DataFrame columns to model features (case-insensitive)
60
+ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
61
+ """Match and rename DataFrame columns to match the model's features, case-insensitively."""
62
+ # Create a set of exact matches from the DataFrame columns
63
+ exact_match_set = set(df.columns)
64
+
65
+ # Create a case-insensitive map of DataFrame columns
66
+ column_map = {col.lower(): col for col in df.columns}
67
+ rename_dict = {}
68
+
69
+ # Build a dictionary for renaming columns based on case-insensitive matching
70
+ for feature in model_features:
71
+ if feature in exact_match_set:
72
+ rename_dict[feature] = feature
73
+ elif feature.lower() in column_map:
74
+ rename_dict[column_map[feature.lower()]] = feature
75
+
76
+ # Rename columns in the DataFrame to match model features
77
+ return df.rename(columns=rename_dict)
78
+
79
+
80
+ #
81
+ # Training Section
82
+ #
83
+ if __name__ == "__main__":
84
+ # Template Parameters
85
+ target = TEMPLATE_PARAMS["target_column"] # Can be None for unsupervised models
86
+ feature_list = TEMPLATE_PARAMS["feature_list"]
87
+ model_class = TEMPLATE_PARAMS["model_class"]
88
+ model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
89
+ train_all_data = TEMPLATE_PARAMS["train_all_data"]
90
+ validation_split = 0.2
91
+
92
+ # Script arguments for input/output directories
93
+ parser = argparse.ArgumentParser()
94
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
95
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
96
+ parser.add_argument(
97
+ "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
98
+ )
99
+ args = parser.parse_args()
100
+
101
+ # Load training data from the specified directory
102
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
103
+ all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
104
+
105
+ # Check if the DataFrame is empty
106
+ check_dataframe(all_df, "training_df")
107
+
108
+ # Initialize the model using the specified model class
109
+ model = model_class()
110
+
111
+ # Determine if standardization is needed based on the model type
112
+ needs_standardization = model_type in ["clusterer", "projection"]
113
+
114
+ if needs_standardization:
115
+ # Create a pipeline with standardization and the model
116
+ model = Pipeline([("scaler", StandardScaler()), ("model", model)])
117
+
118
+ # Handle logic based on the model_type
119
+ if model_type in ["classifier", "regressor"]:
120
+ # Supervised Models: Prepare for training
121
+ if train_all_data:
122
+ # Use all data for both training and validation
123
+ print("Training on all data...")
124
+ df_train = all_df.copy()
125
+ df_val = all_df.copy()
126
+ elif "training" in all_df.columns:
127
+ # Split data based on a 'training' column if it exists
128
+ print("Splitting data based on 'training' column...")
129
+ df_train = all_df[all_df["training"]].copy()
130
+ df_val = all_df[~all_df["training"]].copy()
131
+ else:
132
+ # Perform a random split if no 'training' column is found
133
+ print("Splitting data randomly...")
134
+ df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
135
+
136
+ # Encode the target variable if the model is a classifier
137
+ label_encoder = None
138
+ if model_type == "classifier" and target:
139
+ label_encoder = LabelEncoder()
140
+ df_train[target] = label_encoder.fit_transform(df_train[target])
141
+ df_val[target] = label_encoder.transform(df_val[target])
142
+
143
+ # Prepare features and targets for training
144
+ X_train = df_train[feature_list]
145
+ X_val = df_val[feature_list]
146
+ y_train = df_train[target] if target else None
147
+ y_val = df_val[target] if target else None
148
+
149
+ # Train the model using the training data
150
+ model.fit(X_train, y_train)
151
+
152
+ # Make predictions and handle classification-specific logic
153
+ preds = model.predict(X_val)
154
+ if model_type == "classifier" and target:
155
+ # Get class probabilities and expand them into separate columns
156
+ probs = model.predict_proba(X_val)
157
+ df_val["pred_proba"] = [p.tolist() for p in probs]
158
+ df_val = expand_proba_column(df_val, label_encoder.classes_)
159
+
160
+ # Decode the target and prediction labels
161
+ df_val[target] = label_encoder.inverse_transform(df_val[target])
162
+ preds = label_encoder.inverse_transform(preds)
163
+
164
+ # Add predictions to the validation DataFrame
165
+ df_val["prediction"] = preds
166
+
167
+ # Save the validation predictions to S3
168
+ output_columns = [target, "prediction"] + [col for col in df_val.columns if col.endswith("_proba")]
169
+ wr.s3.to_csv(df_val[output_columns], path=f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
170
+
171
+ elif model_type == "clusterer":
172
+ # Unsupervised Clustering Models: Assign cluster labels
173
+ all_df["cluster"] = model.fit_predict(all_df[feature_list])
174
+
175
+ elif model_type == "projection":
176
+ # Projection Models: Apply transformation and label first three components as x, y, z
177
+ transformed_data = model.fit_transform(all_df[feature_list])
178
+ num_components = transformed_data.shape[1]
179
+
180
+ # Special labels for the first three components, if they exist
181
+ special_labels = ["x", "y", "z"]
182
+ for i in range(num_components):
183
+ if i < len(special_labels):
184
+ all_df[special_labels[i]] = transformed_data[:, i]
185
+ else:
186
+ all_df[f"component_{i + 1}"] = transformed_data[:, i]
187
+
188
+ elif model_type == "transformer":
189
+ # Transformer Models: Apply transformation and use generic component labels
190
+ transformed_data = model.fit_transform(all_df[feature_list])
191
+ for i in range(transformed_data.shape[1]):
192
+ all_df[f"component_{i + 1}"] = transformed_data[:, i]
193
+
194
+ # Save the trained model and any necessary assets
195
+ joblib.dump(model, os.path.join(args.model_dir, "model.joblib"))
196
+ if model_type == "classifier" and label_encoder:
197
+ joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
198
+
199
+ # Save the feature list to validate input during predictions
200
+ with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
201
+ json.dump(feature_list, fp)
202
+
203
+
204
+ #
205
+ # Inference Section
206
+ #
207
+ def model_fn(model_dir):
208
+ """Load and return the model from the specified directory."""
209
+ return joblib.load(os.path.join(model_dir, "model.joblib"))
210
+
211
+
212
+ def input_fn(input_data, content_type):
213
+ """Parse input data and return a DataFrame."""
214
+ if not input_data:
215
+ raise ValueError("Empty input data is not supported!")
216
+
217
+ # Decode bytes to string if necessary
218
+ if isinstance(input_data, bytes):
219
+ input_data = input_data.decode("utf-8")
220
+
221
+ if "text/csv" in content_type:
222
+ return pd.read_csv(StringIO(input_data))
223
+ elif "application/json" in content_type:
224
+ return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
225
+ else:
226
+ raise ValueError(f"{content_type} not supported!")
227
+
228
+
229
+ def output_fn(output_df, accept_type):
230
+ """Supports both CSV and JSON output formats."""
231
+ if "text/csv" in accept_type:
232
+ csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
233
+ return csv_output, "text/csv"
234
+ elif "application/json" in accept_type:
235
+ return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
236
+ else:
237
+ raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
238
+
239
+
240
+ def predict_fn(df, model):
241
+ """Make predictions or apply transformations using the model and return the DataFrame with results."""
242
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
243
+
244
+ # Load feature columns from the saved file
245
+ with open(os.path.join(model_dir, "feature_columns.json")) as fp:
246
+ model_features = json.load(fp)
247
+
248
+ # Load label encoder if available (for classification models)
249
+ label_encoder = None
250
+ if os.path.exists(os.path.join(model_dir, "label_encoder.joblib")):
251
+ label_encoder = joblib.load(os.path.join(model_dir, "label_encoder.joblib"))
252
+
253
+ # Match features in a case-insensitive manner
254
+ matched_df = match_features_case_insensitive(df, model_features)
255
+
256
+ # Initialize a dictionary to store the results
257
+ results = {}
258
+
259
+ # Determine how to handle the model based on its available methods
260
+ if hasattr(model, "predict"):
261
+ # For supervised models (classifier or regressor)
262
+ predictions = model.predict(matched_df[model_features])
263
+ results["prediction"] = predictions
264
+
265
+ elif hasattr(model, "fit_predict"):
266
+ # For clustering models (e.g., DBSCAN)
267
+ clusters = model.fit_predict(matched_df[model_features])
268
+ results["cluster"] = clusters
269
+
270
+ elif hasattr(model, "fit_transform") and not hasattr(model, "predict"):
271
+ # For transformation/projection models (e.g., t-SNE, PCA)
272
+ transformed_data = model.fit_transform(matched_df[model_features])
273
+
274
+ # Handle 2D projection models specifically
275
+ if model_type == "projection" and transformed_data.shape[1] == 2:
276
+ results["x"] = transformed_data[:, 0]
277
+ results["y"] = transformed_data[:, 1]
278
+ else:
279
+ # General case for any number of components
280
+ for i in range(transformed_data.shape[1]):
281
+ results[f"component_{i + 1}"] = transformed_data[:, i]
282
+
283
+ else:
284
+ # Raise an error if the model does not support the expected methods
285
+ raise ValueError("Model does not support predict, fit_predict, or fit_transform methods.")
286
+
287
+ # Decode predictions if using a label encoder (for classification)
288
+ if label_encoder and "prediction" in results:
289
+ results["prediction"] = label_encoder.inverse_transform(results["prediction"])
290
+
291
+ # Add the results to the DataFrame
292
+ for key, value in results.items():
293
+ df[key] = value
294
+
295
+ # Add probability columns if the model supports it (for classification)
296
+ if hasattr(model, "predict_proba"):
297
+ probs = model.predict_proba(matched_df[model_features])
298
+ df["pred_proba"] = [p.tolist() for p in probs]
299
+ df = expand_proba_column(df, label_encoder.classes_)
300
+
301
+ # Return the modified DataFrame
302
+ return df
@@ -93,6 +93,7 @@ def generate_model_script(template_params: dict) -> str:
93
93
  template_params (dict): Dictionary containing the parameters:
94
94
  - model_imports (str): Import string for the model class
95
95
  - model_type (ModelType): The enumerated type of model to generate
96
+ - model_framework (str): The enumerated model framework to use
96
97
  - model_class (str): The model class to use (e.g., "RandomForestRegressor")
97
98
  - target_column (str): Column name of the target variable
98
99
  - feature_list (list[str]): A list of columns for the features
@@ -103,16 +104,18 @@ def generate_model_script(template_params: dict) -> str:
103
104
  Returns:
104
105
  str: The name of the generated model script
105
106
  """
106
- from workbench.api import ModelType # Avoid circular import
107
+ from workbench.api import ModelType, ModelFramework # Avoid circular import
107
108
 
108
109
  # Determine which template to use based on model type
109
110
  if template_params.get("model_class"):
110
- if template_params["model_class"].lower() == "pytorch":
111
- template_name = "pytorch.template"
112
- model_script_dir = "pytorch_model"
113
- else:
114
- template_name = "scikit_learn.template"
115
- model_script_dir = "scikit_learn"
111
+ template_name = "scikit_learn.template"
112
+ model_script_dir = "scikit_learn"
113
+ elif template_params["model_framework"] == ModelFramework.PYTORCH_TABULAR:
114
+ template_name = "pytorch.template"
115
+ model_script_dir = "pytorch_model"
116
+ elif template_params["model_framework"] == ModelFramework.CHEMPROP:
117
+ template_name = "chemprop.template"
118
+ model_script_dir = "chemprop"
116
119
  elif template_params["model_type"] in [ModelType.REGRESSOR, ModelType.CLASSIFIER]:
117
120
  template_name = "xgb_model.template"
118
121
  model_script_dir = "xgb_model"
@@ -5,7 +5,8 @@ from xgboost import XGBRegressor
5
5
  from sklearn.model_selection import train_test_split
6
6
 
7
7
  # Model Performance Scores
8
- from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error
8
+ from sklearn.metrics import mean_absolute_error, median_absolute_error, r2_score, root_mean_squared_error
9
+ from scipy.stats import spearmanr
9
10
 
10
11
  from io import StringIO
11
12
  import json
@@ -18,11 +19,11 @@ from typing import List, Tuple, Optional, Dict
18
19
 
19
20
  # Template Placeholders
20
21
  TEMPLATE_PARAMS = {
21
- "target": "udm_asy_res_efflux_ratio",
22
- "features": ['smr_vsa4', 'tpsa', 'numhdonors', 'nhohcount', 'peoe_vsa1', 'mollogp', 'peoe_vsa8', 'nitrogen_span', 'smr_vsa3', 'vsa_estate2', 'chi1v', 'molmr', 'estate_vsa4', 'xc_4dv', 'vsa_estate3', 'vsa_estate6', 'qed', 'estate_vsa8', 'chi2v', 'molecular_asymmetry', 'asphericity', 'vsa_estate4', 'minpartialcharge', 'axp_1d', 'num_s_centers', 'charge_centroid_distance', 'xpc_4dv', 'axp_0dv', 'estate_vsa2', 'peoe_vsa3', 'molecular_axis_length', 'mi', 'aromatic_interaction_score', 'vsa_estate8', 'bcut2d_logphi', 'molecular_volume_3d', 'balabanj', 'fr_al_oh', 'minabsestateindex', 'axp_7dv', 'axp_7d', 'bcut2d_chglo', 'vsa_estate9', 'xch_6d', 'kappa3', 'bcut2d_mrlow', 'estate_vsa3', 'c3sp3', 'chi3n', 'type_ii_pattern_count', 'xp_3d', 'bcut2d_logplow', 'fr_nhpyrrole', 'peoe_vsa9', 'slogp_vsa3', 'peoe_vsa2', 'maxabspartialcharge', 'fpdensitymorgan1', 'xch_7d', 'peoe_vsa11', 'axp_3d', 'bcut2d_mwlow', 'maxestateindex', 'minestateindex', 'radius_of_gyration', 'avgipc', 'smr_vsa6', 'vsa_estate7', 'fpdensitymorgan3', 'estate_vsa6', 'xp_7dv', 'xp_6dv', 'chi4n', 'vsa_estate5', 'fr_imidazole', 'xc_3dv', 'slogp_vsa2', 'num_r_centers', 'xch_5dv', 'bcut2d_mrhi', 'xp_4dv', 'xp_6d', 'mm', 'xpc_6d', 'numsaturatedcarbocycles', 'axp_3dv', 'chi3v', 'numvalenceelectrons', 'mare', 'c1sp2', 'smr_vsa9', 'xp_3dv', 'axp_1dv', 'fpdensitymorgan2', 'slogp_vsa5', 'sps', 'xc_3d', 'bertzct', 'estate_vsa10', 'axp_4d', 'smr_vsa1', 'peoe_vsa10', 'hallkieralpha', 'axp_5dv', 'chi0v', 'xch_7dv', 'mv', 'estate_vsa9', 'fr_ketone_topliss', 'estate_vsa5', 'molwt', 'estate_vsa7', 'type_i_pattern_count', 'xp_5d', 'heavyatommolwt', 'smr_vsa10', 'xc_4d', 'estate_vsa1', 'vsa_estate10', 'axp_6dv', 'axp_2d', 'mp', 'xc_5d', 'xch_6dv', 'xp_7d', 'peoe_vsa7', 'axp_0d', 'xp_2dv', 'axp_6d', 'xc_5dv', 'chi4v', 'xch_4dv', 'mz', 'tertiary_amine_count', 'xpc_6dv', 'peoe_vsa13', 'xpc_4d', 'hybratio', 'axp_5d', 'kappa2', 'slogp_vsa6', 'xpc_5dv', 'phi', 'xch_4d', 'smr_vsa5', 'kappa1', 'xp_5dv', 'bcut2d_chghi', 'numrotatablebonds', 'fr_ar_n', 'maxpartialcharge', 'bcut2d_mwhi', 'peoe_vsa4', 'c3sp2', 'smr_vsa7', 'slogp_vsa4', 'fr_nh0', 'xch_5d', 'slogp_vsa1', 'slogp_vsa10', 'axp_2dv', 'xc_6dv', 'numaliphaticrings', 'axp_4dv', 'chi0', 'labuteasa', 'c1sp3', 'numaliphaticcarbocycles', 'xp_0dv', 'fr_hoccn', 'fr_piperdine', 'fractioncsp3', 'si', 'slogp_vsa8', 'sv', 'fr_thiazole', 'fr_guanido', 'spe', 'peoe_vsa6', 'fr_pyridine', 'nocount', 'fr_piperzine', 'chi2n', 'chi0n', 'mse', 'fr_aniline', 'xpc_5d', 'peoe_vsa12', 'fr_ndealkylation1', 'fr_al_oh_notert', 'fr_methoxy', 'numheteroatoms', 'c2sp3', 'fr_nh1', 'sp', 'chi1', 'peoe_vsa14', 'numatomstereocenters', 'ringcount', 'mpe', 'slogp_vsa7', 'frac_defined_stereo', 'fr_morpholine', 'c2sp2', 'xp_2d', 'vsa_estate1', 'slogp_vsa11', 'fr_benzene', 'nbase', 'xp_4d', 'num_stereocenters', 'fr_arn', 'minabspartialcharge', 'chi1n', 'sare', 'numspiroatoms', 'xp_0d', 'fr_aryl_methyl', 'fr_imine', 'fr_priamide', 'num_defined_stereocenters', 'numunspecifiedatomstereocenters', 'fr_oxazole'],
22
+ "target": "mppb",
23
+ "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
23
24
  "compressed_features": [],
24
25
  "train_all_data": True,
25
- "hyperparameters": {},
26
+ "hyperparameters": {'objective': 'reg:absoluteerror', 'n_estimators': 300, 'max_depth': 6, 'learning_rate': 0.03, 'subsample': 0.8, 'colsample_bytree': 0.6, 'colsample_bylevel': 0.8, 'min_child_weight': 5, 'gamma': 0.1, 'reg_alpha': 0.3, 'reg_lambda': 1.5, 'random_state': 42},
26
27
  }
27
28
 
28
29
 
@@ -251,6 +252,14 @@ if __name__ == "__main__":
251
252
  print(f"FIT/TRAIN: {df_train.shape}")
252
253
  print(f"VALIDATION: {df_val.shape}")
253
254
 
255
+ # Extract sample weights if present
256
+ if 'sample_weight' in df_train.columns:
257
+ sample_weights = df_train['sample_weight']
258
+ print(f"Using sample weights: min={sample_weights.min():.2f}, max={sample_weights.max():.2f}, mean={sample_weights.mean():.2f}")
259
+ else:
260
+ sample_weights = None
261
+ print("No sample weights found, training with equal weights")
262
+
254
263
  # Prepare features and targets for training
255
264
  X_train = df_train[features]
256
265
  X_validate = df_val[features]
@@ -261,7 +270,7 @@ if __name__ == "__main__":
261
270
  print("\nTraining XGBoost for point predictions...")
262
271
  print(f" Hyperparameters: {hyperparameters}")
263
272
  xgb_model = XGBRegressor(enable_categorical=True, **hyperparameters)
264
- xgb_model.fit(X_train, y_train)
273
+ xgb_model.fit(X_train, y_train, sample_weight=sample_weights)
265
274
 
266
275
  # Evaluate XGBoost performance
267
276
  y_pred_xgb = xgb_model.predict(X_validate)
@@ -269,10 +278,15 @@ if __name__ == "__main__":
269
278
  xgb_mae = mean_absolute_error(y_validate, y_pred_xgb)
270
279
  xgb_r2 = r2_score(y_validate, y_pred_xgb)
271
280
 
281
+ xgb_medae = median_absolute_error(y_validate, y_pred_xgb)
282
+ xgb_spearman = spearmanr(y_validate, y_pred_xgb).correlation
283
+
272
284
  print(f"\nXGBoost Point Prediction Performance:")
273
- print(f"RMSE: {xgb_rmse:.3f}")
274
- print(f"MAE: {xgb_mae:.3f}")
275
- print(f"R2: {xgb_r2:.3f}")
285
+ print(f"rmse: {xgb_rmse:.3f}")
286
+ print(f"mae: {xgb_mae:.3f}")
287
+ print(f"medae: {xgb_medae:.3f}")
288
+ print(f"r2: {xgb_r2:.3f}")
289
+ print(f"spearmanr: {xgb_spearman:.3f}")
276
290
 
277
291
  # Define confidence levels we want to model
278
292
  confidence_levels = [0.50, 0.68, 0.80, 0.90, 0.95] # 50%, 68%, 80%, 90%, 95% confidence intervals
@@ -328,11 +342,14 @@ if __name__ == "__main__":
328
342
  coverage = np.mean((y_validate >= y_pis[:, 0, 0]) & (y_validate <= y_pis[:, 1, 0]))
329
343
  print(f" Coverage: Target={confidence_level * 100:.0f}%, Empirical={coverage * 100:.1f}%")
330
344
 
345
+ support = len(df_val)
331
346
  print(f"\nOverall Model Performance Summary:")
332
- print(f"XGBoost RMSE: {xgb_rmse:.3f}")
333
- print(f"XGBoost MAE: {xgb_mae:.3f}")
334
- print(f"XGBoost R2: {xgb_r2:.3f}")
335
- print(f"NumRows: {len(df_val)}")
347
+ print(f"rmse: {xgb_rmse:.3f}")
348
+ print(f"mae: {xgb_mae:.3f}")
349
+ print(f"medae: {xgb_medae:.3f}")
350
+ print(f"r2: {xgb_r2:.3f}")
351
+ print(f"spearmanr: {xgb_spearman:.3f}")
352
+ print(f"support: {support}")
336
353
 
337
354
  # Analyze interval widths across confidence levels
338
355
  print(f"\nInterval Width Analysis:")
@@ -345,9 +362,8 @@ if __name__ == "__main__":
345
362
  # Compute normalization statistics for confidence calculation
346
363
  print(f"\nComputing normalization statistics for confidence scores...")
347
364
 
348
- # Create a temporary validation dataframe with predictions
349
- temp_val_df = df_val.copy()
350
- temp_val_df["prediction"] = xgb_model.predict(X_validate)
365
+ # Add predictions directly to validation dataframe
366
+ df_val["prediction"] = xgb_model.predict(X_validate)
351
367
 
352
368
  # Add all quantile predictions
353
369
  for conf_level in confidence_levels:
@@ -356,25 +372,25 @@ if __name__ == "__main__":
356
372
  y_pred, y_pis = model.predict_interval(X_validate)
357
373
 
358
374
  if conf_level == 0.50:
359
- temp_val_df["q_25"] = y_pis[:, 0, 0]
360
- temp_val_df["q_75"] = y_pis[:, 1, 0]
375
+ df_val["q_25"] = y_pis[:, 0, 0]
376
+ df_val["q_75"] = y_pis[:, 1, 0]
361
377
  # y_pred is the median prediction
362
- temp_val_df["q_50"] = y_pred
378
+ df_val["q_50"] = y_pred
363
379
  elif conf_level == 0.68:
364
- temp_val_df["q_16"] = y_pis[:, 0, 0]
365
- temp_val_df["q_84"] = y_pis[:, 1, 0]
380
+ df_val["q_16"] = y_pis[:, 0, 0]
381
+ df_val["q_84"] = y_pis[:, 1, 0]
366
382
  elif conf_level == 0.80:
367
- temp_val_df["q_10"] = y_pis[:, 0, 0]
368
- temp_val_df["q_90"] = y_pis[:, 1, 0]
383
+ df_val["q_10"] = y_pis[:, 0, 0]
384
+ df_val["q_90"] = y_pis[:, 1, 0]
369
385
  elif conf_level == 0.90:
370
- temp_val_df["q_05"] = y_pis[:, 0, 0]
371
- temp_val_df["q_95"] = y_pis[:, 1, 0]
386
+ df_val["q_05"] = y_pis[:, 0, 0]
387
+ df_val["q_95"] = y_pis[:, 1, 0]
372
388
  elif conf_level == 0.95:
373
- temp_val_df["q_025"] = y_pis[:, 0, 0]
374
- temp_val_df["q_975"] = y_pis[:, 1, 0]
389
+ df_val["q_025"] = y_pis[:, 0, 0]
390
+ df_val["q_975"] = y_pis[:, 1, 0]
375
391
 
376
392
  # Compute normalization stats using q_10 and q_90 (default range)
377
- interval_width = (temp_val_df["q_90"] - temp_val_df["q_10"]).abs()
393
+ interval_width = (df_val["q_90"] - df_val["q_10"]).abs()
378
394
  median_interval_width = float(interval_width.median())
379
395
  print(f" Median interval width (q_10-q_90): {median_interval_width:.6f}")
380
396
 
@@ -5,7 +5,8 @@ from xgboost import XGBRegressor
5
5
  from sklearn.model_selection import train_test_split
6
6
 
7
7
  # Model Performance Scores
8
- from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error
8
+ from sklearn.metrics import mean_absolute_error, median_absolute_error, r2_score, root_mean_squared_error
9
+ from scipy.stats import spearmanr
9
10
 
10
11
  from io import StringIO
11
12
  import json
@@ -251,6 +252,14 @@ if __name__ == "__main__":
251
252
  print(f"FIT/TRAIN: {df_train.shape}")
252
253
  print(f"VALIDATION: {df_val.shape}")
253
254
 
255
+ # Extract sample weights if present
256
+ if 'sample_weight' in df_train.columns:
257
+ sample_weights = df_train['sample_weight']
258
+ print(f"Using sample weights: min={sample_weights.min():.2f}, max={sample_weights.max():.2f}, mean={sample_weights.mean():.2f}")
259
+ else:
260
+ sample_weights = None
261
+ print("No sample weights found, training with equal weights")
262
+
254
263
  # Prepare features and targets for training
255
264
  X_train = df_train[features]
256
265
  X_validate = df_val[features]
@@ -261,7 +270,7 @@ if __name__ == "__main__":
261
270
  print("\nTraining XGBoost for point predictions...")
262
271
  print(f" Hyperparameters: {hyperparameters}")
263
272
  xgb_model = XGBRegressor(enable_categorical=True, **hyperparameters)
264
- xgb_model.fit(X_train, y_train)
273
+ xgb_model.fit(X_train, y_train, sample_weight=sample_weights)
265
274
 
266
275
  # Evaluate XGBoost performance
267
276
  y_pred_xgb = xgb_model.predict(X_validate)
@@ -269,10 +278,15 @@ if __name__ == "__main__":
269
278
  xgb_mae = mean_absolute_error(y_validate, y_pred_xgb)
270
279
  xgb_r2 = r2_score(y_validate, y_pred_xgb)
271
280
 
281
+ xgb_medae = median_absolute_error(y_validate, y_pred_xgb)
282
+ xgb_spearman = spearmanr(y_validate, y_pred_xgb).correlation
283
+
272
284
  print(f"\nXGBoost Point Prediction Performance:")
273
- print(f"RMSE: {xgb_rmse:.3f}")
274
- print(f"MAE: {xgb_mae:.3f}")
275
- print(f"R2: {xgb_r2:.3f}")
285
+ print(f"rmse: {xgb_rmse:.3f}")
286
+ print(f"mae: {xgb_mae:.3f}")
287
+ print(f"medae: {xgb_medae:.3f}")
288
+ print(f"r2: {xgb_r2:.3f}")
289
+ print(f"spearmanr: {xgb_spearman:.3f}")
276
290
 
277
291
  # Define confidence levels we want to model
278
292
  confidence_levels = [0.50, 0.68, 0.80, 0.90, 0.95] # 50%, 68%, 80%, 90%, 95% confidence intervals
@@ -328,11 +342,14 @@ if __name__ == "__main__":
328
342
  coverage = np.mean((y_validate >= y_pis[:, 0, 0]) & (y_validate <= y_pis[:, 1, 0]))
329
343
  print(f" Coverage: Target={confidence_level * 100:.0f}%, Empirical={coverage * 100:.1f}%")
330
344
 
345
+ support = len(df_val)
331
346
  print(f"\nOverall Model Performance Summary:")
332
- print(f"XGBoost RMSE: {xgb_rmse:.3f}")
333
- print(f"XGBoost MAE: {xgb_mae:.3f}")
334
- print(f"XGBoost R2: {xgb_r2:.3f}")
335
- print(f"NumRows: {len(df_val)}")
347
+ print(f"rmse: {xgb_rmse:.3f}")
348
+ print(f"mae: {xgb_mae:.3f}")
349
+ print(f"medae: {xgb_medae:.3f}")
350
+ print(f"r2: {xgb_r2:.3f}")
351
+ print(f"spearmanr: {xgb_spearman:.3f}")
352
+ print(f"support: {support}")
336
353
 
337
354
  # Analyze interval widths across confidence levels
338
355
  print(f"\nInterval Width Analysis:")
@@ -345,9 +362,8 @@ if __name__ == "__main__":
345
362
  # Compute normalization statistics for confidence calculation
346
363
  print(f"\nComputing normalization statistics for confidence scores...")
347
364
 
348
- # Create a temporary validation dataframe with predictions
349
- temp_val_df = df_val.copy()
350
- temp_val_df["prediction"] = xgb_model.predict(X_validate)
365
+ # Add predictions directly to validation dataframe
366
+ df_val["prediction"] = xgb_model.predict(X_validate)
351
367
 
352
368
  # Add all quantile predictions
353
369
  for conf_level in confidence_levels:
@@ -356,25 +372,25 @@ if __name__ == "__main__":
356
372
  y_pred, y_pis = model.predict_interval(X_validate)
357
373
 
358
374
  if conf_level == 0.50:
359
- temp_val_df["q_25"] = y_pis[:, 0, 0]
360
- temp_val_df["q_75"] = y_pis[:, 1, 0]
375
+ df_val["q_25"] = y_pis[:, 0, 0]
376
+ df_val["q_75"] = y_pis[:, 1, 0]
361
377
  # y_pred is the median prediction
362
- temp_val_df["q_50"] = y_pred
378
+ df_val["q_50"] = y_pred
363
379
  elif conf_level == 0.68:
364
- temp_val_df["q_16"] = y_pis[:, 0, 0]
365
- temp_val_df["q_84"] = y_pis[:, 1, 0]
380
+ df_val["q_16"] = y_pis[:, 0, 0]
381
+ df_val["q_84"] = y_pis[:, 1, 0]
366
382
  elif conf_level == 0.80:
367
- temp_val_df["q_10"] = y_pis[:, 0, 0]
368
- temp_val_df["q_90"] = y_pis[:, 1, 0]
383
+ df_val["q_10"] = y_pis[:, 0, 0]
384
+ df_val["q_90"] = y_pis[:, 1, 0]
369
385
  elif conf_level == 0.90:
370
- temp_val_df["q_05"] = y_pis[:, 0, 0]
371
- temp_val_df["q_95"] = y_pis[:, 1, 0]
386
+ df_val["q_05"] = y_pis[:, 0, 0]
387
+ df_val["q_95"] = y_pis[:, 1, 0]
372
388
  elif conf_level == 0.95:
373
- temp_val_df["q_025"] = y_pis[:, 0, 0]
374
- temp_val_df["q_975"] = y_pis[:, 1, 0]
389
+ df_val["q_025"] = y_pis[:, 0, 0]
390
+ df_val["q_975"] = y_pis[:, 1, 0]
375
391
 
376
392
  # Compute normalization stats using q_10 and q_90 (default range)
377
- interval_width = (temp_val_df["q_90"] - temp_val_df["q_10"]).abs()
393
+ interval_width = (df_val["q_90"] - df_val["q_10"]).abs()
378
394
  median_interval_width = float(interval_width.median())
379
395
  print(f" Median interval width (q_10-q_90): {median_interval_width:.6f}")
380
396