pg-sui 1.6.14.dev9__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,6 +5,7 @@ import matplotlib.pyplot as plt
5
5
  import numpy as np
6
6
  import optuna
7
7
  import torch
8
+ import torch.nn.functional as F
8
9
  from sklearn.exceptions import NotFittedError
9
10
  from sklearn.model_selection import train_test_split
10
11
  from snpio.analysis.genotype_encoder import GenotypeEncoder
@@ -162,6 +163,7 @@ class ImputeAutoencoder(BaseNNImputer):
162
163
  self.verbose = self.cfg.io.verbose
163
164
  self.debug = self.cfg.io.debug
164
165
  self.rng = np.random.default_rng(self.seed)
166
+ self.pos_weights_: torch.Tensor | None = None
165
167
 
166
168
  # Simulated-missing controls (config defaults with ctor overrides)
167
169
  sim_cfg = getattr(self.cfg, "sim", None)
@@ -330,10 +332,12 @@ class ImputeAutoencoder(BaseNNImputer):
330
332
  )
331
333
  )
332
334
  self.ploidy = 1 if self.is_haploid else 2
335
+ # Scoring still uses 3 labels for diploid (REF/HET/ALT); model head uses 2 logits
333
336
  self.num_classes_ = 2 if self.is_haploid else 3
337
+ self.output_classes_ = 2
334
338
  self.logger.info(
335
339
  f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
336
- f"using {self.num_classes_} classes."
340
+ f"using {self.num_classes_} classes for scoring and {self.output_classes_} output channels."
337
341
  )
338
342
 
339
343
  if self.is_haploid:
@@ -345,7 +349,7 @@ class ImputeAutoencoder(BaseNNImputer):
345
349
  # Model params (decoder outputs L * K logits)
