pg-sui 1.6.14.dev9__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -110,15 +110,18 @@ class ImputeNLPCA(BaseNNImputer):
110
110
  tree_parser: Optional["TreeParser"] = None,
111
111
  config: NLPCAConfig | dict | str | None = None,
112
112
  overrides: dict | None = None,
113
- simulate_missing: bool = False,
114
- sim_strategy: Literal[
115
- "random",
116
- "random_weighted",
117
- "random_weighted_inv",
118
- "nonrandom",
119
- "nonrandom_weighted",
120
- ] = "random",
121
- sim_prop: float = 0.10,
113
+ simulate_missing: bool | None = None,
114
+ sim_strategy: (
115
+ Literal[
116
+ "random",
117
+ "random_weighted",
118
+ "random_weighted_inv",
119
+ "nonrandom",
120
+ "nonrandom_weighted",
121
+ ]
122
+ | None
123
+ ) = None,
124
+ sim_prop: float | None = None,
122
125
  sim_kwargs: dict | None = None,
123
126
  ):
124
127
  """Initializes the ImputeNLPCA imputer with genotype data and configuration.
@@ -130,10 +133,10 @@ class ImputeNLPCA(BaseNNImputer):
130
133
  tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for population-specific modes.
131
134
  config (NLPCAConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
132
135
  overrides (dict | None): Dot-key overrides (e.g. {'model.latent_dim': 4}).
133
- simulate_missing (bool): Whether to simulate missing data during training.
134
- sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Strategy for simulating missing data.
135
- sim_prop (float): Proportion of data to simulate as missing.
136
- sim_kwargs (dict | None): Additional keyword arguments for missing data simulation.
136
+ simulate_missing (bool | None): Whether to simulate missing data during training. If None, uses config defaults.
137
+ sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"] | None): Strategy for simulating missing data. If None, uses config default.
138
+ sim_prop (float | None): Proportion of data to simulate as missing. If None, uses config default.
139
+ sim_kwargs (dict | None): Additional keyword arguments for missing data simulation (overrides config kwargs).
137
140
  """
138
141
  self.model_name = "ImputeNLPCA"
139
142
  self.genotype_data = genotype_data
@@ -141,6 +144,7 @@ class ImputeNLPCA(BaseNNImputer):
141
144
 
142
145
  # Normalize config first, then apply overrides (highest precedence)
143
146
  cfg = ensure_nlpca_config(config)
147
+
144
148
  if overrides:
145
149
  cfg = apply_dot_overrides(cfg, overrides)
146
150
 
@@ -153,9 +157,7 @@ class ImputeNLPCA(BaseNNImputer):
153
157
  verbose=self.cfg.io.verbose,
154
158
  )
155
159
  self.logger = configure_logger(
156
- logman.get_logger(),
157
- verbose=self.cfg.io.verbose,
158
- debug=self.cfg.io.debug,
160
+ logman.get_logger(), verbose=self.cfg.io.verbose, debug=self.cfg.io.debug
159
161
  )
160
162
 
161
163
  # Initialize BaseNNImputer with device/dirs/logging from config
@@ -178,6 +180,7 @@ class ImputeNLPCA(BaseNNImputer):
178
180
  self.debug = self.cfg.io.debug
179
181
 
180
182
  self.rng = np.random.default_rng(self.seed)
183
+ self.pos_weights_: torch.Tensor | None = None
181
184
 
182
185
  # Model/train hyperparams
183
186
  self.latent_dim = self.cfg.model.latent_dim
@@ -242,10 +245,31 @@ class ImputeNLPCA(BaseNNImputer):
242
245
  self.num_classes_ = 3
243
246
  self.model_params: Dict[str, Any] = {}
244
247
 
