pg-sui 1.6.14.dev9__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,7 @@
1
1
  import copy
2
2
  from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
3
3
 
4
+ from fastapi import params
4
5
  import numpy as np
5
6
  import optuna
6
7
  import torch
@@ -160,6 +161,7 @@ class ImputeUBP(BaseNNImputer):
160
161
  self.verbose = self.cfg.io.verbose
161
162
  self.debug = self.cfg.io.debug
162
163
  self.rng = np.random.default_rng(self.seed)
164
+ self.pos_weights_: torch.Tensor | None = None
163
165
 
164
166
  # Simulated-missing controls (config defaults w/ overrides)
165
167
  sim_cfg = getattr(self.cfg, "sim", None)
@@ -320,15 +322,17 @@ class ImputeUBP(BaseNNImputer):
320
322
  else:
321
323
  self.num_classes_ = 3
322
324
  self.logger.info(
323
- "Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2)."
325
+ "Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2) for scoring."
324
326
  )
327
+ # Model head always uses two channels; scoring uses num_classes_
328
+ self.output_classes_ = 2
325
329
 
326
330
  n_samples, self.num_features_ = X_for_model.shape
327
331
 
328
332
  # --- model params (decoder: Z -> L * num_classes) ---
329
333
  self.model_params = {
330
334
  "n_features": self.num_features_,
331
- "num_classes": self.num_classes_,
335
+ "num_classes": self.output_classes_,
332
336
  "latent_dim": self.latent_dim,
333
337
  "dropout_rate": self.dropout_rate,
334
338
  "activation": self.activation,
@@ -353,6 +357,12 @@ class ImputeUBP(BaseNNImputer):
353
357
  self.sim_mask_train_ = None
354
358
  self.sim_mask_test_ = None
355
359
 
360
+ # pos weights for diploid multilabel path
361
+ if not self.is_haploid:
362
+ self.pos_weights_ = self._compute_pos_weights(self.X_train_)
363
+ else:
364
+ self.pos_weights_ = None
365
+
356
366
  # --- plotting/scorers & tuning ---
357
367
  self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
358
368
  if self.tune:
@@ -485,14 +495,22 @@ class ImputeUBP(BaseNNImputer):
485
495
  y = y_batch.to(self.device, non_blocking=True).long()
486
496
 
487
497
  logits = decoder(z).view(
488
- len(batch_indices), self.num_features_, self.num_classes_
498
+ len(batch_indices), self.num_features_, self.output_classes_
489
499
  )
490
500
 
491
501
  # Guard upstream explosions
492
502
  if not torch.isfinite(logits).all():
493
503
  continue
494
504
 
495
- loss = criterion(logits.view(-1, self.num_classes_), y.view(-1))
505
+ if self.is_haploid:
506
+ loss = criterion(logits.view(-1, self.output_classes_), y.view(-1))
507
+ else:
508
+ targets = self._multi_hot_targets(y)
509
+ bce = F.binary_cross_entropy_with_logits(
510
+ logits, targets, pos_weight=self.pos_weights_, reduction="none"
511
+ )
512
+ mask = (y != -1).unsqueeze(-1).float()
513
+ loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
496
514
 
497
515
  if l1_penalty > 0:
498
516
  l1 = torch.zeros((), device=self.device)
@@ -555,10 +573,21 @@ class ImputeUBP(BaseNNImputer):
555
573
  raise TypeError(msg)
556
574
 
557
575
  logits = decoder(latent_vectors.to(self.device)).view(
558
- len(latent_vectors), nF, self.num_classes_
576
+ len(latent_vectors), nF, self.output_classes_
559
577
  )
560
- probas = torch.softmax(logits, dim=-1)
561
- labels = torch.argmax(probas, dim=-1)
578
+ if self.is_haploid:
579
+ probas = torch.softmax(logits, dim=-1)
580
+ labels = torch.argmax(probas, dim=-1)
581
+ else:
582
+ probas2 = torch.sigmoid(logits)
583
+ p_ref = probas2[..., 0]
584
+ p_alt = probas2[..., 1]
585
+ p_het = p_ref * p_alt
586
+ p_ref_only = p_ref * (1 - p_alt)
587
+ p_alt_only = p_alt * (1 - p_ref)
588
+ probas = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
589
+ probas = probas / probas.sum(dim=-1, keepdim=True).clamp_min(1e-8)
590
+ labels = torch.argmax(probas, dim=-1)
562
591
 
563
592
  return labels.cpu().numpy(), probas.cpu().numpy()
564
593
 
@@ -752,9 +781,18 @@ class ImputeUBP(BaseNNImputer):
752
781
  )
