workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -1,576 +1,520 @@
1
- # Imports for PyTorch Tabular Model
2
- import os
3
- import awswrangler as wr
4
- import numpy as np
5
-
6
- # PyTorch compatibility: pytorch-tabular saves complex objects, not just tensors
7
- # Use legacy loading behavior for compatibility (recommended by PyTorch docs for this scenario)
8
- os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
9
- from pytorch_tabular import TabularModel
10
- from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
11
- from pytorch_tabular.models import CategoryEmbeddingModelConfig
12
-
13
- # Model Performance Scores
14
- from sklearn.metrics import (
15
- mean_absolute_error,
16
- r2_score,
17
- root_mean_squared_error,
18
- precision_recall_fscore_support,
19
- confusion_matrix,
20
- )
21
-
22
- # Classification Encoder
23
- from sklearn.preprocessing import LabelEncoder
1
+ # PyTorch Tabular Model Template for Workbench
2
+ #
3
+ # This template handles both classification and regression models with:
4
+ # - K-fold cross-validation ensemble training (or single train/val split)
5
+ # - Out-of-fold predictions for validation metrics
6
+ # - Categorical feature embedding via TabularMLP
7
+ # - Compressed feature decompression
8
+ #
9
+ # NOTE: Imports are structured to minimize serverless endpoint startup time.
10
+ # Heavy imports (sklearn, awswrangler) are deferred to training time.
24
11
 
25
- # Scikit Learn Imports
26
- from sklearn.model_selection import train_test_split
27
-
28
- from io import StringIO
29
12
  import json
30
- import argparse
31
- import joblib
32
13
  import os
14
+
15
+ import joblib
16
+ import numpy as np
33
17
  import pandas as pd
34
- from typing import List, Tuple
18
+ import torch
19
+
20
+ from model_script_utils import (
21
+ convert_categorical_types,
22
+ decompress_features,
23
+ expand_proba_column,
24
+ input_fn,
25
+ match_features_case_insensitive,
26
+ output_fn,
27
+ )
28
+ from pytorch_utils import (
29
+ FeatureScaler,
30
+ load_model,
31
+ predict,
32
+ prepare_data,
33
+ )
34
+ from uq_harness import (
35
+ compute_confidence,
36
+ load_uq_models,
37
+ predict_intervals,
38
+ )
35
39
 
36
- # Template Parameters
40
+ # =============================================================================
41
+ # Default Hyperparameters
42
+ # =============================================================================
43
+ DEFAULT_HYPERPARAMETERS = {
44
+ # Training parameters
45
+ "n_folds": 5,
46
+ "max_epochs": 200,
47
+ "early_stopping_patience": 30,
48
+ "batch_size": 128,
49
+ # Model architecture (larger capacity - ensemble provides regularization)
50
+ "layers": "512-256-128",
51
+ "learning_rate": 1e-3,
52
+ "dropout": 0.05,
53
+ "use_batch_norm": True,
54
+ # Loss function for regression (L1Loss=MAE, MSELoss=MSE, HuberLoss, SmoothL1Loss)
55
+ "loss": "L1Loss",
56
+ # Random seed
57
+ "seed": 42,
58
+ }
59
+
60
+ # Template parameters (filled in by Workbench)
37
61
  TEMPLATE_PARAMS = {
38
- "model_type": "classifier",
39
- "target_column": "solubility_class",
40
- "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
62
+ "model_type": "uq_regressor",
63
+ "target": "udm_asy_res_efflux_ratio",
64
+ "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
65
+ "id_column": "udm_mol_bat_id",
41
66
  "compressed_features": [],
42
- "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-pytorch-class/training",
43
- "train_all_data": False,
44
- "hyperparameters": {'training_config': {'max_epochs': 150}, 'model_config': {'layers': '256-128-64'}}
67
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-pytorch-260113/training",
68
+ "hyperparameters": {},
45
69
  }
46
70
 
47
71
 
