workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +2 -0
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  5. workbench/algorithms/dataframe/projection_2d.py +8 -2
  6. workbench/algorithms/dataframe/proximity.py +3 -0
  7. workbench/algorithms/dataframe/smart_aggregator.py +161 -0
  8. workbench/algorithms/sql/column_stats.py +0 -1
  9. workbench/algorithms/sql/correlations.py +0 -1
  10. workbench/algorithms/sql/descriptive_stats.py +0 -1
  11. workbench/api/feature_set.py +0 -1
  12. workbench/api/meta.py +0 -1
  13. workbench/cached/cached_meta.py +0 -1
  14. workbench/cached/cached_model.py +37 -7
  15. workbench/core/artifacts/endpoint_core.py +12 -2
  16. workbench/core/artifacts/feature_set_core.py +238 -225
  17. workbench/core/cloud_platform/cloud_meta.py +0 -1
  18. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  19. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  20. workbench/model_script_utils/model_script_utils.py +30 -0
  21. workbench/model_script_utils/uq_harness.py +0 -1
  22. workbench/model_scripts/chemprop/chemprop.template +196 -68
  23. workbench/model_scripts/chemprop/generated_model_script.py +197 -72
  24. workbench/model_scripts/chemprop/model_script_utils.py +30 -0
  25. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  26. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  27. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
  29. workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
  30. workbench/model_scripts/pytorch_model/pytorch.template +47 -29
  31. workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
  32. workbench/model_scripts/script_generation.py +0 -1
  33. workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
  34. workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
  35. workbench/model_scripts/xgb_model/uq_harness.py +0 -1
  36. workbench/scripts/ml_pipeline_sqs.py +71 -2
  37. workbench/themes/dark/custom.css +85 -8
  38. workbench/themes/dark/plotly.json +6 -6
  39. workbench/themes/light/custom.css +172 -64
  40. workbench/themes/light/plotly.json +9 -9
  41. workbench/themes/midnight_blue/custom.css +82 -29
  42. workbench/themes/midnight_blue/plotly.json +1 -1
  43. workbench/utils/aws_utils.py +0 -1
  44. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  45. workbench/utils/chem_utils/projections.py +16 -6
  46. workbench/utils/chem_utils/vis.py +137 -27
  47. workbench/utils/clientside_callbacks.py +41 -0
  48. workbench/utils/markdown_utils.py +57 -0
  49. workbench/utils/model_utils.py +0 -1
  50. workbench/utils/pipeline_utils.py +0 -1
  51. workbench/utils/plot_utils.py +52 -36
  52. workbench/utils/theme_manager.py +95 -30
  53. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  54. workbench/web_interface/components/model_plot.py +2 -0
  55. workbench/web_interface/components/plugin_unit_test.py +0 -1
  56. workbench/web_interface/components/plugins/ag_table.py +2 -4
  57. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  58. workbench/web_interface/components/plugins/model_details.py +10 -6
  59. workbench/web_interface/components/plugins/scatter_plot.py +184 -85
  60. workbench/web_interface/components/settings_menu.py +185 -0
  61. workbench/web_interface/page_views/main_page.py +0 -1
  62. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
  63. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
  64. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
  65. workbench/themes/quartz/base_css.url +0 -1
  66. workbench/themes/quartz/custom.css +0 -117
  67. workbench/themes/quartz/plotly.json +0 -642
  68. workbench/themes/quartz_dark/base_css.url +0 -1
  69. workbench/themes/quartz_dark/custom.css +0 -131
  70. workbench/themes/quartz_dark/plotly.json +0 -642
  71. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
  72. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
  73. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,7 @@ DEFAULT_HYPERPARAMETERS = {
34
34
  "max_epochs": 400,
35
35
  "patience": 50,
36
36
  "batch_size": 32,
37
- # Message Passing
37
+ # Message Passing (ignored when using foundation model)
38
38
  "hidden_dim": 700,
39
39
  "depth": 6,
40
40
  "dropout": 0.1, # Lower dropout - ensemble provides regularization
@@ -45,15 +45,23 @@ DEFAULT_HYPERPARAMETERS = {
45
45
  "criterion": "mae",
46
46
  # Random seed
47
47
  "seed": 42,
48
+ # Foundation model support
49
+ # - "CheMeleon": Load CheMeleon pretrained weights (auto-downloads on first use)
50
+ # - Path to .pt file: Load custom pretrained Chemprop model
51
+ # - None: Train from scratch (default)
52
+ "from_foundation": None,
53
+ # Freeze MPNN for N epochs, then unfreeze (0 = no freezing, train all params from start)
54
+ # Recommended: 5-20 epochs when using foundation models to stabilize FFN before fine-tuning MPNN
55
+ "freeze_mpnn_epochs": 0,
48
56
  }
49
57
 
50
58
  # Template parameters (filled in by Workbench)
51
59
  TEMPLATE_PARAMS = {
52
60
  "model_type": "uq_regressor",
53
- "targets": ['logd'],
54
- "feature_list": ['smiles', 'mollogp', 'fr_halogen', 'nbase', 'peoe_vsa6', 'bcut2d_mrlow', 'peoe_vsa7', 'peoe_vsa9', 'vsa_estate1', 'peoe_vsa1', 'numhdonors', 'vsa_estate5', 'smr_vsa3', 'slogp_vsa1', 'vsa_estate7', 'bcut2d_mwhi', 'axp_2dv', 'axp_3dv', 'mi', 'smr_vsa9', 'vsa_estate3', 'estate_vsa9', 'bcut2d_mwlow', 'tpsa', 'vsa_estate10', 'xch_5dv', 'slogp_vsa2', 'nhohcount', 'bcut2d_logplow', 'hallkieralpha', 'c2sp2', 'bcut2d_chglo', 'smr_vsa4', 'maxabspartialcharge', 'estate_vsa6', 'qed', 'slogp_vsa6', 'vsa_estate2', 'bcut2d_logphi', 'vsa_estate8', 'xch_7dv', 'fpdensitymorgan3', 'xpc_6d', 'smr_vsa10', 'axp_0d', 'fr_nh1', 'axp_4dv', 'peoe_vsa2', 'estate_vsa8', 'peoe_vsa5', 'vsa_estate6'],
55
- "id_column": "molecule_name",
56
- "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/logd-reg-chemprop-hybrid/training",
61
+ "targets": ['udm_asy_res_value'],
62
+ "feature_list": ['smiles'],
63
+ "id_column": "udm_mol_bat_id",
64
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/logd-value-reg-chemprop-1-dt/training",
57
65
  "hyperparameters": {},
58
66
  }
59
67
 
@@ -114,26 +122,27 @@ def _create_molecule_datapoints(
114
122
  # Model Loading (for SageMaker inference)
115
123
  # =============================================================================
116
124
  def model_fn(model_dir: str) -> dict:
117
- """Load ChemProp MPNN ensemble from the specified directory."""
118
- from lightning import pytorch as pl
125
+ """Load ChemProp MPNN ensemble from the specified directory.
119
126
 
127
+ Optimized for serverless cold starts - uses direct PyTorch inference
128
+ instead of Lightning Trainer to minimize startup time.
129
+ """
120
130
  metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
131
+
132
+ # Load all ensemble models (keep on CPU for serverless compatibility)
133
+ # ChemProp handles device placement internally
121
134
  ensemble_models = []
122
135
  for i in range(metadata["n_ensemble"]):
123
136
  model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
124
137
  model.eval()
125
138
  ensemble_models.append(model)
126
139
 
127
- # Pre-initialize trainer once during model loading (expensive operation)
128
- trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
129
-
130
140
  print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
131
141
  return {
132
142
  "ensemble_models": ensemble_models,
133
143
  "n_ensemble": metadata["n_ensemble"],
134
144
  "target_columns": metadata["target_columns"],
135
145
  "median_std": metadata["median_std"],
136
- "trainer": trainer,
137
146
  }
138
147
 
139
148
 
@@ -141,13 +150,15 @@ def model_fn(model_dir: str) -> dict:
141
150
  # Inference (for SageMaker inference)
142
151
  # =============================================================================
143
152
  def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
144
- """Make predictions with ChemProp MPNN ensemble."""
153
+ """Make predictions with ChemProp MPNN ensemble.
154
+
155
+ Uses direct PyTorch inference (no Lightning Trainer) for fast serverless inference.
156
+ """
145
157
  model_type = TEMPLATE_PARAMS["model_type"]
146
158
  model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
147
159
 
148
160
  ensemble_models = model_dict["ensemble_models"]
149
161
  target_columns = model_dict["target_columns"]
150
- trainer = model_dict["trainer"] # Use pre-initialized trainer
151
162
 
152
163
  # Load artifacts
153
164
  label_encoder = None
@@ -202,18 +213,34 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
202
213
  return df
203
214
 
204
215
  dataset = data.MoleculeDataset(datapoints)
205
- dataloader = data.build_dataloader(dataset, shuffle=False)
216
+ dataloader = data.build_dataloader(dataset, shuffle=False, batch_size=64)
206
217
 
207
- # Ensemble predictions
218
+ # Ensemble predictions using direct PyTorch inference (no Lightning Trainer)
208
219
  all_preds = []
209
220
  for model in ensemble_models:
221
+ model_preds = []
222
+ model.eval()
210
223
  with torch.inference_mode():
211
- predictions = trainer.predict(model, dataloader)
212
- preds = np.concatenate([p.numpy() for p in predictions], axis=0)
224
+ for batch in dataloader:
225
+ # TrainingBatch contains (bmg, V_d, X_d, targets, weights, lt_mask, gt_mask)
226
+ # For inference we only need bmg, V_d, X_d
227
+ bmg, V_d, X_d, *_ = batch
228
+ output = model(bmg, V_d, X_d)
229
+ model_preds.append(output.detach().cpu().numpy())
230
+
231
+ if len(model_preds) == 0:
232
+ print(f"Warning: No predictions generated. Dataset size: {len(datapoints)}")
233
+ continue
234
+
235
+ preds = np.concatenate(model_preds, axis=0)
213
236
  if preds.ndim == 3 and preds.shape[1] == 1:
214
237
  preds = preds.squeeze(axis=1)
215
238
  all_preds.append(preds)
216
239
 
240
+ if len(all_preds) == 0:
241
+ print("Error: No ensemble predictions generated")
242
+ return df
243
+
217
244
  preds = np.mean(np.stack(all_preds), axis=0)
218
245
  preds_std = np.std(np.stack(all_preds), axis=0)
219
246
  if preds.ndim == 1:
@@ -243,8 +270,11 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
243
270
  df["prediction"] = df[f"{target_columns[0]}_pred"]
244
271
  df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
245
272
 
246
- # Compute confidence from ensemble std
247
- df = _compute_std_confidence(df, model_dict["median_std"])
273
+ # Compute confidence from ensemble std (or NaN if single model)
274
+ if model_dict["median_std"] is not None:
275
+ df = _compute_std_confidence(df, model_dict["median_std"])
276
+ else:
277
+ df["confidence"] = np.nan
248
278
 
249
279
  return df
250
280
 
@@ -279,54 +309,107 @@ if __name__ == "__main__":
279
309
  )
280
310
 
281
311
  # -------------------------------------------------------------------------
282
- # Training-only helper function
312
+ # Training-only helper functions
283
313
  # -------------------------------------------------------------------------
284
- def build_mpnn_model(
285
- hyperparameters: dict,
286
- task: str = "regression",
287
- num_classes: int | None = None,
288
- n_targets: int = 1,
289
- n_extra_descriptors: int = 0,
290
- x_d_transform: nn.ScaleTransform | None = None,
291
- output_transform: nn.UnscaleTransform | None = None,
292
- task_weights: np.ndarray | None = None,
293
- ) -> models.MPNN:
294
- """Build an MPNN model with specified hyperparameters."""
295
- hidden_dim = hyperparameters["hidden_dim"]
296
- depth = hyperparameters["depth"]
314
+ def _load_foundation_weights(from_foundation: str) -> tuple[nn.BondMessagePassing, nn.Aggregation]:
315
+ """Load pretrained MPNN weights from foundation model.
316
+
317
+ Args:
318
+ from_foundation: "CheMeleon" or path to .pt file
319
+
320
+ Returns:
321
+ Tuple of (message_passing, aggregation) modules
322
+ """
323
+ import urllib.request
324
+ from pathlib import Path
325
+
326
+ print(f"Loading foundation model: {from_foundation}")
327
+
328
+ if from_foundation.lower() == "chemeleon":
329
+ # Download from Zenodo if not cached
330
+ cache_dir = Path.home() / ".chemprop" / "foundation"
331
+ cache_dir.mkdir(parents=True, exist_ok=True)
332
+ chemeleon_path = cache_dir / "chemeleon_mp.pt"
333
+
334
+ if not chemeleon_path.exists():
335
+ print(" Downloading CheMeleon weights from Zenodo...")
336
+ urllib.request.urlretrieve(
337
+ "https://zenodo.org/records/15460715/files/chemeleon_mp.pt", chemeleon_path
338
+ )
339
+ print(f" Downloaded to {chemeleon_path}")
340
+
341
+ ckpt = torch.load(chemeleon_path, weights_only=True)
342
+ mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
343
+ mp.load_state_dict(ckpt["state_dict"])
344
+ print(f" Loaded CheMeleon MPNN (hidden_dim={mp.output_dim})")
345
+ return mp, nn.MeanAggregation()
346
+
347
+ if not os.path.exists(from_foundation):
348
+ raise ValueError(f"Foundation model not found: {from_foundation}. Use 'CheMeleon' or a valid .pt path.")
349
+
350
+ ckpt = torch.load(from_foundation, weights_only=False)
351
+ if "hyper_parameters" in ckpt and "state_dict" in ckpt:
352
+ # CheMeleon-style checkpoint
353
+ mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
354
+ mp.load_state_dict(ckpt["state_dict"])
355
+ print(f" Loaded custom foundation weights (hidden_dim={mp.output_dim})")
356
+ return mp, nn.MeanAggregation()
357
+
358
+ # Full MPNN model file
359
+ pretrained = models.MPNN.load_from_file(from_foundation)
360
+ print(f" Loaded custom MPNN (hidden_dim={pretrained.message_passing.output_dim})")
361
+ return pretrained.message_passing, pretrained.agg
362
+
363
+ def _build_ffn(
364
+ task: str, input_dim: int, hyperparameters: dict,
365
+ num_classes: int | None, n_targets: int,
366
+ output_transform: nn.UnscaleTransform | None, task_weights: np.ndarray | None,
367
+ ) -> nn.Predictor:
368
+ """Build task-specific FFN head."""
297
369
  dropout = hyperparameters["dropout"]
298
370
  ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
299
371
  ffn_num_layers = hyperparameters["ffn_num_layers"]
300
372
 
301
- mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
302
- agg = nn.NormAggregation()
303
- ffn_input_dim = hidden_dim + n_extra_descriptors
304
-
305
373
  if task == "classification" and num_classes is not None:
306
- ffn = nn.MulticlassClassificationFFN(
307
- n_classes=num_classes, input_dim=ffn_input_dim,
374
+ return nn.MulticlassClassificationFFN(
375
+ n_classes=num_classes, input_dim=input_dim,
308
376
  hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
309
377
  )
378
+
379
+ from chemprop.nn.metrics import MAE, MSE
380
+ criterion_map = {"mae": MAE, "mse": MSE}
381
+ criterion_name = hyperparameters.get("criterion", "mae")
382
+ if criterion_name not in criterion_map:
383
+ raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
384
+
385
+ weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
386
+ return nn.RegressionFFN(
387
+ input_dim=input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
388
+ dropout=dropout, n_tasks=n_targets, output_transform=output_transform,
389
+ task_weights=weights_tensor, criterion=criterion_map[criterion_name](),
390
+ )
391
+
392
+ def build_mpnn_model(
393
+ hyperparameters: dict, task: str = "regression", num_classes: int | None = None,
394
+ n_targets: int = 1, n_extra_descriptors: int = 0,
395
+ x_d_transform: nn.ScaleTransform | None = None,
396
+ output_transform: nn.UnscaleTransform | None = None, task_weights: np.ndarray | None = None,
397
+ ) -> models.MPNN:
398
+ """Build MPNN model, optionally loading pretrained weights."""
399
+ from_foundation = hyperparameters.get("from_foundation")
400
+
401
+ if from_foundation:
402
+ mp, agg = _load_foundation_weights(from_foundation)
403
+ ffn_input_dim = mp.output_dim + n_extra_descriptors
310
404
  else:
311
- # Map criterion name to ChemProp metric class (must have .clone() method)
312
- from chemprop.nn.metrics import MAE, MSE
313
-
314
- criterion_map = {
315
- "mae": MAE,
316
- "mse": MSE,
317
- }
318
- criterion_name = hyperparameters.get("criterion", "mae")
319
- if criterion_name not in criterion_map:
320
- raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
321
- criterion = criterion_map[criterion_name]()
322
-
323
- weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
324
- ffn = nn.RegressionFFN(
325
- input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
326
- dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
327
- criterion=criterion,
405
+ mp = nn.BondMessagePassing(
406
+ d_h=hyperparameters["hidden_dim"], depth=hyperparameters["depth"],
407
+ dropout=hyperparameters["dropout"],
328
408
  )
409
+ agg = nn.NormAggregation()
410
+ ffn_input_dim = hyperparameters["hidden_dim"] + n_extra_descriptors
329
411
 
412
+ ffn = _build_ffn(task, ffn_input_dim, hyperparameters, num_classes, n_targets, output_transform, task_weights)
330
413
  return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
331
414
 
332
415
  # -------------------------------------------------------------------------
@@ -359,6 +442,14 @@ if __name__ == "__main__":
359
442
  print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
360
443
  print(f"Hyperparameters: {hyperparameters}")
361
444
 
445
+ # Log foundation model configuration
446
+ if hyperparameters.get("from_foundation"):
447
+ freeze_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
448
+ freeze_msg = f"MPNN frozen for {freeze_epochs} epochs" if freeze_epochs > 0 else "no freezing"
449
+ print(f"Foundation model: {hyperparameters['from_foundation']} ({freeze_msg})")
450
+ else:
451
+ print("Foundation model: None (training from scratch)")
452
+
362
453
  # Load training data
363
454
  training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
364
455
  print(f"Training Files: {training_files}")
@@ -456,7 +547,7 @@ if __name__ == "__main__":
456
547
  print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
457
548
  print(f"{'='*50}")
458
549
 
459
- # Split data
550
+ # Split data (val_extra_raw preserves unscaled features for OOF predictions)
460
551
  df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
461
552
  train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
462
553
  train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
@@ -484,7 +575,7 @@ if __name__ == "__main__":
484
575
  train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
485
576
  val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
486
577
 
487
- # Build and train model
578
+ # Build model
488
579
  pl.seed_everything(hyperparameters["seed"] + fold_idx)
489
580
  mpnn = build_mpnn_model(
490
581
  hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
@@ -492,14 +583,39 @@ if __name__ == "__main__":
492
583
  output_transform=output_transform, task_weights=task_weights,
493
584
  )
494
585
 
495
- trainer = pl.Trainer(
496
- accelerator="auto", max_epochs=hyperparameters["max_epochs"], logger=False, enable_progress_bar=True,
497
- callbacks=[
498
- pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min"),
499
- pl.callbacks.ModelCheckpoint(dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1),
500
- ],
501
- )
502
- trainer.fit(mpnn, train_loader, val_loader)
586
+ # Train model (with optional two-phase foundation training)
587
+ freeze_mpnn_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
588
+ use_two_phase = hyperparameters.get("from_foundation") and freeze_mpnn_epochs > 0
589
+
590
+ def _set_mpnn_frozen(frozen: bool):
591
+ for param in mpnn.message_passing.parameters():
592
+ param.requires_grad = not frozen
593
+ for param in mpnn.agg.parameters():
594
+ param.requires_grad = not frozen
595
+
596
+ def _make_trainer(max_epochs: int, save_checkpoint: bool = False):
597
+ callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min")]
598
+ if save_checkpoint:
599
+ callbacks.append(pl.callbacks.ModelCheckpoint(
600
+ dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1
601
+ ))
602
+ return pl.Trainer(accelerator="auto", max_epochs=max_epochs, logger=False, enable_progress_bar=True, callbacks=callbacks)
603
+
604
+ if use_two_phase:
605
+ # Phase 1: Freeze MPNN, train FFN only
606
+ print(f"Phase 1: Training with frozen MPNN for {freeze_mpnn_epochs} epochs...")
607
+ _set_mpnn_frozen(True)
608
+ _make_trainer(freeze_mpnn_epochs).fit(mpnn, train_loader, val_loader)
609
+
610
+ # Phase 2: Unfreeze and fine-tune all
611
+ print("Phase 2: Unfreezing MPNN, continuing training...")
612
+ _set_mpnn_frozen(False)
613
+ remaining_epochs = max(1, hyperparameters["max_epochs"] - freeze_mpnn_epochs)
614
+ trainer = _make_trainer(remaining_epochs, save_checkpoint=True)
615
+ trainer.fit(mpnn, train_loader, val_loader)
616
+ else:
617
+ trainer = _make_trainer(hyperparameters["max_epochs"], save_checkpoint=True)
618
+ trainer.fit(mpnn, train_loader, val_loader)
503
619
 
504
620
  # Load best checkpoint
505
621
  if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
@@ -509,7 +625,7 @@ if __name__ == "__main__":
509
625
  mpnn.eval()
510
626
  ensemble_models.append(mpnn)
511
627
 
512
- # Out-of-fold predictions (using raw features)
628
+ # Out-of-fold predictions (using unscaled features - model's x_d_transform handles scaling)
513
629
  val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
514
630
  val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
515
631
 
@@ -599,11 +715,17 @@ if __name__ == "__main__":
599
715
  df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
600
716
  df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
601
717
 
602
- # Compute confidence from ensemble std
603
- median_std = float(np.median(preds_std[:, 0]))
604
- print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
605
- df_val = _compute_std_confidence(df_val, median_std)
606
- print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
718
+ # Compute confidence from ensemble std (or NaN for single model)
719
+ if preds_std is not None:
720
+ median_std = float(np.median(preds_std[:, 0]))
721
+ print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
722
+ df_val = _compute_std_confidence(df_val, median_std)
723
+ print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
724
+ else:
725
+ # Single model - no ensemble std available, confidence is undefined
726
+ median_std = None
727
+ df_val["confidence"] = np.nan
728
+ print("\nSingle model (n_folds=1): No ensemble std, confidence set to NaN")
607
729
 
608
730
  # -------------------------------------------------------------------------
609
731
  # Save validation predictions to S3
@@ -633,6 +755,9 @@ if __name__ == "__main__":
633
755
  "n_folds": n_folds,
634
756
  "target_columns": target_columns,
635
757
  "median_std": median_std, # For confidence calculation during inference
758
+ # Foundation model provenance (for tracking/reproducibility)
759
+ "from_foundation": hyperparameters.get("from_foundation", None),
760
+ "freeze_mpnn_epochs": hyperparameters.get("freeze_mpnn_epochs", 0),
636
761
  }
637
762
  joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
638
763
 
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
249
249
  raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
250
250
 
251
251
 
252
+ def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
253
+ """Cap extreme outliers in prediction_std using IQR method.
254
+
255
+ Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
256
+ This prevents unreasonably large std values while preserving the
257
+ relative ordering and keeping meaningful high-uncertainty signals.
258
+
259
+ Args:
260
+ std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
261
+
262
+ Returns:
263
+ Array with outliers capped at the upper fence
264
+ """
265
+ if std_array.ndim == 1:
266
+ std_array = std_array.reshape(-1, 1)
267
+ squeeze = True
268
+ else:
269
+ squeeze = False
270
+
271
+ capped = std_array.copy()
272
+ for col in range(capped.shape[1]):
273
+ col_data = capped[:, col]
274
+ q1, q3 = np.percentile(col_data, [25, 75])
275
+ iqr = q3 - q1
276
+ upper_bound = q3 + 1.5 * iqr
277
+ capped[:, col] = np.minimum(col_data, upper_bound)
278
+
279
+ return capped.squeeze() if squeeze else capped
280
+
281
+
252
282
  def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
253
283
  """Compute standard regression metrics.
254
284
 
@@ -99,7 +99,6 @@ from rdkit.ML.Descriptors import MoleculeDescriptors
99
99
  from mordred import Calculator as MordredCalculator
100
100
  from mordred import AcidBase, Aromatic, Constitutional, Chi, CarbonTypes
101
101
 
102
-
103
102
  logger = logging.getLogger("workbench")
104
103
  logger.setLevel(logging.DEBUG)
105
104
 
@@ -15,7 +15,6 @@ import json
15
15
  from mol_standardize import standardize
16
16
  from mol_descriptors import compute_descriptors
17
17
 
18
-
19
18
  # TRAINING SECTION
20
19
  #
21
20
  # This section (__main__) is where SageMaker will execute the training job
@@ -17,7 +17,6 @@ import json
17
17
  # Local imports
18
18
  from fingerprints import compute_morgan_fingerprints
19
19
 
20
-
21
20
  # TRAINING SECTION
22
21
  #
23
22
  # This section (__main__) is where SageMaker will execute the training job
@@ -59,12 +59,12 @@ DEFAULT_HYPERPARAMETERS = {
59
59
 
60
60
  # Template parameters (filled in by Workbench)
61
61
  TEMPLATE_PARAMS = {
62
- "model_type": "uq_regressor",
63
- "target": "udm_asy_res_efflux_ratio",
64
- "features": ['fingerprint'],
62
+ "model_type": "classifier",
63
+ "target": "class",
64
+ "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
65
65
  "id_column": "udm_mol_bat_id",
66
- "compressed_features": ['fingerprint'],
67
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-fp-pytorch/training",
66
+ "compressed_features": [],
67
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-class-pytorch-1-fr/training",
68
68
  "hyperparameters": {},
69
69
  }
70
70
 
@@ -152,24 +152,30 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
152
152
  print("Decompressing features for prediction...")
153
153
  matched_df, features = decompress_features(matched_df, features, compressed_features)
154
154
 
155
- # Track missing features
156
- missing_mask = matched_df[features].isna().any(axis=1)
157
- if missing_mask.any():
158
- print(f"Warning: {missing_mask.sum()} rows have missing features")
155
+ # Impute missing values (categorical with mode, continuous handled by scaler)
156
+ missing_counts = matched_df[features].isna().sum()
157
+ if missing_counts.any():
158
+ missing_features = missing_counts[missing_counts > 0]
159
+ print(f"Imputing missing values: {missing_features.to_dict()}")
160
+
161
+ # Load categorical imputation values if available
162
+ impute_path = os.path.join(model_dir, "categorical_impute.json")
163
+ if os.path.exists(impute_path):
164
+ with open(impute_path) as f:
165
+ cat_impute_values = json.load(f)
166
+ for col in categorical_cols:
167
+ if col in cat_impute_values and matched_df[col].isna().any():
168
+ matched_df[col] = matched_df[col].fillna(cat_impute_values[col])
169
+ # Continuous features are imputed by FeatureScaler.transform() using column means
159
170
 
160
171
  # Initialize output columns
161
172
  df["prediction"] = np.nan
162
173
  if model_type in ["regressor", "uq_regressor"]:
163
174
  df["prediction_std"] = np.nan
164
175
 
165
- complete_df = matched_df[~missing_mask].copy()
166
- if len(complete_df) == 0:
167
- print("Warning: No complete rows to predict on")
168
- return df
169
-
170
- # Prepare data for inference (with standardization)
176
+ # Prepare data for inference (with standardization and continuous imputation)
171
177
  x_cont, x_cat, _, _, _ = prepare_data(
172
- complete_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
178
+ matched_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
173
179
  )
174
180
 
175
181
  # Collect ensemble predictions
@@ -191,28 +197,20 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
191
197
  class_preds = np.argmax(avg_probs, axis=1)
192
198
  predictions = label_encoder.inverse_transform(class_preds)
193
199
 
194
- all_proba = pd.Series([None] * len(df), index=df.index, dtype=object)
195
- all_proba.loc[~missing_mask] = [p.tolist() for p in avg_probs]
196
- df["pred_proba"] = all_proba
200
+ df["pred_proba"] = [p.tolist() for p in avg_probs]
197
201
  df = expand_proba_column(df, label_encoder.classes_)
198
202
  else:
199
203
  # Regression
200
204
  predictions = preds.flatten()
201
- df.loc[~missing_mask, "prediction_std"] = preds_std.flatten()
205
+ df["prediction_std"] = preds_std.flatten()
202
206
 
203
207
  # Add UQ intervals if available
204
208
  if uq_models and uq_metadata:
205
- X_complete = complete_df[features]
206
- df_complete = df.loc[~missing_mask].copy()
207
- df_complete["prediction"] = predictions # Set prediction before compute_confidence
208
- df_complete = predict_intervals(df_complete, X_complete, uq_models, uq_metadata)
209
- df_complete = compute_confidence(df_complete, uq_metadata["median_interval_width"], "q_10", "q_90")
210
- # Copy UQ columns back to main dataframe
211
- for col in df_complete.columns:
212
- if col.startswith("q_") or col == "confidence":
213
- df.loc[~missing_mask, col] = df_complete[col].values
214
-
215
- df.loc[~missing_mask, "prediction"] = predictions
209
+ df["prediction"] = predictions # Set prediction before compute_confidence
210
+ df = predict_intervals(df, matched_df[features], uq_models, uq_metadata)
211
+ df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
212
+
213
+ df["prediction"] = predictions
216
214
  return df
217
215
 
218
216
 
@@ -275,11 +273,11 @@ if __name__ == "__main__":
275
273
  all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
276
274
  check_dataframe(all_df, "training_df")
277
275
 
278
- # Drop rows with missing features
276
+ # Drop rows with missing target (required for training)
279
277
  initial_count = len(all_df)
280
- all_df = all_df.dropna(subset=features)
278
+ all_df = all_df.dropna(subset=[target])
281
279
  if len(all_df) < initial_count:
282
- print(f"Dropped {initial_count - len(all_df)} rows with missing features")
280
+ print(f"Dropped {initial_count - len(all_df)} rows with missing target")
283
281
 
284
282
  print(f"Target: {target}")
285
283
  print(f"Features: {features}")
@@ -301,6 +299,23 @@ if __name__ == "__main__":
301
299
  print(f"Categorical: {categorical_cols}")
302
300
  print(f"Continuous: {len(continuous_cols)} columns")
303
301
 
302
+ # Report and handle missing values in features
303
+ # Compute categorical imputation values (mode) for use at inference time
304
+ cat_impute_values = {}
305
+ for col in categorical_cols:
306
+ mode_val = all_df[col].mode().iloc[0] if not all_df[col].mode().empty else all_df[col].cat.categories[0]
307
+ cat_impute_values[col] = str(mode_val) # Convert to string for JSON serialization
308
+
309
+ missing_counts = all_df[features].isna().sum()
310
+ if missing_counts.any():
311
+ missing_features = missing_counts[missing_counts > 0]
312
+ print(f"Missing values in features (will be imputed): {missing_features.to_dict()}")
313
+ # Impute categorical features with mode (most frequent value)
314
+ for col in categorical_cols:
315
+ if all_df[col].isna().any():
316
+ all_df[col] = all_df[col].fillna(cat_impute_values[col])
317
+ # Continuous features are imputed by FeatureScaler.transform() using column means
318
+
304
319
  # -------------------------------------------------------------------------
305
320
  # Classification setup
306
321
  # -------------------------------------------------------------------------
@@ -506,6 +521,9 @@ if __name__ == "__main__":
506
521
  with open(os.path.join(args.model_dir, "feature_metadata.json"), "w") as f:
507
522
  json.dump({"continuous_cols": continuous_cols, "categorical_cols": categorical_cols}, f)
508
523
 
524
+ with open(os.path.join(args.model_dir, "categorical_impute.json"), "w") as f:
525
+ json.dump(cat_impute_values, f)
526
+
509
527
  with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
510
528
  json.dump(hyperparameters, f, indent=2)
511
529