workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,11 @@ None
4
4
  # Template Placeholders
5
5
  TEMPLATE_PARAMS = {
6
6
  "model_type": "regressor",
7
- "target_column": "solubility",
8
- "feature_list": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
7
+ "target_column": "udm_asy_res_efflux_ratio",
8
+ "feature_list": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo', 'tertiary_amine_count', 'type_i_pattern_count', 'type_ii_pattern_count', 'aromatic_interaction_score', 'molecular_axis_length', 'molecular_asymmetry', 'molecular_volume_3d', 'radius_of_gyration', 'asphericity', 'charge_centroid_distance', 'nitrogen_span', 'amide_count', 'hba_hbd_ratio', 'intramolecular_hbond_potential', 'amphiphilic_moment'],
9
9
  "model_class": PyTorch,
10
- "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-pytorch-reg/training",
11
- "train_all_data": False
10
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-pytorch-test/training",
11
+ "train_all_data": False,
12
12
  }
13
13
 
14
14
  import awswrangler as wr
@@ -99,10 +99,7 @@ if __name__ == "__main__":
99
99
  args = parser.parse_args()
100
100
 
101
101
  # Load training data from the specified directory
102
- training_files = [
103
- os.path.join(args.train, file)
104
- for file in os.listdir(args.train) if file.endswith(".csv")
105
- ]
102
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
106
103
  all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
107
104
 
108
105
  # Check if the DataFrame is empty
@@ -116,10 +113,7 @@ if __name__ == "__main__":
116
113
 
117
114
  if needs_standardization:
118
115
  # Create a pipeline with standardization and the model
119
- model = Pipeline([
120
- ("scaler", StandardScaler()),
121
- ("model", model)
122
- ])
116
+ model = Pipeline([("scaler", StandardScaler()), ("model", model)])
123
117
 
124
118
  # Handle logic based on the model_type
125
119
  if model_type in ["classifier", "regressor"]:
@@ -206,6 +200,7 @@ if __name__ == "__main__":
206
200
  with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
207
201
  json.dump(feature_list, fp)
208
202
 
203
+
209
204
  #
210
205
  # Inference Section
211
206
  #
@@ -8,7 +8,7 @@ TEMPLATE_PARAMS = {
8
8
  "feature_list": "{{feature_list}}",
9
9
  "model_class": "{{model_class}}",
10
10
  "model_metrics_s3_path": "{{model_metrics_s3_path}}",
11
- "train_all_data": "{{train_all_data}}"
11
+ "train_all_data": "{{train_all_data}}",
12
12
  }
13
13
 
14
14
  import awswrangler as wr
@@ -99,10 +99,7 @@ if __name__ == "__main__":
99
99
  args = parser.parse_args()
100
100
 
101
101
  # Load training data from the specified directory
102
- training_files = [
103
- os.path.join(args.train, file)
104
- for file in os.listdir(args.train) if file.endswith(".csv")
105
- ]
102
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
106
103
  all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
107
104
 
108
105
  # Check if the DataFrame is empty
@@ -116,10 +113,7 @@ if __name__ == "__main__":
116
113
 
117
114
  if needs_standardization:
118
115
  # Create a pipeline with standardization and the model
119
- model = Pipeline([
120
- ("scaler", StandardScaler()),
121
- ("model", model)
122
- ])
116
+ model = Pipeline([("scaler", StandardScaler()), ("model", model)])
123
117
 
124
118
  # Handle logic based on the model_type
125
119
  if model_type in ["classifier", "regressor"]:
@@ -206,6 +200,7 @@ if __name__ == "__main__":
206
200
  with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
207
201
  json.dump(feature_list, fp)
208
202
 
203
+
209
204
  #
210
205
  # Inference Section
211
206
  #
@@ -70,6 +70,11 @@ def fill_template(template_path: str, params: dict, output_script: str) -> str:
70
70
  # Sanity check to ensure all placeholders were replaced
71
71
  if "{{" in template and "}}" in template:
72
72
  msg = "Not all template placeholders were replaced. Please check your params."