48
- # Function to check if dataframe is empty
49
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
50
- """
51
- Check if the provided dataframe is empty and raise an exception if it is.
52
-
53
- Args:
54
- df (pd.DataFrame): DataFrame to check
55
- df_name (str): Name of the DataFrame
56
- """
57
- if df.empty:
58
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
59
- print(msg)
60
- raise ValueError(msg)
61
-
62
-
63
- def expand_proba_column(df: pd.DataFrame, class_labels: List[str]) -> pd.DataFrame:
64
- """
65
- Expands a column in a DataFrame containing a list of probabilities into separate columns.
66
-
67
- Args:
68
- df (pd.DataFrame): DataFrame containing a "pred_proba" column
69
- class_labels (List[str]): List of class labels
70
-
71
- Returns:
72
- pd.DataFrame: DataFrame with the "pred_proba" expanded into separate columns
73
- """
74
-
75
- # Sanity check
76
- proba_column = "pred_proba"
77
- if proba_column not in df.columns:
78
- raise ValueError('DataFrame does not contain a "pred_proba" column')
79
-
80
- # Construct new column names with '_proba' suffix
81
- proba_splits = [f"{label}_proba" for label in class_labels]
82
-
83
- # Expand the proba_column into separate columns for each probability
84
- proba_df = pd.DataFrame(df[proba_column].tolist(), columns=proba_splits)
85
-
86
- # Drop any proba columns and reset the index in prep for the concat
87
- df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
88
- df = df.reset_index(drop=True)
89
-
90
- # Concatenate the new columns with the original DataFrame
91
- df = pd.concat([df, proba_df], axis=1)
92
- print(df)
93
- return df
94
-
95
-
96
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
97
- """
98
- Matches and renames DataFrame columns to match model feature names (case-insensitive).
99
- Prioritizes exact matches, then case-insensitive matches.
100
-
101
- Raises ValueError if any model features cannot be matched.
102
- """
103
- df_columns_lower = {col.lower(): col for col in df.columns}
104
- rename_dict = {}
105
- missing = []
106
-
107
- for feature in model_features:
108
- if feature in df.columns:
109
- continue # Exact match
110
- elif feature.lower() in df_columns_lower:
111
- rename_dict[df_columns_lower[feature.lower()]] = feature
112
- else:
113
- missing.append(feature)
114
-
115
- if missing:
116
- raise ValueError(f"Features not found: {missing}")
117
-
118
- return df.rename(columns=rename_dict)
119
-
120
-
121
- def convert_categorical_types(df: pd.DataFrame, features: list, category_mappings={}) -> tuple:
122
- """
123
- Converts appropriate columns to categorical type with consistent mappings.
124
-
125
- Args:
126
- df (pd.DataFrame): The DataFrame to process.
127
- features (list): List of feature names to consider for conversion.
128
- category_mappings (dict, optional): Existing category mappings. If empty dict, we're in
129
- training mode. If populated, we're in inference mode.
130
-
131
- Returns:
132
- tuple: (processed DataFrame, category mappings dictionary)
133
- """
134
- # Training mode
135
- if category_mappings == {}:
136
- for col in df.select_dtypes(include=["object", "string"]):
137
- if col in features and df[col].nunique() < 20:
138
- print(f"Training mode: Converting {col} to category")
139
- df[col] = df[col].astype("category")
140
- category_mappings[col] = df[col].cat.categories.tolist() # Store category mappings
141
-
142
- # Inference mode
143
- else:
144
- for col, categories in category_mappings.items():
145
- if col in df.columns:
146
- print(f"Inference mode: Applying categorical mapping for {col}")
147
- df[col] = pd.Categorical(df[col], categories=categories) # Apply consistent categorical mapping
148
-
149
- return df, category_mappings
150
-
151
-
152
- def decompress_features(
153
- df: pd.DataFrame, features: List[str], compressed_features: List[str]
154
- ) -> Tuple[pd.DataFrame, List[str]]:
155
- """Prepare features for the model
156
-
157
- Args:
158
- df (pd.DataFrame): The features DataFrame
159
- features (List[str]): Full list of feature names
160
- compressed_features (List[str]): List of feature names to decompress (bitstrings)
161
-
162
- Returns:
163
- pd.DataFrame: DataFrame with the decompressed features
164
- List[str]: Updated list of feature names after decompression
165
-
166
- Raises:
167
- ValueError: If any missing values are found in the specified features
168
- """
169
-
170
- # Check for any missing values in the required features
171
- missing_counts = df[features].isna().sum()
172
- if missing_counts.any():
173
- missing_features = missing_counts[missing_counts > 0]
174
- print(
175
- f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
176
- "WARNING: You might want to remove/replace all NaN values before processing."
177
- )
178
-
179
- # Decompress the specified compressed features
180
- decompressed_features = features
181
- for feature in compressed_features:
182
- if (feature not in df.columns) or (feature not in features):
183
- print(f"Feature '{feature}' not in the features list, skipping decompression.")
184
- continue
185
-
186
- # Remove the feature from the list of features to avoid duplication
187
- decompressed_features.remove(feature)
188
-
189
- # Handle all compressed features as bitstrings
190
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
191
- prefix = feature[:3]
192
-
193
- # Create all new columns at once - avoids fragmentation
194
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
195
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
196
-
197
- # Add to features list
198
- decompressed_features.extend(new_col_names)
199
-
200
- # Drop original column and concatenate new ones
201
- df = df.drop(columns=[feature])
202
- df = pd.concat([df, new_df], axis=1)
203
-
204
- return df, decompressed_features
205
-
206
-
207
- def model_fn(model_dir):
208
-
209
- # Save current working directory
210
- original_cwd = os.getcwd()
211
- try:
212
- # Change to /tmp because Pytorch Tabular needs write access (creates a .pt_tmp directory)
213
- os.chdir('/tmp')
214
-
215
- # Load the model
216
- model_path = os.path.join(model_dir, "tabular_model")
217
- model = TabularModel.load_model(model_path)
218
-
219
- # Restore the original working directory
220
- finally:
221
- os.chdir(original_cwd)
222
-
223
- return model
224
-
225
-
226
- def input_fn(input_data, content_type):
227
- """Parse input data and return a DataFrame."""
228
- if not input_data:
229
- raise ValueError("Empty input data is not supported!")
230
-
231
- # Decode bytes to string if necessary
232
- if isinstance(input_data, bytes):
233
- input_data = input_data.decode("utf-8")
234
-
235
- if "text/csv" in content_type:
236
- return pd.read_csv(StringIO(input_data))
237
- elif "application/json" in content_type:
238
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
72
+ # =============================================================================
73
+ # Model Loading (for SageMaker inference)
74
+ # =============================================================================
75
+ def model_fn(model_dir: str) -> dict:
76
+ """Load PyTorch TabularMLP ensemble from the specified directory."""
77
+ # Load ensemble metadata
78
+ metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
79
+ if os.path.exists(metadata_path):
80
+ metadata = joblib.load(metadata_path)
81
+ n_ensemble = metadata["n_ensemble"]
239
82
  else:
