workbench 0.8.213__py3-none-any.whl → 0.8.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
  3. workbench/algorithms/dataframe/projection_2d.py +38 -21
  4. workbench/algorithms/dataframe/proximity.py +75 -150
  5. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  6. workbench/algorithms/models/cleanlab_model.py +382 -0
  7. workbench/algorithms/models/noise_model.py +2 -2
  8. workbench/algorithms/sql/outliers.py +3 -3
  9. workbench/api/__init__.py +3 -0
  10. workbench/api/endpoint.py +10 -5
  11. workbench/api/feature_set.py +76 -6
  12. workbench/api/meta_model.py +289 -0
  13. workbench/api/model.py +43 -4
  14. workbench/core/artifacts/endpoint_core.py +65 -117
  15. workbench/core/artifacts/feature_set_core.py +3 -3
  16. workbench/core/artifacts/model_core.py +6 -4
  17. workbench/core/pipelines/pipeline_executor.py +1 -1
  18. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
  19. workbench/model_script_utils/model_script_utils.py +15 -11
  20. workbench/model_script_utils/pytorch_utils.py +11 -1
  21. workbench/model_scripts/chemprop/chemprop.template +147 -71
  22. workbench/model_scripts/chemprop/generated_model_script.py +151 -75
  23. workbench/model_scripts/chemprop/model_script_utils.py +15 -11
  24. workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
  25. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  26. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  27. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  28. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  29. workbench/model_scripts/meta_model/meta_model.template +209 -0
  30. workbench/model_scripts/pytorch_model/generated_model_script.py +45 -27
  31. workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
  32. workbench/model_scripts/pytorch_model/pytorch.template +42 -24
  33. workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
  34. workbench/model_scripts/script_generation.py +4 -0
  35. workbench/model_scripts/xgb_model/generated_model_script.py +167 -156
  36. workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
  37. workbench/model_scripts/xgb_model/xgb_model.template +163 -152
  38. workbench/repl/workbench_shell.py +0 -5
  39. workbench/scripts/endpoint_test.py +2 -2
  40. workbench/scripts/meta_model_sim.py +35 -0
  41. workbench/utils/chem_utils/fingerprints.py +87 -46
  42. workbench/utils/chemprop_utils.py +23 -5
  43. workbench/utils/meta_model_simulator.py +499 -0
  44. workbench/utils/metrics_utils.py +94 -10
  45. workbench/utils/model_utils.py +91 -9
  46. workbench/utils/pytorch_utils.py +1 -1
  47. workbench/utils/shap_utils.py +1 -55
  48. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  49. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/METADATA +2 -1
  50. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/RECORD +54 -50
  51. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/entry_points.txt +1 -0
  52. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  53. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  54. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  55. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  56. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/WHEEL +0 -0
  57. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/licenses/LICENSE +0 -0
  58. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/top_level.txt +0 -0
@@ -5,51 +5,36 @@
5
5
  # - Out-of-fold predictions for validation metrics
6
6
  # - Categorical feature embedding via TabularMLP
7
7
  # - Compressed feature decompression
8
+ #
9
+ # NOTE: Imports are structured to minimize serverless endpoint startup time.
10
+ # Heavy imports (sklearn, awswrangler) are deferred to training time.
8
11
 
9
- import argparse
10
12
  import json
11
13
  import os
12
14
 
13
- import awswrangler as wr
14
15
  import joblib
15
16
  import numpy as np
16
17
  import pandas as pd
17
18
  import torch
18
- from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
19
- from sklearn.preprocessing import LabelEncoder
20
-
21
- # Enable Tensor Core optimization for GPUs that support it
22
- torch.set_float32_matmul_precision("medium")
23
19
 
24
20
  from model_script_utils import (
25
- check_dataframe,
26
- compute_classification_metrics,
27
- compute_regression_metrics,
28
21
  convert_categorical_types,
29
22
  decompress_features,
30
23
  expand_proba_column,
31
24
  input_fn,
32
25
  match_features_case_insensitive,
33
26
  output_fn,
34
- print_classification_metrics,
35
- print_confusion_matrix,
36
- print_regression_metrics,
37
27
  )
38
28
  from pytorch_utils import (
39
29
  FeatureScaler,
40
- create_model,
41
30
  load_model,
42
31
  predict,
43
32
  prepare_data,
44
- save_model,
45
- train_model,
46
33
  )
