workbench 0.8.213__py3-none-any.whl → 0.8.217__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
  3. workbench/algorithms/dataframe/projection_2d.py +38 -21
  4. workbench/algorithms/dataframe/proximity.py +75 -150
  5. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  6. workbench/algorithms/models/cleanlab_model.py +382 -0
  7. workbench/algorithms/models/noise_model.py +2 -2
  8. workbench/api/__init__.py +3 -0
  9. workbench/api/endpoint.py +10 -5
  10. workbench/api/feature_set.py +76 -6
  11. workbench/api/meta_model.py +289 -0
  12. workbench/api/model.py +43 -4
  13. workbench/core/artifacts/endpoint_core.py +63 -115
  14. workbench/core/artifacts/feature_set_core.py +1 -1
  15. workbench/core/artifacts/model_core.py +6 -4
  16. workbench/core/pipelines/pipeline_executor.py +1 -1
  17. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
  18. workbench/model_script_utils/pytorch_utils.py +11 -1
  19. workbench/model_scripts/chemprop/chemprop.template +145 -69
  20. workbench/model_scripts/chemprop/generated_model_script.py +147 -71
  21. workbench/model_scripts/custom_models/chem_info/fingerprints.py +7 -3
  22. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  23. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  24. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  25. workbench/model_scripts/custom_models/uq_models/meta_uq.template +6 -6
  26. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  27. workbench/model_scripts/meta_model/meta_model.template +209 -0
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +42 -24
  29. workbench/model_scripts/pytorch_model/pytorch.template +42 -24
  30. workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
  31. workbench/model_scripts/script_generation.py +4 -0
  32. workbench/model_scripts/xgb_model/generated_model_script.py +169 -158
  33. workbench/model_scripts/xgb_model/xgb_model.template +163 -152
  34. workbench/repl/workbench_shell.py +0 -5
  35. workbench/scripts/endpoint_test.py +2 -2
  36. workbench/utils/chem_utils/fingerprints.py +7 -3
  37. workbench/utils/chemprop_utils.py +23 -5
  38. workbench/utils/meta_model_simulator.py +471 -0
  39. workbench/utils/metrics_utils.py +94 -10
  40. workbench/utils/model_utils.py +91 -9
  41. workbench/utils/pytorch_utils.py +1 -1
  42. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  43. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/METADATA +2 -1
  44. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/RECORD +48 -43
  45. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  46. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  47. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/WHEEL +0 -0
  48. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/entry_points.txt +0 -0
  49. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/licenses/LICENSE +0 -0
  50. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/top_level.txt +0 -0
@@ -7,39 +7,30 @@
7
7
  # - Sample weights support
8
8
  # - Categorical feature handling
9
9
  # - Compressed feature decompression
10
+ #
11
+ # NOTE: Imports are structured to minimize serverless endpoint startup time.
12
+ # Heavy imports (sklearn, awswrangler) are deferred to training time.
10
13
 
11
- import argparse
12
14
  import json
13
15
  import os
14
16
 
15
- import awswrangler as wr
16
17
  import joblib
17
18
  import numpy as np
18
19
  import pandas as pd
19
20
  import xgboost as xgb
20
- from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
21
- from sklearn.preprocessing import LabelEncoder
22
21
 
23
22
  from model_script_utils import (
24
- check_dataframe,
25
- compute_classification_metrics,
26
- compute_regression_metrics,
27
23
  convert_categorical_types,
28
24
  decompress_features,
29
25
  expand_proba_column,
30
26
  input_fn,
31
27
  match_features_case_insensitive,
32
28
  output_fn,
33
- print_classification_metrics,
34
- print_confusion_matrix,
35
- print_regression_metrics,
36
29
  )
37
30
  from uq_harness import (
38
31
  compute_confidence,
39
32
  load_uq_models,
40
33
  predict_intervals,
41
- save_uq_models,
42
- train_uq_models,
43
34
  )
44
35
 
45
36
  # =============================================================================