240
- raise ValueError(f"{content_type} not supported!")
83
+ n_ensemble = 1
84
+
85
+ # Determine device
86
+ device = "cuda" if torch.cuda.is_available() else "cpu"
87
+
88
+ # Load ensemble models
89
+ ensemble_models = []
90
+ for i in range(n_ensemble):
91
+ model_path = os.path.join(model_dir, f"model_{i}")
92
+ model = load_model(model_path, device=device)
93
+ ensemble_models.append(model)
94
+
95
+ print(f"Loaded {len(ensemble_models)} model(s)")
96
+
97
+ # Load feature scaler
98
+ scaler = FeatureScaler.load(os.path.join(model_dir, "scaler.joblib"))
99
+
100
+ # Load UQ models (regression only)
101
+ uq_models, uq_metadata = None, None
102
+ uq_path = os.path.join(model_dir, "uq_metadata.json")
103
+ if os.path.exists(uq_path):
104
+ uq_models, uq_metadata = load_uq_models(model_dir)
105
+
106
+ return {
107
+ "ensemble_models": ensemble_models,
108
+ "n_ensemble": n_ensemble,
109
+ "scaler": scaler,
110
+ "uq_models": uq_models,
111
+ "uq_metadata": uq_metadata,
112
+ }
241
113
 
242
114
 
243
- def output_fn(output_df, accept_type):
244
- """Supports both CSV and JSON output formats."""
245
- if "text/csv" in accept_type:
246
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
247
- return csv_output, "text/csv"
248
- elif "application/json" in accept_type:
249
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
250
- else:
251
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
115
+ # =============================================================================
116
+ # Inference (for SageMaker inference)
117
+ # =============================================================================
118
+ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
119
+ """Make predictions with PyTorch TabularMLP ensemble."""
120
+ model_type = TEMPLATE_PARAMS["model_type"]
121
+ compressed_features = TEMPLATE_PARAMS["compressed_features"]
122
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
252
123
 
124
+ # Load artifacts
125
+ ensemble_models = model_dict["ensemble_models"]
126
+ scaler = model_dict["scaler"]
127
+ uq_models = model_dict.get("uq_models")
128
+ uq_metadata = model_dict.get("uq_metadata")
253
129
 
254
- def predict_fn(df, model) -> pd.DataFrame:
255
- """Make Predictions with our PyTorch Tabular Model
130
+ with open(os.path.join(model_dir, "feature_columns.json")) as f:
131
+ features = json.load(f)
132
+ with open(os.path.join(model_dir, "category_mappings.json")) as f:
133
+ category_mappings = json.load(f)
134
+ with open(os.path.join(model_dir, "feature_metadata.json")) as f:
135
+ feature_metadata = json.load(f)
256
136
 
257
- Args:
258
- df (pd.DataFrame): The input DataFrame
259
- model: The TabularModel use for predictions
137
+ continuous_cols = feature_metadata["continuous_cols"]
138
+ categorical_cols = feature_metadata["categorical_cols"]
260
139
 
261
- Returns:
262
- pd.DataFrame: The DataFrame with the predictions added
263
- """
264
- compressed_features = TEMPLATE_PARAMS["compressed_features"]
140
+ label_encoder = None
141
+ encoder_path = os.path.join(model_dir, "label_encoder.joblib")
142
+ if os.path.exists(encoder_path):
143
+ label_encoder = joblib.load(encoder_path)
265
144
 
266
- # Grab our feature columns (from training)
267
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
268
- with open(os.path.join(model_dir, "feature_columns.json")) as fp:
269
- features = json.load(fp)
270
145
  print(f"Model Features: {features}")
271
146
 
272
- # Load the category mappings (from training)
273
- with open(os.path.join(model_dir, "category_mappings.json")) as fp:
274
- category_mappings = json.load(fp)
275
-
276
- # Load our Label Encoder if we have one
277
- label_encoder = None
278
- if os.path.exists(os.path.join(model_dir, "label_encoder.joblib")):
279
- label_encoder = joblib.load(os.path.join(model_dir, "label_encoder.joblib"))
280
-
281
- # We're going match features in a case-insensitive manner, accounting for all the permutations
282
- # - Model has a feature list that's any case ("Id", "taCos", "cOunT", "likes_tacos")
283
- # - Incoming data has columns that are mixed case ("ID", "Tacos", "Count", "Likes_Tacos")
147
+ # Prepare features
284
148
  matched_df = match_features_case_insensitive(df, features)
285
-
286
- # Detect categorical types in the incoming DataFrame
287
149
  matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
288
150
 
289
- # If we have compressed features, decompress them
290
151
  if compressed_features:
291
152
  print("Decompressing features for prediction...")
292
153
  matched_df, features = decompress_features(matched_df, features, compressed_features)
293
154
 
294
- # Make predictions using the TabularModel
295
- result = model.predict(matched_df[features])
155
+ # Track missing features
156
+ missing_mask = matched_df[features].isna().any(axis=1)
157
+ if missing_mask.any():
158
+ print(f"Warning: {missing_mask.sum()} rows have missing features")
296
159
 
297
- # pytorch-tabular returns predictions using f"{target}_prediction" column
298
- # and classification probabilities in columns ending with "_probability"
299
- target = TEMPLATE_PARAMS["target_column"]
300
- prediction_column = f"{target}_prediction"
301
- if prediction_column in result.columns:
302
- predictions = result[prediction_column].values
303
- else:
304
- raise ValueError(f"Cannot find prediction column in: {result.columns.tolist()}")
160
+ # Initialize output columns
161
+ df["prediction"] = np.nan
162
+ if model_type in ["regressor", "uq_regressor"]:
163
+ df["prediction_std"] = np.nan
305
164
 
306
- # If we have a label encoder, decode the predictions
307
- if label_encoder:
308
- predictions = label_encoder.inverse_transform(predictions.astype(int))
165
+ complete_df = matched_df[~missing_mask].copy()
166
+ if len(complete_df) == 0:
167
+ print("Warning: No complete rows to predict on")
168
+ return df
169
+
170
+ # Prepare data for inference (with standardization)
171
+ x_cont, x_cat, _, _, _ = prepare_data(
172
+ complete_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
173
+ )
309
174
 
310
- # Set the predictions on the DataFrame
311
- df["prediction"] = predictions
175
+ # Collect ensemble predictions
176
+ all_preds = []
177
+ for model in ensemble_models:
178
+ preds = predict(model, x_cont, x_cat)
179
+ all_preds.append(preds)
312
180
 
313
- # For classification, get probabilities
314
- if label_encoder is not None:
315
- prob_cols = [col for col in result.columns if col.endswith("_probability")]
316
- if prob_cols:
317
- probs = result[prob_cols].values
318
- df["pred_proba"] = [p.tolist() for p in probs]
181
+ # Aggregate predictions
182
+ ensemble_preds = np.stack(all_preds, axis=0)
183
+ preds = np.mean(ensemble_preds, axis=0)
184
+ preds_std = np.std(ensemble_preds, axis=0)
319
185
 
320
- # Expand the pred_proba column into separate columns for each class
321
- df = expand_proba_column(df, label_encoder.classes_)
186
+ print(f"Inference complete: {len(preds)} predictions, {len(ensemble_models)} ensemble members")
322
187
 
323
- # All done, return the DataFrame with new columns for the predictions
188
+ if label_encoder is not None:
189
+ # Classification: average probabilities, then argmax
190
+ avg_probs = preds # Already softmax output
191
+ class_preds = np.argmax(avg_probs, axis=1)
192
+ predictions = label_encoder.inverse_transform(class_preds)
193
+
194
+ all_proba = pd.Series([None] * len(df), index=df.index, dtype=object)
195
+ all_proba.loc[~missing_mask] = [p.tolist() for p in avg_probs]
196
+ df["pred_proba"] = all_proba
197
+ df = expand_proba_column(df, label_encoder.classes_)
198
+ else:
199
+ # Regression
200
+ predictions = preds.flatten()
201
+ df.loc[~missing_mask, "prediction_std"] = preds_std.flatten()
202
+
203
+ # Add UQ intervals if available
204
+ if uq_models and uq_metadata:
205
+ X_complete = complete_df[features]
206
+ df_complete = df.loc[~missing_mask].copy()
207
+ df_complete["prediction"] = predictions # Set prediction before compute_confidence
208
+ df_complete = predict_intervals(df_complete, X_complete, uq_models, uq_metadata)
209
+ df_complete = compute_confidence(df_complete, uq_metadata["median_interval_width"], "q_10", "q_90")
210
+ # Copy UQ columns back to main dataframe
211
+ for col in df_complete.columns:
212
+ if col.startswith("q_") or col == "confidence":
213
+ df.loc[~missing_mask, col] = df_complete[col].values
214
+
215
+ df.loc[~missing_mask, "prediction"] = predictions
324
216
  return df
325
217
 
326
218
 
219
+ # =============================================================================
220
+ # Training
221
+ # =============================================================================
327
222
  if __name__ == "__main__":
328
- """The main function is for training the PyTorch Tabular model"""
223
+ # -------------------------------------------------------------------------
224
+ # Training-only imports (deferred to reduce serverless startup time)
225
+ # -------------------------------------------------------------------------
226
+ import argparse
227
+
228
+ import awswrangler as wr
229
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
230
+ from sklearn.preprocessing import LabelEncoder
231
+
232
+ # Enable Tensor Core optimization for GPUs that support it
233
+ torch.set_float32_matmul_precision("medium")
234
+
235
+ from model_script_utils import (
236
+ check_dataframe,
237
+ compute_classification_metrics,
238
+ compute_regression_metrics,
239
+ print_classification_metrics,
240
+ print_confusion_matrix,
241
+ print_regression_metrics,
242
+ )
243
+ from pytorch_utils import (
244
+ create_model,
245
+ save_model,
246
+ train_model,
247
+ )
248
+ from uq_harness import (
249
+ save_uq_models,
250
+ train_uq_models,
251
+ )
329
252
 
330
- # Harness Template Parameters
331
- target = TEMPLATE_PARAMS["target_column"]
253
+ # -------------------------------------------------------------------------
254
+ # Setup: Parse arguments and load data
255
+ # -------------------------------------------------------------------------
256
+ parser = argparse.ArgumentParser()
257
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
258
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
259
+ parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
260
+ args = parser.parse_args()
261
+
262
+ # Extract template parameters
263
+ target = TEMPLATE_PARAMS["target"]
332
264
  features = TEMPLATE_PARAMS["features"]
333
265
  orig_features = features.copy()
266
+ id_column = TEMPLATE_PARAMS["id_column"]
334
267
  compressed_features = TEMPLATE_PARAMS["compressed_features"]
335
268
  model_type = TEMPLATE_PARAMS["model_type"]
336
269
  model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
337
- train_all_data = TEMPLATE_PARAMS["train_all_data"]
338
- hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
339
- validation_split = 0.2
340
-
341
- # Script arguments for input/output directories
342
- parser = argparse.ArgumentParser()
343
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
344
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
345
- parser.add_argument(
346
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
347
- )
348
- args = parser.parse_args()
270
+ hyperparameters = {**DEFAULT_HYPERPARAMETERS, **(TEMPLATE_PARAMS["hyperparameters"] or {})}
349
271
 
350
- # Read the training data into DataFrames
351
- training_files = [
352
- os.path.join(args.train, file)
353
- for file in os.listdir(args.train)
354
- if file.endswith(".csv")
355
- ]
272
+ # Load training data
273
+ training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
356
274
  print(f"Training Files: {training_files}")
357
-
358
- # Combine files and read them all into a single pandas dataframe
359
- all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
360
-
361
- # Check if the dataframe is empty
275
+ all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
362
276
  check_dataframe(all_df, "training_df")
363
277
 
364
- # Features/Target output
278
+ # Drop rows with missing features
279
+ initial_count = len(all_df)
280
+ all_df = all_df.dropna(subset=features)
281
+ if len(all_df) < initial_count:
282
+ print(f"Dropped {initial_count - len(all_df)} rows with missing features")
283
+
365
284
  print(f"Target: {target}")
366
- print(f"Features: {str(features)}")
285
+ print(f"Features: {features}")
286
+ print(f"Hyperparameters: {hyperparameters}")
367
287
 
368
- # Convert any features that might be categorical to 'category' type
288
+ # -------------------------------------------------------------------------
289
+ # Preprocessing
290
+ # -------------------------------------------------------------------------
369
291
  all_df, category_mappings = convert_categorical_types(all_df, features)
370
292
 
371
- # If we have compressed features, decompress them
372
293
  if compressed_features:
373
- print(f"Decompressing features {compressed_features}...")
294
+ print(f"Decompressing features: {compressed_features}")
374
295
  all_df, features = decompress_features(all_df, features, compressed_features)
375
296
 
376
- # Do we want to train on all the data?
377
- if train_all_data:
378
- print("Training on ALL of the data")
379
- df_train = all_df.copy()
380
- df_val = all_df.copy()
381
-
382
- # Does the dataframe have a training column?
383
- elif "training" in all_df.columns:
384
- print("Found training column, splitting data based on training column")
385
- df_train = all_df[all_df["training"]]
386
- df_val = all_df[~all_df["training"]]
387
- else:
388
- # Just do a random training Split
389
- print("WARNING: No training column found, splitting data with random state=42")
390
- df_train, df_val = train_test_split(all_df, test_size=validation_split, random_state=42)
391
- print(f"FIT/TRAIN: {df_train.shape}")
392
- print(f"VALIDATION: {df_val.shape}")
393
-
394
- # Determine categorical and continuous columns
395
- categorical_cols = [col for col in features if df_train[col].dtype.name == "category"]
396
- continuous_cols = [col for col in features if col not in categorical_cols]
397
-
398
- print(f"Categorical columns: {categorical_cols}")
399
- print(f"Continuous columns: {continuous_cols}")
400
-
401
- # Set up PyTorch Tabular configuration
402
- data_config = DataConfig(
403
- target=[target],
404
- continuous_cols=continuous_cols,
405
- categorical_cols=categorical_cols,
406
- )
297
+ # Determine categorical vs continuous columns
298
+ categorical_cols = [c for c in features if all_df[c].dtype.name == "category"]
299
+ continuous_cols = [c for c in features if c not in categorical_cols]
300
+ all_df[continuous_cols] = all_df[continuous_cols].astype("float64")
301
+ print(f"Categorical: {categorical_cols}")
302
+ print(f"Continuous: {len(continuous_cols)} columns")
407
303
 
408
- # Choose the 'task' based on model type also set up the label encoder if needed
304
+ # -------------------------------------------------------------------------
305
+ # Classification setup
306
+ # -------------------------------------------------------------------------
307
+ label_encoder = None
308
+ n_outputs = 1
409
309
  if model_type == "classifier":
410
- task = "classification"
411
- # Encode the target column
412
310
  label_encoder = LabelEncoder()