47
34
  from uq_harness import (
48
35
  compute_confidence,
49
36
  load_uq_models,
50
37
  predict_intervals,
51
- save_uq_models,
52
- train_uq_models,
53
38
  )
54
39
 
55
40
  # =============================================================================
@@ -59,13 +44,15 @@ DEFAULT_HYPERPARAMETERS = {
59
44
  # Training parameters
60
45
  "n_folds": 5,
61
46
  "max_epochs": 200,
62
- "early_stopping_patience": 20,
47
+ "early_stopping_patience": 30,
63
48
  "batch_size": 128,
64
- # Model architecture
65
- "layers": "256-128-64",
49
+ # Model architecture (larger capacity - ensemble provides regularization)
50
+ "layers": "512-256-128",
66
51
  "learning_rate": 1e-3,
67
- "dropout": 0.1,
52
+ "dropout": 0.05,
68
53
  "use_batch_norm": True,
54
+ # Loss function for regression (L1Loss=MAE, MSELoss=MSE, HuberLoss, SmoothL1Loss)
55
+ "loss": "L1Loss",
69
56
  # Random seed
70
57
  "seed": 42,
71
58
  }
@@ -86,7 +73,7 @@ TEMPLATE_PARAMS = {
86
73
  # Model Loading (for SageMaker inference)
87
74
  # =============================================================================
88
75
  def model_fn(model_dir: str) -> dict:
89
- """Load TabularMLP ensemble from the specified directory."""
76
+ """Load PyTorch TabularMLP ensemble from the specified directory."""
90
77
  # Load ensemble metadata
91
78
  metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
92
79
  if os.path.exists(metadata_path):
@@ -129,7 +116,7 @@ def model_fn(model_dir: str) -> dict:
129
116
  # Inference (for SageMaker inference)
130
117
  # =============================================================================
131
118
  def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
132
- """Make predictions with TabularMLP ensemble."""
119
+ """Make predictions with PyTorch TabularMLP ensemble."""
133
120
  model_type = TEMPLATE_PARAMS["model_type"]
134
121
  compressed_features = TEMPLATE_PARAMS["compressed_features"]
135
122
  model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
@@ -233,6 +220,36 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
233
220
  # Training
234
221
  # =============================================================================
235
222
  if __name__ == "__main__":
223
+ # -------------------------------------------------------------------------
224
+ # Training-only imports (deferred to reduce serverless startup time)
225
+ # -------------------------------------------------------------------------
226
+ import argparse
227
+
228
+ import awswrangler as wr
229
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
230
+ from sklearn.preprocessing import LabelEncoder
231
+
232
+ # Enable Tensor Core optimization for GPUs that support it
233
+ torch.set_float32_matmul_precision("medium")
234
+
235
+ from model_script_utils import (
236
+ check_dataframe,
237
+ compute_classification_metrics,
238
+ compute_regression_metrics,
239
+ print_classification_metrics,
240
+ print_confusion_matrix,
241
+ print_regression_metrics,
242
+ )
243
+ from pytorch_utils import (
244
+ create_model,
245
+ save_model,
246
+ train_model,
247
+ )
248
+ from uq_harness import (
249
+ save_uq_models,
250
+ train_uq_models,
251
+ )
252
+
236
253
  # -------------------------------------------------------------------------
237
254
  # Setup: Parse arguments and load data
238
255
  # -------------------------------------------------------------------------
@@ -377,6 +394,7 @@ if __name__ == "__main__":
377
394
  patience=hyperparameters["early_stopping_patience"],
378
395
  batch_size=hyperparameters["batch_size"],
379
396
  learning_rate=hyperparameters["learning_rate"],
397
+ loss=hyperparameters.get("loss", "L1Loss"),
380
398
  device=device,
381
399
  )
382
400
  ensemble_models.append(model)
@@ -245,6 +245,7 @@ def train_model(
245
245
  patience: int = 20,
246
246
  batch_size: int = 128,
247
247
  learning_rate: float = 1e-3,
248
+ loss: str = "L1Loss",
248
249
  device: str = "cpu",
249
250
  ) -> tuple[TabularMLP, dict]:
