workbench 0.8.201__py3-none-any.whl → 0.8.204__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. workbench/api/df_store.py +17 -108
  2. workbench/api/feature_set.py +41 -7
  3. workbench/api/parameter_store.py +3 -52
  4. workbench/core/artifacts/artifact.py +5 -5
  5. workbench/core/artifacts/df_store_core.py +114 -0
  6. workbench/core/artifacts/endpoint_core.py +184 -75
  7. workbench/core/artifacts/model_core.py +11 -7
  8. workbench/core/artifacts/parameter_store_core.py +98 -0
  9. workbench/core/transforms/features_to_model/features_to_model.py +27 -13
  10. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +11 -0
  11. workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
  12. workbench/model_scripts/chemprop/chemprop.template +312 -293
  13. workbench/model_scripts/chemprop/generated_model_script.py +316 -297
  14. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +11 -5
  15. workbench/model_scripts/custom_models/uq_models/meta_uq.template +11 -5
  16. workbench/model_scripts/custom_models/uq_models/ngboost.template +11 -5
  17. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +11 -5
  18. workbench/model_scripts/pytorch_model/generated_model_script.py +278 -128
  19. workbench/model_scripts/pytorch_model/pytorch.template +273 -123
  20. workbench/model_scripts/uq_models/generated_model_script.py +20 -11
  21. workbench/model_scripts/uq_models/mapie.template +17 -8
  22. workbench/model_scripts/xgb_model/generated_model_script.py +38 -9
  23. workbench/model_scripts/xgb_model/xgb_model.template +34 -5
  24. workbench/resources/open_source_api.key +1 -1
  25. workbench/utils/chemprop_utils.py +38 -1
  26. workbench/utils/pytorch_utils.py +38 -8
  27. workbench/web_interface/components/model_plot.py +7 -1
  28. {workbench-0.8.201.dist-info → workbench-0.8.204.dist-info}/METADATA +2 -2
  29. {workbench-0.8.201.dist-info → workbench-0.8.204.dist-info}/RECORD +33 -33
  30. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  31. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -296
  32. {workbench-0.8.201.dist-info → workbench-0.8.204.dist-info}/WHEEL +0 -0
  33. {workbench-0.8.201.dist-info → workbench-0.8.204.dist-info}/entry_points.txt +0 -0
  34. {workbench-0.8.201.dist-info → workbench-0.8.204.dist-info}/licenses/LICENSE +0 -0
  35. {workbench-0.8.201.dist-info → workbench-0.8.204.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@
25
25
  # - argparse, file loading, S3 writes
26
26
  # =============================
27
27
 
28
+ import glob
28
29
  import os
29
30
  import argparse
30
31
  import json
@@ -39,11 +40,13 @@ from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
39
40
  from sklearn.preprocessing import LabelEncoder
40
41
  from sklearn.metrics import (
41
42
  mean_absolute_error,
43
+ median_absolute_error,
42
44
  r2_score,
43
45
  root_mean_squared_error,
44
46
  precision_recall_fscore_support,
45
47
  confusion_matrix,
46
48
  )
49
+ from scipy.stats import spearmanr
47
50
  import joblib
48
51
 
49
52
  # ChemProp imports
@@ -51,12 +54,12 @@ from chemprop import data, models, nn
51
54
 
52
55
  # Template Parameters
53
56
  TEMPLATE_PARAMS = {
54
- "model_type": "classifier",
55
- "target": "solubility_class",
56
- "feature_list": ['smiles'],
57
- "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-chemprop-class/training",
58
- "train_all_data": False,
59
- "hyperparameters": {'max_epochs': 400, 'hidden_dim': 300, 'depth': 3, 'n_folds': 5},
57
+ "model_type": "uq_regressor",
58
+ "targets": ['udm_asy_res_efflux_ratio'], # List of target columns (single or multi-task)
59
+ "feature_list": ['smiles', 'smr_vsa4', 'tpsa', 'nhohcount', 'mollogp', 'peoe_vsa1', 'smr_vsa3', 'nitrogen_span', 'numhdonors', 'minpartialcharge', 'vsa_estate3', 'vsa_estate6', 'tertiary_amine_count', 'hba_hbd_ratio', 'peoe_vsa8', 'estate_vsa4', 'xc_4dv', 'vsa_estate2', 'molmr', 'xp_2dv', 'mi', 'molecular_axis_length', 'vsa_estate4', 'xp_6dv', 'qed', 'estate_vsa8', 'chi1v', 'asphericity', 'axp_1d', 'bcut2d_logphi', 'kappa3', 'axp_7d', 'num_s_centers', 'amphiphilic_moment', 'molecular_asymmetry', 'charge_centroid_distance', 'estate_vsa3', 'vsa_estate8', 'aromatic_interaction_score', 'molecular_volume_3d', 'axp_7dv', 'peoe_vsa3', 'smr_vsa6', 'bcut2d_mrhi', 'radius_of_gyration', 'xpc_4dv', 'minabsestateindex', 'axp_0dv', 'chi4n', 'balabanj', 'bcut2d_mwlow'],
60
+ "id_column": "udm_mol_bat_id",
61
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-chemprop-reg-hybrid/training",
62
+ "hyperparameters": {'n_folds': 5, 'hidden_dim': 700, 'depth': 6, 'dropout': 0.15, 'ffn_hidden_dim': 2000, 'ffn_num_layers': 2},
60
63
  }
61
64
 
62
65
 
@@ -108,14 +111,14 @@ def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFra
108
111
 
109
112
  def create_molecule_datapoints(
110
113
  smiles_list: list[str],
111
- targets: list[float] | None = None,
114
+ targets: list[float] | np.ndarray | None = None,
112
115
  extra_descriptors: np.ndarray | None = None,
113
116
  ) -> tuple[list[data.MoleculeDatapoint], list[int]]:
114
117
  """Create ChemProp MoleculeDatapoints from SMILES strings.
115
118
 
116
119
  Args:
117
120
  smiles_list: List of SMILES strings
118
- targets: Optional list of target values (for training)
121
+ targets: Optional target values as 2D array (n_samples, n_targets). NaN allowed for missing targets.
119
122
  extra_descriptors: Optional array of extra features (n_samples, n_features)
120
123
 
121
124
  Returns:
@@ -127,6 +130,12 @@ def create_molecule_datapoints(
127
130
  valid_indices = []
128
131
  invalid_count = 0
129
132
 
133
+ # Convert targets to 2D array if provided
134
+ if targets is not None:
135
+ targets = np.atleast_2d(np.array(targets))
136
+ if targets.shape[0] == 1 and len(smiles_list) > 1:
137
+ targets = targets.T # Shape was (1, n_samples), transpose to (n_samples, 1)
138
+
130
139
  for i, smi in enumerate(smiles_list):
131
140
  # Validate SMILES with RDKit first
132
141
  mol = Chem.MolFromSmiles(smi)
@@ -134,8 +143,9 @@ def create_molecule_datapoints(
134
143
  invalid_count += 1
135
144
  continue
136
145
 
137
- # Build datapoint with optional target and extra descriptors
138
- y = [targets[i]] if targets is not None else None
146
+ # Build datapoint with optional target(s) and extra descriptors
147
+ # For multi-task, y is a list of values (can include NaN for missing targets)
148
+ y = targets[i].tolist() if targets is not None else None
139
149
  x_d = extra_descriptors[i] if extra_descriptors is not None else None
140
150
 
141
151
  dp = data.MoleculeDatapoint.from_smi(smi, y=y, x_d=x_d)
@@ -152,9 +162,11 @@ def build_mpnn_model(
152
162
  hyperparameters: dict,
153
163
  task: str = "regression",
154
164
  num_classes: int | None = None,
165
+ n_targets: int = 1,
155
166
  n_extra_descriptors: int = 0,
156
167
  x_d_transform: nn.ScaleTransform | None = None,
157
168
  output_transform: nn.UnscaleTransform | None = None,
169
+ task_weights: np.ndarray | None = None,
158
170
  ) -> models.MPNN:
159
171
  """Build an MPNN model with the specified hyperparameters.
160
172
 
@@ -162,19 +174,21 @@ def build_mpnn_model(
162
174
  hyperparameters: Dictionary of model hyperparameters
163
175
  task: Either "regression" or "classification"
164
176
  num_classes: Number of classes for classification tasks
177
+ n_targets: Number of target columns (for multi-task regression)
165
178
  n_extra_descriptors: Number of extra descriptor features (for hybrid mode)
166
179
  x_d_transform: Optional transform for extra descriptors (scaling)
167
180
  output_transform: Optional transform for regression output (unscaling targets)
181
+ task_weights: Optional array of weights for each task (multi-task learning)
168
182
 
169
183
  Returns:
170
184
  Configured MPNN model
171
185
  """
172
186
  # Model hyperparameters with defaults
173
- hidden_dim = hyperparameters.get("hidden_dim", 300)
174
- depth = hyperparameters.get("depth", 3)
175
- dropout = hyperparameters.get("dropout", 0.0)
176
- ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 300)
177
- ffn_num_layers = hyperparameters.get("ffn_num_layers", 1)
187
+ hidden_dim = hyperparameters.get("hidden_dim", 700)
188
+ depth = hyperparameters.get("depth", 6)
189
+ dropout = hyperparameters.get("dropout", 0.15)
190
+ ffn_hidden_dim = hyperparameters.get("ffn_hidden_dim", 2000)
191
+ ffn_num_layers = hyperparameters.get("ffn_num_layers", 2)
178
192
 
179
193
  # Message passing component
180
194
  mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
@@ -197,12 +211,20 @@ def build_mpnn_model(
197
211
  )
198
212
  else:
199
213
  # Regression with optional output transform to unscale predictions
214
+ # n_tasks controls the number of output heads for multi-task learning
215
+ # task_weights goes here (in RegressionFFN) to weight loss per task
216
+ weights_tensor = None
217
+ if task_weights is not None:
218
+ weights_tensor = torch.tensor(task_weights, dtype=torch.float32)
219
+
200
220
  ffn = nn.RegressionFFN(
201
221
  input_dim=ffn_input_dim,
202
222
  hidden_dim=ffn_hidden_dim,
203
223
  n_layers=ffn_num_layers,
204
224
  dropout=dropout,
225
+ n_tasks=n_targets,
205
226
  output_transform=output_transform,
227
+ task_weights=weights_tensor,
206
228
  )
207
229
 
208
230
  # Create the MPNN model
@@ -227,31 +249,26 @@ def model_fn(model_dir: str) -> dict:
227
249
  Returns:
228
250
  Dictionary with ensemble models and metadata
229
251
  """
230
- # Load ensemble metadata
252
+ # Load ensemble metadata (required)
231
253
  ensemble_metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
232
- if os.path.exists(ensemble_metadata_path):
233
- ensemble_metadata = joblib.load(ensemble_metadata_path)
234
- n_ensemble = ensemble_metadata["n_ensemble"]
235
- else:
236
- # Backwards compatibility: single model without ensemble metadata
237
- n_ensemble = 1
254
+ ensemble_metadata = joblib.load(ensemble_metadata_path)
255
+ n_ensemble = ensemble_metadata["n_ensemble"]
256
+ target_columns = ensemble_metadata["target_columns"]
238
257
 
239
258
  # Load all ensemble models
240
259
  ensemble_models = []
241
260
  for ens_idx in range(n_ensemble):
242
261
  model_path = os.path.join(model_dir, f"chemprop_model_{ens_idx}.pt")
243
- if not os.path.exists(model_path):
244
- # Backwards compatibility: try old single model path
245
- model_path = os.path.join(model_dir, "chemprop_model.pt")
246
262
  model = models.MPNN.load_from_file(model_path)
247
263
  model.eval()
248
264
  ensemble_models.append(model)
249
265
 
250
- print(f"Loaded {len(ensemble_models)} ensemble model(s)")
266
+ print(f"Loaded {len(ensemble_models)} ensemble model(s), n_targets={len(target_columns)}")
251
267
 
252
268
  return {
253
269
  "ensemble_models": ensemble_models,
254
270
  "n_ensemble": n_ensemble,
271
+ "target_columns": target_columns,
255
272
  }
256
273
 
257
274
 
@@ -297,9 +314,10 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
297
314
  model_type = TEMPLATE_PARAMS["model_type"]
298
315
  model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
299
316
 
300
- # Extract ensemble models
317
+ # Extract ensemble models and metadata
301
318
  ensemble_models = model_dict["ensemble_models"]
302
319
  n_ensemble = model_dict["n_ensemble"]
320
+ target_columns = model_dict["target_columns"]
303
321
 
304
322
  # Load label encoder if present (classification)
305
323
  label_encoder = None
@@ -337,13 +355,14 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
337
355
  valid_mask = np.array(valid_mask)
338
356
  print(f"Valid SMILES: {sum(valid_mask)} / {len(smiles_list)}")
339
357
 
340
- # Initialize prediction column (use object dtype for classifiers to avoid FutureWarning)
358
+ # Initialize prediction columns (use object dtype for classifiers to avoid FutureWarning)
341
359
  if model_type == "classifier":
342
360
  df["prediction"] = pd.Series([None] * len(df), dtype=object)
343
361
  else:
344
- df["prediction"] = np.nan
345
- if n_ensemble > 1:
346
- df["prediction_std"] = np.nan
362
+ # Regression: create prediction column for each target
363
+ for tc in target_columns:
364
+ df[f"{tc}_pred"] = np.nan
365
+ df[f"{tc}_pred_std"] = np.nan
347
366
 
348
367
  if sum(valid_mask) == 0:
349
368
  print("Warning: No valid SMILES to predict on")
@@ -408,10 +427,15 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
408
427
  ens_preds = ens_preds.squeeze(axis=1)
409
428
  all_ensemble_preds.append(ens_preds)
410
429
 
411
- # Stack and compute mean/std
430
+ # Stack and compute mean/std (std is 0 for single model)
412
431
  ensemble_preds = np.stack(all_ensemble_preds, axis=0)
413
432
  preds = np.mean(ensemble_preds, axis=0)
414
- preds_std = np.std(ensemble_preds, axis=0) if n_ensemble > 1 else None
433
+ preds_std = np.std(ensemble_preds, axis=0) # Will be 0s for n_ensemble=1
434
+
435
+ # Ensure 2D: (n_samples, n_targets)
436
+ if preds.ndim == 1:
437
+ preds = preds.reshape(-1, 1)
438
+ preds_std = preds_std.reshape(-1, 1)
415
439
 
416
440
  print(f"Inference: Ensemble predictions shape: {preds.shape}")
417
441
 
@@ -440,12 +464,15 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
440
464
  decoded_preds = label_encoder.inverse_transform(class_preds)
441
465
  df.loc[valid_mask, "prediction"] = decoded_preds
442
466
  else:
443
- # Regression: direct predictions
444
- df.loc[valid_mask, "prediction"] = preds.flatten()
467
+ # Regression: store predictions for each target
468
+ for t_idx, tc in enumerate(target_columns):
469
+ df.loc[valid_mask, f"{tc}_pred"] = preds[:, t_idx]
470
+ df.loc[valid_mask, f"{tc}_pred_std"] = preds_std[:, t_idx]
445
471
 
446
- # Add prediction_std for ensemble models
447
- if preds_std is not None:
448
- df.loc[valid_mask, "prediction_std"] = preds_std.flatten()
472
+ # Add prediction/prediction_std aliases for first target
473
+ first_target = target_columns[0]
474
+ df["prediction"] = df[f"{first_target}_pred"]
475
+ df["prediction_std"] = df[f"{first_target}_pred_std"]
449
476
 
450
477
  return df
451
478
 
@@ -454,13 +481,18 @@ if __name__ == "__main__":
454
481
  """Training script for ChemProp MPNN model"""
455
482
 
456
483
  # Template Parameters
457
- target = TEMPLATE_PARAMS["target"]
484
+ target_columns = TEMPLATE_PARAMS["targets"] # List of target columns
458
485
  model_type = TEMPLATE_PARAMS["model_type"]
459
486
  feature_list = TEMPLATE_PARAMS["feature_list"]
487
+ id_column = TEMPLATE_PARAMS["id_column"]
460
488
  model_metrics_s3_path = TEMPLATE_PARAMS["model_metrics_s3_path"]
461
- train_all_data = TEMPLATE_PARAMS["train_all_data"]
462
489
  hyperparameters = TEMPLATE_PARAMS["hyperparameters"]
463
- validation_split = 0.2
490
+
491
+ # Validate target_columns
492
+ if not target_columns or not isinstance(target_columns, list) or len(target_columns) == 0:
493
+ raise ValueError("'targets' must be a non-empty list of target column names")
494
+ n_targets = len(target_columns)
495
+ print(f"Target columns ({n_targets}): {target_columns}")
464
496
 
465
497
  # Get the SMILES column name from feature_list (user defines this, so we use their exact name)
466
498
  smiles_column = find_smiles_column(feature_list)
@@ -502,21 +534,29 @@ if __name__ == "__main__":
502
534
 
503
535
  check_dataframe(all_df, "training_df")
504
536
 
505
- # Drop rows with missing SMILES or target values
537
+ # Drop rows with missing SMILES or all target values
506
538
  initial_count = len(all_df)
507
- all_df = all_df.dropna(subset=[smiles_column, target])
539
+ all_df = all_df.dropna(subset=[smiles_column])
540
+ # Keep rows that have at least one non-null target (works for single and multi-task)
541
+ has_any_target = all_df[target_columns].notna().any(axis=1)
542
+ all_df = all_df[has_any_target]
508
543
  dropped = initial_count - len(all_df)
509
544
  if dropped > 0:
510
- print(f"Dropped {dropped} rows with missing SMILES or target values")
545
+ print(f"Dropped {dropped} rows with missing SMILES or all target values")
511
546
 
512
- print(f"Target: {target}")
547
+ print(f"Target columns: {target_columns}")
513
548
  print(f"Data Shape after cleaning: {all_df.shape}")
549
+ for tc in target_columns:
550
+ n_valid = all_df[tc].notna().sum()
551
+ print(f" {tc}: {n_valid} samples with values")
514
552
 
515
- # Set up label encoder for classification
553
+ # Set up label encoder for classification (single-target only)
516
554
  label_encoder = None
517
555
  if model_type == "classifier":
556
+ if n_targets > 1:
557
+ raise ValueError("Multi-task classification is not supported. Use regression for multi-task.")
518
558
  label_encoder = LabelEncoder()
519
- all_df[target] = label_encoder.fit_transform(all_df[target])
559
+ all_df[target_columns[0]] = label_encoder.fit_transform(all_df[target_columns[0]])
520
560
  num_classes = len(label_encoder.classes_)
521
561
  print(
522
562
  f"Classification task with {num_classes} classes: {label_encoder.classes_}"
@@ -528,10 +568,10 @@ if __name__ == "__main__":
528
568
  print(f"Hyperparameters: {hyperparameters}")
529
569
  task = "classification" if model_type == "classifier" else "regression"
530
570
  n_extra = len(extra_feature_cols) if use_extra_features else 0
531
- max_epochs = hyperparameters.get("max_epochs", 50)
532
- patience = hyperparameters.get("patience", 10)
533
- n_folds = hyperparameters.get("n_folds", 1) # Number of CV folds (default: 1 = no CV)
534
- batch_size = hyperparameters.get("batch_size", min(64, max(16, len(all_df) // 16)))
571
+ max_epochs = hyperparameters.get("max_epochs", 400)
572
+ patience = hyperparameters.get("patience", 40)
573
+ n_folds = hyperparameters.get("n_folds", 5) # Number of CV folds (default: 5)
574
+ batch_size = hyperparameters.get("batch_size", 16)
535
575
 
536
576
  # Check extra feature columns exist
537
577
  if use_extra_features:
@@ -540,60 +580,108 @@ if __name__ == "__main__":
540
580
  raise ValueError(f"Missing extra feature columns in training data: {missing_cols}")
541
581
 
542
582
  # =========================================================================
543
- # SINGLE MODEL TRAINING (n_folds=1) - uses train/val split
583
+ # UNIFIED TRAINING: Works for n_folds=1 (single model) or n_folds>1 (K-fold CV)
544
584
  # =========================================================================
585
+ print(f"Training {'single model' if n_folds == 1 else f'{n_folds}-fold cross-validation ensemble'}...")
586
+
587
+ # Prepare extra features and validate SMILES upfront
588
+ all_extra_features = None
589
+ col_means = None
590
+ if use_extra_features:
591
+ all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
592
+ col_means = np.nanmean(all_extra_features, axis=0)
593
+ for i in range(all_extra_features.shape[1]):
594
+ all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
595
+
596
+ # Prepare target array: always 2D (n_samples, n_targets)
597
+ all_targets = all_df[target_columns].values.astype(np.float32)
598
+
599
+ # Filter invalid SMILES from the full dataset
600
+ _, valid_indices = create_molecule_datapoints(
601
+ all_df[smiles_column].tolist(), all_targets, all_extra_features
602
+ )
603
+ all_df = all_df.iloc[valid_indices].reset_index(drop=True)
604
+ all_targets = all_targets[valid_indices]
605
+ if all_extra_features is not None:
606
+ all_extra_features = all_extra_features[valid_indices]
607
+ print(f"Data after SMILES validation: {all_df.shape}")
608
+
609
+ # Compute dynamic task weights for multi-task regression
610
+ # Weight = inverse of sample count (normalized so min weight = 1.0)
611
+ # This gives higher weight to targets with fewer samples
612
+ task_weights = None
613
+ if n_targets > 1 and model_type != "classifier":
614
+ sample_counts = np.array([np.sum(~np.isnan(all_targets[:, t])) for t in range(n_targets)])
615
+ # Inverse weighting: fewer samples = higher weight
616
+ inverse_counts = 1.0 / sample_counts
617
+ # Normalize so minimum weight is 1.0
618
+ task_weights = inverse_counts / inverse_counts.min()
619
+ print(f"Task weights (inverse sample count):")
620
+ for t_idx, t_name in enumerate(target_columns):
621
+ print(f" {t_name}: {task_weights[t_idx]:.3f} (n={sample_counts[t_idx]})")
622
+
623
+ # Create fold splits
545
624
  if n_folds == 1:
546
- print("Training single model (no cross-validation)...")
547
-
548
- # Split data
549
- if train_all_data:
550
- print("Training on ALL of the data")
551
- df_train = all_df.copy()
552
- df_val = all_df.copy()
553
- elif "training" in all_df.columns:
625
+ # Single fold: use train/val split from "training" column or random split
626
+ if "training" in all_df.columns:
554
627
  print("Found training column, splitting data based on training column")
555
- df_train = all_df[all_df["training"]].copy()
556
- df_val = all_df[~all_df["training"]].copy()
628
+ train_idx = np.where(all_df["training"])[0]
629
+ val_idx = np.where(~all_df["training"])[0]
557
630
  else:
558
- print("WARNING: No training column found, splitting data with random state=42")
559
- df_train, df_val = train_test_split(
560
- all_df, test_size=validation_split, random_state=42
561
- )
631
+ print("WARNING: No training column found, splitting data with random 80/20 split")
632
+ indices = np.arange(len(all_df))
633
+ train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)
634
+ folds = [(train_idx, val_idx)]
635
+ else:
636
+ # K-Fold CV
637
+ if model_type == "classifier":
638
+ kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
639
+ split_target = all_df[target_columns[0]]
640
+ else:
641
+ kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
642
+ split_target = None
643
+ folds = list(kfold.split(all_df, split_target))
644
+
645
+ # Initialize storage for out-of-fold predictions: always 2D (n_samples, n_targets)
646
+ oof_predictions = np.full((len(all_df), n_targets), np.nan, dtype=np.float64)
647
+ if model_type == "classifier" and num_classes and num_classes > 1:
648
+ oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
649
+ else:
650
+ oof_proba = None
562
651
 
563
- print(f"TRAIN: {df_train.shape}")
564
- print(f"VALIDATION: {df_val.shape}")
652
+ ensemble_models = []
565
653
 
566
- # Extract and prepare extra features
567
- train_extra_features = None
568
- val_extra_features = None
569
- col_means = None
654
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
655
+ print(f"\n{'='*50}")
656
+ print(f"Training Fold {fold_idx + 1}/{len(folds)}")
657
+ print(f"{'='*50}")
570
658
 
571
- if use_extra_features:
572
- train_extra_features = df_train[extra_feature_cols].values.astype(np.float32)
573
- val_extra_features = df_val[extra_feature_cols].values.astype(np.float32)
574
- col_means = np.nanmean(train_extra_features, axis=0)
575
- for i in range(train_extra_features.shape[1]):
576
- train_extra_features[np.isnan(train_extra_features[:, i]), i] = col_means[i]
577
- val_extra_features[np.isnan(val_extra_features[:, i]), i] = col_means[i]
578
-
579
- # Create ChemProp datasets
580
- train_datapoints, train_valid_idx = create_molecule_datapoints(
581
- df_train[smiles_column].tolist(), df_train[target].tolist(), train_extra_features
659
+ # Split data for this fold
660
+ df_train = all_df.iloc[train_idx].reset_index(drop=True)
661
+ df_val = all_df.iloc[val_idx].reset_index(drop=True)
662
+ train_targets = all_targets[train_idx]
663
+ val_targets = all_targets[val_idx]
664
+
665
+ train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
666
+ val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
667
+
668
+ print(f"Fold {fold_idx + 1} - Train: {len(df_train)}, Val: {len(df_val)}")
669
+
670
+ # Create ChemProp datasets for this fold
671
+ train_datapoints, _ = create_molecule_datapoints(
672
+ df_train[smiles_column].tolist(), train_targets, train_extra
582
673
  )
583
- val_datapoints, val_valid_idx = create_molecule_datapoints(
584
- df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra_features
674
+ val_datapoints, _ = create_molecule_datapoints(
675
+ df_val[smiles_column].tolist(), val_targets, val_extra
585
676
  )
586
677
 
587
- df_train = df_train.iloc[train_valid_idx].reset_index(drop=True)
588
- df_val = df_val.iloc[val_valid_idx].reset_index(drop=True)
589
-
590
678
  train_dataset = data.MoleculeDataset(train_datapoints)
591
679
  val_dataset = data.MoleculeDataset(val_datapoints)
592
680
 
593
- # Save raw validation features for predictions later
594
- val_extra_raw = val_extra_features[val_valid_idx] if val_extra_features is not None else None
681
+ # Save raw val features for prediction
682
+ val_extra_raw = val_extra.copy() if val_extra is not None else None
595
683
 
596
- # Scale features and targets
684
+ # Scale features and targets for this fold
597
685
  x_d_transform = None
598
686
  if use_extra_features:
599
687
  feature_scaler = train_dataset.normalize_inputs("X_d")
@@ -601,7 +689,7 @@ if __name__ == "__main__":
601
689
  x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
602
690
 
603
691
  output_transform = None
604
- if model_type == "regressor":
692
+ if model_type in ["regressor", "uq_regressor"]:
605
693
  target_scaler = train_dataset.normalize_targets()
606
694
  val_dataset.normalize_targets(target_scaler)
607
695
  output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
@@ -609,17 +697,18 @@ if __name__ == "__main__":
609
697
  train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
610
698
  val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
611
699
 
612
- # Build and train single model
613
- pl.seed_everything(42)
700
+ # Build and train model for this fold
701
+ pl.seed_everything(42 + fold_idx)
614
702
  mpnn = build_mpnn_model(
615
- hyperparameters, task=task, num_classes=num_classes,
703
+ hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
616
704
  n_extra_descriptors=n_extra, x_d_transform=x_d_transform, output_transform=output_transform,
705
+ task_weights=task_weights,
617
706
  )
618
707
 
619
708
  callbacks = [
620
709
  pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
621
710
  pl.callbacks.ModelCheckpoint(
622
- dirpath=args.model_dir, filename="best_model_0",
711
+ dirpath=args.model_dir, filename=f"best_model_{fold_idx}",
623
712
  monitor="val_loss", mode="min", save_top_k=1,
624
713
  ),
625
714
  ]
@@ -636,201 +725,95 @@ if __name__ == "__main__":
636
725
  mpnn.load_state_dict(checkpoint["state_dict"])
637
726
 
638
727
  mpnn.eval()
639
- ensemble_models = [mpnn]
728
+ ensemble_models.append(mpnn)
640
729
 
641
- # Make predictions on validation set
730
+ # Make out-of-fold predictions using raw features
642
731
  val_datapoints_raw, _ = create_molecule_datapoints(
643
- df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra_raw
732
+ df_val[smiles_column].tolist(), val_targets, val_extra_raw
644
733
  )
645
734
  val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
646
735
  val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
647
736
 
648
737
  with torch.inference_mode():
649
- val_predictions = trainer.predict(mpnn, val_loader_pred)
650
- preds = np.concatenate([p.numpy() for p in val_predictions], axis=0)
651
- if preds.ndim == 3 and preds.shape[1] == 1:
652
- preds = preds.squeeze(axis=1)
653
-
654
- preds_std = None
655
- y_validate = df_val[target].values
656
-
657
- # =========================================================================
658
- # K-FOLD CROSS-VALIDATION (n_folds > 1) - trains n_folds models
659
- # =========================================================================
660
- else:
661
- print(f"Training {n_folds}-fold cross-validation ensemble...")
662
-
663
- # Validate all SMILES upfront and filter invalid ones
664
- all_extra_features = None
665
- if use_extra_features:
666
- all_extra_features = all_df[extra_feature_cols].values.astype(np.float32)
667
- col_means = np.nanmean(all_extra_features, axis=0)
668
- for i in range(all_extra_features.shape[1]):
669
- all_extra_features[np.isnan(all_extra_features[:, i]), i] = col_means[i]
738
+ fold_predictions = trainer.predict(mpnn, val_loader_pred)
739
+ fold_preds = np.concatenate([p.numpy() for p in fold_predictions], axis=0)
740
+ if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
741
+ fold_preds = fold_preds.squeeze(axis=1)
742
+
743
+ # Store out-of-fold predictions
744
+ if model_type == "classifier" and fold_preds.ndim == 2:
745
+ # Store class index in first column for classification
746
+ oof_predictions[val_idx, 0] = np.argmax(fold_preds, axis=1)
747
+ if oof_proba is not None:
748
+ oof_proba[val_idx] = fold_preds
670
749
  else:
671
- col_means = None
750
+ # Regression: fold_preds shape is (n_val, n_targets) or (n_val,)
751
+ if fold_preds.ndim == 1:
752
+ fold_preds = fold_preds.reshape(-1, 1)
753
+ oof_predictions[val_idx] = fold_preds
672
754
 
673
- # Filter invalid SMILES from the full dataset
674
- _, valid_indices = create_molecule_datapoints(
675
- all_df[smiles_column].tolist(), all_df[target].tolist(), all_extra_features
676
- )
677
- all_df = all_df.iloc[valid_indices].reset_index(drop=True)
678
- if all_extra_features is not None:
679
- all_extra_features = all_extra_features[valid_indices]
680
- print(f"Data after SMILES validation: {all_df.shape}")
755
+ print(f"Fold {fold_idx + 1} complete!")
681
756
 
682
- # Set up K-Fold
683
- if model_type == "classifier":
684
- kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
685
- split_target = all_df[target]
686
- else:
687
- kfold = KFold(n_splits=n_folds, shuffle=True, random_state=42)
688
- split_target = None
757
+ print(f"\nTraining complete! Trained {len(ensemble_models)} model(s).")
689
758
 
690
- # Initialize storage for out-of-fold predictions
691
- oof_predictions = np.full(len(all_df), np.nan, dtype=np.float64)
692
- if model_type == "classifier" and num_classes and num_classes > 1:
693
- oof_proba = np.full((len(all_df), num_classes), np.nan, dtype=np.float64)
694
- else:
695
- oof_proba = None
696
-
697
- ensemble_models = []
698
-
699
- for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(all_df, split_target)):
700
- print(f"\n{'='*50}")
701
- print(f"Training Fold {fold_idx + 1}/{n_folds}")
702
- print(f"{'='*50}")
703
-
704
- # Split data for this fold
705
- df_train = all_df.iloc[train_idx].reset_index(drop=True)
706
- df_val = all_df.iloc[val_idx].reset_index(drop=True)
707
-
708
- train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
709
- val_extra = all_extra_features[val_idx] if all_extra_features is not None else None
710
-
711
- print(f"Fold {fold_idx + 1} - Train: {len(df_train)}, Val: {len(df_val)}")
712
-
713
- # Create ChemProp datasets for this fold
714
- train_datapoints, _ = create_molecule_datapoints(
715
- df_train[smiles_column].tolist(), df_train[target].tolist(), train_extra
716
- )
717
- val_datapoints, _ = create_molecule_datapoints(
718
- df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra
719
- )
720
-
721
- train_dataset = data.MoleculeDataset(train_datapoints)
722
- val_dataset = data.MoleculeDataset(val_datapoints)
723
-
724
- # Save raw val features for prediction
725
- val_extra_raw = val_extra.copy() if val_extra is not None else None
726
-
727
- # Scale features and targets for this fold
728
- x_d_transform = None
729
- if use_extra_features:
730
- feature_scaler = train_dataset.normalize_inputs("X_d")
731
- val_dataset.normalize_inputs("X_d", feature_scaler)
732
- x_d_transform = nn.ScaleTransform.from_standard_scaler(feature_scaler)
733
-
734
- output_transform = None
735
- if model_type == "regressor":
736
- target_scaler = train_dataset.normalize_targets()
737
- val_dataset.normalize_targets(target_scaler)
738
- output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
739
-
740
- train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True)
741
- val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False)
742
-
743
- # Build and train model for this fold
744
- pl.seed_everything(42 + fold_idx)
745
- mpnn = build_mpnn_model(
746
- hyperparameters, task=task, num_classes=num_classes,
747
- n_extra_descriptors=n_extra, x_d_transform=x_d_transform, output_transform=output_transform,
748
- )
749
-
750
- callbacks = [
751
- pl.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min"),
752
- pl.callbacks.ModelCheckpoint(
753
- dirpath=args.model_dir, filename=f"best_model_{fold_idx}",
754
- monitor="val_loss", mode="min", save_top_k=1,
755
- ),
756
- ]
757
-
758
- trainer = pl.Trainer(
759
- accelerator="auto", max_epochs=max_epochs, callbacks=callbacks,
760
- logger=False, enable_progress_bar=True,
761
- )
762
-
763
- trainer.fit(mpnn, train_loader, val_loader)
764
-
765
- if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
766
- checkpoint = torch.load(trainer.checkpoint_callback.best_model_path, weights_only=False)
767
- mpnn.load_state_dict(checkpoint["state_dict"])
768
-
769
- mpnn.eval()
770
- ensemble_models.append(mpnn)
771
-
772
- # Make out-of-fold predictions using raw features
773
- val_datapoints_raw, _ = create_molecule_datapoints(
774
- df_val[smiles_column].tolist(), df_val[target].tolist(), val_extra_raw
775
- )
776
- val_dataset_raw = data.MoleculeDataset(val_datapoints_raw)
777
- val_loader_pred = data.build_dataloader(val_dataset_raw, batch_size=batch_size, shuffle=False)
759
+ # Use out-of-fold predictions for metrics
760
+ # For n_folds=1, we only have predictions for val_idx, so filter to those rows
761
+ if n_folds == 1:
762
+ # oof_predictions is always 2D now: check if any column has a value
763
+ val_mask = ~np.isnan(oof_predictions).all(axis=1)
764
+ preds = oof_predictions[val_mask]
765
+ df_val = all_df[val_mask].copy()
766
+ y_validate = all_targets[val_mask]
767
+ if oof_proba is not None:
768
+ oof_proba = oof_proba[val_mask]
769
+ val_extra_features = all_extra_features[val_mask] if all_extra_features is not None else None
770
+ else:
771
+ preds = oof_predictions
772
+ df_val = all_df.copy()
773
+ y_validate = all_targets
774
+ val_extra_features = all_extra_features
775
+
776
+ # Compute prediction_std by running all ensemble models on validation data
777
+ # For n_folds=1, std will be 0 (only one model). For n_folds>1, std shows ensemble disagreement.
778
+ preds_std = None
779
+ if model_type in ["regressor", "uq_regressor"] and len(ensemble_models) > 0:
780
+ print("Computing prediction_std from ensemble predictions on validation data...")
781
+ val_datapoints_for_std, _ = create_molecule_datapoints(
782
+ df_val[smiles_column].tolist(),
783
+ y_validate,
784
+ val_extra_features
785
+ )
786
+ val_dataset_for_std = data.MoleculeDataset(val_datapoints_for_std)
787
+ val_loader_for_std = data.build_dataloader(val_dataset_for_std, batch_size=batch_size, shuffle=False)
778
788
 
789
+ all_ensemble_preds_for_std = []
790
+ trainer_pred = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
791
+ for ens_model in ensemble_models:
779
792
  with torch.inference_mode():
780
- fold_predictions = trainer.predict(mpnn, val_loader_pred)
781
- fold_preds = np.concatenate([p.numpy() for p in fold_predictions], axis=0)
782
- if fold_preds.ndim == 3 and fold_preds.shape[1] == 1:
783
- fold_preds = fold_preds.squeeze(axis=1)
784
-
785
- # Store out-of-fold predictions
786
- if model_type == "classifier" and fold_preds.ndim == 2:
787
- oof_predictions[val_idx] = np.argmax(fold_preds, axis=1)
788
- if oof_proba is not None:
789
- oof_proba[val_idx] = fold_preds
790
- else:
791
- oof_predictions[val_idx] = fold_preds.flatten()
792
-
793
- print(f"Fold {fold_idx + 1} complete!")
794
-
795
- print(f"\nCross-validation complete! Trained {len(ensemble_models)} models.")
796
-
797
- # Use out-of-fold predictions for metrics
798
- preds = oof_predictions
799
- preds_std = None # Will compute from ensemble at inference time
800
- y_validate = all_df[target].values
801
- df_val = all_df # For saving predictions
793
+ ens_preds = trainer_pred.predict(ens_model, val_loader_for_std)
794
+ ens_preds = np.concatenate([p.numpy() for p in ens_preds], axis=0)
795
+ if ens_preds.ndim == 3 and ens_preds.shape[1] == 1:
796
+ ens_preds = ens_preds.squeeze(axis=1)
797
+ all_ensemble_preds_for_std.append(ens_preds)
798
+
799
+ # Stack ensemble predictions: shape (n_ensemble, n_samples, n_targets)
800
+ ensemble_preds_stacked = np.stack(all_ensemble_preds_for_std, axis=0)
801
+ preds_std = np.std(ensemble_preds_stacked, axis=0)
802
+ # Ensure 2D
803
+ if preds_std.ndim == 1:
804
+ preds_std = preds_std.reshape(-1, 1)
805
+ print(f"Ensemble prediction_std - mean per target: {np.nanmean(preds_std, axis=0)}")
802
806
 
803
807
  if model_type == "classifier":
804
- # Classification metrics - handle multi-class output
805
- # For CV mode, preds already contains class indices; for single model, preds are probabilities
806
- if preds.ndim == 2 and preds.shape[1] > 1:
807
- # Multi-class probabilities: (n_samples, n_classes), take argmax
808
- class_preds = np.argmax(preds, axis=1)
809
- has_proba = True
810
- elif preds.ndim == 1:
811
- # Either class indices (CV mode) or binary probabilities
812
- if n_folds > 1:
813
- # CV mode: preds are already class indices
814
- class_preds = preds.astype(int)
815
- has_proba = False
816
- else:
817
- # Single model: preds are probabilities
818
- class_preds = (preds > 0.5).astype(int)
819
- has_proba = False
820
- else:
821
- # Squeeze extra dimensions if needed
822
- preds = preds.squeeze()
823
- if preds.ndim == 2:
824
- class_preds = np.argmax(preds, axis=1)
825
- has_proba = True
826
- else:
827
- class_preds = (preds > 0.5).astype(int)
828
- has_proba = False
808
+ # Classification metrics - preds contains class indices in first column from OOF predictions
809
+ class_preds = preds[:, 0].astype(int)
810
+ has_proba = oof_proba is not None
829
811
 
830
812
  print(f"class_preds shape: {class_preds.shape}")
831
813
 
832
- # Decode labels for metrics
833
- y_validate_decoded = label_encoder.inverse_transform(y_validate.astype(int))
814
+ # Decode labels for metrics (classification is single-target only)
815
+ target_name = target_columns[0]
816
+ y_validate_decoded = label_encoder.inverse_transform(y_validate[:, 0].astype(int))
834
817
  preds_decoded = label_encoder.inverse_transform(class_preds)
835
818
 
836
819
  # Calculate metrics
@@ -841,7 +824,7 @@ if __name__ == "__main__":
841
824
 
842
825
  score_df = pd.DataFrame(
843
826
  {
844
- target: label_names,
827
+ target_name: label_names,
845
828
  "precision": scores[0],
846
829
  "recall": scores[1],
847
830
  "f1": scores[2],
@@ -853,7 +836,7 @@ if __name__ == "__main__":
853
836
  metrics = ["precision", "recall", "f1", "support"]
854
837
  for t in label_names:
855
838
  for m in metrics:
856
- value = score_df.loc[score_df[target] == t, m].iloc[0]
839
+ value = score_df.loc[score_df[target_name] == t, m].iloc[0]
857
840
  print(f"Metrics:{t}:{m} {value}")
858
841
 
859
842
  # Confusion matrix
@@ -868,34 +851,61 @@ if __name__ == "__main__":
868
851
  # Save validation predictions
869
852
  df_val = df_val.copy()
870
853
  df_val["prediction"] = preds_decoded
871
- if has_proba and preds.ndim == 2 and preds.shape[1] > 1:
872
- df_val["pred_proba"] = [p.tolist() for p in preds]
854
+ if has_proba and oof_proba is not None:
855
+ df_val["pred_proba"] = [p.tolist() for p in oof_proba]
873
856
  df_val = expand_proba_column(df_val, label_names)
874
857
 
875
858
  else:
876
- # Regression metrics
877
- preds_flat = preds.flatten()
878
- rmse = root_mean_squared_error(y_validate, preds_flat)
879
- mae = mean_absolute_error(y_validate, preds_flat)
880
- r2 = r2_score(y_validate, preds_flat)
881
- print(f"RMSE: {rmse:.3f}")
882
- print(f"MAE: {mae:.3f}")
883
- print(f"R2: {r2:.3f}")
884
- print(f"NumRows: {len(df_val)}")
885
-
859
+ # Regression metrics: compute per target (works for single or multi-task)
886
860
  df_val = df_val.copy()
887
- df_val["prediction"] = preds_flat
861
+ print("\n--- Per-target metrics ---")
862
+ for t_idx, t_name in enumerate(target_columns):
863
+ # Get valid (non-NaN) indices for this target
864
+ target_valid_mask = ~np.isnan(y_validate[:, t_idx])
865
+ y_true = y_validate[target_valid_mask, t_idx]
866
+ y_pred = preds[target_valid_mask, t_idx]
867
+
868
+ if len(y_true) > 0:
869
+ rmse = root_mean_squared_error(y_true, y_pred)
870
+ mae = mean_absolute_error(y_true, y_pred)
871
+ medae = median_absolute_error(y_true, y_pred)
872
+ r2 = r2_score(y_true, y_pred)
873
+ spearman_corr = spearmanr(y_true, y_pred).correlation
874
+ support = len(y_true)
875
+ # Print metrics in format expected by SageMaker metric definitions
876
+ print(f"rmse: {rmse:.3f}")
877
+ print(f"mae: {mae:.3f}")
878
+ print(f"medae: {medae:.3f}")
879
+ print(f"r2: {r2:.3f}")
880
+ print(f"spearmanr: {spearman_corr:.3f}")
881
+ print(f"support: {support}")
882
+
883
+ # Store predictions in dataframe
884
+ df_val[f"{t_name}_pred"] = preds[:, t_idx]
885
+ if preds_std is not None:
886
+ df_val[f"{t_name}_pred_std"] = preds_std[:, t_idx]
887
+ else:
888
+ df_val[f"{t_name}_pred_std"] = 0.0
888
889
 
889
- # Add prediction_std for ensemble models
890
- if preds_std is not None:
891
- df_val["prediction_std"] = preds_std.flatten()
892
- print(f"Ensemble std - mean: {df_val['prediction_std'].mean():.4f}, max: {df_val['prediction_std'].max():.4f}")
890
+ # Add prediction/prediction_std aliases for first target
891
+ first_target = target_columns[0]
892
+ df_val["prediction"] = df_val[f"{first_target}_pred"]
893
+ df_val["prediction_std"] = df_val[f"{first_target}_pred_std"]
893
894
 
894
895
  # Save validation predictions to S3
895
- output_columns = [target, "prediction"]
896
- if "prediction_std" in df_val.columns:
897
- output_columns.append("prediction_std")
896
+ # Include id_column if it exists in df_val
897
+ output_columns = []
898
+ if id_column in df_val.columns:
899
+ output_columns.append(id_column)
900
+ # Include all target columns and their predictions
901
+ output_columns += target_columns
902
+ output_columns += [f"{t}_pred" for t in target_columns]
903
+ output_columns += [f"{t}_pred_std" for t in target_columns]
904
+ output_columns += ["prediction", "prediction_std"]
905
+ # Add proba columns for classifiers
898
906
  output_columns += [col for col in df_val.columns if col.endswith("_proba")]
907
+ # Filter to only columns that exist
908
+ output_columns = [c for c in output_columns if c in df_val.columns]
899
909
  wr.s3.to_csv(
900
910
  df_val[output_columns],
901
911
  path=f"{model_metrics_s3_path}/validation_predictions.csv",
@@ -908,11 +918,20 @@ if __name__ == "__main__":
908
918
  models.save_model(model_path, ens_model)
909
919
  print(f"Saved model {model_idx + 1} to {model_path}")
910
920
 
921
+ # Clean up checkpoint files (not needed for inference, reduces artifact size)
922
+ for ckpt_file in glob.glob(os.path.join(args.model_dir, "best_model_*.ckpt")):
923
+ os.remove(ckpt_file)
924
+ print(f"Removed checkpoint: {ckpt_file}")
925
+
911
926
  # Save ensemble metadata (n_ensemble = number of models for inference)
912
927
  n_ensemble = len(ensemble_models)
913
- ensemble_metadata = {"n_ensemble": n_ensemble, "n_folds": n_folds}
928
+ ensemble_metadata = {
929
+ "n_ensemble": n_ensemble,
930
+ "n_folds": n_folds,
931
+ "target_columns": target_columns,
932
+ }
914
933
  joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
915
- print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds})")
934
+ print(f"Saved ensemble metadata (n_ensemble={n_ensemble}, n_folds={n_folds}, targets={target_columns})")
916
935
 
917
936
  # Save label encoder if classification
918
937
  if label_encoder is not None: