pg-sui 1.6.14.dev9__py3-none-any.whl → 1.6.16a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.6.16a3.dist-info/METADATA +292 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/RECORD +14 -14
- pgsui/_version.py +2 -2
- pgsui/cli.py +14 -1
- pgsui/data_processing/containers.py +116 -104
- pgsui/impute/unsupervised/base.py +4 -1
- pgsui/impute/unsupervised/imputers/autoencoder.py +111 -35
- pgsui/impute/unsupervised/imputers/nlpca.py +239 -127
- pgsui/impute/unsupervised/imputers/ubp.py +135 -50
- pgsui/impute/unsupervised/imputers/vae.py +134 -46
- pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.6.16a3.dist-info}/top_level.txt +0 -0
|
@@ -110,15 +110,18 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
110
110
|
tree_parser: Optional["TreeParser"] = None,
|
|
111
111
|
config: NLPCAConfig | dict | str | None = None,
|
|
112
112
|
overrides: dict | None = None,
|
|
113
|
-
simulate_missing: bool =
|
|
114
|
-
sim_strategy:
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
113
|
+
simulate_missing: bool | None = None,
|
|
114
|
+
sim_strategy: (
|
|
115
|
+
Literal[
|
|
116
|
+
"random",
|
|
117
|
+
"random_weighted",
|
|
118
|
+
"random_weighted_inv",
|
|
119
|
+
"nonrandom",
|
|
120
|
+
"nonrandom_weighted",
|
|
121
|
+
]
|
|
122
|
+
| None
|
|
123
|
+
) = None,
|
|
124
|
+
sim_prop: float | None = None,
|
|
122
125
|
sim_kwargs: dict | None = None,
|
|
123
126
|
):
|
|
124
127
|
"""Initializes the ImputeNLPCA imputer with genotype data and configuration.
|
|
@@ -130,10 +133,10 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
130
133
|
tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for population-specific modes.
|
|
131
134
|
config (NLPCAConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
|
|
132
135
|
overrides (dict | None): Dot-key overrides (e.g. {'model.latent_dim': 4}).
|
|
133
|
-
simulate_missing (bool): Whether to simulate missing data during training.
|
|
134
|
-
sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Strategy for simulating missing data.
|
|
135
|
-
sim_prop (float): Proportion of data to simulate as missing.
|
|
136
|
-
sim_kwargs (dict | None): Additional keyword arguments for missing data simulation.
|
|
136
|
+
simulate_missing (bool | None): Whether to simulate missing data during training. If None, uses config defaults.
|
|
137
|
+
sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"] | None): Strategy for simulating missing data. If None, uses config default.
|
|
138
|
+
sim_prop (float | None): Proportion of data to simulate as missing. If None, uses config default.
|
|
139
|
+
sim_kwargs (dict | None): Additional keyword arguments for missing data simulation (overrides config kwargs).
|
|
137
140
|
"""
|
|
138
141
|
self.model_name = "ImputeNLPCA"
|
|
139
142
|
self.genotype_data = genotype_data
|
|
@@ -141,6 +144,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
141
144
|
|
|
142
145
|
# Normalize config first, then apply overrides (highest precedence)
|
|
143
146
|
cfg = ensure_nlpca_config(config)
|
|
147
|
+
|
|
144
148
|
if overrides:
|
|
145
149
|
cfg = apply_dot_overrides(cfg, overrides)
|
|
146
150
|
|
|
@@ -153,9 +157,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
153
157
|
verbose=self.cfg.io.verbose,
|
|
154
158
|
)
|
|
155
159
|
self.logger = configure_logger(
|
|
156
|
-
logman.get_logger(),
|
|
157
|
-
verbose=self.cfg.io.verbose,
|
|
158
|
-
debug=self.cfg.io.debug,
|
|
160
|
+
logman.get_logger(), verbose=self.cfg.io.verbose, debug=self.cfg.io.debug
|
|
159
161
|
)
|
|
160
162
|
|
|
161
163
|
# Initialize BaseNNImputer with device/dirs/logging from config
|
|
@@ -178,6 +180,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
178
180
|
self.debug = self.cfg.io.debug
|
|
179
181
|
|
|
180
182
|
self.rng = np.random.default_rng(self.seed)
|
|
183
|
+
self.pos_weights_: torch.Tensor | None = None
|
|
181
184
|
|
|
182
185
|
# Model/train hyperparams
|
|
183
186
|
self.latent_dim = self.cfg.model.latent_dim
|
|
@@ -242,10 +245,31 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
242
245
|
self.num_classes_ = 3
|
|
243
246
|
self.model_params: Dict[str, Any] = {}
|
|
244
247
|
|
|
245
|
-
self.
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
248
|
+
sim_cfg = getattr(self.cfg, "sim", None)
|
|
249
|
+
sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
|
|
250
|
+
|
|
251
|
+
if sim_kwargs:
|
|
252
|
+
sim_cfg_kwargs.update(sim_kwargs)
|
|
253
|
+
|
|
254
|
+
if sim_cfg is None:
|
|
255
|
+
default_strategy = "random"
|
|
256
|
+
default_prop = 0.10
|
|
257
|
+
else:
|
|
258
|
+
default_strategy = sim_cfg.sim_strategy
|
|
259
|
+
default_prop = sim_cfg.sim_prop
|
|
260
|
+
|
|
261
|
+
self.simulate_missing = (
|
|
262
|
+
(
|
|
263
|
+
sim_cfg.simulate_missing
|
|
264
|
+
if simulate_missing is None
|
|
265
|
+
else bool(simulate_missing)
|
|
266
|
+
)
|
|
267
|
+
if sim_cfg is not None
|
|
268
|
+
else bool(simulate_missing)
|
|
269
|
+
)
|
|
270
|
+
self.sim_strategy = sim_strategy or default_strategy
|
|
271
|
+
self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
|
|
272
|
+
self.sim_kwargs = sim_cfg_kwargs
|
|
249
273
|
|
|
250
274
|
if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
|
|
251
275
|
msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
|
|
@@ -319,10 +343,11 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
319
343
|
self.logger.info("Haploid data detected. Using 2 classes (REF=0, ALT=1).")
|
|
320
344
|
else:
|
|
321
345
|
self.num_classes_ = 3
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
)
|
|
346
|
+
# Model head uses two channels; scoring uses num_classes_
|
|
347
|
+
self.output_classes_ = 2
|
|
348
|
+
self.logger.info(
|
|
349
|
+
"Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2) for scoring; 2 output channels with sigmoid for training."
|
|
350
|
+
)
|
|
326
351
|
|
|
327
352
|
n_samples, self.num_features_ = X_for_model.shape
|
|
328
353
|
|
|
@@ -332,7 +357,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
332
357
|
"dropout_rate": self.dropout_rate,
|
|
333
358
|
"activation": self.activation,
|
|
334
359
|
"gamma": self.gamma,
|
|
335
|
-
"num_classes": self.
|
|
360
|
+
"num_classes": self.output_classes_,
|
|
336
361
|
}
|
|
337
362
|
|
|
338
363
|
# --- Train/Test Split ---
|
|
@@ -354,6 +379,11 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
354
379
|
else:
|
|
355
380
|
self.sim_mask_train_ = None
|
|
356
381
|
self.sim_mask_test_ = None
|
|
382
|
+
# pos weights for multilabel diploid path
|
|
383
|
+
if not self.is_haploid:
|
|
384
|
+
self.pos_weights_ = self._compute_pos_weights(self.X_train_)
|
|
385
|
+
else:
|
|
386
|
+
self.pos_weights_ = None
|
|
357
387
|
|
|
358
388
|
# Tuning, model setup, training (unchanged except DataLoader input)
|
|
359
389
|
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
@@ -519,17 +549,24 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
519
549
|
|
|
520
550
|
# Forward
|
|
521
551
|
z = latent_vectors[batch_indices].to(self.device)
|
|
522
|
-
logits = decoder(z).view(len(batch_indices), nF, self.
|
|
552
|
+
logits = decoder(z).view(len(batch_indices), nF, self.output_classes_)
|
|
523
553
|
|
|
524
554
|
# Guard upstream explosions
|
|
525
555
|
if not torch.isfinite(logits).all():
|
|
526
556
|
# Skip batch if model already produced non-finite values
|
|
527
557
|
continue
|
|
528
558
|
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
559
|
+
if self.is_haploid:
|
|
560
|
+
logits_flat = logits.view(-1, self.output_classes_)
|
|
561
|
+
targets_flat = y_batch.view(-1)
|
|
562
|
+
loss = criterion(logits_flat, targets_flat)
|
|
563
|
+
else:
|
|
564
|
+
targets = self._multi_hot_targets(y_batch)
|
|
565
|
+
bce = F.binary_cross_entropy_with_logits(
|
|
566
|
+
logits, targets, pos_weight=self.pos_weights_, reduction="none"
|
|
567
|
+
)
|
|
568
|
+
mask = (y_batch != -1).unsqueeze(-1).float()
|
|
569
|
+
loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
|
|
533
570
|
|
|
534
571
|
# L1 on model weights only (exclude latents)
|
|
535
572
|
if l1_penalty > 0:
|
|
@@ -605,10 +642,21 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
605
642
|
|
|
606
643
|
with torch.no_grad():
|
|
607
644
|
logits = model.phase23_decoder(latent_vectors.to(self.device)).view(
|
|
608
|
-
len(latent_vectors), nF, self.
|
|
645
|
+
len(latent_vectors), nF, self.output_classes_
|
|
609
646
|
)
|
|
610
|
-
|
|
611
|
-
|
|
647
|
+
if self.is_haploid:
|
|
648
|
+
probas = torch.softmax(logits, dim=-1)
|
|
649
|
+
labels = torch.argmax(probas, dim=-1)
|
|
650
|
+
else:
|
|
651
|
+
probas2 = torch.sigmoid(logits)
|
|
652
|
+
p_ref = probas2[..., 0]
|
|
653
|
+
p_alt = probas2[..., 1]
|
|
654
|
+
p_het = p_ref * p_alt
|
|
655
|
+
p_ref_only = p_ref * (1 - p_alt)
|
|
656
|
+
p_alt_only = p_alt * (1 - p_ref)
|
|
657
|
+
probas = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
|
|
658
|
+
probas = probas / probas.sum(dim=-1, keepdim=True).clamp_min(1e-8)
|
|
659
|
+
labels = torch.argmax(probas, dim=-1)
|
|
612
660
|
|
|
613
661
|
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
614
662
|
|
|
@@ -798,6 +846,44 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
798
846
|
dataset, batch_size=self.batch_size, shuffle=True
|
|
799
847
|
)
|
|
800
848
|
|
|
849
|
+
def _encode_multilabel_inputs(self, y: torch.Tensor) -> torch.Tensor:
|
|
850
|
+
"""Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
|
|
851
|
+
if self.is_haploid:
|
|
852
|
+
return self._one_hot_encode_012(y)
|
|
853
|
+
y = y.to(self.device)
|
|
854
|
+
shape = y.shape + (2,)
|
|
855
|
+
out = torch.zeros(shape, device=self.device, dtype=torch.float32)
|
|
856
|
+
valid = y != -1
|
|
857
|
+
ref_mask = valid & (y != 2)
|
|
858
|
+
alt_mask = valid & (y != 0)
|
|
859
|
+
out[ref_mask, 0] = 1.0
|
|
860
|
+
out[alt_mask, 1] = 1.0
|
|
861
|
+
return out
|
|
862
|
+
|
|
863
|
+
def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
|
|
864
|
+
"""Targets aligned with _encode_multilabel_inputs for diploid training."""
|
|
865
|
+
if self.is_haploid:
|
|
866
|
+
raise RuntimeError("_multi_hot_targets called for haploid data.")
|
|
867
|
+
y = y.to(self.device)
|
|
868
|
+
out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
|
|
869
|
+
valid = y != -1
|
|
870
|
+
ref_mask = valid & (y != 2)
|
|
871
|
+
alt_mask = valid & (y != 0)
|
|
872
|
+
out[ref_mask, 0] = 1.0
|
|
873
|
+
out[alt_mask, 1] = 1.0
|
|
874
|
+
return out
|
|
875
|
+
|
|
876
|
+
def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
|
|
877
|
+
"""Balance REF/ALT channels for multilabel BCE."""
|
|
878
|
+
ref_pos = np.count_nonzero((X == 0) | (X == 1))
|
|
879
|
+
alt_pos = np.count_nonzero((X == 2) | (X == 1))
|
|
880
|
+
total_valid = np.count_nonzero(X != -1)
|
|
881
|
+
pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
|
|
882
|
+
neg_counts = np.maximum(total_valid - pos_counts, 1.0)
|
|
883
|
+
pos_counts = np.maximum(pos_counts, 1.0)
|
|
884
|
+
weights = neg_counts / pos_counts
|
|
885
|
+
return torch.tensor(weights, device=self.device, dtype=torch.float32)
|
|
886
|
+
|
|
801
887
|
def _create_latent_space(
|
|
802
888
|
self,
|
|
803
889
|
params: dict,
|
|
@@ -894,71 +980,80 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
894
980
|
Returns:
|
|
895
981
|
float: The value of the tuning metric to be minimized or maximized.
|
|
896
982
|
"""
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
983
|
+
try:
|
|
984
|
+
self._prepare_tuning_artifacts()
|
|
985
|
+
trial_params = self._sample_hyperparameters(trial)
|
|
986
|
+
model_params = trial_params["model_params"]
|
|
987
|
+
|
|
988
|
+
nfeat = self._tune_num_features
|
|
989
|
+
if self.tune and self.tune_fast:
|
|
990
|
+
model_params["n_features"] = nfeat
|
|
991
|
+
|
|
992
|
+
lr = trial_params["lr"]
|
|
993
|
+
l1_penalty = trial_params["l1_penalty"]
|
|
994
|
+
lr_input_fac = trial_params["lr_input_factor"]
|
|
995
|
+
|
|
996
|
+
X_train_trial = self._tune_X_train
|
|
997
|
+
X_test_trial = self._tune_X_test
|
|
998
|
+
class_weights = self._tune_class_weights
|
|
999
|
+
train_loader = self._tune_loader
|
|
1000
|
+
|
|
1001
|
+
train_latents = self._create_latent_space(
|
|
1002
|
+
model_params,
|
|
1003
|
+
len(X_train_trial),
|
|
1004
|
+
X_train_trial,
|
|
1005
|
+
trial_params["latent_init"],
|
|
1006
|
+
)
|
|
917
1007
|
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
1008
|
+
model = self.build_model(self.Model, model_params)
|
|
1009
|
+
model.n_features = model_params["n_features"]
|
|
1010
|
+
model.apply(self.initialize_weights)
|
|
921
1011
|
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
1012
|
+
_, model, __ = self._train_and_validate_model(
|
|
1013
|
+
model=model,
|
|
1014
|
+
loader=train_loader,
|
|
1015
|
+
lr=lr,
|
|
1016
|
+
l1_penalty=l1_penalty,
|
|
1017
|
+
trial=trial,
|
|
1018
|
+
latent_vectors=train_latents,
|
|
1019
|
+
lr_input_factor=lr_input_fac,
|
|
1020
|
+
class_weights=class_weights,
|
|
1021
|
+
X_val=X_test_trial,
|
|
1022
|
+
params=model_params,
|
|
1023
|
+
prune_metric=self.tune_metric,
|
|
1024
|
+
prune_warmup_epochs=10,
|
|
1025
|
+
eval_interval=self.tune_eval_interval,
|
|
1026
|
+
eval_latent_steps=self.eval_latent_steps,
|
|
1027
|
+
eval_latent_lr=self.eval_latent_lr,
|
|
1028
|
+
eval_latent_weight_decay=self.eval_latent_weight_decay,
|
|
1029
|
+
)
|
|
940
1030
|
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
1031
|
+
# --- simulate-only eval mask for tuning ---
|
|
1032
|
+
eval_mask = None
|
|
1033
|
+
if (
|
|
1034
|
+
self.simulate_missing
|
|
1035
|
+
and getattr(self, "sim_mask_global_", None) is not None
|
|
1036
|
+
):
|
|
1037
|
+
if (
|
|
1038
|
+
hasattr(self, "_tune_test_idx")
|
|
1039
|
+
and self.sim_mask_global_ is not None
|
|
1040
|
+
):
|
|
1041
|
+
eval_mask = self.sim_mask_global_[self._tune_test_idx]
|
|
1042
|
+
elif getattr(self, "sim_mask_test_", None) is not None:
|
|
1043
|
+
eval_mask = self.sim_mask_test_
|
|
1044
|
+
|
|
1045
|
+
metrics = self._evaluate_model(
|
|
1046
|
+
X_test_trial,
|
|
1047
|
+
model,
|
|
1048
|
+
model_params,
|
|
1049
|
+
objective_mode=True,
|
|
1050
|
+
eval_mask_override=eval_mask,
|
|
1051
|
+
)
|
|
959
1052
|
|
|
960
|
-
|
|
961
|
-
|
|
1053
|
+
self._clear_resources(model, train_loader, latent_vectors=train_latents)
|
|
1054
|
+
return metrics[self.tune_metric]
|
|
1055
|
+
except Exception as e:
|
|
1056
|
+
raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
|
|
962
1057
|
|
|
963
1058
|
def _sample_hyperparameters(self, trial: optuna.Trial) -> Dict[str, Any]:
|
|
964
1059
|
"""Samples hyperparameters for the simplified NLPCA model.
|
|
@@ -972,23 +1067,23 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
972
1067
|
Dict[str, int | float | str | list]: A dictionary of sampled hyperparameters.
|
|
973
1068
|
"""
|
|
974
1069
|
params = {
|
|
975
|
-
"latent_dim": trial.suggest_int("latent_dim",
|
|
976
|
-
"lr": trial.suggest_float("learning_rate",
|
|
977
|
-
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.
|
|
978
|
-
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1,
|
|
1070
|
+
"latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
|
|
1071
|
+
"lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
|
|
1072
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30),
|
|
1073
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 4),
|
|
979
1074
|
"activation": trial.suggest_categorical(
|
|
980
|
-
"activation", ["relu", "elu", "selu"
|
|
1075
|
+
"activation", ["relu", "elu", "selu"]
|
|
981
1076
|
),
|
|
982
|
-
"gamma": trial.suggest_float("gamma", 0.
|
|
1077
|
+
"gamma": trial.suggest_float("gamma", 0.5, 3.0, step=0.5),
|
|
983
1078
|
"lr_input_factor": trial.suggest_float(
|
|
984
|
-
"lr_input_factor", 0.
|
|
1079
|
+
"lr_input_factor", 0.3, 3.0, log=True
|
|
985
1080
|
),
|
|
986
1081
|
"l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
|
|
987
1082
|
"layer_scaling_factor": trial.suggest_float(
|
|
988
|
-
"layer_scaling_factor", 2.0,
|
|
1083
|
+
"layer_scaling_factor", 2.0, 4.0, step=0.5
|
|
989
1084
|
),
|
|
990
1085
|
"layer_schedule": trial.suggest_categorical(
|
|
991
|
-
"layer_schedule", ["pyramid", "
|
|
1086
|
+
"layer_schedule", ["pyramid", "linear"]
|
|
992
1087
|
),
|
|
993
1088
|
"latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
|
|
994
1089
|
}
|
|
@@ -1006,7 +1101,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
1006
1101
|
|
|
1007
1102
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1008
1103
|
n_inputs=params["latent_dim"],
|
|
1009
|
-
n_outputs=use_n_features * self.
|
|
1104
|
+
n_outputs=use_n_features * self.output_classes_,
|
|
1010
1105
|
n_samples=use_n_samples,
|
|
1011
1106
|
n_hidden=params["num_hidden_layers"],
|
|
1012
1107
|
alpha=params["layer_scaling_factor"],
|
|
@@ -1015,7 +1110,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
1015
1110
|
|
|
1016
1111
|
params["model_params"] = {
|
|
1017
1112
|
"n_features": use_n_features,
|
|
1018
|
-
"num_classes": self.
|
|
1113
|
+
"num_classes": self.output_classes_,
|
|
1019
1114
|
"latent_dim": params["latent_dim"],
|
|
1020
1115
|
"dropout_rate": params["dropout_rate"],
|
|
1021
1116
|
"hidden_layer_sizes": hidden_layer_sizes,
|
|
@@ -1043,10 +1138,11 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
1043
1138
|
self.lr_input_factor = best_params["lr_input_factor"]
|
|
1044
1139
|
self.l1_penalty = best_params["l1_penalty"]
|
|
1045
1140
|
self.activation = best_params["activation"]
|
|
1141
|
+
self.latent_init = best_params["latent_init"]
|
|
1046
1142
|
|
|
1047
1143
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1048
1144
|
n_inputs=self.latent_dim,
|
|
1049
|
-
n_outputs=self.num_features_ * self.
|
|
1145
|
+
n_outputs=self.num_features_ * self.output_classes_,
|
|
1050
1146
|
n_samples=len(self.train_idx_),
|
|
1051
1147
|
n_hidden=best_params["num_hidden_layers"],
|
|
1052
1148
|
alpha=best_params["layer_scaling_factor"],
|
|
@@ -1060,7 +1156,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
1060
1156
|
"dropout_rate": self.dropout_rate,
|
|
1061
1157
|
"activation": self.activation,
|
|
1062
1158
|
"gamma": self.gamma,
|
|
1063
|
-
"num_classes": self.
|
|
1159
|
+
"num_classes": self.output_classes_,
|
|
1064
1160
|
}
|
|
1065
1161
|
|
|
1066
1162
|
def _set_best_params_default(self) -> Dict[str, int | float | str | list]:
|
|
@@ -1073,7 +1169,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
1073
1169
|
"""
|
|
1074
1170
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1075
1171
|
n_inputs=self.latent_dim,
|
|
1076
|
-
n_outputs=self.num_features_ * self.
|
|
1172
|
+
n_outputs=self.num_features_ * self.output_classes_,
|
|
1077
1173
|
n_samples=len(self.ground_truth_),
|
|
1078
1174
|
n_hidden=self.num_hidden_layers,
|
|
1079
1175
|
alpha=self.layer_scaling_factor,
|
|
@@ -1087,7 +1183,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
1087
1183
|
"dropout_rate": self.dropout_rate,
|
|
1088
1184
|
"activation": self.activation,
|
|
1089
1185
|
"gamma": self.gamma,
|
|
1090
|
-
"num_classes": self.
|
|
1186
|
+
"num_classes": self.output_classes_,
|
|
1091
1187
|
}
|
|
1092
1188
|
|
|
1093
1189
|
def _train_and_validate_model(
|
|
@@ -1105,7 +1201,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
1105
1201
|
X_val: np.ndarray | None = None,
|
|
1106
1202
|
params: dict | None = None,
|
|
1107
1203
|
prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
|
|
1108
|
-
prune_warmup_epochs: int =
|
|
1204
|
+
prune_warmup_epochs: int = 10,
|
|
1109
1205
|
eval_interval: int = 1,
|
|
1110
1206
|
eval_latent_steps: int = 50,
|
|
1111
1207
|
eval_latent_lr: float = 1e-2,
|
|
@@ -1222,7 +1318,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
1222
1318
|
X_val=self.X_test_,
|
|
1223
1319
|
params=best_params,
|
|
1224
1320
|
prune_metric=self.tune_metric,
|
|
1225
|
-
prune_warmup_epochs=
|
|
1321
|
+
prune_warmup_epochs=10,
|
|
1226
1322
|
eval_interval=1,
|
|
1227
1323
|
eval_latent_steps=self.eval_latent_steps,
|
|
1228
1324
|
eval_latent_lr=self.eval_latent_lr,
|
|
@@ -1255,7 +1351,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
1255
1351
|
X_val: np.ndarray | None = None,
|
|
1256
1352
|
params: dict | None = None,
|
|
1257
1353
|
prune_metric: str | None = None,
|
|
1258
|
-
prune_warmup_epochs: int =
|
|
1354
|
+
prune_warmup_epochs: int = 10,
|
|
1259
1355
|
eval_interval: int = 1,
|
|
1260
1356
|
eval_latent_steps: int = 50,
|
|
1261
1357
|
eval_latent_lr: float = 1e-2,
|
|
@@ -1435,17 +1531,25 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
1435
1531
|
self.logger.error(msg)
|
|
1436
1532
|
raise TypeError(msg)
|
|
1437
1533
|
|
|
1438
|
-
logits = decoder(z).view(len(X_new), nF, self.
|
|
1534
|
+
logits = decoder(z).view(len(X_new), nF, self.output_classes_)
|
|
1439
1535
|
|
|
1440
1536
|
if not torch.isfinite(logits).all():
|
|
1441
1537
|
break
|
|
1442
1538
|
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1539
|
+
if self.is_haploid:
|
|
1540
|
+
loss = F.cross_entropy(
|
|
1541
|
+
logits.view(-1, self.output_classes_),
|
|
1542
|
+
y.view(-1),
|
|
1543
|
+
ignore_index=-1,
|
|
1544
|
+
reduction="mean",
|
|
1545
|
+
)
|
|
1546
|
+
else:
|
|
1547
|
+
targets = self._multi_hot_targets(y)
|
|
1548
|
+
bce = F.binary_cross_entropy_with_logits(
|
|
1549
|
+
logits, targets, pos_weight=self.pos_weights_, reduction="none"
|
|
1550
|
+
)
|
|
1551
|
+
mask = (y != -1).unsqueeze(-1).float()
|
|
1552
|
+
loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
|
|
1449
1553
|
if not torch.isfinite(loss):
|
|
1450
1554
|
break
|
|
1451
1555
|
|
|
@@ -1499,7 +1603,7 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
1499
1603
|
y = torch.from_numpy(X_val).long().to(self.device)
|
|
1500
1604
|
|
|
1501
1605
|
latent_dim = self._first_linear_in_features(model)
|
|
1502
|
-
cache_key = f"{self.prefix}_nlpca_val_latents_z{latent_dim}_L{self.num_features_}_K{self.
|
|
1606
|
+
cache_key = f"{self.prefix}_nlpca_val_latents_z{latent_dim}_L{self.num_features_}_K{self.output_classes_}"
|
|
1503
1607
|
|
|
1504
1608
|
if cache is not None and cache_key in cache:
|
|
1505
1609
|
z = cache[cache_key].detach().clone().requires_grad_(True)
|
|
@@ -1523,17 +1627,25 @@ class ImputeNLPCA(BaseNNImputer):
|
|
|
1523
1627
|
self.logger.error(msg)
|
|
1524
1628
|
raise TypeError(msg)
|
|
1525
1629
|
|
|
1526
|
-
logits = decoder(z).view(X_val.shape[0], nF, self.
|
|
1630
|
+
logits = decoder(z).view(X_val.shape[0], nF, self.output_classes_)
|
|
1527
1631
|
|
|
1528
1632
|
if not torch.isfinite(logits).all():
|
|
1529
1633
|
break
|
|
1530
1634
|
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1635
|
+
if self.is_haploid:
|
|
1636
|
+
loss = F.cross_entropy(
|
|
1637
|
+
logits.view(-1, self.output_classes_),
|
|
1638
|
+
y.view(-1),
|
|
1639
|
+
ignore_index=-1,
|
|
1640
|
+
reduction="mean",
|
|
1641
|
+
)
|
|
1642
|
+
else:
|
|
1643
|
+
targets = self._multi_hot_targets(y)
|
|
1644
|
+
bce = F.binary_cross_entropy_with_logits(
|
|
1645
|
+
logits, targets, pos_weight=self.pos_weights_, reduction="none"
|
|
1646
|
+
)
|
|
1647
|
+
mask = (y != -1).unsqueeze(-1).float()
|
|
1648
|
+
loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
|
|
1537
1649
|
|
|
1538
1650
|
if not torch.isfinite(loss):
|
|
1539
1651
|
break
|