245
- self.simulate_missing = simulate_missing
246
- self.sim_strategy = sim_strategy
247
- self.sim_prop = float(sim_prop)
248
- self.sim_kwargs = sim_kwargs or {}
248
+ sim_cfg = getattr(self.cfg, "sim", None)
249
+ sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
250
+
251
+ if sim_kwargs:
252
+ sim_cfg_kwargs.update(sim_kwargs)
253
+
254
+ if sim_cfg is None:
255
+ default_strategy = "random"
256
+ default_prop = 0.10
257
+ else:
258
+ default_strategy = sim_cfg.sim_strategy
259
+ default_prop = sim_cfg.sim_prop
260
+
261
+ self.simulate_missing = (
262
+ (
263
+ sim_cfg.simulate_missing
264
+ if simulate_missing is None
265
+ else bool(simulate_missing)
266
+ )
267
+ if sim_cfg is not None
268
+ else bool(simulate_missing)
269
+ )
270
+ self.sim_strategy = sim_strategy or default_strategy
271
+ self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
272
+ self.sim_kwargs = sim_cfg_kwargs
249
273
 
250
274
  if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
251
275
  msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
@@ -319,10 +343,11 @@ class ImputeNLPCA(BaseNNImputer):
319
343
  self.logger.info("Haploid data detected. Using 2 classes (REF=0, ALT=1).")
320
344
  else:
321
345
  self.num_classes_ = 3
322
-
323
- self.logger.info(
324
- "Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2)."
325
- )
346
+ # Model head uses two channels; scoring uses num_classes_
347
+ self.output_classes_ = 2
348
+ self.logger.info(
349
+ "Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2) for scoring; 2 output channels with sigmoid for training."
350
+ )
326
351
 
327
352
  n_samples, self.num_features_ = X_for_model.shape
328
353
 
@@ -332,7 +357,7 @@ class ImputeNLPCA(BaseNNImputer):
332
357
  "dropout_rate": self.dropout_rate,
333
358
  "activation": self.activation,
334
359
  "gamma": self.gamma,
335
- "num_classes": self.num_classes_,
360
+ "num_classes": self.output_classes_,
336
361
  }
337
362
 
338
363
  # --- Train/Test Split ---
@@ -354,6 +379,11 @@ class ImputeNLPCA(BaseNNImputer):
354
379
  else:
355
380
  self.sim_mask_train_ = None
356
381
  self.sim_mask_test_ = None
382
+ # pos weights for multilabel diploid path
383
+ if not self.is_haploid:
384
+ self.pos_weights_ = self._compute_pos_weights(self.X_train_)
385
+ else:
386
+ self.pos_weights_ = None
357
387
 
358
388
  # Tuning, model setup, training (unchanged except DataLoader input)
359
389
  self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
@@ -519,17 +549,24 @@ class ImputeNLPCA(BaseNNImputer):
519
549
 
520
550
  # Forward
521
551
  z = latent_vectors[batch_indices].to(self.device)
522
- logits = decoder(z).view(len(batch_indices), nF, self.num_classes_)
552
+ logits = decoder(z).view(len(batch_indices), nF, self.output_classes_)
523
553
 
524
554
  # Guard upstream explosions
525
555
  if not torch.isfinite(logits).all():
526
556
  # Skip batch if model already produced non-finite values
527
557
  continue
528
558
 
529
- logits_flat = logits.view(-1, self.num_classes_)
530
- targets_flat = y_batch.view(-1)
531
-
532
- loss = criterion(logits_flat, targets_flat)
559
+ if self.is_haploid:
560
+ logits_flat = logits.view(-1, self.output_classes_)
561
+ targets_flat = y_batch.view(-1)
562
+ loss = criterion(logits_flat, targets_flat)
563
+ else:
564
+ targets = self._multi_hot_targets(y_batch)
565
+ bce = F.binary_cross_entropy_with_logits(
566
+ logits, targets, pos_weight=self.pos_weights_, reduction="none"
567
+ )
568
+ mask = (y_batch != -1).unsqueeze(-1).float()
569
+ loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
533
570
 