@@ -49,41 +40,173 @@ DEFAULT_HYPERPARAMETERS = {
49
40
  # Training parameters
50
41
  "n_folds": 5, # Number of CV folds (1 = single train/val split)
51
42
  # Core tree parameters
52
- "n_estimators": 200,
53
- "max_depth": 6,
43
+ "n_estimators": 300,
44
+ "max_depth": 7,
54
45
  "learning_rate": 0.05,
55
- # Sampling parameters
56
- "subsample": 0.7,
57
- "colsample_bytree": 0.6,
58
- "colsample_bylevel": 0.8,
59
- # Regularization
60
- "min_child_weight": 5,
61
- "gamma": 0.2,
62
- "reg_alpha": 0.5,
63
- "reg_lambda": 2.0,
46
+ # Sampling parameters (less aggressive - ensemble provides regularization)
47
+ "subsample": 0.8,
48
+ "colsample_bytree": 0.8,
49
+ # Regularization (lighter - ensemble averaging reduces overfitting)
50
+ "min_child_weight": 3,
51
+ "gamma": 0.1,
52
+ "reg_alpha": 0.1,
53
+ "reg_lambda": 1.0,
64
54
  # Random seed
65
- "random_state": 42,
55
+ "seed": 42,
66
56
  }
67
57
 
68
58
  # Workbench-specific parameters (not passed to XGBoost)
69
59
  WORKBENCH_PARAMS = {"n_folds"}
70
60
 
61
+ # Regression-only parameters (filtered out for classifiers)
62
+ REGRESSION_ONLY_PARAMS = {"objective"}
63
+
71
64
  # Template parameters (filled in by Workbench)
