workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +2 -0
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
- workbench/algorithms/dataframe/projection_2d.py +8 -2
- workbench/algorithms/dataframe/proximity.py +3 -0
- workbench/algorithms/dataframe/smart_aggregator.py +161 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/api/feature_set.py +0 -1
- workbench/api/meta.py +0 -1
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +37 -7
- workbench/core/artifacts/endpoint_core.py +12 -2
- workbench/core/artifacts/feature_set_core.py +238 -225
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/transforms/features_to_model/features_to_model.py +2 -8
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
- workbench/model_script_utils/model_script_utils.py +30 -0
- workbench/model_script_utils/uq_harness.py +0 -1
- workbench/model_scripts/chemprop/chemprop.template +196 -68
- workbench/model_scripts/chemprop/generated_model_script.py +197 -72
- workbench/model_scripts/chemprop/model_script_utils.py +30 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
- workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
- workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
- workbench/model_scripts/pytorch_model/pytorch.template +47 -29
- workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
- workbench/model_scripts/script_generation.py +0 -1
- workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
- workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
- workbench/model_scripts/xgb_model/uq_harness.py +0 -1
- workbench/scripts/ml_pipeline_sqs.py +71 -2
- workbench/themes/dark/custom.css +85 -8
- workbench/themes/dark/plotly.json +6 -6
- workbench/themes/light/custom.css +172 -64
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +82 -29
- workbench/themes/midnight_blue/plotly.json +1 -1
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +137 -27
- workbench/utils/clientside_callbacks.py +41 -0
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/model_utils.py +0 -1
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +52 -36
- workbench/utils/theme_manager.py +95 -30
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +2 -0
- workbench/web_interface/components/plugin_unit_test.py +0 -1
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +10 -6
- workbench/web_interface/components/plugins/scatter_plot.py +184 -85
- workbench/web_interface/components/settings_menu.py +185 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
|
@@ -34,7 +34,7 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
34
34
|
"max_epochs": 400,
|
|
35
35
|
"patience": 50,
|
|
36
36
|
"batch_size": 32,
|
|
37
|
-
# Message Passing
|
|
37
|
+
# Message Passing (ignored when using foundation model)
|
|
38
38
|
"hidden_dim": 700,
|
|
39
39
|
"depth": 6,
|
|
40
40
|
"dropout": 0.1, # Lower dropout - ensemble provides regularization
|
|
@@ -45,15 +45,23 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
45
45
|
"criterion": "mae",
|
|
46
46
|
# Random seed
|
|
47
47
|
"seed": 42,
|
|
48
|
+
# Foundation model support
|
|
49
|
+
# - "CheMeleon": Load CheMeleon pretrained weights (auto-downloads on first use)
|
|
50
|
+
# - Path to .pt file: Load custom pretrained Chemprop model
|
|
51
|
+
# - None: Train from scratch (default)
|
|
52
|
+
"from_foundation": None,
|
|
53
|
+
# Freeze MPNN for N epochs, then unfreeze (0 = no freezing, train all params from start)
|
|
54
|
+
# Recommended: 5-20 epochs when using foundation models to stabilize FFN before fine-tuning MPNN
|
|
55
|
+
"freeze_mpnn_epochs": 0,
|
|
48
56
|
}
|
|
49
57
|
|
|
50
58
|
# Template parameters (filled in by Workbench)
|
|
51
59
|
TEMPLATE_PARAMS = {
|
|
52
60
|
"model_type": "uq_regressor",
|
|
53
|
-
"targets": ['
|
|
54
|
-
"feature_list": ['smiles'
|
|
55
|
-
"id_column": "
|
|
56
|
-
"model_metrics_s3_path": "s3://
|
|
61
|
+
"targets": ['udm_asy_res_value'],
|
|
62
|
+
"feature_list": ['smiles'],
|
|
63
|
+
"id_column": "udm_mol_bat_id",
|
|
64
|
+
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/logd-value-reg-chemprop-1-dt/training",
|
|
57
65
|
"hyperparameters": {},
|
|
58
66
|
}
|
|
59
67
|
|
|
@@ -114,26 +122,27 @@ def _create_molecule_datapoints(
|
|
|
114
122
|
# Model Loading (for SageMaker inference)
|
|
115
123
|
# =============================================================================
|
|
116
124
|
def model_fn(model_dir: str) -> dict:
|
|
117
|
-
"""Load ChemProp MPNN ensemble from the specified directory.
|
|
118
|
-
from lightning import pytorch as pl
|
|
125
|
+
"""Load ChemProp MPNN ensemble from the specified directory.
|
|
119
126
|
|
|
127
|
+
Optimized for serverless cold starts - uses direct PyTorch inference
|
|
128
|
+
instead of Lightning Trainer to minimize startup time.
|
|
129
|
+
"""
|
|
120
130
|
metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
|
|
131
|
+
|
|
132
|
+
# Load all ensemble models (keep on CPU for serverless compatibility)
|
|
133
|
+
# ChemProp handles device placement internally
|
|
121
134
|
ensemble_models = []
|
|
122
135
|
for i in range(metadata["n_ensemble"]):
|
|
123
136
|
model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
|
|
124
137
|
model.eval()
|
|
125
138
|
ensemble_models.append(model)
|
|
126
139
|
|
|
127
|
-
# Pre-initialize trainer once during model loading (expensive operation)
|
|
128
|
-
trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
|
|
129
|
-
|
|
130
140
|
print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
|
|
131
141
|
return {
|
|
132
142
|
"ensemble_models": ensemble_models,
|
|
133
143
|
"n_ensemble": metadata["n_ensemble"],
|
|
134
144
|
"target_columns": metadata["target_columns"],
|
|
135
145
|
"median_std": metadata["median_std"],
|
|
136
|
-
"trainer": trainer,
|
|
137
146
|
}
|
|
138
147
|
|
|
139
148
|
|
|
@@ -141,13 +150,15 @@ def model_fn(model_dir: str) -> dict:
|
|
|
141
150
|
# Inference (for SageMaker inference)
|
|
142
151
|
# =============================================================================
|
|
143
152
|
def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
144
|
-
"""Make predictions with ChemProp MPNN ensemble.
|
|
153
|
+
"""Make predictions with ChemProp MPNN ensemble.
|
|
154
|
+
|
|
155
|
+
Uses direct PyTorch inference (no Lightning Trainer) for fast serverless inference.
|
|
156
|
+
"""
|
|
145
157
|
model_type = TEMPLATE_PARAMS["model_type"]
|
|
146
158
|
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
147
159
|
|
|
148
160
|
ensemble_models = model_dict["ensemble_models"]
|
|
149
161
|
target_columns = model_dict["target_columns"]
|
|
150
|
-
trainer = model_dict["trainer"] # Use pre-initialized trainer
|
|
151
162
|
|
|
152
163
|
# Load artifacts
|
|
153
164
|
label_encoder = None
|
|
@@ -202,18 +213,34 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
202
213
|
return df
|
|
203
214
|
|
|
204
215
|
dataset = data.MoleculeDataset(datapoints)
|
|
205
|
-
dataloader = data.build_dataloader(dataset, shuffle=False)
|
|
216
|
+
dataloader = data.build_dataloader(dataset, shuffle=False, batch_size=64)
|
|
206
217
|
|
|
207
|
-
# Ensemble predictions
|
|
218
|
+
# Ensemble predictions using direct PyTorch inference (no Lightning Trainer)
|
|
208
219
|
all_preds = []
|
|
209
220
|
for model in ensemble_models:
|
|
221
|
+
model_preds = []
|
|
222
|
+
model.eval()
|
|
210
223
|
with torch.inference_mode():
|
|
211
|
-
|
|
212
|
-
|
|
224
|
+
for batch in dataloader:
|
|
225
|
+
# TrainingBatch contains (bmg, V_d, X_d, targets, weights, lt_mask, gt_mask)
|
|
226
|
+
# For inference we only need bmg, V_d, X_d
|
|
227
|
+
bmg, V_d, X_d, *_ = batch
|
|
228
|
+
output = model(bmg, V_d, X_d)
|
|
229
|
+
model_preds.append(output.detach().cpu().numpy())
|
|
230
|
+
|
|
231
|
+
if len(model_preds) == 0:
|
|
232
|
+
print(f"Warning: No predictions generated. Dataset size: {len(datapoints)}")
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
preds = np.concatenate(model_preds, axis=0)
|
|
213
236
|
if preds.ndim == 3 and preds.shape[1] == 1:
|
|
214
237
|
preds = preds.squeeze(axis=1)
|
|
215
238
|
all_preds.append(preds)
|
|
216
239
|
|
|
240
|
+
if len(all_preds) == 0:
|
|
241
|
+
print("Error: No ensemble predictions generated")
|
|
242
|
+
return df
|
|
243
|
+
|
|
217
244
|
preds = np.mean(np.stack(all_preds), axis=0)
|
|
218
245
|
preds_std = np.std(np.stack(all_preds), axis=0)
|
|
219
246
|
if preds.ndim == 1:
|
|
@@ -243,8 +270,11 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
243
270
|
df["prediction"] = df[f"{target_columns[0]}_pred"]
|
|
244
271
|
df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
|
|
245
272
|
|
|
246
|
-
# Compute confidence from ensemble std
|
|
247
|
-
|
|
273
|
+
# Compute confidence from ensemble std (or NaN if single model)
|
|
274
|
+
if model_dict["median_std"] is not None:
|
|
275
|
+
df = _compute_std_confidence(df, model_dict["median_std"])
|
|
276
|
+
else:
|
|
277
|
+
df["confidence"] = np.nan
|
|
248
278
|
|
|
249
279
|
return df
|
|
250
280
|
|
|
@@ -279,54 +309,107 @@ if __name__ == "__main__":
|
|
|
279
309
|
)
|
|
280
310
|
|
|
281
311
|
# -------------------------------------------------------------------------
|
|
282
|
-
# Training-only helper
|
|
312
|
+
# Training-only helper functions
|
|
283
313
|
# -------------------------------------------------------------------------
|
|
284
|
-
def
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
314
|
+
def _load_foundation_weights(from_foundation: str) -> tuple[nn.BondMessagePassing, nn.Aggregation]:
|
|
315
|
+
"""Load pretrained MPNN weights from foundation model.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
from_foundation: "CheMeleon" or path to .pt file
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
Tuple of (message_passing, aggregation) modules
|
|
322
|
+
"""
|
|
323
|
+
import urllib.request
|
|
324
|
+
from pathlib import Path
|
|
325
|
+
|
|
326
|
+
print(f"Loading foundation model: {from_foundation}")
|
|
327
|
+
|
|
328
|
+
if from_foundation.lower() == "chemeleon":
|
|
329
|
+
# Download from Zenodo if not cached
|
|
330
|
+
cache_dir = Path.home() / ".chemprop" / "foundation"
|
|
331
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
332
|
+
chemeleon_path = cache_dir / "chemeleon_mp.pt"
|
|
333
|
+
|
|
334
|
+
if not chemeleon_path.exists():
|
|
335
|
+
print(" Downloading CheMeleon weights from Zenodo...")
|
|
336
|
+
urllib.request.urlretrieve(
|
|
337
|
+
"https://zenodo.org/records/15460715/files/chemeleon_mp.pt", chemeleon_path
|
|
338
|
+
)
|
|
339
|
+
print(f" Downloaded to {chemeleon_path}")
|
|
340
|
+
|
|
341
|
+
ckpt = torch.load(chemeleon_path, weights_only=True)
|
|
342
|
+
mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
|
|
343
|
+
mp.load_state_dict(ckpt["state_dict"])
|
|
344
|
+
print(f" Loaded CheMeleon MPNN (hidden_dim={mp.output_dim})")
|
|
345
|
+
return mp, nn.MeanAggregation()
|
|
346
|
+
|
|
347
|
+
if not os.path.exists(from_foundation):
|
|
348
|
+
raise ValueError(f"Foundation model not found: {from_foundation}. Use 'CheMeleon' or a valid .pt path.")
|
|
349
|
+
|
|
350
|
+
ckpt = torch.load(from_foundation, weights_only=False)
|
|
351
|
+
if "hyper_parameters" in ckpt and "state_dict" in ckpt:
|
|
352
|
+
# CheMeleon-style checkpoint
|
|
353
|
+
mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
|
|
354
|
+
mp.load_state_dict(ckpt["state_dict"])
|
|
355
|
+
print(f" Loaded custom foundation weights (hidden_dim={mp.output_dim})")
|
|
356
|
+
return mp, nn.MeanAggregation()
|
|
357
|
+
|
|
358
|
+
# Full MPNN model file
|
|
359
|
+
pretrained = models.MPNN.load_from_file(from_foundation)
|
|
360
|
+
print(f" Loaded custom MPNN (hidden_dim={pretrained.message_passing.output_dim})")
|
|
361
|
+
return pretrained.message_passing, pretrained.agg
|
|
362
|
+
|
|
363
|
+
def _build_ffn(
|
|
364
|
+
task: str, input_dim: int, hyperparameters: dict,
|
|
365
|
+
num_classes: int | None, n_targets: int,
|
|
366
|
+
output_transform: nn.UnscaleTransform | None, task_weights: np.ndarray | None,
|
|
367
|
+
) -> nn.Predictor:
|
|
368
|
+
"""Build task-specific FFN head."""
|
|
297
369
|
dropout = hyperparameters["dropout"]
|
|
298
370
|
ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
|
|
299
371
|
ffn_num_layers = hyperparameters["ffn_num_layers"]
|
|
300
372
|
|
|
301
|
-
mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
|
|
302
|
-
agg = nn.NormAggregation()
|
|
303
|
-
ffn_input_dim = hidden_dim + n_extra_descriptors
|
|
304
|
-
|
|
305
373
|
if task == "classification" and num_classes is not None:
|
|
306
|
-
|
|
307
|
-
n_classes=num_classes, input_dim=
|
|
374
|
+
return nn.MulticlassClassificationFFN(
|
|
375
|
+
n_classes=num_classes, input_dim=input_dim,
|
|
308
376
|
hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
|
|
309
377
|
)
|
|
378
|
+
|
|
379
|
+
from chemprop.nn.metrics import MAE, MSE
|
|
380
|
+
criterion_map = {"mae": MAE, "mse": MSE}
|
|
381
|
+
criterion_name = hyperparameters.get("criterion", "mae")
|
|
382
|
+
if criterion_name not in criterion_map:
|
|
383
|
+
raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
|
|
384
|
+
|
|
385
|
+
weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
|
|
386
|
+
return nn.RegressionFFN(
|
|
387
|
+
input_dim=input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
|
|
388
|
+
dropout=dropout, n_tasks=n_targets, output_transform=output_transform,
|
|
389
|
+
task_weights=weights_tensor, criterion=criterion_map[criterion_name](),
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
def build_mpnn_model(
|
|
393
|
+
hyperparameters: dict, task: str = "regression", num_classes: int | None = None,
|
|
394
|
+
n_targets: int = 1, n_extra_descriptors: int = 0,
|
|
395
|
+
x_d_transform: nn.ScaleTransform | None = None,
|
|
396
|
+
output_transform: nn.UnscaleTransform | None = None, task_weights: np.ndarray | None = None,
|
|
397
|
+
) -> models.MPNN:
|
|
398
|
+
"""Build MPNN model, optionally loading pretrained weights."""
|
|
399
|
+
from_foundation = hyperparameters.get("from_foundation")
|
|
400
|
+
|
|
401
|
+
if from_foundation:
|
|
402
|
+
mp, agg = _load_foundation_weights(from_foundation)
|
|
403
|
+
ffn_input_dim = mp.output_dim + n_extra_descriptors
|
|
310
404
|
else:
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
criterion_map = {
|
|
315
|
-
"mae": MAE,
|
|
316
|
-
"mse": MSE,
|
|
317
|
-
}
|
|
318
|
-
criterion_name = hyperparameters.get("criterion", "mae")
|
|
319
|
-
if criterion_name not in criterion_map:
|
|
320
|
-
raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
|
|
321
|
-
criterion = criterion_map[criterion_name]()
|
|
322
|
-
|
|
323
|
-
weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
|
|
324
|
-
ffn = nn.RegressionFFN(
|
|
325
|
-
input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
|
|
326
|
-
dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
|
|
327
|
-
criterion=criterion,
|
|
405
|
+
mp = nn.BondMessagePassing(
|
|
406
|
+
d_h=hyperparameters["hidden_dim"], depth=hyperparameters["depth"],
|
|
407
|
+
dropout=hyperparameters["dropout"],
|
|
328
408
|
)
|
|
409
|
+
agg = nn.NormAggregation()
|
|
410
|
+
ffn_input_dim = hyperparameters["hidden_dim"] + n_extra_descriptors
|
|
329
411
|
|
|
412
|
+
ffn = _build_ffn(task, ffn_input_dim, hyperparameters, num_classes, n_targets, output_transform, task_weights)
|
|
330
413
|
return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
|
|
331
414
|
|
|
332
415
|
# -------------------------------------------------------------------------
|
|
@@ -359,6 +442,14 @@ if __name__ == "__main__":
|
|
|
359
442
|
print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
|
|
360
443
|
print(f"Hyperparameters: {hyperparameters}")
|
|
361
444
|
|
|
445
|
+
# Log foundation model configuration
|
|
446
|
+
if hyperparameters.get("from_foundation"):
|
|
447
|
+
freeze_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
|
|
448
|
+
freeze_msg = f"MPNN frozen for {freeze_epochs} epochs" if freeze_epochs > 0 else "no freezing"
|
|
449
|
+
print(f"Foundation model: {hyperparameters['from_foundation']} ({freeze_msg})")
|
|
450
|
+
else:
|
|
451
|
+
print("Foundation model: None (training from scratch)")
|
|
452
|
+
|
|
362
453
|
# Load training data
|
|
363
454
|
training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
|
|
364
455
|
print(f"Training Files: {training_files}")
|
|
@@ -456,7 +547,7 @@ if __name__ == "__main__":
|
|
|
456
547
|
print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
|
|
457
548
|
print(f"{'='*50}")
|
|
458
549
|
|
|
459
|
-
# Split data
|
|
550
|
+
# Split data (val_extra_raw preserves unscaled features for OOF predictions)
|
|
460
551
|
df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
|
|
461
552
|
train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
|
|
462
553
|
train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
|
|
@@ -484,7 +575,7 @@ if __name__ == "__main__":
|
|
|
484
575
|
train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
|
|
485
576
|
val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
|
|
486
577
|
|
|
487
|
-
# Build
|
|
578
|
+
# Build model
|
|
488
579
|
pl.seed_everything(hyperparameters["seed"] + fold_idx)
|
|
489
580
|
mpnn = build_mpnn_model(
|
|
490
581
|
hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
|
|
@@ -492,14 +583,39 @@ if __name__ == "__main__":
|
|
|
492
583
|
output_transform=output_transform, task_weights=task_weights,
|
|
493
584
|
)
|
|
494
585
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
586
|
+
# Train model (with optional two-phase foundation training)
|
|
587
|
+
freeze_mpnn_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
|
|
588
|
+
use_two_phase = hyperparameters.get("from_foundation") and freeze_mpnn_epochs > 0
|
|
589
|
+
|
|
590
|
+
def _set_mpnn_frozen(frozen: bool):
|
|
591
|
+
for param in mpnn.message_passing.parameters():
|
|
592
|
+
param.requires_grad = not frozen
|
|
593
|
+
for param in mpnn.agg.parameters():
|
|
594
|
+
param.requires_grad = not frozen
|
|
595
|
+
|
|
596
|
+
def _make_trainer(max_epochs: int, save_checkpoint: bool = False):
|
|
597
|
+
callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min")]
|
|
598
|
+
if save_checkpoint:
|
|
599
|
+
callbacks.append(pl.callbacks.ModelCheckpoint(
|
|
600
|
+
dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1
|
|
601
|
+
))
|
|
602
|
+
return pl.Trainer(accelerator="auto", max_epochs=max_epochs, logger=False, enable_progress_bar=True, callbacks=callbacks)
|
|
603
|
+
|
|
604
|
+
if use_two_phase:
|
|
605
|
+
# Phase 1: Freeze MPNN, train FFN only
|
|
606
|
+
print(f"Phase 1: Training with frozen MPNN for {freeze_mpnn_epochs} epochs...")
|
|
607
|
+
_set_mpnn_frozen(True)
|
|
608
|
+
_make_trainer(freeze_mpnn_epochs).fit(mpnn, train_loader, val_loader)
|
|
609
|
+
|
|
610
|
+
# Phase 2: Unfreeze and fine-tune all
|
|
611
|
+
print("Phase 2: Unfreezing MPNN, continuing training...")
|
|
612
|
+
_set_mpnn_frozen(False)
|
|
613
|
+
remaining_epochs = max(1, hyperparameters["max_epochs"] - freeze_mpnn_epochs)
|
|
614
|
+
trainer = _make_trainer(remaining_epochs, save_checkpoint=True)
|
|
615
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
616
|
+
else:
|
|
617
|
+
trainer = _make_trainer(hyperparameters["max_epochs"], save_checkpoint=True)
|
|
618
|
+
trainer.fit(mpnn, train_loader, val_loader)
|
|
503
619
|
|
|
504
620
|
# Load best checkpoint
|
|
505
621
|
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
|
|
@@ -509,7 +625,7 @@ if __name__ == "__main__":
|
|
|
509
625
|
mpnn.eval()
|
|
510
626
|
ensemble_models.append(mpnn)
|
|
511
627
|
|
|
512
|
-
# Out-of-fold predictions (using
|
|
628
|
+
# Out-of-fold predictions (using unscaled features - model's x_d_transform handles scaling)
|
|
513
629
|
val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
|
|
514
630
|
val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
|
|
515
631
|
|
|
@@ -599,11 +715,17 @@ if __name__ == "__main__":
|
|
|
599
715
|
df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
|
|
600
716
|
df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
|
|
601
717
|
|
|
602
|
-
# Compute confidence from ensemble std
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
718
|
+
# Compute confidence from ensemble std (or NaN for single model)
|
|
719
|
+
if preds_std is not None:
|
|
720
|
+
median_std = float(np.median(preds_std[:, 0]))
|
|
721
|
+
print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
|
|
722
|
+
df_val = _compute_std_confidence(df_val, median_std)
|
|
723
|
+
print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
|
|
724
|
+
else:
|
|
725
|
+
# Single model - no ensemble std available, confidence is undefined
|
|
726
|
+
median_std = None
|
|
727
|
+
df_val["confidence"] = np.nan
|
|
728
|
+
print("\nSingle model (n_folds=1): No ensemble std, confidence set to NaN")
|
|
607
729
|
|
|
608
730
|
# -------------------------------------------------------------------------
|
|
609
731
|
# Save validation predictions to S3
|
|
@@ -633,6 +755,9 @@ if __name__ == "__main__":
|
|
|
633
755
|
"n_folds": n_folds,
|
|
634
756
|
"target_columns": target_columns,
|
|
635
757
|
"median_std": median_std, # For confidence calculation during inference
|
|
758
|
+
# Foundation model provenance (for tracking/reproducibility)
|
|
759
|
+
"from_foundation": hyperparameters.get("from_foundation", None),
|
|
760
|
+
"freeze_mpnn_epochs": hyperparameters.get("freeze_mpnn_epochs", 0),
|
|
636
761
|
}
|
|
637
762
|
joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
|
|
638
763
|
|
|
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
|
249
249
|
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
250
250
|
|
|
251
251
|
|
|
252
|
+
def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
|
|
253
|
+
"""Cap extreme outliers in prediction_std using IQR method.
|
|
254
|
+
|
|
255
|
+
Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
|
|
256
|
+
This prevents unreasonably large std values while preserving the
|
|
257
|
+
relative ordering and keeping meaningful high-uncertainty signals.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Array with outliers capped at the upper fence
|
|
264
|
+
"""
|
|
265
|
+
if std_array.ndim == 1:
|
|
266
|
+
std_array = std_array.reshape(-1, 1)
|
|
267
|
+
squeeze = True
|
|
268
|
+
else:
|
|
269
|
+
squeeze = False
|
|
270
|
+
|
|
271
|
+
capped = std_array.copy()
|
|
272
|
+
for col in range(capped.shape[1]):
|
|
273
|
+
col_data = capped[:, col]
|
|
274
|
+
q1, q3 = np.percentile(col_data, [25, 75])
|
|
275
|
+
iqr = q3 - q1
|
|
276
|
+
upper_bound = q3 + 1.5 * iqr
|
|
277
|
+
capped[:, col] = np.minimum(col_data, upper_bound)
|
|
278
|
+
|
|
279
|
+
return capped.squeeze() if squeeze else capped
|
|
280
|
+
|
|
281
|
+
|
|
252
282
|
def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
|
|
253
283
|
"""Compute standard regression metrics.
|
|
254
284
|
|
|
@@ -99,7 +99,6 @@ from rdkit.ML.Descriptors import MoleculeDescriptors
|
|
|
99
99
|
from mordred import Calculator as MordredCalculator
|
|
100
100
|
from mordred import AcidBase, Aromatic, Constitutional, Chi, CarbonTypes
|
|
101
101
|
|
|
102
|
-
|
|
103
102
|
logger = logging.getLogger("workbench")
|
|
104
103
|
logger.setLevel(logging.DEBUG)
|
|
105
104
|
|
|
@@ -59,12 +59,12 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
59
59
|
|
|
60
60
|
# Template parameters (filled in by Workbench)
|
|
61
61
|
TEMPLATE_PARAMS = {
|
|
62
|
-
"model_type": "
|
|
63
|
-
"target": "
|
|
64
|
-
"features": ['
|
|
62
|
+
"model_type": "classifier",
|
|
63
|
+
"target": "class",
|
|
64
|
+
"features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
|
|
65
65
|
"id_column": "udm_mol_bat_id",
|
|
66
|
-
"compressed_features": [
|
|
67
|
-
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-
|
|
66
|
+
"compressed_features": [],
|
|
67
|
+
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-class-pytorch-1-fr/training",
|
|
68
68
|
"hyperparameters": {},
|
|
69
69
|
}
|
|
70
70
|
|
|
@@ -152,24 +152,30 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
152
152
|
print("Decompressing features for prediction...")
|
|
153
153
|
matched_df, features = decompress_features(matched_df, features, compressed_features)
|
|
154
154
|
|
|
155
|
-
#
|
|
156
|
-
|
|
157
|
-
if
|
|
158
|
-
|
|
155
|
+
# Impute missing values (categorical with mode, continuous handled by scaler)
|
|
156
|
+
missing_counts = matched_df[features].isna().sum()
|
|
157
|
+
if missing_counts.any():
|
|
158
|
+
missing_features = missing_counts[missing_counts > 0]
|
|
159
|
+
print(f"Imputing missing values: {missing_features.to_dict()}")
|
|
160
|
+
|
|
161
|
+
# Load categorical imputation values if available
|
|
162
|
+
impute_path = os.path.join(model_dir, "categorical_impute.json")
|
|
163
|
+
if os.path.exists(impute_path):
|
|
164
|
+
with open(impute_path) as f:
|
|
165
|
+
cat_impute_values = json.load(f)
|
|
166
|
+
for col in categorical_cols:
|
|
167
|
+
if col in cat_impute_values and matched_df[col].isna().any():
|
|
168
|
+
matched_df[col] = matched_df[col].fillna(cat_impute_values[col])
|
|
169
|
+
# Continuous features are imputed by FeatureScaler.transform() using column means
|
|
159
170
|
|
|
160
171
|
# Initialize output columns
|
|
161
172
|
df["prediction"] = np.nan
|
|
162
173
|
if model_type in ["regressor", "uq_regressor"]:
|
|
163
174
|
df["prediction_std"] = np.nan
|
|
164
175
|
|
|
165
|
-
|
|
166
|
-
if len(complete_df) == 0:
|
|
167
|
-
print("Warning: No complete rows to predict on")
|
|
168
|
-
return df
|
|
169
|
-
|
|
170
|
-
# Prepare data for inference (with standardization)
|
|
176
|
+
# Prepare data for inference (with standardization and continuous imputation)
|
|
171
177
|
x_cont, x_cat, _, _, _ = prepare_data(
|
|
172
|
-
|
|
178
|
+
matched_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
|
|
173
179
|
)
|
|
174
180
|
|
|
175
181
|
# Collect ensemble predictions
|
|
@@ -191,28 +197,20 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
191
197
|
class_preds = np.argmax(avg_probs, axis=1)
|
|
192
198
|
predictions = label_encoder.inverse_transform(class_preds)
|
|
193
199
|
|
|
194
|
-
|
|
195
|
-
all_proba.loc[~missing_mask] = [p.tolist() for p in avg_probs]
|
|
196
|
-
df["pred_proba"] = all_proba
|
|
200
|
+
df["pred_proba"] = [p.tolist() for p in avg_probs]
|
|
197
201
|
df = expand_proba_column(df, label_encoder.classes_)
|
|
198
202
|
else:
|
|
199
203
|
# Regression
|
|
200
204
|
predictions = preds.flatten()
|
|
201
|
-
df
|
|
205
|
+
df["prediction_std"] = preds_std.flatten()
|
|
202
206
|
|
|
203
207
|
# Add UQ intervals if available
|
|
204
208
|
if uq_models and uq_metadata:
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
# Copy UQ columns back to main dataframe
|
|
211
|
-
for col in df_complete.columns:
|
|
212
|
-
if col.startswith("q_") or col == "confidence":
|
|
213
|
-
df.loc[~missing_mask, col] = df_complete[col].values
|
|
214
|
-
|
|
215
|
-
df.loc[~missing_mask, "prediction"] = predictions
|
|
209
|
+
df["prediction"] = predictions # Set prediction before compute_confidence
|
|
210
|
+
df = predict_intervals(df, matched_df[features], uq_models, uq_metadata)
|
|
211
|
+
df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
|
|
212
|
+
|
|
213
|
+
df["prediction"] = predictions
|
|
216
214
|
return df
|
|
217
215
|
|
|
218
216
|
|
|
@@ -275,11 +273,11 @@ if __name__ == "__main__":
|
|
|
275
273
|
all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
|
|
276
274
|
check_dataframe(all_df, "training_df")
|
|
277
275
|
|
|
278
|
-
# Drop rows with missing
|
|
276
|
+
# Drop rows with missing target (required for training)
|
|
279
277
|
initial_count = len(all_df)
|
|
280
|
-
all_df = all_df.dropna(subset=
|
|
278
|
+
all_df = all_df.dropna(subset=[target])
|
|
281
279
|
if len(all_df) < initial_count:
|
|
282
|
-
print(f"Dropped {initial_count - len(all_df)} rows with missing
|
|
280
|
+
print(f"Dropped {initial_count - len(all_df)} rows with missing target")
|
|
283
281
|
|
|
284
282
|
print(f"Target: {target}")
|
|
285
283
|
print(f"Features: {features}")
|
|
@@ -301,6 +299,23 @@ if __name__ == "__main__":
|
|
|
301
299
|
print(f"Categorical: {categorical_cols}")
|
|
302
300
|
print(f"Continuous: {len(continuous_cols)} columns")
|
|
303
301
|
|
|
302
|
+
# Report and handle missing values in features
|
|
303
|
+
# Compute categorical imputation values (mode) for use at inference time
|
|
304
|
+
cat_impute_values = {}
|
|
305
|
+
for col in categorical_cols:
|
|
306
|
+
mode_val = all_df[col].mode().iloc[0] if not all_df[col].mode().empty else all_df[col].cat.categories[0]
|
|
307
|
+
cat_impute_values[col] = str(mode_val) # Convert to string for JSON serialization
|
|
308
|
+
|
|
309
|
+
missing_counts = all_df[features].isna().sum()
|
|
310
|
+
if missing_counts.any():
|
|
311
|
+
missing_features = missing_counts[missing_counts > 0]
|
|
312
|
+
print(f"Missing values in features (will be imputed): {missing_features.to_dict()}")
|
|
313
|
+
# Impute categorical features with mode (most frequent value)
|
|
314
|
+
for col in categorical_cols:
|
|
315
|
+
if all_df[col].isna().any():
|
|
316
|
+
all_df[col] = all_df[col].fillna(cat_impute_values[col])
|
|
317
|
+
# Continuous features are imputed by FeatureScaler.transform() using column means
|
|
318
|
+
|
|
304
319
|
# -------------------------------------------------------------------------
|
|
305
320
|
# Classification setup
|
|
306
321
|
# -------------------------------------------------------------------------
|
|
@@ -506,6 +521,9 @@ if __name__ == "__main__":
|
|
|
506
521
|
with open(os.path.join(args.model_dir, "feature_metadata.json"), "w") as f:
|
|
507
522
|
json.dump({"continuous_cols": continuous_cols, "categorical_cols": categorical_cols}, f)
|
|
508
523
|
|
|
524
|
+
with open(os.path.join(args.model_dir, "categorical_impute.json"), "w") as f:
|
|
525
|
+
json.dump(cat_impute_values, f)
|
|
526
|
+
|
|
509
527
|
with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
|
|
510
528
|
json.dump(hyperparameters, f, indent=2)
|
|
511
529
|
|