534
571
  # L1 on model weights only (exclude latents)
535
572
  if l1_penalty > 0:
@@ -605,10 +642,21 @@ class ImputeNLPCA(BaseNNImputer):
605
642
 
606
643
  with torch.no_grad():
607
644
  logits = model.phase23_decoder(latent_vectors.to(self.device)).view(
608
- len(latent_vectors), nF, self.num_classes_
645
+ len(latent_vectors), nF, self.output_classes_
609
646
  )
610
- probas = torch.softmax(logits, dim=-1)
611
- labels = torch.argmax(probas, dim=-1)
647
+ if self.is_haploid:
648
+ probas = torch.softmax(logits, dim=-1)
649
+ labels = torch.argmax(probas, dim=-1)
650
+ else:
651
+ probas2 = torch.sigmoid(logits)
652
+ p_ref = probas2[..., 0]
653
+ p_alt = probas2[..., 1]
654
+ p_het = p_ref * p_alt
655
+ p_ref_only = p_ref * (1 - p_alt)
656
+ p_alt_only = p_alt * (1 - p_ref)
657
+ probas = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
658
+ probas = probas / probas.sum(dim=-1, keepdim=True).clamp_min(1e-8)
659
+ labels = torch.argmax(probas, dim=-1)
612
660
 
613
661
  return labels.cpu().numpy(), probas.cpu().numpy()
614
662
 
@@ -798,6 +846,44 @@ class ImputeNLPCA(BaseNNImputer):
798
846
  dataset, batch_size=self.batch_size, shuffle=True
799
847
  )
800
848
 
849
+ def _encode_multilabel_inputs(self, y: torch.Tensor) -> torch.Tensor:
850
+ """Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
851
+ if self.is_haploid:
852
+ return self._one_hot_encode_012(y)
853
+ y = y.to(self.device)
854
+ shape = y.shape + (2,)
855
+ out = torch.zeros(shape, device=self.device, dtype=torch.float32)
856
+ valid = y != -1
857
+ ref_mask = valid & (y != 2)
858
+ alt_mask = valid & (y != 0)
859
+ out[ref_mask, 0] = 1.0
860
+ out[alt_mask, 1] = 1.0
861
+ return out
862
+
863
+ def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
864
+ """Targets aligned with _encode_multilabel_inputs for diploid training."""
865
+ if self.is_haploid:
866
+ raise RuntimeError("_multi_hot_targets called for haploid data.")
867
+ y = y.to(self.device)
868
+ out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
869
+ valid = y != -1
870
+ ref_mask = valid & (y != 2)
871
+ alt_mask = valid & (y != 0)
872
+ out[ref_mask, 0] = 1.0
873
+ out[alt_mask, 1] = 1.0
874
+ return out
875
+
876
+ def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
877
+ """Balance REF/ALT channels for multilabel BCE."""
878
+ ref_pos = np.count_nonzero((X == 0) | (X == 1))
879
+ alt_pos = np.count_nonzero((X == 2) | (X == 1))
880
+ total_valid = np.count_nonzero(X != -1)
881
+ pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
882
+ neg_counts = np.maximum(total_valid - pos_counts, 1.0)
883
+ pos_counts = np.maximum(pos_counts, 1.0)
884
+ weights = neg_counts / pos_counts
885
+ return torch.tensor(weights, device=self.device, dtype=torch.float32)
886
+
801
887
  def _create_latent_space(
802
888
  self,
803
889
  params: dict,
@@ -894,71 +980,80 @@ class ImputeNLPCA(BaseNNImputer):
894
980
  Returns:
895
981
  float: The value of the tuning metric to be minimized or maximized.
896
982
  """