72
65
  TEMPLATE_PARAMS = {
73
- "model_type": "uq_regressor",
74
- "target": "udm_asy_res_efflux_ratio",
75
- "features": ['smr_vsa4', 'tpsa', 'numhdonors', 'nhohcount', 'nbase', 'vsa_estate3', 'fr_guanido', 'mollogp', 'peoe_vsa8', 'peoe_vsa1', 'fr_imine', 'vsa_estate2', 'estate_vsa10', 'asphericity', 'xc_3dv', 'smr_vsa3', 'charge_centroid_distance', 'c3sp3', 'nitrogen_span', 'estate_vsa2', 'minpartialcharge', 'hba_hbd_ratio', 'slogp_vsa1', 'axp_7d', 'nocount', 'vsa_estate4', 'vsa_estate6', 'estate_vsa4', 'xc_4dv', 'xc_4d', 'num_s_centers', 'vsa_estate9', 'chi2v', 'axp_5d', 'mi', 'mse', 'bcut2d_mrhi', 'smr_vsa6', 'hallkieralpha', 'balabanj', 'amphiphilic_moment', 'type_ii_pattern_count', 'minabsestateindex', 'bcut2d_mwlow', 'axp_0dv', 'slogp_vsa5', 'axp_2d', 'axp_1dv', 'xch_5d', 'peoe_vsa10', 'molecular_asymmetry', 'kappa3', 'estate_vsa3', 'sse', 'bcut2d_logphi', 'fr_imidazole', 'molecular_volume_3d', 'bertzct', 'maxestateindex', 'aromatic_interaction_score', 'axp_3d', 'radius_of_gyration', 'vsa_estate7', 'si', 'axp_5dv', 'molecular_axis_length', 'estate_vsa6', 'fpdensitymorgan1', 'axp_6d', 'estate_vsa9', 'fpdensitymorgan2', 'xp_0dv', 'xp_6dv', 'molmr', 'qed', 'estate_vsa8', 'peoe_vsa9', 'xch_6dv', 'xp_7d', 'slogp_vsa2', 'xp_5dv', 'bcut2d_chghi', 'xch_6d', 'chi0n', 'slogp_vsa3', 'chi1v', 'chi3v', 'bcut2d_chglo', 'axp_1d', 'mp', 'num_defined_stereocenters', 'xp_3dv', 'bcut2d_mrlow', 'fr_al_oh', 'peoe_vsa7', 'chi2n', 'axp_6dv', 'axp_2dv', 'chi4n', 'xc_3d', 'axp_7dv', 'vsa_estate8', 'xch_7d', 'maxpartialcharge', 'chi1n', 'peoe_vsa2', 'axp_3dv', 'bcut2d_logplow', 'mv', 'xpc_5dv', 'kappa2', 'vsa_estate5', 'xp_5d', 'mm', 'maxabspartialcharge', 'axp_4dv', 'maxabsestateindex', 'axp_4d', 'xch_4dv', 'xp_2dv', 'heavyatommolwt', 'numatomstereocenters', 'xp_7dv', 'numsaturatedheterocycles', 'xp_3d', 'kappa1', 'mz', 'axp_0d', 'chi1', 'xch_4d', 'smr_vsa1', 'xp_2d', 'estate_vsa5', 'phi', 'fr_ether', 'xc_5d', 'c1sp3', 'estate_vsa7', 'estate_vsa1', 'vsa_estate1', 'slogp_vsa4', 'avgipc', 'smr_vsa10', 'numvalenceelectrons', 'xc_5dv', 'peoe_vsa12', 'peoe_vsa6', 'xpc_5d', 'xpc_6d', 'minestateindex', 'chi3n', 'smr_vsa5', 'xp_4d', 'numheteroatoms', 'fpdensitymorgan3', 'xpc_4d', 'sps', 'xp_1d', 'sv', 'fr_ar_n', 'slogp_vsa10', 'c2sp3', 'xpc_4dv', 'chi0v', 'xpc_6dv', 'xp_1dv', 'vsa_estate10', 'sare', 'c2sp2', 'mpe', 'xch_7dv', 'chi4v', 'type_i_pattern_count', 'sp', 'slogp_vsa8', 'amide_count', 'num_stereocenters', 'num_r_centers', 'tertiary_amine_count', 'spe', 'xp_4dv', 'numsaturatedrings', 'mare', 'numhacceptors', 'chi0', 'fractioncsp3', 'fr_nh0', 'xch_5dv', 'fr_aniline', 'smr_vsa7', 'labuteasa', 'c3sp2', 'xp_0d', 'xp_6d', 'peoe_vsa11', 'fr_ar_nh', 'molwt', 'intramolecular_hbond_potential', 'peoe_vsa3', 'fr_nhpyrrole', 'numaliphaticrings', 'hybratio', 'smr_vsa9', 'peoe_vsa13', 'bcut2d_mwhi', 'c1sp2', 'slogp_vsa11', 'numrotatablebonds', 'numaliphaticcarbocycles', 'slogp_vsa6', 'peoe_vsa4', 'numunspecifiedatomstereocenters', 'xc_6d', 'xc_6dv', 'num_unspecified_stereocenters', 'sz', 'minabspartialcharge', 'fcsp3', 'c1sp1', 'fr_piperzine', 'numaliphaticheterocycles', 'numamidebonds', 'fr_benzene', 'numaromaticheterocycles', 'sm', 'fr_priamide', 'fr_piperdine', 'fr_methoxy', 'c4sp3', 'fr_c_o_nocoo', 'exactmolwt', 'stereo_complexity', 'fr_hoccn', 'numaromaticcarbocycles', 'fr_nh2', 'numheterocycles', 'fr_morpholine', 'fr_ketone', 'fr_nh1', 'frac_defined_stereo', 'fr_aryl_methyl', 'fr_alkyl_halide', 'fr_phenol', 'fr_al_oh_notert', 'fr_ar_oh', 'fr_pyridine', 'fr_amide', 'slogp_vsa7', 'fr_halogen', 'numsaturatedcarbocycles', 'slogp_vsa12', 'fr_ndealkylation1', 'xch_3d', 'fr_bicyclic', 'naromatom', 'narombond'],
76
- "id_column": "udm_mol_bat_id",
66
+ "model_type": "classifier",
67
+ "target": "solubility_class",
68
+ "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
69
+ "id_column": "id",
77
70
  "compressed_features": [],
78
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-test-log/training",
79
- "hyperparameters": {'target_transform': 'log'},
71
+ "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-class/training",
72
+ "hyperparameters": {},
80
73
  }
81
74
 
82
75
 
