workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +7 -7
- workbench/core/artifacts/data_capture_core.py +8 -1
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +323 -205
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +133 -101
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +18 -7
- workbench/utils/chem_utils/mol_standardize.py +80 -58
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +274 -87
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
|
@@ -4,11 +4,11 @@ None
|
|
|
4
4
|
# Template Placeholders
|
|
5
5
|
TEMPLATE_PARAMS = {
|
|
6
6
|
"model_type": "regressor",
|
|
7
|
-
"target_column": "
|
|
8
|
-
"feature_list": ['molwt', '
|
|
7
|
+
"target_column": "udm_asy_res_efflux_ratio",
|
|
8
|
+
"feature_list": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo', 'tertiary_amine_count', 'type_i_pattern_count', 'type_ii_pattern_count', 'aromatic_interaction_score', 'molecular_axis_length', 'molecular_asymmetry', 'molecular_volume_3d', 'radius_of_gyration', 'asphericity', 'charge_centroid_distance', 'nitrogen_span', 'amide_count', 'hba_hbd_ratio', 'intramolecular_hbond_potential', 'amphiphilic_moment'],
|
|
9
9
|
"model_class": PyTorch,
|
|
10
|
-
"model_metrics_s3_path": "s3://
|
|
11
|
-
"train_all_data": False
|
|
10
|
+
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-pytorch-test/training",
|
|
11
|
+
"train_all_data": False,
|
|
12
12
|
}
|
|
13
13
|
|
|
14
14
|
import awswrangler as wr
|
|
@@ -99,10 +99,7 @@ if __name__ == "__main__":
|
|
|
99
99
|
args = parser.parse_args()
|
|
100
100
|
|
|
101
101
|
# Load training data from the specified directory
|
|
102
|
-
training_files = [
|
|
103
|
-
os.path.join(args.train, file)
|
|
104
|
-
for file in os.listdir(args.train) if file.endswith(".csv")
|
|
105
|
-
]
|
|
102
|
+
training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
|
|
106
103
|
all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
|
|
107
104
|
|
|
108
105
|
# Check if the DataFrame is empty
|
|
@@ -116,10 +113,7 @@ if __name__ == "__main__":
|
|
|
116
113
|
|
|
117
114
|
if needs_standardization:
|
|
118
115
|
# Create a pipeline with standardization and the model
|
|
119
|
-
model = Pipeline([
|
|
120
|
-
("scaler", StandardScaler()),
|
|
121
|
-
("model", model)
|
|
122
|
-
])
|
|
116
|
+
model = Pipeline([("scaler", StandardScaler()), ("model", model)])
|
|
123
117
|
|
|
124
118
|
# Handle logic based on the model_type
|
|
125
119
|
if model_type in ["classifier", "regressor"]:
|
|
@@ -206,6 +200,7 @@ if __name__ == "__main__":
|
|
|
206
200
|
with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
|
|
207
201
|
json.dump(feature_list, fp)
|
|
208
202
|
|
|
203
|
+
|
|
209
204
|
#
|
|
210
205
|
# Inference Section
|
|
211
206
|
#
|
|
@@ -8,7 +8,7 @@ TEMPLATE_PARAMS = {
|
|
|
8
8
|
"feature_list": "{{feature_list}}",
|
|
9
9
|
"model_class": "{{model_class}}",
|
|
10
10
|
"model_metrics_s3_path": "{{model_metrics_s3_path}}",
|
|
11
|
-
"train_all_data": "{{train_all_data}}"
|
|
11
|
+
"train_all_data": "{{train_all_data}}",
|
|
12
12
|
}
|
|
13
13
|
|
|
14
14
|
import awswrangler as wr
|
|
@@ -99,10 +99,7 @@ if __name__ == "__main__":
|
|
|
99
99
|
args = parser.parse_args()
|
|
100
100
|
|
|
101
101
|
# Load training data from the specified directory
|
|
102
|
-
training_files = [
|
|
103
|
-
os.path.join(args.train, file)
|
|
104
|
-
for file in os.listdir(args.train) if file.endswith(".csv")
|
|
105
|
-
]
|
|
102
|
+
training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
|
|
106
103
|
all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
|
|
107
104
|
|
|
108
105
|
# Check if the DataFrame is empty
|
|
@@ -116,10 +113,7 @@ if __name__ == "__main__":
|
|
|
116
113
|
|
|
117
114
|
if needs_standardization:
|
|
118
115
|
# Create a pipeline with standardization and the model
|
|
119
|
-
model = Pipeline([
|
|
120
|
-
("scaler", StandardScaler()),
|
|
121
|
-
("model", model)
|
|
122
|
-
])
|
|
116
|
+
model = Pipeline([("scaler", StandardScaler()), ("model", model)])
|
|
123
117
|
|
|
124
118
|
# Handle logic based on the model_type
|
|
125
119
|
if model_type in ["classifier", "regressor"]:
|
|
@@ -206,6 +200,7 @@ if __name__ == "__main__":
|
|
|
206
200
|
with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
|
|
207
201
|
json.dump(feature_list, fp)
|
|
208
202
|
|
|
203
|
+
|
|
209
204
|
#
|
|
210
205
|
# Inference Section
|
|
211
206
|
#
|
|
@@ -6,7 +6,6 @@ import logging
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
import importlib.util
|
|
8
8
|
|
|
9
|
-
|
|
10
9
|
# Setup the logger
|
|
11
10
|
log = logging.getLogger("workbench")
|
|
12
11
|
|
|
@@ -93,32 +92,36 @@ def generate_model_script(template_params: dict) -> str:
|
|
|
93
92
|
template_params (dict): Dictionary containing the parameters:
|
|
94
93
|
- model_imports (str): Import string for the model class
|
|
95
94
|
- model_type (ModelType): The enumerated type of model to generate
|
|
95
|
+
- model_framework (str): The enumerated model framework to use
|
|
96
96
|
- model_class (str): The model class to use (e.g., "RandomForestRegressor")
|
|
97
97
|
- target_column (str): Column name of the target variable
|
|
98
98
|
- feature_list (list[str]): A list of columns for the features
|
|
99
99
|
- model_metrics_s3_path (str): The S3 path to store the model metrics
|
|
100
100
|
- train_all_data (bool): Whether to train on all (100%) of the data
|
|
101
101
|
- hyperparameters (dict, optional): Hyperparameters for the model (default: None)
|
|
102
|
+
- child_endpoints (list[str], optional): For META models, list of child endpoint names
|
|
102
103
|
|
|
103
104
|
Returns:
|
|
104
105
|
str: The name of the generated model script
|
|
105
106
|
"""
|
|
106
|
-
from workbench.api import ModelType # Avoid circular import
|
|
107
|
+
from workbench.api import ModelType, ModelFramework # Avoid circular import
|
|
107
108
|
|
|
108
109
|
# Determine which template to use based on model type
|
|
109
110
|
if template_params.get("model_class"):
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
111
|
+
template_name = "scikit_learn.template"
|
|
112
|
+
model_script_dir = "scikit_learn"
|
|
113
|
+
elif template_params["model_framework"] == ModelFramework.PYTORCH:
|
|
114
|
+
template_name = "pytorch.template"
|
|
115
|
+
model_script_dir = "pytorch_model"
|
|
116
|
+
elif template_params["model_framework"] == ModelFramework.CHEMPROP:
|
|
117
|
+
template_name = "chemprop.template"
|
|
118
|
+
model_script_dir = "chemprop"
|
|
119
|
+
elif template_params["model_framework"] == ModelFramework.META:
|
|
120
|
+
template_name = "meta_model.template"
|
|
121
|
+
model_script_dir = "meta_model"
|
|
122
|
+
elif template_params["model_type"] in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.CLASSIFIER]:
|
|
117
123
|
template_name = "xgb_model.template"
|
|
118
124
|
model_script_dir = "xgb_model"
|
|
119
|
-
elif template_params["model_type"] == ModelType.UQ_REGRESSOR:
|
|
120
|
-
template_name = "quant_regression.template"
|
|
121
|
-
model_script_dir = "quant_regression"
|
|
122
125
|
elif template_params["model_type"] == ModelType.ENSEMBLE_REGRESSOR:
|
|
123
126
|
template_name = "ensemble_xgb.template"
|
|
124
127
|
model_script_dir = "ensemble_xgb"
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
# Model: XGBoost for point predictions + MAPIE UQ Harness for conformalized intervals
|
|
2
|
+
from xgboost import XGBRegressor
|
|
3
|
+
from sklearn.model_selection import train_test_split
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import argparse
|
|
7
|
+
import joblib
|
|
8
|
+
import os
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
# Shared model script utilities
|
|
13
|
+
from model_script_utils import (
|
|
14
|
+
check_dataframe,
|
|
15
|
+
match_features_case_insensitive,
|
|
16
|
+
convert_categorical_types,
|
|
17
|
+
decompress_features,
|
|
18
|
+
input_fn,
|
|
19
|
+
output_fn,
|
|
20
|
+
compute_regression_metrics,
|
|
21
|
+
print_regression_metrics,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# UQ Harness for uncertainty quantification
|
|
25
|
+
from uq_harness import (
|
|
26
|
+
train_uq_models,
|
|
27
|
+
save_uq_models,
|
|
28
|
+
load_uq_models,
|
|
29
|
+
predict_intervals,
|
|
30
|
+
compute_confidence,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Template Placeholders
|
|
34
|
+
TEMPLATE_PARAMS = {
|
|
35
|
+
"target": "solubility",
|
|
36
|
+
"features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
|
|
37
|
+
"compressed_features": [],
|
|
38
|
+
"train_all_data": False,
|
|
39
|
+
"hyperparameters": {'training_config': {'max_epochs': 150}, 'model_config': {'layers': '128-64-32'}},
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
if __name__ == "__main__":
|
|
44
|
+
# Template Parameters
|
|
45
|
+
target = TEMPLATE_PARAMS["target"]
|
|
46
|
+
features = TEMPLATE_PARAMS["features"]
|
|
47
|
+
orig_features = features.copy()
|
|
48
|
+
compressed_features = TEMPLATE_PARAMS["compressed_features"]
|
|
49
|
+
train_all_data = TEMPLATE_PARAMS["train_all_data"]
|
|
50
|
+
hyperparameters = TEMPLATE_PARAMS["hyperparameters"] or {}
|
|
51
|
+
validation_split = 0.2
|
|
52
|
+
|
|
53
|
+
# Script arguments for input/output directories
|
|
54
|
+
parser = argparse.ArgumentParser()
|
|
55
|
+
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
56
|
+
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
57
|
+
parser.add_argument(
|
|
58
|
+
"--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
|
|
59
|
+
)
|
|
60
|
+
args = parser.parse_args()
|
|
61
|
+
|
|
62
|
+
# Read the training data into DataFrames
|
|
63
|
+
training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
|
|
64
|
+
print(f"Training Files: {training_files}")
|
|
65
|
+
|
|
66
|
+
# Combine files and read them all into a single pandas dataframe
|
|
67
|
+
all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
|
|
68
|
+
|
|
69
|
+
# Check if the dataframe is empty
|
|
70
|
+
check_dataframe(all_df, "training_df")
|
|
71
|
+
|
|
72
|
+
# Features/Target output
|
|
73
|
+
print(f"Target: {target}")
|
|
74
|
+
print(f"Features: {str(features)}")
|
|
75
|
+
|
|
76
|
+
# Convert any features that might be categorical to 'category' type
|
|
77
|
+
all_df, category_mappings = convert_categorical_types(all_df, features)
|
|
78
|
+
|
|
79
|
+
# If we have compressed features, decompress them
|
|
80
|
+
if compressed_features:
|
|
81
|
+
print(f"Decompressing features {compressed_features}...")
|
|
82
|
+
all_df, features = decompress_features(all_df, features, compressed_features)
|
|
83
|
+
|
|
84
|
+
# Do we want to train on all the data?
|
|
85
|
+
if train_all_data:
|
|
86
|
+
print("Training on ALL of the data")
|
|
87
|
+
df_train = all_df.copy()
|
|
88
|
+
df_val = all_df.copy()
|
|
89
|
+
|
|
90
|
+
# Does the dataframe have a training column?
|
|
91
|
+
elif "training" in all_df.columns:
|
|
92
|
+
print("Found training column, splitting data based on training column")
|
|
93
|
+
df_train = all_df[all_df["training"]]
|
|
94
|
+
df_val = all_df[~all_df["training"]]
|
|
95
|
+
else:
|
|
96
|
+
# Just do a random training Split
|
|
97
|
+
print("WARNING: No training column found, splitting data with random state=42")
|
|
98
|
+
df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
|
|
99
|
+
print(f"FIT/TRAIN: {df_train.shape}")
|
|
100
|
+
print(f"VALIDATION: {df_val.shape}")
|
|
101
|
+
|
|
102
|
+
# Extract sample weights if present
|
|
103
|
+
if "sample_weight" in df_train.columns:
|
|
104
|
+
sample_weights = df_train["sample_weight"]
|
|
105
|
+
print(f"Using sample weights: min={sample_weights.min():.2f}, max={sample_weights.max():.2f}, mean={sample_weights.mean():.2f}")
|
|
106
|
+
else:
|
|
107
|
+
sample_weights = None
|
|
108
|
+
print("No sample weights found, training with equal weights")
|
|
109
|
+
|
|
110
|
+
# Prepare features and targets for training
|
|
111
|
+
X_train = df_train[features]
|
|
112
|
+
X_validate = df_val[features]
|
|
113
|
+
y_train = df_train[target]
|
|
114
|
+
y_validate = df_val[target]
|
|
115
|
+
|
|
116
|
+
# ==========================================
|
|
117
|
+
# Train XGBoost for point predictions
|
|
118
|
+
# ==========================================
|
|
119
|
+
print("\nTraining XGBoost for point predictions...")
|
|
120
|
+
print(f" Hyperparameters: {hyperparameters}")
|
|
121
|
+
xgb_model = XGBRegressor(enable_categorical=True, **hyperparameters)
|
|
122
|
+
xgb_model.fit(X_train, y_train, sample_weight=sample_weights)
|
|
123
|
+
|
|
124
|
+
# Evaluate XGBoost performance
|
|
125
|
+
y_pred_xgb = xgb_model.predict(X_validate)
|
|
126
|
+
xgb_metrics = compute_regression_metrics(y_validate, y_pred_xgb)
|
|
127
|
+
|
|
128
|
+
print(f"\nXGBoost Point Prediction Performance:")
|
|
129
|
+
print_regression_metrics(xgb_metrics)
|
|
130
|
+
|
|
131
|
+
# ==========================================
|
|
132
|
+
# Train UQ models using the harness
|
|
133
|
+
# ==========================================
|
|
134
|
+
uq_models, uq_metadata = train_uq_models(X_train, y_train, X_validate, y_validate)
|
|
135
|
+
|
|
136
|
+
print(f"\nOverall Model Performance Summary:")
|
|
137
|
+
print_regression_metrics(xgb_metrics)
|
|
138
|
+
|
|
139
|
+
# ==========================================
|
|
140
|
+
# Save all models
|
|
141
|
+
# ==========================================
|
|
142
|
+
# Save the trained XGBoost model
|
|
143
|
+
joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
|
|
144
|
+
|
|
145
|
+
# Save UQ models using the harness
|
|
146
|
+
save_uq_models(uq_models, uq_metadata, args.model_dir)
|
|
147
|
+
|
|
148
|
+
# Save the feature list
|
|
149
|
+
with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
|
|
150
|
+
json.dump(features, fp)
|
|
151
|
+
|
|
152
|
+
# Save category mappings if any
|
|
153
|
+
if category_mappings:
|
|
154
|
+
with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as fp:
|
|
155
|
+
json.dump(category_mappings, fp)
|
|
156
|
+
|
|
157
|
+
# Save model configuration
|
|
158
|
+
model_config = {
|
|
159
|
+
"model_type": "XGBoost_MAPIE_UQ",
|
|
160
|
+
"confidence_levels": uq_metadata["confidence_levels"],
|
|
161
|
+
"n_features": len(features),
|
|
162
|
+
"target": target,
|
|
163
|
+
"validation_metrics": {
|
|
164
|
+
"xgb_rmse": float(xgb_metrics["rmse"]),
|
|
165
|
+
"xgb_mae": float(xgb_metrics["mae"]),
|
|
166
|
+
"xgb_r2": float(xgb_metrics["r2"]),
|
|
167
|
+
"n_validation": len(df_val),
|
|
168
|
+
},
|
|
169
|
+
}
|
|
170
|
+
with open(os.path.join(args.model_dir, "model_config.json"), "w") as fp:
|
|
171
|
+
json.dump(model_config, fp, indent=2)
|
|
172
|
+
|
|
173
|
+
print(f"\nModel training complete!")
|
|
174
|
+
print(f"Saved XGBoost model and {len(uq_models)} UQ models to {args.model_dir}")
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
#
|
|
178
|
+
# Inference Section
|
|
179
|
+
#
|
|
180
|
+
def model_fn(model_dir) -> dict:
|
|
181
|
+
"""Load XGBoost and all UQ models from the specified directory."""
|
|
182
|
+
|
|
183
|
+
# Load model configuration
|
|
184
|
+
with open(os.path.join(model_dir, "model_config.json")) as fp:
|
|
185
|
+
config = json.load(fp)
|
|
186
|
+
|
|
187
|
+
# Load XGBoost regressor
|
|
188
|
+
xgb_path = os.path.join(model_dir, "xgb_model.joblib")
|
|
189
|
+
xgb_model = joblib.load(xgb_path)
|
|
190
|
+
|
|
191
|
+
# Load UQ models using the harness
|
|
192
|
+
uq_models, uq_metadata = load_uq_models(model_dir)
|
|
193
|
+
|
|
194
|
+
# Load category mappings if they exist
|
|
195
|
+
category_mappings = {}
|
|
196
|
+
category_path = os.path.join(model_dir, "category_mappings.json")
|
|
197
|
+
if os.path.exists(category_path):
|
|
198
|
+
with open(category_path) as fp:
|
|
199
|
+
category_mappings = json.load(fp)
|
|
200
|
+
|
|
201
|
+
return {
|
|
202
|
+
"xgb_model": xgb_model,
|
|
203
|
+
"uq_models": uq_models,
|
|
204
|
+
"uq_metadata": uq_metadata,
|
|
205
|
+
"category_mappings": category_mappings,
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def predict_fn(df, models) -> pd.DataFrame:
|
|
210
|
+
"""Make predictions using XGBoost for point estimates and UQ harness for intervals.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
df (pd.DataFrame): The input DataFrame
|
|
214
|
+
models (dict): Dictionary containing XGBoost and UQ models
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
pd.DataFrame: DataFrame with predictions and conformalized intervals
|
|
218
|
+
"""
|
|
219
|
+
# Grab our feature columns (from training)
|
|
220
|
+
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
221
|
+
with open(os.path.join(model_dir, "feature_columns.json")) as fp:
|
|
222
|
+
model_features = json.load(fp)
|
|
223
|
+
|
|
224
|
+
# Match features in a case-insensitive manner
|
|
225
|
+
matched_df = match_features_case_insensitive(df, model_features)
|
|
226
|
+
|
|
227
|
+
# Apply categorical mappings if they exist
|
|
228
|
+
if models.get("category_mappings"):
|
|
229
|
+
matched_df, _ = convert_categorical_types(matched_df, model_features, models["category_mappings"])
|
|
230
|
+
|
|
231
|
+
# Get features for prediction
|
|
232
|
+
X = matched_df[model_features]
|
|
233
|
+
|
|
234
|
+
# Get XGBoost point predictions
|
|
235
|
+
df["prediction"] = models["xgb_model"].predict(X)
|
|
236
|
+
|
|
237
|
+
# Get prediction intervals using UQ harness
|
|
238
|
+
df = predict_intervals(df, X, models["uq_models"], models["uq_metadata"])
|
|
239
|
+
|
|
240
|
+
# Compute confidence scores
|
|
241
|
+
df = compute_confidence(
|
|
242
|
+
df,
|
|
243
|
+
median_interval_width=models["uq_metadata"]["median_interval_width"],
|
|
244
|
+
lower_q="q_10",
|
|
245
|
+
upper_q="q_90",
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
return df
|