413
- df_train[target] = label_encoder.fit_transform(df_train[target])
414
- df_val[target] = label_encoder.transform(df_val[target])
311
+ all_df[target] = label_encoder.fit_transform(all_df[target])
312
+ n_outputs = len(label_encoder.classes_)
313
+ print(f"Class labels: {label_encoder.classes_.tolist()}")
314
+
315
+ # -------------------------------------------------------------------------
316
+ # Cross-validation setup
317
+ # -------------------------------------------------------------------------
318
+ n_folds = hyperparameters["n_folds"]
319
+ task = "classification" if model_type == "classifier" else "regression"
320
+ hidden_layers = [int(x) for x in hyperparameters["layers"].split("-")]
321
+
322
+ # Get categorical cardinalities
323
+ categorical_cardinalities = [len(category_mappings.get(col, {})) for col in categorical_cols]
324
+
325
+ if n_folds == 1:
326
+ if "training" in all_df.columns:
327
+ print("Using 'training' column for train/val split")
328
+ train_idx = np.where(all_df["training"])[0]
329
+ val_idx = np.where(~all_df["training"])[0]
330
+ else:
331
+ print("WARNING: No 'training' column found, using random 80/20 split")
332
+ train_idx, val_idx = train_test_split(np.arange(len(all_df)), test_size=0.2, random_state=42)
333
+ folds = [(train_idx, val_idx)]
415
334
  else:
416
- task = "regression"
417
- label_encoder = None
335
+ if model_type == "classifier":
336
+ kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
337
+ folds = list(kfold.split(all_df, all_df[target]))
338
+ else:
339
+ kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
340
+ folds = list(kfold.split(all_df))
418
341
 
419
- # Use any hyperparameters to set up both the trainer and model configurations
420
- print(f"Hyperparameters: {hyperparameters}")
342
+ print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
421
343
 
422
- # Set up PyTorch Tabular configuration with defaults
423
- trainer_defaults = {
424
- "auto_lr_find": True,
425
- "batch_size": min(1024, max(32, len(df_train) // 4)),
426
- "max_epochs": 100,
427
- "early_stopping": "valid_loss",
428
- "early_stopping_patience": 15,
429
- "checkpoints": "valid_loss",
430
- "accelerator": "auto",
431
- "progress_bar": "none",
432
- "gradient_clip_val": 1.0,
433
- }
344
+ # Fit scaler on all training data (used across all folds)
345
+ scaler = FeatureScaler()
346
+ scaler.fit(all_df, continuous_cols)
347
+ print(f"Fitted scaler on {len(continuous_cols)} continuous features")
434
348
 
435
- # Override defaults with training_config if present
436
- training_overrides = {k: v for k, v in hyperparameters.get('training_config', {}).items()
437
- if k in trainer_defaults}
438
- # Print overwrites
439
- for key, value in training_overrides.items():
440
- print(f"TRAINING CONFIG Override: {key}: {trainer_defaults[key]} → {value}")
441
- trainer_params = {**trainer_defaults, **training_overrides}
442
- trainer_config = TrainerConfig(**trainer_params)
443
-
444
- # Model config defaults
445
- model_defaults = {
446
- "layers": "1024-512-512",
447
- "activation": "ReLU",
448
- "learning_rate": 1e-3,
449
- "dropout": 0.1,
450
- "use_batch_norm": True,
451
- "initialization": "kaiming",
452
- }
453
- # Override defaults with model_config if present
454
- model_overrides = {k: v for k, v in hyperparameters.get('model_config', {}).items()
455
- if k in model_defaults}
456
- # Print overwrites
457
- for key, value in model_overrides.items():
458
- print(f"MODEL CONFIG Override: {key}: {model_defaults[key]} → {value}")
459
- model_params = {**model_defaults, **model_overrides}
460
-
461
- # Use CategoryEmbedding model configuration for general-purpose tabular modeling.
462
- # Works effectively for both regression and classification as the foundational
463
- # architecture in PyTorch Tabular
464
- model_config = CategoryEmbeddingModelConfig(
465
- task=task,
466
- **model_params
467
- )
468
- optimizer_config = OptimizerConfig()
469
-
470
- #####################################
471
- # Create and train the TabularModel #
472
- #####################################
473
- tabular_model = TabularModel(
474
- data_config=data_config,
475
- model_config=model_config,
476
- optimizer_config=optimizer_config,
477
- trainer_config=trainer_config,
478
- )
479
- tabular_model.fit(train=df_train, validation=df_val)
349
+ # Determine device
350
+ device = "cuda" if torch.cuda.is_available() else "cpu"
351
+ print(f"Using device: {device}")
480
352
 
481
- # Make Predictions on the Validation Set
482
- print("Making Predictions on Validation Set...")
483
- result = tabular_model.predict(df_val, include_input_features=False)
353
+ # -------------------------------------------------------------------------
354
+ # Training loop
355
+ # -------------------------------------------------------------------------
356
+ oof_predictions = np.full((len(all_df), n_outputs), np.nan, dtype=np.float64)
484
357
 
485
- # pytorch-tabular returns predictions using f"{target}_prediction" column
486
- # and classification probabilities in columns ending with "_probability"
487
- if model_type == "classifier":
488
- preds = result[f"{target}_prediction"].values
358
+ ensemble_models = []
359
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
360
+ print(f"\n{'='*50}")
361
+ print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
362
+ print(f"{'='*50}")
363
+
364
+ df_train = all_df.iloc[train_idx].reset_index(drop=True)
365
+ df_val = all_df.iloc[val_idx].reset_index(drop=True)
366
+
367
+ # Prepare data (using pre-fitted scaler)
368
+ train_x_cont, train_x_cat, train_y, _, _ = prepare_data(
369
+ df_train, continuous_cols, categorical_cols, target, category_mappings, scaler=scaler
370
+ )
371
+ val_x_cont, val_x_cat, val_y, _, _ = prepare_data(
372
+ df_val, continuous_cols, categorical_cols, target, category_mappings, scaler=scaler
373
+ )
374
+
375
+ # Create model
376
+ torch.manual_seed(hyperparameters["seed"] + fold_idx)
377
+ model = create_model(
378
+ n_continuous=len(continuous_cols),
379
+ categorical_cardinalities=categorical_cardinalities,
380
+ hidden_layers=hidden_layers,
381
+ n_outputs=n_outputs,
382
+ task=task,
383
+ dropout=hyperparameters["dropout"],
384
+ use_batch_norm=hyperparameters["use_batch_norm"],
385
+ )
386
+
387
+ # Train
388
+ model, history = train_model(
389
+ model,
390
+ train_x_cont, train_x_cat, train_y,
391
+ val_x_cont, val_x_cat, val_y,
392
+ task=task,
393
+ max_epochs=hyperparameters["max_epochs"],
394
+ patience=hyperparameters["early_stopping_patience"],
395
+ batch_size=hyperparameters["batch_size"],
396
+ learning_rate=hyperparameters["learning_rate"],
397
+ loss=hyperparameters.get("loss", "L1Loss"),
398
+ device=device,
399
+ )
400
+ ensemble_models.append(model)
401
+
402
+ # Out-of-fold predictions
403
+ fold_preds = predict(model, val_x_cont, val_x_cat)
404
+ oof_predictions[val_idx] = fold_preds
405
+
406
+ print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
407
+
408
+ # -------------------------------------------------------------------------
409
+ # Prepare validation results
410
+ # -------------------------------------------------------------------------
411
+ if n_folds == 1:
412
+ val_mask = ~np.isnan(oof_predictions[:, 0])
413
+ df_val = all_df[val_mask].copy()
414
+ predictions = oof_predictions[val_mask]
489
415
  else:
490
- # Regression: use the target column name
491
- preds = result[f"{target}_prediction"].values
416
+ df_val = all_df.copy()
417
+ predictions = oof_predictions
492
418
 
419
+ # Decode labels for classification
493
420
  if model_type == "classifier":
494
- # Get probabilities for classification
495
- print("Processing Probabilities...")
496
- prob_cols = [col for col in result.columns if col.endswith("_probability")]
497
- if prob_cols:
498
- probs = result[prob_cols].values
499
- df_val["pred_proba"] = [p.tolist() for p in probs]
500
-
501
- # Expand the pred_proba column into separate columns for each class
502
- print(df_val.columns)
503
- df_val = expand_proba_column(df_val, label_encoder.classes_)
504
- print(df_val.columns)
505
-
506
- # Decode the target and prediction labels
507
- y_validate = label_encoder.inverse_transform(df_val[target])
508
- preds = label_encoder.inverse_transform(preds.astype(int))
421
+ class_preds = np.argmax(predictions, axis=1)
422
+ df_val[target] = label_encoder.inverse_transform(df_val[target].astype(int))
423
+ df_val["prediction"] = label_encoder.inverse_transform(class_preds)
424
+ df_val["pred_proba"] = [p.tolist() for p in predictions]
425
+ df_val = expand_proba_column(df_val, label_encoder.classes_)
509
426
  else:
510
- y_validate = df_val[target].values
511
-
512
- # Save predictions to S3 (just the target, prediction, and '_probability' columns)
513
- df_val["prediction"] = preds
514
- output_columns = [target, "prediction"]
515
- output_columns += [col for col in df_val.columns if col.endswith("_probability")]
516
- wr.s3.to_csv(
517
- df_val[output_columns],
518
- path=f"{model_metrics_s3_path}/validation_predictions.csv",
519
- index=False,
520
- )
427
+ df_val["prediction"] = predictions.flatten()
428
+
429
+ # -------------------------------------------------------------------------
430
+ # Compute and print metrics
431
+ # -------------------------------------------------------------------------
432
+ y_true = df_val[target].values
433
+ y_pred = df_val["prediction"].values
521
434
 
522
- # Report Performance Metrics
523
435
  if model_type == "classifier":
524
- # Get the label names and their integer mapping
525
- label_names = label_encoder.classes_
526
-
527
- # Calculate various model performance metrics
528
- scores = precision_recall_fscore_support(y_validate, preds, average=None, labels=label_names)
529
-
530
- # Put the scores into a dataframe
531
- score_df = pd.DataFrame(
532
- {
533
- target: label_names,
534
- "precision": scores[0],
535
- "recall": scores[1],
536
- "fscore": scores[2],
537
- "support": scores[3],
538
- }
436
+ score_df = compute_classification_metrics(y_true, y_pred, label_encoder.classes_, target)
437
+ print_classification_metrics(score_df, target, label_encoder.classes_)
438
+ print_confusion_matrix(y_true, y_pred, label_encoder.classes_)
439
+ else:
440
+ metrics = compute_regression_metrics(y_true, y_pred)
441
+ print_regression_metrics(metrics)
442
+
443
+ # Compute ensemble prediction_std
444
+ if n_folds > 1:
445
+ # Re-run inference with all models to get std
446
+ x_cont, x_cat, _, _, _ = prepare_data(
447
+ df_val, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
448
+ )
449
+ all_preds = [predict(m, x_cont, x_cat).flatten() for m in ensemble_models]
450
+ df_val["prediction_std"] = np.std(np.stack(all_preds), axis=0)
451
+ print(f"Ensemble std - mean: {df_val['prediction_std'].mean():.4f}, max: {df_val['prediction_std'].max():.4f}")
452
+ else:
453
+ df_val["prediction_std"] = 0.0
454
+
455
+ # Train UQ models for uncertainty quantification
456
+ print("\n" + "=" * 50)
457
+ print("Training UQ Models")
458
+ print("=" * 50)
459
+ uq_models, uq_metadata = train_uq_models(
460
+ all_df[features], all_df[target], df_val[features], y_true
539
461
  )
462
+ df_val = predict_intervals(df_val, df_val[features], uq_models, uq_metadata)
463
+ df_val = compute_confidence(df_val, uq_metadata["median_interval_width"])
464
+
465
+ # -------------------------------------------------------------------------
466
+ # Save validation predictions to S3
467
+ # -------------------------------------------------------------------------
468
+ output_columns = []
469
+ if id_column in df_val.columns:
470
+ output_columns.append(id_column)
471
+ output_columns += [target, "prediction"]
472
+
473
+ if model_type != "classifier":
474
+ output_columns.append("prediction_std")
475
+ output_columns += [c for c in df_val.columns if c.startswith("q_") or c == "confidence"]
476
+
477
+ output_columns += [c for c in df_val.columns if c.endswith("_proba")]
478
+
479
+ wr.s3.to_csv(df_val[output_columns], f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
480
+
481
+ # -------------------------------------------------------------------------
482
+ # Save model artifacts
483
+ # -------------------------------------------------------------------------
484
+ model_config = {
485
+ "n_continuous": len(continuous_cols),
486
+ "categorical_cardinalities": categorical_cardinalities,
487
+ "hidden_layers": hidden_layers,
488
+ "n_outputs": n_outputs,
489
+ "task": task,
490
+ "dropout": hyperparameters["dropout"],
491
+ "use_batch_norm": hyperparameters["use_batch_norm"],
492
+ }
540
493
 
541
- # We need to get creative with the Classification Metrics
542
- metrics = ["precision", "recall", "fscore", "support"]
543
- for t in label_names:
544
- for m in metrics:
545
- value = score_df.loc[score_df[target] == t, m].iloc[0]
546
- print(f"Metrics:{t}:{m} {value}")
494
+ for idx, m in enumerate(ensemble_models):
495
+ save_model(m, os.path.join(args.model_dir, f"model_{idx}"), model_config)
496
+ print(f"Saved {len(ensemble_models)} model(s)")
547
497
 
548
- # Compute and output the confusion matrix
549
- conf_mtx = confusion_matrix(y_validate, preds, labels=label_names)
550
- for i, row_name in enumerate(label_names):
551
- for j, col_name in enumerate(label_names):
552
- value = conf_mtx[i, j]
553
- print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
498
+ joblib.dump({"n_ensemble": len(ensemble_models), "n_folds": n_folds}, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
499
+
500
+ with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as f:
501
+ json.dump(orig_features, f)
502
+
503
+ with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as f:
504
+ json.dump(category_mappings, f)
505
+
506
+ with open(os.path.join(args.model_dir, "feature_metadata.json"), "w") as f:
507
+ json.dump({"continuous_cols": continuous_cols, "categorical_cols": categorical_cols}, f)
508
+
509
+ with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
510
+ json.dump(hyperparameters, f, indent=2)
511
+
512
+ scaler.save(os.path.join(args.model_dir, "scaler.joblib"))
554
513
 
555
- else:
556
- # Calculate various model performance metrics (regression)
557
- rmse = root_mean_squared_error(y_validate, preds)
558
- mae = mean_absolute_error(y_validate, preds)
559
- r2 = r2_score(y_validate, preds)
560
- print(f"RMSE: {rmse:.3f}")
561
- print(f"MAE: {mae:.3f}")
562
- print(f"R2: {r2:.3f}")
563
- print(f"NumRows: {len(df_val)}")
564
-
565
- # Save the model to the standard place/name
566
- tabular_model.save_model(os.path.join(args.model_dir, "tabular_model"))
567
514
  if label_encoder:
568
515
  joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
569
516
 
570
- # Save the features (this will validate input during predictions)
571
- with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
572
- json.dump(orig_features, fp) # We save the original features, not the decompressed ones
517
+ if model_type != "classifier":
518
+ save_uq_models(uq_models, uq_metadata, args.model_dir)
573
519
 
574
- # Save the category mappings
575
- with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as fp:
576
- json.dump(category_mappings, fp)
520
+ print(f"\nModel training complete! Artifacts saved to {args.model_dir}")