pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.0.dist-info/RECORD +0 -75
  83. pg_sui-0.2.0.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
@@ -0,0 +1,261 @@
1
+ from typing import List, Literal
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class SafeFocalCELoss(nn.Module):
9
+ """Focal cross-entropy with ignore_index and numeric guards.
10
+
11
+ This class implements the focal loss function, which is designed to address class imbalance by down-weighting easy examples and focusing training on hard negatives. It also includes handling for ignored indices and numeric stability.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ gamma: float,
17
+ weight: torch.Tensor | None = None,
18
+ ignore_index: int = -1,
19
+ eps: float = 1e-8,
20
+ ):
21
+ """Initialize the SafeFocalCELoss.
22
+
23
+ This class sets up the focal loss with specified focusing parameter, class weights, ignore index, and a small epsilon for numerical stability.
24
+
25
+ Args:
26
+ gamma (float): Focusing parameter.
27
+ weight (torch.Tensor | None): A manual rescaling weight given to each class. If given, has to be a Tensor of size C (number of classes). Defaults to None.
28
+ ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. Default is -1.
29
+ eps (float): Small value to avoid numerical issues. Default is 1e-8.
30
+ """
31
+ super().__init__()
32
+ self.gamma = gamma
33
+ self.weight = weight
34
+ self.ignore_index = ignore_index
35
+ self.eps = eps
36
+
37
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
38
+ """Calculates the focal loss on pre-flattened tensors.
39
+
40
+ Args:
41
+ logits (torch.Tensor): Logits from the model of shape (N, C) where N is the number of samples and C is the number of classes.
42
+ targets (torch.Tensor): Ground truth labels of shape (N,).
43
+
44
+ Returns:
45
+ torch.Tensor: The computed scalar loss value.
46
+ """
47
+ # logits: (N, C), targets: (N,)
48
+ valid = targets != self.ignore_index
49
+
50
+ if not valid.any():
51
+ return logits.new_tensor(0.0)
52
+
53
+ logits_v = logits[valid]
54
+ targets_v = targets[valid]
55
+
56
+ logp = F.log_softmax(logits_v, dim=-1) # stable
57
+ ce = F.nll_loss(logp, targets_v, weight=self.weight, reduction="none")
58
+
59
+ # p_t = exp(logp[range, targets])
60
+ p_t = torch.exp(logp.gather(1, targets_v.unsqueeze(1)).squeeze(1))
61
+
62
+ # focal factor with clamp to avoid 0**gamma and NaNs
63
+ focal = (1.0 - p_t).clamp_min(self.eps).pow(self.gamma)
64
+
65
+ loss_vec = focal * ce
66
+
67
+ # guard remaining inf/nan if any slipped through
68
+ loss_vec = torch.nan_to_num(loss_vec, nan=0.0, posinf=1e6, neginf=0.0)
69
+ return loss_vec.mean()
70
+
71
+
72
+ class WeightedMaskedCCELoss(nn.Module):
73
+ def __init__(
74
+ self,
75
+ alpha: float | List[float] | torch.Tensor | None = None,
76
+ reduction: Literal["mean", "sum"] = "mean",
77
+ ):
78
+ """A weighted, masked Categorical Cross-Entropy loss function.
79
+
80
+ This method computes the categorical cross-entropy loss while allowing for class weights and masking of invalid (missing) entries. It is particularly useful for sequence data where some positions may be missing or should not contribute to the loss calculation.
81
+
82
+ Args:
83
+ alpha (float | List | Tensor | None): A manual rescaling weight given to each class. If given, has to be a Tensor of size C (number of classes). Defaults to None.
84
+ reduction (str, optional): Specifies the reduction to apply to the output: 'mean' or 'sum'. Defaults to "mean".
85
+ """
86
+ super(WeightedMaskedCCELoss, self).__init__()
87
+ self.reduction = reduction
88
+ self.alpha = alpha
89
+
90
+ def forward(
91
+ self,
92
+ logits: torch.Tensor,
93
+ targets: torch.Tensor,
94
+ valid_mask: torch.Tensor | None = None,
95
+ ) -> torch.Tensor:
96
+ """Compute the masked categorical cross-entropy loss.
97
+
98
+ Args:
99
+ logits (torch.Tensor): Logits from the model of shape
100
+ (batch_size, seq_len, num_classes).
101
+ targets (torch.Tensor): Ground truth labels of shape (batch_size, seq_len).
102
+ valid_mask (torch.Tensor, optional): Boolean mask of shape (batch_size, seq_len) where True indicates a valid (observed) value to include in the loss.
103
+ Defaults to None, in which case all values are considered valid.
104
+
105
+ Returns:
106
+ torch.Tensor: The computed scalar loss value.
107
+ """
108
+ # Automatically detect the device from the input tensor
109
+ device = logits.device
110
+ num_classes = logits.shape[-1]
111
+
112
+ # Ensure targets are on the correct device and are Long type
113
+ targets = targets.to(device).long()
114
+
115
+ # Prepare weights and pass them directly to the loss function
116
+ class_weights = None
117
+ if self.alpha is not None:
118
+ if not isinstance(self.alpha, torch.Tensor):
119
+ class_weights = torch.tensor(
120
+ self.alpha, dtype=torch.float, device=device
121
+ )
122
+ else:
123
+ class_weights = self.alpha.to(device)
124
+
125
+ loss = F.cross_entropy(
126
+ logits.reshape(-1, num_classes),
127
+ targets.reshape(-1),
128
+ weight=class_weights,
129
+ reduction="none",
130
+ ignore_index=-1, # Ignore all targets with the value -1
131
+ )
132
+
133
+ # If a mask is provided, filter the losses for the training set
134
+ if valid_mask is not None:
135
+ loss = loss[valid_mask.reshape(-1)]
136
+
137
+ # If after masking no valid losses remain, return 0
138
+ if loss.numel() == 0:
139
+ return torch.tensor(0.0, device=device)
140
+
141
+ # Apply the final reduction
142
+ if self.reduction == "mean":
143
+ return loss.mean()
144
+ elif self.reduction == "sum":
145
+ return loss.sum()
146
+ else:
147
+ msg = f"Reduction mode '{self.reduction}' not supported."
148
+ raise ValueError(msg)
149
+
150
+
151
+ class MaskedFocalLoss(nn.Module):
152
+ """Focal loss (gamma > 0) with optional class weights and a boolean valid mask.
153
+
154
+ This method implements the focal loss function, which is designed to address class imbalance by down-weighting easy examples and focusing training on hard negatives. It also supports masking of invalid (missing) entries, making it suitable for sequence data with missing values.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ gamma: float = 2.0,
160
+ alpha: torch.Tensor | None = None,
161
+ reduction: Literal["mean", "sum"] = "mean",
162
+ ):
163
+ """Initialize the MaskedFocalLoss.
164
+
165
+ This class sets up the focal loss with specified focusing parameter, class weights, and reduction method. It is designed to handle missing data through a valid mask, ensuring that only relevant entries contribute to the loss calculation.
166
+
167
+ Args:
168
+ gamma (float): Focusing parameter.
169
+ alpha (torch.Tensor | None): Class weights.
170
+ reduction (Literal["mean", "sum"]): Reduction mode ('mean' or 'sum').
171
+ """
172
+ super().__init__()
173
+ self.gamma = gamma
174
+ self.alpha = alpha
175
+ self.reduction = reduction
176
+
177
+ def forward(
178
+ self,
179
+ logits: torch.Tensor, # Expects (N, C) where N = batch*features
180
+ targets: torch.Tensor, # Expects (N,)
181
+ valid_mask: torch.Tensor, # Expects (N,)
182
+ ) -> torch.Tensor:
183
+ """Calculates the focal loss on pre-flattened tensors.
184
+
185
+ Args:
186
+ logits (torch.Tensor): Logits from the model of shape (N, C) where N is the number of samples (batch_size * seq_len) and C is the number of classes.
187
+ targets (torch.Tensor): Ground truth labels of shape (N,).
188
+ valid_mask (torch.Tensor): Boolean mask of shape (N,) where True indicates a valid (observed) value to include in the loss.
189
+
190
+ Returns:
191
+ torch.Tensor: The computed scalar loss value.
192
+ """
193
+ device = logits.device
194
+
195
+ # Calculate standard cross-entropy loss per-token (no reduction)
196
+ ce = F.cross_entropy(
197
+ logits,
198
+ targets,
199
+ weight=(self.alpha.to(device) if self.alpha is not None else None),
200
+ reduction="none",
201
+ ignore_index=-1,
202
+ )
203
+
204
+ # Calculate p_t from the cross-entropy loss
205
+ pt = torch.exp(-ce)
206
+ focal = ((1 - pt) ** self.gamma) * ce
207
+
208
+ # Apply the valid mask. We select only the elements that should contribute to the loss.
209
+ focal = focal[valid_mask]
210
+
211
+ # Return early if no valid elements exist to avoid NaN results
212
+ if focal.numel() == 0:
213
+ return torch.tensor(0.0, device=device)
214
+
215
+ # Apply reduction
216
+ if self.reduction == "mean":
217
+ return focal.mean()
218
+ elif self.reduction == "sum":
219
+ return focal.sum()
220
+ else:
221
+ msg = f"Reduction mode '{self.reduction}' not supported."
222
+ raise ValueError(msg)
223
+
224
+
225
+ def safe_kl_gauss_unit(
226
+ mu: torch.Tensor, logvar: torch.Tensor, reduction: str = "mean"
227
+ ) -> torch.Tensor:
228
+ """KL divergence between N(mu, exp(logvar)) and N(0, I) with guards."""
229
+ logvar = logvar.clamp(min=-30.0, max=20.0)
230
+ kl = -0.5 * (1.0 + logvar - mu.pow(2) - logvar.exp())
231
+ if reduction == "sum":
232
+ kl = kl.sum()
233
+ elif reduction == "mean":
234
+ kl = kl.mean()
235
+ return torch.nan_to_num(kl, nan=0.0, posinf=1e6, neginf=0.0)
236
+
237
+
238
+ def compute_vae_loss(
239
+ recon_logits: torch.Tensor,
240
+ targets: torch.Tensor,
241
+ *,
242
+ mu: torch.Tensor,
243
+ logvar: torch.Tensor,
244
+ class_weights: torch.Tensor | None,
245
+ gamma: float,
246
+ beta: float,
247
+ ignore_index: int = -1,
248
+ ) -> torch.Tensor:
249
+ """Focal reconstruction + beta * KL with normalized class weights."""
250
+ cw = None
251
+ if class_weights is not None:
252
+ cw = class_weights / class_weights.mean().clamp_min(1e-8)
253
+
254
+ criterion = SafeFocalCELoss(
255
+ gamma=gamma,
256
+ weight=cw,
257
+ ignore_index=ignore_index,
258
+ )
259
+ rec = criterion(recon_logits.view(-1, recon_logits.size(-1)), targets.view(-1))
260
+ kl = safe_kl_gauss_unit(mu, logvar, reduction="mean")
261
+ return rec + beta * kl
File without changes