pg-sui 1.6.14.dev9__py3-none-any.whl → 1.6.16a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.6.16a3.dist-info/METADATA +292 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/RECORD +14 -14
- pgsui/_version.py +2 -2
- pgsui/cli.py +14 -1
- pgsui/data_processing/containers.py +116 -104
- pgsui/impute/unsupervised/base.py +4 -1
- pgsui/impute/unsupervised/imputers/autoencoder.py +111 -35
- pgsui/impute/unsupervised/imputers/nlpca.py +239 -127
- pgsui/impute/unsupervised/imputers/ubp.py +135 -50
- pgsui/impute/unsupervised/imputers/vae.py +134 -46
- pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
|
3
3
|
|
|
4
|
+
from fastapi import params
|
|
4
5
|
import numpy as np
|
|
5
6
|
import optuna
|
|
6
7
|
import torch
|
|
@@ -160,6 +161,7 @@ class ImputeUBP(BaseNNImputer):
|
|
|
160
161
|
self.verbose = self.cfg.io.verbose
|
|
161
162
|
self.debug = self.cfg.io.debug
|
|
162
163
|
self.rng = np.random.default_rng(self.seed)
|
|
164
|
+
self.pos_weights_: torch.Tensor | None = None
|
|
163
165
|
|
|
164
166
|
# Simulated-missing controls (config defaults w/ overrides)
|
|
165
167
|
sim_cfg = getattr(self.cfg, "sim", None)
|
|
@@ -320,15 +322,17 @@ class ImputeUBP(BaseNNImputer):
|
|
|
320
322
|
else:
|
|
321
323
|
self.num_classes_ = 3
|
|
322
324
|
self.logger.info(
|
|
323
|
-
"Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2)."
|
|
325
|
+
"Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2) for scoring."
|
|
324
326
|
)
|
|
327
|
+
# Model head always uses two channels; scoring uses num_classes_
|
|
328
|
+
self.output_classes_ = 2
|
|
325
329
|
|
|
326
330
|
n_samples, self.num_features_ = X_for_model.shape
|
|
327
331
|
|
|
328
332
|
# --- model params (decoder: Z -> L * num_classes) ---
|
|
329
333
|
self.model_params = {
|
|
330
334
|
"n_features": self.num_features_,
|
|
331
|
-
"num_classes": self.
|
|
335
|
+
"num_classes": self.output_classes_,
|
|
332
336
|
"latent_dim": self.latent_dim,
|
|
333
337
|
"dropout_rate": self.dropout_rate,
|
|
334
338
|
"activation": self.activation,
|
|
@@ -353,6 +357,12 @@ class ImputeUBP(BaseNNImputer):
|
|
|
353
357
|
self.sim_mask_train_ = None
|
|
354
358
|
self.sim_mask_test_ = None
|
|
355
359
|
|
|
360
|
+
# pos weights for diploid multilabel path
|
|
361
|
+
if not self.is_haploid:
|
|
362
|
+
self.pos_weights_ = self._compute_pos_weights(self.X_train_)
|
|
363
|
+
else:
|
|
364
|
+
self.pos_weights_ = None
|
|
365
|
+
|
|
356
366
|
# --- plotting/scorers & tuning ---
|
|
357
367
|
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
358
368
|
if self.tune:
|
|
@@ -485,14 +495,22 @@ class ImputeUBP(BaseNNImputer):
|
|
|
485
495
|
y = y_batch.to(self.device, non_blocking=True).long()
|
|
486
496
|
|
|
487
497
|
logits = decoder(z).view(
|
|
488
|
-
len(batch_indices), self.num_features_, self.
|
|
498
|
+
len(batch_indices), self.num_features_, self.output_classes_
|
|
489
499
|
)
|
|
490
500
|
|
|
491
501
|
# Guard upstream explosions
|
|
492
502
|
if not torch.isfinite(logits).all():
|
|
493
503
|
continue
|
|
494
504
|
|
|
495
|
-
|
|
505
|
+
if self.is_haploid:
|
|
506
|
+
loss = criterion(logits.view(-1, self.output_classes_), y.view(-1))
|
|
507
|
+
else:
|
|
508
|
+
targets = self._multi_hot_targets(y)
|
|
509
|
+
bce = F.binary_cross_entropy_with_logits(
|
|
510
|
+
logits, targets, pos_weight=self.pos_weights_, reduction="none"
|
|
511
|
+
)
|
|
512
|
+
mask = (y != -1).unsqueeze(-1).float()
|
|
513
|
+
loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
|
|
496
514
|
|
|
497
515
|
if l1_penalty > 0:
|
|
498
516
|
l1 = torch.zeros((), device=self.device)
|
|
@@ -555,10 +573,21 @@ class ImputeUBP(BaseNNImputer):
|
|
|
555
573
|
raise TypeError(msg)
|
|
556
574
|
|
|
557
575
|
logits = decoder(latent_vectors.to(self.device)).view(
|
|
558
|
-
len(latent_vectors), nF, self.
|
|
576
|
+
len(latent_vectors), nF, self.output_classes_
|
|
559
577
|
)
|
|
560
|
-
|
|
561
|
-
|
|
578
|
+
if self.is_haploid:
|
|
579
|
+
probas = torch.softmax(logits, dim=-1)
|
|
580
|
+
labels = torch.argmax(probas, dim=-1)
|
|
581
|
+
else:
|
|
582
|
+
probas2 = torch.sigmoid(logits)
|
|
583
|
+
p_ref = probas2[..., 0]
|
|
584
|
+
p_alt = probas2[..., 1]
|
|
585
|
+
p_het = p_ref * p_alt
|
|
586
|
+
p_ref_only = p_ref * (1 - p_alt)
|
|
587
|
+
p_alt_only = p_alt * (1 - p_ref)
|
|
588
|
+
probas = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
|
|
589
|
+
probas = probas / probas.sum(dim=-1, keepdim=True).clamp_min(1e-8)
|
|
590
|
+
labels = torch.argmax(probas, dim=-1)
|
|
562
591
|
|
|
563
592
|
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
564
593
|
|
|
@@ -752,9 +781,18 @@ class ImputeUBP(BaseNNImputer):
|
|
|
752
781
|
)
|
|
753
782
|
|
|
754
783
|
def _objective(self, trial: optuna.Trial) -> float:
|
|
755
|
-
"""Optuna objective using the UBP training loop.
|
|
784
|
+
"""Optuna objective using the UBP training loop.
|
|
785
|
+
|
|
786
|
+
This method defines the objective function for hyperparameter tuning using Optuna. It prepares the necessary artifacts for tuning, samples a set of hyperparameters for the current trial, and trains the UBP model using these hyperparameters. The model is evaluated on a validation set, and the specified tuning metric is returned as the objective value. If any exception occurs during the process, the trial is pruned.
|
|
787
|
+
"""
|
|
756
788
|
try:
|
|
757
|
-
|
|
789
|
+
self._prepare_tuning_artifacts()
|
|
790
|
+
trial_params = self._sample_hyperparameters(trial)
|
|
791
|
+
model_params = trial_params["model_params"]
|
|
792
|
+
|
|
793
|
+
nfeat = self._tune_num_features
|
|
794
|
+
if self.tune and self.tune_fast:
|
|
795
|
+
model_params["n_features"] = nfeat
|
|
758
796
|
|
|
759
797
|
X_train_trial = getattr(
|
|
760
798
|
self, "X_train_", self.ground_truth_[self.train_idx_]
|
|
@@ -764,31 +802,38 @@ class ImputeUBP(BaseNNImputer):
|
|
|
764
802
|
class_weights = self._normalize_class_weights(
|
|
765
803
|
self._class_weights_from_zygosity(X_train_trial)
|
|
766
804
|
)
|
|
805
|
+
if not self.is_haploid:
|
|
806
|
+
self.pos_weights_ = self._compute_pos_weights(X_train_trial)
|
|
807
|
+
else:
|
|
808
|
+
self.pos_weights_ = None
|
|
767
809
|
train_loader = self._get_data_loaders(X_train_trial)
|
|
768
810
|
|
|
769
811
|
train_latent_vectors = self._create_latent_space(
|
|
770
|
-
|
|
812
|
+
model_params,
|
|
813
|
+
len(X_train_trial),
|
|
814
|
+
X_train_trial,
|
|
815
|
+
trial_params["latent_init"],
|
|
771
816
|
)
|
|
772
817
|
|
|
773
|
-
model = self.build_model(self.Model,
|
|
774
|
-
model.n_features =
|
|
818
|
+
model = self.build_model(self.Model, model_params)
|
|
819
|
+
model.n_features = model_params["n_features"]
|
|
775
820
|
model.apply(self.initialize_weights)
|
|
776
821
|
|
|
777
822
|
_, model, __ = self._train_and_validate_model(
|
|
778
823
|
model=model,
|
|
779
824
|
loader=train_loader,
|
|
780
|
-
lr=
|
|
781
|
-
l1_penalty=
|
|
825
|
+
lr=trial_params["lr"],
|
|
826
|
+
l1_penalty=trial_params["l1_penalty"],
|
|
782
827
|
trial=trial,
|
|
783
828
|
return_history=False,
|
|
784
829
|
latent_vectors=train_latent_vectors,
|
|
785
|
-
lr_input_factor=
|
|
830
|
+
lr_input_factor=trial_params["lr_input_factor"],
|
|
786
831
|
class_weights=class_weights,
|
|
787
832
|
X_val=X_test_trial,
|
|
788
|
-
params=
|
|
833
|
+
params=model_params,
|
|
789
834
|
prune_metric=self.tune_metric,
|
|
790
|
-
prune_warmup_epochs=
|
|
791
|
-
eval_interval=
|
|
835
|
+
prune_warmup_epochs=10,
|
|
836
|
+
eval_interval=self.tune_eval_interval,
|
|
792
837
|
eval_requires_latents=True,
|
|
793
838
|
eval_latent_steps=self.eval_latent_steps,
|
|
794
839
|
eval_latent_lr=self.eval_latent_lr,
|
|
@@ -806,7 +851,7 @@ class ImputeUBP(BaseNNImputer):
|
|
|
806
851
|
metrics = self._evaluate_model(
|
|
807
852
|
X_test_trial,
|
|
808
853
|
model,
|
|
809
|
-
|
|
854
|
+
model_params,
|
|
810
855
|
objective_mode=True,
|
|
811
856
|
eval_mask_override=eval_mask,
|
|
812
857
|
)
|
|
@@ -829,30 +874,30 @@ class ImputeUBP(BaseNNImputer):
|
|
|
829
874
|
Dict[str, int | float | str | list]: Sampled hyperparameters.
|
|
830
875
|
"""
|
|
831
876
|
params = {
|
|
832
|
-
"latent_dim": trial.suggest_int("latent_dim",
|
|
833
|
-
"lr": trial.suggest_float("learning_rate",
|
|
834
|
-
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.
|
|
835
|
-
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1,
|
|
877
|
+
"latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
|
|
878
|
+
"lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
|
|
879
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30, step=0.05),
|
|
880
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 6),
|
|
836
881
|
"activation": trial.suggest_categorical(
|
|
837
|
-
"activation", ["relu", "elu", "selu"]
|
|
882
|
+
"activation", ["relu", "elu", "selu", "leaky_relu"]
|
|
838
883
|
),
|
|
839
|
-
"gamma": trial.suggest_float("gamma", 0.0, 5
|
|
884
|
+
"gamma": trial.suggest_float("gamma", 0.5, 3.0, step=0.5),
|
|
840
885
|
"lr_input_factor": trial.suggest_float(
|
|
841
|
-
"lr_input_factor", 0.
|
|
886
|
+
"lr_input_factor", 0.3, 3.0, log=True
|
|
842
887
|
),
|
|
843
|
-
"l1_penalty": trial.suggest_float("l1_penalty", 1e-
|
|
888
|
+
"l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
|
|
844
889
|
"layer_scaling_factor": trial.suggest_float(
|
|
845
|
-
"layer_scaling_factor", 2.0,
|
|
890
|
+
"layer_scaling_factor", 2.0, 4.0, step=0.5
|
|
846
891
|
),
|
|
847
892
|
"layer_schedule": trial.suggest_categorical(
|
|
848
|
-
"layer_schedule", ["pyramid", "
|
|
893
|
+
"layer_schedule", ["pyramid", "linear"]
|
|
849
894
|
),
|
|
850
895
|
"latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
|
|
851
896
|
}
|
|
852
897
|
|
|
853
898
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
854
899
|
n_inputs=params["latent_dim"],
|
|
855
|
-
n_outputs=self.num_features_ * self.
|
|
900
|
+
n_outputs=self.num_features_ * self.output_classes_,
|
|
856
901
|
n_samples=len(self.train_idx_),
|
|
857
902
|
n_hidden=params["num_hidden_layers"],
|
|
858
903
|
alpha=params["layer_scaling_factor"],
|
|
@@ -866,7 +911,7 @@ class ImputeUBP(BaseNNImputer):
|
|
|
866
911
|
|
|
867
912
|
params["model_params"] = {
|
|
868
913
|
"n_features": self.num_features_,
|
|
869
|
-
"num_classes": self.
|
|
914
|
+
"num_classes": self.output_classes_,
|
|
870
915
|
"latent_dim": params["latent_dim"],
|
|
871
916
|
"dropout_rate": params["dropout_rate"],
|
|
872
917
|
"hidden_layer_sizes": hidden_only,
|
|
@@ -900,7 +945,7 @@ class ImputeUBP(BaseNNImputer):
|
|
|
900
945
|
|
|
901
946
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
902
947
|
n_inputs=self.latent_dim,
|
|
903
|
-
n_outputs=self.num_features_ * self.
|
|
948
|
+
n_outputs=self.num_features_ * self.output_classes_,
|
|
904
949
|
n_samples=len(self.train_idx_),
|
|
905
950
|
n_hidden=best_params["num_hidden_layers"],
|
|
906
951
|
alpha=best_params["layer_scaling_factor"],
|
|
@@ -916,7 +961,7 @@ class ImputeUBP(BaseNNImputer):
|
|
|
916
961
|
"dropout_rate": self.dropout_rate,
|
|
917
962
|
"activation": self.activation,
|
|
918
963
|
"gamma": self.gamma,
|
|
919
|
-
"num_classes": self.
|
|
964
|
+
"num_classes": self.output_classes_,
|
|
920
965
|
}
|
|
921
966
|
|
|
922
967
|
def _set_best_params_default(self) -> dict:
|
|
@@ -929,7 +974,7 @@ class ImputeUBP(BaseNNImputer):
|
|
|
929
974
|
"""
|
|
930
975
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
931
976
|
n_inputs=self.latent_dim,
|
|
932
|
-
n_outputs=self.num_features_ * self.
|
|
977
|
+
n_outputs=self.num_features_ * self.output_classes_,
|
|
933
978
|
n_samples=len(self.ground_truth_),
|
|
934
979
|
n_hidden=self.num_hidden_layers,
|
|
935
980
|
alpha=self.layer_scaling_factor,
|
|
@@ -945,7 +990,7 @@ class ImputeUBP(BaseNNImputer):
|
|
|
945
990
|
"dropout_rate": self.dropout_rate,
|
|
946
991
|
"activation": self.activation,
|
|
947
992
|
"gamma": self.gamma,
|
|
948
|
-
"num_classes": self.
|
|
993
|
+
"num_classes": self.output_classes_,
|
|
949
994
|
}
|
|
950
995
|
|
|
951
996
|
def _train_and_validate_model(
|
|
@@ -963,7 +1008,7 @@ class ImputeUBP(BaseNNImputer):
|
|
|
963
1008
|
X_val: np.ndarray | None = None,
|
|
964
1009
|
params: dict | None = None,
|
|
965
1010
|
prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
|
|
966
|
-
prune_warmup_epochs: int =
|
|
1011
|
+
prune_warmup_epochs: int = 10,
|
|
967
1012
|
eval_interval: int = 1,
|
|
968
1013
|
eval_requires_latents: bool = True, # UBP needs latent eval
|
|
969
1014
|
eval_latent_steps: int = 50,
|
|
@@ -1073,7 +1118,7 @@ class ImputeUBP(BaseNNImputer):
|
|
|
1073
1118
|
X_val=self.X_test_,
|
|
1074
1119
|
params=best_params,
|
|
1075
1120
|
prune_metric=self.tune_metric,
|
|
1076
|
-
prune_warmup_epochs=
|
|
1121
|
+
prune_warmup_epochs=10,
|
|
1077
1122
|
eval_interval=1,
|
|
1078
1123
|
eval_requires_latents=True,
|
|
1079
1124
|
eval_latent_steps=self.eval_latent_steps,
|
|
@@ -1105,7 +1150,7 @@ class ImputeUBP(BaseNNImputer):
|
|
|
1105
1150
|
X_val: np.ndarray | None = None,
|
|
1106
1151
|
params: dict | None = None,
|
|
1107
1152
|
prune_metric: str | None = None,
|
|
1108
|
-
prune_warmup_epochs: int =
|
|
1153
|
+
prune_warmup_epochs: int = 10,
|
|
1109
1154
|
eval_interval: int = 1,
|
|
1110
1155
|
eval_requires_latents: bool = True,
|
|
1111
1156
|
eval_latent_steps: int = 50,
|
|
@@ -1152,7 +1197,7 @@ class ImputeUBP(BaseNNImputer):
|
|
|
1152
1197
|
# Schema-aware latent cache for eval
|
|
1153
1198
|
_latent_cache: dict = {}
|
|
1154
1199
|
nF = getattr(model, "n_features", self.num_features_)
|
|
1155
|
-
cache_key_root = f"{self.prefix}_ubp_val_latents_L{nF}_K{self.
|
|
1200
|
+
cache_key_root = f"{self.prefix}_ubp_val_latents_L{nF}_K{self.output_classes_}"
|
|
1156
1201
|
|
|
1157
1202
|
E = int(self.epochs)
|
|
1158
1203
|
phase_epochs = {
|
|
@@ -1349,14 +1394,22 @@ class ImputeUBP(BaseNNImputer):
|
|
|
1349
1394
|
raise TypeError(msg)
|
|
1350
1395
|
|
|
1351
1396
|
opt.zero_grad(set_to_none=True)
|
|
1352
|
-
logits = decoder(z).view(len(X_new), nF, self.
|
|
1397
|
+
logits = decoder(z).view(len(X_new), nF, self.output_classes_)
|
|
1353
1398
|
|
|
1354
1399
|
if not torch.isfinite(logits).all():
|
|
1355
1400
|
break
|
|
1356
1401
|
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1402
|
+
if self.is_haploid:
|
|
1403
|
+
loss = F.cross_entropy(
|
|
1404
|
+
logits.view(-1, self.output_classes_), y.view(-1), ignore_index=-1
|
|
1405
|
+
)
|
|
1406
|
+
else:
|
|
1407
|
+
targets = self._multi_hot_targets(y)
|
|
1408
|
+
bce = F.binary_cross_entropy_with_logits(
|
|
1409
|
+
logits, targets, pos_weight=self.pos_weights_, reduction="none"
|
|
1410
|
+
)
|
|
1411
|
+
mask = (y != -1).unsqueeze(-1).float()
|
|
1412
|
+
loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
|
|
1360
1413
|
|
|
1361
1414
|
if not torch.isfinite(loss):
|
|
1362
1415
|
break
|
|
@@ -1457,8 +1510,31 @@ class ImputeUBP(BaseNNImputer):
|
|
|
1457
1510
|
else:
|
|
1458
1511
|
latents = torch.empty(n_samples, latent_dim, device=self.device)
|
|
1459
1512
|
torch.nn.init.xavier_uniform_(latents)
|
|
1513
|
+
return torch.nn.Parameter(latents, requires_grad=True)
|
|
1460
1514
|
|
|
1461
|
-
|
|
1515
|
+
def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
|
|
1516
|
+
"""Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
|
|
1517
|
+
if self.is_haploid:
|
|
1518
|
+
raise RuntimeError("_multi_hot_targets called for haploid data.")
|
|
1519
|
+
y = y.to(self.device)
|
|
1520
|
+
out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
|
|
1521
|
+
valid = y != -1
|
|
1522
|
+
ref_mask = valid & (y != 2)
|
|
1523
|
+
alt_mask = valid & (y != 0)
|
|
1524
|
+
out[ref_mask, 0] = 1.0
|
|
1525
|
+
out[alt_mask, 1] = 1.0
|
|
1526
|
+
return out
|
|
1527
|
+
|
|
1528
|
+
def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
|
|
1529
|
+
"""Balance REF/ALT channels for multilabel BCE."""
|
|
1530
|
+
ref_pos = np.count_nonzero((X == 0) | (X == 1))
|
|
1531
|
+
alt_pos = np.count_nonzero((X == 2) | (X == 1))
|
|
1532
|
+
total_valid = np.count_nonzero(X != -1)
|
|
1533
|
+
pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
|
|
1534
|
+
neg_counts = np.maximum(total_valid - pos_counts, 1.0)
|
|
1535
|
+
pos_counts = np.maximum(pos_counts, 1.0)
|
|
1536
|
+
weights = neg_counts / pos_counts
|
|
1537
|
+
return torch.tensor(weights, device=self.device, dtype=torch.float32)
|
|
1462
1538
|
|
|
1463
1539
|
def _reset_weights(self, model: torch.nn.Module) -> None:
|
|
1464
1540
|
"""Selectively resets only the weights of the phase 2/3 decoder.
|
|
@@ -1527,7 +1603,9 @@ class ImputeUBP(BaseNNImputer):
|
|
|
1527
1603
|
y = torch.from_numpy(X_val).long().to(self.device)
|
|
1528
1604
|
|
|
1529
1605
|
zdim = self._first_linear_in_features(model)
|
|
1530
|
-
schema_key =
|
|
1606
|
+
schema_key = (
|
|
1607
|
+
f"{self.prefix}_ubp_val_latents_z{zdim}_L{nF}_K{self.output_classes_}"
|
|
1608
|
+
)
|
|
1531
1609
|
|
|
1532
1610
|
if cache is not None and schema_key in cache:
|
|
1533
1611
|
z = cache[schema_key].detach().clone().requires_grad_(True)
|
|
@@ -1548,13 +1626,20 @@ class ImputeUBP(BaseNNImputer):
|
|
|
1548
1626
|
self.logger.error(msg)
|
|
1549
1627
|
raise TypeError(msg)
|
|
1550
1628
|
|
|
1551
|
-
logits = decoder(z).view(X_val.shape[0], nF, self.
|
|
1629
|
+
logits = decoder(z).view(X_val.shape[0], nF, self.output_classes_)
|
|
1552
1630
|
if not torch.isfinite(logits).all():
|
|
1553
1631
|
break
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1632
|
+
if self.is_haploid:
|
|
1633
|
+
loss = F.cross_entropy(
|
|
1634
|
+
logits.view(-1, self.output_classes_), y.view(-1), ignore_index=-1
|
|
1635
|
+
)
|
|
1636
|
+
else:
|
|
1637
|
+
targets = self._multi_hot_targets(y)
|
|
1638
|
+
bce = F.binary_cross_entropy_with_logits(
|
|
1639
|
+
logits, targets, pos_weight=self.pos_weights_, reduction="none"
|
|
1640
|
+
)
|
|
1641
|
+
mask = (y != -1).unsqueeze(-1).float()
|
|
1642
|
+
loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
|
|
1558
1643
|
|
|
1559
1644
|
if not torch.isfinite(loss):
|
|
1560
1645
|
break
|