76
+ # =============================================================================
77
+ # Model Loading (for SageMaker inference)
78
+ # =============================================================================
79
+ def model_fn(model_dir: str) -> dict:
80
+ """Load XGBoost ensemble from the specified directory."""
81
+ # Load ensemble metadata
82
+ metadata_path = os.path.join(model_dir, "ensemble_metadata.json")
83
+ if os.path.exists(metadata_path):
84
+ with open(metadata_path) as f:
85
+ metadata = json.load(f)
86
+ n_ensemble = metadata["n_ensemble"]
87
+ else:
88
+ n_ensemble = 1 # Legacy single model
89
+
90
+ # Load ensemble models
91
+ ensemble_models = []
92
+ for i in range(n_ensemble):
93
+ model_path = os.path.join(model_dir, f"xgb_model_{i}.joblib")
94
+ if not os.path.exists(model_path):
95
+ model_path = os.path.join(model_dir, "xgb_model.joblib") # Legacy fallback
96
+ ensemble_models.append(joblib.load(model_path))
97
+
98
+ print(f"Loaded {len(ensemble_models)} model(s)")
99
+
100
+ # Load label encoder (classifier only)
101
+ label_encoder = None
102
+ encoder_path = os.path.join(model_dir, "label_encoder.joblib")
103
+ if os.path.exists(encoder_path):
104
+ label_encoder = joblib.load(encoder_path)
105
+
106
+ # Load category mappings
107
+ category_mappings = {}
108
+ category_path = os.path.join(model_dir, "category_mappings.json")
109
+ if os.path.exists(category_path):
110
+ with open(category_path) as f:
111
+ category_mappings = json.load(f)
112
+
113
+ # Load UQ models (regression only)
114
+ uq_models, uq_metadata = None, None
115
+ uq_path = os.path.join(model_dir, "uq_metadata.json")
116
+ if os.path.exists(uq_path):
117
+ uq_models, uq_metadata = load_uq_models(model_dir)
118
+
119
+ return {
120
+ "ensemble_models": ensemble_models,
121
+ "n_ensemble": n_ensemble,
122
+ "label_encoder": label_encoder,
123
+ "category_mappings": category_mappings,
124
+ "uq_models": uq_models,
125
+ "uq_metadata": uq_metadata,
126
+ }
127
+
128
+
129
+ # =============================================================================
130
+ # Inference (for SageMaker inference)
131
+ # =============================================================================
132
+ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
133
+ """Make predictions with XGBoost ensemble."""
134
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
135
+ with open(os.path.join(model_dir, "feature_columns.json")) as f:
136
+ features = json.load(f)
137
+ print(f"Model Features: {features}")
138
+
139
+ # Extract model components
140
+ ensemble_models = model_dict["ensemble_models"]
141
+ label_encoder = model_dict.get("label_encoder")
142
+ category_mappings = model_dict.get("category_mappings", {})
143
+ uq_models = model_dict.get("uq_models")
144
+ uq_metadata = model_dict.get("uq_metadata")
145
+ compressed_features = TEMPLATE_PARAMS["compressed_features"]
146
+
147
+ # Prepare features
148
+ matched_df = match_features_case_insensitive(df, features)
149
+ matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
150
+
151
+ if compressed_features:
152
+ print("Decompressing features for prediction...")
153
+ matched_df, features = decompress_features(matched_df, features, compressed_features)
154
+
155
+ X = matched_df[features]
156
+
157
+ # Collect ensemble predictions
158
+ all_preds = [m.predict(X) for m in ensemble_models]
159
+ ensemble_preds = np.stack(all_preds, axis=0)
160
+
161
+ if label_encoder is not None:
162
+ # Classification: average probabilities, then argmax
163
+ all_probs = [m.predict_proba(X) for m in ensemble_models]
164
+ avg_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
165
+ class_preds = np.argmax(avg_probs, axis=1)
166
+
167
+ df["prediction"] = label_encoder.inverse_transform(class_preds)
168
+ df["pred_proba"] = [p.tolist() for p in avg_probs]
169
+ df = expand_proba_column(df, label_encoder.classes_)
170
+ else:
171
+ # Regression: average predictions
172
+ df["prediction"] = np.mean(ensemble_preds, axis=0)
173
+ df["prediction_std"] = np.std(ensemble_preds, axis=0)
174
+
175
+ # Add UQ intervals if available
176
+ if uq_models and uq_metadata:
177
+ df = predict_intervals(df, X, uq_models, uq_metadata)
178
+ df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
179
+
180
+ print(f"Inference complete: {len(df)} predictions, {len(ensemble_models)} ensemble members")
181
+ return df
182
+
183
+
83
184
  # =============================================================================
84
185
  # Training
85
186
  # =============================================================================
86
187
  if __name__ == "__main__":
188
+ # -------------------------------------------------------------------------
189
+ # Training-only imports (deferred to reduce serverless startup time)
190
+ # -------------------------------------------------------------------------
191
+ import argparse
192
+
193
+ import awswrangler as wr
194
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
195
+ from sklearn.preprocessing import LabelEncoder
196
+
197
+ from model_script_utils import (
198
+ check_dataframe,
199
+ compute_classification_metrics,
200
+ compute_regression_metrics,
201
+ print_classification_metrics,
202
+ print_confusion_matrix,
203
+ print_regression_metrics,
204
+ )
205
+ from uq_harness import (
206
+ save_uq_models,
207
+ train_uq_models,
208
+ )
209
+
87
210
  # -------------------------------------------------------------------------
88
211
  # Setup: Parse arguments and load data
89
212
  # -------------------------------------------------------------------------
@@ -123,7 +246,7 @@ if __name__ == "__main__":
123
246
  all_df, features = decompress_features(all_df, features, compressed_features)
124
247
 
125
248
  # -------------------------------------------------------------------------
126
- # Classification setup: Encode target labels
249
+ # Classification setup
127
250
  # -------------------------------------------------------------------------
128
251
  label_encoder = None
129
252
  if model_type == "classifier":
@@ -136,6 +259,18 @@ if __name__ == "__main__":
136
259
  # -------------------------------------------------------------------------
137
260
  n_folds = hyperparameters["n_folds"]
138
261
  xgb_params = {k: v for k, v in hyperparameters.items() if k not in WORKBENCH_PARAMS}
262
+
263
+ # Map 'seed' to 'random_state' for XGBoost
264
+ if "seed" in xgb_params:
265
+ xgb_params["random_state"] = xgb_params.pop("seed")
266
+
267
+ # Handle objective: filter regression-only params for classifiers, set default for regressors
268
+ if model_type == "classifier":
269
+ xgb_params = {k: v for k, v in xgb_params.items() if k not in REGRESSION_ONLY_PARAMS}
270
+ else:
271
+ # Default to MAE (reg:absoluteerror) for regression if not specified
272
+ xgb_params.setdefault("objective", "reg:absoluteerror")
273
+
139
274
  print(f"XGBoost params: {xgb_params}")
140
275
 
141
276
  if n_folds == 1:
@@ -285,12 +420,10 @@ if __name__ == "__main__":
285
420
  # -------------------------------------------------------------------------
286
421
  # Save model artifacts
287
422
  # -------------------------------------------------------------------------
288
- # Ensemble models
289
- for idx, ens_model in enumerate(ensemble_models):
290
- joblib.dump(ens_model, os.path.join(args.model_dir, f"xgb_model_{idx}.joblib"))
291
- print(f"Saved {len(ensemble_models)} XGBoost model(s)")
423
+ for idx, m in enumerate(ensemble_models):
424
+ joblib.dump(m, os.path.join(args.model_dir, f"xgb_model_{idx}.joblib"))
425
+ print(f"Saved {len(ensemble_models)} model(s)")
292
426
 
293
- # Metadata files
294
427
  with open(os.path.join(args.model_dir, "ensemble_metadata.json"), "w") as f:
295
428
  json.dump({"n_ensemble": len(ensemble_models), "n_folds": n_folds}, f)
296
429
 
@@ -310,125 +443,3 @@ if __name__ == "__main__":
310
443
  save_uq_models(uq_models, uq_metadata, args.model_dir)
311
444
 
312
445
  print(f"\nModel training complete! Artifacts saved to {args.model_dir}")
313
-
314
-
315
- # =============================================================================
316
- # Model Loading (for SageMaker inference)
317
- # =============================================================================
318
- def model_fn(model_dir: str) -> dict:
319
- """Load XGBoost ensemble and associated artifacts.
320
-
321
- Args:
322
- model_dir: Directory containing model artifacts
323
-
324
- Returns:
325
- Dictionary with ensemble_models, label_encoder, category_mappings, uq_models, etc.
326
- """
327
- # Load ensemble metadata
328
- metadata_path = os.path.join(model_dir, "ensemble_metadata.json")
329
- if os.path.exists(metadata_path):
330
- with open(metadata_path) as f:
331
- metadata = json.load(f)
332
- n_ensemble = metadata["n_ensemble"]
333
- else:
334
- n_ensemble = 1 # Legacy single model
335
-
336
- # Load ensemble models
337
- ensemble_models = []
338
- for i in range(n_ensemble):
339
- model_path = os.path.join(model_dir, f"xgb_model_{i}.joblib")
340
- if not os.path.exists(model_path):
341
- model_path = os.path.join(model_dir, "xgb_model.joblib") # Legacy fallback
342
- ensemble_models.append(joblib.load(model_path))
343
-
344
- # Load label encoder (classifier only)
345
- label_encoder = None
346
- encoder_path = os.path.join(model_dir, "label_encoder.joblib")
347
- if os.path.exists(encoder_path):
348
- label_encoder = joblib.load(encoder_path)
349
-
350
- # Load category mappings
351
- category_mappings = {}
352
- category_path = os.path.join(model_dir, "category_mappings.json")
353
- if os.path.exists(category_path):
354
- with open(category_path) as f:
355
- category_mappings = json.load(f)
356
-
357
- # Load UQ models (regression only)
358
- uq_models, uq_metadata = None, None
359
- uq_path = os.path.join(model_dir, "uq_metadata.json")
360
- if os.path.exists(uq_path):
361
- uq_models, uq_metadata = load_uq_models(model_dir)
362
-
363
- return {
364
- "ensemble_models": ensemble_models,
365
- "n_ensemble": n_ensemble,
366
- "label_encoder": label_encoder,
367
- "category_mappings": category_mappings,
368
- "uq_models": uq_models,
369
- "uq_metadata": uq_metadata,
370
- }
371
-
372
-
373
- # =============================================================================
374
- # Inference (for SageMaker inference)
375
- # =============================================================================
376
- def predict_fn(df: pd.DataFrame, models: dict) -> pd.DataFrame:
377
- """Make predictions with XGBoost ensemble.
378
-
379
- Args:
380
- df: Input DataFrame with features
381
- models: Dictionary from model_fn containing ensemble and metadata
382
-
383
- Returns:
384
- DataFrame with predictions added
385
- """
386
- # Load feature columns
387
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
388
- with open(os.path.join(model_dir, "feature_columns.json")) as f:
389
- features = json.load(f)
390
- print(f"Model Features: {features}")
391
-
392
- # Extract model components
393
- ensemble_models = models["ensemble_models"]
394
- label_encoder = models.get("label_encoder")
395
- category_mappings = models.get("category_mappings", {})
396
- uq_models = models.get("uq_models")
397
- uq_metadata = models.get("uq_metadata")
398
- compressed_features = TEMPLATE_PARAMS["compressed_features"]
399
-
400
- # Prepare features
401
- matched_df = match_features_case_insensitive(df, features)
402
- matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
403
-
404
- if compressed_features:
405
- print("Decompressing features for prediction...")
406
- matched_df, features = decompress_features(matched_df, features, compressed_features)
407
-
408
- X = matched_df[features]
409
-
410
- # Collect ensemble predictions
411
- all_preds = [m.predict(X) for m in ensemble_models]
412
- ensemble_preds = np.stack(all_preds, axis=0)
413
-
414
- if label_encoder is not None:
415
- # Classification: average probabilities, then argmax
416
- all_probs = [m.predict_proba(X) for m in ensemble_models]
417
- avg_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
418
- class_preds = np.argmax(avg_probs, axis=1)
419
-
420
- df["prediction"] = label_encoder.inverse_transform(class_preds)
421
- df["pred_proba"] = [p.tolist() for p in avg_probs]
422
- df = expand_proba_column(df, label_encoder.classes_)
423
- else:
424
- # Regression: average predictions
425
- df["prediction"] = np.mean(ensemble_preds, axis=0)
426
- df["prediction_std"] = np.std(ensemble_preds, axis=0)
427
-
428
- # Add UQ intervals if available
429
- if uq_models and uq_metadata:
430
- df = predict_intervals(df, X, uq_models, uq_metadata)
431
- df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
432
-
433
- print(f"Inference complete: {len(df)} predictions, {len(ensemble_models)} ensemble members")
434
- return df