753
782
 
754
783
  def _objective(self, trial: optuna.Trial) -> float:
755
- """Optuna objective using the UBP training loop."""
784
+ """Optuna objective using the UBP training loop.
785
+
786
+ This method defines the objective function for hyperparameter tuning using Optuna. It prepares the necessary artifacts for tuning, samples a set of hyperparameters for the current trial, and trains the UBP model using these hyperparameters. The model is evaluated on a validation set, and the specified tuning metric is returned as the objective value. If any exception occurs during the process, the trial is pruned.
787
+ """
756
788
  try:
757
- params = self._sample_hyperparameters(trial)
789
+ self._prepare_tuning_artifacts()
790
+ trial_params = self._sample_hyperparameters(trial)
791
+ model_params = trial_params["model_params"]
792
+
793
+ nfeat = self._tune_num_features
794
+ if self.tune and self.tune_fast:
795
+ model_params["n_features"] = nfeat
758
796
 
759
797
  X_train_trial = getattr(
760
798
  self, "X_train_", self.ground_truth_[self.train_idx_]
@@ -764,31 +802,38 @@ class ImputeUBP(BaseNNImputer):
764
802
  class_weights = self._normalize_class_weights(
765
803
  self._class_weights_from_zygosity(X_train_trial)
766
804
  )
805
+ if not self.is_haploid:
806
+ self.pos_weights_ = self._compute_pos_weights(X_train_trial)
807
+ else:
808
+ self.pos_weights_ = None
767
809
  train_loader = self._get_data_loaders(X_train_trial)
768
810
 
769
811
  train_latent_vectors = self._create_latent_space(
770
- params, len(X_train_trial), X_train_trial, params["latent_init"]
812
+ model_params,
813
+ len(X_train_trial),
814
+ X_train_trial,
815
+ trial_params["latent_init"],
771
816
  )
772
817
 
773
- model = self.build_model(self.Model, params["model_params"])
774
- model.n_features = params["model_params"]["n_features"]
818
+ model = self.build_model(self.Model, model_params)
819
+ model.n_features = model_params["n_features"]
775
820
  model.apply(self.initialize_weights)
776
821
 
777
822
  _, model, __ = self._train_and_validate_model(
778
823
  model=model,
779
824
  loader=train_loader,
780
- lr=params["lr"],
781
- l1_penalty=params["l1_penalty"],
825
+ lr=trial_params["lr"],
826
+ l1_penalty=trial_params["l1_penalty"],
782
827
  trial=trial,
783
828
  return_history=False,
784
829
  latent_vectors=train_latent_vectors,
785
- lr_input_factor=params["lr_input_factor"],
830
+ lr_input_factor=trial_params["lr_input_factor"],
786
831
  class_weights=class_weights,
787
832
  X_val=X_test_trial,
788
- params=params,
833
+ params=model_params,
789
834
  prune_metric=self.tune_metric,
790
- prune_warmup_epochs=5,
791
- eval_interval=1,
835
+ prune_warmup_epochs=10,
836
+ eval_interval=self.tune_eval_interval,
792
837
  eval_requires_latents=True,
793
838
  eval_latent_steps=self.eval_latent_steps,
794
839
  eval_latent_lr=self.eval_latent_lr,
@@ -806,7 +851,7 @@ class ImputeUBP(BaseNNImputer):
806
851
  metrics = self._evaluate_model(
807
852
  X_test_trial,
808
853
  model,
809
- params,
854
+ model_params,
810
855
  objective_mode=True,
811
856
  eval_mask_override=eval_mask,
812
857
  )
@@ -829,30 +874,30 @@ class ImputeUBP(BaseNNImputer):
829
874
  Dict[str, int | float | str | list]: Sampled hyperparameters.
830
875
  """
831
876
  params = {
832
- "latent_dim": trial.suggest_int("latent_dim", 2, 32),
833
- "lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
834
- "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
835
- "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
877
+ "latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
878
+ "lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
879
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30, step=0.05),
880
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 6),
836
881
  "activation": trial.suggest_categorical(
837
- "activation", ["relu", "elu", "selu"]
882
+ "activation", ["relu", "elu", "selu", "leaky_relu"]
838
883
  ),
839
- "gamma": trial.suggest_float("gamma", 0.0, 5.0),
884
+ "gamma": trial.suggest_float("gamma", 0.5, 3.0, step=0.5),
840
885
  "lr_input_factor": trial.suggest_float(
841
- "lr_input_factor", 0.1, 10.0, log=True
886
+ "lr_input_factor", 0.3, 3.0, log=True
842
887
  ),
843
- "l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
888
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
844
889
  "layer_scaling_factor": trial.suggest_float(
845
- "layer_scaling_factor", 2.0, 10.0
890
+ "layer_scaling_factor", 2.0, 4.0, step=0.5
846
891
  ),
847
892
  "layer_schedule": trial.suggest_categorical(
848
- "layer_schedule", ["pyramid", "constant", "linear"]
893
+ "layer_schedule", ["pyramid", "linear"]
849
894
  ),
850
895
  "latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
851
896
  }
852
897
 
853
898
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
854
899
  n_inputs=params["latent_dim"],
855
- n_outputs=self.num_features_ * self.num_classes_,
900
+ n_outputs=self.num_features_ * self.output_classes_,
856
901
  n_samples=len(self.train_idx_),
857
902
  n_hidden=params["num_hidden_layers"],
858
903
  alpha=params["layer_scaling_factor"],
@@ -866,7 +911,7 @@ class ImputeUBP(BaseNNImputer):
866
911
 
867
912
  params["model_params"] = {
868
913
  "n_features": self.num_features_,
869
- "num_classes": self.num_classes_,
914
+ "num_classes": self.output_classes_,
870
915
  "latent_dim": params["latent_dim"],
871
916
  "dropout_rate": params["dropout_rate"],
872
917
  "hidden_layer_sizes": hidden_only,
@@ -900,7 +945,7 @@ class ImputeUBP(BaseNNImputer):
900
945
 
901
946
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
902
947
  n_inputs=self.latent_dim,
903
- n_outputs=self.num_features_ * self.num_classes_,
948
+ n_outputs=self.num_features_ * self.output_classes_,
904
949
  n_samples=len(self.train_idx_),
905
950
  n_hidden=best_params["num_hidden_layers"],
906
951
  alpha=best_params["layer_scaling_factor"],
@@ -916,7 +961,7 @@ class ImputeUBP(BaseNNImputer):
916
961
  "dropout_rate": self.dropout_rate,
917
962
  "activation": self.activation,
918
963
  "gamma": self.gamma,
919
- "num_classes": self.num_classes_,
964
+ "num_classes": self.output_classes_,
920
965
  }
921
966
 
922
967
  def _set_best_params_default(self) -> dict:
@@ -929,7 +974,7 @@ class ImputeUBP(BaseNNImputer):
929
974
  """
930
975
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
931
976
  n_inputs=self.latent_dim,
932
- n_outputs=self.num_features_ * self.num_classes_,
977
+ n_outputs=self.num_features_ * self.output_classes_,
933
978
  n_samples=len(self.ground_truth_),
934
979
  n_hidden=self.num_hidden_layers,
935
980
  alpha=self.layer_scaling_factor,
@@ -945,7 +990,7 @@ class ImputeUBP(BaseNNImputer):
945
990
  "dropout_rate": self.dropout_rate,
946
991
  "activation": self.activation,
947
992
  "gamma": self.gamma,
948
- "num_classes": self.num_classes_,
993
+ "num_classes": self.output_classes_,
949
994
  }
950
995
 
951
996
  def _train_and_validate_model(
@@ -963,7 +1008,7 @@ class ImputeUBP(BaseNNImputer):
963
1008
  X_val: np.ndarray | None = None,
964
1009
  params: dict | None = None,
965
1010
  prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
966
- prune_warmup_epochs: int = 3,
1011
+ prune_warmup_epochs: int = 10,
967
1012
  eval_interval: int = 1,
968
1013
  eval_requires_latents: bool = True, # UBP needs latent eval
969
1014
  eval_latent_steps: int = 50,
@@ -1073,7 +1118,7 @@ class ImputeUBP(BaseNNImputer):
1073
1118
  X_val=self.X_test_,
1074
1119
  params=best_params,
1075
1120
  prune_metric=self.tune_metric,
1076
- prune_warmup_epochs=5,
1121
+ prune_warmup_epochs=10,
1077
1122
  eval_interval=1,
1078
1123
  eval_requires_latents=True,
1079
1124
  eval_latent_steps=self.eval_latent_steps,
@@ -1105,7 +1150,7 @@ class ImputeUBP(BaseNNImputer):
1105
1150
  X_val: np.ndarray | None = None,
1106
1151
  params: dict | None = None,
1107
1152
  prune_metric: str | None = None,
1108
- prune_warmup_epochs: int = 3,
1153
+ prune_warmup_epochs: int = 10,
1109
1154
  eval_interval: int = 1,
1110
1155
  eval_requires_latents: bool = True,
1111
1156
  eval_latent_steps: int = 50,
@@ -1152,7 +1197,7 @@ class ImputeUBP(BaseNNImputer):
1152
1197
  # Schema-aware latent cache for eval
1153
1198
  _latent_cache: dict = {}
1154
1199
  nF = getattr(model, "n_features", self.num_features_)
1155
- cache_key_root = f"{self.prefix}_ubp_val_latents_L{nF}_K{self.num_classes_}"
1200
+ cache_key_root = f"{self.prefix}_ubp_val_latents_L{nF}_K{self.output_classes_}"
1156
1201
 
1157
1202
  E = int(self.epochs)
1158
1203
  phase_epochs = {
@@ -1349,14 +1394,22 @@ class ImputeUBP(BaseNNImputer):
1349
1394
  raise TypeError(msg)
1350
1395
 
1351
1396
  opt.zero_grad(set_to_none=True)
1352
- logits = decoder(z).view(len(X_new), nF, self.num_classes_)
1397
+ logits = decoder(z).view(len(X_new), nF, self.output_classes_)
1353
1398
 
1354
1399
  if not torch.isfinite(logits).all():
1355
1400
  break
1356
1401
 
1357
- loss = F.cross_entropy(
1358
- logits.view(-1, self.num_classes_), y.view(-1), ignore_index=-1
1359
- )
1402
+ if self.is_haploid:
1403
+ loss = F.cross_entropy(
1404
+ logits.view(-1, self.output_classes_), y.view(-1), ignore_index=-1
1405
+ )
1406
+ else:
1407
+ targets = self._multi_hot_targets(y)
1408
+ bce = F.binary_cross_entropy_with_logits(
1409
+ logits, targets, pos_weight=self.pos_weights_, reduction="none"
1410
+ )
1411
+ mask = (y != -1).unsqueeze(-1).float()
1412
+ loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
1360
1413
 
1361
1414
  if not torch.isfinite(loss):
1362
1415
  break
@@ -1457,8 +1510,31 @@ class ImputeUBP(BaseNNImputer):
1457
1510
  else:
1458
1511
  latents = torch.empty(n_samples, latent_dim, device=self.device)
1459
1512
  torch.nn.init.xavier_uniform_(latents)
1513
+ return torch.nn.Parameter(latents, requires_grad=True)
1460
1514
 
1461
- return torch.nn.Parameter(latents, requires_grad=True)
1515
+ def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
1516
+ """Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
1517
+ if self.is_haploid:
1518
+ raise RuntimeError("_multi_hot_targets called for haploid data.")
1519
+ y = y.to(self.device)
1520
+ out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
1521
+ valid = y != -1
1522
+ ref_mask = valid & (y != 2)
1523
+ alt_mask = valid & (y != 0)
1524
+ out[ref_mask, 0] = 1.0
1525
+ out[alt_mask, 1] = 1.0
1526
+ return out
1527
+
1528
+ def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
1529
+ """Balance REF/ALT channels for multilabel BCE."""
1530
+ ref_pos = np.count_nonzero((X == 0) | (X == 1))
1531
+ alt_pos = np.count_nonzero((X == 2) | (X == 1))
1532
+ total_valid = np.count_nonzero(X != -1)
1533
+ pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
1534
+ neg_counts = np.maximum(total_valid - pos_counts, 1.0)
1535
+ pos_counts = np.maximum(pos_counts, 1.0)
1536
+ weights = neg_counts / pos_counts
1537
+ return torch.tensor(weights, device=self.device, dtype=torch.float32)
1462
1538
 
1463
1539
  def _reset_weights(self, model: torch.nn.Module) -> None:
1464
1540
  """Selectively resets only the weights of the phase 2/3 decoder.
@@ -1527,7 +1603,9 @@ class ImputeUBP(BaseNNImputer):
1527
1603
  y = torch.from_numpy(X_val).long().to(self.device)
1528
1604
 
1529
1605
  zdim = self._first_linear_in_features(model)
1530
- schema_key = f"{self.prefix}_ubp_val_latents_z{zdim}_L{nF}_K{self.num_classes_}"
1606
+ schema_key = (
1607
+ f"{self.prefix}_ubp_val_latents_z{zdim}_L{nF}_K{self.output_classes_}"
1608
+ )
1531
1609
 
1532
1610
  if cache is not None and schema_key in cache:
1533
1611
  z = cache[schema_key].detach().clone().requires_grad_(True)
@@ -1548,13 +1626,20 @@ class ImputeUBP(BaseNNImputer):
1548
1626
  self.logger.error(msg)
1549
1627
  raise TypeError(msg)
1550
1628
 
1551
- logits = decoder(z).view(X_val.shape[0], nF, self.num_classes_)
1629
+ logits = decoder(z).view(X_val.shape[0], nF, self.output_classes_)
1552
1630
  if not torch.isfinite(logits).all():
1553
1631
  break
1554
-
1555
- loss = F.cross_entropy(
1556
- logits.view(-1, self.num_classes_), y.view(-1), ignore_index=-1
1557
- )
1632
+ if self.is_haploid:
1633
+ loss = F.cross_entropy(
1634
+ logits.view(-1, self.output_classes_), y.view(-1), ignore_index=-1
1635
+ )
1636
+ else:
1637
+ targets = self._multi_hot_targets(y)
1638
+ bce = F.binary_cross_entropy_with_logits(
1639
+ logits, targets, pos_weight=self.pos_weights_, reduction="none"
1640
+ )
1641
+ mask = (y != -1).unsqueeze(-1).float()
1642
+ loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
1558
1643
 
1559
1644
  if not torch.isfinite(loss):
1560
1645
  break