346
350
  self.model_params = {
347
351
  "n_features": self.num_features_,
348
- "num_classes": self.num_classes_,
352
+ "num_classes": self.output_classes_,
349
353
  "latent_dim": self.latent_dim,
350
354
  "dropout_rate": self.dropout_rate,
351
355
  "activation": self.activation,
@@ -369,6 +373,12 @@ class ImputeAutoencoder(BaseNNImputer):
369
373
  self.sim_mask_train_ = None
370
374
  self.sim_mask_test_ = None
371
375
 
376
+ # Pos weights for diploid multilabel path (must exist before tuning)
377
+ if not self.is_haploid:
378
+ self.pos_weights_ = self._compute_pos_weights(self.X_train_)
379
+ else:
380
+ self.pos_weights_ = None
381
+
372
382
  # Plotters/scorers (shared utilities)
373
383
  self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
374
384
 
@@ -401,7 +411,7 @@ class ImputeAutoencoder(BaseNNImputer):
401
411
  X_val=self.X_val_,
402
412
  params=self.best_params_,
403
413
  prune_metric=self.tune_metric,
404
- prune_warmup_epochs=5,
414
+ prune_warmup_epochs=10,
405
415
  eval_interval=1,
406
416
  eval_requires_latents=False,
407
417
  eval_latent_steps=0,
@@ -509,7 +519,7 @@ class ImputeAutoencoder(BaseNNImputer):
509
519
  X_val: np.ndarray | None = None,
510
520
  params: dict | None = None,
511
521
  prune_metric: str = "f1", # "f1" | "accuracy" | "pr_macro"
512
- prune_warmup_epochs: int = 3,
522
+ prune_warmup_epochs: int = 10,
513
523
  eval_interval: int = 1,
514
524
  # Evaluation parameters (AE ignores latent refinement knobs)
515
525
  eval_requires_latents: bool = False, # AE: always False
@@ -593,7 +603,7 @@ class ImputeAutoencoder(BaseNNImputer):
593
603
  X_val: np.ndarray | None = None,
594
604
  params: dict | None = None,
595
605
  prune_metric: str = "f1",
596
- prune_warmup_epochs: int = 3,
606
+ prune_warmup_epochs: int = 10,
597
607
  eval_interval: int = 1,
598
608
  # Evaluation parameters (AE ignores latent refinement knobs)
599
609
  eval_requires_latents: bool = False, # AE: False
@@ -761,24 +771,35 @@ class ImputeAutoencoder(BaseNNImputer):
761
771
  # Use model.gamma if present, else self.gamma
762
772
  gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
763
773
  gamma = float(torch.tensor(gamma).clamp(min=0.0, max=10.0)) # sane bound
764
- criterion = SafeFocalCELoss(gamma=gamma, weight=class_weights, ignore_index=-1)
774
+ ce_criterion = SafeFocalCELoss(
775
+ gamma=gamma, weight=class_weights, ignore_index=-1
776
+ )
765
777
 
766
778
  for _, y_batch in loader:
767
779
  optimizer.zero_grad(set_to_none=True)
768
780
  y_batch = y_batch.to(self.device, non_blocking=True)
769
781
 
770
782
  # Inputs: one-hot with zeros for missing; Targets: long ints with -1 for missing
771
- x_ohe = self._one_hot_encode_012(y_batch) # (B, L, K)
772
- logits = model(x_ohe).view(-1, self.num_features_, self.num_classes_)
773
- logits_flat = logits.view(-1, self.num_classes_)
774
- targets_flat = y_batch.view(-1).long()
775
-
776
- # Upfront guards on inputs
777
- if not torch.isfinite(logits_flat).all():
778
- # Skip this batch if model already produced non-finite
779
- continue
780
-
781
- loss = criterion(logits_flat, targets_flat)
783
+ if self.is_haploid:
784
+ x_in = self._one_hot_encode_012(y_batch) # (B, L, 2)
785
+ logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
786
+ logits_flat = logits.view(-1, self.output_classes_)
787
+ targets_flat = y_batch.view(-1).long()
788
+ if not torch.isfinite(logits_flat).all():
789
+ continue
790
+ loss = ce_criterion(logits_flat, targets_flat)
791
+ else:
792
+ x_in = self._encode_multilabel_inputs(y_batch) # (B, L, 2)
793
+ logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
794
+ if not torch.isfinite(logits).all():
795
+ continue
796
+ pos_w = getattr(self, "pos_weights_", None)
797
+ targets = self._multi_hot_targets(y_batch) # float, same shape
798
+ bce = F.binary_cross_entropy_with_logits(
799
+ logits, targets, pos_weight=pos_w, reduction="none"
800
+ )
801
+ mask = (y_batch != -1).unsqueeze(-1).float()
802
+ loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
782
803
 
783
804
  if l1_penalty > 0:
784
805
  l1 = torch.zeros((), device=self.device)
@@ -842,16 +863,69 @@ class ImputeAutoencoder(BaseNNImputer):
842
863
  with torch.no_grad():
843
864
  X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
844
865
  X_tensor = X_tensor.to(self.device).long()
845
- x_ohe = self._one_hot_encode_012(X_tensor)
846
- logits = model(x_ohe).view(-1, self.num_features_, self.num_classes_)
847
- probas = torch.softmax(logits, dim=-1)
848
- labels = torch.argmax(probas, dim=-1)
866
+ if self.is_haploid:
867
+ x_ohe = self._one_hot_encode_012(X_tensor)
868
+ logits = model(x_ohe).view(-1, self.num_features_, self.output_classes_)
869
+ probas = torch.softmax(logits, dim=-1)
870
+ labels = torch.argmax(probas, dim=-1)
871
+ else:
872
+ x_in = self._encode_multilabel_inputs(X_tensor)
873
+ logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
874
+ probas_2 = torch.sigmoid(logits)
875
+ p_ref = probas_2[..., 0]
876
+ p_alt = probas_2[..., 1]
877
+ p_het = p_ref * p_alt
878
+ p_ref_only = p_ref * (1 - p_alt)
879
+ p_alt_only = p_alt * (1 - p_ref)
880
+ stacked = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
881
+ stacked = stacked / stacked.sum(dim=-1, keepdim=True).clamp_min(1e-8)
882
+ probas = stacked
883
+ labels = torch.argmax(stacked, dim=-1)
849
884
 
850
885
  if return_proba:
851
886
  return labels.cpu().numpy(), probas.cpu().numpy()
852
887
 
853
888
  return labels.cpu().numpy()
854
889
 
890
+ def _encode_multilabel_inputs(self, y: torch.Tensor) -> torch.Tensor:
891
+ """Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
892
+ if self.is_haploid:
893
+ return self._one_hot_encode_012(y)
894
+ y = y.to(self.device)
895
+ shape = y.shape + (2,)
896
+ out = torch.zeros(shape, device=self.device, dtype=torch.float32)
897
+ valid = y != -1
898
+ ref_mask = valid & (y != 2)
899
+ alt_mask = valid & (y != 0)
900
+ out[ref_mask, 0] = 1.0
901
+ out[alt_mask, 1] = 1.0
902
+ return out
903
+
904
+ def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
905
+ """Targets aligned with _encode_multilabel_inputs for diploid training."""
906
+ if self.is_haploid:
907
+ # One-hot CE path expects integer targets; handled upstream.
908
+ raise RuntimeError("_multi_hot_targets called for haploid data.")
909
+ y = y.to(self.device)
910
+ out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
911
+ valid = y != -1
912
+ ref_mask = valid & (y != 2)
913
+ alt_mask = valid & (y != 0)
914
+ out[ref_mask, 0] = 1.0
915
+ out[alt_mask, 1] = 1.0
916
+ return out
917
+
918
+ def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
919
+ """Balance REF/ALT channels for multilabel BCE."""
920
+ ref_pos = np.count_nonzero((X == 0) | (X == 1))
921
+ alt_pos = np.count_nonzero((X == 2) | (X == 1))
922
+ total_valid = np.count_nonzero(X != -1)
923
+ pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
924
+ neg_counts = np.maximum(total_valid - pos_counts, 1.0)
925
+ pos_counts = np.maximum(pos_counts, 1.0)
926
+ weights = neg_counts / pos_counts
927
+ return torch.tensor(weights, device=self.device, dtype=torch.float32)
928
+
855
929
  def _evaluate_model(
856
930
  self,
857
931
  X_val: np.ndarray,
@@ -1090,7 +1164,7 @@ class ImputeAutoencoder(BaseNNImputer):
1090
1164
  X_val=X_val,
1091
1165
  params=params,
1092
1166
  prune_metric=self.tune_metric,
1093
- prune_warmup_epochs=5,
1167
+ prune_warmup_epochs=10,
1094
1168
  eval_interval=self.tune_eval_interval,
1095
1169
  eval_requires_latents=False,
1096
1170
  eval_latent_steps=0,
@@ -1137,24 +1211,24 @@ class ImputeAutoencoder(BaseNNImputer):
1137
1211
  Dict[str, int | float | str | bool]: Sampled hyperparameters and model_params.
1138
1212
  """
1139
1213
  params = {
1140
- "latent_dim": trial.suggest_int("latent_dim", 2, 64),
1141
- "lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
1142
- "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
1143
- "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
1214
+ "latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
1215
+ "lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
1216
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30, step=0.05),
1217
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 6),
1144
1218
  "activation": trial.suggest_categorical(
1145
- "activation", ["relu", "elu", "selu"]
1219
+ "activation", ["relu", "elu", "selu", "leaky_relu"]
1146
1220
  ),