897
- self._prepare_tuning_artifacts()
898
- trial_params = self._sample_hyperparameters(trial)
899
- model_params = trial_params["model_params"]
900
-
901
- nfeat = self._tune_num_features
902
- if self.tune and self.tune_fast:
903
- model_params["n_features"] = nfeat
904
-
905
- lr = trial_params["lr"]
906
- l1_penalty = trial_params["l1_penalty"]
907
- lr_input_fac = trial_params["lr_input_factor"]
908
-
909
- X_train_trial = self._tune_X_train
910
- X_test_trial = self._tune_X_test
911
- class_weights = self._tune_class_weights
912
- train_loader = self._tune_loader
913
-
914
- train_latents = self._create_latent_space(
915
- model_params, len(X_train_trial), X_train_trial, trial_params["latent_init"]
916
- )
983
+ try:
984
+ self._prepare_tuning_artifacts()
985
+ trial_params = self._sample_hyperparameters(trial)
986
+ model_params = trial_params["model_params"]
987
+
988
+ nfeat = self._tune_num_features
989
+ if self.tune and self.tune_fast:
990
+ model_params["n_features"] = nfeat
991
+
992
+ lr = trial_params["lr"]
993
+ l1_penalty = trial_params["l1_penalty"]
994
+ lr_input_fac = trial_params["lr_input_factor"]
995
+
996
+ X_train_trial = self._tune_X_train
997
+ X_test_trial = self._tune_X_test
998
+ class_weights = self._tune_class_weights
999
+ train_loader = self._tune_loader
1000
+
1001
+ train_latents = self._create_latent_space(
1002
+ model_params,
1003
+ len(X_train_trial),
1004
+ X_train_trial,
1005
+ trial_params["latent_init"],
1006
+ )
917
1007
 
918
- model = self.build_model(self.Model, model_params)
919
- model.n_features = model_params["n_features"]
920
- model.apply(self.initialize_weights)
1008
+ model = self.build_model(self.Model, model_params)
1009
+ model.n_features = model_params["n_features"]
1010
+ model.apply(self.initialize_weights)
921
1011
 
922
- _, model, __ = self._train_and_validate_model(
923
- model=model,
924
- loader=train_loader,
925
- lr=lr,
926
- l1_penalty=l1_penalty,
927
- trial=trial,
928
- latent_vectors=train_latents,
929
- lr_input_factor=lr_input_fac,
930
- class_weights=class_weights,
931
- X_val=X_test_trial,
932
- params=model_params,
933
- prune_metric=self.tune_metric,
934
- prune_warmup_epochs=5,
935
- eval_interval=self.tune_eval_interval,
936
- eval_latent_steps=self.eval_latent_steps,
937
- eval_latent_lr=self.eval_latent_lr,
938
- eval_latent_weight_decay=self.eval_latent_weight_decay,
939
- )
1012
+ _, model, __ = self._train_and_validate_model(
1013
+ model=model,
1014
+ loader=train_loader,
1015
+ lr=lr,
1016
+ l1_penalty=l1_penalty,
1017
+ trial=trial,
1018
+ latent_vectors=train_latents,
1019
+ lr_input_factor=lr_input_fac,
1020
+ class_weights=class_weights,
1021
+ X_val=X_test_trial,
1022
+ params=model_params,
1023
+ prune_metric=self.tune_metric,
1024
+ prune_warmup_epochs=10,
1025
+ eval_interval=self.tune_eval_interval,
1026
+ eval_latent_steps=self.eval_latent_steps,
1027
+ eval_latent_lr=self.eval_latent_lr,
1028
+ eval_latent_weight_decay=self.eval_latent_weight_decay,
1029
+ )
940
1030
 
