pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +577 -125
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +203 -530
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1269 -534
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
- pgsui/impute/unsupervised/imputers/vae.py +931 -787
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import numpy as np
|
|
1
2
|
from snpio.utils.logging import LoggerManager
|
|
2
3
|
|
|
3
4
|
from pgsui.utils.logging_utils import configure_logger
|
|
@@ -23,7 +24,7 @@ class EarlyStopping:
|
|
|
23
24
|
delta: float = 0.0,
|
|
24
25
|
verbose: int = 0,
|
|
25
26
|
mode: str = "min",
|
|
26
|
-
min_epochs: int =
|
|
27
|
+
min_epochs: int = 150,
|
|
27
28
|
prefix: str = "pgsui_output",
|
|
28
29
|
debug: bool = False,
|
|
29
30
|
):
|
|
@@ -53,7 +54,7 @@ class EarlyStopping:
|
|
|
53
54
|
self.epoch_count = 0
|
|
54
55
|
self.best_score = float("inf") if mode == "min" else 0.0
|
|
55
56
|
self.early_stop = False
|
|
56
|
-
self.
|
|
57
|
+
self.best_state_dict: dict | None = None
|
|
57
58
|
self.min_epochs = min_epochs
|
|
58
59
|
|
|
59
60
|
is_verbose = verbose >= 2 or debug
|
|
@@ -72,45 +73,39 @@ class EarlyStopping:
|
|
|
72
73
|
self.logger.error(msg)
|
|
73
74
|
raise ValueError(msg)
|
|
74
75
|
|
|
75
|
-
def __call__(self, score, model):
|
|
76
|
-
"""
|
|
76
|
+
def __call__(self, score, model, *, epoch: int | None = None):
|
|
77
|
+
"""Update early stopping state.
|
|
77
78
|
|
|
78
79
|
Args:
|
|
79
|
-
score
|
|
80
|
-
model
|
|
80
|
+
score: Monitored metric value.
|
|
81
|
+
model: Model to checkpoint.
|
|
82
|
+
epoch: If provided, sets the internal epoch counter to the true epoch number.
|
|
81
83
|
"""
|
|
82
|
-
|
|
83
|
-
|
|
84
|
+
if epoch is not None:
|
|
85
|
+
self.epoch_count = int(epoch)
|
|
86
|
+
else:
|
|
87
|
+
self.epoch_count += 1
|
|
88
|
+
|
|
89
|
+
# Treat non-finite scores as non-improvements
|
|
90
|
+
try:
|
|
91
|
+
score_f = float(score)
|
|
92
|
+
except Exception:
|
|
93
|
+
score_f = float("inf") if self.mode == "min" else float("-inf")
|
|
84
94
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
self.
|
|
95
|
+
if not np.isfinite(score_f):
|
|
96
|
+
self.counter += 1
|
|
97
|
+
if self.counter >= self.patience and self.epoch_count >= self.min_epochs:
|
|
98
|
+
self.early_stop = True
|
|
88
99
|
return
|
|
89
100
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
#
|
|
93
|
-
self.
|
|
94
|
-
|
|
101
|
+
if self.monitor(score_f, self.best_score):
|
|
102
|
+
self.best_score = score_f
|
|
103
|
+
# THIS is the real checkpoint:
|
|
104
|
+
self.best_state_dict = {
|
|
105
|
+
k: v.detach().cpu().clone() for k, v in model.state_dict().items()
|
|
106
|
+
}
|
|
95
107
|
self.counter = 0
|
|
96
108
|
else:
|
|
97
|
-
# No improvement: increase counter
|
|
98
109
|
self.counter += 1
|
|
99
|
-
|
|
100
|
-
if self.verbose:
|
|
101
|
-
self.logger.info(
|
|
102
|
-
f"EarlyStopping counter: {self.counter}/{self.patience}"
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
# Now check if we surpass patience AND have reached min_epochs
|
|
106
110
|
if self.counter >= self.patience and self.epoch_count >= self.min_epochs:
|
|
107
|
-
|
|
108
|
-
if self.best_model is None:
|
|
109
|
-
self.best_model = model
|
|
110
|
-
|
|
111
111
|
self.early_stop = True
|
|
112
|
-
|
|
113
|
-
if self.verbose:
|
|
114
|
-
self.logger.info(
|
|
115
|
-
f"Early stopping triggered at epoch {self.epoch_count}"
|
|
116
|
-
)
|