1147
- "l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
1221
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
1148
1222
  "layer_scaling_factor": trial.suggest_float(
1149
- "layer_scaling_factor", 2.0, 10.0
1223
+ "layer_scaling_factor", 2.0, 4.0, step=0.5
1150
1224
  ),
1151
1225
  "layer_schedule": trial.suggest_categorical(
1152
- "layer_schedule", ["pyramid", "constant", "linear"]
1226
+ "layer_schedule", ["pyramid", "linear"]
1153
1227
  ),
1154
1228
  }
1155
1229
 
1156
1230
  nF: int = self.num_features_
1157
- nC: int = int(self.num_classes_) if self.num_classes_ is not None else 3
1231
+ nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
1158
1232
  input_dim = nF * nC
1159
1233
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1160
1234
  n_inputs=input_dim,
@@ -1173,8 +1247,8 @@ class ImputeAutoencoder(BaseNNImputer):
1173
1247
 
1174
1248
  params["model_params"] = {
1175
1249
  "n_features": int(self.num_features_),
1176
- "num_classes": (
1177
- int(self.num_classes_) if self.num_classes_ is not None else 3
1250
+ "num_classes": int(
1251
+ getattr(self, "output_classes_", self.num_classes_ or 3)
1178
1252
  ),
1179
1253
  "latent_dim": int(params["latent_dim"]),
1180
1254
  "dropout_rate": float(params["dropout_rate"]),
@@ -1227,7 +1301,7 @@ class ImputeAutoencoder(BaseNNImputer):
1227
1301
  self.layer_schedule: str = bp["layer_schedule"]
1228
1302
 
1229
1303
  nF: int = self.num_features_
1230
- nC: int = int(self.num_classes_) if self.num_classes_ is not None else 3
1304
+ nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
1231
1305
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1232
1306
  n_inputs=nF * nC,
1233
1307
  n_outputs=nF * nC,
@@ -1261,7 +1335,9 @@ class ImputeAutoencoder(BaseNNImputer):
1261
1335
  Dict[str, int | float | str | list]: Default model parameters.
1262
1336
  """
1263
1337
  nF: int = self.num_features_
1264
- nC: int = int(self.num_classes_) if self.num_classes_ is not None else 3
1338
+ # Use the number of output channels passed to the model (2 for diploid multilabel)
1339
+ # instead of the scoring classes (3) to keep layer shapes aligned.
1340
+ nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
1265
1341
  ls = self.layer_schedule
1266
1342
 
1267
1343
  if ls not in {"pyramid", "constant", "linear"}: