workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,11 @@ None
4
4
  # Template Placeholders
5
5
  TEMPLATE_PARAMS = {
6
6
  "model_type": "regressor",
7
- "target_column": "solubility",
8
- "feature_list": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
7
+ "target_column": "udm_asy_res_efflux_ratio",
8
+ "feature_list": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo', 'tertiary_amine_count', 'type_i_pattern_count', 'type_ii_pattern_count', 'aromatic_interaction_score', 'molecular_axis_length', 'molecular_asymmetry', 'molecular_volume_3d', 'radius_of_gyration', 'asphericity', 'charge_centroid_distance', 'nitrogen_span', 'amide_count', 'hba_hbd_ratio', 'intramolecular_hbond_potential', 'amphiphilic_moment'],
9
9
  "model_class": PyTorch,
10
- "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-pytorch-reg/training",
11
- "train_all_data": False
10
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-pytorch-test/training",
11
+ "train_all_data": False,
12
12
  }
13
13
 
14
14
  import awswrangler as wr
@@ -99,10 +99,7 @@ if __name__ == "__main__":
99
99
  args = parser.parse_args()
100
100
 
101
101
  # Load training data from the specified directory
102
- training_files = [
103
- os.path.join(args.train, file)
104
- for file in os.listdir(args.train) if file.endswith(".csv")
105
- ]
102
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
106
103
  all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
107
104
 
108
105
  # Check if the DataFrame is empty
@@ -116,10 +113,7 @@ if __name__ == "__main__":
116
113
 
117
114
  if needs_standardization:
118
115
  # Create a pipeline with standardization and the model
119
- model = Pipeline([
120
- ("scaler", StandardScaler()),
121
- ("model", model)
122
- ])
116
+ model = Pipeline([("scaler", StandardScaler()), ("model", model)])
123
117
 
124
118
  # Handle logic based on the model_type
125
119
  if model_type in ["classifier", "regressor"]:
@@ -206,6 +200,7 @@ if __name__ == "__main__":
206
200
  with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
207
201
  json.dump(feature_list, fp)
208
202
 
203
+
209
204
  #
210
205
  # Inference Section
211
206
  #
@@ -8,7 +8,7 @@ TEMPLATE_PARAMS = {
8
8
  "feature_list": "{{feature_list}}",
9
9
  "model_class": "{{model_class}}",
10
10
  "model_metrics_s3_path": "{{model_metrics_s3_path}}",
11
- "train_all_data": "{{train_all_data}}"
11
+ "train_all_data": "{{train_all_data}}",
12
12
  }
13
13
 
14
14
  import awswrangler as wr
@@ -99,10 +99,7 @@ if __name__ == "__main__":
99
99
  args = parser.parse_args()
100
100
 
101
101
  # Load training data from the specified directory
102
- training_files = [
103
- os.path.join(args.train, file)
104
- for file in os.listdir(args.train) if file.endswith(".csv")
105
- ]
102
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
106
103
  all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
107
104
 
108
105
  # Check if the DataFrame is empty
@@ -116,10 +113,7 @@ if __name__ == "__main__":
116
113
 
117
114
  if needs_standardization:
118
115
  # Create a pipeline with standardization and the model
119
- model = Pipeline([
120
- ("scaler", StandardScaler()),
121
- ("model", model)
122
- ])
116
+ model = Pipeline([("scaler", StandardScaler()), ("model", model)])
123
117
 
124
118
  # Handle logic based on the model_type
125
119
  if model_type in ["classifier", "regressor"]:
@@ -206,6 +200,7 @@ if __name__ == "__main__":
206
200
  with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
207
201
  json.dump(feature_list, fp)
208
202
 
203
+
209
204
  #
210
205
  # Inference Section
211
206
  #
@@ -6,7 +6,6 @@ import logging
6
6
  from pathlib import Path
7
7
  import importlib.util
8
8
 
9
-
10
9
  # Setup the logger
11
10
  log = logging.getLogger("workbench")
12
11
 