941
- # --- simulate-only eval mask for tuning ---
942
- eval_mask = None
943
- if (
944
- self.simulate_missing
945
- and getattr(self, "sim_mask_global_", None) is not None
946
- ):
947
- if hasattr(self, "_tune_test_idx") and self.sim_mask_global_ is not None:
948
- eval_mask = self.sim_mask_global_[self._tune_test_idx]
949
- elif getattr(self, "sim_mask_test_", None) is not None:
950
- eval_mask = self.sim_mask_test_
951
-
952
- metrics = self._evaluate_model(
953
- X_test_trial,
954
- model,
955
- model_params,
956
- objective_mode=True,
957
- eval_mask_override=eval_mask,
958
- )
1031
+ # --- simulate-only eval mask for tuning ---
1032
+ eval_mask = None
1033
+ if (
1034
+ self.simulate_missing
1035
+ and getattr(self, "sim_mask_global_", None) is not None
1036
+ ):
1037
+ if (
1038
+ hasattr(self, "_tune_test_idx")
1039
+ and self.sim_mask_global_ is not None
1040
+ ):
1041
+ eval_mask = self.sim_mask_global_[self._tune_test_idx]
1042
+ elif getattr(self, "sim_mask_test_", None) is not None:
1043
+ eval_mask = self.sim_mask_test_
1044
+
1045
+ metrics = self._evaluate_model(
1046
+ X_test_trial,
1047
+ model,
1048
+ model_params,
1049
+ objective_mode=True,
1050
+ eval_mask_override=eval_mask,
1051
+ )
959
1052
 
960
- self._clear_resources(model, train_loader, latent_vectors=train_latents)
961
- return metrics[self.tune_metric]
1053
+ self._clear_resources(model, train_loader, latent_vectors=train_latents)
1054
+ return metrics[self.tune_metric]
1055
+ except Exception as e:
1056
+ raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
962
1057
 
963
1058
  def _sample_hyperparameters(self, trial: optuna.Trial) -> Dict[str, Any]:
964
1059
  """Samples hyperparameters for the simplified NLPCA model.
@@ -972,23 +1067,23 @@ class ImputeNLPCA(BaseNNImputer):
972
1067
  Dict[str, int | float | str | list]: A dictionary of sampled hyperparameters.
973
1068
  """
974
1069
  params = {
975
- "latent_dim": trial.suggest_int("latent_dim", 2, 32),
976
- "lr": trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True),
977
- "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.05),
978
- "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 16),
1070
+ "latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
1071
+ "lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
1072
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30),
1073
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 4),
979
1074
  "activation": trial.suggest_categorical(
980
- "activation", ["relu", "elu", "selu", "leaky_relu"]
1075
+ "activation", ["relu", "elu", "selu"]
981
1076
  ),
982
- "gamma": trial.suggest_float("gamma", 0.1, 5.0, step=0.1),
1077
+ "gamma": trial.suggest_float("gamma", 0.5, 3.0, step=0.5),
983
1078
  "lr_input_factor": trial.suggest_float(
984
- "lr_input_factor", 0.1, 10.0, log=True
1079
+ "lr_input_factor", 0.3, 3.0, log=True
985
1080
  ),
986
1081
  "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
987
1082
  "layer_scaling_factor": trial.suggest_float(
988
- "layer_scaling_factor", 2.0, 10.0
1083
+ "layer_scaling_factor", 2.0, 4.0, step=0.5
989
1084
  ),
990
1085
  "layer_schedule": trial.suggest_categorical(
991
- "layer_schedule", ["pyramid", "constant", "linear"]
1086
+ "layer_schedule", ["pyramid", "linear"]
992
1087
  ),
993
1088
  "latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
994
1089
  }
@@ -1006,7 +1101,7 @@ class ImputeNLPCA(BaseNNImputer):
1006
1101
 
1007
1102
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1008
1103
  n_inputs=params["latent_dim"],
1009
- n_outputs=use_n_features * self.num_classes_,
1104
+ n_outputs=use_n_features * self.output_classes_,
1010
1105
  n_samples=use_n_samples,
1011
1106
  n_hidden=params["num_hidden_layers"],
1012
1107
  alpha=params["layer_scaling_factor"],
@@ -1015,7 +1110,7 @@ class ImputeNLPCA(BaseNNImputer):
1015
1110
 
1016
1111
  params["model_params"] = {
1017
1112
  "n_features": use_n_features,
1018
- "num_classes": self.num_classes_,
1113
+ "num_classes": self.output_classes_,
1019
1114
  "latent_dim": params["latent_dim"],
1020
1115
  "dropout_rate": params["dropout_rate"],
1021
1116
  "hidden_layer_sizes": hidden_layer_sizes,
@@ -1043,10 +1138,11 @@ class ImputeNLPCA(BaseNNImputer):
1043
1138
  self.lr_input_factor = best_params["lr_input_factor"]
1044
1139
  self.l1_penalty = best_params["l1_penalty"]
1045
1140
  self.activation = best_params["activation"]
1141
+ self.latent_init = best_params["latent_init"]
1046
1142
 
1047
1143
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1048
1144
  n_inputs=self.latent_dim,
1049
- n_outputs=self.num_features_ * self.num_classes_,
1145
+ n_outputs=self.num_features_ * self.output_classes_,
1050
1146
  n_samples=len(self.train_idx_),
1051
1147
  n_hidden=best_params["num_hidden_layers"],
1052
1148
  alpha=best_params["layer_scaling_factor"],
@@ -1060,7 +1156,7 @@ class ImputeNLPCA(BaseNNImputer):
1060
1156
  "dropout_rate": self.dropout_rate,
1061
1157
  "activation": self.activation,
1062
1158
  "gamma": self.gamma,
1063
- "num_classes": self.num_classes_,
1159
+ "num_classes": self.output_classes_,
1064
1160
  }
1065
1161
 
1066
1162
  def _set_best_params_default(self) -> Dict[str, int | float | str | list]:
