workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -1,393 +0,0 @@
1
- # Model: NGBoost Regressor with Distribution output
2
- from ngboost import NGBRegressor
3
- from ngboost.distns import Cauchy, T
4
- from xgboost import XGBRegressor # Point Estimator
5
- from sklearn.model_selection import train_test_split
6
-
7
- # Model Performance Scores
8
- from sklearn.metrics import (
9
- mean_absolute_error,
10
- r2_score,
11
- root_mean_squared_error
12
- )
13
-
14
- from io import StringIO
15
- import json
16
- import argparse
17
- import joblib
18
- import os
19
- import numpy as np
20
- import pandas as pd
21
- from typing import List, Tuple
22
-
23
- # Local Imports
24
- from proximity import Proximity
25
-
26
-
27
-
28
- # Template Placeholders
29
- TEMPLATE_PARAMS = {
30
- "id_column": "udm_mol_id",
31
- "target": "udm_asy_res_value",
32
- "features": ['bcut2d_logplow', 'numradicalelectrons', 'smr_vsa5', 'fr_lactam', 'fr_morpholine', 'fr_aldehyde', 'slogp_vsa1', 'fr_amidine', 'bpol', 'fr_ester', 'fr_azo', 'kappa3', 'peoe_vsa5', 'fr_ketone_topliss', 'vsa_estate9', 'estate_vsa9', 'bcut2d_mrhi', 'fr_ndealkylation1', 'numrotatablebonds', 'minestateindex', 'fr_quatn', 'peoe_vsa3', 'fr_epoxide', 'fr_aniline', 'minpartialcharge', 'fr_nitroso', 'fpdensitymorgan2', 'fr_oxime', 'fr_sulfone', 'smr_vsa1', 'kappa1', 'fr_pyridine', 'numaromaticrings', 'vsa_estate6', 'molmr', 'estate_vsa1', 'fr_dihydropyridine', 'vsa_estate10', 'fr_alkyl_halide', 'chi2n', 'fr_thiocyan', 'fpdensitymorgan1', 'fr_unbrch_alkane', 'slogp_vsa9', 'chi4n', 'fr_nitro_arom', 'fr_al_oh', 'fr_furan', 'fr_c_s', 'peoe_vsa8', 'peoe_vsa14', 'numheteroatoms', 'fr_ndealkylation2', 'maxabspartialcharge', 'vsa_estate2', 'peoe_vsa7', 'apol', 'numhacceptors', 'fr_tetrazole', 'vsa_estate1', 'peoe_vsa9', 'naromatom', 'bcut2d_chghi', 'fr_sh', 'fr_halogen', 'slogp_vsa4', 'fr_benzodiazepine', 'molwt', 'fr_isocyan', 'fr_prisulfonamd', 'maxabsestateindex', 'minabsestateindex', 'peoe_vsa11', 'slogp_vsa12', 'estate_vsa5', 'numaliphaticcarbocycles', 'bcut2d_mwlow', 'slogp_vsa7', 'fr_allylic_oxid', 'fr_methoxy', 'fr_nh0', 'fr_coo2', 'fr_phenol', 'nacid', 'nbase', 'chi3v', 'fr_ar_nh', 'fr_nitrile', 'fr_imidazole', 'fr_urea', 'bcut2d_mrlow', 'chi1', 'smr_vsa6', 'fr_aryl_methyl', 'narombond', 'fr_alkyl_carbamate', 'fr_piperzine', 'exactmolwt', 'qed', 'chi0n', 'fr_sulfonamd', 'fr_thiazole', 'numvalenceelectrons', 'fr_phos_acid', 'peoe_vsa12', 'fr_nh1', 'fr_hdrzine', 'fr_c_o_nocoo', 'fr_lactone', 'estate_vsa6', 'bcut2d_logphi', 'vsa_estate7', 'peoe_vsa13', 'numsaturatedcarbocycles', 'fr_nitro', 'fr_phenol_noorthohbond', 'rotratio', 'fr_barbitur', 'fr_isothiocyan', 'balabanj', 'fr_arn', 'fr_imine', 'maxpartialcharge', 'fr_sulfide', 'slogp_vsa11', 'fr_hoccn', 'fr_n_o', 'peoe_vsa1', 'slogp_vsa6', 'heavyatommolwt', 'fractioncsp3', 'estate_vsa8', 'peoe_vsa10', 'numaliphaticrings', 'fr_thiophene', 'maxestateindex', 'smr_vsa10', 'labuteasa', 'smr_vsa2', 'fpdensitymorgan3', 'smr_vsa9', 'slogp_vsa10', 'numaromaticheterocycles', 'fr_nh2', 'fr_diazo', 'chi3n', 'fr_ar_coo', 'slogp_vsa5', 'fr_bicyclic', 'fr_amide', 'estate_vsa10', 'fr_guanido', 'chi1n', 'numsaturatedrings', 'fr_piperdine', 'fr_term_acetylene', 'estate_vsa4', 'slogp_vsa3', 'fr_coo', 'fr_ether', 'estate_vsa7', 'bcut2d_chglo', 'fr_oxazole', 'peoe_vsa6', 'hallkieralpha', 'peoe_vsa2', 'chi2v', 'nocount', 'vsa_estate5', 'fr_nhpyrrole', 'fr_al_coo', 'bertzct', 'estate_vsa11', 'minabspartialcharge', 'slogp_vsa8', 'fr_imide', 'kappa2', 'numaliphaticheterocycles', 'numsaturatedheterocycles', 'fr_hdrzone', 'smr_vsa4', 'fr_ar_n', 'nrot', 'smr_vsa8', 'slogp_vsa2', 'chi4v', 'fr_phos_ester', 'fr_para_hydroxylation', 'smr_vsa3', 'nhohcount', 'estate_vsa2', 'mollogp', 'tpsa', 'fr_azide', 'peoe_vsa4', 'numhdonors', 'fr_al_oh_notert', 'fr_c_o', 'chi0', 'fr_nitro_arom_nonortho', 'vsa_estate3', 'fr_benzene', 'fr_ketone', 'vsa_estate8', 'smr_vsa7', 'fr_ar_oh', 'fr_priamide', 'ringcount', 'estate_vsa3', 'numaromaticcarbocycles', 'bcut2d_mwhi', 'chi1v', 'heavyatomcount', 'vsa_estate4', 'chi0v', 'chiral_centers', 'r_cnt', 's_cnt', 'db_stereo', 'e_cnt', 'z_cnt', 'chiral_fp', 'db_fp'],
33
- "compressed_features": [],
34
- "train_all_data": False,
35
- "track_columns": "udm_asy_res_value"
36
- }
37
-
38
-
39
- # Function to check if dataframe is empty
40
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
41
- """
42
- Check if the provided dataframe is empty and raise an exception if it is.
43
-
44
- Args:
45
- df (pd.DataFrame): DataFrame to check
46
- df_name (str): Name of the DataFrame
47
- """
48
- if df.empty:
49
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
50
- print(msg)
51
- raise ValueError(msg)
52
-
53
-
54
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
55
- """
56
- Matches and renames DataFrame columns to match model feature names (case-insensitive).
57
- Prioritizes exact matches, then case-insensitive matches.
58
-
59
- Raises ValueError if any model features cannot be matched.
60
- """
61
- df_columns_lower = {col.lower(): col for col in df.columns}
62
- rename_dict = {}
63
- missing = []
64
- for feature in model_features:
65
- if feature in df.columns:
66
- continue # Exact match
67
- elif feature.lower() in df_columns_lower:
68
- rename_dict[df_columns_lower[feature.lower()]] = feature
69
- else:
70
- missing.append(feature)
71
-
72
- if missing:
73
- raise ValueError(f"Features not found: {missing}")
74
-
75
- # Rename the DataFrame columns to match the model features
76
- return df.rename(columns=rename_dict)
77
-
78
-
79
- def convert_categorical_types(df: pd.DataFrame, features: list, category_mappings={}) -> tuple:
80
- """
81
- Converts appropriate columns to categorical type with consistent mappings.
82
-
83
- Args:
84
- df (pd.DataFrame): The DataFrame to process.
85
- features (list): List of feature names to consider for conversion.
86
- category_mappings (dict, optional): Existing category mappings. If empty dict, we're in
87
- training mode. If populated, we're in inference mode.
88
-
89
- Returns:
90
- tuple: (processed DataFrame, category mappings dictionary)
91
- """
92
- # Training mode
93
- if category_mappings == {}:
94
- for col in df.select_dtypes(include=["object", "string"]):
95
- if col in features and df[col].nunique() < 20:
96
- print(f"Training mode: Converting {col} to category")
97
- df[col] = df[col].astype("category")
98
- category_mappings[col] = df[col].cat.categories.tolist() # Store category mappings
99
-
100
- # Inference mode
101
- else:
102
- for col, categories in category_mappings.items():
103
- if col in df.columns:
104
- print(f"Inference mode: Applying categorical mapping for {col}")
105
- df[col] = pd.Categorical(df[col], categories=categories) # Apply consistent categorical mapping
106
-
107
- return df, category_mappings
108
-
109
-
110
- def decompress_features(
111
- df: pd.DataFrame, features: List[str], compressed_features: List[str]
112
- ) -> Tuple[pd.DataFrame, List[str]]:
113
- """Prepare features for the model by decompressing bitstring features
114
-
115
- Args:
116
- df (pd.DataFrame): The features DataFrame
117
- features (List[str]): Full list of feature names
118
- compressed_features (List[str]): List of feature names to decompress (bitstrings)
119
-
120
- Returns:
121
- pd.DataFrame: DataFrame with the decompressed features
122
- List[str]: Updated list of feature names after decompression
123
-
124
- Raises:
125
- ValueError: If any missing values are found in the specified features
126
- """
127
-
128
- # Check for any missing values in the required features
129
- missing_counts = df[features].isna().sum()
130
- if missing_counts.any():
131
- missing_features = missing_counts[missing_counts > 0]
132
- print(
133
- f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
134
- "WARNING: You might want to remove/replace all NaN values before processing."
135
- )
136
-
137
- # Decompress the specified compressed features
138
- decompressed_features = features.copy()
139
- for feature in compressed_features:
140
- if (feature not in df.columns) or (feature not in features):
141
- print(f"Feature '{feature}' not in the features list, skipping decompression.")
142
- continue
143
-
144
- # Remove the feature from the list of features to avoid duplication
145
- decompressed_features.remove(feature)
146
-
147
- # Handle all compressed features as bitstrings
148
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
149
- prefix = feature[:3]
150
-
151
- # Create all new columns at once - avoids fragmentation
152
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
153
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
154
-
155
- # Add to features list
156
- decompressed_features.extend(new_col_names)
157
-
158
- # Drop original column and concatenate new ones
159
- df = df.drop(columns=[feature])
160
- df = pd.concat([df, new_df], axis=1)
161
-
162
- return df, decompressed_features
163
-
164
-
165
- if __name__ == "__main__":
166
- # Template Parameters
167
- id_column = TEMPLATE_PARAMS["id_column"]
168
- target = TEMPLATE_PARAMS["target"]
169
- features = TEMPLATE_PARAMS["features"]
170
- orig_features = features.copy()
171
- compressed_features = TEMPLATE_PARAMS["compressed_features"]
172
- train_all_data = TEMPLATE_PARAMS["train_all_data"]
173
- track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
174
- validation_split = 0.2
175
-
176
- # Script arguments for input/output directories
177
- parser = argparse.ArgumentParser()
178
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
179
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
180
- parser.add_argument(
181
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
182
- )
183
- args = parser.parse_args()
184
-
185
- # Read the training data into DataFrames
186
- training_files = [
187
- os.path.join(args.train, file)
188
- for file in os.listdir(args.train)
189
- if file.endswith(".csv")
190
- ]
191
- print(f"Training Files: {training_files}")
192
-
193
- # Combine files and read them all into a single pandas dataframe
194
- all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
195
-
196
- # Check if the dataframe is empty
197
- check_dataframe(all_df, "training_df")
198
-
199
- # Features/Target output
200
- print(f"Target: {target}")
201
- print(f"Features: {str(features)}")
202
-
203
- # Convert any features that might be categorical to 'category' type
204
- all_df, category_mappings = convert_categorical_types(all_df, features)
205
-
206
- # If we have compressed features, decompress them
207
- if compressed_features:
208
- print(f"Decompressing features {compressed_features}...")
209
- all_df, features = decompress_features(all_df, features, compressed_features)
210
-
211
- # Do we want to train on all the data?
212
- if train_all_data:
213
- print("Training on ALL of the data")
214
- df_train = all_df.copy()
215
- df_val = all_df.copy()
216
-
217
- # Does the dataframe have a training column?
218
- elif "training" in all_df.columns:
219
- print("Found training column, splitting data based on training column")
220
- df_train = all_df[all_df["training"]]
221
- df_val = all_df[~all_df["training"]]
222
- else:
223
- # Just do a random training Split
224
- print("WARNING: No training column found, splitting data with random state=42")
225
- df_train, df_val = train_test_split(
226
- all_df, test_size=validation_split, random_state=42
227
- )
228
- print(f"FIT/TRAIN: {df_train.shape}")
229
- print(f"VALIDATION: {df_val.shape}")
230
-
231
- # We're using XGBoost for point predictions and NGBoost for uncertainty quantification
232
- xgb_model = XGBRegressor()
233
- ngb_model = NGBRegressor() # Dist=Cauchy) Seems to give HUGE prediction intervals
234
- ngb_model = NGBRegressor(
235
- Dist=T,
236
- learning_rate=0.005,
237
- minibatch_frac=0.1, # Very small batches
238
- col_sample=0.8 # This parameter DOES exist
239
- ) # Testing this out
240
- print("NGBoost using T distribution for uncertainty quantification")
241
-
242
- # Prepare features and targets for training
243
- X_train = df_train[features]
244
- X_validate = df_val[features]
245
- y_train = df_train[target]
246
- y_validate = df_val[target]
247
-
248
- # Train both models using the training data
249
- xgb_model.fit(X_train, y_train)
250
- ngb_model.fit(X_train, y_train, X_val=X_validate, Y_val=y_validate)
251
-
252
- # Make Predictions on the Validation Set
253
- print(f"Making Predictions on Validation Set...")
254
- preds = xgb_model.predict(X_validate)
255
-
256
- # Calculate various model performance metrics (regression)
257
- rmse = root_mean_squared_error(y_validate, preds)
258
- mae = mean_absolute_error(y_validate, preds)
259
- r2 = r2_score(y_validate, preds)
260
- print(f"RMSE: {rmse:.3f}")
261
- print(f"MAE: {mae:.3f}")
262
- print(f"R2: {r2:.3f}")
263
- print(f"NumRows: {len(df_val)}")
264
-
265
- # Save the trained XGBoost model
266
- xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
267
-
268
- # Save the trained NGBoost model
269
- joblib.dump(ngb_model, os.path.join(args.model_dir, "ngb_model.joblib"))
270
-
271
- # Save the features (this will validate input during predictions)
272
- with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
273
- json.dump(orig_features, fp) # We save the original features, not the decompressed ones
274
-
275
- # Now the Proximity model
276
- model = Proximity(df_train, id_column, features, target, track_columns=track_columns)
277
-
278
- # Now serialize the model
279
- model.serialize(args.model_dir)
280
-
281
-
282
- #
283
- # Inference Section
284
- #
285
- def model_fn(model_dir) -> dict:
286
- """Load and return XGBoost, NGBoost, and Prox Model from model directory."""
287
-
288
- # Load XGBoost regressor
289
- xgb_path = os.path.join(model_dir, "xgb_model.json")
290
- xgb_model = XGBRegressor(enable_categorical=True)
291
- xgb_model.load_model(xgb_path)
292
-
293
- # Load NGBoost regressor
294
- ngb_model = joblib.load(os.path.join(model_dir, "ngb_model.joblib"))
295
-
296
- # Deserialize the proximity model
297
- prox_model = Proximity.deserialize(model_dir)
298
-
299
- return {
300
- "xgboost": xgb_model,
301
- "ngboost": ngb_model,
302
- "proximity": prox_model
303
- }
304
-
305
-
306
- def input_fn(input_data, content_type):
307
- """Parse input data and return a DataFrame."""
308
- if not input_data:
309
- raise ValueError("Empty input data is not supported!")
310
-
311
- # Decode bytes to string if necessary
312
- if isinstance(input_data, bytes):
313
- input_data = input_data.decode("utf-8")
314
-
315
- if "text/csv" in content_type:
316
- return pd.read_csv(StringIO(input_data))
317
- elif "application/json" in content_type:
318
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
319
- else:
320
- raise ValueError(f"{content_type} not supported!")
321
-
322
-
323
- def output_fn(output_df, accept_type):
324
- """Supports both CSV and JSON output formats."""
325
- if "text/csv" in accept_type:
326
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
327
- return csv_output, "text/csv"
328
- elif "application/json" in accept_type:
329
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
330
- else:
331
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
332
-
333
-
334
- def predict_fn(df, models) -> pd.DataFrame:
335
- """Make Predictions with our XGB Quantile Regression Model
336
-
337
- Args:
338
- df (pd.DataFrame): The input DataFrame
339
- models (dict): The dictionary of models to use for predictions
340
-
341
- Returns:
342
- pd.DataFrame: The DataFrame with the predictions added
343
- """
344
-
345
- # Grab our feature columns (from training)
346
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
347
- with open(os.path.join(model_dir, "feature_columns.json")) as fp:
348
- model_features = json.load(fp)
349
-
350
- # Match features in a case-insensitive manner
351
- matched_df = match_features_case_insensitive(df, model_features)
352
-
353
- # Use XGBoost for point predictions
354
- df["prediction"] = models["xgboost"].predict(matched_df[model_features])
355
-
356
- # NGBoost predict returns distribution objects
357
- y_dists = models["ngboost"].pred_dist(matched_df[model_features])
358
-
359
- # Extract parameters from distribution
360
- dist_params = y_dists.params
361
-
362
- # Extract mean and std from distribution parameters
363
- df["prediction_uq"] = dist_params['loc'] # mean
364
- df["prediction_std"] = dist_params['scale'] # standard deviation
365
-
366
- # Add 95% prediction intervals using ppf (percent point function)
367
- # Note: Our hybrid model uses XGB point prediction and NGBoost UQ
368
- # so we need to adjust the bounds to include the point prediction
369
- df["q_025"] = np.minimum(y_dists.ppf(0.025), df["prediction"])
370
- df["q_975"] = np.maximum(y_dists.ppf(0.975), df["prediction"])
371
-
372
- # Add 90% prediction intervals
373
- df["q_05"] = y_dists.ppf(0.05) # 5th percentile
374
- df["q_95"] = y_dists.ppf(0.95) # 95th percentile
375
-
376
- # Add 80% prediction intervals
377
- df["q_10"] = y_dists.ppf(0.10) # 10th percentile
378
- df["q_90"] = y_dists.ppf(0.90) # 90th percentile
379
-
380
- # Add 50% prediction intervals
381
- df["q_25"] = y_dists.ppf(0.25) # 25th percentile
382
- df["q_75"] = y_dists.ppf(0.75) # 75th percentile
383
-
384
- # Reorder the quantile columns for easier reading
385
- quantile_cols = ["q_025", "q_05", "q_10", "q_25", "q_75", "q_90", "q_95", "q_975"]
386
- other_cols = [col for col in df.columns if col not in quantile_cols]
387
- df = df[other_cols + quantile_cols]
388
-
389
- # Compute Nearest neighbors with Proximity model
390
- models["proximity"].neighbors(df)
391
-
392
- # Return the modified DataFrame
393
- return df