pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
  2. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +577 -125
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +203 -530
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1269 -534
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
  16. pgsui/impute/unsupervised/imputers/vae.py +931 -787
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
  27. pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
  28. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  29. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  30. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  31. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  32. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  33. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import numpy as np
1
2
  from snpio.utils.logging import LoggerManager
2
3
 
3
4
  from pgsui.utils.logging_utils import configure_logger
@@ -23,7 +24,7 @@ class EarlyStopping:
23
24
  delta: float = 0.0,
24
25
  verbose: int = 0,
25
26
  mode: str = "min",
26
- min_epochs: int = 100,
27
+ min_epochs: int = 150,
27
28
  prefix: str = "pgsui_output",
28
29
  debug: bool = False,
29
30
  ):
@@ -53,7 +54,7 @@ class EarlyStopping:
53
54
  self.epoch_count = 0
54
55
  self.best_score = float("inf") if mode == "min" else 0.0
55
56
  self.early_stop = False
56
- self.best_model = None
57
+ self.best_state_dict: dict | None = None
57
58
  self.min_epochs = min_epochs
58
59
 
59
60
  is_verbose = verbose >= 2 or debug
@@ -72,45 +73,39 @@ class EarlyStopping:
72
73
  self.logger.error(msg)
73
74
  raise ValueError(msg)
74
75
 
75
- def __call__(self, score, model):
76
- """Checks if early stopping condition is met and checkpoints model accordingly.
76
+ def __call__(self, score, model, *, epoch: int | None = None):
77
+ """Update early stopping state.
77
78
 
78
79
  Args:
79
- score (float): The current metric value (e.g., validation loss/accuracy).
80
- model (torch.nn.Module): The model being trained.
80
+ score: Monitored metric value.
81
+ model: Model to checkpoint.
82
+ epoch: If provided, sets the internal epoch counter to the true epoch number.
81
83
  """
82
- # Increment the epoch count each time we call this function
83
- self.epoch_count += 1
84
+ if epoch is not None:
85
+ self.epoch_count = int(epoch)
86
+ else:
87
+ self.epoch_count += 1
88
+
89
+ # Treat non-finite scores as non-improvements
90
+ try:
91
+ score_f = float(score)
92
+ except Exception:
93
+ score_f = float("inf") if self.mode == "min" else float("-inf")
84
94
 
85
- # If this is the first epoch, initialize best_score and save model
86
- if self.best_score is None:
87
- self.best_score = score
95
+ if not np.isfinite(score_f):
96
+ self.counter += 1
97
+ if self.counter >= self.patience and self.epoch_count >= self.min_epochs:
98
+ self.early_stop = True
88
99
  return
89
100
 
90
- # Check if there is improvement
91
- if self.monitor(score, self.best_score):
92
- # If improved, reset counter and update the best score/model
93
- self.best_score = score
94
- self.best_model = model
101
+ if self.monitor(score_f, self.best_score):
102
+ self.best_score = score_f
103
+ # THIS is the real checkpoint:
104
+ self.best_state_dict = {
105
+ k: v.detach().cpu().clone() for k, v in model.state_dict().items()
106
+ }
95
107
  self.counter = 0
96
108
  else:
97
- # No improvement: increase counter
98
109
  self.counter += 1
99
-
100
- if self.verbose:
101
- self.logger.info(
102
- f"EarlyStopping counter: {self.counter}/{self.patience}"
103
- )
104
-
105
- # Now check if we surpass patience AND have reached min_epochs
106
110
  if self.counter >= self.patience and self.epoch_count >= self.min_epochs:
107
-
108
- if self.best_model is None:
109
- self.best_model = model
110
-
111
111
  self.early_stop = True
112
-
113
- if self.verbose:
114
- self.logger.info(
115
- f"Early stopping triggered at epoch {self.epoch_count}"
116
- )