73
+
74
+ # Show which placeholders are still present
75
+ start = template.index("{{")
76
+ end = template.index("}}", start) + 2
77
+ msg += f" Unreplaced placeholder: {template[start:end]}"
73
78
  log.critical(msg)
74
79
  raise ValueError(msg)
75
80
 
@@ -88,32 +93,36 @@ def generate_model_script(template_params: dict) -> str:
88
93
  template_params (dict): Dictionary containing the parameters:
89
94
  - model_imports (str): Import string for the model class
90
95
  - model_type (ModelType): The enumerated type of model to generate
96
+ - model_framework (str): The enumerated model framework to use
91
97
  - model_class (str): The model class to use (e.g., "RandomForestRegressor")
92
98
  - target_column (str): Column name of the target variable
93
99
  - feature_list (list[str]): A list of columns for the features
94
100
  - model_metrics_s3_path (str): The S3 path to store the model metrics
95
101
  - train_all_data (bool): Whether to train on all (100%) of the data
96
102
  - hyperparameters (dict, optional): Hyperparameters for the model (default: None)
103
+ - child_endpoints (list[str], optional): For META models, list of child endpoint names
97
104
 
98
105
  Returns:
99
106
  str: The name of the generated model script
100
107
  """
101
- from workbench.api import ModelType # Avoid circular import
108
+ from workbench.api import ModelType, ModelFramework # Avoid circular import
102
109
 
103
110
  # Determine which template to use based on model type
104
111
  if template_params.get("model_class"):
105
- if template_params["model_class"].lower() == "pytorch":
106
- template_name = "pytorch.template"
107
- model_script_dir = "pytorch_model"
108
- else:
109
- template_name = "scikit_learn.template"
110
- model_script_dir = "scikit_learn"
111
- elif template_params["model_type"] in [ModelType.REGRESSOR, ModelType.CLASSIFIER]:
112
+ template_name = "scikit_learn.template"
113
+ model_script_dir = "scikit_learn"
114
+ elif template_params["model_framework"] == ModelFramework.PYTORCH:
115
+ template_name = "pytorch.template"
116
+ model_script_dir = "pytorch_model"
117
+ elif template_params["model_framework"] == ModelFramework.CHEMPROP:
118
+ template_name = "chemprop.template"
119
+ model_script_dir = "chemprop"
120
+ elif template_params["model_framework"] == ModelFramework.META:
121
+ template_name = "meta_model.template"
122
+ model_script_dir = "meta_model"
123
+ elif template_params["model_type"] in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.CLASSIFIER]:
112
124
  template_name = "xgb_model.template"
113
125
  model_script_dir = "xgb_model"
114
- elif template_params["model_type"] == ModelType.UQ_REGRESSOR:
115
- template_name = "quant_regression.template"
116
- model_script_dir = "quant_regression"
117
126
  elif template_params["model_type"] == ModelType.ENSEMBLE_REGRESSOR:
118
127
  template_name = "ensemble_xgb.template"
119
128
  model_script_dir = "ensemble_xgb"
@@ -0,0 +1,248 @@
1
+ # Model: XGBoost for point predictions + MAPIE UQ Harness for conformalized intervals
2
+ from xgboost import XGBRegressor
3
+ from sklearn.model_selection import train_test_split
4
+
5
+ import json
6
+ import argparse
7
+ import joblib
8
+ import os
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ # Shared model script utilities
13
+ from model_script_utils import (
14
+ check_dataframe,
15
+ match_features_case_insensitive,
16
+ convert_categorical_types,
17
+ decompress_features,
18
+ input_fn,
19
+ output_fn,
20
+ compute_regression_metrics,
21
+ print_regression_metrics,
22
+ )
23
+
24
+ # UQ Harness for uncertainty quantification
25
+ from uq_harness import (
26
+ train_uq_models,
27
+ save_uq_models,
28
+ load_uq_models,
29
+ predict_intervals,
30
+ compute_confidence,
31
+ )
32
+
33
+ # Template Placeholders
34
+ TEMPLATE_PARAMS = {
35
+ "target": "solubility",
36
+ "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
37
+ "compressed_features": [],
38
+ "train_all_data": False,
39
+ "hyperparameters": {'training_config': {'max_epochs': 150}, 'model_config': {'layers': '128-64-32'}},
40
+ }
41
+
42
+
43
+ if __name__ == "__main__":
44
+ # Template Parameters
45
+ target = TEMPLATE_PARAMS["target"]
46
+ features = TEMPLATE_PARAMS["features"]
47
+ orig_features = features.copy()
48
+ compressed_features = TEMPLATE_PARAMS["compressed_features"]
49
+ train_all_data = TEMPLATE_PARAMS["train_all_data"]
50
+ hyperparameters = TEMPLATE_PARAMS["hyperparameters"] or {}
51
+ validation_split = 0.2
52
+
53
+ # Script arguments for input/output directories
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
56
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
57
+ parser.add_argument(
58
+ "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
59
+ )
60
+ args = parser.parse_args()
61
+
62
+ # Read the training data into DataFrames
63
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
64
+ print(f"Training Files: {training_files}")
65
+
66
+ # Combine files and read them all into a single pandas dataframe
67
+ all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
68
+
69
+ # Check if the dataframe is empty
70
+ check_dataframe(all_df, "training_df")
71
+
72
+ # Features/Target output
73
+ print(f"Target: {target}")
74
+ print(f"Features: {str(features)}")
75
+
76
+ # Convert any features that might be categorical to 'category' type
77
+ all_df, category_mappings = convert_categorical_types(all_df, features)
78
+
79
+ # If we have compressed features, decompress them
80
+ if compressed_features:
81
+ print(f"Decompressing features {compressed_features}...")
82
+ all_df, features = decompress_features(all_df, features, compressed_features)
83
+
84
+ # Do we want to train on all the data?
85
+ if train_all_data:
86
+ print("Training on ALL of the data")
87
+ df_train = all_df.copy()
88
+ df_val = all_df.copy()
89
+
90
+ # Does the dataframe have a training column?
91
+ elif "training" in all_df.columns:
92
+ print("Found training column, splitting data based on training column")
93
+ df_train = all_df[all_df["training"]]
94
+ df_val = all_df[~all_df["training"]]
95
+ else:
96
+ # Just do a random training Split
97
+ print("WARNING: No training column found, splitting data with random state=42")
98
+ df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
99
+ print(f"FIT/TRAIN: {df_train.shape}")
100
+ print(f"VALIDATION: {df_val.shape}")
101
+
102
+ # Extract sample weights if present
103
+ if "sample_weight" in df_train.columns:
104
+ sample_weights = df_train["sample_weight"]
105
+ print(f"Using sample weights: min={sample_weights.min():.2f}, max={sample_weights.max():.2f}, mean={sample_weights.mean():.2f}")
106
+ else:
107
+ sample_weights = None
108
+ print("No sample weights found, training with equal weights")
109
+
110
+ # Prepare features and targets for training
111
+ X_train = df_train[features]
112
+ X_validate = df_val[features]
113
+ y_train = df_train[target]
114
+ y_validate = df_val[target]
115
+
116
+ # ==========================================
117
+ # Train XGBoost for point predictions
118
+ # ==========================================
119
+ print("\nTraining XGBoost for point predictions...")
120
+ print(f" Hyperparameters: {hyperparameters}")
121
+ xgb_model = XGBRegressor(enable_categorical=True, **hyperparameters)
122
+ xgb_model.fit(X_train, y_train, sample_weight=sample_weights)
123
+
124
+ # Evaluate XGBoost performance
125
+ y_pred_xgb = xgb_model.predict(X_validate)
126
+ xgb_metrics = compute_regression_metrics(y_validate, y_pred_xgb)
127
+
128
+ print(f"\nXGBoost Point Prediction Performance:")
129
+ print_regression_metrics(xgb_metrics)
130
+
131
+ # ==========================================
132
+ # Train UQ models using the harness
133
+ # ==========================================
134
+ uq_models, uq_metadata = train_uq_models(X_train, y_train, X_validate, y_validate)
135
+
136
+ print(f"\nOverall Model Performance Summary:")
137
+ print_regression_metrics(xgb_metrics)
138
+
139
+ # ==========================================
140
+ # Save all models
141
+ # ==========================================
142
+ # Save the trained XGBoost model
143
+ joblib.dump(xgb_model, os.path.join(args.model_dir, "xgb_model.joblib"))
144
+
145
+ # Save UQ models using the harness
146
+ save_uq_models(uq_models, uq_metadata, args.model_dir)
147
+
148
+ # Save the feature list
149
+ with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
150
+ json.dump(features, fp)
151
+
152
+ # Save category mappings if any
153
+ if category_mappings:
154
+ with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as fp:
155
+ json.dump(category_mappings, fp)
156
+
157
+ # Save model configuration
158
+ model_config = {
159
+ "model_type": "XGBoost_MAPIE_UQ",
160
+ "confidence_levels": uq_metadata["confidence_levels"],
161
+ "n_features": len(features),
162
+ "target": target,
163
+ "validation_metrics": {
164
+ "xgb_rmse": float(xgb_metrics["rmse"]),
165
+ "xgb_mae": float(xgb_metrics["mae"]),
166
+ "xgb_r2": float(xgb_metrics["r2"]),
167
+ "n_validation": len(df_val),
168
+ },
169
+ }
170
+ with open(os.path.join(args.model_dir, "model_config.json"), "w") as fp:
171
+ json.dump(model_config, fp, indent=2)
172
+
173
+ print(f"\nModel training complete!")
174
+ print(f"Saved XGBoost model and {len(uq_models)} UQ models to {args.model_dir}")
175
+
176
+
177
+ #
178
+ # Inference Section
179
+ #
180
+ def model_fn(model_dir) -> dict:
181
+ """Load XGBoost and all UQ models from the specified directory."""
182
+
183
+ # Load model configuration
184
+ with open(os.path.join(model_dir, "model_config.json")) as fp:
185
+ config = json.load(fp)
186
+
187
+ # Load XGBoost regressor
188
+ xgb_path = os.path.join(model_dir, "xgb_model.joblib")
189
+ xgb_model = joblib.load(xgb_path)
190
+
191
+ # Load UQ models using the harness
192
+ uq_models, uq_metadata = load_uq_models(model_dir)
193
+
194
+ # Load category mappings if they exist
195
+ category_mappings = {}
196
+ category_path = os.path.join(model_dir, "category_mappings.json")
197
+ if os.path.exists(category_path):
198
+ with open(category_path) as fp:
199
+ category_mappings = json.load(fp)
200
+
201
+ return {
202
+ "xgb_model": xgb_model,
203
+ "uq_models": uq_models,
204
+ "uq_metadata": uq_metadata,
205
+ "category_mappings": category_mappings,
206
+ }
207
+
208
+
209
+ def predict_fn(df, models) -> pd.DataFrame:
210
+ """Make predictions using XGBoost for point estimates and UQ harness for intervals.
211
+
212
+ Args:
213
+ df (pd.DataFrame): The input DataFrame
214
+ models (dict): Dictionary containing XGBoost and UQ models
215
+
216
+ Returns:
217
+ pd.DataFrame: DataFrame with predictions and conformalized intervals
218
+ """
219
+ # Grab our feature columns (from training)
220
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
221
+ with open(os.path.join(model_dir, "feature_columns.json")) as fp:
222
+ model_features = json.load(fp)
223
+
224
+ # Match features in a case-insensitive manner
225
+ matched_df = match_features_case_insensitive(df, model_features)
226
+
227
+ # Apply categorical mappings if they exist
228
+ if models.get("category_mappings"):
229
+ matched_df, _ = convert_categorical_types(matched_df, model_features, models["category_mappings"])
230
+
231
+ # Get features for prediction
232
+ X = matched_df[model_features]
233
+
234
+ # Get XGBoost point predictions
235
+ df["prediction"] = models["xgb_model"].predict(X)
236
+
237
+ # Get prediction intervals using UQ harness
238
+ df = predict_intervals(df, X, models["uq_models"], models["uq_metadata"])
239
+
240
+ # Compute confidence scores
241
+ df = compute_confidence(
242
+ df,
243
+ median_interval_width=models["uq_metadata"]["median_interval_width"],
244
+ lower_q="q_10",
245
+ upper_q="q_90",
246
+ )
247
+
248
+ return df