pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pg_sui-1.6.16a3.dist-info/METADATA +292 -0
  2. pg_sui-1.6.16a3.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +922 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1436 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1121 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1316 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/METADATA +0 -322
  83. pg_sui-0.2.3.dist-info/RECORD +0 -75
  84. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  85. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  88. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  89. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  90. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  93. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  94. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  95. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  96. pgsui/example_data/trees/test.iqtree +0 -376
  97. pgsui/example_data/trees/test.qmat +0 -5
  98. pgsui/example_data/trees/test.rate +0 -2033
  99. pgsui/example_data/trees/test.tre +0 -1
  100. pgsui/example_data/trees/test_n10.rate +0 -19
  101. pgsui/example_data/trees/test_n100.rate +0 -109
  102. pgsui/example_data/trees/test_n500.rate +0 -509
  103. pgsui/example_data/trees/test_siterates.txt +0 -2024
  104. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  105. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  106. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  107. pgsui/example_data/vcf_files/test.vcf +0 -244
  108. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  109. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  110. pgsui/impute/estimators.py +0 -1268
  111. pgsui/impute/impute.py +0 -1463
  112. pgsui/impute/simple_imputers.py +0 -1431
  113. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  114. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  115. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  116. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  117. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  118. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  119. pgsui/pg_sui.py +0 -261
  120. pgsui/utils/sequence_tools.py +0 -407
  121. simulation/sim_benchmarks.py +0 -333
  122. simulation/sim_treeparams.py +0 -475
  123. test/__init__.py +0 -0
  124. test/pg_sui_simtest.py +0 -215
  125. test/pg_sui_testing.py +0 -523
  126. test/test.py +0 -151
  127. test/test_pgsui.py +0 -374
  128. test/test_tkc.py +0 -185
