pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pg-sui might be problematic. Click here for more details.
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
- pg_sui-1.6.8.dist-info/RECORD +78 -0
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
- pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
- pg_sui-1.6.8.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +635 -0
- pgsui/data_processing/config.py +576 -0
- pgsui/data_processing/containers.py +1782 -0
- pgsui/data_processing/transformers.py +121 -1103
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +189 -0
- pgsui/electron/app/package-lock.json +6893 -0
- pgsui/electron/app/package.json +50 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +146 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +130 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +59 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
- pgsui/impute/deterministic/imputers/mode.py +679 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +971 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
- pgsui/impute/supervised/base.py +339 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
- pgsui/impute/supervised/imputers/random_forest.py +287 -0
- pgsui/impute/unsupervised/base.py +924 -0
- pgsui/impute/unsupervised/callbacks.py +89 -263
- pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
- pgsui/impute/unsupervised/imputers/vae.py +957 -0
- pgsui/impute/unsupervised/loss_functions.py +158 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
- pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
- pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
- pgsui/impute/unsupervised/models/vae_model.py +259 -618
- pgsui/impute/unsupervised/nn_scorers.py +215 -0
- pgsui/utils/classification_viz.py +591 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +514 -824
- pgsui/utils/scorers.py +212 -438
- pg_sui-1.0.2.1.dist-info/RECORD +0 -75
- pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -735
- pgsui/impute/impute.py +0 -1486
- pgsui/impute/simple_imputers.py +0 -1439
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
- pgsui/impute/unsupervised/keras_classifiers.py +0 -702
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -297
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -214
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
- /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
|
@@ -0,0 +1,691 @@
|
|
|
1
|
+
# Standard library imports
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union
|
|
3
|
+
|
|
4
|
+
# Third-party imports
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from sklearn.exceptions import NotFittedError
|
|
8
|
+
from sklearn.metrics import (
|
|
9
|
+
accuracy_score,
|
|
10
|
+
classification_report,
|
|
11
|
+
f1_score,
|
|
12
|
+
precision_score,
|
|
13
|
+
recall_score,
|
|
14
|
+
)
|
|
15
|
+
from snpio import GenotypeEncoder
|
|
16
|
+
|
|
17
|
+
# Local imports
|
|
18
|
+
from pgsui.utils.plotting import Plotting
|
|
19
|
+
|
|
20
|
+
# Type checking imports
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from snpio.read_input.genotype_data import GenotypeData
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ImputeAlleleFreq:
|
|
26
|
+
"""Frequency-based imputer for integer-encoded categorical genotype data with test-only evaluation.
|
|
27
|
+
|
|
28
|
+
This implementation imputes missing values by sampling from the empirical frequency distribution of observed integer codes (e.g., 0-9 for nucleotides). It supports a strict train/test protocol: distributions are learned on the train split, missingness is simulated only on the test split, and all metrics are computed exclusively on the test split.
|
|
29
|
+
|
|
30
|
+
The algorithm is as follows:
|
|
31
|
+
1. Split the dataset into train and test sets (row-wise).
|
|
32
|
+
2. For each feature (column), compute the empirical frequency distribution of observed values in the train set.
|
|
33
|
+
3. On the test set, simulate additional missing values by randomly masking a specified proportion of observed entries.
|
|
34
|
+
4. Impute missing values in the test set by sampling from the train-learned distributions.
|
|
35
|
+
5. Evaluate imputation accuracy using various metrics on the test set only.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
genotype_data: "GenotypeData",
|
|
41
|
+
*,
|
|
42
|
+
prefix: str = "pgsui",
|
|
43
|
+
by_populations: bool = False,
|
|
44
|
+
default: int = 0,
|
|
45
|
+
missing: int = -9,
|
|
46
|
+
verbose: bool = True,
|
|
47
|
+
seed: Optional[int] = None,
|
|
48
|
+
sim_prop_missing: float = 0.30,
|
|
49
|
+
debug: bool = False,
|
|
50
|
+
test_size: float = 0.2,
|
|
51
|
+
test_indices: Optional[Sequence[int]] = None,
|
|
52
|
+
stratify_by_populations: bool = False,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Initialize ImputeAlleleFreq.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
genotype_data: Object with `.genotypes_int`, `.ref`, `.alt`, and optional `.populations`.
|
|
58
|
+
prefix: Output prefix.
|
|
59
|
+
by_populations: Learn separate dists per population.
|
|
60
|
+
default: Default genotype for cold-start loci.
|
|
61
|
+
missing: Integer code for missing values in X_.
|
|
62
|
+
verbose: Verbosity switch.
|
|
63
|
+
seed: RNG seed.
|
|
64
|
+
sim_prop_missing: Fraction of OBSERVED test cells to mask for evaluation.
|
|
65
|
+
debug: Debug switch.
|
|
66
|
+
test_size: Fraction of rows held out for test if `test_indices` not provided.
|
|
67
|
+
test_indices: Explicit test row indices. Overrides `test_size` if given.
|
|
68
|
+
stratify_by_populations: If True and populations are available, create a stratified test split per population.
|
|
69
|
+
"""
|
|
70
|
+
self.genotype_data = genotype_data
|
|
71
|
+
self.prefix = prefix
|
|
72
|
+
self.by_populations = by_populations
|
|
73
|
+
self.default = int(default)
|
|
74
|
+
self.missing = int(missing)
|
|
75
|
+
self.verbose = verbose
|
|
76
|
+
self.sim_prop_missing = float(sim_prop_missing)
|
|
77
|
+
self.debug = debug
|
|
78
|
+
|
|
79
|
+
if not (0.0 <= self.sim_prop_missing <= 0.95):
|
|
80
|
+
raise ValueError("sim_prop_missing must be in [0, 0.95].")
|
|
81
|
+
|
|
82
|
+
self.rng = np.random.default_rng(seed)
|
|
83
|
+
self.encoder = GenotypeEncoder(self.genotype_data)
|
|
84
|
+
self.X_ = np.asarray(self.encoder.genotypes_int, dtype=np.int8)
|
|
85
|
+
self.num_features_ = self.X_.shape[1]
|
|
86
|
+
|
|
87
|
+
# --- split controls ---
|
|
88
|
+
self.test_size = float(test_size)
|
|
89
|
+
self.test_indices = (
|
|
90
|
+
None if test_indices is None else np.asarray(test_indices, dtype=int)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
self.stratify_by_populations = bool(stratify_by_populations)
|
|
94
|
+
|
|
95
|
+
self.pops = None
|
|
96
|
+
if self.by_populations:
|
|
97
|
+
pops = getattr(self.genotype_data, "populations", None)
|
|
98
|
+
if pops is None:
|
|
99
|
+
raise TypeError(
|
|
100
|
+
"by_populations=True requires genotype_data.populations."
|
|
101
|
+
)
|
|
102
|
+
self.pops = np.asarray(pops)
|
|
103
|
+
if len(self.pops) != self.X_.shape[0]:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"`populations` length ({len(self.pops)}) != number of samples ({self.X_.shape[0]})."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.is_fit_: bool = False
|
|
109
|
+
self.global_dist_: Dict[int, Tuple[np.ndarray, np.ndarray]] = {}
|
|
110
|
+
self.group_dist_: Dict[
|
|
111
|
+
Union[str, int], Dict[int, Tuple[np.ndarray, np.ndarray]]
|
|
112
|
+
] = {}
|
|
113
|
+
self.sim_mask_: Optional[np.ndarray] = None
|
|
114
|
+
self.train_idx_: Optional[np.ndarray] = None
|
|
115
|
+
self.test_idx_: Optional[np.ndarray] = None
|
|
116
|
+
self.X_train_: Optional[pd.DataFrame] = None
|
|
117
|
+
self.ground_truths_: Optional[np.ndarray] = None
|
|
118
|
+
self.metrics_: Dict[str, float] = {}
|
|
119
|
+
self.X_imputed_: Optional[np.ndarray] = None
|
|
120
|
+
|
|
121
|
+
# VCF ref/alt cache + IUPAC LUT
|
|
122
|
+
self.ref_codes_: Optional[np.ndarray] = None # (nF,) in {0..3}
|
|
123
|
+
self.alt_mask_: Optional[np.ndarray] = None # (nF,4) bool
|
|
124
|
+
self._iupac_presence_lut_: Optional[np.ndarray] = None # (10,4) bool
|
|
125
|
+
|
|
126
|
+
self.plotter = Plotting(
|
|
127
|
+
"ImputeAlleleFreq",
|
|
128
|
+
prefix=self.prefix,
|
|
129
|
+
plot_format=genotype_data.plot_format,
|
|
130
|
+
plot_fontsize=genotype_data.plot_fontsize,
|
|
131
|
+
plot_dpi=genotype_data.plot_dpi,
|
|
132
|
+
title_fontsize=genotype_data.plot_fontsize,
|
|
133
|
+
despine=genotype_data.plot_despine,
|
|
134
|
+
show_plots=genotype_data.show_plots,
|
|
135
|
+
verbose=self.verbose,
|
|
136
|
+
debug=self.debug,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# ------------------------------------------
|
|
140
|
+
# Helpers for VCF ref/alt and IUPAC mapping
|
|
141
|
+
# ------------------------------------------
|
|
142
|
+
def _map_base_to_int(self, arr) -> np.ndarray | None:
|
|
143
|
+
"""Map bases to A/C/G/T -> 0/1/2/3; pass integers through; others -> -1.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
arr: Array-like of bases (str) or integer codes.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
np.ndarray | None: Mapped integer array, or None if input is None or invalid.
|
|
150
|
+
"""
|
|
151
|
+
if arr is None:
|
|
152
|
+
return None
|
|
153
|
+
arr = np.asarray(arr)
|
|
154
|
+
if arr.dtype.kind in ("i", "u"):
|
|
155
|
+
return arr.astype(np.int32, copy=False)
|
|
156
|
+
if arr.dtype.kind in ("U", "S", "O"):
|
|
157
|
+
up = np.char.upper(arr.astype("U1"))
|
|
158
|
+
out = np.full(up.shape, -1, dtype=np.int32)
|
|
159
|
+
out[up == "A"] = 0
|
|
160
|
+
out[up == "C"] = 1
|
|
161
|
+
out[up == "G"] = 2
|
|
162
|
+
out[up == "T"] = 3
|
|
163
|
+
return out
|
|
164
|
+
return None
|
|
165
|
+
|
|
166
|
+
def _ref_codes_from_genotype_data(self) -> np.ndarray:
|
|
167
|
+
"""Fetch per-locus reference base from genotype_data.ref as 0..3.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
np.ndarray: Array of shape (n_features,) with values in {0,1,2,3}.
|
|
171
|
+
"""
|
|
172
|
+
ref_raw = getattr(self.genotype_data, "ref", None)
|
|
173
|
+
ref_codes = self._map_base_to_int(ref_raw)
|
|
174
|
+
if ref_codes is None or ref_codes.shape[0] != self.num_features_:
|
|
175
|
+
msg = (
|
|
176
|
+
"genotype_data.ref missing or wrong length; "
|
|
177
|
+
f"expected ({self.num_features_},) got "
|
|
178
|
+
f"{None if ref_codes is None else ref_codes.shape}"
|
|
179
|
+
)
|
|
180
|
+
raise ValueError(msg)
|
|
181
|
+
return ref_codes
|
|
182
|
+
|
|
183
|
+
def _alt_mask_from_genotype_data(self) -> np.ndarray:
|
|
184
|
+
"""Build a per-locus mask of which bases are ALT (supports multi-alt).
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
np.ndarray: Boolean array of shape (n_features, 4) indicating presence of A,C,G,T as ALT.
|
|
188
|
+
"""
|
|
189
|
+
nF = self.num_features_
|
|
190
|
+
alt_raw = getattr(self.genotype_data, "alt", None)
|
|
191
|
+
alt_mask = np.zeros((nF, 4), dtype=bool)
|
|
192
|
+
if alt_raw is None:
|
|
193
|
+
return alt_mask
|
|
194
|
+
|
|
195
|
+
alt_arr = np.asarray(alt_raw, dtype=object)
|
|
196
|
+
if alt_arr.shape[0] != nF and self.verbose:
|
|
197
|
+
print(
|
|
198
|
+
f"[warn] genotype_data.alt length {alt_arr.shape[0]} != n_features {nF}; truncating."
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def add_code(mask_row, x):
|
|
202
|
+
if x is None:
|
|
203
|
+
return
|
|
204
|
+
if isinstance(x, (int, np.integer)):
|
|
205
|
+
v = int(x)
|
|
206
|
+
if 0 <= v <= 3:
|
|
207
|
+
mask_row[v] = True
|
|
208
|
+
return
|
|
209
|
+
if isinstance(x, str):
|
|
210
|
+
s = x.strip().upper()
|
|
211
|
+
if not s:
|
|
212
|
+
return
|
|
213
|
+
if "," in s:
|
|
214
|
+
for token in s.split(","):
|
|
215
|
+
add_code(mask_row, token.strip())
|
|
216
|
+
return
|
|
217
|
+
if s in ("A", "C", "G", "T"):
|
|
218
|
+
idx = {"A": 0, "C": 1, "G": 2, "T": 3}[s]
|
|
219
|
+
mask_row[idx] = True
|
|
220
|
+
return
|
|
221
|
+
if isinstance(x, (list, tuple, np.ndarray)):
|
|
222
|
+
for t in x:
|
|
223
|
+
add_code(mask_row, t)
|
|
224
|
+
return
|
|
225
|
+
|
|
226
|
+
for i in range(min(nF, alt_arr.shape[0])):
|
|
227
|
+
add_code(alt_mask[i], alt_arr[i])
|
|
228
|
+
return alt_mask
|
|
229
|
+
|
|
230
|
+
def _build_iupac_presence_lut(self) -> np.ndarray:
|
|
231
|
+
"""Create LUT mapping integer codes {0..9} -> allele presence over A,C,G,T.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
np.ndarray: Boolean array of shape (10,4) indicating presence of A,C,G,T for each IUPAC code.
|
|
235
|
+
"""
|
|
236
|
+
lut = np.zeros((10, 4), dtype=bool) # A,C,G,T
|
|
237
|
+
lut[0, 0] = True
|
|
238
|
+
lut[1, 3] = True
|
|
239
|
+
lut[2, 2] = True
|
|
240
|
+
lut[3, 1] = True # A T G C
|
|
241
|
+
lut[4, [0, 3]] = True
|
|
242
|
+
lut[5, [0, 2]] = True
|
|
243
|
+
lut[6, [0, 1]] = True # W R M
|
|
244
|
+
lut[7, [2, 3]] = True
|
|
245
|
+
lut[8, [1, 3]] = True
|
|
246
|
+
lut[9, [1, 2]] = True # K Y S
|
|
247
|
+
return lut
|
|
248
|
+
|
|
249
|
+
# -----------------------
|
|
250
|
+
# Fit / Transform
|
|
251
|
+
# -----------------------
|
|
252
|
+
def _make_train_test_split(self) -> Tuple[np.ndarray, np.ndarray]:
|
|
253
|
+
"""Create row-wise train/test split according to init settings.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Tuple[np.ndarray, np.ndarray]: (train indices, test indices) as integer arrays.
|
|
257
|
+
"""
|
|
258
|
+
n = self.X_.shape[0]
|
|
259
|
+
all_idx = np.arange(n, dtype=int)
|
|
260
|
+
|
|
261
|
+
if self.test_indices is not None:
|
|
262
|
+
test_idx = np.unique(self.test_indices)
|
|
263
|
+
if np.any((test_idx < 0) | (test_idx >= n)):
|
|
264
|
+
raise IndexError("Some test_indices are out of bounds.")
|
|
265
|
+
train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
|
|
266
|
+
return train_idx, test_idx
|
|
267
|
+
|
|
268
|
+
# Random split (optionally stratified by population)
|
|
269
|
+
if (
|
|
270
|
+
self.by_populations
|
|
271
|
+
and self.stratify_by_populations
|
|
272
|
+
and (self.pops is not None)
|
|
273
|
+
):
|
|
274
|
+
test_buckets = []
|
|
275
|
+
for pop in np.unique(self.pops):
|
|
276
|
+
pop_rows = np.where(self.pops == pop)[0]
|
|
277
|
+
k = int(round(self.test_size * pop_rows.size))
|
|
278
|
+
if k > 0:
|
|
279
|
+
chosen = self.rng.choice(pop_rows, size=k, replace=False)
|
|
280
|
+
test_buckets.append(chosen)
|
|
281
|
+
test_idx = (
|
|
282
|
+
np.sort(np.concatenate(test_buckets))
|
|
283
|
+
if test_buckets
|
|
284
|
+
else np.array([], dtype=int)
|
|
285
|
+
)
|
|
286
|
+
else:
|
|
287
|
+
k = int(round(self.test_size * n))
|
|
288
|
+
test_idx = (
|
|
289
|
+
self.rng.choice(n, size=k, replace=False)
|
|
290
|
+
if k > 0
|
|
291
|
+
else np.array([], dtype=int)
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
|
|
295
|
+
return train_idx, test_idx
|
|
296
|
+
|
|
297
|
+
def fit(self) -> "ImputeAlleleFreq":
|
|
298
|
+
"""Learn per-locus distributions on TRAIN rows; simulate missingness on TEST rows.
|
|
299
|
+
|
|
300
|
+
Notes:
|
|
301
|
+
The general workflow is:
|
|
302
|
+
1) Split rows into train/test.
|
|
303
|
+
2) Build distributions from TRAIN only.
|
|
304
|
+
3) Simulate missingness on TEST only (stores `sim_mask_`).
|
|
305
|
+
4) Cache ref/alt and IUPAC LUT.
|
|
306
|
+
"""
|
|
307
|
+
# 0) Row split
|
|
308
|
+
self.train_idx_, self.test_idx_ = self._make_train_test_split()
|
|
309
|
+
|
|
310
|
+
self.ground_truths_ = self.X_.copy()
|
|
311
|
+
df_all = pd.DataFrame(self.ground_truths_, dtype=np.float32)
|
|
312
|
+
df_all.replace(self.missing, np.nan, inplace=True)
|
|
313
|
+
|
|
314
|
+
# 1) TRAIN-only distributions (no simulated holes here)
|
|
315
|
+
df_train = df_all.iloc[self.train_idx_].copy()
|
|
316
|
+
|
|
317
|
+
self.global_dist_ = {
|
|
318
|
+
col: self._series_distribution(df_train[col]) for col in df_train.columns
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
self.group_dist_.clear()
|
|
322
|
+
if self.by_populations:
|
|
323
|
+
tmp = df_train.copy()
|
|
324
|
+
tmp["_pops_"] = self.pops[self.train_idx_]
|
|
325
|
+
for pop, grp in tmp.groupby("_pops_"):
|
|
326
|
+
gdf = grp.drop(columns=["_pops_"])
|
|
327
|
+
self.group_dist_[pop] = {
|
|
328
|
+
col: self._series_distribution(gdf[col]) for col in gdf.columns
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
# 2) TEST-only simulated missingness
|
|
332
|
+
obs_mask = df_all.notna().to_numpy()
|
|
333
|
+
sim_mask = np.zeros_like(obs_mask, dtype=bool)
|
|
334
|
+
if self.test_idx_.size > 0 and self.sim_prop_missing > 0.0:
|
|
335
|
+
# restrict candidate coords to TEST rows that are observed
|
|
336
|
+
coords = np.argwhere(obs_mask)
|
|
337
|
+
test_row_mask = np.zeros(obs_mask.shape[0], dtype=bool)
|
|
338
|
+
test_row_mask[self.test_idx_] = True
|
|
339
|
+
coords_test = coords[test_row_mask[coords[:, 0]]]
|
|
340
|
+
total_obs_test = coords_test.shape[0]
|
|
341
|
+
if total_obs_test > 0:
|
|
342
|
+
n_to_mask = int(round(self.sim_prop_missing * total_obs_test))
|
|
343
|
+
if n_to_mask > 0:
|
|
344
|
+
choice_idx = self.rng.choice(
|
|
345
|
+
total_obs_test, size=n_to_mask, replace=False
|
|
346
|
+
)
|
|
347
|
+
chosen_coords = coords_test[choice_idx]
|
|
348
|
+
sim_mask[chosen_coords[:, 0], chosen_coords[:, 1]] = True
|
|
349
|
+
|
|
350
|
+
df_sim = df_all.copy()
|
|
351
|
+
df_sim.values[sim_mask] = np.nan
|
|
352
|
+
|
|
353
|
+
# Store matrix to be imputed (train rows intact; test rows with simulated NaNs)
|
|
354
|
+
self.sim_mask_ = sim_mask
|
|
355
|
+
self.X_train_ = df_sim
|
|
356
|
+
|
|
357
|
+
# 3) Cache VCF ref/alt + IUPAC LUT once
|
|
358
|
+
self.ref_codes_ = self._ref_codes_from_genotype_data() # (nF,)
|
|
359
|
+
self.alt_mask_ = self._alt_mask_from_genotype_data() # (nF,4)
|
|
360
|
+
self._iupac_presence_lut_ = self._build_iupac_presence_lut()
|
|
361
|
+
|
|
362
|
+
self.is_fit_ = True
|
|
363
|
+
if self.verbose:
|
|
364
|
+
n_masked = int(sim_mask.sum())
|
|
365
|
+
print(
|
|
366
|
+
f"Fit complete. Train rows: {self.train_idx_.size}, Test rows: {self.test_idx_.size}."
|
|
367
|
+
)
|
|
368
|
+
print(
|
|
369
|
+
f"Simulated {n_masked} missing values (TEST rows only) for evaluation."
|
|
370
|
+
)
|
|
371
|
+
return self
|
|
372
|
+
|
|
373
|
+
def transform(self) -> np.ndarray:
|
|
374
|
+
"""Impute the matrix and evaluate on the TEST-set simulated cells.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
np.ndarray: Imputed genotypes in the original (IUPAC/int) encoding.
|
|
378
|
+
"""
|
|
379
|
+
if not self.is_fit_:
|
|
380
|
+
msg = "Model is not fitted. Call `fit()` before `transform()`."
|
|
381
|
+
self.logger.error(msg)
|
|
382
|
+
raise NotFittedError(msg)
|
|
383
|
+
assert (
|
|
384
|
+
self.X_train_ is not None
|
|
385
|
+
and self.sim_mask_ is not None
|
|
386
|
+
and self.ground_truths_ is not None
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# ---- Impute using train-learned distributions ----
|
|
390
|
+
# uses self.global_dist_/group_dist_
|
|
391
|
+
imputed_df = self._impute_df(self.X_train_)
|
|
392
|
+
X_imp = imputed_df.to_numpy(dtype=np.int8)
|
|
393
|
+
self.X_imputed_ = X_imp
|
|
394
|
+
|
|
395
|
+
# -------------------------------------------------------------------
|
|
396
|
+
# Test-set evaluation on IUPAC/int codes (mask is test-only by design)
|
|
397
|
+
# -------------------------------------------------------------------
|
|
398
|
+
sim_mask = self.sim_mask_
|
|
399
|
+
y_true = self.ground_truths_[sim_mask]
|
|
400
|
+
y_pred = X_imp[sim_mask]
|
|
401
|
+
|
|
402
|
+
if y_true.size > 0:
|
|
403
|
+
labels = sorted(np.unique(np.concatenate([y_true, y_pred])))
|
|
404
|
+
self.metrics_ = {
|
|
405
|
+
"n_masked_test": int(y_true.size),
|
|
406
|
+
"accuracy": accuracy_score(y_true, y_pred),
|
|
407
|
+
"f1": f1_score(
|
|
408
|
+
y_true, y_pred, average="macro", labels=labels, zero_division=0
|
|
409
|
+
),
|
|
410
|
+
"precision": precision_score(
|
|
411
|
+
y_true, y_pred, average="macro", labels=labels, zero_division=0
|
|
412
|
+
),
|
|
413
|
+
"recall": recall_score(
|
|
414
|
+
y_true, y_pred, average="macro", labels=labels, zero_division=0
|
|
415
|
+
),
|
|
416
|
+
}
|
|
417
|
+
if self.verbose:
|
|
418
|
+
print("\n--- TEST-only Evaluation (IUPAC/int) ---")
|
|
419
|
+
for k, v in self.metrics_.items():
|
|
420
|
+
print(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}")
|
|
421
|
+
print("\nClassification Report (IUPAC/int, TEST-only):")
|
|
422
|
+
print(
|
|
423
|
+
classification_report(
|
|
424
|
+
y_true, y_pred, labels=labels, zero_division=0
|
|
425
|
+
)
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
self.metrics_.update({"n_masked_test": 0})
|
|
429
|
+
if self.verbose:
|
|
430
|
+
print("No TEST cells were held out for evaluation (n_masked_test=0).")
|
|
431
|
+
|
|
432
|
+
# Optional confusion matrix (IUPAC/int)
|
|
433
|
+
labels_map = {
|
|
434
|
+
"A": 0,
|
|
435
|
+
"T": 1,
|
|
436
|
+
"G": 2,
|
|
437
|
+
"C": 3,
|
|
438
|
+
"W": 4,
|
|
439
|
+
"R": 5,
|
|
440
|
+
"M": 6,
|
|
441
|
+
"K": 7,
|
|
442
|
+
"Y": 8,
|
|
443
|
+
"S": 9,
|
|
444
|
+
"N": -9,
|
|
445
|
+
}
|
|
446
|
+
self.plotter.plot_confusion_matrix(y_true, y_pred, label_names=labels_map)
|
|
447
|
+
|
|
448
|
+
# ----------------------------------------------------------------
|
|
449
|
+
# TEST-only Zygosity 0/1/2 evaluation using VCF ref/alt (hom-ref/het/hom-alt)
|
|
450
|
+
# ----------------------------------------------------------------
|
|
451
|
+
r_idx, f_idx = np.nonzero(sim_mask)
|
|
452
|
+
if r_idx.size > 0:
|
|
453
|
+
true_codes = self.ground_truths_[r_idx, f_idx].astype(np.int16, copy=False)
|
|
454
|
+
pred_codes = X_imp[r_idx, f_idx].astype(np.int16, copy=False)
|
|
455
|
+
|
|
456
|
+
keep_nm = (true_codes != self.missing) & (pred_codes != self.missing)
|
|
457
|
+
if np.any(keep_nm):
|
|
458
|
+
true_codes = true_codes[keep_nm]
|
|
459
|
+
pred_codes = pred_codes[keep_nm]
|
|
460
|
+
f_k = f_idx[keep_nm]
|
|
461
|
+
|
|
462
|
+
ref_k = self.ref_codes_[f_k] # (n,)
|
|
463
|
+
alt_rows = self.alt_mask_[f_k, :] # (n,4)
|
|
464
|
+
ra_mask = alt_rows.copy()
|
|
465
|
+
ra_mask[np.arange(ref_k.size), ref_k] = True
|
|
466
|
+
|
|
467
|
+
lut = self._iupac_presence_lut_
|
|
468
|
+
valid_true = (true_codes >= 0) & (true_codes < lut.shape[0])
|
|
469
|
+
valid_pred = (pred_codes >= 0) & (pred_codes < lut.shape[0])
|
|
470
|
+
keep_valid = valid_true & valid_pred
|
|
471
|
+
|
|
472
|
+
if np.any(keep_valid):
|
|
473
|
+
true_codes = true_codes[keep_valid]
|
|
474
|
+
pred_codes = pred_codes[keep_valid]
|
|
475
|
+
ref_k = ref_k[keep_valid]
|
|
476
|
+
ra_mask = ra_mask[keep_valid, :]
|
|
477
|
+
|
|
478
|
+
A_true = lut[true_codes] # (n,4)
|
|
479
|
+
A_pred = lut[pred_codes] # (n,4)
|
|
480
|
+
|
|
481
|
+
any_true = A_true.any(axis=1)
|
|
482
|
+
any_pred = A_pred.any(axis=1)
|
|
483
|
+
out_true = (A_true & ~ra_mask).any(axis=1)
|
|
484
|
+
out_pred = (A_pred & ~ra_mask).any(axis=1)
|
|
485
|
+
valid_rows = any_true & any_pred & (~out_true) & (~out_pred)
|
|
486
|
+
|
|
487
|
+
if np.any(valid_rows):
|
|
488
|
+
A_true = A_true[valid_rows]
|
|
489
|
+
A_pred = A_pred[valid_rows]
|
|
490
|
+
ref_kv = ref_k[valid_rows]
|
|
491
|
+
n = A_true.shape[0]
|
|
492
|
+
rows = np.arange(n, dtype=int)
|
|
493
|
+
|
|
494
|
+
cnt_true = A_true.sum(axis=1)
|
|
495
|
+
homref_true = (cnt_true == 1) & A_true[rows, ref_kv]
|
|
496
|
+
homalt_true = (cnt_true == 1) & (~A_true[rows, ref_kv])
|
|
497
|
+
y_true_3 = np.empty(n, dtype=np.int8)
|
|
498
|
+
y_true_3[homref_true] = 0
|
|
499
|
+
y_true_3[homalt_true] = 2
|
|
500
|
+
y_true_3[~(homref_true | homalt_true)] = 1
|
|
501
|
+
|
|
502
|
+
cnt_pred = A_pred.sum(axis=1)
|
|
503
|
+
homref_pred = (cnt_pred == 1) & A_pred[rows, ref_kv]
|
|
504
|
+
homalt_pred = (cnt_pred == 1) & (~A_pred[rows, ref_kv])
|
|
505
|
+
y_pred_3 = np.empty(n, dtype=np.int8)
|
|
506
|
+
y_pred_3[homref_pred] = 0
|
|
507
|
+
y_pred_3[homalt_pred] = 2
|
|
508
|
+
y_pred_3[~(homref_pred | homalt_pred)] = 1
|
|
509
|
+
|
|
510
|
+
labels_3 = [0, 1, 2]
|
|
511
|
+
self.metrics_.update(
|
|
512
|
+
{
|
|
513
|
+
"zyg_n_test": int(n),
|
|
514
|
+
"zyg_accuracy": accuracy_score(y_true_3, y_pred_3),
|
|
515
|
+
"zyg_f1": f1_score(
|
|
516
|
+
y_true_3,
|
|
517
|
+
y_pred_3,
|
|
518
|
+
average="macro",
|
|
519
|
+
labels=labels_3,
|
|
520
|
+
zero_division=0,
|
|
521
|
+
),
|
|
522
|
+
"zyg_precision": precision_score(
|
|
523
|
+
y_true_3,
|
|
524
|
+
y_pred_3,
|
|
525
|
+
average="macro",
|
|
526
|
+
labels=labels_3,
|
|
527
|
+
zero_division=0,
|
|
528
|
+
),
|
|
529
|
+
"zyg_recall": recall_score(
|
|
530
|
+
y_true_3,
|
|
531
|
+
y_pred_3,
|
|
532
|
+
average="macro",
|
|
533
|
+
labels=labels_3,
|
|
534
|
+
zero_division=0,
|
|
535
|
+
),
|
|
536
|
+
}
|
|
537
|
+
)
|
|
538
|
+
if self.verbose:
|
|
539
|
+
print(
|
|
540
|
+
"\n--- TEST-only Zygosity (0=hom-ref,1=het,2=hom-alt) ---"
|
|
541
|
+
)
|
|
542
|
+
for k in (
|
|
543
|
+
"zyg_n_test",
|
|
544
|
+
"zyg_accuracy",
|
|
545
|
+
"zyg_f1",
|
|
546
|
+
"zyg_precision",
|
|
547
|
+
"zyg_recall",
|
|
548
|
+
):
|
|
549
|
+
v = self.metrics_[k]
|
|
550
|
+
print(
|
|
551
|
+
f" {k}: {v:.4f}"
|
|
552
|
+
if isinstance(v, float)
|
|
553
|
+
else f" {k}: {v}"
|
|
554
|
+
)
|
|
555
|
+
print("\nClassification Report (zyg, TEST-only):")
|
|
556
|
+
print(
|
|
557
|
+
classification_report(
|
|
558
|
+
y_true_3,
|
|
559
|
+
y_pred_3,
|
|
560
|
+
labels=labels_3,
|
|
561
|
+
target_names=["hom-ref", "het", "hom-alt"],
|
|
562
|
+
zero_division=0,
|
|
563
|
+
)
|
|
564
|
+
)
|
|
565
|
+
self.plotter.plot_confusion_matrix(
|
|
566
|
+
y_true_3,
|
|
567
|
+
y_pred_3,
|
|
568
|
+
label_names=["hom-ref", "het", "hom-alt"],
|
|
569
|
+
)
|
|
570
|
+
else:
|
|
571
|
+
if self.verbose:
|
|
572
|
+
print(
|
|
573
|
+
"[info] Zygosity TEST-only: no valid rows after RA filtering."
|
|
574
|
+
)
|
|
575
|
+
else:
|
|
576
|
+
if self.verbose:
|
|
577
|
+
print(
|
|
578
|
+
"[info] Zygosity TEST-only: no valid rows after code filtering."
|
|
579
|
+
)
|
|
580
|
+
else:
|
|
581
|
+
if self.verbose:
|
|
582
|
+
print(
|
|
583
|
+
"[info] Zygosity TEST-only: nothing to score (all masked entries missing)."
|
|
584
|
+
)
|
|
585
|
+
else:
|
|
586
|
+
if self.verbose:
|
|
587
|
+
print("[info] TEST-only evaluation: no masked coordinates found.")
|
|
588
|
+
|
|
589
|
+
return self.encoder.inverse_int_iupac(X_imp)
|
|
590
|
+
|
|
591
|
+
def fit_transform(self) -> np.ndarray:
|
|
592
|
+
"""Convenience method that calls `fit()` then `transform()`."""
|
|
593
|
+
self.fit()
|
|
594
|
+
return self.transform()
|
|
595
|
+
|
|
596
|
+
# -----------------------
|
|
597
|
+
# Core frequency model
|
|
598
|
+
# -----------------------
|
|
599
|
+
def _safe_probs(self, probs: np.ndarray) -> np.ndarray:
|
|
600
|
+
"""Ensure probs are non-negative and sum to 1; fallback to uniform if invalid.
|
|
601
|
+
|
|
602
|
+
Args:
|
|
603
|
+
probs: Array of non-negative values (not necessarily summing to 1).
|
|
604
|
+
|
|
605
|
+
Returns:
|
|
606
|
+
np.ndarray: Valid probability distribution summing to 1.
|
|
607
|
+
"""
|
|
608
|
+
probs = np.asarray(probs, dtype=float)
|
|
609
|
+
probs[probs < 0] = 0
|
|
610
|
+
s = probs.sum()
|
|
611
|
+
|
|
612
|
+
if not np.isfinite(s) or s <= 0:
|
|
613
|
+
return np.full(probs.size, 1.0 / max(1, probs.size))
|
|
614
|
+
return probs / s
|
|
615
|
+
|
|
616
|
+
def _series_distribution(self, s: pd.Series) -> Tuple[np.ndarray, np.ndarray]:
|
|
617
|
+
"""Compute empirical (states, probs) for one locus from observed integer codes.
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
s: One column (locus) as a pandas Series with NaNs for missing.
|
|
621
|
+
|
|
622
|
+
Returns:
|
|
623
|
+
Tuple[np.ndarray, np.ndarray]: (states, probs) sorted by state.
|
|
624
|
+
"""
|
|
625
|
+
s_valid = s.dropna().astype(int)
|
|
626
|
+
if s_valid.empty:
|
|
627
|
+
return np.array([self.default], dtype=int), np.array([1.0])
|
|
628
|
+
freqs = s_valid.value_counts(normalize=True).sort_index()
|
|
629
|
+
states = freqs.index.to_numpy(dtype=int)
|
|
630
|
+
probs = self._safe_probs(freqs.to_numpy(dtype=float))
|
|
631
|
+
return states, probs
|
|
632
|
+
|
|
633
|
+
def _impute_df(self, df_in: pd.DataFrame) -> pd.DataFrame:
|
|
634
|
+
"""Impute NaNs in df_in using precomputed TRAIN distributions.
|
|
635
|
+
|
|
636
|
+
Args:
|
|
637
|
+
df_in: DataFrame with NaNs to impute.
|
|
638
|
+
|
|
639
|
+
Returns:
|
|
640
|
+
DataFrame with NaNs imputed.
|
|
641
|
+
"""
|
|
642
|
+
return (
|
|
643
|
+
self._impute_global(df_in)
|
|
644
|
+
if not self.by_populations
|
|
645
|
+
else self._impute_by_population(df_in)
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
def _impute_global(self, df_in: pd.DataFrame) -> pd.DataFrame:
|
|
649
|
+
"""Impute dataframe globally, preserving original sample order.
|
|
650
|
+
|
|
651
|
+
Args:
|
|
652
|
+
df_in: DataFrame with NaNs to impute.
|
|
653
|
+
|
|
654
|
+
Returns:
|
|
655
|
+
DataFrame with NaNs imputed.
|
|
656
|
+
"""
|
|
657
|
+
df = df_in.copy()
|
|
658
|
+
for col in df.columns:
|
|
659
|
+
if not df[col].isnull().any():
|
|
660
|
+
continue
|
|
661
|
+
states, probs = self.global_dist_[col]
|
|
662
|
+
n_missing = int(df[col].isnull().sum())
|
|
663
|
+
samples = self.rng.choice(states, size=n_missing, p=probs)
|
|
664
|
+
df.loc[df[col].isnull(), col] = samples
|
|
665
|
+
return df.astype(np.int8)
|
|
666
|
+
|
|
667
|
+
def _impute_by_population(self, df_in: pd.DataFrame) -> pd.DataFrame:
|
|
668
|
+
"""Impute dataframe by population, preserving original sample order.
|
|
669
|
+
|
|
670
|
+
Args:
|
|
671
|
+
df_in: DataFrame with NaNs to impute.
|
|
672
|
+
|
|
673
|
+
Returns:
|
|
674
|
+
DataFrame with NaNs imputed.
|
|
675
|
+
"""
|
|
676
|
+
df = df_in.copy()
|
|
677
|
+
df["_pops_"] = getattr(self, "pops", None)
|
|
678
|
+
for pop, grp in df.groupby("_pops_"):
|
|
679
|
+
grp_imputed = grp.copy()
|
|
680
|
+
per_pop_dist = self.group_dist_.get(pop, {})
|
|
681
|
+
for col in grp.columns:
|
|
682
|
+
if col == "_pops_":
|
|
683
|
+
continue
|
|
684
|
+
if not grp[col].isnull().any():
|
|
685
|
+
continue
|
|
686
|
+
states, probs = per_pop_dist.get(col, self.global_dist_[col])
|
|
687
|
+
n_missing = int(grp[col].isnull().sum())
|
|
688
|
+
samples = self.rng.choice(states, size=n_missing, p=probs)
|
|
689
|
+
grp_imputed.loc[grp_imputed[col].isnull(), col] = samples
|
|
690
|
+
df.update(grp_imputed)
|
|
691
|
+
return df.drop(columns=["_pops_"]).astype(np.int8)
|