workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +14 -12
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/compound.py +1 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +18 -5
- workbench/api/feature_set.py +121 -15
- workbench/api/meta.py +5 -2
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +55 -21
- workbench/api/monitor.py +1 -16
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +16 -8
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +382 -253
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -80
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +62 -40
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +649 -0
- workbench/model_scripts/chemprop/generated_model_script.py +649 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +20 -11
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +369 -401
- workbench/repl/workbench_shell.py +28 -19
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/scripts/training_test.py +85 -0
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +175 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +219 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +278 -79
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -219
- workbench/web_interface/components/model_plot.py +14 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +38 -74
- workbench/web_interface/components/plugins/scatter_plot.py +6 -10
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
- workbench-0.8.220.dist-info/entry_points.txt +11 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- workbench-0.8.162.dist-info/entry_points.txt +0 -5
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,649 @@
|
|
|
1
|
+
# ChemProp Model Template for Workbench
|
|
2
|
+
#
|
|
3
|
+
# This template handles molecular property prediction using ChemProp 2.x MPNN with:
|
|
4
|
+
# - K-fold cross-validation ensemble training (or single train/val split)
|
|
5
|
+
# - Multi-task regression support
|
|
6
|
+
# - Hybrid mode (SMILES + extra molecular descriptors)
|
|
7
|
+
# - Classification (single-target only)
|
|
8
|
+
#
|
|
9
|
+
# NOTE: Imports are structured to minimize serverless endpoint startup time.
|
|
10
|
+
# Heavy imports (lightning, sklearn, awswrangler) are deferred to training time.
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import os
|
|
14
|
+
|
|
15
|
+
import joblib
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from chemprop import data, models
|
|
21
|
+
|
|
22
|
+
from model_script_utils import (
|
|
23
|
+
expand_proba_column,
|
|
24
|
+
input_fn,
|
|
25
|
+
output_fn,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# =============================================================================
|
|
29
|
+
# Default Hyperparameters
|
|
30
|
+
# =============================================================================
|
|
31
|
+
DEFAULT_HYPERPARAMETERS = {
|
|
32
|
+
# Training
|
|
33
|
+
"n_folds": 5,
|
|
34
|
+
"max_epochs": 400,
|
|
35
|
+
"patience": 50,
|
|
36
|
+
"batch_size": 32,
|
|
37
|
+
# Message Passing
|
|
38
|
+
"hidden_dim": 700,
|
|
39
|
+
"depth": 6,
|
|
40
|
+
"dropout": 0.1, # Lower dropout - ensemble provides regularization
|
|
41
|
+
# FFN
|
|
42
|
+
"ffn_hidden_dim": 2000,
|
|
43
|
+
"ffn_num_layers": 2,
|
|
44
|
+
# Loss function for regression (mae, mse)
|
|
45
|
+
"criterion": "mae",
|
|
46
|
+
# Random seed
|
|
47
|
+
"seed": 42,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
# Template parameters (filled in by Workbench)
|
|
51
|
+
TEMPLATE_PARAMS = {
|
|
52
|
+
"model_type": "uq_regressor",
|
|
53
|
+
"targets": ['logd'],
|
|
54
|
+
"feature_list": ['smiles', 'mollogp', 'fr_halogen', 'nbase', 'peoe_vsa6', 'bcut2d_mrlow', 'peoe_vsa7', 'peoe_vsa9', 'vsa_estate1', 'peoe_vsa1', 'numhdonors', 'vsa_estate5', 'smr_vsa3', 'slogp_vsa1', 'vsa_estate7', 'bcut2d_mwhi', 'axp_2dv', 'axp_3dv', 'mi', 'smr_vsa9', 'vsa_estate3', 'estate_vsa9', 'bcut2d_mwlow', 'tpsa', 'vsa_estate10', 'xch_5dv', 'slogp_vsa2', 'nhohcount', 'bcut2d_logplow', 'hallkieralpha', 'c2sp2', 'bcut2d_chglo', 'smr_vsa4', 'maxabspartialcharge', 'estate_vsa6', 'qed', 'slogp_vsa6', 'vsa_estate2', 'bcut2d_logphi', 'vsa_estate8', 'xch_7dv', 'fpdensitymorgan3', 'xpc_6d', 'smr_vsa10', 'axp_0d', 'fr_nh1', 'axp_4dv', 'peoe_vsa2', 'estate_vsa8', 'peoe_vsa5', 'vsa_estate6'],
|
|
55
|
+
"id_column": "molecule_name",
|
|
56
|
+
"model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/logd-reg-chemprop-hybrid/training",
|
|
57
|
+
"hyperparameters": {},
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# =============================================================================
|
|
62
|
+
# Helper Functions
|
|
63
|
+
# =============================================================================
|
|
64
|
+
def _compute_std_confidence(df: pd.DataFrame, median_std: float, std_col: str = "prediction_std") -> pd.DataFrame:
|
|
65
|
+
"""Compute confidence score from ensemble prediction_std.
|
|
66
|
+
|
|
67
|
+
Uses exponential decay: confidence = exp(-std / median_std)
|
|
68
|
+
- Low std (ensemble agreement) -> high confidence
|
|
69
|
+
- High std (ensemble disagreement) -> low confidence
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
df: DataFrame with prediction_std column
|
|
73
|
+
median_std: Median std from training validation set (normalization factor)
|
|
74
|
+
std_col: Name of the std column to use
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
DataFrame with added 'confidence' column (0.0 to 1.0)
|
|
78
|
+
"""
|
|
79
|
+
df["confidence"] = np.exp(-df[std_col] / median_std)
|
|
80
|
+
return df
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _find_smiles_column(columns: list[str]) -> str:
|
|
84
|
+
"""Find SMILES column (case-insensitive match for 'smiles')."""
|
|
85
|
+
smiles_col = next((c for c in columns if c.lower() == "smiles"), None)
|
|
86
|
+
if smiles_col is None:
|
|
87
|
+
raise ValueError("Column list must contain a 'smiles' column (case-insensitive)")
|
|
88
|
+
return smiles_col
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _create_molecule_datapoints(
|
|
92
|
+
smiles_list: list[str],
|
|
93
|
+
targets: np.ndarray | None = None,
|
|
94
|
+
extra_descriptors: np.ndarray | None = None,
|
|
95
|
+
) -> tuple[list[data.MoleculeDatapoint], list[int]]:
|
|
96
|
+
"""Create ChemProp MoleculeDatapoints from SMILES strings."""
|
|
97
|
+
from rdkit import Chem
|
|
98
|
+
|
|
99
|
+
datapoints, valid_indices = [], []
|
|
100
|
+
targets = np.atleast_2d(np.array(targets)).T if targets is not None and np.array(targets).ndim == 1 else targets
|
|
101
|
+
|
|
102
|
+
for i, smi in enumerate(smiles_list):
|
|
103
|
+
if Chem.MolFromSmiles(smi) is None:
|
|
104
|
+
continue
|
|
105
|
+
y = targets[i].tolist() if targets is not None else None
|
|
106
|
+
x_d = extra_descriptors[i] if extra_descriptors is not None else None
|
|
107
|
+
datapoints.append(data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d))
|
|
108
|
+
valid_indices.append(i)
|
|
109
|
+
|
|
110
|
+
return datapoints, valid_indices
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# =============================================================================
|
|
114
|
+
# Model Loading (for SageMaker inference)
|
|
115
|
+
# =============================================================================
|
|
116
|
+
def model_fn(model_dir: str) -> dict:
|
|
117
|
+
"""Load ChemProp MPNN ensemble from the specified directory."""
|
|
118
|
+
from lightning import pytorch as pl
|
|
119
|
+
|
|
120
|
+
metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
|
|
121
|
+
ensemble_models = []
|
|
122
|
+
for i in range(metadata["n_ensemble"]):
|
|
123
|
+
model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
|
|
124
|
+
model.eval()
|
|
125
|
+
ensemble_models.append(model)
|
|
126
|
+
|
|
127
|
+
# Pre-initialize trainer once during model loading (expensive operation)
|
|
128
|
+
trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
129
|
+
|
|
130
|
+
print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
|
|
131
|
+
return {
|
|
132
|
+
"ensemble_models": ensemble_models,
|
|
133
|
+
"n_ensemble": metadata["n_ensemble"],
|
|
134
|
+
"target_columns": metadata["target_columns"],
|
|
135
|
+
"median_std": metadata["median_std"],
|
|
136
|
+
"trainer": trainer,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# =============================================================================
|
|
141
|
+
# Inference (for SageMaker inference)
|
|
142
|
+
# =============================================================================
|
|
143
|
+
def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
144
|
+
"""Make predictions with ChemProp MPNN ensemble."""
|
|
145
|
+
model_type = TEMPLATE_PARAMS["model_type"]
|
|
146
|
+
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
147
|
+
|
|
148
|
+
ensemble_models = model_dict["ensemble_models"]
|
|
149
|
+
target_columns = model_dict["target_columns"]
|
|
150
|
+
trainer = model_dict["trainer"] # Use pre-initialized trainer
|
|
151
|
+
|
|
152
|
+
# Load artifacts
|
|
153
|
+
label_encoder = None
|
|
154
|
+
encoder_path = os.path.join(model_dir, "label_encoder.joblib")
|
|
155
|
+
if os.path.exists(encoder_path):
|
|
156
|
+
label_encoder = joblib.load(encoder_path)
|
|
157
|
+
|
|
158
|
+
feature_metadata = None
|
|
159
|
+
feature_path = os.path.join(model_dir, "feature_metadata.joblib")
|
|
160
|
+
if os.path.exists(feature_path):
|
|
161
|
+
feature_metadata = joblib.load(feature_path)
|
|
162
|
+
print(f"Hybrid mode: {len(feature_metadata['extra_feature_cols'])} extra features")
|
|
163
|
+
|
|
164
|
+
# Find SMILES column and validate
|
|
165
|
+
smiles_column = _find_smiles_column(df.columns.tolist())
|
|
166
|
+
smiles_list = df[smiles_column].tolist()
|
|
167
|
+
|
|
168
|
+
valid_mask = np.array([bool(s and isinstance(s, str) and s.strip()) for s in smiles_list])
|
|
169
|
+
valid_smiles = [s.strip() for i, s in enumerate(smiles_list) if valid_mask[i]]
|
|
170
|
+
print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
|
|
171
|
+
|
|
172
|
+
# Initialize output columns
|
|
173
|
+
if model_type == "classifier":
|
|
174
|
+
df["prediction"] = pd.Series([None] * len(df), dtype=object)
|
|
175
|
+
else:
|
|
176
|
+
for tc in target_columns:
|
|
177
|
+
df[f"{tc}_pred"] = np.nan
|
|
178
|
+
df[f"{tc}_pred_std"] = np.nan
|
|
179
|
+
|
|
180
|
+
if sum(valid_mask) == 0:
|
|
181
|
+
return df
|
|
182
|
+
|
|
183
|
+
# Prepare extra features (raw, unscaled - model handles scaling)
|
|
184
|
+
extra_features = None
|
|
185
|
+
if feature_metadata is not None:
|
|
186
|
+
extra_cols = feature_metadata["extra_feature_cols"]
|
|
187
|
+
col_means = np.array(feature_metadata["col_means"])
|
|
188
|
+
valid_indices = np.where(valid_mask)[0]
|
|
189
|
+
|
|
190
|
+
extra_features = np.zeros((len(valid_indices), len(extra_cols)), dtype=np.float32)
|
|
191
|
+
for j, col in enumerate(extra_cols):
|
|
192
|
+
if col in df.columns:
|
|
193
|
+
values = df.iloc[valid_indices][col].values.astype(np.float32)
|
|
194
|
+
values[np.isnan(values)] = col_means[j]
|
|
195
|
+
extra_features[:, j] = values
|
|
196
|
+
else:
|
|
197
|
+
extra_features[:, j] = col_means[j]
|
|
198
|
+
|
|
199
|
+
# Create datapoints and predict
|
|
200
|
+
datapoints, rdkit_valid = _create_molecule_datapoints(valid_smiles, extra_descriptors=extra_features)
|
|
201
|
+
if len(datapoints) == 0:
|
|
202
|
+
return df
|
|
203
|
+
|
|
204
|
+
dataset = data.MoleculeDataset(datapoints)
|
|
205
|
+
dataloader = data.build_dataloader(dataset, shuffle=False)
|
|
206
|
+
|
|
207
|
+
# Ensemble predictions
|
|
208
|
+
all_preds = []
|
|
209
|
+
for model in ensemble_models:
|
|
210
|
+
with torch.inference_mode():
|
|
211
|
+
predictions = trainer.predict(model, dataloader)
|
|
212
|
+
preds = np.concatenate([p.numpy() for p in predictions], axis=0)
|
|
213
|
+
if preds.ndim == 3 and preds.shape[1] == 1:
|
|
214
|
+
preds = preds.squeeze(axis=1)
|
|
215
|
+
all_preds.append(preds)
|
|
216
|
+
|
|
217
|
+
preds = np.mean(np.stack(all_preds), axis=0)
|
|
218
|
+
preds_std = np.std(np.stack(all_preds), axis=0)
|
|
219
|
+
if preds.ndim == 1:
|
|
220
|
+
preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
|
|
221
|
+
|
|
222
|
+
print(f"Inference complete: {preds.shape[0]} predictions")
|
|
223
|
+
|
|
224
|
+
# Map predictions back to valid positions
|
|
225
|
+
valid_positions = np.where(valid_mask)[0][rdkit_valid]
|
|
226
|
+
valid_mask = np.zeros(len(df), dtype=bool)
|
|
227
|
+
valid_mask[valid_positions] = True
|
|
228
|
+
|
|
229
|
+
if model_type == "classifier" and label_encoder is not None:
|
|
230
|
+
if preds.shape[1] > 1:
|
|
231
|
+
class_preds = np.argmax(preds, axis=1)
|
|
232
|
+
df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform(class_preds)
|
|
233
|
+
proba = pd.Series([None] * len(df), dtype=object)
|
|
234
|
+
proba.loc[valid_mask] = [p.tolist() for p in preds]
|
|
235
|
+
df["pred_proba"] = proba
|
|
236
|
+
df = expand_proba_column(df, label_encoder.classes_)
|
|
237
|
+
else:
|
|
238
|
+
df.loc[valid_mask, "prediction"] = label_encoder.inverse_transform((preds.flatten() > 0.5).astype(int))
|
|
239
|
+
else:
|
|
240
|
+
for t_idx, tc in enumerate(target_columns):
|
|
241
|
+
df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
|
|
242
|
+
df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
|
|
243
|
+
df["prediction"] = df[f"{target_columns[0]}_pred"]
|
|
244
|
+
df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
|
|
245
|
+
|
|
246
|
+
# Compute confidence from ensemble std
|
|
247
|
+
df = _compute_std_confidence(df, model_dict["median_std"])
|
|
248
|
+
|
|
249
|
+
return df
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
# =============================================================================
|
|
253
|
+
# Training
|
|
254
|
+
# =============================================================================
|
|
255
|
+
if __name__ == "__main__":
|
|
256
|
+
# -------------------------------------------------------------------------
|
|
257
|
+
# Training-only imports (deferred to reduce serverless startup time)
|
|
258
|
+
# -------------------------------------------------------------------------
|
|
259
|
+
import argparse
|
|
260
|
+
import glob
|
|
261
|
+
|
|
262
|
+
import awswrangler as wr
|
|
263
|
+
from lightning import pytorch as pl
|
|
264
|
+
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
|
|
265
|
+
from sklearn.preprocessing import LabelEncoder
|
|
266
|
+
|
|
267
|
+
# Enable Tensor Core optimization for GPUs that support it
|
|
268
|
+
torch.set_float32_matmul_precision("medium")
|
|
269
|
+
|
|
270
|
+
from chemprop import nn
|
|
271
|
+
|
|
272
|
+
from model_script_utils import (
|
|
273
|
+
check_dataframe,
|
|
274
|
+
compute_classification_metrics,
|
|
275
|
+
compute_regression_metrics,
|
|
276
|
+
print_classification_metrics,
|
|
277
|
+
print_confusion_matrix,
|
|
278
|
+
print_regression_metrics,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# -------------------------------------------------------------------------
|
|
282
|
+
# Training-only helper function
|
|
283
|
+
# -------------------------------------------------------------------------
|
|
284
|
+
def build_mpnn_model(
|
|
285
|
+
hyperparameters: dict,
|
|
286
|
+
task: str = "regression",
|
|
287
|
+
num_classes: int | None = None,
|
|
288
|
+
n_targets: int = 1,
|
|
289
|
+
n_extra_descriptors: int = 0,
|
|
290
|
+
x_d_transform: nn.ScaleTransform | None = None,
|
|
291
|
+
output_transform: nn.UnscaleTransform | None = None,
|
|
292
|
+
task_weights: np.ndarray | None = None,
|
|
293
|
+
) -> models.MPNN:
|
|
294
|
+
"""Build an MPNN model with specified hyperparameters."""
|
|
295
|
+
hidden_dim = hyperparameters["hidden_dim"]
|
|
296
|
+
depth = hyperparameters["depth"]
|
|
297
|
+
dropout = hyperparameters["dropout"]
|
|
298
|
+
ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
|
|
299
|
+
ffn_num_layers = hyperparameters["ffn_num_layers"]
|
|
300
|
+
|
|
301
|
+
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
302
|
+
agg = nn.NormAggregation()
|
|
303
|
+
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
304
|
+
|
|
305
|
+
if task == "classification" and num_classes is not None:
|
|
306
|
+
ffn = nn.MulticlassClassificationFFN(
|
|
307
|
+
n_classes=num_classes, input_dim=ffn_input_dim,
|
|
308
|
+
hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
|
|
309
|
+
)
|
|
310
|
+
else:
|
|
311
|
+
# Map criterion name to ChemProp metric class (must have .clone() method)
|
|
312
|
+
from chemprop.nn.metrics import MAE, MSE
|
|
313
|
+
|
|
314
|
+
criterion_map = {
|
|
315
|
+
"mae": MAE,
|
|
316
|
+
"mse": MSE,
|
|
317
|
+
}
|
|
318
|
+
criterion_name = hyperparameters.get("criterion", "mae")
|
|
319
|
+
if criterion_name not in criterion_map:
|
|
320
|
+
raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
|
|
321
|
+
criterion = criterion_map[criterion_name]()
|
|
322
|
+
|
|
323
|
+
weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
|
|
324
|
+
ffn = nn.RegressionFFN(
|
|
325
|
+
input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
|
|
326
|
+
dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
|
|
327
|
+
criterion=criterion,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
|
|
331
|
+
|
|
332
|
+
# -------------------------------------------------------------------------
|
|
333
|
+
# Setup: Parse arguments and load data
|
|
334
|
+
# -------------------------------------------------------------------------
|
|
335
|
+
parser = argparse.ArgumentParser()
|
|
336
|
+
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
337
|
+
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
338
|
+
parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
|
|
339
|
+
args = parser.parse_args()
|
|
340
|
+
|
|
341
|
+
# Extract template parameters
|
|
342
|
+
target_columns = TEMPLATE_PARAMS["targets"]
|
|
343
|
+
model_type = TEMPLATE_PARAMS["model_type"]
|
|
344
|
+
feature_list = TEMPLATE_PARAMS["feature_list"]
|
|
345
|
+
id_column = TEMPLATE_PARAMS["id_column"]
|
|
346
|
+
model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
|
|
347
|
+
hyperparameters = {**DEFAULT_HYPERPARAMETERS, **(TEMPLATE_PARAMS["hyperparameters"] or {})}
|
|
348
|
+
|
|
349
|
+
if not target_columns or not isinstance(target_columns, list):
|
|
350
|
+
raise ValueError("'targets' must be a non-empty list of target column names")
|
|
351
|
+
n_targets = len(target_columns)
|
|
352
|
+
|
|
353
|
+
smiles_column = _find_smiles_column(feature_list)
|
|
354
|
+
extra_feature_cols = [f for f in feature_list if f != smiles_column]
|
|
355
|
+
use_extra_features = len(extra_feature_cols) > 0
|
|
356
|
+
|
|
357
|
+
print(f"Target columns ({n_targets}): {target_columns}")
|
|
358
|
+
print(f"SMILES column: {smiles_column}")
|
|
359
|
+
print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
|
|
360
|
+
print(f"Hyperparameters: {hyperparameters}")
|
|
361
|
+
|
|
362
|
+
# Load training data
|
|
363
|
+
training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
|
|
364
|
+
print(f"Training Files: {training_files}")
|
|
365
|
+
all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
|
|
366
|
+
check_dataframe(all_df, "training_df")
|
|
367
|
+
|
|
368
|
+
# Clean data
|
|
369
|
+
initial_count = len(all_df)
|
|
370
|
+
all_df = all_df.dropna(subset=[smiles_column])
|
|
371
|
+
all_df = all_df[all_df[target_columns].notna().any(axis=1)]
|
|
372
|
+
if len(all_df) < initial_count:
|
|
373
|
+
print(f"Dropped {initial_count - len(all_df)} rows with missing SMILES/targets")
|
|
374
|
+
|
|
375
|
+
print(f"Data shape: {all_df.shape}")
|
|
376
|
+
for tc in target_columns:
|
|
377
|
+
print(f" {tc}: {all_df[tc].notna().sum()} samples")
|
|
378
|
+
|
|
379
|
+
# -------------------------------------------------------------------------
|
|
380
|
+
# Classification setup
|
|
381
|
+
# -------------------------------------------------------------------------
|
|
382
|
+
label_encoder = None
|
|
383
|
+
num_classes = None
|
|
384
|
+
if model_type == "classifier":
|
|
385
|
+
if n_targets > 1:
|
|
386
|
+
raise ValueError("Multi-task classification not supported")
|
|
387
|
+
label_encoder = LabelEncoder()
|
|
388
|
+
all_df[target_columns[0]] = label_encoder.fit_transform(all_df[target_columns[0]])
|
|
389
|
+
num_classes = len(label_encoder.classes_)
|
|
390
|
+
print(f"Classification: {num_classes} classes: {label_encoder.classes_}")
|
|
391
|
+
|
|
392
|
+
# -------------------------------------------------------------------------
|
|
393
|
+
# Prepare features
|
|
394
|
+
# -------------------------------------------------------------------------
|
|
395
|
+
task = "classification" if model_type == "classifier" else "regression"
|
|
396
|
+
n_extra = len(extra_feature_cols) if use_extra_features else 0
|
|
397
|
+
|
|
398
|
+
all_extra_features, col_means = None, None
|
|
399
|
+
if use_extra_features:
|
|
400
|
+
all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
|
|
401
|
+
col_means = np.nanmean(all_extra_features, axis=0)
|
|
402
|
+
for i in range(all_extra_features.shape[1]):
|
|
403
|
+
all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
|
|
404
|
+
|
|
405
|
+
all_targets = all_df[target_columns].values.astype(np.float32)
|
|
406
|
+
|
|
407
|
+
# Filter invalid SMILES
|
|
408
|
+
_, valid_indices = _create_molecule_datapoints(all_df[smiles_column].tolist(), all_targets, all_extra_features)
|
|
409
|
+
all_df = all_df.iloc[valid_indices].reset_index(drop=True)
|
|
410
|
+
all_targets = all_targets[valid_indices]
|
|
411
|
+
if all_extra_features is not None:
|
|
412
|
+
all_extra_features = all_extra_features[valid_indices]
|
|
413
|
+
print(f"Data after SMILES validation: {all_df.shape}")
|
|
414
|
+
|
|
415
|
+
# Task weights for multi-task (inverse sample count)
|
|
416
|
+
task_weights = None
|
|
417
|
+
if n_targets > 1 and model_type != "classifier":
|
|
418
|
+
counts = np.array([np.sum(~np.isnan(all_targets[:, t])) for t in range(n_targets)])
|
|
419
|
+
task_weights = (1.0 / counts) / (1.0 / counts).min()
|
|
420
|
+
print(f"Task weights: {dict(zip(target_columns, task_weights.round(3)))}")
|
|
421
|
+
|
|
422
|
+
# -------------------------------------------------------------------------
|
|
423
|
+
# Cross-validation setup
|
|
424
|
+
# -------------------------------------------------------------------------
|
|
425
|
+
n_folds = hyperparameters["n_folds"]
|
|
426
|
+
batch_size = hyperparameters["batch_size"]
|
|
427
|
+
|
|
428
|
+
if n_folds == 1:
|
|
429
|
+
if "training" in all_df.columns:
|
|
430
|
+
print("Using 'training' column for train/val split")
|
|
431
|
+
train_idx = np.where(all_df["training"])[0]
|
|
432
|
+
val_idx = np.where(~all_df["training"])[0]
|
|
433
|
+
else:
|
|
434
|
+
print("WARNING: No 'training' column, using random 80/20 split")
|
|
435
|
+
train_idx, val_idx = train_test_split(np.arange(len(all_df)), test_size=0.2, random_state=42)
|
|
436
|
+
folds = [(train_idx, val_idx)]
|
|
437
|
+
else:
|
|
438
|
+
if model_type == "classifier":
|
|
439
|
+
kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
440
|
+
folds = list(kfold.split(all_df, all_df[target_columns[0]]))
|
|
441
|
+
else:
|
|
442
|
+
kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
443
|
+
folds = list(kfold.split(all_df))
|
|
444
|
+
|
|
445
|
+
print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold ensemble'}...")
|
|
446
|
+
|
|
447
|
+
# -------------------------------------------------------------------------
|
|
448
|
+
# Training loop
|
|
449
|
+
# -------------------------------------------------------------------------
|
|
450
|
+
oof_predictions = np.full((len(all_df), n_targets), np.nan, dtype=np.float64)
|
|
451
|
+
oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64) if model_type == "classifier" and num_classes else None
|
|
452
|
+
|
|
453
|
+
ensemble_models = []
|
|
454
|
+
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
455
|
+
print(f"\n{'='*50}")
|
|
456
|
+
print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
|
|
457
|
+
print(f"{'='*50}")
|
|
458
|
+
|
|
459
|
+
# Split data
|
|
460
|
+
df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
|
|
461
|
+
train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
|
|
462
|
+
train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
|
|
463
|
+
val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
|
|
464
|
+
val_extra_raw = val_extra.copy() if val_extra is not None else None
|
|
465
|
+
|
|
466
|
+
# Create datasets
|
|
467
|
+
train_dps, _ = _create_molecule_datapoints(df_train[smiles_column].tolist(), train_targets, train_extra)
|
|
468
|
+
val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra)
|
|
469
|
+
train_dataset, val_dataset = data.MoleculeDataset(train_dps), data.MoleculeDataset(val_dps)
|
|
470
|
+
|
|
471
|
+
# Scale features/targets
|
|
472
|
+
x_d_transform = None
|
|
473
|
+
if use_extra_features:
|
|
474
|
+
scaler = train_dataset.normalize_inputs("X_d")
|
|
475
|
+
val_dataset.normalize_inputs("X_d", scaler)
|
|
476
|
+
x_d_transform = nn.ScaleTransform.from_standard_scaler(scaler)
|
|
477
|
+
|
|
478
|
+
output_transform = None
|
|
479
|
+
if model_type in ["regressor", "uq_regressor"]:
|
|
480
|
+
target_scaler = train_dataset.normalize_targets()
|
|
481
|
+
val_dataset.normalize_targets(target_scaler)
|
|
482
|
+
output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
|
|
483
|
+
|
|
484
|
+
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
|
|
485
|
+
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
|
|
486
|
+
|
|
487
|
+
# Build and train model
|
|
488
|
+
pl.seed_everything(hyperparameters["seed"] + fold_idx)
|
|
489
|
+
mpnn = build_mpnn_model(
|
|
490
|
+
hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
|
|
491
|
+
n_extra_descriptors=n_extra, x_d_transform=x_d_transform,
|
|
492
|
+
output_transform=output_transform, task_weights=task_weights,
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
trainer = pl.Trainer(
|
|
496
|
+
accelerator="auto", max_epochs=hyperparameters["max_epochs"], logger=False, enable_progress_bar=True,
|
|
497
|
+
callbacks=[
|
|
498
|
+
pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min"),
|
|
499
|
+
pl.callbacks.ModelCheckpoint(dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1),
|
|
500
|
+
],
|
|
501
|
+
)
|
|
502
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
503
|
+
|
|
504
|
+
# Load best checkpoint
|
|
505
|
+
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
|
|
506
|
+
checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
|
|
507
|
+
mpnn.load_state_dict(checkpoint["state_dict"])
|
|
508
|
+
|
|
509
|
+
mpnn.eval()
|
|
510
|
+
ensemble_models.append(mpnn)
|
|
511
|
+
|
|
512
|
+
# Out-of-fold predictions (using raw features)
|
|
513
|
+
val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
|
|
514
|
+
val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
|
|
515
|
+
|
|
516
|
+
with torch.inference_mode():
|
|
517
|
+
fold_preds = np.concatenate([p.numpy() for p in trainer.predict(mpnn, val_loader_pred)], axis=0)
|
|
518
|
+
if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
|
|
519
|
+
fold_preds = fold_preds.squeeze(axis=1)
|
|
520
|
+
|
|
521
|
+
if model_type == "classifier" and fold_preds.ndim == 2:
|
|
522
|
+
oof_predictions[val_idx, 0] = np.argmax(fold_preds, axis=1)
|
|
523
|
+
if oof_proba is not None:
|
|
524
|
+
oof_proba[val_idx] = fold_preds
|
|
525
|
+
else:
|
|
526
|
+
if fold_preds.ndim == 1:
|
|
527
|
+
fold_preds = fold_preds.reshape(-1, 1)
|
|
528
|
+
oof_predictions[val_idx] = fold_preds
|
|
529
|
+
|
|
530
|
+
print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
|
|
531
|
+
|
|
532
|
+
# -------------------------------------------------------------------------
|
|
533
|
+
# Prepare validation results
|
|
534
|
+
# -------------------------------------------------------------------------
|
|
535
|
+
if n_folds == 1:
|
|
536
|
+
val_mask = ~np.isnan(oof_predictions).all(axis=1)
|
|
537
|
+
df_val = all_df[val_mask].copy()
|
|
538
|
+
preds = oof_predictions[val_mask]
|
|
539
|
+
y_validate = all_targets[val_mask]
|
|
540
|
+
if oof_proba is not None:
|
|
541
|
+
oof_proba = oof_proba[val_mask]
|
|
542
|
+
val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
|
|
543
|
+
else:
|
|
544
|
+
df_val = all_df.copy()
|
|
545
|
+
preds = oof_predictions
|
|
546
|
+
y_validate = all_targets
|
|
547
|
+
val_extra_features = all_extra_features
|
|
548
|
+
|
|
549
|
+
# -------------------------------------------------------------------------
|
|
550
|
+
# Compute metrics and prepare output
|
|
551
|
+
# -------------------------------------------------------------------------
|
|
552
|
+
median_std = None # Only set for regression models with ensemble
|
|
553
|
+
if model_type == "classifier":
|
|
554
|
+
class_preds = preds[:, 0].astype(int)
|
|
555
|
+
target_name = target_columns[0]
|
|
556
|
+
y_true_decoded = label_encoder.inverse_transform(y_validate[:, 0].astype(int))
|
|
557
|
+
preds_decoded = label_encoder.inverse_transform(class_preds)
|
|
558
|
+
|
|
559
|
+
score_df = compute_classification_metrics(y_true_decoded, preds_decoded, label_encoder.classes_, target_name)
|
|
560
|
+
print_classification_metrics(score_df, target_name, label_encoder.classes_)
|
|
561
|
+
print_confusion_matrix(y_true_decoded, preds_decoded, label_encoder.classes_)
|
|
562
|
+
|
|
563
|
+
# Decode target column back to string labels (was encoded for training)
|
|
564
|
+
df_val[target_name] = y_true_decoded
|
|
565
|
+
df_val["prediction"] = preds_decoded
|
|
566
|
+
if oof_proba is not None:
|
|
567
|
+
df_val["pred_proba"] = [p.tolist() for p in oof_proba]
|
|
568
|
+
df_val = expand_proba_column(df_val, label_encoder.classes_)
|
|
569
|
+
else:
|
|
570
|
+
# Compute ensemble std
|
|
571
|
+
preds_std = None
|
|
572
|
+
if len(ensemble_models) > 1:
|
|
573
|
+
print("Computing prediction_std from ensemble...")
|
|
574
|
+
val_dps, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), y_validate, val_extra_features)
|
|
575
|
+
val_loader = data.build_dataloader(data.MoleculeDataset(val_dps), batch_size=batch_size, shuffle=False)
|
|
576
|
+
trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
577
|
+
|
|
578
|
+
all_ens_preds = []
|
|
579
|
+
for m in ensemble_models:
|
|
580
|
+
with torch.inference_mode():
|
|
581
|
+
ens_preds = np.concatenate([p.numpy() for p in trainer_pred.predict(m, val_loader)], axis=0)
|
|
582
|
+
if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
|
|
583
|
+
ens_preds = ens_preds.squeeze(axis=1)
|
|
584
|
+
all_ens_preds.append(ens_preds)
|
|
585
|
+
preds_std = np.std(np.stack(all_ens_preds), axis=0)
|
|
586
|
+
if preds_std.ndim == 1:
|
|
587
|
+
preds_std = preds_std.reshape(-1, 1)
|
|
588
|
+
|
|
589
|
+
print("\n--- Per-target metrics ---")
|
|
590
|
+
for t_idx, t_name in enumerate(target_columns):
|
|
591
|
+
valid_mask = ~np.isnan(y_validate[:, t_idx])
|
|
592
|
+
if valid_mask.sum() > 0:
|
|
593
|
+
metrics = compute_regression_metrics(y_validate[valid_mask, t_idx], preds[valid_mask, t_idx])
|
|
594
|
+
print_regression_metrics(metrics)
|
|
595
|
+
|
|
596
|
+
df_val[f"{t_name}_pred"] = preds[:, t_idx]
|
|
597
|
+
df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx] if preds_std is not None else 0.0
|
|
598
|
+
|
|
599
|
+
df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
|
|
600
|
+
df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
|
|
601
|
+
|
|
602
|
+
# Compute confidence from ensemble std
|
|
603
|
+
median_std = float(np.median(preds_std[:, 0]))
|
|
604
|
+
print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
|
|
605
|
+
df_val = _compute_std_confidence(df_val, median_std)
|
|
606
|
+
print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
|
|
607
|
+
|
|
608
|
+
# -------------------------------------------------------------------------
|
|
609
|
+
# Save validation predictions to S3
|
|
610
|
+
# -------------------------------------------------------------------------
|
|
611
|
+
output_columns = [id_column] if id_column in df_val.columns else []
|
|
612
|
+
output_columns += target_columns
|
|
613
|
+
output_columns += [f"{t}_pred" for t in target_columns] + [f"{t}_pred_std" for t in target_columns]
|
|
614
|
+
output_columns += ["prediction", "prediction_std", "confidence"]
|
|
615
|
+
output_columns += [c for c in df_val.columns if c.endswith("_proba")]
|
|
616
|
+
output_columns = [c for c in output_columns if c in df_val.columns]
|
|
617
|
+
|
|
618
|
+
wr.s3.to_csv(df_val[output_columns], f"{model_metrics_s3_path}/validation_predictions.csv", index=False)
|
|
619
|
+
|
|
620
|
+
# -------------------------------------------------------------------------
|
|
621
|
+
# Save model artifacts
|
|
622
|
+
# -------------------------------------------------------------------------
|
|
623
|
+
for idx, m in enumerate(ensemble_models):
|
|
624
|
+
models.save_model(os.path.join(args.model_dir, f"chemprop_model_{idx}.pt"), m)
|
|
625
|
+
print(f"Saved {len(ensemble_models)} model(s)")
|
|
626
|
+
|
|
627
|
+
# Clean up checkpoints
|
|
628
|
+
for ckpt in glob.glob(os.path.join(args.model_dir, "best_*.ckpt")):
|
|
629
|
+
os.remove(ckpt)
|
|
630
|
+
|
|
631
|
+
ensemble_metadata = {
|
|
632
|
+
"n_ensemble": len(ensemble_models),
|
|
633
|
+
"n_folds": n_folds,
|
|
634
|
+
"target_columns": target_columns,
|
|
635
|
+
"median_std": median_std, # For confidence calculation during inference
|
|
636
|
+
}
|
|
637
|
+
joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
638
|
+
|
|
639
|
+
with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
|
|
640
|
+
json.dump(hyperparameters, f, indent=2)
|
|
641
|
+
|
|
642
|
+
if label_encoder:
|
|
643
|
+
joblib.dump(label_encoder, os.path.join(args.model_dir, "label_encoder.joblib"))
|
|
644
|
+
|
|
645
|
+
if use_extra_features:
|
|
646
|
+
joblib.dump({"extra_feature_cols": extra_feature_cols, "col_means": col_means.tolist()}, os.path.join(args.model_dir, "feature_metadata.joblib"))
|
|
647
|
+
print(f"Saved feature metadata for {len(extra_feature_cols)} extra features")
|
|
648
|
+
|
|
649
|
+
print(f"\nModel training complete! Artifacts saved to {args.model_dir}")
|