@@ -1,1431 +0,0 @@
1
- import os
2
- import sys
3
- from pathlib import Path
4
- import warnings
5
- from typing import Optional, Union, List, Dict, Tuple, Any, Callable
6
- from copy import deepcopy
7
-
8
- # Third-party imports
9
- import numpy as np
10
- import pandas as pd
11
- import scipy.linalg
12
- import toyplot.pdf
13
- import toyplot as tp
14
- import toytree as tt
15
- from decimal import Decimal
16
-
17
- from sklearn.impute import SimpleImputer
18
-
19
- # Custom imports
20
- try:
21
- from snpio import GenotypeData
22
- from ..utils.misc import isnotebook
23
- except (ModuleNotFoundError, ValueError, ImportError):
24
- from snpio import GenotypeData
25
- from utils.misc import isnotebook
26
-
27
- is_notebook = isnotebook()
28
-
29
- if is_notebook:
30
- from tqdm.notebook import tqdm as progressbar
31
- else:
32
- from tqdm import tqdm as progressbar
33
-
34
- # Pandas on pip gives a performance warning when doing the below code.
35
- # Apparently it's a bug that exists in the pandas version I used here.
36
- # It can be safely ignored.
37
- warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
38
-
39
-
40
- class ImputePhylo:
41
- """Impute missing data using a phylogenetic tree to inform the imputation.
42
-
43
- Args:
44
- genotype_data (GenotypeData instance): GenotypeData instance. Must have the q, tree, and optionally site_rates attributes defined.
45
-
46
- minbr (float or None, optional): Minimum branch length. Defaults to 0.0000000001
47
-
48
- str_encodings (Dict[str, int], optional): Integer encodings used in STRUCTURE-formatted file. Should be a dictionary with keys=nucleotides and values=integer encodings. The missing data encoding should also be included. Argument is ignored if using a PHYLIP-formatted file. Defaults to {"A": 1, "C": 2, "G": 3, "T": 4, "N": -9}
49
-
50
- prefix (str, optional): Prefix to use with output files.
51
-
52
- save_plots (bool, optional): Whether to save PDF files with genotype imputations for each site to disk. It makes one PDF file per locus, so if you have a lot of loci it will make a lot of PDF files. Defaults to False.
53
-
54
- write_output (bool, optional): Whether to save the imputed data to disk. Defaults to True.
55
-
56
- disable_progressbar (bool, optional): Whether to disable the progress bar during the imputation. Defaults to False.
57
-
58
- kwargs (Dict[str, Any] or None, optional): Additional keyword arguments intended for internal purposes only. Possible arguments: {"column_subset": List[int] or numpy.ndarray[int]}; Subset SNPs by a list of indices for IterativeImputer. Defauls to None.
59
-
60
- Attributes:
61
- imputed (GenotypeData): New GenotypeData instance with imputed data.
62
-
63
- Example:
64
- >>>data = GenotypeData(
65
- >>> filename="test.str",
66
- >>> filetype="structure",
67
- >>> popmapfile="test.popmap",
68
- >>> guidetree="test.tre",
69
- >>> qmatrix_iqtree="test.iqtree",
70
- >>> siterates_iqtree="test.rates",
71
- >>>)
72
- >>>
73
- >>>phylo = ImputePhylo(
74
- >>> genotype_data=data,
75
- >>> save_plots=True,
76
- >>>)
77
- >>> # Get GenotypeData object.
78
- >>>gd_phylo = phylo.imputed
79
- """
80
-
81
- def __init__(
82
- self,
83
- genotype_data: Optional[Any],
84
- minbr: Optional[float] = 0.0000000001,
85
- *,
86
- str_encodings: Dict[str, int] = {
87
- "A": 1,
88
- "C": 2,
89
- "G": 3,
90
- "T": 4,
91
- "N": -9,
92
- },
93
- prefix: str = "output",
94
- save_plots: bool = False,
95
- disable_progressbar: bool = False,
96
- **kwargs: Optional[Dict[str, Any]],
97
- ) -> None:
98
- self.genotype_data = genotype_data
99
- self.alnfile = genotype_data.filename
100
- self.filetype = genotype_data.filetype
101
- self.popmap = genotype_data.popmap
102
- self.str_encodings = str_encodings
103
- self.prefix = prefix
104
- self.minbr = minbr
105
- self.save_plots = save_plots
106
- self.disable_progressbar = disable_progressbar
107
- self.column_subset = kwargs.get("column_subset", None)
108
- self.validation_mode = kwargs.get("validation_mode", False)
109
-
110
- self.valid_sites = None
111
- self.valid_sites_count = None
112
-
113
- self._validate_arguments(genotype_data)
114
- data, tree, q, site_rates = self._parse_arguments(genotype_data)
115
-
116
- if not self.validation_mode:
117
- imputed012 = self.impute_phylo(tree, data, q, site_rates)
118
- genotype_data = deepcopy(genotype_data)
119
- genotype_data.genotypes_012 = imputed012
120
- self.imputed = genotype_data
121
- else:
122
- self.imputed = self.impute_phylo(tree, data, q, site_rates)
123
-
124
- @property
125
- def genotypes_012(self):
126
- return self.imputed.genotypes012
127
-
128
- @property
129
- def snp_data(self):
130
- return self.imputed.snp_data
131
-
132
- @property
133
- def alignment(self):
134
- return self.imputed.alignment
135
-
136
- def impute_phylo(
137
- self,
138
- tree: tt.tree,
139
- genotypes: Dict[str, List[Union[str, int]]],
140
- Q: pd.DataFrame,
141
- site_rates=None,
142
- minbr=0.0000000001,
143
- ) -> pd.DataFrame:
144
- """Imputes genotype values with a guide tree.
145
-
146
- Imputes genotype values by using a provided guide
147
- tree to inform the imputation, assuming maximum parsimony.
148
-
149
- Process Outline:
150
- For each SNP:
151
- 1) if site_rates, get site-transformated Q matrix.
152
-
153
- 2) Postorder traversal of tree to compute ancestral
154
- state likelihoods for internal nodes (tips -> root).
155
- If exclude_N==True, then ignore N tips for this step.
156
-
157
- 3) Preorder traversal of tree to populate missing genotypes
158
- with the maximum likelihood state (root -> tips).
159
-
160
- Args:
161
- tree (toytree.tree object): Input tree.
162
-
163
- genotypes (Dict[str, List[Union[str, int]]]): Dictionary with key=sampleids, value=sequences.
164
-
165
- Q (pandas.DataFrame): Rate Matrix Q from .iqtree or separate file.
166
-
167
- site_rates (List): Site-specific substitution rates (used to weight per-site Q)
168
-
169
- minbr (float) : Minimum branch length (those below this value will be treated as == minbr)
170
-
171
- Returns:
172
- pandas.DataFrame: Imputed genotypes.
173
-
174
- Raises:
175
- IndexError: If index does not exist when trying to read genotypes.
176
- AssertionError: Sites must have same lengths.
177
- AssertionError: Missing data still found after imputation.
178
- """
179
- try:
180
- if list(genotypes.values())[0][0][1] == "/":
181
- genotypes = self._str2iupac(genotypes, self.str_encodings)
182
- except IndexError:
183
- if self._is_int(list(genotypes.values())[0][0][0]):
184
- raise
185
-
186
- if self.column_subset is not None:
187
- if isinstance(self.column_subset, np.ndarray):
188
- self.column_subset = self.column_subset.tolist()
189
-
190
- genotypes = {
191
- k: [v[i] for i in self.column_subset]
192
- for k, v in genotypes.items()
193
- }
194
-
195
- # For each SNP:
196
- nsites = list(set([len(v) for v in genotypes.values()]))
197
- assert len(nsites) == 1, "Some sites have different lengths!"
198
-
199
- outdir = f"{self.prefix}_imputation_plots"
200
-
201
- if self.save_plots:
202
- Path(outdir).mkdir(parents=True, exist_ok=True)
203
-
204
- for snp_index in progressbar(
205
- range(nsites[0]),
206
- desc="Feature Progress: ",
207
- leave=True,
208
- disable=self.disable_progressbar,
209
- ):
210
- rate = 1.0
211
- if site_rates is not None:
212
- rate = site_rates[snp_index]
213
-
214
- site_Q = Q.copy(deep=True) * rate
215
-
216
- bads = list()
217
- for samp in genotypes.keys():
218
- if genotypes[samp][snp_index].upper() == "N":
219
- bads.append(samp)
220
-
221
- # postorder traversal to compute likelihood at root
222
- node_lik = dict()
223
- for node in tree.treenode.traverse("postorder"):
224
- if node.is_leaf():
225
- continue
226
-
227
- if node.idx not in node_lik:
228
- node_lik[node.idx] = [1.0, 1.0, 1.0, 1.0]
229
-
230
- for child in node.get_children():
231
- # get branch length to child
232
- # bl = child.edge.length
233
- # get transition probs
234
- d = child.dist
235
- if d < minbr:
236
- d = minbr
237
- pt = self._transition_probs(site_Q, d)
238
- if child.is_leaf():
239
- if child.name in genotypes:
240
- if child.name in bads:
241
- sum = [1.0, 1.0, 1.0, 1.0]
242
- else:
243
- # get genotype data
244
- sum = None
245
- for allele in self._get_iupac_full(
246
- genotypes[child.name][snp_index]
247
- ):
248
- if sum is None:
249
- sum = [
250
- Decimal(x)
251
- for x in list(pt[allele])
252
- ]
253
- else:
254
- sum = [
255
- Decimal(sum[i]) + Decimal(val)
256
- for i, val in enumerate(
257
- list(pt[allele])
258
- )
259
- ]
260
- node_lik[child.idx] = [Decimal(x) for x in sum]
261
-
262
- # add to likelihood for parent node
263
- if node_lik[node.idx] is None:
264
- node_lik[node.idx] = node_lik[child.idx]
265
- else:
266
- node_lik[node.idx] = [
267
- Decimal(node_lik[child.idx][i])
268
- * Decimal(val)
269
- for i, val in enumerate(node_lik[node.idx])
270
- ]
271
- else:
272
- # raise error
273
- sys.exit(
274
- f"Error: Taxon {child.name} not found in "
275
- f"genotypes"
276
- )
277
- else:
278
- l = self._get_internal_lik(pt, node_lik[child.idx])
279
- if node_lik[node.idx] is None:
280
- node_lik[node.idx] = [Decimal(x) for x in l]
281
-
282
- else:
283
- node_lik[node.idx] = [
284
- Decimal(l[i]) * Decimal(val)
285
- for i, val in enumerate(node_lik[node.idx])
286
- ]
287
-
288
- # preorder traversal to get marginal reconstructions at internal
289
- # nodes
290
- marg = node_lik.copy()
291
- for node in tree.treenode.traverse("preorder"):
292
- if node.is_root():
293
- continue
294
- elif node.is_leaf():
295
- continue
296
- lik_arr = marg[node.idx]
297
- parent_arr = marg[node.up.idx]
298
- marg[node.idx] = [
299
- Decimal(lik) * (Decimal(parent_arr[i]) / Decimal(lik))
300
- for i, lik in enumerate(lik_arr)
301
- ]
302
-
303
- # get marginal reconstructions for bad bois
304
- two_pass = dict()
305
- for samp in bads:
306
- # get most likely state for focal tip
307
- node = tree.idx_dict[
308
- tree.get_mrca_idx_from_tip_labels(names=samp)
309
- ]
310
- dist = node.dist
311
- parent = node.up
312
- imputed = None
313
- pt = self._transition_probs(site_Q, dist)
314
- lik = self._get_internal_lik(pt, marg[parent.idx])
315
-
316
- tol = 0.001
317
- imputed = self._get_imputed_nuc(lik)
318
-
319
- # two_pass[samp] = [imputed, lik]
320
- genotypes[samp][snp_index] = imputed
321
-
322
- # DEPRECATED: RE-ROOTING METHOD OF YANG ET AL
323
- # NEW METHOD (ABOVE) IS LINEAR
324
- # reroot=dict()
325
- # for samp in bads:
326
- # #focaltree = tree.drop_tips(names=[x for x in bads if x != samp])
327
- # focaltree = tree.root(names=[samp])
328
- #
329
- # mystyle = {
330
- # "edge_type": "p",
331
- # "edge_style": {
332
- # "stroke-width": 1,
333
- # },
334
- # "tip_labels_align": True,
335
- # "tip_labels_style": {"font-size": "5px"},
336
- # "node_labels": False,
337
- # }
338
- #
339
- # canvas, axes, mark = focaltree.draw()
340
- # toyplot.pdf.render(canvas, "test.pdf")
341
- #
342
- # #postorder traversal to compute likelihood
343
- # node_lik = dict()
344
- # for node in focaltree.treenode.traverse("postorder"):
345
- # if node.is_leaf():
346
- # continue
347
- #
348
- # if node.idx not in node_lik:
349
- # node_lik[node.idx] = None
350
- #
351
- # for child in node.get_children():
352
- # # get branch length to child
353
- # # bl = child.edge.length
354
- # # get transition probs
355
- # pt = self._transition_probs(site_Q, child.dist)
356
- # if child.is_leaf():
357
- # if child.name in genotypes:
358
- # if child.name in bads:
359
- # sum = [1.0, 1.0, 1.0, 1.0]
360
- # else:
361
- # # get genotype data
362
- # sum = None
363
- # for allele in self._get_iupac_full(
364
- # genotypes[child.name][snp_index]
365
- # ):
366
- # if sum is None:
367
- # sum = [Decimal(x) for x in list(pt[allele])]
368
- # else:
369
- # sum = [
370
- # Decimal(sum[i]) + Decimal(val)
371
- # for i, val in enumerate(
372
- # list(pt[allele])
373
- # )
374
- # ]
375
- #
376
- # node_lik[child.idx] = [Decimal(x) for x in sum]
377
- #
378
- # #add to likelihood for parent node
379
- # if node_lik[node.idx] is None:
380
- # node_lik[node.idx] = node_lik[child.idx]
381
- # else:
382
- # node_lik[node.idx] = [
383
- # Decimal(node_lik[child.idx][i]) * Decimal(val)
384
- # for i, val in enumerate(node_lik[node.idx])
385
- # ]
386
- # else:
387
- # # raise error
388
- # sys.exit(
389
- # f"Error: Taxon {child.name} not found in "
390
- # f"genotypes"
391
- # )
392
- # else:
393
- # l = self._get_internal_lik(pt, node_lik[child.idx])
394
- # if node_lik[node.idx] is None:
395
- # node_lik[node.idx] = [Decimal(x) for x in l]
396
- #
397
- # else:
398
- # node_lik[node.idx] = [
399
- # Decimal(l[i]) * Decimal(val)
400
- # for i, val in enumerate(node_lik[node.idx])
401
- # ]
402
- #
403
- # # get most likely state for focal tip
404
- # node = focaltree.idx_dict[
405
- # focaltree.get_mrca_idx_from_tip_labels(names=samp)
406
- # ]
407
- # dist = node.dist
408
- # parent = node.up
409
- # imputed = None
410
- # pt = self._transition_probs(site_Q, dist)
411
- # lik = self._get_internal_lik(pt, node_lik[parent.idx])
412
- # maxpos = lik.index(max(lik))
413
- # if maxpos == 0:
414
- # imputed = "A"
415
- #
416
- # elif maxpos == 1:
417
- # imputed = "C"
418
- #
419
- # elif maxpos == 2:
420
- # imputed = "G"
421
- #
422
- # else:
423
- # imputed = "T"
424
- # reroot[samp] = [imputed, lik]
425
- # check if two methods give same results
426
- # for key in two_pass:
427
- # if two_pass[key][0] != reroot[key][0]:
428
- # print("Two-pass:", two_pass[key][0], "-", two_pass[key][1])
429
- # print("Reroot:", reroot[key][0], "-", reroot[key][1])
430
-
431
- if self.save_plots:
432
- self._draw_imputed_position(
433
- tree,
434
- bads,
435
- genotypes,
436
- snp_index,
437
- f"{outdir}/{self.prefix}_pos{snp_index}.pdf",
438
- )
439
-
440
- df = pd.DataFrame.from_dict(genotypes, orient="index")
441
-
442
- # Make sure no missing data remains in the dataset
443
- assert (
444
- not df.isin([-9]).any().any()
445
- ), "Imputation failed...Missing values found in the imputed dataset"
446
-
447
- (
448
- imp_snps,
449
- self.valid_sites,
450
- self.valid_sites_count,
451
- ) = self.genotype_data.convert_012(
452
- df.to_numpy().tolist(), impute_mode=True
453
- )
454
-
455
- df_imp = pd.DataFrame.from_records(imp_snps)
456
-
457
- return df_imp
458
-
459
- def nbiallelic(self) -> int:
460
- """Get the number of remaining bi-allelic sites after imputation.
461
-
462
- Returns:
463
- int: Number of bi-allelic sites remaining after imputation.
464
- """
465
- return len(self.imputed.columns)
466
-
467
- def _get_imputed_nuc(self, lik_arr):
468
- nucmap = {0: "A", 1: "C", 2: "G", 3: "T"}
469
- maxpos = lik_arr.index(max(lik_arr))
470
- picks = set([maxpos])
471
- # NOT USED:
472
- # Experimenting with ways to impute heterozygotes.
473
- # Note that LRT isn't appropriate (as I used here) because
474
- # the models are not nested & LRTS isn't necessarily expected
475
- # to be chisq distributed.
476
- # Check out Vuong test and read Lewis et al 2011 (doi: 10.1111/j.2041-210X.2010.00063.x)
477
- #
478
- # for index, alt in enumerate(lik_arr):
479
- # if index == maxpos:
480
- # continue
481
- # else:
482
- # lr = lrt(lik_arr[maxpos], alt, loglik=False)
483
- # p = chi2.sf(lr)
484
- # print(nucmap[maxpos], ":", str(lrt(lik_arr[maxpos], alt, loglik=False)), p)
485
- return nucmap[maxpos]
486
-
487
- def _parse_arguments(
488
- self, genotype_data: Any
489
- ) -> Tuple[Dict[str, List[Union[int, str]]], tt.tree, pd.DataFrame]:
490
- """Determine which arguments were specified and set appropriate values.
491
-
492
- Args:
493
- genotype_data (GenotypeData object): Initialized GenotypeData object.
494
-
495
- Returns:
496
- Dict[str, List[Union[int, str]]]: GenotypeData.snpsdict object. If genotype_data is not None, then this value gets set from the GenotypeData.snpsdict object. If alnfile is not None, then the alignment file gets read and the snpsdict object gets set from the alnfile.
497
-
498
- toytree.tree: Input phylogeny, either read from GenotypeData object or supplied with treefile.
499
-
500
- pandas.DataFrame: Q Rate Matrix, either from IQ-TREE file or from its own supplied file.
501
- """
502
- data = genotype_data.snpsdict
503
- tree = genotype_data.tree
504
-
505
- # read (optional) Q-matrix
506
- if genotype_data.q is not None:
507
- q = genotype_data.q
508
- else:
509
- raise TypeError("q must be defined in GenotypeData instance.")
510
-
511
- # read (optional) site-specific substitution rates
512
- if genotype_data.site_rates is not None:
513
- site_rates = genotype_data.site_rates
514
- else:
515
- raise TypeError(
516
- "site rates must be defined in GenotypeData instance."
517
- )
518
-
519
- return data, tree, q, site_rates
520
-
521
- def _validate_arguments(self, genotype_data: Any) -> None:
522
- """Validate that the correct arguments were supplied.
523
-
524
- Args:
525
- genotype_data (GenotypeData object): Input GenotypeData instance.
526
-
527
- Raises:
528
- TypeError: Must define genotype_data.tree in GenotypeData instance.
529
- TypeError: Q rate matrix must be defined in GenotypeData instance.
530
- """
531
-
532
- if genotype_data.tree is None:
533
- raise TypeError("genotype_data.tree must be defined")
534
-
535
- if genotype_data.q is None:
536
- raise TypeError("q must be defined in GenotypeData instance.")
537
-
538
- def _print_q(self, q: pd.DataFrame) -> None:
539
- """Print Rate Matrix Q.
540
-
541
- Args:
542
- q (pandas.DataFrame): Rate Matrix Q.
543
- """
544
- print("Rate matrix Q:")
545
- print("\tA\tC\tG\tT\t")
546
- for nuc1 in ["A", "C", "G", "T"]:
547
- print(nuc1, end="\t")
548
- for nuc2 in ["A", "C", "G", "T"]:
549
- print(q[nuc1][nuc2], end="\t")
550
- print("")
551
-
552
- def _is_int(self, val: Union[str, int]) -> bool:
553
- """Check if value is integer.
554
-
555
- Args:
556
- val (int or str): Value to check.
557
-
558
- Returns:
559
- bool: True if integer, False if string.
560
- """
561
- try:
562
- num = int(val)
563
- except ValueError:
564
- return False
565
- return True
566
-
567
- def _get_nuc_colors(self, nucs: List[str]) -> List[str]:
568
- """Get colors for each nucleotide when plotting.
569
-
570
- Args:
571
- nucs (List[str]): Nucleotides at current site.
572
-
573
- Returns:
574
- List[str]: Hex-code color values for each IUPAC nucleotide.
575
- """
576
- ret = list()
577
- for nuc in nucs:
578
- nuc = nuc.upper()
579
- if nuc == "A":
580
- ret.append("#0000FF") # blue
581
- elif nuc == "C":
582
- ret.append("#FF0000") # red
583
- elif nuc == "G":
584
- ret.append("#00FF00") # green
585
- elif nuc == "T":
586
- ret.append("#FFFF00") # yellow
587
- elif nuc == "R":
588
- ret.append("#0dbaa9") # blue-green
589
- elif nuc == "Y":
590
- ret.append("#FFA500") # orange
591
- elif nuc == "K":
592
- ret.append("#9acd32") # yellow-green
593
- elif nuc == "M":
594
- ret.append("#800080") # purple
595
- elif nuc == "S":
596
- ret.append("#964B00")
597
- elif nuc == "W":
598
- ret.append("#C0C0C0")
599
- else:
600
- ret.append("#000000")
601
- return ret
602
-
603
- def _label_bads(
604
- self, tips: List[str], labels: List[str], bads: List[str]
605
- ) -> List[str]:
606
- """Insert asterisks around bad nucleotides.
607
-
608
- Args:
609
- tips (List[str]): Tip labels (sample IDs).
610
- labels (List[str]): List of nucleotides at current site.
611
- bads (List[str]): List of tips that have missing data at current site.
612
-
613
- Returns:
614
- List[str]: IUPAC Nucleotides with "*" inserted around tips that had missing data.
615
- """
616
- for i, t in enumerate(tips):
617
- if t in bads:
618
- labels[i] = "*" + str(labels[i]) + "*"
619
- return labels
620
-
621
- def _draw_imputed_position(
622
- self,
623
- tree: tt.tree,
624
- bads: List[str],
625
- genotypes: Dict[str, List[str]],
626
- pos: int,
627
- out: str = "tree.pdf",
628
- ) -> None:
629
- """Draw nucleotides at phylogeny tip and saves to file on disk.
630
-
631
- Draws nucleotides as tip labels for the current SNP site. Imputed values have asterisk surrounding the nucleotide label. The tree is converted to a toyplot object and saved to file.
632
-
633
- Args:
634
- tree (toytree.tree): Input tree object.
635
- bads (List[str]): List of sampleIDs that have missing data at the current SNP site.
636
- genotypes (Dict[str, List[str]]): Genotypes at all SNP sites.
637
- pos (int): Current SNP index.
638
- out (str, optional): Output filename for toyplot object.
639
- """
640
-
641
- # print(tree.get_tip_labels())
642
- sizes = [8 if i in bads else 0 for i in tree.get_tip_labels()]
643
- colors = [genotypes[i][pos] for i in tree.get_tip_labels()]
644
- labels = colors
645
-
646
- labels = self._label_bads(tree.get_tip_labels(), labels, bads)
647
-
648
- colors = self._get_nuc_colors(colors)
649
-
650
- mystyle = {
651
- "edge_type": "p",
652
- "edge_style": {
653
- "stroke": tt.colors[0],
654
- "stroke-width": 1,
655
- },
656
- "tip_labels_align": True,
657
- "tip_labels_style": {"font-size": "5px"},
658
- "node_labels": False,
659
- }
660
-
661
- canvas, axes, mark = tree.draw(
662
- tip_labels_colors=colors,
663
- tip_labels=labels,
664
- width=400,
665
- height=600,
666
- **mystyle,
667
- )
668
-
669
- toyplot.pdf.render(canvas, out)
670
-
671
- def _all_missing(
672
- self,
673
- tree: tt.tree,
674
- node_index: int,
675
- snp_index: int,
676
- genotypes: Dict[str, List[str]],
677
- ) -> bool:
678
- """Check if all descendants of a clade have missing data at SNP site.
679
-
680
- Args:
681
- tree (toytree.tree): Input guide tree object.
682
-
683
- node_index (int): Parent node to determine if all descendants have missing data.
684
-
685
- snp_index (int): Index of current SNP site.
686
-
687
- genotypes (Dict[str, List[str]]): Genotypes at all SNP sites.
688
-
689
- Returns:
690
- bool: True if all descendants have missing data, otherwise False.
691
- """
692
- for des in tree.get_tip_labels(idx=node_index):
693
- if genotypes[des][snp_index].upper() not in ["N", "-"]:
694
- return False
695
- return True
696
-
697
- def _get_internal_lik(
698
- self, pt: pd.DataFrame, lik_arr: List[float]
699
- ) -> List[float]:
700
- """Get ancestral state likelihoods for internal nodes of the tree.
701
-
702
- Postorder traversal to calculate internal ancestral state likelihoods (tips -> root).
703
-
704
- Args:
705
- pt (pandas.DataFrame): Transition probabilities calculated from Rate Matrix Q.
706
- lik_arr (List[float]): Likelihoods for nodes or leaves.
707
-
708
- Returns:
709
- List[float]: Internal likelihoods.
710
- """
711
- ret = list()
712
- for i, val in enumerate(lik_arr):
713
- col = list(pt.iloc[:, i])
714
- sum = Decimal(0.0)
715
- for v in col:
716
- sum += Decimal(v) * Decimal(val)
717
- ret.append(sum)
718
- return ret
719
-
720
- def _transition_probs(self, Q: pd.DataFrame, t: float) -> pd.DataFrame:
721
- """Get transition probabilities for tree.
722
-
723
- Args:
724
- Q (pd.DataFrame): Rate Matrix Q.
725
- t (float): Tree distance of child.
726
-
727
- Returns:
728
- pd.DataFrame: Transition probabilities.
729
- """
730
- ret = Q.copy(deep=True)
731
- m = Q.to_numpy()
732
- pt = scipy.linalg.expm(m * t)
733
- ret[:] = pt
734
- return ret
735
-
736
- def _str2iupac(
737
- self, genotypes: Dict[str, List[str]], str_encodings: Dict[str, int]
738
- ) -> Dict[str, List[str]]:
739
- """Convert STRUCTURE-format encodings to IUPAC bases.
740
-
741
- Args:
742
- genotypes (Dict[str, List[str]]): Genotypes at all sites.
743
- str_encodings (Dict[str, int]): Dictionary that maps IUPAC bases (keys) to integer encodings (values).
744
-
745
- Returns:
746
- Dict[str, List[str]]: Genotypes converted to IUPAC format.
747
- """
748
- a = str_encodings["A"]
749
- c = str_encodings["C"]
750
- g = str_encodings["G"]
751
- t = str_encodings["T"]
752
- n = str_encodings["N"]
753
- nuc = {
754
- f"{a}/{a}": "A",
755
- f"{c}/{c}": "C",
756
- f"{g}/{g}": "G",
757
- f"{t}/{t}": "T",
758
- f"{n}/{n}": "N",
759
- f"{a}/{c}": "M",
760
- f"{c}/{a}": "M",
761
- f"{a}/{g}": "R",
762
- f"{g}/{a}": "R",
763
- f"{a}/{t}": "W",
764
- f"{t}/{a}": "W",
765
- f"{c}/{g}": "S",
766
- f"{g}/{c}": "S",
767
- f"{c}/{t}": "Y",
768
- f"{t}/{c}": "Y",
769
- f"{g}/{t}": "K",
770
- f"{t}/{g}": "K",
771
- }
772
-
773
- for k, v in genotypes.items():
774
- for i, gt in enumerate(v):
775
- v[i] = nuc[gt]
776
-
777
- return genotypes
778
-
779
- def _get_iupac_full(self, char: str) -> List[str]:
780
- """Map nucleotide to list of expanded IUPAC encodings.
781
-
782
- Args:
783
- char (str): Current nucleotide.
784
-
785
- Returns:
786
- List[str]: List of nucleotides in ``char`` expanded IUPAC.
787
- """
788
- char = char.upper()
789
- iupac = {
790
- "A": ["A"],
791
- "G": ["G"],
792
- "C": ["C"],
793
- "T": ["T"],
794
- "N": ["A", "C", "T", "G"],
795
- "-": ["A", "C", "T", "G"],
796
- "R": ["A", "G"],
797
- "Y": ["C", "T"],
798
- "S": ["G", "C"],
799
- "W": ["A", "T"],
800
- "K": ["G", "T"],
801
- "M": ["A", "C"],
802
- "B": ["C", "G", "T"],
803
- "D": ["A", "G", "T"],
804
- "H": ["A", "C", "T"],
805
- "V": ["A", "C", "G"],
806
- }
807
-
808
- ret = iupac[char]
809
- return ret
810
-
811
-
812
- class ImputeAlleleFreq:
813
- """Impute missing data by global allele frequency. Population IDs can be sepcified with the pops argument. if pops is None, then imputation is by global allele frequency. If pops is not None, then imputation is by population-wise allele frequency. A list of population IDs in the appropriate format can be obtained from the GenotypeData object as GenotypeData.populations.
814
-
815
- Args:
816
- genotype_data (GenotypeData object): GenotypeData instance.
817
-
818
- by_populations (bool, optional): Whether or not to impute by-population or globally. Defaults to False (global allele frequency).
819
-
820
- diploid (bool, optional): When diploid=True, function assumes 0=homozygous ref; 1=heterozygous; 2=homozygous alt. 0-1-2 genotypes are decomposed to compute p (=frequency of ref) and q (=frequency of alt). In this case, p and q alleles are sampled to generate either 0 (hom-p), 1 (het), or 2 (hom-q) genotypes. When diploid=FALSE, 0-1-2 are sampled according to their observed frequency. Defaults to True.
821
-
822
- default (int, optional): Value to set if no alleles sampled at a locus. Defaults to 0.
823
-
824
- missing (int, optional): Missing data value. Defaults to -9.
825
-
826
- verbose (bool, optional): Whether to print status updates. Set to False for no status updates. Defaults to True.
827
-
828
- kwargs (Dict[str, Any]): Additional keyword arguments to supply. Primarily for internal purposes. Options include: {"iterative_mode": bool, validation_mode: bool, gt: List[List[int]]}. "iterative_mode" determines whether ``ImputeAlleleFreq`` is being used as the initial imputer in ``IterativeImputer``\. ``gt`` is used internally for the simple imputers during grid searches and validation. If ``genotype_data is None`` then ``gt`` cannot also be None, and vice versa. Only one of ``gt`` or ``genotype_data`` can be set.
829
-
830
- Raises:
831
- TypeError: genotype_data cannot be NoneType.
832
-
833
- Attributes:
834
- imputed (GenotypeData): New GenotypeData instance with imputed data.
835
-
836
- Example:
837
- >>>data = GenotypeData(
838
- >>> filename="test.str",
839
- >>> filetype="structure2rowPopID",
840
- >>> popmapfile="test.popmap",
841
- >>>)
842
- >>>
843
- >>>afpop = ImputeAlleleFreq(
844
- >>> genotype_data=data,
845
- >>> by_populations=True,
846
- >>>)
847
- >>>
848
- >>>gd_afpop = afpop.imputed
849
- """
850
-
851
- def __init__(
852
- self,
853
- genotype_data: GenotypeData,
854
- *,
855
- by_populations: bool = False,
856
- diploid: bool = True,
857
- default: int = 0,
858
- missing: int = -9,
859
- verbose: bool = True,
860
- **kwargs: Dict[str, Any],
861
- ) -> None:
862
- if genotype_data is None and gt is None:
863
- raise TypeError("GenotypeData instance or gt must be provided.")
864
-
865
- gt = kwargs.get("gt", None)
866
-
867
- if gt is None:
868
- gt_list = genotype_data.genotypes_012(fmt="list")
869
- else:
870
- gt_list = gt
871
-
872
- if by_populations:
873
- if genotype_data.populations is None:
874
- raise TypeError(
875
- "When by_populations is True, GenotypeData instance must have a defined populations attribute"
876
- )
877
-
878
- self.pops = genotype_data.populations
879
-
880
- else:
881
- self.pops = None
882
-
883
- self.diploid = diploid
884
- self.default = default
885
- self.missing = missing
886
- self.verbose = verbose
887
- self.iterative_mode = kwargs.get("iterative_mode", False)
888
- self.validation_mode = kwargs.get("validation_mode", False)
889
-
890
- if not self.validation_mode:
891
- imputed012, self.valid_cols = self.fit_predict(gt_list)
892
- genotype_data = deepcopy(genotype_data)
893
- genotype_data.genotypes_012 = imputed012
894
- self.imputed = genotype_data
895
- else:
896
- self.imputed, self.valid_cols = self.fit_predict(gt_list)
897
-
898
- @property
899
- def genotypes_012(self):
900
- return self.imputed.genotypes_012
901
-
902
- @property
903
- def snp_data(self):
904
- return self.imputed.snp_data
905
-
906
- @property
907
- def alignment(self):
908
- return self.imputed.alignment
909
-
910
- def fit_predict(
911
- self, X: List[List[int]]
912
- ) -> Tuple[
913
- Union[pd.DataFrame, np.ndarray, List[List[Union[int, float]]]],
914
- List[int],
915
- ]:
916
- """Impute missing genotypes using allele frequencies.
917
-
918
- Impute using global or by_population allele frequencies. Missing alleles are primarily coded as negative; usually -9.
919
-
920
- Args:
921
- X (List[List[int]], numpy.ndarray, or pandas.DataFrame): 012-encoded genotypes obtained from the GenotypeData object.
922
-
923
- Returns:
924
- pandas.DataFrame, numpy.ndarray, or List[List[Union[int, float]]]: Imputed genotypes of same shape as data.
925
-
926
- List[int]: Column indexes that were retained.
927
-
928
- Raises:
929
- TypeError: X must be either list, np.ndarray, or pd.DataFrame.
930
- """
931
- if self.pops is not None and self.verbose:
932
- print("\nImputing by population allele frequencies...")
933
- elif self.pops is None and self.verbose:
934
- print("\nImputing by global allele frequency...")
935
-
936
- if isinstance(X, (list, np.ndarray)):
937
- df = pd.DataFrame(X)
938
- elif isinstance(X, pd.DataFrame):
939
- df = X.copy()
940
- else:
941
- raise TypeError(
942
- f"X must be of type list(list(int)), numpy.ndarray, "
943
- f"or pandas.DataFrame, but got {type(X)}"
944
- )
945
-
946
- df = df.astype(int)
947
- df.replace(self.missing, np.nan, inplace=True)
948
-
949
- # Initialize an empty list to hold the columns
950
- columns = []
951
- valid_cols = list()
952
- bad_cnt = 0
953
-
954
- if self.pops is not None:
955
- df = df.copy()
956
-
957
- # Impute per-population mode.
958
- df["pops"] = self.pops
959
- groups = df.groupby(["pops"], sort=False)
960
-
961
- for col in df.columns:
962
- try:
963
- # Instead of appending to the DataFrame, append to the list
964
- columns.append(
965
- groups[col].transform(
966
- lambda x: x.fillna(x.mode().iloc[0])
967
- )
968
- )
969
-
970
- if col != "pops":
971
- valid_cols.append(col)
972
-
973
- except IndexError as e:
974
- if str(e).lower().startswith("single positional indexer"):
975
- bad_cnt += 1
976
- # Impute with global mode, unless globally missing in which case call as 0.0
977
- if df[col].isna().all():
978
- columns.append(df[col].fillna(0.0, inplace=False))
979
- else:
980
- columns.append(
981
- df[col].fillna(df[col].mode().iloc[0])
982
- )
983
- else:
984
- raise
985
-
986
- data = pd.concat(columns, axis=1)
987
-
988
- if bad_cnt > 0 and not self.validation_mode:
989
- UserWarning(
990
- f"\n{bad_cnt} columns were imputed with the "
991
- f"global mode because some of the populations "
992
- f"contained only missing data"
993
- )
994
-
995
- data.drop("pops", axis=1, inplace=True)
996
- else:
997
- # Impute global mode.
998
- imp = SimpleImputer(strategy="most_frequent")
999
-
1000
- # replace any columns that are fully missing
1001
- df.loc[:, df.isna().all()] = df.loc[:, df.isna().all()].fillna(0.0)
1002
-
1003
- data = pd.DataFrame(imp.fit_transform(df))
1004
-
1005
- if self.iterative_mode:
1006
- data = data.astype(dtype="float32")
1007
- else:
1008
- data = data.astype(dtype="Int8")
1009
-
1010
- if self.verbose:
1011
- print("Done!")
1012
-
1013
- if not self.validation_mode:
1014
- return data.values.tolist(), valid_cols
1015
- return data.values, valid_cols
1016
-
1017
- def write2file(
1018
- self, X: Union[pd.DataFrame, np.ndarray, List[List[Union[int, float]]]]
1019
- ) -> None:
1020
- """Write imputed data to file on disk.
1021
-
1022
- Args:
1023
- X (pandas.DataFrame, numpy.ndarray, List[List[Union[int, float]]]): Imputed data to write to file.
1024
-
1025
- Raises:
1026
- TypeError: If X is of unsupported type.
1027
- """
1028
- outfile = os.path.join(
1029
- f"{self.prefix}_output",
1030
- "alignments",
1031
- "Unsupervised",
1032
- "ImputeAlleleFreq",
1033
- )
1034
-
1035
- Path(outfile).mkdir(parents=True, exist_ok=True)
1036
-
1037
- outfile = os.path.join(outfile, "imputed_012.csv")
1038
-
1039
- if isinstance(X, pd.DataFrame):
1040
- df = X
1041
- elif isinstance(X, (np.ndarray, list)):
1042
- df = pd.DataFrame(X)
1043
- else:
1044
- raise TypeError(
1045
- f"Could not write imputed data because it is of incorrect "
1046
- f"type. Got {type(X)}"
1047
- )
1048
-
1049
- df.to_csv(outfile, header=False, index=False)
1050
-
1051
-
1052
- class ImputeMF:
1053
- """Impute missing data using matrix factorization. If ``by_populations=False`` then imputation is by global allele frequency. If ``by_populations=True`` then imputation is by population-wise allele frequency.
1054
-
1055
- Args:
1056
- genotype_data (GenotypeData object or None, optional): GenotypeData instance.
1057
-
1058
- latent_features (float, optional): The number of latent variables used to reduce dimensionality of the data. Defaults to 2.
1059
-
1060
- learning_rate (float, optional): The learning rate for the optimizers. Adjust if the loss is learning too slowly. Defaults to 0.1.
1061
-
1062
- tol (float, optional): Tolerance of the stopping condition. Defaults to 1e-3.
1063
-
1064
- missing (int, optional): Missing data value. Defaults to -9.
1065
-
1066
- prefix (str, optional): Prefix for writing output files. Defaults to "output".
1067
-
1068
- verbose (bool, optional): Whether to print status updates. Set to False for no status updates. Defaults to True.
1069
-
1070
- **kwargs (Dict[str, Any]): Additional keyword arguments to supply. Primarily for internal purposes. Options include: {"iterative_mode": bool}. "iterative_mode" determines whether ``ImputeAlleleFreq`` is being used as the initial imputer in ``IterativeImputer``.
1071
-
1072
- Attributes:
1073
- imputed (GenotypeData): New GenotypeData instance with imputed data.
1074
-
1075
- Example:
1076
- >>>data = GenotypeData(
1077
- >>> filename="test.str",
1078
- >>> filetype="structure",
1079
- >>> popmapfile="test.popmap",
1080
- >>>)
1081
- >>>
1082
- >>>nmf = ImputeMF(
1083
- >>> genotype_data=data,
1084
- >>> by_populations=True,
1085
- >>>)
1086
- >>>
1087
- >>> # Get GenotypeData instance.
1088
- >>>gd_nmf = nmf.imputed
1089
-
1090
- Raises:
1091
- TypeError: genotype_data cannot be NoneType.
1092
- """
1093
-
1094
- def __init__(
1095
- self,
1096
- genotype_data,
1097
- *,
1098
- latent_features: int = 2,
1099
- max_iter: int = 100,
1100
- learning_rate: float = 0.0002,
1101
- regularization_param: float = 0.02,
1102
- tol: float = 0.1,
1103
- n_fail: int = 20,
1104
- missing: int = -9,
1105
- prefix: str = "output",
1106
- verbose: bool = True,
1107
- **kwargs: Dict[str, Any],
1108
- ) -> None:
1109
- self.max_iter = max_iter
1110
- self.latent_features = latent_features
1111
- self.n_fail = n_fail
1112
- self.learning_rate = learning_rate
1113
- self.tol = tol
1114
- self.regularization_param = regularization_param
1115
- self.missing = missing
1116
- self.prefix = prefix
1117
- self.verbose = verbose
1118
- self.iterative_mode = kwargs.get("iterative_mode", False)
1119
- self.validation_mode = kwargs.get("validation_mode", False)
1120
-
1121
- gt = kwargs.get("gt", None)
1122
-
1123
- if genotype_data is None and gt is None:
1124
- raise TypeError("GenotypeData and gt cannot both be NoneType.")
1125
-
1126
- if gt is None:
1127
- X = genotype_data.genotypes_012(fmt="numpy")
1128
- else:
1129
- X = gt.copy()
1130
- imputed012 = pd.DataFrame(self.fit_predict(X))
1131
- genotype_data = deepcopy(genotype_data)
1132
- genotype_data.genotypes_012 = imputed012
1133
-
1134
- if self.validation_mode:
1135
- self.imputed = imputed012.to_numpy()
1136
- else:
1137
- self.imputed = genotype_data
1138
-
1139
- @property
1140
- def genotypes_012(self):
1141
- return self.imputed.genotypes012
1142
-
1143
- @property
1144
- def snp_data(self):
1145
- return self.imputed.snp_data
1146
-
1147
- @property
1148
- def alignment(self):
1149
- return self.imputed.alignment
1150
-
1151
- def fit_predict(self, X):
1152
- # imputation
1153
- if self.verbose:
1154
- print(f"Doing MF imputation...")
1155
- R = X
1156
- R = R.astype(int)
1157
- R[R == self.missing] = -9
1158
- R = R + 1
1159
- R[R < 0] = 0
1160
- n_row = len(R)
1161
- n_col = len(R[0])
1162
- p = np.random.rand(n_row, self.latent_features)
1163
- q = np.random.rand(n_col, self.latent_features)
1164
- q_t = q.T
1165
- fails = 0
1166
- e_current = None
1167
- for step in range(self.max_iter):
1168
- for i in range(n_row):
1169
- for j in range(n_col):
1170
- if R[i][j] > 0:
1171
- eij = R[i][j] - np.dot(p[i, :], q_t[:, j])
1172
- for k in range(self.latent_features):
1173
- p[i][k] = p[i][k] + self.learning_rate * (
1174
- 2 * eij * q_t[k][j]
1175
- - self.regularization_param * p[i][k]
1176
- )
1177
- q_t[k][j] = q_t[k][j] + self.learning_rate * (
1178
- 2 * eij * p[i][k]
1179
- - self.regularization_param * q_t[k][j]
1180
- )
1181
- e = 0
1182
- for i in range(n_row):
1183
- for j in range(len(R[i])):
1184
- if R[i][j] > 0:
1185
- e = e + pow(R[i][j] - np.dot(p[i, :], q_t[:, j]), 2)
1186
- for k in range(self.latent_features):
1187
- e = e + (self.regularization_param / 2) * (
1188
- pow(p[i][k], 2) + pow(q_t[k][j], 2)
1189
- )
1190
- if e_current is None:
1191
- e_current = e
1192
- else:
1193
- if abs(e_current - e) < self.tol:
1194
- fails += 1
1195
- else:
1196
- fails = 0
1197
- e_current = e
1198
- if fails >= self.n_fail:
1199
- break
1200
- nR = np.dot(p, q_t)
1201
-
1202
- # transform values per-column (i.e., only allowing values found in original)
1203
- tR = self.transform(R, nR)
1204
-
1205
- # get accuracy of re-constructing non-missing genotypes
1206
- accuracy = self.accuracy(X, tR)
1207
-
1208
- # insert imputed values for missing genotypes
1209
- fR = X
1210
- fR[X < 0] = tR[X < 0]
1211
-
1212
- if self.verbose:
1213
- print("Done!")
1214
-
1215
- return fR
1216
-
1217
- def transform(self, original, predicted):
1218
- n_row = len(original)
1219
- n_col = len(original[0])
1220
- tR = predicted
1221
- for j in range(n_col):
1222
- observed = predicted[:, j]
1223
- expected = original[:, j]
1224
- options = np.unique(expected[expected != 0])
1225
- for i in range(n_row):
1226
- transform = min(
1227
- options, key=lambda x: abs(x - predicted[i, j])
1228
- )
1229
- tR[i, j] = transform
1230
- tR = tR - 1
1231
- tR[tR < 0] = -9
1232
- return tR
1233
-
1234
- def accuracy(self, expected, predicted):
1235
- prop_same = np.sum(expected[expected >= 0] == predicted[expected >= 0])
1236
- tot = expected[expected >= 0].size
1237
- accuracy = prop_same / tot
1238
- return accuracy
1239
-
1240
- def write2file(
1241
- self, X: Union[pd.DataFrame, np.ndarray, List[List[Union[int, float]]]]
1242
- ) -> None:
1243
- """Write imputed data to file on disk.
1244
-
1245
- Args:
1246
- X (pandas.DataFrame, numpy.ndarray, List[List[Union[int, float]]]): Imputed data to write to file.
1247
-
1248
- Raises:
1249
- TypeError: If X is of unsupported type.
1250
- """
1251
- outfile = os.path.join(
1252
- f"{self.prefix}_output",
1253
- "alignments",
1254
- "Unsupervised",
1255
- "ImputeMF",
1256
- )
1257
-
1258
- Path(outfile).mkdir(parents=True, exist_ok=True)
1259
-
1260
- outfile = os.path.join(outfile, "imputed_012.csv")
1261
-
1262
- if isinstance(X, pd.DataFrame):
1263
- df = X
1264
- elif isinstance(X, (np.ndarray, list)):
1265
- df = pd.DataFrame(X)
1266
- else:
1267
- raise TypeError(
1268
- f"Could not write imputed data because it is of incorrect "
1269
- f"type. Got {type(X)}"
1270
- )
1271
-
1272
- df.to_csv(outfile, header=False, index=False)
1273
-
1274
-
1275
- class ImputeRefAllele:
1276
- """Impute missing data by reference allele.
1277
-
1278
- Args:
1279
- genotype_data (GenotypeData object): GenotypeData instance.
1280
-
1281
- missing (int, optional): Missing data value. Defaults to -9.
1282
-
1283
- verbose (bool, optional): Whether to print status updates. Set to False for no status updates. Defaults to True.
1284
-
1285
- kwargs (Dict[str, Any]): Additional keyword arguments to supply. Primarily for internal purposes. Options include: {"iterative_mode": bool, validation_mode: bool, gt: List[List[int]]}. "iterative_mode" determines whether ``ImputeRefAllele`` is being used as the initial imputer in ``IterativeImputer``\. ``gt`` is used internally for the simple imputers during grid searches and validation. If ``genotype_data is None`` then ``gt`` cannot also be None, and vice versa. Only one of ``gt`` or ``genotype_data`` can be set.
1286
-
1287
- Raises:
1288
- TypeError: genotype_data cannot be NoneType.
1289
-
1290
- Attributes:
1291
- imputed (GenotypeData): New GenotypeData instance with imputed data.
1292
-
1293
- Example:
1294
- >>>data = GenotypeData(
1295
- >>> filename="test.str",
1296
- >>> filetype="structure2rowPopID",
1297
- >>> popmapfile="test.popmap",
1298
- >>>)
1299
- >>>
1300
- >>>refallele = ImputeRefAllele(
1301
- >>> genotype_data=data
1302
- >>>)
1303
- >>>
1304
- >>>gd_refallele = refallele.imputed
1305
- """
1306
-
1307
- def __init__(
1308
- self,
1309
- genotype_data: GenotypeData,
1310
- *,
1311
- missing: int = -9,
1312
- verbose: bool = True,
1313
- **kwargs: Dict[str, Any],
1314
- ) -> None:
1315
- if genotype_data is None:
1316
- raise TypeError("GenotypeData instance must be provided.")
1317
-
1318
- gt = kwargs.get("gt", None)
1319
-
1320
- if gt is None:
1321
- gt_list = genotype_data.genotypes_012(fmt="list")
1322
- else:
1323
- gt_list = gt
1324
-
1325
- self.missing = missing
1326
- self.verbose = verbose
1327
- self.iterative_mode = kwargs.get("iterative_mode", False)
1328
- self.validation_mode = kwargs.get("validation_mode", False)
1329
-
1330
- # Get reference alleles from GenotypeData object
1331
- self.ref_alleles = genotype_data.ref
1332
-
1333
- if not self.validation_mode:
1334
- imputed012 = self.fit_predict(gt_list)
1335
- genotype_data = deepcopy(genotype_data)
1336
- genotype_data.genotypes_012 = imputed012
1337
- self.imputed = genotype_data
1338
- else:
1339
- self.imputed = self.fit_predict(gt_list)
1340
-
1341
- @property
1342
- def genotypes_012(self):
1343
- return self.imputed.genotypes_012
1344
-
1345
- @property
1346
- def snp_data(self):
1347
- return self.imputed.snp_data
1348
-
1349
- @property
1350
- def alignment(self):
1351
- return self.imputed.alignment
1352
-
1353
- def fit_predict(
1354
- self, X: List[List[Union[int, str]]]
1355
- ) -> Union[pd.DataFrame, np.ndarray, List[List[Union[int, str]]]]:
1356
- """Impute missing genotypes using reference alleles.
1357
-
1358
- Impute using reference alleles. Missing alleles are primarily coded as negative; usually -9.
1359
-
1360
- Args:
1361
- X (List[List[Union[int, str]]], numpy.ndarray, or pandas.DataFrame): Genotypes obtained from the GenotypeData object.
1362
-
1363
- Returns:
1364
- pandas.DataFrame, numpy.ndarray, or List[List[Union[int, str]]]: Imputed genotypes of same shape as data.
1365
-
1366
- Raises:
1367
- TypeError: X must be of type list(list(int or str)), numpy.ndarray,
1368
- or pandas.DataFrame, but got {type(X)}
1369
- """
1370
- if self.verbose:
1371
- print("\nImputing missing data with reference alleles...")
1372
-
1373
- if isinstance(X, (list, np.ndarray)):
1374
- df = pd.DataFrame(X)
1375
- elif isinstance(X, pd.DataFrame):
1376
- df = X.copy()
1377
- else:
1378
- raise TypeError(
1379
- f"X must be of type list(list(int or str)), numpy.ndarray, "
1380
- f"or pandas.DataFrame, but got {type(X)}"
1381
- )
1382
-
1383
- df = df.astype(df.dtypes)
1384
- df.replace(self.missing, np.nan, inplace=True)
1385
-
1386
- if df.dtypes[0] == int:
1387
- df.fillna(0, inplace=True)
1388
- else:
1389
- for i, ref in enumerate(self.ref_alleles):
1390
- df[i].fillna(ref, inplace=True)
1391
-
1392
- if self.verbose:
1393
- print("Done!")
1394
-
1395
- if not self.validation_mode:
1396
- return df.values.tolist()
1397
- return df.values
1398
-
1399
- def write2file(
1400
- self, X: Union[pd.DataFrame, np.ndarray, List[List[Union[int, float]]]]
1401
- ) -> None:
1402
- """Write imputed data to file on disk.
1403
-
1404
- Args:
1405
- X (pandas.DataFrame, numpy.ndarray, List[List[Union[int, float]]]): Imputed data to write to file.
1406
-
1407
- Raises:
1408
- TypeError: If X is of unsupported type.
1409
- """
1410
- outfile = os.path.join(
1411
- f"{self.prefix}_output",
1412
- "alignments",
1413
- "Unsupervised",
1414
- "ImputeRefAllele",
1415
- )
1416
-
1417
- Path(outfile).mkdir(parents=True, exist_ok=True)
1418
-
1419
- outfile = os.path.join(outfile, "imputed_012.csv")
1420
-
1421
- if isinstance(X, pd.DataFrame):
1422
- df = X
1423
- elif isinstance(X, (np.ndarray, list)):
1424
- df = pd.DataFrame(X)
1425
- else:
1426
- raise TypeError(
1427
- f"Could not write imputed data because it is of incorrect "
1428
- f"type. Got {type(X)}"
1429
- )
1430
-
1431
- df.to_csv(outfile, header=False, index=False)