@@ -1073,7 +1169,7 @@ class ImputeNLPCA(BaseNNImputer):
1073
1169
  """
1074
1170
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1075
1171
  n_inputs=self.latent_dim,
1076
- n_outputs=self.num_features_ * self.num_classes_,
1172
+ n_outputs=self.num_features_ * self.output_classes_,
1077
1173
  n_samples=len(self.ground_truth_),
1078
1174
  n_hidden=self.num_hidden_layers,
1079
1175
  alpha=self.layer_scaling_factor,
@@ -1087,7 +1183,7 @@ class ImputeNLPCA(BaseNNImputer):
1087
1183
  "dropout_rate": self.dropout_rate,
1088
1184
  "activation": self.activation,
1089
1185
  "gamma": self.gamma,
1090
- "num_classes": self.num_classes_,
1186
+ "num_classes": self.output_classes_,
1091
1187
  }
1092
1188
 
1093
1189
  def _train_and_validate_model(
@@ -1105,7 +1201,7 @@ class ImputeNLPCA(BaseNNImputer):
1105
1201
  X_val: np.ndarray | None = None,
1106
1202
  params: dict | None = None,
1107
1203
  prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
1108
- prune_warmup_epochs: int = 3,
1204
+ prune_warmup_epochs: int = 10,
1109
1205
  eval_interval: int = 1,
1110
1206
  eval_latent_steps: int = 50,
1111
1207
  eval_latent_lr: float = 1e-2,
@@ -1222,7 +1318,7 @@ class ImputeNLPCA(BaseNNImputer):
1222
1318
  X_val=self.X_test_,
1223
1319
  params=best_params,
1224
1320
  prune_metric=self.tune_metric,
1225
- prune_warmup_epochs=5,
1321
+ prune_warmup_epochs=10,
1226
1322
  eval_interval=1,
1227
1323
  eval_latent_steps=self.eval_latent_steps,
1228
1324
  eval_latent_lr=self.eval_latent_lr,
@@ -1255,7 +1351,7 @@ class ImputeNLPCA(BaseNNImputer):
1255
1351
  X_val: np.ndarray | None = None,
1256
1352
  params: dict | None = None,
1257
1353
  prune_metric: str | None = None,
1258
- prune_warmup_epochs: int = 3,
1354
+ prune_warmup_epochs: int = 10,
1259
1355
  eval_interval: int = 1,
1260
1356
  eval_latent_steps: int = 50,
1261
1357
  eval_latent_lr: float = 1e-2,
@@ -1435,17 +1531,25 @@ class ImputeNLPCA(BaseNNImputer):
1435
1531
  self.logger.error(msg)
1436
1532
  raise TypeError(msg)
1437
1533
 
1438
- logits = decoder(z).view(len(X_new), nF, self.num_classes_)
1534
+ logits = decoder(z).view(len(X_new), nF, self.output_classes_)
1439
1535
 
1440
1536
  if not torch.isfinite(logits).all():
1441
1537
  break
1442
1538
 
1443
- loss = F.cross_entropy(
1444
- logits.view(-1, self.num_classes_),
1445
- y.view(-1),
1446
- ignore_index=-1,
1447
- reduction="mean",
1448
- )
1539
+ if self.is_haploid:
1540
+ loss = F.cross_entropy(
1541
+ logits.view(-1, self.output_classes_),
1542
+ y.view(-1),
1543
+ ignore_index=-1,
1544
+ reduction="mean",
1545
+ )
1546
+ else:
1547
+ targets = self._multi_hot_targets(y)
1548
+ bce = F.binary_cross_entropy_with_logits(
1549
+ logits, targets, pos_weight=self.pos_weights_, reduction="none"
1550
+ )
1551
+ mask = (y != -1).unsqueeze(-1).float()
1552
+ loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
1449
1553
  if not torch.isfinite(loss):
1450
1554
  break
1451
1555
 
@@ -1499,7 +1603,7 @@ class ImputeNLPCA(BaseNNImputer):
1499
1603
  y = torch.from_numpy(X_val).long().to(self.device)
1500
1604
 
1501
1605
  latent_dim = self._first_linear_in_features(model)
1502
- cache_key = f"{self.prefix}_nlpca_val_latents_z{latent_dim}_L{self.num_features_}_K{self.num_classes_}"
1606
+ cache_key = f"{self.prefix}_nlpca_val_latents_z{latent_dim}_L{self.num_features_}_K{self.output_classes_}"
1503
1607
 
1504
1608
  if cache is not None and cache_key in cache:
1505
1609
  z = cache[cache_key].detach().clone().requires_grad_(True)
@@ -1523,17 +1627,25 @@ class ImputeNLPCA(BaseNNImputer):
1523
1627
  self.logger.error(msg)
1524
1628
  raise TypeError(msg)
1525
1629
 
1526
- logits = decoder(z).view(X_val.shape[0], nF, self.num_classes_)
1630
+ logits = decoder(z).view(X_val.shape[0], nF, self.output_classes_)
1527
1631
 
1528
1632
  if not torch.isfinite(logits).all():
1529
1633
  break
1530
1634
 
1531
- loss = F.cross_entropy(
1532
- logits.view(-1, self.num_classes_),
1533
- y.view(-1),
1534
- ignore_index=-1,
1535
- reduction="mean",
1536
- )
1635
+ if self.is_haploid:
1636
+ loss = F.cross_entropy(
1637
+ logits.view(-1, self.output_classes_),
1638
+ y.view(-1),
1639
+ ignore_index=-1,
1640
+ reduction="mean",
1641
+ )
1642
+ else:
1643
+ targets = self._multi_hot_targets(y)
1644
+ bce = F.binary_cross_entropy_with_logits(
1645
+ logits, targets, pos_weight=self.pos_weights_, reduction="none"
1646
+ )
1647
+ mask = (y != -1).unsqueeze(-1).float()
1648
+ loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
1537
1649
 
1538
1650
  if not torch.isfinite(loss):
1539
1651
  break