workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,774 @@
1
+ # ChemProp Model Template for Workbench
2
+ #
3
+ # This template handles molecular property prediction using ChemProp 2.x MPNN with:
4
+ # - K-fold cross-validation ensemble training (or single train/val split)
5
+ # - Multi-task regression support
6
+ # - Hybrid mode (SMILES + extra molecular descriptors)
7
+ # - Classification (single-target only)
8
+ #
9
+ # NOTE: Imports are structured to minimize serverless endpoint startup time.
10
+ # Heavy imports (lightning, sklearn, awswrangler) are deferred to training time.
11
+
12
+ import json
13
+ import os
14
+
15
+ import joblib
16
+ import numpy as np
17
+ import pandas as pd
18
+ import torch
19
+
20
+ from chemprop import data, models
21
+
22
+ from model_script_utils import (
23
+ expand_proba_column,
24
+ input_fn,
25
+ output_fn,
26
+ )
27
+
28
+ # =============================================================================
29
+ # Default Hyperparameters
30
+ # =============================================================================
31
+ DEFAULT_HYPERPARAMETERS = {
32
+ # Training
33
+ "n_folds": 5,
34
+ "max_epochs": 400,
35
+ "patience": 50,
36
+ "batch_size": 32,
37
+ # Message Passing (ignored when using foundation model)
38
+ "hidden_dim": 700,
39
+ "depth": 6,
40
+ "dropout": 0.1, # Lower dropout - ensemble provides regularization
41
+ # FFN
42
+ "ffn_hidden_dim": 2000,
43
+ "ffn_num_layers": 2,
44
+ # Loss function for regression (mae, mse)
45
+ "criterion": "mae",
46
+ # Random seed
47
+ "seed": 42,
48
+ # Foundation model support
49
+ # - "CheMeleon": Load CheMeleon pretrained weights (auto-downloads on first use)
50
+ # - Path to .pt file: Load custom pretrained Chemprop model
51
+ # - None: Train from scratch (default)
52
+ "from_foundation": None,
53
+ # Freeze MPNN for N epochs, then unfreeze (0 = no freezing, train all params from start)
54
+ # Recommended: 5-20 epochs when using foundation models to stabilize FFN before fine-tuning MPNN
55
+ "freeze_mpnn_epochs": 0,
56
+ }
57
+
58
+ # Template parameters (filled in by Workbench)
59
+ TEMPLATE_PARAMS = {
60
+ "model_type": "uq_regressor",
61
+ "targets": ['udm_asy_res_free_percent'],
62
+ "feature_list": ['smiles'],
63
+ "id_column": "udm_mol_bat_id",
64
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/ppb-human-free-reg-chemprop-foundation-1-dt/training",
65
+ "hyperparameters": {'from_foundation': 'CheMeleon', 'freeze_mpnn_epochs': 10, 'n_folds': 5, 'max_epochs': 100, 'patience': 20, 'ffn_hidden_dim': 512, 'dropout': 0.15},
66
+ }
67
+
68
+
69
+ # =============================================================================
70
+ # Helper Functions
71
+ # =============================================================================
72
+ def _compute_std_confidence(df: pd.DataFrame, median_std: float, std_col: str = "prediction_std") -> pd.DataFrame:
73
+ """Compute confidence score from ensemble prediction_std.
74
+
75
+ Uses exponential decay: confidence = exp(-std / median_std)
76
+ - Low std (ensemble agreement) -> high confidence
77
+ - High std (ensemble disagreement) -> low confidence
78
+
79
+ Args:
80
+ df: DataFrame with prediction_std column
81
+ median_std: Median std from training validation set (normalization factor)
82
+ std_col: Name of the std column to use
83
+
84
+ Returns:
85
+ DataFrame with added 'confidence' column (0.0 to 1.0)
86
+ """
87
+ df["confidence"] = np.exp(-df[std_col] / median_std)
88
+ return df
89
+
90
+
91
+ def _find_smiles_column(columns: list[str]) -> str:
92
+ """Find SMILES column (case-insensitive match for 'smiles')."""
93
+ smiles_col = next((c for c in columns if c.lower() == "smiles"), None)
94
+ if smiles_col is None:
95
+ raise ValueError("Column list must contain a 'smiles' column (case-insensitive)")
96
+ return smiles_col
97
+
98
+
99
+ def _create_molecule_datapoints(
100
+ smiles_list: list[str],
101
+ targets: np.ndarray | None = None,
102
+ extra_descriptors: np.ndarray | None = None,
103
+ ) -> tuple[list[data.MoleculeDatapoint], list[int]]:
104
+ """Create ChemProp MoleculeDatapoints from SMILES strings."""
105
+ from rdkit import Chem
106
+
107
+ datapoints, valid_indices = [], []
108
+ targets = np.atleast_2d(np.array(targets)).T if targets is not None and np.array(targets).ndim == 1 else targets
109
+
110
+ for i, smi in enumerate(smiles_list):
111
+ if Chem.MolFromSmiles(smi) is None:
112
+ continue
113
+ y = targets[i].tolist() if targets is not None else None
114
+ x_d = extra_descriptors[i] if extra_descriptors is not None else None
115
+ datapoints.append(data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d))
116
+ valid_indices.append(i)
117
+
118
+ return datapoints, valid_indices
119
+
120
+
121
+ # =============================================================================
122
+ # Model Loading (for SageMaker inference)
123
+ # =============================================================================
124
+ def model_fn(model_dir: str) -> dict:
125
+ """Load ChemProp MPNN ensemble from the specified directory.
126
+
127
+ Optimized for serverless cold starts - uses direct PyTorch inference
128
+ instead of Lightning Trainer to minimize startup time.
129
+ """
130
+ metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
131
+
132
+ # Load all ensemble models (keep on CPU for serverless compatibility)
133
+ # ChemProp handles device placement internally
134
+ ensemble_models = []
135
+ for i in range(metadata["n_ensemble"]):
136
+ model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
137
+ model.eval()
138
+ ensemble_models.append(model)
139
+
140
+ print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
141
+ return {
142
+ "ensemble_models": ensemble_models,
143
+ "n_ensemble": metadata["n_ensemble"],
144
+ "target_columns": metadata["target_columns"],
145
+ "median_std": metadata["median_std"],
146
+ }
147
+
148
+
149
+ # =============================================================================
150
+ # Inference (for SageMaker inference)
151
+ # =============================================================================
152
+ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
153
+ """Make predictions with ChemProp MPNN ensemble.
154
+
155
+ Uses direct PyTorch inference (no Lightning Trainer) for fast serverless inference.
156
+ """
157
+ model_type = TEMPLATE_PARAMS["model_type"]
158
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
159
+
160
+ ensemble_models = model_dict["ensemble_models"]
161
+ target_columns = model_dict["target_columns"]
162
+
163
+ # Load artifacts
164
+ label_encoder = None
165
+ encoder_path = os.path.join(model_dir, "label_encoder.joblib")
166
+ if os.path.exists(encoder_path):
167
+ label_encoder = joblib.load(encoder_path)
168
+
169
+ feature_metadata = None
170
+ feature_path = os.path.join(model_dir, "feature_metadata.joblib")
171
+ if os.path.exists(feature_path):
172
+ feature_metadata = joblib.load(feature_path)
173
+ print(f"Hybrid mode: {len(feature_metadata['extra_feature_cols'])} extra features")
174
+
175
+ # Find SMILES column and validate
176
+ smiles_column = _find_smiles_column(df.columns.tolist())
177
+ smiles_list = df[smiles_column].tolist()
178
+
179
+ valid_mask = np.array([bool(s and isinstance(s, str) and s.strip()) for s in smiles_list])
180
+ valid_smiles = [s.strip() for i, s in enumerate(smiles_list) if valid_mask[i]]
181
+ print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
182
+
183
+ # Initialize output columns
184
+ if model_type == "classifier":
185
+ df["prediction"] = pd.Series([None] * len(df), dtype=object)
186
+ else:
187
+ for tc in target_columns:
188
+ df[f"{tc}_pred"] = np.nan
189
+ df[f"{tc}_pred_std"] = np.nan
190
+
191
+ if sum(valid_mask) == 0:
192
+ return df
193
+
194
+ # Prepare extra features (raw, unscaled - model handles scaling)
195
+ extra_features = None
196
+ if feature_metadata is not None:
197
+ extra_cols = feature_metadata["extra_feature_cols"]
198
+ col_means = np.array(feature_metadata["col_means"])
199
+ valid_indices = np.where(valid_mask)[0]
200
+
201
+ extra_features = np.zeros((len(valid_indices), len(extra_cols)), dtype=np.float32)
202
+ for j, col in enumerate(extra_cols):
203
+ if col in df.columns:
204
+ values = df.iloc[valid_indices][col].values.astype(np.float32)
205
+ values[np.isnan(values)] = col_means[j]
206
+ extra_features[:, j] = values
207
+ else:
208
+ extra_features[:, j] = col_means[j]
209
+
210
+ # Create datapoints and predict
211
+ datapoints, rdkit_valid = _create_molecule_datapoints(valid_smiles, extra_descriptors=extra_features)
212
+ if len(datapoints) == 0:
213
+ return df
214
+
215
+ dataset = data.MoleculeDataset(datapoints)
216
+ dataloader = data.build_dataloader(dataset, shuffle=False, batch_size=64)
217
+
218
+ # Ensemble predictions using direct PyTorch inference (no Lightning Trainer)
219
+ all_preds = []
220
+ for model in ensemble_models:
221
+ model_preds = []
222
+ model.eval()
223
+ with torch.inference_mode():
224
+ for batch in dataloader:
225
+ # TrainingBatch contains (bmg, V_d, X_d, targets, weights, lt_mask, gt_mask)
226
+ # For inference we only need bmg, V_d, X_d
227
+ bmg, V_d, X_d, *_ = batch
228
+ output = model(bmg, V_d, X_d)
229
+ model_preds.append(output.detach().cpu().numpy())
230
+
231
+ if len(model_preds) == 0:
232
+ print(f"Warning: No predictions generated. Dataset size: {len(datapoints)}")
233
+ continue
234
+
235
+ preds = np.concatenate(model_preds, axis=0)
236
+ if preds.ndim == 3 and preds.shape[1] == 1:
237
+ preds = preds.squeeze(axis=1)
238
+ all_preds.append(preds)
239
+
240
+ if len(all_preds) == 0:
241
+ print("Error: No ensemble predictions generated")
242
+ return df
243
+
244
+ preds = np.mean(np.stack(all_preds), axis=0)
245
+ preds_std = np.std(np.stack(all_preds), axis=0)
246
+ if preds.ndim == 1:
247
+ preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
248
+
249
+ print(f"Inference complete: {preds.shape[0]} predictions")
250
+
251
+ # Map predictions back to valid positions
252
+ valid_positions = np.where(valid_mask)[0][rdkit_valid]
253
+ valid_mask = np.zeros(len(df), dtype=bool)
254
+ valid_mask[valid_positions] = True
255
+
256
+ if model_type == "classifier" and label_encoder is not None:
257
+ if preds.shape[1] > 1:
258
+ class_preds = np.argmax(preds, axis=1)
259
+ df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform(class_preds)
260
+ proba = pd.Series([None] * len(df), dtype=object)
261
+ proba.loc[valid_mask] = [p.tolist() for p in preds]
262
+ df["pred_proba"] = proba
263
+ df = expand_proba_column(df, label_encoder.classes_)
264
+ else:
265
+ df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform((preds.flatten() > 0.5).astype(int))
266
+ else:
267
+ for t_idx, tc in enumerate(target_columns):
268
+ df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
269
+ df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
270
+ df["prediction"] = df[f"{target_columns[0]}_pred"]
271
+ df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
272
+
273
+ # Compute confidence from ensemble std (or NaN if single model)
274
+ if model_dict["median_std"] is not None:
275
+ df = _compute_std_confidence(df, model_dict["median_std"])
276
+ else:
277
+ df["confidence"] = np.nan
278
+
279
+ return df
280
+
281
+
282
+ # =============================================================================
283
+ # Training
284
+ # =============================================================================
285
+ if __name__ == "__main__":
286
+ # -------------------------------------------------------------------------
287
+ # Training-only imports (deferred to reduce serverless startup time)
288
+ # -------------------------------------------------------------------------
289
+ import argparse
290
+ import glob
291
+
292
+ import awswrangler as wr
293
+ from lightning import pytorch as pl
294
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
295
+ from sklearn.preprocessing import LabelEncoder
296
+
297
+ # Enable Tensor Core optimization for GPUs that support it
298
+ torch.set_float32_matmul_precision("medium")
299
+
300
+ from chemprop import nn
301
+
302
+ from model_script_utils import (
303
+ check_dataframe,
304
+ compute_classification_metrics,
305
+ compute_regression_metrics,
306
+ print_classification_metrics,
307
+ print_confusion_matrix,
308
+ print_regression_metrics,
309
+ )
310
+
311
+ # -------------------------------------------------------------------------
312
+ # Training-only helper functions
313
+ # -------------------------------------------------------------------------
314
+ def _load_foundation_weights(from_foundation: str) -> tuple[nn.BondMessagePassing, nn.Aggregation]:
315
+ """Load pretrained MPNN weights from foundation model.
316
+
317
+ Args:
318
+ from_foundation: "CheMeleon" or path to .pt file
319
+
320
+ Returns:
321
+ Tuple of (message_passing, aggregation) modules
322
+ """
323
+ import urllib.request
324
+ from pathlib import Path
325
+
326
+ print(f"Loading foundation model: {from_foundation}")
327
+
328
+ if from_foundation.lower() == "chemeleon":
329
+ # Download from Zenodo if not cached
330
+ cache_dir = Path.home() / ".chemprop" / "foundation"
331
+ cache_dir.mkdir(parents=True, exist_ok=True)
332
+ chemeleon_path = cache_dir / "chemeleon_mp.pt"
333
+
334
+ if not chemeleon_path.exists():
335
+ print(" Downloading CheMeleon weights from Zenodo...")
336
+ urllib.request.urlretrieve(
337
+ "https://zenodo.org/records/15460715/files/chemeleon_mp.pt", chemeleon_path
338
+ )
339
+ print(f" Downloaded to {chemeleon_path}")
340
+
341
+ ckpt = torch.load(chemeleon_path, weights_only=True)
342
+ mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
343
+ mp.load_state_dict(ckpt["state_dict"])
344
+ print(f" Loaded CheMeleon MPNN (hidden_dim={mp.output_dim})")
345
+ return mp, nn.MeanAggregation()
346
+
347
+ if not os.path.exists(from_foundation):
348
+ raise ValueError(f"Foundation model not found: {from_foundation}. Use 'CheMeleon' or a valid .pt path.")
349
+
350
+ ckpt = torch.load(from_foundation, weights_only=False)
351
+ if "hyper_parameters" in ckpt and "state_dict" in ckpt:
352
+ # CheMeleon-style checkpoint
353
+ mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
354
+ mp.load_state_dict(ckpt["state_dict"])
355
+ print(f" Loaded custom foundation weights (hidden_dim={mp.output_dim})")
356
+ return mp, nn.MeanAggregation()
357
+
358
+ # Full MPNN model file
359
+ pretrained = models.MPNN.load_from_file(from_foundation)
360
+ print(f" Loaded custom MPNN (hidden_dim={pretrained.message_passing.output_dim})")
361
+ return pretrained.message_passing, pretrained.agg
362
+
363
+ def _build_ffn(
364
+ task: str, input_dim: int, hyperparameters: dict,
365
+ num_classes: int | None, n_targets: int,
366
+ output_transform: nn.UnscaleTransform | None, task_weights: np.ndarray | None,
367
+ ) -> nn.Predictor:
368
+ """Build task-specific FFN head."""
369
+ dropout = hyperparameters["dropout"]
370
+ ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
371
+ ffn_num_layers = hyperparameters["ffn_num_layers"]
372
+
373
+ if task == "classification" and num_classes is not None:
374
+ return nn.MulticlassClassificationFFN(
375
+ n_classes=num_classes, input_dim=input_dim,
376
+ hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
377
+ )
378
+
379
+ from chemprop.nn.metrics import MAE, MSE
380
+ criterion_map = {"mae": MAE, "mse": MSE}
381
+ criterion_name = hyperparameters.get("criterion", "mae")
382
+ if criterion_name not in criterion_map:
383
+ raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
384
+
385
+ weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
386
+ return nn.RegressionFFN(
387
+ input_dim=input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
388
+ dropout=dropout, n_tasks=n_targets, output_transform=output_transform,
389
+ task_weights=weights_tensor, criterion=criterion_map[criterion_name](),
390
+ )
391
+
392
+ def build_mpnn_model(
393
+ hyperparameters: dict, task: str = "regression", num_classes: int | None = None,
394
+ n_targets: int = 1, n_extra_descriptors: int = 0,
395
+ x_d_transform: nn.ScaleTransform | None = None,
396
+ output_transform: nn.UnscaleTransform | None = None, task_weights: np.ndarray | None = None,
397
+ ) -> models.MPNN:
398
+ """Build MPNN model, optionally loading pretrained weights."""
399
+ from_foundation = hyperparameters.get("from_foundation")
400
+
401
+ if from_foundation:
402
+ mp, agg = _load_foundation_weights(from_foundation)
403
+ ffn_input_dim = mp.output_dim + n_extra_descriptors
404
+ else:
405
+ mp = nn.BondMessagePassing(
406
+ d_h=hyperparameters["hidden_dim"], depth=hyperparameters["depth"],
407
+ dropout=hyperparameters["dropout"],
408
+ )
409
+ agg = nn.NormAggregation()
410
+ ffn_input_dim = hyperparameters["hidden_dim"] + n_extra_descriptors
411
+
412
+ ffn = _build_ffn(task, ffn_input_dim, hyperparameters, num_classes, n_targets, output_transform, task_weights)
413
+ return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
414
+
415
+ # -------------------------------------------------------------------------
416
+ # Setup: Parse arguments and load data
417
+ # -------------------------------------------------------------------------
418
+ parser = argparse.ArgumentParser()
419
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
420
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
421
+ parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
422
+ args = parser.parse_args()
423
+
424
+ # Extract template parameters
425
+ target_columns = TEMPLATE_PARAMS["targets"]
426
+ model_type = TEMPLATE_PARAMS["model_type"]
427
+ feature_list = TEMPLATE_PARAMS["feature_list"]
428
+ id_column = TEMPLATE_PARAMS["id_column"]
429
+ model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
430
+ hyperparameters = {**DEFAULT_HYPERPARAMETERS, **(TEMPLATE_PARAMS["hyperparameters"] or {})}
431
+
432
+ if not target_columns or not isinstance(target_columns, list):
433
+ raise ValueError("'targets' must be a non-empty list of target column names")
434
+ n_targets = len(target_columns)
435
+
436
+ smiles_column = _find_smiles_column(feature_list)
437
+ extra_feature_cols = [f for f in feature_list if f != smiles_column]
438
+ use_extra_features = len(extra_feature_cols) > 0
439
+
440
+ print(f"Target columns ({n_targets}): {target_columns}")
441
+ print(f"SMILES column: {smiles_column}")
442
+ print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
443
+ print(f"Hyperparameters: {hyperparameters}")
444
+
445
+ # Log foundation model configuration
446
+ if hyperparameters.get("from_foundation"):
447
+ freeze_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
448
+ freeze_msg = f"MPNN frozen for {freeze_epochs} epochs" if freeze_epochs > 0 else "no freezing"
449
+ print(f"Foundation model: {hyperparameters['from_foundation']} ({freeze_msg})")
450
+ else:
451
+ print("Foundation model: None (training from scratch)")
452
+
453
+ # Load training data
454
+ training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
455
+ print(f"Training Files: {training_files}")
456
+ all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
457
+ check_dataframe(all_df, "training_df")
458
+
459
+ # Clean data
460
+ initial_count = len(all_df)
461
+ all_df = all_df.dropna(subset=[smiles_column])
462
+ all_df = all_df[all_df[target_columns].notna().any(axis=1)]
463
+ if len(all_df) < initial_count:
464
+ print(f"Dropped {initial_count - len(all_df)} rows with missing SMILES/targets")
465
+
466
+ print(f"Data shape: {all_df.shape}")
467
+ for tc in target_columns:
468
+ print(f" {tc}: {all_df[tc].notna().sum()} samples")
469
+
470
+ # -------------------------------------------------------------------------
471
+ # Classification setup
472
+ # -------------------------------------------------------------------------
473
+ label_encoder = None
474
+ num_classes = None
475
+ if model_type == "classifier":
476
+ if n_targets > 1:
477
+ raise ValueError("Multi-task classification not supported")
478
+ label_encoder = LabelEncoder()
479
+ all_df[target_columns[0]] = label_encoder.fit_transform(all_df[target_columns[0]])
480
+ num_classes = len(label_encoder.classes_)
481
+ print(f"Classification: {num_classes} classes: {label_encoder.classes_}")
482
+
483
+ # -------------------------------------------------------------------------
484
+ # Prepare features
485
+ # -------------------------------------------------------------------------
486
+ task = "classification" if model_type == "classifier" else "regression"
487
+ n_extra = len(extra_feature_cols) if use_extra_features else 0
488
+
489
+ all_extra_features, col_means = None, None
490
+ if use_extra_features:
491
+ all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
492
+ col_means = np.nanmean(all_extra_features, axis=0)
493
+ for i in range(all_extra_features.shape[1]):
494
+ all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
495
+
496
+ all_targets = all_df[target_columns].values.astype(np.float32)
497
+
498
+ # Filter invalid SMILES
499
+ _, valid_indices = _create_molecule_datapoints(all_df[smiles_column].tolist(), all_targets, all_extra_features)
500
+ all_df = all_df.iloc[valid_indices].reset_index(drop=True)
501
+ all_targets = all_targets[valid_indices]
502
+ if all_extra_features is not None:
503
+ all_extra_features = all_extra_features[valid_indices]
504
+ print(f"Data after SMILES validation: {all_df.shape}")
505
+
506
+ # Task weights for multi-task (inverse sample count)
507
+ task_weights = None
508
+ if n_targets > 1 and model_type != "classifier":
509
+ counts = np.array([np.sum(~np.isnan(all_targets[:, t])) for t in range(n_targets)])
510
+ task_weights = (1.0 / counts) / (1.0 / counts).min()
511
+ print(f"Task weights: {dict(zip(target_columns, task_weights.round(3)))}")
512
+
513
+ # -------------------------------------------------------------------------
514
+ # Cross-validation setup
515
+ # -------------------------------------------------------------------------
516
+ n_folds = hyperparameters["n_folds"]
517
+ batch_size = hyperparameters["batch_size"]
518
+
519
+ if n_folds == 1:
520
+ if "training" in all_df.columns:
521
+ print("Using 'training' column for train/val split")
522
+ train_idx = np.where(all_df["training"])[0]
523
+ val_idx = np.where(~all_df["training"])[0]
524
+ else:
525
+ print("WARNING: No 'training' column, using random 80/20 split")
526
+ train_idx, val_idx = train_test_split(np.arange(len(all_df)), test_size=0.2, random_state=42)
527
+ folds = [(train_idx, val_idx)]
528
+ else:
529
+ if model_type == "classifier":
530
+ kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
531
+ folds = list(kfold.split(all_df, all_df[target_columns[0]]))
532
+ else:
533
+ kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
534
+ folds = list(kfold.split(all_df))
535
+
536
+ print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
537
+
538
+ # -------------------------------------------------------------------------
539
+ # Training loop
540
+ # -------------------------------------------------------------------------
541
+ oof_predictions = np.full((len(all_df), n_targets), np.nan, dtype=np.float64)
542
+ oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64) if model_type == "classifier" and num_classes else None
543
+
544
+ ensemble_models = []
545
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
546
+ print(f"\n{'='*50}")
547
+ print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
548
+ print(f"{'='*50}")
549
+
550
+ # Split data (val_extra_raw preserves unscaled features for OOF predictions)
551
+ df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
552
+ train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
553
+ train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
554
+ val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
555
+ val_extra_raw = val_extra.copy() if val_extra is not None else None
556
+
557
+ # Create datasets
558
+ train_dps, _ = _create_molecule_datapoints(df_train[smiles_column].tolist(), train_targets, train_extra)
559
+ val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra)
560
+ train_dataset, val_dataset = data.MoleculeDataset(train_dps), data.MoleculeDataset(val_dps)
561
+
562
+ # Scale features/targets
563
+ x_d_transform = None
564
+ if use_extra_features:
565
+ scaler = train_dataset.normalize_inputs("X_d")
566
+ val_dataset.normalize_inputs("X_d", scaler)
567
+ x_d_transform = nn.ScaleTransform.from_standard_scaler(scaler)
568
+
569
+ output_transform = None
570
+ if model_type in ["regressor", "uq_regressor"]:
571
+ target_scaler = train_dataset.normalize_targets()
572
+ val_dataset.normalize_targets(target_scaler)
573
+ output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
574
+
575
+ train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
576
+ val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
577
+
578
+ # Build model
579
+ pl.seed_everything(hyperparameters["seed"] + fold_idx)
580
+ mpnn = build_mpnn_model(
581
+ hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
582
+ n_extra_descriptors=n_extra, x_d_transform=x_d_transform,
583
+ output_transform=output_transform, task_weights=task_weights,
584
+ )
585
+
586
+ # Train model (with optional two-phase foundation training)
587
+ freeze_mpnn_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
588
+ use_two_phase = hyperparameters.get("from_foundation") and freeze_mpnn_epochs > 0
589
+
590
+ def _set_mpnn_frozen(frozen: bool):
591
+ for param in mpnn.message_passing.parameters():
592
+ param.requires_grad = not frozen
593
+ for param in mpnn.agg.parameters():
594
+ param.requires_grad = not frozen
595
+
596
+ def _make_trainer(max_epochs: int, save_checkpoint: bool = False):
597
+ callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min")]
598
+ if save_checkpoint:
599
+ callbacks.append(pl.callbacks.ModelCheckpoint(
600
+ dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1
601
+ ))
602
+ return pl.Trainer(accelerator="auto", max_epochs=max_epochs, logger=False, enable_progress_bar=True, callbacks=callbacks)
603
+
604
+ if use_two_phase:
605
+ # Phase 1: Freeze MPNN, train FFN only
606
+ print(f"Phase 1: Training with frozen MPNN for {freeze_mpnn_epochs} epochs...")
607
+ _set_mpnn_frozen(True)
608
+ _make_trainer(freeze_mpnn_epochs).fit(mpnn, train_loader, val_loader)
609
+
610
+ # Phase 2: Unfreeze and fine-tune all
611
+ print("Phase 2: Unfreezing MPNN, continuing training...")
612
+ _set_mpnn_frozen(False)
613
+ remaining_epochs = max(1, hyperparameters["max_epochs"] - freeze_mpnn_epochs)
614
+ trainer = _make_trainer(remaining_epochs, save_checkpoint=True)
615
+ trainer.fit(mpnn, train_loader, val_loader)
616
+ else:
617
+ trainer = _make_trainer(hyperparameters["max_epochs"], save_checkpoint=True)
618
+ trainer.fit(mpnn, train_loader, val_loader)
619
+
620
+ # Load best checkpoint
621
+ if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
622
+ checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
623
+ mpnn.load_state_dict(checkpoint["state_dict"])
624
+
625
+ mpnn.eval()
626
+ ensemble_models.append(mpnn)
627
+
628
+ # Out-of-fold predictions (using unscaled features - model's x_d_transform handles scaling)
629
+ val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
630
+ val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
631
+
632
+ with torch.inference_mode():
633
+ fold_preds = np.concatenate([p.numpy() for p in trainer.predict(mpnn, val_loader_pred)], axis=0)
634
+ if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
635
+ fold_preds = fold_preds.squeeze(axis=1)
636
+
637
+ if model_type == "classifier" and fold_preds.ndim == 2:
638
+ oof_predictions[val_idx, 0] = np.argmax(fold_preds, axis=1)
639
+ if oof_proba is not None:
640
+ oof_proba[val_idx] = fold_preds
641
+ else:
642
+ if fold_preds.ndim == 1:
643
+ fold_preds = fold_preds.reshape(-1, 1)
644
+ oof_predictions[val_idx] = fold_preds
645
+
646
+ print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
647
+
648
+ # -------------------------------------------------------------------------
649
+ # Prepare validation results
650
+ # -------------------------------------------------------------------------
651
+ if n_folds == 1:
652
+ val_mask = ~np.isnan(oof_predictions).all(axis=1)
653
+ df_val = all_df[val_mask].copy()
654
+ preds = oof_predictions[val_mask]
655
+ y_validate = all_targets[val_mask]
656
+ if oof_proba is not None:
657
+ oof_proba = oof_proba[val_mask]
658
+ val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
659
+ else:
660
+ df_val = all_df.copy()
661
+ preds = oof_predictions
662
+ y_validate = all_targets
663
+ val_extra_features = all_extra_features
664
+
665
+ # -------------------------------------------------------------------------
666
+ # Compute metrics and prepare output
667
+ # -------------------------------------------------------------------------
668
+ median_std = None # Only set for regression models with ensemble
669
+ if model_type == "classifier":
670
+ class_preds = preds[:, 0].astype(int)
671
+ target_name = target_columns[0]
672
+ y_true_decoded = label_encoder.inverse_transform(y_validate[:, 0].astype(int))
673
+ preds_decoded = label_encoder.inverse_transform(class_preds)
674
+
675
+ score_df = compute_classification_metrics(y_true_decoded, preds_decoded, label_encoder.classes_, target_name)
676
+ print_classification_metrics(score_df, target_name, label_encoder.classes_)
677
+ print_confusion_matrix(y_true_decoded, preds_decoded, label_encoder.classes_)
678
+
679
+ # Decode target column back to string labels (was encoded for training)
680
+ df_val[target_name] = y_true_decoded
681
+ df_val["prediction"] = preds_decoded
682
+ if oof_proba is not None:
683
+ df_val["pred_proba"] = [p.tolist() for p in oof_proba]
684
+ df_val = expand_proba_column(df_val, label_encoder.classes_)
685
+ else:
686
+ # Compute ensemble std
687
+ preds_std = None
688
+ if len(ensemble_models) > 1:
689
+ print("Computing prediction_std from ensemble...")
690
+ val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), y_validate, val_extra_features)
691
+ val_loader = data.build_dataloader(data.MoleculeDataset(val_dps), batch_size=batch_size, shuffle=False)
692
+ trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
693
+
694
+ all_ens_preds = []
695
+ for m in ensemble_models:
696
+ with torch.inference_mode():
697
+ ens_preds = np.concatenate([p.numpy() for p in trainer_pred.predict(m, val_loader)], axis=0)
698
+ if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
699
+ ens_preds = ens_preds.squeeze(axis=1)
700
+ all_ens_preds.append(ens_preds)
701
+ preds_std = np.std(np.stack(all_ens_preds), axis=0)
702
+ if preds_std.ndim == 1:
703
+ preds_std = preds_std.reshape(-1, 1)
704
+
705
+ print("\n--- Per-target metrics ---")
706
+ for t_idx, t_name in enumerate(target_columns):
707
+ valid_mask = ~np.isnan(y_validate[:, t_idx])
708
+ if valid_mask.sum() > 0:
709
+ metrics = compute_regression_metrics(y_validate[valid_mask, t_idx], preds[valid_mask, t_idx])
710
+ print_regression_metrics(metrics)
711
+
712
+ df_val[f"{t_name}_pred"] = preds[:, t_idx]
713
+ df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx] if preds_std is not None else 0.0
714
+
715
+ df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
716
+ df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
717
+
718
+ # Compute confidence from ensemble std (or NaN for single model)
719
+ if preds_std is not None:
720
+ median_std = float(np.median(preds_std[:, 0]))
721
+ print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
722
+ df_val = _compute_std_confidence(df_val, median_std)
723
+ print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
724
+ else:
725
+ # Single model - no ensemble std available, confidence is undefined
726
+ median_std = None
727
+ df_val["confidence"] = np.nan
728
+ print("\nSingle model (n_folds=1): No ensemble std, confidence set to NaN")
729
+
730
+ # -------------------------------------------------------------------------
731
+ # Save validation predictions to S3
732
+ # -------------------------------------------------------------------------
733
+ output_columns = [id_column] if id_column in df_val.columns else []
734
+ output_columns += target_columns
735
+ output_columns += [f"{t}_pred" for t in target_columns] + [f"{t}_pred_std" for t in target_columns]
736
+ output_columns += ["prediction", "prediction_std", "confidence"]
737
+ output_columns += [c for c in df_val.columns if c.endswith("_proba")]
738
+ output_columns = [c for c in output_columns if c in df_val.columns]
739
+
740
+ wr.s3.to_csv(df_val[output_columns], f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
741
+
742
+ # -------------------------------------------------------------------------
743
+ # Save model artifacts
744
+ # -------------------------------------------------------------------------
745
+ for idx, m in enumerate(ensemble_models):
746
+ models.save_model(os.path.join(args.model_dir, f"chemprop_model_{idx}.pt"), m)
747
+ print(f"Saved {len(ensemble_models)} model(s)")
748
+
749
+ # Clean up checkpoints
750
+ for ckpt in glob.glob(os.path.join(args.model_dir, "best_*.ckpt")):
751
+ os.remove(ckpt)
752
+
753
+ ensemble_metadata = {
754
+ "n_ensemble": len(ensemble_models),
755
+ "n_folds": n_folds,
756
+ "target_columns": target_columns,
757
+ "median_std": median_std, # For confidence calculation during inference
758
+ # Foundation model provenance (for tracking/reproducibility)
759
+ "from_foundation": hyperparameters.get("from_foundation", None),
760
+ "freeze_mpnn_epochs": hyperparameters.get("freeze_mpnn_epochs", 0),
761
+ }
762
+ joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
763
+
764
+ with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
765
+ json.dump(hyperparameters, f, indent=2)
766
+
767
+ if label_encoder:
768
+ joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
769
+
770
+ if use_extra_features:
771
+ joblib.dump({"extra_feature_cols": extra_feature_cols, "col_means": col_means.tolist()}, os.path.join(args.model_dir, "feature_metadata.joblib"))
772
+ print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")
773
+
774
+ print(f"\nModel training complete! Artifacts saved to {args.model_dir}")