@@ -93,32 +92,36 @@ def generate_model_script(template_params: dict) -> str:
93
92
  template_params (dict): Dictionary containing the parameters:
94
93
  - model_imports (str): Import string for the model class
95
94
  - model_type (ModelType): The enumerated type of model to generate
95
+ - model_framework (str): The enumerated model framework to use
96
96
  - model_class (str): The model class to use (e.g., "RandomForestRegressor")
97
97
  - target_column (str): Column name of the target variable
98
98
  - feature_list (list[str]): A list of columns for the features
99
99
  - model_metrics_s3_path (str): The S3 path to store the model metrics
100
100
  - train_all_data (bool): Whether to train on all (100%) of the data
101
101
  - hyperparameters (dict, optional): Hyperparameters for the model (default: None)
102
+ - child_endpoints (list[str], optional): For META models, list of child endpoint names
102
103
 
103
104
  Returns:
104
105
  str: The name of the generated model script
105
106
  """
106
- from workbench.api import ModelType # Avoid circular import
107
+ from workbench.api import ModelType, ModelFramework # Avoid circular import
107
108
 
108
109
  # Determine which template to use based on model type
109
110
  if template_params.get("model_class"):
110
- if template_params["model_class"].lower() == "pytorch":
111
- template_name = "pytorch.template"
112
- model_script_dir = "pytorch_model"
113
- else:
114
- template_name = "scikit_learn.template"
115
- model_script_dir = "scikit_learn"
116
- elif template_params["model_type"] in [ModelType.REGRESSOR, ModelType.CLASSIFIER]:
111
+ template_name = "scikit_learn.template"
112
+ model_script_dir = "scikit_learn"
113
+ elif template_params["model_framework"] == ModelFramework.PYTORCH:
114
+ template_name = "pytorch.template"
115
+ model_script_dir = "pytorch_model"
116
+ elif template_params["model_framework"] == ModelFramework.CHEMPROP:
117
+ template_name = "chemprop.template"
118
+ model_script_dir = "chemprop"
119
+ elif template_params["model_framework"] == ModelFramework.META:
120
+ template_name = "meta_model.template"
121
+ model_script_dir = "meta_model"
122
+ elif template_params["model_type"] in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.CLASSIFIER]:
117
123
  template_name = "xgb_model.template"
118
124
  model_script_dir = "xgb_model"
119
- elif template_params["model_type"] == ModelType.UQ_REGRESSOR:
120
- template_name = "quant_regression.template"
121
- model_script_dir = "quant_regression"
122
125
  elif template_params["model_type"] == ModelType.ENSEMBLE_REGRESSOR:
123
126
  template_name = "ensemble_xgb.template"
124
127
  model_script_dir = "ensemble_xgb"
@@ -0,0 +1,248 @@
1
+ # Model: XGBoost for point predictions + MAPIE UQ Harness for conformalized intervals
2
+ from xgboost import XGBRegressor
3
+ from sklearn.model_selection import train_test_split
4
+
5
+ import json
6
+ import argparse
7
+ import joblib
8
+ import os
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ # Shared model script utilities
13
+ from model_script_utils import (
14
+ check_dataframe,
15
+ match_features_case_insensitive,
16
+ convert_categorical_types,
17
+ decompress_features,
18
+ input_fn,
19
+ output_fn,
20
+ compute_regression_metrics,
21
+ print_regression_metrics,
22
+ )
23
+
24
+ # UQ Harness for uncertainty quantification
25
+ from uq_harness import (
26
+ train_uq_models,
27
+ save_uq_models,
28
+ load_uq_models,
29
+ predict_intervals,
30
+ compute_confidence,
31
+ )
32
+
33
+ # Template Placeholders
34
+ TEMPLATE_PARAMS = {
35
+ "target": "solubility",
36
+ "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
37
+ "compressed_features": [],
38
+ "train_all_data": False,
39
+ "hyperparameters": {'training_config': {'max_epochs': 150}, 'model_config': {'layers': '128-64-32'}},
40
+ }
41
+
42
+
43
+ if __name__ == "__main__":
44
+ # Template Parameters
45
+ target = TEMPLATE_PARAMS["target"]
46
+ features = TEMPLATE_PARAMS["features"]
47
+ orig_features = features.copy()
48
+ compressed_features = TEMPLATE_PARAMS["compressed_features"]
49
+ train_all_data = TEMPLATE_PARAMS["train_all_data"]
50
+ hyperparameters = TEMPLATE_PARAMS["hyperparameters"] or {}
51
+ validation_split = 0.2
52
+
53
+ # Script arguments for input/output directories
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
56
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
57
+ parser.add_argument(
58
+ "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
59
+ )
60
+ args = parser.parse_args()
61
+
62
+ # Read the training data into DataFrames
63
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
64
+ print(f"Training Files: {training_files}")
65
+
66
+ # Combine files and read them all into a single pandas dataframe
67
+ all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
68
+
69
+ # Check if the dataframe is empty
70
+ check_dataframe(all_df, "training_df")
71
+
72
+ # Features/Target output
73
+ print(f"Target: {target}")
74
+ print(f"Features: {str(features)}")
75
+
76
+ # Convert any features that might be categorical to 'category' type
77
+ all_df, category_mappings = convert_categorical_types(all_df, features)
78
+
79
+ # If we have compressed features, decompress them
80
+ if compressed_features:
81
+ print(f"Decompressing features {compressed_features}...")
82
+ all_df, features = decompress_features(all_df, features, compressed_features)
83
+
84
+ # Do we want to train on all the data?
85
+ if train_all_data:
86
+ print("Training on ALL of the data")
87
+ df_train = all_df.copy()
88
+ df_val = all_df.copy()
89
+
90
+ # Does the dataframe have a training column?
91
+ elif "training" in all_df.columns:
92
+ print("Found training column, splitting data based on training column")
93
+ df_train = all_df[all_df["training"]]
94
+ df_val = all_df[~all_df["training"]]
95
+ else:
96
+ # Just do a random training Split
97
+ print("WARNING: No training column found, splitting data with random state=42")
98
+ df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
99
+ print(f"FIT/TRAIN: {df_train.shape}")
100
+ print(f"VALIDATION: {df_val.shape}")
101
+
102
+ # Extract sample weights if present
103
+ if "sample_weight" in df_train.columns:
104
+ sample_weights = df_train["sample_weight"]
105
+ print(f"Using sample weights: min={sample_weights.min():.2f}, max={sample_weights.max():.2f}, mean={sample_weights.mean():.2f}")
106
+ else:
107
+ sample_weights = None
108
+ print("No sample weights found, training with equal weights")
109
+
110
+ # Prepare features and targets for training
111
+ X_train = df_train[features]
112
+ X_validate = df_val[features]
113
+ y_train = df_train[target]
114
+ y_validate = df_val[target]
115
+
116
+ # ==========================================
117
+ # Train XGBoost for point predictions
118
+ # ==========================================
119
+ print("\nTraining XGBoost for point predictions...")
120
+ print(f" Hyperparameters: {hyperparameters}")
121
+ xgb_model = XGBRegressor(enable_categorical=True, **hyperparameters)
122
+ xgb_model.fit(X_train, y_train, sample_weight=sample_weights)
123
+
124
+ # Evaluate XGBoost performance
125
+ y_pred_xgb = xgb_model.predict(X_validate)
126
+ xgb_metrics = compute_regression_metrics(y_validate, y_pred_xgb)
127
+
128
+ print(f"\nXGBoost Point Prediction Performance:")
129
+ print_regression_metrics(xgb_metrics)
130
+
131
+ # ==========================================
132
+ # Train UQ models using the harness
133
+ # ==========================================
134
+ uq_models, uq_metadata = train_uq_models(X_train, y_train, X_validate, y_validate)
135
+
136
+ print(f"\nOverall Model Performance Summary:")
137
+ print_regression_metrics(xgb_metrics)
138
+
139
+ # ==========================================
140
+ # Save all models
141
+ # ==========================================
142
+ # Save the trained XGBoost model
143
+ joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
144
+
145
+ # Save UQ models using the harness
146
+ save_uq_models(uq_models, uq_metadata, args.model_dir)
147
+
148
+ # Save the feature list
149
+ with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
150
+ json.dump(features, fp)
151
+
152
+ # Save category mappings if any
153
+ if category_mappings:
154
+ with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as fp:
155
+ json.dump(category_mappings, fp)
156
+
157
+ # Save model configuration
158
+ model_config = {
159
+ "model_type": "XGBoost_MAPIE_UQ",
160
+ "confidence_levels": uq_metadata["confidence_levels"],
161
+ "n_features": len(features),
162
+ "target": target,
163
+ "validation_metrics": {
164
+ "xgb_rmse": float(xgb_metrics["rmse"]),
165
+ "xgb_mae": float(xgb_metrics["mae"]),
166
+ "xgb_r2": float(xgb_metrics["r2"]),
167
+ "n_validation": len(df_val),
168
+ },
169
+ }
170
+ with open(os.path.join(args.model_dir, "model_config.json"), "w") as fp:
171
+ json.dump(model_config, fp, indent=2)
172
+
173
+ print(f"\nModel training complete!")
174
+ print(f"Saved XGBoost model and {len(uq_models)} UQ models to {args.model_dir}")
175
+
176
+
177
+ #
178
+ # Inference Section
179
+ #
180
+ def model_fn(model_dir) -> dict:
181
+ """Load XGBoost and all UQ models from the specified directory."""
182
+
183
+ # Load model configuration
184
+ with open(os.path.join(model_dir, "model_config.json")) as fp:
185
+ config = json.load(fp)
186
+
187
+ # Load XGBoost regressor
188
+ xgb_path = os.path.join(model_dir, "xgb_model.joblib")
189
+ xgb_model = joblib.load(xgb_path)
190
+
191
+ # Load UQ models using the harness
192
+ uq_models, uq_metadata = load_uq_models(model_dir)
193
+
194
+ # Load category mappings if they exist
195
+ category_mappings = {}
196
+ category_path = os.path.join(model_dir, "category_mappings.json")
197
+ if os.path.exists(category_path):
198
+ with open(category_path) as fp:
199
+ category_mappings = json.load(fp)
200
+
201
+ return {
202
+ "xgb_model": xgb_model,
203
+ "uq_models": uq_models,
204
+ "uq_metadata": uq_metadata,
205
+ "category_mappings": category_mappings,
206
+ }
207
+
208
+
209
+ def predict_fn(df, models) -> pd.DataFrame:
210
+ """Make predictions using XGBoost for point estimates and UQ harness for intervals.
211
+
212
+ Args:
213
+ df (pd.DataFrame): The input DataFrame
214
+ models (dict): Dictionary containing XGBoost and UQ models
215
+
216
+ Returns:
217
+ pd.DataFrame: DataFrame with predictions and conformalized intervals
218
+ """
219
+ # Grab our feature columns (from training)
220
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
221
+ with open(os.path.join(model_dir, "feature_columns.json")) as fp:
222
+ model_features = json.load(fp)
223
+
224
+ # Match features in a case-insensitive manner
225
+ matched_df = match_features_case_insensitive(df, model_features)
226
+
227
+ # Apply categorical mappings if they exist
228
+ if models.get("category_mappings"):
229
+ matched_df, _ = convert_categorical_types(matched_df, model_features, models["category_mappings"])
230
+
231
+ # Get features for prediction
232
+ X = matched_df[model_features]
233
+
234
+ # Get XGBoost point predictions
235
+ df["prediction"] = models["xgb_model"].predict(X)
236
+
237
+ # Get prediction intervals using UQ harness
238
+ df = predict_intervals(df, X, models["uq_models"], models["uq_metadata"])
239
+
240
+ # Compute confidence scores
241
+ df = compute_confidence(
242
+ df,
243
+ median_interval_width=models["uq_metadata"]["median_interval_width"],
244
+ lower_q="q_10",
245
+ upper_q="q_90",
246
+ )
247
+
248
+ return df