pg-sui 1.6.14.dev9__py3-none-any.whl → 1.6.16a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.6.16a3.dist-info/METADATA +292 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/RECORD +14 -14
- pgsui/_version.py +2 -2
- pgsui/cli.py +14 -1
- pgsui/data_processing/containers.py +116 -104
- pgsui/impute/unsupervised/base.py +4 -1
- pgsui/impute/unsupervised/imputers/autoencoder.py +111 -35
- pgsui/impute/unsupervised/imputers/nlpca.py +239 -127
- pgsui/impute/unsupervised/imputers/ubp.py +135 -50
- pgsui/impute/unsupervised/imputers/vae.py +134 -46
- pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/top_level.txt +0 -0
|
@@ -5,6 +5,7 @@ import matplotlib.pyplot as plt
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import optuna
|
|
7
7
|
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
8
9
|
from sklearn.exceptions import NotFittedError
|
|
9
10
|
from sklearn.model_selection import train_test_split
|
|
10
11
|
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
@@ -162,6 +163,7 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
162
163
|
self.verbose = self.cfg.io.verbose
|
|
163
164
|
self.debug = self.cfg.io.debug
|
|
164
165
|
self.rng = np.random.default_rng(self.seed)
|
|
166
|
+
self.pos_weights_: torch.Tensor | None = None
|
|
165
167
|
|
|
166
168
|
# Simulated-missing controls (config defaults with ctor overrides)
|
|
167
169
|
sim_cfg = getattr(self.cfg, "sim", None)
|
|
@@ -330,10 +332,12 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
330
332
|
)
|
|
331
333
|
)
|
|
332
334
|
self.ploidy = 1 if self.is_haploid else 2
|
|
335
|
+
# Scoring still uses 3 labels for diploid (REF/HET/ALT); model head uses 2 logits
|
|
333
336
|
self.num_classes_ = 2 if self.is_haploid else 3
|
|
337
|
+
self.output_classes_ = 2
|
|
334
338
|
self.logger.info(
|
|
335
339
|
f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
|
|
336
|
-
f"using {self.num_classes_} classes."
|
|
340
|
+
f"using {self.num_classes_} classes for scoring and {self.output_classes_} output channels."
|
|
337
341
|
)
|
|
338
342
|
|
|
339
343
|
if self.is_haploid:
|
|
@@ -345,7 +349,7 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
345
349
|
# Model params (decoder outputs L * K logits)
|
|
346
350
|
self.model_params = {
|
|
347
351
|
"n_features": self.num_features_,
|
|
348
|
-
"num_classes": self.
|
|
352
|
+
"num_classes": self.output_classes_,
|
|
349
353
|
"latent_dim": self.latent_dim,
|
|
350
354
|
"dropout_rate": self.dropout_rate,
|
|
351
355
|
"activation": self.activation,
|
|
@@ -369,6 +373,12 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
369
373
|
self.sim_mask_train_ = None
|
|
370
374
|
self.sim_mask_test_ = None
|
|
371
375
|
|
|
376
|
+
# Pos weights for diploid multilabel path (must exist before tuning)
|
|
377
|
+
if not self.is_haploid:
|
|
378
|
+
self.pos_weights_ = self._compute_pos_weights(self.X_train_)
|
|
379
|
+
else:
|
|
380
|
+
self.pos_weights_ = None
|
|
381
|
+
|
|
372
382
|
# Plotters/scorers (shared utilities)
|
|
373
383
|
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
374
384
|
|
|
@@ -401,7 +411,7 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
401
411
|
X_val=self.X_val_,
|
|
402
412
|
params=self.best_params_,
|
|
403
413
|
prune_metric=self.tune_metric,
|
|
404
|
-
prune_warmup_epochs=
|
|
414
|
+
prune_warmup_epochs=10,
|
|
405
415
|
eval_interval=1,
|
|
406
416
|
eval_requires_latents=False,
|
|
407
417
|
eval_latent_steps=0,
|
|
@@ -509,7 +519,7 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
509
519
|
X_val: np.ndarray | None = None,
|
|
510
520
|
params: dict | None = None,
|
|
511
521
|
prune_metric: str = "f1", # "f1" | "accuracy" | "pr_macro"
|
|
512
|
-
prune_warmup_epochs: int =
|
|
522
|
+
prune_warmup_epochs: int = 10,
|
|
513
523
|
eval_interval: int = 1,
|
|
514
524
|
# Evaluation parameters (AE ignores latent refinement knobs)
|
|
515
525
|
eval_requires_latents: bool = False, # AE: always False
|
|
@@ -593,7 +603,7 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
593
603
|
X_val: np.ndarray | None = None,
|
|
594
604
|
params: dict | None = None,
|
|
595
605
|
prune_metric: str = "f1",
|
|
596
|
-
prune_warmup_epochs: int =
|
|
606
|
+
prune_warmup_epochs: int = 10,
|
|
597
607
|
eval_interval: int = 1,
|
|
598
608
|
# Evaluation parameters (AE ignores latent refinement knobs)
|
|
599
609
|
eval_requires_latents: bool = False, # AE: False
|
|
@@ -761,24 +771,35 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
761
771
|
# Use model.gamma if present, else self.gamma
|
|
762
772
|
gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
|
|
763
773
|
gamma = float(torch.tensor(gamma).clamp(min=0.0, max=10.0)) # sane bound
|
|
764
|
-
|
|
774
|
+
ce_criterion = SafeFocalCELoss(
|
|
775
|
+
gamma=gamma, weight=class_weights, ignore_index=-1
|
|
776
|
+
)
|
|
765
777
|
|
|
766
778
|
for _, y_batch in loader:
|
|
767
779
|
optimizer.zero_grad(set_to_none=True)
|
|
768
780
|
y_batch = y_batch.to(self.device, non_blocking=True)
|
|
769
781
|
|
|
770
782
|
# Inputs: one-hot with zeros for missing; Targets: long ints with -1 for missing
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
783
|
+
if self.is_haploid:
|
|
784
|
+
x_in = self._one_hot_encode_012(y_batch) # (B, L, 2)
|
|
785
|
+
logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
|
|
786
|
+
logits_flat = logits.view(-1, self.output_classes_)
|
|
787
|
+
targets_flat = y_batch.view(-1).long()
|
|
788
|
+
if not torch.isfinite(logits_flat).all():
|
|
789
|
+
continue
|
|
790
|
+
loss = ce_criterion(logits_flat, targets_flat)
|
|
791
|
+
else:
|
|
792
|
+
x_in = self._encode_multilabel_inputs(y_batch) # (B, L, 2)
|
|
793
|
+
logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
|
|
794
|
+
if not torch.isfinite(logits).all():
|
|
795
|
+
continue
|
|
796
|
+
pos_w = getattr(self, "pos_weights_", None)
|
|
797
|
+
targets = self._multi_hot_targets(y_batch) # float, same shape
|
|
798
|
+
bce = F.binary_cross_entropy_with_logits(
|
|
799
|
+
logits, targets, pos_weight=pos_w, reduction="none"
|
|
800
|
+
)
|
|
801
|
+
mask = (y_batch != -1).unsqueeze(-1).float()
|
|
802
|
+
loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
|
|
782
803
|
|
|
783
804
|
if l1_penalty > 0:
|
|
784
805
|
l1 = torch.zeros((), device=self.device)
|
|
@@ -842,16 +863,69 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
842
863
|
with torch.no_grad():
|
|
843
864
|
X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
|
|
844
865
|
X_tensor = X_tensor.to(self.device).long()
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
866
|
+
if self.is_haploid:
|
|
867
|
+
x_ohe = self._one_hot_encode_012(X_tensor)
|
|
868
|
+
logits = model(x_ohe).view(-1, self.num_features_, self.output_classes_)
|
|
869
|
+
probas = torch.softmax(logits, dim=-1)
|
|
870
|
+
labels = torch.argmax(probas, dim=-1)
|
|
871
|
+
else:
|
|
872
|
+
x_in = self._encode_multilabel_inputs(X_tensor)
|
|
873
|
+
logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
|
|
874
|
+
probas_2 = torch.sigmoid(logits)
|
|
875
|
+
p_ref = probas_2[..., 0]
|
|
876
|
+
p_alt = probas_2[..., 1]
|
|
877
|
+
p_het = p_ref * p_alt
|
|
878
|
+
p_ref_only = p_ref * (1 - p_alt)
|
|
879
|
+
p_alt_only = p_alt * (1 - p_ref)
|
|
880
|
+
stacked = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
|
|
881
|
+
stacked = stacked / stacked.sum(dim=-1, keepdim=True).clamp_min(1e-8)
|
|
882
|
+
probas = stacked
|
|
883
|
+
labels = torch.argmax(stacked, dim=-1)
|
|
849
884
|
|
|
850
885
|
if return_proba:
|
|
851
886
|
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
852
887
|
|
|
853
888
|
return labels.cpu().numpy()
|
|
854
889
|
|
|
890
|
+
def _encode_multilabel_inputs(self, y: torch.Tensor) -> torch.Tensor:
|
|
891
|
+
"""Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
|
|
892
|
+
if self.is_haploid:
|
|
893
|
+
return self._one_hot_encode_012(y)
|
|
894
|
+
y = y.to(self.device)
|
|
895
|
+
shape = y.shape + (2,)
|
|
896
|
+
out = torch.zeros(shape, device=self.device, dtype=torch.float32)
|
|
897
|
+
valid = y != -1
|
|
898
|
+
ref_mask = valid & (y != 2)
|
|
899
|
+
alt_mask = valid & (y != 0)
|
|
900
|
+
out[ref_mask, 0] = 1.0
|
|
901
|
+
out[alt_mask, 1] = 1.0
|
|
902
|
+
return out
|
|
903
|
+
|
|
904
|
+
def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
|
|
905
|
+
"""Targets aligned with _encode_multilabel_inputs for diploid training."""
|
|
906
|
+
if self.is_haploid:
|
|
907
|
+
# One-hot CE path expects integer targets; handled upstream.
|
|
908
|
+
raise RuntimeError("_multi_hot_targets called for haploid data.")
|
|
909
|
+
y = y.to(self.device)
|
|
910
|
+
out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
|
|
911
|
+
valid = y != -1
|
|
912
|
+
ref_mask = valid & (y != 2)
|
|
913
|
+
alt_mask = valid & (y != 0)
|
|
914
|
+
out[ref_mask, 0] = 1.0
|
|
915
|
+
out[alt_mask, 1] = 1.0
|
|
916
|
+
return out
|
|
917
|
+
|
|
918
|
+
def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
|
|
919
|
+
"""Balance REF/ALT channels for multilabel BCE."""
|
|
920
|
+
ref_pos = np.count_nonzero((X == 0) | (X == 1))
|
|
921
|
+
alt_pos = np.count_nonzero((X == 2) | (X == 1))
|
|
922
|
+
total_valid = np.count_nonzero(X != -1)
|
|
923
|
+
pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
|
|
924
|
+
neg_counts = np.maximum(total_valid - pos_counts, 1.0)
|
|
925
|
+
pos_counts = np.maximum(pos_counts, 1.0)
|
|
926
|
+
weights = neg_counts / pos_counts
|
|
927
|
+
return torch.tensor(weights, device=self.device, dtype=torch.float32)
|
|
928
|
+
|
|
855
929
|
def _evaluate_model(
|
|
856
930
|
self,
|
|
857
931
|
X_val: np.ndarray,
|
|
@@ -1090,7 +1164,7 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
1090
1164
|
X_val=X_val,
|
|
1091
1165
|
params=params,
|
|
1092
1166
|
prune_metric=self.tune_metric,
|
|
1093
|
-
prune_warmup_epochs=
|
|
1167
|
+
prune_warmup_epochs=10,
|
|
1094
1168
|
eval_interval=self.tune_eval_interval,
|
|
1095
1169
|
eval_requires_latents=False,
|
|
1096
1170
|
eval_latent_steps=0,
|
|
@@ -1137,24 +1211,24 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
1137
1211
|
Dict[str, int | float | str | bool]: Sampled hyperparameters and model_params.
|
|
1138
1212
|
"""
|
|
1139
1213
|
params = {
|
|
1140
|
-
"latent_dim": trial.suggest_int("latent_dim",
|
|
1141
|
-
"lr": trial.suggest_float("learning_rate",
|
|
1142
|
-
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.
|
|
1143
|
-
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1,
|
|
1214
|
+
"latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
|
|
1215
|
+
"lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
|
|
1216
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30, step=0.05),
|
|
1217
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 6),
|
|
1144
1218
|
"activation": trial.suggest_categorical(
|
|
1145
|
-
"activation", ["relu", "elu", "selu"]
|
|
1219
|
+
"activation", ["relu", "elu", "selu", "leaky_relu"]
|
|
1146
1220
|
),
|
|
1147
|
-
"l1_penalty": trial.suggest_float("l1_penalty", 1e-
|
|
1221
|
+
"l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
|
|
1148
1222
|
"layer_scaling_factor": trial.suggest_float(
|
|
1149
|
-
"layer_scaling_factor", 2.0,
|
|
1223
|
+
"layer_scaling_factor", 2.0, 4.0, step=0.5
|
|
1150
1224
|
),
|
|
1151
1225
|
"layer_schedule": trial.suggest_categorical(
|
|
1152
|
-
"layer_schedule", ["pyramid", "
|
|
1226
|
+
"layer_schedule", ["pyramid", "linear"]
|
|
1153
1227
|
),
|
|
1154
1228
|
}
|
|
1155
1229
|
|
|
1156
1230
|
nF: int = self.num_features_
|
|
1157
|
-
nC: int = int(self
|
|
1231
|
+
nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
|
|
1158
1232
|
input_dim = nF * nC
|
|
1159
1233
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1160
1234
|
n_inputs=input_dim,
|
|
@@ -1173,8 +1247,8 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
1173
1247
|
|
|
1174
1248
|
params["model_params"] = {
|
|
1175
1249
|
"n_features": int(self.num_features_),
|
|
1176
|
-
"num_classes": (
|
|
1177
|
-
|
|
1250
|
+
"num_classes": int(
|
|
1251
|
+
getattr(self, "output_classes_", self.num_classes_ or 3)
|
|
1178
1252
|
),
|
|
1179
1253
|
"latent_dim": int(params["latent_dim"]),
|
|
1180
1254
|
"dropout_rate": float(params["dropout_rate"]),
|
|
@@ -1227,7 +1301,7 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
1227
1301
|
self.layer_schedule: str = bp["layer_schedule"]
|
|
1228
1302
|
|
|
1229
1303
|
nF: int = self.num_features_
|
|
1230
|
-
nC: int = int(self
|
|
1304
|
+
nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
|
|
1231
1305
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1232
1306
|
n_inputs=nF * nC,
|
|
1233
1307
|
n_outputs=nF * nC,
|
|
@@ -1261,7 +1335,9 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
1261
1335
|
Dict[str, int | float | str | list]: Default model parameters.
|
|
1262
1336
|
"""
|
|
1263
1337
|
nF: int = self.num_features_
|
|
1264
|
-
|
|
1338
|
+
# Use the number of output channels passed to the model (2 for diploid multilabel)
|
|
1339
|
+
# instead of the scoring classes (3) to keep layer shapes aligned.
|
|
1340
|
+
nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
|
|
1265
1341
|
ls = self.layer_schedule
|
|
1266
1342
|
|
|
1267
1343
|
if ls not in {"pyramid", "constant", "linear"}:
|