workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,649 @@
1
+ # ChemProp Model Template for Workbench
2
+ #
3
+ # This template handles molecular property prediction using ChemProp 2.x MPNN with:
4
+ # - K-fold cross-validation ensemble training (or single train/val split)
5
+ # - Multi-task regression support
6
+ # - Hybrid mode (SMILES + extra molecular descriptors)
7
+ # - Classification (single-target only)
8
+ #
9
+ # NOTE: Imports are structured to minimize serverless endpoint startup time.
10
+ # Heavy imports (lightning, sklearn, awswrangler) are deferred to training time.
11
+
12
+ import json
13
+ import os
14
+
15
+ import joblib
16
+ import numpy as np
17
+ import pandas as pd
18
+ import torch
19
+
20
+ from chemprop import data, models
21
+
22
+ from model_script_utils import (
23
+ expand_proba_column,
24
+ input_fn,
25
+ output_fn,
26
+ )
27
+
28
+ # =============================================================================
29
+ # Default Hyperparameters
30
+ # =============================================================================
31
+ DEFAULT_HYPERPARAMETERS = {
32
+ # Training
33
+ "n_folds": 5,
34
+ "max_epochs": 400,
35
+ "patience": 50,
36
+ "batch_size": 32,
37
+ # Message Passing
38
+ "hidden_dim": 700,
39
+ "depth": 6,
40
+ "dropout": 0.1, # Lower dropout - ensemble provides regularization
41
+ # FFN
42
+ "ffn_hidden_dim": 2000,
43
+ "ffn_num_layers": 2,
44
+ # Loss function for regression (mae, mse)
45
+ "criterion": "mae",
46
+ # Random seed
47
+ "seed": 42,
48
+ }
49
+
50
+ # Template parameters (filled in by Workbench)
51
+ TEMPLATE_PARAMS = {
52
+ "model_type": "uq_regressor",
53
+ "targets": ['logd'],
54
+ "feature_list": ['smiles', 'mollogp', 'fr_halogen', 'nbase', 'peoe_vsa6', 'bcut2d_mrlow', 'peoe_vsa7', 'peoe_vsa9', 'vsa_estate1', 'peoe_vsa1', 'numhdonors', 'vsa_estate5', 'smr_vsa3', 'slogp_vsa1', 'vsa_estate7', 'bcut2d_mwhi', 'axp_2dv', 'axp_3dv', 'mi', 'smr_vsa9', 'vsa_estate3', 'estate_vsa9', 'bcut2d_mwlow', 'tpsa', 'vsa_estate10', 'xch_5dv', 'slogp_vsa2', 'nhohcount', 'bcut2d_logplow', 'hallkieralpha', 'c2sp2', 'bcut2d_chglo', 'smr_vsa4', 'maxabspartialcharge', 'estate_vsa6', 'qed', 'slogp_vsa6', 'vsa_estate2', 'bcut2d_logphi', 'vsa_estate8', 'xch_7dv', 'fpdensitymorgan3', 'xpc_6d', 'smr_vsa10', 'axp_0d', 'fr_nh1', 'axp_4dv', 'peoe_vsa2', 'estate_vsa8', 'peoe_vsa5', 'vsa_estate6'],
55
+ "id_column": "molecule_name",
56
+ "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/logd-reg-chemprop-hybrid/training",
57
+ "hyperparameters": {},
58
+ }
59
+
60
+
61
+ # =============================================================================
62
+ # Helper Functions
63
+ # =============================================================================
64
+ def _compute_std_confidence(df: pd.DataFrame, median_std: float, std_col: str = "prediction_std") -> pd.DataFrame:
65
+ """Compute confidence score from ensemble prediction_std.
66
+
67
+ Uses exponential decay: confidence = exp(-std / median_std)
68
+ - Low std (ensemble agreement) -> high confidence
69
+ - High std (ensemble disagreement) -> low confidence
70
+
71
+ Args:
72
+ df: DataFrame with prediction_std column
73
+ median_std: Median std from training validation set (normalization factor)
74
+ std_col: Name of the std column to use
75
+
76
+ Returns:
77
+ DataFrame with added 'confidence' column (0.0 to 1.0)
78
+ """
79
+ df["confidence"] = np.exp(-df[std_col] / median_std)
80
+ return df
81
+
82
+
83
+ def _find_smiles_column(columns: list[str]) -> str:
84
+ """Find SMILES column (case-insensitive match for 'smiles')."""
85
+ smiles_col = next((c for c in columns if c.lower() == "smiles"), None)
86
+ if smiles_col is None:
87
+ raise ValueError("Column list must contain a 'smiles' column (case-insensitive)")
88
+ return smiles_col
89
+
90
+
91
+ def _create_molecule_datapoints(
92
+ smiles_list: list[str],
93
+ targets: np.ndarray | None = None,
94
+ extra_descriptors: np.ndarray | None = None,
95
+ ) -> tuple[list[data.MoleculeDatapoint], list[int]]:
96
+ """Create ChemProp MoleculeDatapoints from SMILES strings."""
97
+ from rdkit import Chem
98
+
99
+ datapoints, valid_indices = [], []
100
+ targets = np.atleast_2d(np.array(targets)).T if targets is not None and np.array(targets).ndim == 1 else targets
101
+
102
+ for i, smi in enumerate(smiles_list):
103
+ if Chem.MolFromSmiles(smi) is None:
104
+ continue
105
+ y = targets[i].tolist() if targets is not None else None
106
+ x_d = extra_descriptors[i] if extra_descriptors is not None else None
107
+ datapoints.append(data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d))
108
+ valid_indices.append(i)
109
+
110
+ return datapoints, valid_indices
111
+
112
+
113
+ # =============================================================================
114
+ # Model Loading (for SageMaker inference)
115
+ # =============================================================================
116
+ def model_fn(model_dir: str) -> dict:
117
+ """Load ChemProp MPNN ensemble from the specified directory."""
118
+ from lightning import pytorch as pl
119
+
120
+ metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
121
+ ensemble_models = []
122
+ for i in range(metadata["n_ensemble"]):
123
+ model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
124
+ model.eval()
125
+ ensemble_models.append(model)
126
+
127
+ # Pre-initialize trainer once during model loading (expensive operation)
128
+ trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
129
+
130
+ print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
131
+ return {
132
+ "ensemble_models": ensemble_models,
133
+ "n_ensemble": metadata["n_ensemble"],
134
+ "target_columns": metadata["target_columns"],
135
+ "median_std": metadata["median_std"],
136
+ "trainer": trainer,
137
+ }
138
+
139
+
140
+ # =============================================================================
141
+ # Inference (for SageMaker inference)
142
+ # =============================================================================
143
+ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
144
+ """Make predictions with ChemProp MPNN ensemble."""
145
+ model_type = TEMPLATE_PARAMS["model_type"]
146
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
147
+
148
+ ensemble_models = model_dict["ensemble_models"]
149
+ target_columns = model_dict["target_columns"]
150
+ trainer = model_dict["trainer"] # Use pre-initialized trainer
151
+
152
+ # Load artifacts
153
+ label_encoder = None
154
+ encoder_path = os.path.join(model_dir, "label_encoder.joblib")
155
+ if os.path.exists(encoder_path):
156
+ label_encoder = joblib.load(encoder_path)
157
+
158
+ feature_metadata = None
159
+ feature_path = os.path.join(model_dir, "feature_metadata.joblib")
160
+ if os.path.exists(feature_path):
161
+ feature_metadata = joblib.load(feature_path)
162
+ print(f"Hybrid mode: {len(feature_metadata['extra_feature_cols'])} extra features")
163
+
164
+ # Find SMILES column and validate
165
+ smiles_column = _find_smiles_column(df.columns.tolist())
166
+ smiles_list = df[smiles_column].tolist()
167
+
168
+ valid_mask = np.array([bool(s and isinstance(s, str) and s.strip()) for s in smiles_list])
169
+ valid_smiles = [s.strip() for i, s in enumerate(smiles_list) if valid_mask[i]]
170
+ print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
171
+
172
+ # Initialize output columns
173
+ if model_type == "classifier":
174
+ df["prediction"] = pd.Series([None] * len(df), dtype=object)
175
+ else:
176
+ for tc in target_columns:
177
+ df[f"{tc}_pred"] = np.nan
178
+ df[f"{tc}_pred_std"] = np.nan
179
+
180
+ if sum(valid_mask) == 0:
181
+ return df
182
+
183
+ # Prepare extra features (raw, unscaled - model handles scaling)
184
+ extra_features = None
185
+ if feature_metadata is not None:
186
+ extra_cols = feature_metadata["extra_feature_cols"]
187
+ col_means = np.array(feature_metadata["col_means"])
188
+ valid_indices = np.where(valid_mask)[0]
189
+
190
+ extra_features = np.zeros((len(valid_indices), len(extra_cols)), dtype=np.float32)
191
+ for j, col in enumerate(extra_cols):
192
+ if col in df.columns:
193
+ values = df.iloc[valid_indices][col].values.astype(np.float32)
194
+ values[np.isnan(values)] = col_means[j]
195
+ extra_features[:, j] = values
196
+ else:
197
+ extra_features[:, j] = col_means[j]
198
+
199
+ # Create datapoints and predict
200
+ datapoints, rdkit_valid = _create_molecule_datapoints(valid_smiles, extra_descriptors=extra_features)
201
+ if len(datapoints) == 0:
202
+ return df
203
+
204
+ dataset = data.MoleculeDataset(datapoints)
205
+ dataloader = data.build_dataloader(dataset, shuffle=False)
206
+
207
+ # Ensemble predictions
208
+ all_preds = []
209
+ for model in ensemble_models:
210
+ with torch.inference_mode():
211
+ predictions = trainer.predict(model, dataloader)
212
+ preds = np.concatenate([p.numpy() for p in predictions], axis=0)
213
+ if preds.ndim == 3 and preds.shape[1] == 1:
214
+ preds = preds.squeeze(axis=1)
215
+ all_preds.append(preds)
216
+
217
+ preds = np.mean(np.stack(all_preds), axis=0)
218
+ preds_std = np.std(np.stack(all_preds), axis=0)
219
+ if preds.ndim == 1:
220
+ preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
221
+
222
+ print(f"Inference complete: {preds.shape[0]} predictions")
223
+
224
+ # Map predictions back to valid positions
225
+ valid_positions = np.where(valid_mask)[0][rdkit_valid]
226
+ valid_mask = np.zeros(len(df), dtype=bool)
227
+ valid_mask[valid_positions] = True
228
+
229
+ if model_type == "classifier" and label_encoder is not None:
230
+ if preds.shape[1] > 1:
231
+ class_preds = np.argmax(preds, axis=1)
232
+ df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform(class_preds)
233
+ proba = pd.Series([None] * len(df), dtype=object)
234
+ proba.loc[valid_mask] = [p.tolist() for p in preds]
235
+ df["pred_proba"] = proba
236
+ df = expand_proba_column(df, label_encoder.classes_)
237
+ else:
238
+ df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform((preds.flatten() > 0.5).astype(int))
239
+ else:
240
+ for t_idx, tc in enumerate(target_columns):
241
+ df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
242
+ df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
243
+ df["prediction"] = df[f"{target_columns[0]}_pred"]
244
+ df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
245
+
246
+ # Compute confidence from ensemble std
247
+ df = _compute_std_confidence(df, model_dict["median_std"])
248
+
249
+ return df
250
+
251
+
252
+ # =============================================================================
253
+ # Training
254
+ # =============================================================================
255
+ if __name__ == "__main__":
256
+ # -------------------------------------------------------------------------
257
+ # Training-only imports (deferred to reduce serverless startup time)
258
+ # -------------------------------------------------------------------------
259
+ import argparse
260
+ import glob
261
+
262
+ import awswrangler as wr
263
+ from lightning import pytorch as pl
264
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
265
+ from sklearn.preprocessing import LabelEncoder
266
+
267
+ # Enable Tensor Core optimization for GPUs that support it
268
+ torch.set_float32_matmul_precision("medium")
269
+
270
+ from chemprop import nn
271
+
272
+ from model_script_utils import (
273
+ check_dataframe,
274
+ compute_classification_metrics,
275
+ compute_regression_metrics,
276
+ print_classification_metrics,
277
+ print_confusion_matrix,
278
+ print_regression_metrics,
279
+ )
280
+
281
+ # -------------------------------------------------------------------------
282
+ # Training-only helper function
283
+ # -------------------------------------------------------------------------
284
+ def build_mpnn_model(
285
+ hyperparameters: dict,
286
+ task: str = "regression",
287
+ num_classes: int | None = None,
288
+ n_targets: int = 1,
289
+ n_extra_descriptors: int = 0,
290
+ x_d_transform: nn.ScaleTransform | None = None,
291
+ output_transform: nn.UnscaleTransform | None = None,
292
+ task_weights: np.ndarray | None = None,
293
+ ) -> models.MPNN:
294
+ """Build an MPNN model with specified hyperparameters."""
295
+ hidden_dim = hyperparameters["hidden_dim"]
296
+ depth = hyperparameters["depth"]
297
+ dropout = hyperparameters["dropout"]
298
+ ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
299
+ ffn_num_layers = hyperparameters["ffn_num_layers"]
300
+
301
+ mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
302
+ agg = nn.NormAggregation()
303
+ ffn_input_dim = hidden_dim + n_extra_descriptors
304
+
305
+ if task == "classification" and num_classes is not None:
306
+ ffn = nn.MulticlassClassificationFFN(
307
+ n_classes=num_classes, input_dim=ffn_input_dim,
308
+ hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
309
+ )
310
+ else:
311
+ # Map criterion name to ChemProp metric class (must have .clone() method)
312
+ from chemprop.nn.metrics import MAE, MSE
313
+
314
+ criterion_map = {
315
+ "mae": MAE,
316
+ "mse": MSE,
317
+ }
318
+ criterion_name = hyperparameters.get("criterion", "mae")
319
+ if criterion_name not in criterion_map:
320
+ raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
321
+ criterion = criterion_map[criterion_name]()
322
+
323
+ weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
324
+ ffn = nn.RegressionFFN(
325
+ input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
326
+ dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
327
+ criterion=criterion,
328
+ )
329
+
330
+ return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
331
+
332
+ # -------------------------------------------------------------------------
333
+ # Setup: Parse arguments and load data
334
+ # -------------------------------------------------------------------------
335
+ parser = argparse.ArgumentParser()
336
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
337
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
338
+ parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
339
+ args = parser.parse_args()
340
+
341
+ # Extract template parameters
342
+ target_columns = TEMPLATE_PARAMS["targets"]
343
+ model_type = TEMPLATE_PARAMS["model_type"]
344
+ feature_list = TEMPLATE_PARAMS["feature_list"]
345
+ id_column = TEMPLATE_PARAMS["id_column"]
346
+ model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
347
+ hyperparameters = {**DEFAULT_HYPERPARAMETERS, **(TEMPLATE_PARAMS["hyperparameters"] or {})}
348
+
349
+ if not target_columns or not isinstance(target_columns, list):
350
+ raise ValueError("'targets' must be a non-empty list of target column names")
351
+ n_targets = len(target_columns)
352
+
353
+ smiles_column = _find_smiles_column(feature_list)
354
+ extra_feature_cols = [f for f in feature_list if f != smiles_column]
355
+ use_extra_features = len(extra_feature_cols) > 0
356
+
357
+ print(f"Target columns ({n_targets}): {target_columns}")
358
+ print(f"SMILES column: {smiles_column}")
359
+ print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
360
+ print(f"Hyperparameters: {hyperparameters}")
361
+
362
+ # Load training data
363
+ training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
364
+ print(f"Training Files: {training_files}")
365
+ all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
366
+ check_dataframe(all_df, "training_df")
367
+
368
+ # Clean data
369
+ initial_count = len(all_df)
370
+ all_df = all_df.dropna(subset=[smiles_column])
371
+ all_df = all_df[all_df[target_columns].notna().any(axis=1)]
372
+ if len(all_df) < initial_count:
373
+ print(f"Dropped {initial_count - len(all_df)} rows with missing SMILES/targets")
374
+
375
+ print(f"Data shape: {all_df.shape}")
376
+ for tc in target_columns:
377
+ print(f" {tc}: {all_df[tc].notna().sum()} samples")
378
+
379
+ # -------------------------------------------------------------------------
380
+ # Classification setup
381
+ # -------------------------------------------------------------------------
382
+ label_encoder = None
383
+ num_classes = None
384
+ if model_type == "classifier":
385
+ if n_targets > 1:
386
+ raise ValueError("Multi-task classification not supported")
387
+ label_encoder = LabelEncoder()
388
+ all_df[target_columns[0]] = label_encoder.fit_transform(all_df[target_columns[0]])
389
+ num_classes = len(label_encoder.classes_)
390
+ print(f"Classification: {num_classes} classes: {label_encoder.classes_}")
391
+
392
+ # -------------------------------------------------------------------------
393
+ # Prepare features
394
+ # -------------------------------------------------------------------------
395
+ task = "classification" if model_type == "classifier" else "regression"
396
+ n_extra = len(extra_feature_cols) if use_extra_features else 0
397
+
398
+ all_extra_features, col_means = None, None
399
+ if use_extra_features:
400
+ all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
401
+ col_means = np.nanmean(all_extra_features, axis=0)
402
+ for i in range(all_extra_features.shape[1]):
403
+ all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
404
+
405
+ all_targets = all_df[target_columns].values.astype(np.float32)
406
+
407
+ # Filter invalid SMILES
408
+ _, valid_indices = _create_molecule_datapoints(all_df[smiles_column].tolist(), all_targets, all_extra_features)
409
+ all_df = all_df.iloc[valid_indices].reset_index(drop=True)
410
+ all_targets = all_targets[valid_indices]
411
+ if all_extra_features is not None:
412
+ all_extra_features = all_extra_features[valid_indices]
413
+ print(f"Data after SMILES validation: {all_df.shape}")
414
+
415
+ # Task weights for multi-task (inverse sample count)
416
+ task_weights = None
417
+ if n_targets > 1 and model_type != "classifier":
418
+ counts = np.array([np.sum(~np.isnan(all_targets[:, t])) for t in range(n_targets)])
419
+ task_weights = (1.0 / counts) / (1.0 / counts).min()
420
+ print(f"Task weights: {dict(zip(target_columns, task_weights.round(3)))}")
421
+
422
+ # -------------------------------------------------------------------------
423
+ # Cross-validation setup
424
+ # -------------------------------------------------------------------------
425
+ n_folds = hyperparameters["n_folds"]
426
+ batch_size = hyperparameters["batch_size"]
427
+
428
+ if n_folds == 1:
429
+ if "training" in all_df.columns:
430
+ print("Using 'training' column for train/val split")
431
+ train_idx = np.where(all_df["training"])[0]
432
+ val_idx = np.where(~all_df["training"])[0]
433
+ else:
434
+ print("WARNING: No 'training' column, using random 80/20 split")
435
+ train_idx, val_idx = train_test_split(np.arange(len(all_df)), test_size=0.2, random_state=42)
436
+ folds = [(train_idx, val_idx)]
437
+ else:
438
+ if model_type == "classifier":
439
+ kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
440
+ folds = list(kfold.split(all_df, all_df[target_columns[0]]))
441
+ else:
442
+ kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
443
+ folds = list(kfold.split(all_df))
444
+
445
+ print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
446
+
447
+ # -------------------------------------------------------------------------
448
+ # Training loop
449
+ # -------------------------------------------------------------------------
450
+ oof_predictions = np.full((len(all_df), n_targets), np.nan, dtype=np.float64)
451
+ oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64) if model_type == "classifier" and num_classes else None
452
+
453
+ ensemble_models = []
454
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
455
+ print(f"\n{'='*50}")
456
+ print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
457
+ print(f"{'='*50}")
458
+
459
+ # Split data
460
+ df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
461
+ train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
462
+ train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
463
+ val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
464
+ val_extra_raw = val_extra.copy() if val_extra is not None else None
465
+
466
+ # Create datasets
467
+ train_dps, _ = _create_molecule_datapoints(df_train[smiles_column].tolist(), train_targets, train_extra)
468
+ val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra)
469
+ train_dataset, val_dataset = data.MoleculeDataset(train_dps), data.MoleculeDataset(val_dps)
470
+
471
+ # Scale features/targets
472
+ x_d_transform = None
473
+ if use_extra_features:
474
+ scaler = train_dataset.normalize_inputs("X_d")
475
+ val_dataset.normalize_inputs("X_d", scaler)
476
+ x_d_transform = nn.ScaleTransform.from_standard_scaler(scaler)
477
+
478
+ output_transform = None
479
+ if model_type in ["regressor", "uq_regressor"]:
480
+ target_scaler = train_dataset.normalize_targets()
481
+ val_dataset.normalize_targets(target_scaler)
482
+ output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
483
+
484
+ train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
485
+ val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
486
+
487
+ # Build and train model
488
+ pl.seed_everything(hyperparameters["seed"] + fold_idx)
489
+ mpnn = build_mpnn_model(
490
+ hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
491
+ n_extra_descriptors=n_extra, x_d_transform=x_d_transform,
492
+ output_transform=output_transform, task_weights=task_weights,
493
+ )
494
+
495
+ trainer = pl.Trainer(
496
+ accelerator="auto", max_epochs=hyperparameters["max_epochs"], logger=False, enable_progress_bar=True,
497
+ callbacks=[
498
+ pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min"),
499
+ pl.callbacks.ModelCheckpoint(dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1),
500
+ ],
501
+ )
502
+ trainer.fit(mpnn, train_loader, val_loader)
503
+
504
+ # Load best checkpoint
505
+ if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
506
+ checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
507
+ mpnn.load_state_dict(checkpoint["state_dict"])
508
+
509
+ mpnn.eval()
510
+ ensemble_models.append(mpnn)
511
+
512
+ # Out-of-fold predictions (using raw features)
513
+ val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
514
+ val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
515
+
516
+ with torch.inference_mode():
517
+ fold_preds = np.concatenate([p.numpy() for p in trainer.predict(mpnn, val_loader_pred)], axis=0)
518
+ if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
519
+ fold_preds = fold_preds.squeeze(axis=1)
520
+
521
+ if model_type == "classifier" and fold_preds.ndim == 2:
522
+ oof_predictions[val_idx, 0] = np.argmax(fold_preds, axis=1)
523
+ if oof_proba is not None:
524
+ oof_proba[val_idx] = fold_preds
525
+ else:
526
+ if fold_preds.ndim == 1:
527
+ fold_preds = fold_preds.reshape(-1, 1)
528
+ oof_predictions[val_idx] = fold_preds
529
+
530
+ print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
531
+
532
+ # -------------------------------------------------------------------------
533
+ # Prepare validation results
534
+ # -------------------------------------------------------------------------
535
+ if n_folds == 1:
536
+ val_mask = ~np.isnan(oof_predictions).all(axis=1)
537
+ df_val = all_df[val_mask].copy()
538
+ preds = oof_predictions[val_mask]
539
+ y_validate = all_targets[val_mask]
540
+ if oof_proba is not None:
541
+ oof_proba = oof_proba[val_mask]
542
+ val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
543
+ else:
544
+ df_val = all_df.copy()
545
+ preds = oof_predictions
546
+ y_validate = all_targets
547
+ val_extra_features = all_extra_features
548
+
549
+ # -------------------------------------------------------------------------
550
+ # Compute metrics and prepare output
551
+ # -------------------------------------------------------------------------
552
+ median_std = None # Only set for regression models with ensemble
553
+ if model_type == "classifier":
554
+ class_preds = preds[:, 0].astype(int)
555
+ target_name = target_columns[0]
556
+ y_true_decoded = label_encoder.inverse_transform(y_validate[:, 0].astype(int))
557
+ preds_decoded = label_encoder.inverse_transform(class_preds)
558
+
559
+ score_df = compute_classification_metrics(y_true_decoded, preds_decoded, label_encoder.classes_, target_name)
560
+ print_classification_metrics(score_df, target_name, label_encoder.classes_)
561
+ print_confusion_matrix(y_true_decoded, preds_decoded, label_encoder.classes_)
562
+
563
+ # Decode target column back to string labels (was encoded for training)
564
+ df_val[target_name] = y_true_decoded
565
+ df_val["prediction"] = preds_decoded
566
+ if oof_proba is not None:
567
+ df_val["pred_proba"] = [p.tolist() for p in oof_proba]
568
+ df_val = expand_proba_column(df_val, label_encoder.classes_)
569
+ else:
570
+ # Compute ensemble std
571
+ preds_std = None
572
+ if len(ensemble_models) > 1:
573
+ print("Computing prediction_std from ensemble...")
574
+ val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), y_validate, val_extra_features)
575
+ val_loader = data.build_dataloader(data.MoleculeDataset(val_dps), batch_size=batch_size, shuffle=False)
576
+ trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
577
+
578
+ all_ens_preds = []
579
+ for m in ensemble_models:
580
+ with torch.inference_mode():
581
+ ens_preds = np.concatenate([p.numpy() for p in trainer_pred.predict(m, val_loader)], axis=0)
582
+ if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
583
+ ens_preds = ens_preds.squeeze(axis=1)
584
+ all_ens_preds.append(ens_preds)
585
+ preds_std = np.std(np.stack(all_ens_preds), axis=0)
586
+ if preds_std.ndim == 1:
587
+ preds_std = preds_std.reshape(-1, 1)
588
+
589
+ print("\n--- Per-target metrics ---")
590
+ for t_idx, t_name in enumerate(target_columns):
591
+ valid_mask = ~np.isnan(y_validate[:, t_idx])
592
+ if valid_mask.sum() > 0:
593
+ metrics = compute_regression_metrics(y_validate[valid_mask, t_idx], preds[valid_mask, t_idx])
594
+ print_regression_metrics(metrics)
595
+
596
+ df_val[f"{t_name}_pred"] = preds[:, t_idx]
597
+ df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx] if preds_std is not None else 0.0
598
+
599
+ df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
600
+ df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
601
+
602
+ # Compute confidence from ensemble std
603
+ median_std = float(np.median(preds_std[:, 0]))
604
+ print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
605
+ df_val = _compute_std_confidence(df_val, median_std)
606
+ print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
607
+
608
+ # -------------------------------------------------------------------------
609
+ # Save validation predictions to S3
610
+ # -------------------------------------------------------------------------
611
+ output_columns = [id_column] if id_column in df_val.columns else []
612
+ output_columns += target_columns
613
+ output_columns += [f"{t}_pred" for t in target_columns] + [f"{t}_pred_std" for t in target_columns]
614
+ output_columns += ["prediction", "prediction_std", "confidence"]
615
+ output_columns += [c for c in df_val.columns if c.endswith("_proba")]
616
+ output_columns = [c for c in output_columns if c in df_val.columns]
617
+
618
+ wr.s3.to_csv(df_val[output_columns], f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
619
+
620
+ # -------------------------------------------------------------------------
621
+ # Save model artifacts
622
+ # -------------------------------------------------------------------------
623
+ for idx, m in enumerate(ensemble_models):
624
+ models.save_model(os.path.join(args.model_dir, f"chemprop_model_{idx}.pt"), m)
625
+ print(f"Saved {len(ensemble_models)} model(s)")
626
+
627
+ # Clean up checkpoints
628
+ for ckpt in glob.glob(os.path.join(args.model_dir, "best_*.ckpt")):
629
+ os.remove(ckpt)
630
+
631
+ ensemble_metadata = {
632
+ "n_ensemble": len(ensemble_models),
633
+ "n_folds": n_folds,
634
+ "target_columns": target_columns,
635
+ "median_std": median_std, # For confidence calculation during inference
636
+ }
637
+ joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
638
+
639
+ with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
640
+ json.dump(hyperparameters, f, indent=2)
641
+
642
+ if label_encoder:
643
+ joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
644
+
645
+ if use_extra_features:
646
+ joblib.dump({"extra_feature_cols": extra_feature_cols, "col_means": col_means.tolist()}, os.path.join(args.model_dir, "feature_metadata.joblib"))
647
+ print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")
648
+
649
+ print(f"\nModel training complete! Artifacts saved to {args.model_dir}")