250
251
  """Train the model with early stopping.
@@ -272,7 +273,16 @@ def train_model(
272
273
  if task == "classification":
273
274
  criterion = nn.CrossEntropyLoss()
274
275
  else:
275
- criterion = nn.MSELoss()
276
+ # Map loss name to PyTorch loss class
277
+ loss_map = {
278
+ "L1Loss": nn.L1Loss,
279
+ "MSELoss": nn.MSELoss,
280
+ "HuberLoss": nn.HuberLoss,
281
+ "SmoothL1Loss": nn.SmoothL1Loss,
282
+ }
283
+ if loss not in loss_map:
284
+ raise ValueError(f"Unknown loss '{loss}'. Supported: {list(loss_map.keys())}")
285
+ criterion = loss_map[loss]()
276
286
 
277
287
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
278
288
 
@@ -100,6 +100,7 @@ def generate_model_script(template_params: dict) -> str:
100
100
  - model_metrics_s3_path (str): The S3 path to store the model metrics
101
101
  - train_all_data (bool): Whether to train on all (100%) of the data
102
102
  - hyperparameters (dict, optional): Hyperparameters for the model (default: None)
103
+ - child_endpoints (list[str], optional): For META models, list of child endpoint names
103
104
 
104
105
  Returns:
105
106
  str: The name of the generated model script
@@ -116,6 +117,9 @@ def generate_model_script(template_params: dict) -> str:
116
117
  elif template_params["model_framework"] == ModelFramework.CHEMPROP:
117
118
  template_name = "chemprop.template"
118
119
  model_script_dir = "chemprop"
120
+ elif template_params["model_framework"] == ModelFramework.META:
121
+ template_name = "meta_model.template"
122
+ model_script_dir = "meta_model"
119
123
  elif template_params["model_type"] in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.CLASSIFIER]:
120
124
  template_name = "xgb_model.template"
121
125
  model_script_dir = "xgb_model"
@@ -7,39 +7,30 @@
7
7
  # - Sample weights support
8
8
  # - Categorical feature handling
9
9
  # - Compressed feature decompression
10
+ #
11
+ # NOTE: Imports are structured to minimize serverless endpoint startup time.
12
+ # Heavy imports (sklearn, awswrangler) are deferred to training time.
10
13
 
11
- import argparse
12
14
  import json
13
15
  import os
14
16
 
15
- import awswrangler as wr
16
17
  import joblib
17
18
  import numpy as np
18
19
  import pandas as pd
19
20
  import xgboost as xgb
20
- from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
21
- from sklearn.preprocessing import LabelEncoder
22
21
 
23
22
  from model_script_utils import (
24
- check_dataframe,
25
- compute_classification_metrics,
26
- compute_regression_metrics,
27
23
  convert_categorical_types,
28
24
  decompress_features,
29
25
  expand_proba_column,
30
26
  input_fn,
31
27
  match_features_case_insensitive,
32
28
  output_fn,
33
- print_classification_metrics,
34
- print_confusion_matrix,
35
- print_regression_metrics,
36
29
  )
37
30
  from uq_harness import (
38
31
  compute_confidence,
39
32
  load_uq_models,
40
33
  predict_intervals,
41
- save_uq_models,
42
- train_uq_models,
43
34
  )
44
35
 
45
36
  # =============================================================================
@@ -49,41 +40,173 @@ DEFAULT_HYPERPARAMETERS = {
49
40
  # Training parameters
50
41
  "n_folds": 5, # Number of CV folds (1 = single train/val split)
51
42
  # Core tree parameters
52
- "n_estimators": 200,
53
- "max_depth": 6,
43
+ "n_estimators": 300,
44
+ "max_depth": 7,
54
45
  "learning_rate": 0.05,
55
- # Sampling parameters
56
- "subsample": 0.7,
57
- "colsample_bytree": 0.6,
58
- "colsample_bylevel": 0.8,
59
- # Regularization
60
- "min_child_weight": 5,
61
- "gamma": 0.2,
62
- "reg_alpha": 0.5,
63
- "reg_lambda": 2.0,
46
+ # Sampling parameters (less aggressive - ensemble provides regularization)
47
+ "subsample": 0.8,
48
+ "colsample_bytree": 0.8,
49
+ # Regularization (lighter - ensemble averaging reduces overfitting)
50
+ "min_child_weight": 3,
51
+ "gamma": 0.1,
52
+ "reg_alpha": 0.1,
53
+ "reg_lambda": 1.0,
64
54
  # Random seed
65
- "random_state": 42,
55
+ "seed": 42,
66
56
  }
67
57
 
68
58
  # Workbench-specific parameters (not passed to XGBoost)
69
59
  WORKBENCH_PARAMS = {"n_folds"}
70
60
 
61
+ # Regression-only parameters (filtered out for classifiers)
62
+ REGRESSION_ONLY_PARAMS = {"objective"}
63
+
71
64
  # Template parameters (filled in by Workbench)
72
65
  TEMPLATE_PARAMS = {
73
66
  "model_type": "uq_regressor",
74
67
  "target": "udm_asy_res_efflux_ratio",
75
- "features": ['smr_vsa4', 'tpsa', 'numhdonors', 'nhohcount', 'nbase', 'vsa_estate3', 'fr_guanido', 'mollogp', 'peoe_vsa8', 'peoe_vsa1', 'fr_imine', 'vsa_estate2', 'estate_vsa10', 'asphericity', 'xc_3dv', 'smr_vsa3', 'charge_centroid_distance', 'c3sp3', 'nitrogen_span', 'estate_vsa2', 'minpartialcharge', 'hba_hbd_ratio', 'slogp_vsa1', 'axp_7d', 'nocount', 'vsa_estate4', 'vsa_estate6', 'estate_vsa4', 'xc_4dv', 'xc_4d', 'num_s_centers', 'vsa_estate9', 'chi2v', 'axp_5d', 'mi', 'mse', 'bcut2d_mrhi', 'smr_vsa6', 'hallkieralpha', 'balabanj', 'amphiphilic_moment', 'type_ii_pattern_count', 'minabsestateindex', 'bcut2d_mwlow', 'axp_0dv', 'slogp_vsa5', 'axp_2d', 'axp_1dv', 'xch_5d', 'peoe_vsa10', 'molecular_asymmetry', 'kappa3', 'estate_vsa3', 'sse', 'bcut2d_logphi', 'fr_imidazole', 'molecular_volume_3d', 'bertzct', 'maxestateindex', 'aromatic_interaction_score', 'axp_3d', 'radius_of_gyration', 'vsa_estate7', 'si', 'axp_5dv', 'molecular_axis_length', 'estate_vsa6', 'fpdensitymorgan1', 'axp_6d', 'estate_vsa9', 'fpdensitymorgan2', 'xp_0dv', 'xp_6dv', 'molmr', 'qed', 'estate_vsa8', 'peoe_vsa9', 'xch_6dv', 'xp_7d', 'slogp_vsa2', 'xp_5dv', 'bcut2d_chghi', 'xch_6d', 'chi0n', 'slogp_vsa3', 'chi1v', 'chi3v', 'bcut2d_chglo', 'axp_1d', 'mp', 'num_defined_stereocenters', 'xp_3dv', 'bcut2d_mrlow', 'fr_al_oh', 'peoe_vsa7', 'chi2n', 'axp_6dv', 'axp_2dv', 'chi4n', 'xc_3d', 'axp_7dv', 'vsa_estate8', 'xch_7d', 'maxpartialcharge', 'chi1n', 'peoe_vsa2', 'axp_3dv', 'bcut2d_logplow', 'mv', 'xpc_5dv', 'kappa2', 'vsa_estate5', 'xp_5d', 'mm', 'maxabspartialcharge', 'axp_4dv', 'maxabsestateindex', 'axp_4d', 'xch_4dv', 'xp_2dv', 'heavyatommolwt', 'numatomstereocenters', 'xp_7dv', 'numsaturatedheterocycles', 'xp_3d', 'kappa1', 'mz', 'axp_0d', 'chi1', 'xch_4d', 'smr_vsa1', 'xp_2d', 'estate_vsa5', 'phi', 'fr_ether', 'xc_5d', 'c1sp3', 'estate_vsa7', 'estate_vsa1', 'vsa_estate1', 'slogp_vsa4', 'avgipc', 'smr_vsa10', 'numvalenceelectrons', 'xc_5dv', 'peoe_vsa12', 'peoe_vsa6', 'xpc_5d', 'xpc_6d', 'minestateindex', 'chi3n', 'smr_vsa5', 'xp_4d', 'numheteroatoms', 'fpdensitymorgan3', 'xpc_4d', 'sps', 'xp_1d', 'sv', 'fr_ar_n', 'slogp_vsa10', 'c2sp3', 'xpc_4dv', 'chi0v', 'xpc_6dv', 'xp_1dv', 'vsa_estate10', 'sare', 'c2sp2', 'mpe', 'xch_7dv', 'chi4v', 'type_i_pattern_count', 'sp', 'slogp_vsa8', 'amide_count', 'num_stereocenters', 'num_r_centers', 'tertiary_amine_count', 'spe', 'xp_4dv', 'numsaturatedrings', 'mare', 'numhacceptors', 'chi0', 'fractioncsp3', 'fr_nh0', 'xch_5dv', 'fr_aniline', 'smr_vsa7', 'labuteasa', 'c3sp2', 'xp_0d', 'xp_6d', 'peoe_vsa11', 'fr_ar_nh', 'molwt', 'intramolecular_hbond_potential', 'peoe_vsa3', 'fr_nhpyrrole', 'numaliphaticrings', 'hybratio', 'smr_vsa9', 'peoe_vsa13', 'bcut2d_mwhi', 'c1sp2', 'slogp_vsa11', 'numrotatablebonds', 'numaliphaticcarbocycles', 'slogp_vsa6', 'peoe_vsa4', 'numunspecifiedatomstereocenters', 'xc_6d', 'xc_6dv', 'num_unspecified_stereocenters', 'sz', 'minabspartialcharge', 'fcsp3', 'c1sp1', 'fr_piperzine', 'numaliphaticheterocycles', 'numamidebonds', 'fr_benzene', 'numaromaticheterocycles', 'sm', 'fr_priamide', 'fr_piperdine', 'fr_methoxy', 'c4sp3', 'fr_c_o_nocoo', 'exactmolwt', 'stereo_complexity', 'fr_hoccn', 'numaromaticcarbocycles', 'fr_nh2', 'numheterocycles', 'fr_morpholine', 'fr_ketone', 'fr_nh1', 'frac_defined_stereo', 'fr_aryl_methyl', 'fr_alkyl_halide', 'fr_phenol', 'fr_al_oh_notert', 'fr_ar_oh', 'fr_pyridine', 'fr_amide', 'slogp_vsa7', 'fr_halogen', 'numsaturatedcarbocycles', 'slogp_vsa12', 'fr_ndealkylation1', 'xch_3d', 'fr_bicyclic', 'naromatom', 'narombond'],
68
+ "features": ['fingerprint'],
76
69
  "id_column": "udm_mol_bat_id",
77
- "compressed_features": [],
78
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-test-log/training",
79
- "hyperparameters": {'target_transform': 'log'},
70
+ "compressed_features": ['fingerprint'],
71
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-fp/training",
72
+ "hyperparameters": {},
80
73
  }
81
74
 
82
75
 
76
+ # =============================================================================
77
+ # Model Loading (for SageMaker inference)
78
+ # =============================================================================
79
+ def model_fn(model_dir: str) -> dict:
80
+ """Load XGBoost ensemble from the specified directory."""
81
+ # Load ensemble metadata
82
+ metadata_path = os.path.join(model_dir, "ensemble_metadata.json")
83
+ if os.path.exists(metadata_path):
84
+ with open(metadata_path) as f:
85
+ metadata = json.load(f)
86
+ n_ensemble = metadata["n_ensemble"]
87
+ else:
88
+ n_ensemble = 1 # Legacy single model
89
+
90
+ # Load ensemble models
91
+ ensemble_models = []
92
+ for i in range(n_ensemble):
93
+ model_path = os.path.join(model_dir, f"xgb_model_{i}.joblib")
94
+ if not os.path.exists(model_path):
95
+ model_path = os.path.join(model_dir, "xgb_model.joblib") # Legacy fallback
96
+ ensemble_models.append(joblib.load(model_path))
97
+
98
+ print(f"Loaded {len(ensemble_models)} model(s)")
99
+
100
+ # Load label encoder (classifier only)
101
+ label_encoder = None
102
+ encoder_path = os.path.join(model_dir, "label_encoder.joblib")
103
+ if os.path.exists(encoder_path):
104
+ label_encoder = joblib.load(encoder_path)
105
+
106
+ # Load category mappings
107
+ category_mappings = {}
108
+ category_path = os.path.join(model_dir, "category_mappings.json")
109
+ if os.path.exists(category_path):
110
+ with open(category_path) as f:
111
+ category_mappings = json.load(f)
112
+
113
+ # Load UQ models (regression only)
114
+ uq_models, uq_metadata = None, None
115
+ uq_path = os.path.join(model_dir, "uq_metadata.json")
116
+ if os.path.exists(uq_path):
117
+ uq_models, uq_metadata = load_uq_models(model_dir)
118
+
119
+ return {
120
+ "ensemble_models": ensemble_models,
121
+ "n_ensemble": n_ensemble,
122
+ "label_encoder": label_encoder,
123
+ "category_mappings": category_mappings,
124
+ "uq_models": uq_models,
125
+ "uq_metadata": uq_metadata,
126
+ }
127
+
128
+
129
+ # =============================================================================
130
+ # Inference (for SageMaker inference)
131
+ # =============================================================================
132
+ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
133
+ """Make predictions with XGBoost ensemble."""
134
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
135
+ with open(os.path.join(model_dir, "feature_columns.json")) as f:
136
+ features = json.load(f)
137
+ print(f"Model Features: {features}")
138
+
139
+ # Extract model components
140
+ ensemble_models = model_dict["ensemble_models"]
141
+ label_encoder = model_dict.get("label_encoder")
142
+ category_mappings = model_dict.get("category_mappings", {})
143
+ uq_models = model_dict.get("uq_models")
144
+ uq_metadata = model_dict.get("uq_metadata")
145
+ compressed_features = TEMPLATE_PARAMS["compressed_features"]
146
+
147
+ # Prepare features
148
+ matched_df = match_features_case_insensitive(df, features)
149
+ matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
150
+
151
+ if compressed_features:
152
+ print("Decompressing features for prediction...")
153
+ matched_df, features = decompress_features(matched_df, features, compressed_features)
154
+
155
+ X = matched_df[features]
156
+
157
+ # Collect ensemble predictions
158
+ all_preds = [m.predict(X) for m in ensemble_models]
159
+ ensemble_preds = np.stack(all_preds, axis=0)
160
+
161
+ if label_encoder is not None:
162
+ # Classification: average probabilities, then argmax
163
+ all_probs = [m.predict_proba(X) for m in ensemble_models]
164
+ avg_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
165
+ class_preds = np.argmax(avg_probs, axis=1)
166
+
167
+ df["prediction"] = label_encoder.inverse_transform(class_preds)
168
+ df["pred_proba"] = [p.tolist() for p in avg_probs]
169
+ df = expand_proba_column(df, label_encoder.classes_)
170
+ else:
171
+ # Regression: average predictions
172
+ df["prediction"] = np.mean(ensemble_preds, axis=0)
173
+ df["prediction_std"] = np.std(ensemble_preds, axis=0)
174
+
175
+ # Add UQ intervals if available
176
+ if uq_models and uq_metadata:
177
+ df = predict_intervals(df, X, uq_models, uq_metadata)
178
+ df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
179
+
180
+ print(f"Inference complete: {len(df)} predictions, {len(ensemble_models)} ensemble members")
181
+ return df
182
+
183
+
83
184
  # =============================================================================
84
185
  # Training
85
186
  # =============================================================================
86
187
  if __name__ == "__main__":
188
+ # -------------------------------------------------------------------------
189
+ # Training-only imports (deferred to reduce serverless startup time)
190
+ # -------------------------------------------------------------------------
191
+ import argparse
192
+
193
+ import awswrangler as wr
194
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
195
+ from sklearn.preprocessing import LabelEncoder
196
+
197
+ from model_script_utils import (
198
+ check_dataframe,
199
+ compute_classification_metrics,
200
+ compute_regression_metrics,
201
+ print_classification_metrics,
202
+ print_confusion_matrix,
203
+ print_regression_metrics,
204
+ )
205
+ from uq_harness import (
206
+ save_uq_models,
207
+ train_uq_models,
208
+ )
209
+
87
210
  # -------------------------------------------------------------------------
88
211
  # Setup: Parse arguments and load data
89
212
  # -------------------------------------------------------------------------
@@ -123,7 +246,7 @@ if __name__ == "__main__":
123
246
  all_df, features = decompress_features(all_df, features, compressed_features)
124
247
 
125
248
  # -------------------------------------------------------------------------
126
- # Classification setup: Encode target labels
249
+ # Classification setup
127
250
  # -------------------------------------------------------------------------
128
251
  label_encoder = None
129
252
  if model_type == "classifier":
@@ -136,6 +259,18 @@ if __name__ == "__main__":
136
259
  # -------------------------------------------------------------------------
137
260
  n_folds = hyperparameters["n_folds"]
138
261
  xgb_params = {k: v for k, v in hyperparameters.items() if k not in WORKBENCH_PARAMS}
262
+
263
+ # Map 'seed' to 'random_state' for XGBoost
264
+ if "seed" in xgb_params:
265
+ xgb_params["random_state"] = xgb_params.pop("seed")
266
+
267
+ # Handle objective: filter regression-only params for classifiers, set default for regressors
268
+ if model_type == "classifier":
269
+ xgb_params = {k: v for k, v in xgb_params.items() if k not in REGRESSION_ONLY_PARAMS}
270
+ else:
271
+ # Default to MAE (reg:absoluteerror) for regression if not specified
272
+ xgb_params.setdefault("objective", "reg:absoluteerror")
273
+
139
274
  print(f"XGBoost params: {xgb_params}")
140
275
 
141
276
  if n_folds == 1:
@@ -285,12 +420,10 @@ if __name__ == "__main__":
285
420
  # -------------------------------------------------------------------------
286
421
  # Save model artifacts
287
422
  # -------------------------------------------------------------------------
288
- # Ensemble models
289
- for idx, ens_model in enumerate(ensemble_models):
290
- joblib.dump(ens_model, os.path.join(args.model_dir, f"xgb_model_{idx}.joblib"))
291
- print(f"Saved {len(ensemble_models)} XGBoost model(s)")
423
+ for idx, m in enumerate(ensemble_models):
424
+ joblib.dump(m, os.path.join(args.model_dir, f"xgb_model_{idx}.joblib"))
425
+ print(f"Saved {len(ensemble_models)} model(s)")
292
426
 
293
- # Metadata files
294
427
  with open(os.path.join(args.model_dir, "ensemble_metadata.json"), "w") as f:
295
428
  json.dump({"n_ensemble": len(ensemble_models), "n_folds": n_folds}, f)
296
429
 
@@ -310,125 +443,3 @@ if __name__ == "__main__":
310
443
  save_uq_models(uq_models, uq_metadata, args.model_dir)
311
444
 
312
445
  print(f"\nModel training complete! Artifacts saved to {args.model_dir}")
313
-
314
-
315
- # =============================================================================
316
- # Model Loading (for SageMaker inference)
317
- # =============================================================================
318
- def model_fn(model_dir: str) -> dict:
319
- """Load XGBoost ensemble and associated artifacts.
320
-
321
- Args:
322
- model_dir: Directory containing model artifacts
323
-
324
- Returns:
325
- Dictionary with ensemble_models, label_encoder, category_mappings, uq_models, etc.
326
- """
327
- # Load ensemble metadata
328
- metadata_path = os.path.join(model_dir, "ensemble_metadata.json")
329
- if os.path.exists(metadata_path):
330
- with open(metadata_path) as f:
331
- metadata = json.load(f)
332
- n_ensemble = metadata["n_ensemble"]
333
- else:
334
- n_ensemble = 1 # Legacy single model
335
-
336
- # Load ensemble models
337
- ensemble_models = []
338
- for i in range(n_ensemble):
339
- model_path = os.path.join(model_dir, f"xgb_model_{i}.joblib")
340
- if not os.path.exists(model_path):
341
- model_path = os.path.join(model_dir, "xgb_model.joblib") # Legacy fallback
342
- ensemble_models.append(joblib.load(model_path))
343
-
344
- # Load label encoder (classifier only)
345
- label_encoder = None
346
- encoder_path = os.path.join(model_dir, "label_encoder.joblib")
347
- if os.path.exists(encoder_path):
348
- label_encoder = joblib.load(encoder_path)
349
-
350
- # Load category mappings
351
- category_mappings = {}
352
- category_path = os.path.join(model_dir, "category_mappings.json")
353
- if os.path.exists(category_path):
354
- with open(category_path) as f:
355
- category_mappings = json.load(f)
356
-
357
- # Load UQ models (regression only)
358
- uq_models, uq_metadata = None, None
359
- uq_path = os.path.join(model_dir, "uq_metadata.json")
360
- if os.path.exists(uq_path):
361
- uq_models, uq_metadata = load_uq_models(model_dir)
362
-
363
- return {
364
- "ensemble_models": ensemble_models,
365
- "n_ensemble": n_ensemble,
366
- "label_encoder": label_encoder,
367
- "category_mappings": category_mappings,
368
- "uq_models": uq_models,
369
- "uq_metadata": uq_metadata,
370
- }
371
-
372
-
373
- # =============================================================================
374
- # Inference (for SageMaker inference)
375
- # =============================================================================
376
- def predict_fn(df: pd.DataFrame, models: dict) -> pd.DataFrame:
377
- """Make predictions with XGBoost ensemble.
378
-
379
- Args:
380
- df: Input DataFrame with features
381
- models: Dictionary from model_fn containing ensemble and metadata
382
-
383
- Returns:
384
- DataFrame with predictions added
385
- """
386
- # Load feature columns
387
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
388
- with open(os.path.join(model_dir, "feature_columns.json")) as f:
389
- features = json.load(f)
390
- print(f"Model Features: {features}")
391
-
392
- # Extract model components
393
- ensemble_models = models["ensemble_models"]
394
- label_encoder = models.get("label_encoder")
395
- category_mappings = models.get("category_mappings", {})
396
- uq_models = models.get("uq_models")
397
- uq_metadata = models.get("uq_metadata")
398
- compressed_features = TEMPLATE_PARAMS["compressed_features"]
399
-
400
- # Prepare features
401
- matched_df = match_features_case_insensitive(df, features)
402
- matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
403
-
404
- if compressed_features:
405
- print("Decompressing features for prediction...")
406
- matched_df, features = decompress_features(matched_df, features, compressed_features)
407
-
408
- X = matched_df[features]
409
-
410
- # Collect ensemble predictions
411
- all_preds = [m.predict(X) for m in ensemble_models]
412
- ensemble_preds = np.stack(all_preds, axis=0)
413
-
414
- if label_encoder is not None:
415
- # Classification: average probabilities, then argmax
416
- all_probs = [m.predict_proba(X) for m in ensemble_models]
417
- avg_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
418
- class_preds = np.argmax(avg_probs, axis=1)
419
-
420
- df["prediction"] = label_encoder.inverse_transform(class_preds)
421
- df["pred_proba"] = [p.tolist() for p in avg_probs]
422
- df = expand_proba_column(df, label_encoder.classes_)
423
- else:
424
- # Regression: average predictions
425
- df["prediction"] = np.mean(ensemble_preds, axis=0)
426
- df["prediction_std"] = np.std(ensemble_preds, axis=0)
427
-
428
- # Add UQ intervals if available
429
- if uq_models and uq_metadata:
430
- df = predict_intervals(df, X, uq_models, uq_metadata)
431
- df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
432
-
433
- print(f"Inference complete: {len(df)} predictions, {len(ensemble_models)} ensemble members")
434
- return df
@@ -148,12 +148,16 @@ def convert_categorical_types(
148
148
  def decompress_features(
149
149
  df: pd.DataFrame, features: list[str], compressed_features: list[str]
150
150
  ) -> tuple[pd.DataFrame, list[str]]:
151
- """Decompress bitstring features into individual bit columns.
151
+ """Decompress compressed features (bitstrings or count vectors) into individual columns.
152
+
153
+ Supports two formats (auto-detected):
154
+ - Bitstrings: "10110010..." → individual uint8 columns (0 or 1)
155
+ - Count vectors: "0,3,0,1,5,..." → individual uint8 columns (0-255)
152
156
 
153
157
  Args:
154
158
  df: The features DataFrame
155
159
  features: Full list of feature names
156
- compressed_features: List of feature names to decompress (bitstrings)
160
+ compressed_features: List of feature names to decompress
157
161
 
158
162
  Returns:
159
163
  Tuple of (DataFrame with decompressed features, updated feature list)
@@ -178,18 +182,18 @@ def decompress_features(
178
182
  # Remove the feature from the list to avoid duplication
179
183
  decompressed_features.remove(feature)
180
184
 
181
- # Handle all compressed features as bitstrings
182
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
183
- prefix = feature[:3]
185
+ # Auto-detect format and parse: comma-separated counts or bitstring
186
+ sample = str(df[feature].dropna().iloc[0]) if not df[feature].dropna().empty else ""
187
+ parse_fn = (lambda s: list(map(int, s.split(",")))) if "," in sample else list
188
+ feature_matrix = np.array([parse_fn(s) for s in df[feature]], dtype=np.uint8)
184
189
 
185
- # Create all new columns at once - avoids fragmentation
186
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
187
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
190
+ # Create new columns with prefix from feature name
191
+ prefix = feature[:3]
192
+ new_col_names = [f"{prefix}_{i}" for i in range(feature_matrix.shape[1])]
193
+ new_df = pd.DataFrame(feature_matrix, columns=new_col_names, index=df.index)
188
194
 
189
- # Add to features list
195
+ # Update features list and dataframe
190
196
  decompressed_features.extend(new_col_names)
191
-
192
- # Drop original column and concatenate new ones
193
197
  df = df.drop(columns=[feature])
194
198
  df = pd.concat([df, new_df], axis=1)
195
199