pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pg_sui-1.6.16a3.dist-info/METADATA +292 -0
  2. pg_sui-1.6.16a3.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +922 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1436 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1121 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1316 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/METADATA +0 -322
  83. pg_sui-0.2.3.dist-info/RECORD +0 -75
  84. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  85. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  88. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  89. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  90. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  93. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  94. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  95. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  96. pgsui/example_data/trees/test.iqtree +0 -376
  97. pgsui/example_data/trees/test.qmat +0 -5
  98. pgsui/example_data/trees/test.rate +0 -2033
  99. pgsui/example_data/trees/test.tre +0 -1
  100. pgsui/example_data/trees/test_n10.rate +0 -19
  101. pgsui/example_data/trees/test_n100.rate +0 -109
  102. pgsui/example_data/trees/test_n500.rate +0 -509
  103. pgsui/example_data/trees/test_siterates.txt +0 -2024
  104. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  105. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  106. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  107. pgsui/example_data/vcf_files/test.vcf +0 -244
  108. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  109. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  110. pgsui/impute/estimators.py +0 -1268
  111. pgsui/impute/impute.py +0 -1463
  112. pgsui/impute/simple_imputers.py +0 -1431
  113. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  114. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  115. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  116. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  117. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  118. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  119. pgsui/pg_sui.py +0 -261
  120. pgsui/utils/sequence_tools.py +0 -407
  121. simulation/sim_benchmarks.py +0 -333
  122. simulation/sim_treeparams.py +0 -475
  123. test/__init__.py +0 -0
  124. test/pg_sui_simtest.py +0 -215
  125. test/pg_sui_testing.py +0 -523
  126. test/test.py +0 -151
  127. test/test_pgsui.py +0 -374
  128. test/test_tkc.py +0 -185
@@ -1,694 +1,223 @@
1
+ # Standard library imports
1
2
  import copy
2
- import os
3
3
  import logging
4
- import sys
5
- import warnings
6
-
7
- import numpy as np
8
- import pandas as pd
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Literal, Optional, Tuple
9
6
 
10
7
  # Third-party imports
11
8
  import numpy as np
12
- import pandas as pd
13
-
14
9
  from sklearn.base import BaseEstimator, TransformerMixin
15
- from sklearn.impute import SimpleImputer
16
- from sklearn.metrics import (
17
- roc_auc_score,
18
- precision_recall_fscore_support,
19
- average_precision_score,
20
- )
21
- from sklearn.preprocessing import label_binarize
22
-
23
- # Import tensorflow with reduced warnings.
24
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
25
- logging.getLogger("tensorflow").disabled = True
26
- warnings.filterwarnings("ignore", category=UserWarning)
27
-
28
- # noinspection PyPackageRequirements
29
- import tensorflow as tf
30
-
31
- # Disable can't find cuda .dll errors. Also turns of GPU support.
32
- tf.config.set_visible_devices([], "GPU")
33
-
34
- from tensorflow.python.util import deprecation
35
-
36
- # Disable warnings and info logs.
37
- tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
38
- tf.get_logger().setLevel(logging.ERROR)
39
-
40
-
41
- # Monkey patching deprecation utils to supress warnings.
42
- # noinspection PyUnusedLocal
43
- def deprecated(
44
- date, instructions, warn_once=True
45
- ): # pylint: disable=unused-argument
46
- def deprecated_wrapper(func):
47
- return func
48
-
49
- return deprecated_wrapper
50
-
51
-
52
- deprecation.deprecated = deprecated
53
-
54
- # Custom Modules
55
- try:
56
- from ..utils import misc
57
-
58
- except (ModuleNotFoundError, ValueError, ImportError):
59
- from pgsui.utils import misc
60
-
61
-
62
- # Pandas on pip gives a performance warning when doing the below code.
63
- # Apparently it's a bug that exists in the pandas version I used here.
64
- # It can be safely ignored.
65
- warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
66
-
67
-
68
- def encode_onehot(X):
69
- """Convert 012-encoded data to one-hot encodings.
70
- Args:
71
- X (numpy.ndarray): Input array with 012-encoded data and -9 as the missing data value.
72
- Returns:
73
- pandas.DataFrame: One-hot encoded data, ignoring missing values (np.nan).
74
- """
75
- Xt = np.zeros(shape=(X.shape[0], X.shape[1], 3))
76
- mappings = {
77
- 0: np.array([1, 0, 0]),
78
- 1: np.array([0, 1, 0]),
79
- 2: np.array([0, 0, 1]),
80
- -9: np.array([np.nan, np.nan, np.nan]),
81
- }
82
- for row in np.arange(X.shape[0]):
83
- Xt[row] = [mappings[enc] for enc in X[row]]
84
- return Xt
85
10
 
11
+ from pgsui.utils.misc import validate_input_type
86
12
 
87
- def mle(row):
88
- """Get the Maximum Likelihood Estimation for the best prediction. Basically, it sets the index of the maxiumum value in a vector (row) to 1.0, since it is one-hot encoded.
13
+ if TYPE_CHECKING:
14
+ from snpio import TreeParser
89
15
 
90
- Args:
91
- row (numpy.ndarray(float)): Row vector with predicted values as floating points.
92
-
93
- Returns:
94
- numpy.ndarray(float): Row vector with the highest prediction set to 1.0 and the others set to 0.0.
95
- """
96
- res = np.zeros(row.shape[0])
97
- res[np.argmax(row)] = 1
98
- return res
99
16
 
17
+ class SimGenotypeDataTransformer:
18
+ """Simulates missing genotypes at the locus level on a 2D integer matrix.
100
19
 
101
- class UBPInputTransformer(BaseEstimator, TransformerMixin):
102
- """Transform input X prior to estimator fitting.
20
+ This transformer masks a proportion of known genotypes in the input matrix X, setting them to a specified missing value. The masking can be done randomly or based on inverse genotype frequencies, with an option to boost the likelihood of masking heterozygous genotypes.
103
21
 
104
22
  Args:
105
- n_components (int): Number of principal components currently being used in V.
106
-
107
- V (numpy.ndarray or Dict[str, Any]): If doing grid search, should be a dictionary with current_component: numpy.ndarray. If not doing grid search, then it should be a numpy.ndarray.
23
+ prop_missing (float): Proportion of *known* loci to mask (0..1).
24
+ strategy (Literal): Strategy name.
25
+ missing_val (int): Missing code value (default: -9).
26
+ seed (int | None): RNG seed.
27
+ logger (logging.Logger | None): Logger for messages.
28
+ het_boost (float): Multiplier for heterozygotes in inv-genotype mode.
108
29
  """
109
30
 
110
- def __init__(self, n_components, V):
111
- self.n_components = n_components
112
- self.V = V
113
-
114
- def fit(self, X):
115
- """Fit transformer to input data X.
116
-
117
- Args:
118
- X (numpy.ndarray): Input data to fit. If numpy.ndarray, then should be of shape (n_samples, n_components). If dictionary, then should be component: numpy.ndarray.
119
-
120
- Returns:
121
- self: Class instance.
122
- """
123
- self.n_features_in_ = self.n_components
124
- return self
125
-
126
- def transform(self, X):
127
- """Transform input data X to the needed format.
128
-
129
- Args:
130
- X (numpy.ndarray): Input data to fit. If numpy.ndarray, then should be of shape (n_samples, n_components). If dictionary, then should be component: numpy.ndarray.
131
-
132
- Returns:
133
- numpy.ndarray: Formatted input data with correct component.
134
-
135
- Raises:
136
- TypeError: V must be a dictionary if phase is None or phase == 1.
137
- TypeError: V must be a numpy array if phase is 2 or 3.
138
- """
139
- if not isinstance(self.V, dict):
140
- raise TypeError(f"V must be a dictionary, but got {type(self.V)}")
141
- return self.V[self.n_components]
142
-
143
-
144
- class AutoEncoderFeatureTransformer(BaseEstimator, TransformerMixin):
145
- """Transformer to format autoencoder features and targets before model fitting.
146
-
147
- The input data, X, is encoded to one-hot format, and then missing values are filled to [-1] * num_classes.
148
-
149
- Missing and observed boolean masks are also generated.
150
-
151
- Args:
152
- num_classes (int, optional): The number of classes in the last axis dimention of the input array. Defaults to 3.
153
-
154
- return_int (bool, optional): Whether to return an integer-encoded array (If True) or a one-hot or multi-label encoded array (If False.). Defaults to False.
155
-
156
- activate (str or None, optional): If not None, then does the appropriate activation. Multilabel learning uses sigmoid activation, and multiclass uses softmax. If set to None, then the function assumes that the input has already been activated. Possible values include: {None, 'sigmoid', 'softmax'}. Defaults to None.
157
- """
158
-
159
- def __init__(self, num_classes=3, return_int=False, activate=None):
160
- self.num_classes = num_classes
161
- self.return_int = return_int
162
- self.activate = activate
163
-
164
- def fit(self, X, y=None):
165
- """set attributes used to transform X (input features).
166
-
167
- Args:
168
- X (numpy.ndarray): Input integer-encoded numpy array.
169
-
170
- y (None): Just for compatibility with sklearn API.
171
- """
172
- X = misc.validate_input_type(X, return_type="array")
173
-
174
- self.X_decoded = X
175
-
176
- # VAE uses 4 classes ([A,T,G,C]), SAE uses 3 ([0,1,2]).
177
- if self.num_classes == 3:
178
- enc_func = self.encode_012
179
- elif self.num_classes == 4:
180
- enc_func = self.encode_multilab
181
- elif self.num_classes == 10:
182
- enc_func = self.encode_multiclass
183
- else:
184
- raise ValueError(
185
- f"Invalid value passed to num_classes in "
186
- f"AutoEncoderFeatureTransformer. Only 3 or 4 are supported, "
187
- f"but got {self.num_classes}."
188
- )
189
-
190
- # Encode the data.
191
- self.X_train = enc_func(X)
192
- self.classes_ = np.arange(self.num_classes)
193
- self.n_classes_ = self.num_classes
194
-
195
- # Get missing and observed data boolean masks.
196
- self.missing_mask_, self.observed_mask_ = self._get_masks(self.X_train)
197
-
198
- # To accomodate multiclass-multioutput.
199
- self.n_outputs_expected_ = 1
200
-
201
- self.n_outputs_ = self.X_train.shape[1]
202
-
203
- return self
204
-
205
- def transform(self, X):
206
- """Transform X to one-hot encoded format.
207
-
208
- Accomodates multiclass targets with a 3D shape.
209
-
210
- Args:
211
- X (numpy.ndarray): One-hot encoded target data of shape (n_samples, n_features, num_classes).
212
-
213
- Returns:
214
- numpy.ndarray: Transformed target data in one-hot format of shape (n_samples, n_features, num_classes).
215
- """
216
- if self.return_int:
217
- return X
218
- else:
219
- # X = misc.validate_input_type(X, return_type="array")
220
- return self._fill(self.X_train, self.missing_mask_)
221
-
222
- def inverse_transform(self, y, return_proba=False):
223
- """Transform target to output format.
224
-
225
- Args:
226
- y (numpy.ndarray): Array to inverse transform.
227
-
228
- return_proba (bool): Just for compatibility with scikeras API.
229
- """
230
- try:
231
- if self.activate is None:
232
- y = y.numpy()
233
- elif self.activate == "softmax":
234
- y = tf.nn.softmax(y).numpy()
235
- elif self.activate == "sigmoid":
236
- y = tf.nn.sigmoid(y).numpy()
237
- else:
238
- raise ValueError(
239
- f"Invalid value passed to keyword argument activate. Valid "
240
- f"options include: None, 'softmax', or 'sigmoid', but got "
241
- f"{self.activate}"
242
- )
243
- except AttributeError:
244
- # If numpy array already.
245
- if self.activate is None:
246
- y = y.copy()
247
- elif self.activate == "softmax":
248
- y = tf.nn.softmax(tf.convert_to_tensor(y)).numpy()
249
- elif self.activate == "sigmoid":
250
- y = tf.nn.sigmoid(tf.convert_to_tensor(y)).numpy()
251
- else:
252
- raise ValueError(
253
- f"Invalid value passed to keyword argument activate. Valid "
254
- f"options include: None, 'softmax', or 'sigmoid', but got "
255
- f"{self.activate}"
256
- )
257
- return y
258
-
259
- def encode_012(self, X):
260
- """Convert 012-encoded data to one-hot encodings.
261
- Args:
262
- X (numpy.ndarray): Input array with 012-encoded data and -9 as the missing data value.
263
- Returns:
264
- pandas.DataFrame: One-hot encoded data, ignoring missing values (np.nan).
265
- """
266
- Xt = np.zeros(shape=(X.shape[0], X.shape[1], 3))
267
- mappings = {
268
- 0: np.array([1, 0, 0]),
269
- 1: np.array([0, 1, 0]),
270
- 2: np.array([0, 0, 1]),
271
- -9: np.array([np.nan, np.nan, np.nan]),
272
- }
273
- for row in np.arange(X.shape[0]):
274
- Xt[row] = [mappings[enc] for enc in X[row]]
275
- return Xt
276
-
277
- def encode_multilab(self, X, multilab_value=1.0):
278
- """Encode 0-9 integer data in multi-label one-hot format.
279
- Args:
280
- X (numpy.ndarray): Input array with 012-encoded data and -9 as the missing data value.
281
-
282
- multilab_value (float): Value to use for multilabel target encodings. Defaults to 0.5.
283
- Returns:
284
- pandas.DataFrame: One-hot encoded data, ignoring missing values (np.nan). multi-label categories will be encoded as 0.5. Otherwise, it will be 1.0.
285
- """
286
- Xt = np.zeros(shape=(X.shape[0], X.shape[1], 4))
287
- mappings = {
288
- 0: [1.0, 0.0, 0.0, 0.0],
289
- 1: [0.0, 1.0, 0.0, 0.0],
290
- 2: [0.0, 0.0, 1.0, 0.0],
291
- 3: [0.0, 0.0, 0.0, 1.0],
292
- 4: [multilab_value, multilab_value, 0.0, 0.0],
293
- 5: [multilab_value, 0.0, multilab_value, 0.0],
294
- 6: [multilab_value, 0.0, 0.0, multilab_value],
295
- 7: [0.0, multilab_value, multilab_value, 0.0],
296
- 8: [0.0, multilab_value, 0.0, multilab_value],
297
- 9: [0.0, 0.0, multilab_value, multilab_value],
298
- -9: [np.nan, np.nan, np.nan, np.nan],
299
- }
300
- for row in np.arange(X.shape[0]):
301
- Xt[row] = [mappings[enc] for enc in X[row]]
302
- return Xt
303
-
304
- def decode_multilab(self, X, multilab_value=1.0):
305
- """Decode one-hot format data back to 0-9 integer data.
306
-
307
- Args:
308
- X (numpy.ndarray): Input array with one-hot-encoded data.
309
-
310
- multilab_value (float): Value to use for multilabel target encodings. Defaults to 0.5.
311
-
312
- Returns:
313
- pandas.DataFrame: Decoded data, with multi-label categories decoded to their original integer representation.
314
- """
315
- Xt = np.zeros(shape=(X.shape[0], X.shape[1]))
316
- mappings = {
317
- tuple([1.0, 0.0, 0.0, 0.0]): 0,
318
- tuple([0.0, 1.0, 0.0, 0.0]): 1,
319
- tuple([0.0, 0.0, 1.0, 0.0]): 2,
320
- tuple([0.0, 0.0, 0.0, 1.0]): 3,
321
- tuple([multilab_value, multilab_value, 0.0, 0.0]): 4,
322
- tuple([multilab_value, 0.0, multilab_value, 0.0]): 5,
323
- tuple([multilab_value, 0.0, 0.0, multilab_value]): 6,
324
- tuple([0.0, multilab_value, multilab_value, 0.0]): 7,
325
- tuple([0.0, multilab_value, 0.0, multilab_value]): 8,
326
- tuple([0.0, 0.0, multilab_value, multilab_value]): 9,
327
- tuple([np.nan, np.nan, np.nan, np.nan]): -9,
328
- }
329
- for row in np.arange(X.shape[0]):
330
- Xt[row] = [mappings[tuple(enc)] for enc in X[row]]
331
- return Xt
332
-
333
- def encode_multiclass(self, X, num_classes=10, missing_value=-9):
334
- """Encode 0-9 integer data in multi-class one-hot format.
335
-
336
- Missing values get encoded as ``[np.nan] * num_classes``
337
- Args:
338
- X (numpy.ndarray): Input array with 012-encoded data and ``missing_value`` as the missing data value.
339
-
340
- num_classes (int, optional): Number of classes to use. Defaults to 10.
341
-
342
- missing_value (int, optional): Missing data value to replace with ``[np.nan] * num_classes``\. Defaults to -9.
343
- Returns:
344
- pandas.DataFrame: Multi-class one-hot encoded data, ignoring missing values (np.nan).
345
- """
346
- int_cats, ohe_arr = np.arange(num_classes), np.eye(num_classes)
347
- mappings = dict(zip(int_cats, ohe_arr))
348
- mappings[missing_value] = np.array([np.nan] * num_classes)
349
-
350
- Xt = np.zeros(shape=(X.shape[0], X.shape[1], num_classes))
351
- for row in np.arange(X.shape[0]):
352
- Xt[row] = [mappings[enc] for enc in X[row]]
353
- return Xt
354
-
355
- def _fill(self, data, missing_mask, missing_value=-1):
356
- """Mask missing data as ``missing_value``\.
357
-
358
- Args:
359
- data (numpy.ndarray): Input with missing values of shape (n_samples, n_features, num_classes).
360
-
361
- missing_mask (np.ndarray(bool)): Missing data mask with True corresponding to a missing value.
362
-
363
- missing_value (int): Value to set missing data to. If a list is provided, then its length should equal the number of one-hot classes.
364
- """
365
- if self.num_classes > 1:
366
- missing_value = [missing_value] * self.num_classes
367
- data[missing_mask] = missing_value
368
- return data
369
-
370
- def _get_masks(self, X):
371
- """Format the provided target data for use with UBP/NLPCA.
372
-
373
- Args:
374
- y (numpy.ndarray(float)): Input data that will be used as the target of shape (n_samples, n_features, num_classes).
375
-
376
- Returns:
377
- numpy.ndarray(float): Missing data mask, with missing values encoded as 1's and non-missing as 0's.
378
-
379
- numpy.ndarray(float): Observed data mask, with non-missing values encoded as 1's and missing values as 0's.
380
- """
381
- missing_mask = self._create_missing_mask(X)
382
- observed_mask = ~missing_mask
383
- return missing_mask, observed_mask
384
-
385
- def _create_missing_mask(self, data):
386
- """Creates a missing data mask with boolean values.
387
- Args:
388
- data (numpy.ndarray): Data to generate missing mask from, of shape (n_samples, n_features, n_classes).
389
- Returns:
390
- numpy.ndarray(bool): Boolean mask of missing values of shape (n_samples, n_features), with True corresponding to a missing data point.
391
- """
392
- return np.isnan(data).all(axis=2)
393
-
394
-
395
- class MLPTargetTransformer(BaseEstimator, TransformerMixin):
396
- """Transformer to format UBP / NLPCA target data both before and after model fitting."""
31
+ def __init__(
32
+ self,
33
+ *,
34
+ prop_missing: float = 0.1,
35
+ strategy: Literal["random", "random_inv_genotype"] = "random",
36
+ missing_val: int = -1,
37
+ seed: int | None = None,
38
+ logger: logging.Logger | None = None,
39
+ het_boost: float = 1.0,
40
+ ):
41
+ self.prop_missing = float(prop_missing)
42
+ self.strategy = strategy
43
+ self.missing_val = int(missing_val)
44
+ self.seed = seed
45
+ self.rng = np.random.default_rng(seed)
46
+ self.het_boost = float(het_boost)
47
+ self.logger = logger or logging.getLogger(__name__)
397
48
 
398
- def fit(self, y):
399
- """Fit 012-encoded target data.
49
+ def fit(self, X, y=None) -> "SimGenotypeDataTransformer":
50
+ """Stateless.
400
51
 
401
52
  Args:
402
- y (numpy.ndarray): Target data that is 012-encoded.
403
-
404
- Returns:
405
- self: Class instance.
53
+ X (np.ndarray): (n_samples, n_features), integer codes {0..9} or <0 as missing.
54
+ y: Ignored.
406
55
  """
407
- y = misc.validate_input_type(y, return_type="array")
408
-
409
- # Original 012-encoded y
410
- self.y_decoded_ = y
411
-
412
- y_train = encode_onehot(y)
413
-
414
- # Get missing and observed data boolean masks.
415
- self.missing_mask_, self.observed_mask_ = self._get_masks(y_train)
416
-
417
- # To accomodate multiclass-multioutput.
418
- self.n_outputs_expected_ = 1
419
-
420
56
  return self
421
57
 
422
- def transform(self, y):
423
- """Transform y_true to one-hot encoded.
424
-
425
- Accomodates multiclass-multioutput targets.
426
-
427
- Args:
428
- y (numpy.ndarray): One-hot encoded target data.
429
-
430
- Returns:
431
- numpy.ndarray: y_true target data.
432
- """
433
- y = misc.validate_input_type(y, return_type="array")
434
- y_train = encode_onehot(y)
435
- return self._fill(y_train, self.missing_mask_)
436
-
437
- def inverse_transform(self, y):
438
- """Decode y_pred from one-hot to 012-based encoding.
439
-
440
- This allows sklearn.metrics to be used.
58
+ def transform(self, X: np.ndarray) -> tuple[np.ndarray, dict]:
59
+ """Apply missing-data simulation on a 2D genotype matrix.
441
60
 
442
61
  Args:
443
- y (numpy.ndarray): One-hot encoded predicted probabilities after model fitting.
62
+ X (np.ndarray): (n_samples, n_features), integer codes {0..9} or <0 as missing.
444
63
 
445
64
  Returns:
446
- numpy.ndarray: y predictions in same format as y_true.
65
+ tuple[np.ndarray, dict]: (X_masked, masks) where masks has keys: 'original': original missing (boolean 2D). 'simulated': loci masked here (boolean 2D). 'all': union of original + simulated (boolean 2D)
447
66
  """
448
- # VAE has tuple output
449
- if isinstance(y, tuple):
450
- y = y[0]
67
+ if X.ndim != 2:
68
+ msg = f"X must be 2D, got shape {X.shape}"
69
+ self.logger.error(msg)
70
+ raise ValueError(msg)
451
71
 
452
- # Return predictions.
453
- return tf.nn.softmax(y).numpy()
72
+ X = np.asarray(X)
73
+ original_mask = X < 0
454
74
 
455
- def _fill(self, data, missing_mask, missing_value=-1, num_classes=3):
456
- """Mask missing data as ``missing_value``\.
75
+ sim_mask = self._simulate_missing_mask(X, original_mask)
76
+ sim_mask = sim_mask & (~original_mask)
77
+ sim_mask = self._validate_mask(sim_mask)
457
78
 
458
- Args:
459
- data (numpy.ndarray): Input with missing values of shape (n_samples, n_features, num_classes).
460
-
461
- missing_mask (np.ndarray(bool)): Missing data mask with True corresponding to a missing value.
462
-
463
- missing_value (int): Value to set missing data to. If a list is provided, then its length should equal the number of one-hot classes. Defaults to -1.
464
-
465
- num_classes (int): Number of classes in dataset. Defaults to 3.
466
- """
467
- if num_classes > 1:
468
- missing_value = [missing_value] * num_classes
469
- data[missing_mask] = missing_value
470
- return data
471
-
472
- def _get_masks(self, X):
473
- """Format the provided target data for use with UBP/NLPCA.
474
-
475
- Args:
476
- X (numpy.ndarray(float)): Input data that will be used as the target.
477
-
478
- Returns:
479
- numpy.ndarray(float): Missing data mask, with missing values encoded as 1's and non-missing as 0's.
480
-
481
- numpy.ndarray(float): Observed data mask, with non-missing values encoded as 1's and missing values as 0's.
482
- """
483
- missing_mask = self._create_missing_mask(X)
484
- observed_mask = ~missing_mask
485
- return missing_mask, observed_mask
486
-
487
- def _create_missing_mask(self, data):
488
- """Creates a missing data mask with boolean values.
489
- Args:
490
- data (numpy.ndarray): Data to generate missing mask from, of shape (n_samples, n_features, n_classes).
491
- Returns:
492
- numpy.ndarray(bool): Boolean mask of missing values of shape (n_samples, n_features), with True corresponding to a missing data point.
493
- """
494
- return np.isnan(data).all(axis=2)
495
-
496
- def _decode(self, y):
497
- """Evaluate UBP / NLPCA predictions by calculating the highest predicted value.
498
-
499
- Calucalates highest predicted value for each row vector and each class, setting the most likely class to 1.0.
500
-
501
- Args:
502
- y (numpy.ndarray): Input one-hot encoded data.
79
+ all_mask = original_mask | sim_mask
80
+ Xt = X.copy()
81
+ Xt[all_mask] = self.missing_val
503
82
 
504
- Returns:
505
- numpy.ndarray: Imputed one-hot encoded values.
506
- """
507
- Xprob = y
508
- Xt = np.apply_along_axis(mle, axis=2, arr=Xprob)
509
- Xpred = np.argmax(Xt, axis=2)
510
- Xtrue = np.argmax(y, axis=2)
511
- Xdecoded = np.zeros((Xpred.shape[0], Xpred.shape[1]))
512
- for idx in np.arange(Xdecoded):
513
- imputed_idx = np.where(self.observed_mask_[idx] == 0)
514
- known_idx = np.nonzero(self.observed_mask_[idx])
515
- Xdecoded[idx, imputed_idx] = Xpred[idx, imputed_idx]
516
- Xdecoded[idx, known_idx] = Xtrue[idx, known_idx]
517
- return Xdecoded.astype("int8")
518
-
519
-
520
- class UBPTargetTransformer(BaseEstimator, TransformerMixin):
521
- """Transformer to format UBP / NLPCA target data both before model fitting.
522
-
523
- Examples:
524
- >>>ubp_tt = UBPTargetTransformer()
525
- >>>y_train = ubp_tt.fit_transform(y)
526
- """
83
+ masks = {"original": original_mask, "simulated": sim_mask, "all": all_mask}
84
+ return Xt, masks
527
85
 
528
- def fit(self, y):
529
- """Fit 012-encoded target data.
86
+ # ---- strategies ----
87
+ def _simulate_missing_mask(
88
+ self, X: np.ndarray, original_mask: np.ndarray
89
+ ) -> np.ndarray:
90
+ """Simulate missingness mask based on the chosen strategy.
530
91
 
531
92
  Args:
532
- y (numpy.ndarray): Target data that is 012-encoded, of shape (n_samples, n_features).
93
+ X (np.ndarray): Input genotype matrix.
94
+ original_mask (np.ndarray): Boolean mask of original missing values.
533
95
 
534
96
  Returns:
535
- self: Class instance.
97
+ np.ndarray: Simulated missing mask.
536
98
  """
537
- y = misc.validate_input_type(y, return_type="array")
538
-
539
- # Original 012-encoded y
540
- self.y_decoded_ = y
541
-
542
- # One-hot encode y.
543
- y_train = encode_onehot(y)
544
-
545
- # Get missing and observed data boolean masks.
546
- self.missing_mask_, self.observed_mask_ = self._get_masks(y_train)
547
-
548
- # To accomodate multiclass-multioutput.
549
- self.n_outputs_expected_ = 1
550
-
551
- return self
99
+ if self.strategy == "random":
100
+ return self._simulate_random(original_mask)
101
+ elif self.strategy == "random_inv_genotype":
102
+ return self._simulate_inv_genotype(X, original_mask)
552
103
 
553
- def transform(self, y):
554
- """Transform 012-encoded target to one-hot encoded format.
104
+ msg = "strategy must be one of {'random','random_inv_genotype'}"
105
+ self.logger.error(msg)
106
+ raise ValueError(msg)
555
107
 
556
- Accomodates multiclass-multioutput targets.
108
+ def _simulate_random(self, original_mask: np.ndarray) -> np.ndarray:
109
+ rows, cols = np.where(~original_mask)
110
+ n_known = len(rows)
111
+ mask = np.zeros_like(original_mask, dtype=bool)
557
112
 
558
- Args:
559
- y (numpy.ndarray): One-hot encoded target data of shape (n_samples, n_features).
113
+ if n_known == 0:
114
+ return mask
560
115
 
561
- Returns:
562
- numpy.ndarray: y_true target data.
563
- """
564
- y = misc.validate_input_type(y, return_type="array")
565
- y_train = encode_onehot(y)
566
- return self._fill(y_train, self.missing_mask_)
116
+ n_to_mask = int(np.floor(self.prop_missing * n_known))
567
117
 
568
- def inverse_transform(self, y):
569
- """Decode y_predicted from one-hot to 012-integer encoding.
118
+ if n_to_mask <= 0:
119
+ return mask
570
120
 
571
- Performs a softmax activation for multiclass classification.
121
+ idx = self.rng.choice(n_known, size=n_to_mask, replace=False)
122
+ mask[rows[idx], cols[idx]] = True
123
+ return mask
572
124
 
573
- This allows sklearn.metrics to be used.
125
+ def _simulate_inv_genotype(
126
+ self, X: np.ndarray, original_mask: np.ndarray
127
+ ) -> np.ndarray:
128
+ """Simulate missingness mask inversely proportional to genotype frequencies.
574
129
 
575
130
  Args:
576
- y (numpy.ndarray): One-hot encoded predicted probabilities after model fitting, of shape (n_samples, n_features, num_classes).
131
+ X (np.ndarray): Input genotype matrix.
132
+ original_mask (np.ndarray): Boolean mask of original missing values.
577
133
 
578
134
  Returns:
579
- numpy.ndarray: y predictions in same format as y_true (n_samples, n_features).
135
+ np.ndarray: Simulated missing mask. 0..3: homozygous (0,1,2,3). 4..9: heterozygous (0/1,0/2,0/3,1/2,1/3,2/3).
580
136
  """
581
- return tf.nn.softmax(y).numpy()
582
137
 
583
- def _fill(self, data, missing_mask, missing_value=-1, num_classes=3):
584
- """Mask missing data as ``missing_value``\.
138
+ rows, cols = np.where(~original_mask)
139
+ n_known = len(rows)
140
+ mask = np.zeros_like(original_mask, dtype=bool)
141
+ if n_known == 0:
142
+ return mask
585
143
 
586
- Args:
587
- data (numpy.ndarray): Input with missing values of shape (n_samples, n_features, num_classes).
144
+ # Global genotype frequencies (0..9) from all known
145
+ vals = X[~original_mask].astype(int)
146
+ vals = vals[(vals >= 0) & (vals < 10)]
588
147
 
589
- missing_mask (np.ndarray(bool)): Missing data mask with True corresponding to a missing value, of shape (n_samples, n_features).
148
+ if vals.size == 0:
149
+ return self._simulate_random(original_mask)
590
150
 
591
- missing_value (int, optional): Value to set missing data to. If a list is provided, then its length should equal the number of one-hot classes. Defaults to -1.
151
+ cnt = np.bincount(vals, minlength=10).astype(float)
152
+ freqs = cnt / (cnt.sum() + 1e-12)
592
153
 
593
- num_classes (int, optional): Number of classes to use. Defaults to 3.
594
- """
595
- if num_classes > 1:
596
- missing_value = [missing_value] * num_classes
597
- data[missing_mask] = missing_value
598
- return data
154
+ # Candidate weights
155
+ geno_known = X[rows, cols].astype(int) # (n_known,)
156
+ inv = 1.0 / (freqs[geno_known] + 1e-12)
599
157
 
600
- def _get_masks(self, y):
601
- """Format the provided target data for use with UBP/NLPCA models.
158
+ # Optional het boost (heterozygous codes are 4..9)
159
+ if self.het_boost != 1.0:
160
+ is_het = (geno_known >= 4) & (geno_known <= 9)
161
+ inv = inv * np.where(is_het, self.het_boost, 1.0)
602
162
 
603
- Args:
604
- y (numpy.ndarray(float)): Input data that will be used as the target of shape (n_samples, n_features, num_classes).
163
+ n_to_mask = int(np.floor(self.prop_missing * n_known))
164
+ if n_to_mask <= 0:
165
+ return mask
605
166
 
606
- Returns:
607
- numpy.ndarray(float): Missing data mask, with missing values encoded as 1's and non-missing as 0's.
608
-
609
- numpy.ndarray(float): Observed data mask, with non-missing values encoded as 1's and missing values as 0's.
610
- """
611
- missing_mask = self._create_missing_mask(y)
612
- observed_mask = ~missing_mask
613
- return missing_mask, observed_mask
614
-
615
- def _create_missing_mask(self, data):
616
- """Creates a missing data mask with boolean values.
617
-
618
- Args:
619
- data (numpy.ndarray): Data to generate missing mask from, of shape (n_samples, n_features, n_classes).
620
-
621
- Returns:
622
- numpy.ndarray(bool): Boolean mask of missing values of shape (n_samples, n_features), with True corresponding to a missing data point.
623
- """
624
- return np.isnan(data).all(axis=2)
625
-
626
- def _decode(self, y):
627
- """Evaluate UBP/NLPCA predictions by calculating the argmax.
167
+ probs = inv / (inv.sum() + 1e-12)
168
+ idx = self.rng.choice(n_known, size=n_to_mask, replace=False, p=probs)
169
+ mask[rows[idx], cols[idx]] = True
170
+ return mask
628
171
 
629
- Calucalates highest predicted value for each row vector and each class, setting the most likely class to 1.0.
172
+ def _validate_mask(self, mask: np.ndarray) -> np.ndarray:
173
+ """Avoid fully-masked rows/columns.
630
174
 
631
175
  Args:
632
- y (numpy.ndarray): Input one-hot encoded data of shape (n_samples, n_features, num_classes).
176
+ mask (np.ndarray): Input boolean mask.
633
177
 
634
178
  Returns:
635
- numpy.ndarray: Imputed one-hot encoded values.
636
- """
637
- Xprob = y
638
- Xt = np.apply_along_axis(mle, axis=2, arr=Xprob)
639
- Xpred = np.argmax(Xt, axis=2)
640
- Xtrue = np.argmax(y, axis=2)
641
- Xdecoded = np.zeros((Xpred.shape[0], Xpred.shape[1]))
642
- for idx in np.arange(Xdecoded):
643
- imputed_idx = np.where(self.observed_mask_[idx] == 0)
644
- known_idx = np.nonzero(self.observed_mask_[idx])
645
- Xdecoded[idx, imputed_idx] = Xpred[idx, imputed_idx]
646
- Xdecoded[idx, known_idx] = Xtrue[idx, known_idx]
647
- return Xdecoded.astype("int8")
179
+ np.ndarray: Validated mask.
180
+ """
181
+ rng = self.rng
182
+ # columns
183
+ full_cols = np.where(mask.all(axis=0))[0]
184
+ for c in full_cols:
185
+ r = int(rng.integers(0, mask.shape[0]))
186
+ mask[r, c] = False
187
+ # rows
188
+ full_rows = np.where(mask.all(axis=1))[0]
189
+ for r in full_rows:
190
+ c = int(rng.integers(0, mask.shape[1]))
191
+ mask[r, c] = False
192
+ return mask
648
193
 
649
194
 
650
- class SimGenotypeDataTransformer(BaseEstimator, TransformerMixin):
651
- """Simulate missing data on genotypes read/ encoded in a GenotypeData object.
195
+ class SimMissingTransformer(BaseEstimator, TransformerMixin):
196
+ """Simulate missing data on genotypes encoded as 0/1/2 integers.
652
197
 
653
- Copies metadata from a GenotypeData object and simulates user-specified proportion of missing data
198
+ This transformer is designed to work with genotype data that has been preprocessed into a suitable format. It simulates missing data according to various strategies, allowing for the testing and evaluation of imputation methods. The simulated missing data can be controlled in terms of proportion and distribution across samples and loci.
654
199
 
655
200
  Args:
656
201
  genotype_data (GenotypeData object): GenotypeData instance.
657
-
658
- prop_missing (float, optional): Proportion of missing data desired in output. Defaults to 0.1
659
-
660
- strategy (str, optional): Strategy for simulating missing data. May be one of: "nonrandom", "nonrandom_weighted", "random_weighted", "random_weighted_inv", or "random". When set to "nonrandom", branches from GenotypeData.guidetree will be randomly sampled to generate missing data on descendant nodes. For "nonrandom_weighted", missing data will be placed on nodes proportionally to their branch lengths (e.g., to generate data distributed as might be the case with mutation-disruption of RAD sites). Defaults to "random"
661
-
202
+ prop_missing (float, optional): Proportion of missing data desired in output. Must be in the interval [0, 1]. Defaults to 0.1
203
+ strategy (Literal["nonrandom", "nonrandom_weighted", "random_weighted", "random_weighted_inv", "random"]): Strategy for simulating missing data. "random": Uniformly masks genotypes at random among eligible entries until the target missing proportion is reached. "random_weighted": Masks genotypes at random with probabilities proportional to their observed genotype frequencies in each column (more common genotypes are more likely to be masked). "random_weighted_inv": Masks genotypes at random with probabilities inversely proportional to their observed genotype frequencies in each column (rarer genotypes are more likely to be masked). "nonrandom": Uses the supplied genotype tree to place missing data on clades that are sampled uniformly from internal and/or tip nodes, producing phylogenetically clustered missingness. "nonrandom_weighted": As in "nonrandom", but clades are sampled with probabilities proportional to their branch lengths, concentrating missingness on longer branches (e.g., mimicking locus dropout tied to evolutionary divergence). Defaults to "random".
662
204
  missing_val (int, optional): Value that represents missing data. Defaults to -9.
663
-
664
205
  mask_missing (bool, optional): True if you want to skip original missing values when simulating new missing data, False otherwise. Defaults to True.
665
-
666
206
  verbose (bool, optional): Verbosity level. Defaults to 0.
667
-
668
207
  tol (float): Tolerance to reach proportion specified in self.prop_missing. Defaults to 1/num_snps*num_inds
669
-
670
208
  max_tries (int): Maximum number of tries to reach targeted missing data proportion within specified tol. If None, num_inds will be used. Defaults to None.
671
209
 
672
210
  Attributes:
673
-
674
211
  original_missing_mask_ (numpy.ndarray): Array with boolean mask for original missing locations.
675
-
676
212
  simulated_missing_mask_ (numpy.ndarray): Array with boolean mask for simulated missing locations, excluding the original ones.
677
-
678
213
  all_missing_mask_ (numpy.ndarray): Array with boolean mask for all missing locations, including both simulated and original.
679
-
680
- Properties:
681
- missing_count (int): Number of genotypes masked by chosen missing data strategy
682
-
683
- prop_missing_real (float): True proportion of missing data generated using chosen strategy
684
-
685
- mask (numpy.ndarray): 2-dimensional array tracking the indices of sampled missing data sites (n_samples, n_sites)
686
214
  """
687
215
 
688
216
  def __init__(
689
217
  self,
690
218
  genotype_data,
691
219
  *,
220
+ tree_parser: Optional["TreeParser"] = None,
692
221
  prop_missing=0.1,
693
222
  strategy="random",
694
223
  missing_val=-9,
@@ -696,8 +225,10 @@ class SimGenotypeDataTransformer(BaseEstimator, TransformerMixin):
696
225
  verbose=0,
697
226
  tol=None,
698
227
  max_tries=None,
228
+ logger: logging.Logger | None = None,
699
229
  ) -> None:
700
230
  self.genotype_data = genotype_data
231
+ self.tree_parser = tree_parser
701
232
  self.prop_missing = prop_missing
702
233
  self.strategy = strategy
703
234
  self.missing_val = missing_val
@@ -705,396 +236,470 @@ class SimGenotypeDataTransformer(BaseEstimator, TransformerMixin):
705
236
  self.verbose = verbose
706
237
  self.tol = tol
707
238
  self.max_tries = max_tries
239
+ self.logger = logger or logging.getLogger(__name__)
708
240
 
709
- def fit(self, X):
241
+ def fit(self, X: np.ndarray, y=None) -> "SimMissingTransformer":
710
242
  """Fit to input data X by simulating missing data.
711
243
 
712
244
  Missing data will be simulated in varying ways depending on the ``strategy`` setting.
713
245
 
714
246
  Args:
715
- X (pandas.DataFrame, numpy.ndarray, or List[List[int]]): Data with which to simulate missing data. It should have already been imputed with one of the non-machine learning simple imputers, and there should be no missing data present in X.
247
+ X (np.ndarray): Data with which to simulate missing data. It should have already been imputed with one of the non-machine learning simple imputers, and there should be no missing data present in X.
716
248
 
717
249
  Raises:
718
- TypeError: SimGenotypeData.tree must not be NoneType when using strategy="nonrandom" or "nonrandom_weighted".
719
-
250
+ TypeError: ``SimGenotypeDataTreeTransformer.tree`` must not be NoneType when using strategy="nonrandom" or "nonrandom_weighted".
720
251
  ValueError: Invalid ``strategy`` parameter provided.
721
252
  """
722
- X = misc.validate_input_type(X, return_type="array").astype("float32")
253
+ X = np.asarray(validate_input_type(X, return_type="array")).astype("float32")
723
254
 
724
- if self.verbose > 0:
725
- print(
726
- f"\nAdding {self.prop_missing} missing data per column "
727
- f"using strategy: {self.strategy}"
728
- )
255
+ self.logger.info(
256
+ f"Adding {self.prop_missing} missing data per column using strategy: {self.strategy}"
257
+ )
729
258
 
730
- if np.all(np.isnan(np.array([self.missing_val])) == False):
259
+ if not np.isnan(self.missing_val):
260
+ X = X.copy()
731
261
  X[X == self.missing_val] = np.nan
732
262
 
733
263
  self.original_missing_mask_ = np.isnan(X)
734
264
 
735
265
  if self.strategy == "random":
736
- if self.mask_missing:
737
- # Get indexes where non-missing (Xobs) and missing (Xmiss).
738
- Xobs = np.where(~self.original_missing_mask_.ravel())[0]
739
- Xmiss = np.where(self.original_missing_mask_.ravel())[0]
740
-
741
- # Generate mask of 0's (non-missing) and 1's (missing).
742
- obs_mask = np.random.choice(
743
- [0, 1],
744
- size=Xobs.size,
745
- p=((1 - self.prop_missing), self.prop_missing),
746
- ).astype(bool)
747
-
748
- # Make missing data mask.
749
- mask = np.zeros(X.size)
750
- mask[Xobs] = obs_mask
751
- mask[Xmiss] = 1
752
-
753
- # Reshape from raveled to 2D.
754
- # With strategy=="random", mask_ is equal to all_missing_.
755
- self.mask_ = np.reshape(mask, X.shape)
266
+ present = ~self.original_missing_mask_
267
+ self.mask_ = np.zeros_like(X, dtype=bool)
268
+
269
+ # sample only over present sites
270
+ draws = np.random.random(X.shape)
271
+ self.mask_[present] = draws[present] < self.prop_missing
756
272
 
273
+ if self.mask_missing:
274
+ # keep original-missing as not simulated
275
+ pass
757
276
  else:
758
- # Generate mask of 0's (non-missing) and 1's (missing).
759
- self.mask_ = np.random.choice(
760
- [0, 1],
761
- size=X.shape,
762
- p=((1 - self.prop_missing), self.prop_missing),
763
- ).astype(bool)
277
+ # optionally also include original-missing as masked (no-op in
278
+ # transform anyway)
279
+ self.mask_[~present] = True
764
280
 
765
- # Make sure no entirely missing columns were simulated.
766
- self._validate_mask()
281
+ self._validate_mask(use_non_original_only=True)
767
282
 
768
283
  elif self.strategy == "random_weighted":
769
- self.mask_ = self.random_weighted_missing_data(X, inv=False)
284
+ self.mask_ = self.random_weighted_missing_data(
285
+ X, inv=False, target_rate=self.prop_missing
286
+ )
770
287
 
771
288
  elif self.strategy == "random_weighted_inv":
772
- self.mask_ = self.random_weighted_missing_data(X, inv=True)
773
-
774
- elif (
775
- self.strategy == "nonrandom"
776
- or self.strategy == "nonrandom_weighted"
777
- ):
778
- if self.genotype_data.tree is None:
779
- raise TypeError(
780
- "SimGenotypeData.tree cannot be NoneType when "
781
- "strategy='nonrandom' or 'nonrandom_weighted'"
782
- )
783
-
784
- mask = np.full_like(X, 0.0, dtype=bool)
785
-
786
- if self.strategy == "nonrandom_weighted":
787
- weighted = True
788
- else:
789
- weighted = False
790
-
791
- sample_map = dict()
792
- for i, sample in enumerate(self.genotype_data.samples):
793
- sample_map[sample] = i
794
-
795
- # if no tolerance provided, set to 1 snp position
796
- if self.tol is None:
797
- self.tol = 1.0 / mask.size
289
+ self.mask_ = self.random_weighted_missing_data(
290
+ X, inv=True, target_rate=self.prop_missing
291
+ )
798
292
 
799
- # if no max_tries provided, set to # inds
800
- if self.max_tries is None:
801
- self.max_tries = mask.shape[0]
293
+ elif self.strategy.startswith("nonrandom"):
294
+ if self.strategy not in {"nonrandom", "nonrandom_weighted"}:
295
+ msg = f"strategy must be one of {{'nonrandom','nonrandom_weighted'}}, got: {self.strategy}"
296
+ self.logger.error(msg)
297
+ raise ValueError(msg)
298
+
299
+ if self.tree_parser is None or not hasattr(self.tree_parser, "tree"):
300
+ msg = "SimMissingTransformer.tree cannot be NoneType when strategy='nonrandom' or strategy='nonrandom_weighted'"
301
+ self.logger.error(msg)
302
+ raise TypeError(msg)
303
+
304
+ rng = np.random.default_rng()
305
+ skip_root = True
306
+ weighted = self.strategy == "nonrandom_weighted"
307
+
308
+ # working mask
309
+ mask = np.zeros_like(X, dtype=bool)
310
+
311
+ # eligible cells
312
+ present = (
313
+ ~self.original_missing_mask_
314
+ if self.mask_missing
315
+ else np.ones_like(mask, dtype=bool)
316
+ )
802
317
 
803
- filled = False
804
- while not filled:
805
- # Get list of samples from tree
806
- samples = self._sample_tree(
807
- internal_only=False, skip_root=True, weighted=weighted
318
+ total_eligible = int(present.sum())
319
+ if total_eligible == 0:
320
+ self.mask_ = mask
321
+ self._validate_mask(use_non_original_only=self.mask_missing)
322
+ self.all_missing_mask_ = np.logical_or(
323
+ self.mask_, self.original_missing_mask_
324
+ )
325
+ self.sim_missing_mask_ = np.logical_and(
326
+ self.all_missing_mask_, ~self.original_missing_mask_
808
327
  )
328
+ return self
329
+
330
+ target = int(round(self.prop_missing * total_eligible))
331
+ tol = int(
332
+ max(
333
+ 1,
334
+ (self.tol if self.tol is not None else 1.0 / mask.size)
335
+ * total_eligible,
336
+ )
337
+ )
809
338
 
810
- # Convert to row indices
811
- rows = [sample_map[i] for i in samples]
339
+ # map tip labels -> row indices
340
+ name_to_idx = {name: i for i, name in enumerate(self.genotype_data.samples)}
812
341
 
813
- # Randomly sample a column
814
- col_idx = np.random.randint(0, mask.shape[1])
815
- sampled_col = copy.copy(mask[:, col_idx])
816
- miss_mask = copy.copy(self.original_missing_mask_[:, col_idx])
342
+ max_outer = (
343
+ self.max_tries
344
+ if self.max_tries is not None
345
+ else max(10_000, mask.shape[0] * 10)
346
+ )
347
+ placed = int(mask.sum())
348
+ best_delta = abs(placed - target)
349
+ tries = 0
350
+
351
+ # simple per-locus quota to distribute hits
352
+ col_quota = np.full(
353
+ mask.shape[1],
354
+ max(1, int(np.ceil(target / max(1, mask.shape[1])))),
355
+ dtype=int,
356
+ )
817
357
 
818
- # Mask column
819
- sampled_col[rows] = True
358
+ while tries < max_outer and abs(placed - target) > tol:
359
+ tries += 1
360
+
361
+ # >>> Call _sample_tree here <<<
362
+ try:
363
+ tips = self._sample_tree(
364
+ internal_only=False,
365
+ tips_only=False,
366
+ skip_root=skip_root,
367
+ weighted=weighted,
368
+ rng=rng,
369
+ )
370
+ except ValueError:
371
+ # no eligible nodes or no tips intersect samples; try again
372
+ continue
820
373
 
821
- # If original was missing, set back to False.
822
- if self.mask_missing:
823
- sampled_col[miss_mask] = False
374
+ # Convert to row indices; skip labels not in matrix
375
+ rows = [name_to_idx[t] for t in tips if t in name_to_idx]
376
+ if not rows:
377
+ continue
378
+
379
+ # choose a column to edit
380
+ cols_left = np.flatnonzero(col_quota > 0)
381
+ if cols_left.size == 0:
382
+ cols_left = np.arange(mask.shape[1])
383
+ j = int(rng.choice(cols_left))
824
384
 
825
- # check that column is not 100% missing now
826
- # if yes, sample again
827
- if np.sum(sampled_col) == sampled_col.size:
385
+ # only edit eligible cells in this column
386
+ eligible_rows = np.fromiter(
387
+ (r for r in rows if present[r, j]), dtype=int
388
+ )
389
+ if eligible_rows.size == 0:
828
390
  continue
829
391
 
830
- # if not, set values in mask matrix
831
- else:
832
- mask[:, col_idx] = sampled_col
833
-
834
- # if this addition pushes missing % > self.prop_missing,
835
- # check previous prop_missing, remove masked samples from
836
- # this column until closest to target prop_missing
837
- current_prop = np.sum(mask) / mask.size
838
- if abs(current_prop - self.prop_missing) <= self.tol:
839
- filled = True
840
- break
841
- elif current_prop > self.prop_missing:
842
- tries = 0
843
- while (
844
- abs(current_prop - self.prop_missing) > self.tol
845
- and tries < self.max_tries
846
- ):
847
- r = np.random.randint(0, mask.shape[0])
848
- c = np.random.randint(0, mask.shape[1])
849
- mask[r, c] = False
850
- tries += 1
851
- current_prop = np.sum(mask) / mask.size
852
-
853
- filled = True
392
+ if placed < target:
393
+ prev_col = mask[:, j].copy()
394
+ mask[eligible_rows, j] = True
395
+
396
+ # avoid fully missing column among observed
397
+ col_after = mask[present[:, j], j]
398
+ if col_after.all():
399
+ idx_present = np.flatnonzero(present[:, j])
400
+ k = int(rng.choice(idx_present))
401
+ mask[k, j] = False
402
+
403
+ new_placed = int(mask.sum())
404
+ delta = abs(new_placed - target)
405
+ if delta <= best_delta:
406
+ best_delta = delta
407
+ placed = new_placed
408
+ col_quota[j] = max(0, col_quota[j] - 1)
854
409
  else:
410
+ mask[:, j] = prev_col
411
+ else:
412
+ # remove within the same clade and column
413
+ prev_col = mask[:, j].copy()
414
+ col_idxs = eligible_rows[mask[eligible_rows, j]]
415
+ if col_idxs.size == 0:
855
416
  continue
417
+ need = min(col_idxs.size, max(1, placed - target))
418
+ to_clear = rng.choice(col_idxs, size=need, replace=False)
419
+ mask[to_clear, j] = False
420
+
421
+ new_placed = int(mask.sum())
422
+ delta = abs(new_placed - target)
423
+ if delta <= best_delta:
424
+ best_delta = delta
425
+ placed = new_placed
426
+ else:
427
+ mask[:, j] = prev_col
856
428
 
857
- # With strategy=="nonrandom" or "nonrandom_weighted",
858
- # mask_ is equal to sim_missing_mask_ if mask_missing is True.
859
- # Otherwise it is equal to all_missing_.
860
429
  self.mask_ = mask
861
-
862
- self._validate_mask()
863
-
430
+ self._validate_mask(use_non_original_only=self.mask_missing)
864
431
  else:
865
- raise ValueError(
866
- "Invalid SimGenotypeData.strategy value:", self.strategy
867
- )
432
+ msg = f"Invalid SimMissingTransformer.strategy value: {self.strategy}"
433
+ self.logger.error(msg)
434
+ raise ValueError(msg)
868
435
 
869
436
  # Get all missing values.
870
- self.all_missing_mask_ = np.logical_or(
871
- self.mask_, self.original_missing_mask_
872
- )
437
+ self.all_missing_mask_ = np.logical_or(self.mask_, self.original_missing_mask_)
438
+
873
439
  # Get values where original value was not missing and simulated.
874
440
  # data is missing.
875
441
  self.sim_missing_mask_ = np.logical_and(
876
442
  self.all_missing_mask_, self.original_missing_mask_ == False
877
443
  )
878
444
 
879
- self._validate_mask(mask=self.mask_missing)
445
+ self._validate_mask(use_non_original_only=self.mask_missing)
880
446
 
881
447
  return self
882
448
 
883
- def transform(self, X):
449
+ def transform(self, X: np.ndarray) -> np.ndarray:
884
450
  """Function to generate masked sites in a SimGenotypeData object
885
451
 
886
452
  Args:
887
- X (pandas.DataFrame, numpy.ndarray, or List[List[int]]): Data to transform. No missing data should be present in X. It should have already been imputed with one of the non-machine learning simple imputers.
453
+ X (np.ndarray): Data to transform. No missing data should be present in X. It should have already been imputed with one of the non-machine learning simple imputers.
888
454
 
889
455
  Returns:
890
- numpy.ndarray: Transformed data with missing data added.
456
+ np.ndarray: Transformed data with missing data added.
891
457
  """
892
- X = misc.validate_input_type(X, return_type="array")
458
+ X = np.asarray(validate_input_type(X, return_type="array")).astype("float32")
893
459
 
894
460
  # mask 012-encoded and one-hot encoded genotypes.
895
461
  return self._mask_snps(X)
896
462
 
897
- def accuracy(self, X_true, X_pred):
898
- """Calculate imputation accuracy of the simulated genotypes.
899
-
900
- Args:
901
- X_true (np.ndarray): True values.
902
-
903
- X_pred (np.ndarray): Imputed values.
904
-
905
- Returns:
906
- float: Accuracy score between X_true and X_pred.
907
- '"""
908
- masked_sites = np.sum(self.sim_missing_mask_)
909
- num_correct = np.sum(
910
- X_true[self.sim_missing_mask_] == X_pred[self.sim_missing_mask_]
911
- )
912
- return num_correct / masked_sites
913
-
914
- def auc_roc_pr_ap(self, X_true, X_pred):
915
- """Calcuate AUC-ROC, Precision-Recall, and Average Precision (AP).
463
+ def sqrt_transform(self, proportions: np.ndarray) -> np.ndarray:
464
+ """Apply the square root transformation to an array of proportions.
916
465
 
917
466
  Args:
918
- X_true (np.ndarray): True values.
919
-
920
- X_pred (np.ndarray): Imputed values.
467
+ proportions (np.ndarray): An array of proportions.
921
468
 
922
469
  Returns:
923
- List[float]: List of AUC-ROC scores in order of: 0,1,2.
924
- List[float]: List of precision scores in order of: 0,1,2.
925
- List[float]: List of recall scores in order of: 0,1,2.
926
- List[float]: List of average precision scores in order of 0,1,2.
927
-
470
+ np.ndarray: The transformed proportions.
928
471
  """
929
- y_true = X_true[self.sim_missing_mask_]
930
- y_pred = X_pred[self.sim_missing_mask_]
931
-
932
- # Binarize the output
933
- y_true_bin = label_binarize(y_true, classes=[0, 1, 2])
934
- y_pred_bin = label_binarize(y_pred, classes=[0, 1, 2])
935
-
936
- # Initialize lists to hold the scores for each class
937
- auc_roc_scores = []
938
- precision_scores = []
939
- recall_scores = []
940
- avg_precision_scores = []
941
-
942
- for i in range(y_true_bin.shape[1]):
943
- # AUC-ROC score
944
- auc_roc = roc_auc_score(
945
- y_true_bin[:, i], y_pred_bin[:, i], average="weighted"
946
- )
947
- auc_roc_scores.append(auc_roc)
948
-
949
- # Precision-recall score
950
- precision, recall, _, _ = precision_recall_fscore_support(
951
- y_true_bin[:, i], y_pred_bin[:, i], average="weighted"
952
- )
953
- precision_scores.append(precision)
954
- recall_scores.append(recall)
472
+ return np.sqrt(proportions)
955
473
 
956
- # Average precision score
957
- avg_precision = average_precision_score(
958
- y_true_bin[:, i], y_pred_bin[:, i], average="weighted"
959
- )
960
- avg_precision_scores.append(avg_precision)
961
-
962
- return (
963
- auc_roc_scores,
964
- precision_scores,
965
- recall_scores,
966
- avg_precision_scores,
967
- )
474
+ def random_weighted_missing_data(
475
+ self,
476
+ X: np.ndarray,
477
+ transform_fn: Literal["sqrt", "exp"] = "sqrt",
478
+ power: float = 0.5,
479
+ inv: bool = False,
480
+ rng: np.random.Generator | None = None,
481
+ target_rate: float | None = None, # if None, use realized draw
482
+ ) -> np.ndarray:
483
+ """Simulate missing data proportional or inversely proportional to genotype frequencies.
968
484
 
969
- def random_weighted_missing_data(self, X, inv=False):
970
- """Choose values for which to simulate missing data by biasing towards the minority or majority alleles, depending on whether inv is True or False.
485
+ This method simulates missing data in a genotype matrix based on genotype frequencies. It allows for different transformation functions to be applied to the base probabilities, and can optionally use inverse genotype frequencies.
971
486
 
972
487
  Args:
973
- X (np.ndarray): True values.
974
-
975
- inv (bool, optional): If True, then biases towards choosing majority alleles. If False, then biases towards choosing minority alleles. Defaults to False.
488
+ X (np.ndarray): Input genotype matrix.
489
+ transform_fn (Literal["sqrt", "exp"]): Transformation function to apply to base probabilities.
490
+ power (float): Exponent to raise transformed probabilities.
491
+ inv (bool): If True, use inverse genotype frequencies. If False, use direct frequencies to weight missingness.
492
+ rng (np.random.Generator | None): Optional NumPy Generator for reproducibility.
493
+ target_rate (float | None): If provided, scales the probabilities to achieve this target missing rate.
976
494
 
977
495
  Returns:
978
- np.ndarray: X with simulated missing values.
979
-
980
- """
981
- # Get unique classes and their counts
982
- classes, counts = np.unique(X, return_counts=True)
983
- # Compute class weights
984
- if inv:
985
- class_weights = 1 / counts
986
- else:
987
- class_weights = counts
988
- # Normalize class weights
989
- class_weights = class_weights / sum(class_weights)
990
-
991
- # Compute mask
992
- if self.mask_missing:
993
- # Get indexes where non-missing (Xobs) and missing (Xmiss)
994
- Xobs = np.where(~self.original_missing_mask_.ravel())[0]
995
- Xmiss = np.where(self.original_missing_mask_.ravel())[0]
996
-
997
- # Generate mask of 0's (non-missing) and 1's (missing)
998
- obs_mask = np.random.choice(
999
- classes, size=Xobs.size, p=class_weights
496
+ np.ndarray: Simulated missing mask.
497
+ """
498
+ tf = transform_fn.lower()
499
+ if tf not in {"sqrt", "exp"}:
500
+ msg = f"transform_fn must be 'sqrt' or 'exp', got: {transform_fn}"
501
+ self.logger.error(msg)
502
+ raise ValueError(msg)
503
+
504
+ rng = np.random.default_rng() if rng is None else rng
505
+ eps = 1e-12
506
+
507
+ def _tf(arr: np.ndarray) -> np.ndarray:
508
+ arr = np.clip(arr, eps, None)
509
+ return np.sqrt(arr) if tf == "sqrt" else np.exp(-arr)
510
+
511
+ n_samples, n_snps = X.shape
512
+ out_mask = np.zeros((n_samples, n_snps), dtype=bool)
513
+
514
+ for j in range(n_snps):
515
+ col = X[:, j]
516
+ present = ~np.isnan(col)
517
+ if not np.any(present):
518
+ continue
519
+
520
+ vals = col[present]
521
+ classes, counts = np.unique(vals, return_counts=True)
522
+ if classes.size == 1: # never wipe entire column
523
+ continue
524
+
525
+ p = counts.astype(float) / counts.sum()
526
+ base = 1.0 / np.clip(p, eps, None) if inv else p
527
+ w = _tf(base)
528
+ w = np.clip(w, 0.0, None) ** power
529
+ s = w.sum()
530
+ w = (
531
+ np.full_like(w, 1.0 / w.size, dtype=float)
532
+ if (s <= 0 or ~np.isfinite(s))
533
+ else (w / s)
1000
534
  )
1001
- obs_mask = (obs_mask == classes[:, None]).argmax(axis=0)
1002
535
 
1003
- # Make missing data mask
1004
- mask = np.zeros(X.size, dtype=bool)
1005
- mask[Xobs] = obs_mask
1006
- mask[Xmiss] = 1
536
+ probs = np.zeros(n_samples, dtype=float)
537
+ for c, pw in zip(classes, w):
538
+ probs[present & (col == c)] = pw
1007
539
 
1008
- # Reshape from raveled to 2D
1009
- mask = mask.reshape(X.shape)
1010
- else:
1011
- # Generate mask of 0's (non-missing) and 1's (missing)
1012
- mask = np.random.choice(classes, size=X.size, p=class_weights)
1013
- mask = (mask == classes[:, None]).argmax(axis=0).reshape(X.shape)
540
+ if target_rate is not None:
541
+ probs *= float(target_rate) # scale global intensity
1014
542
 
1015
- # Assign mask to self before validation
1016
- self.mask_ = mask
543
+ draws = rng.random(n_samples)
544
+ out_mask[:, j] = draws < probs
545
+ out_mask[~present, j] = False # never alter already-missing
1017
546
 
1018
- self._validate_mask()
547
+ # guard against accidentally wiping this column (using only non-original-missing)
548
+ col_after = out_mask[present, j]
549
+ if col_after.sum() == col_after.size:
550
+ # clear a random observed index
551
+ k = rng.integers(0, col_after.size)
552
+ out_mask[np.flatnonzero(present)[k], j] = False
1019
553
 
1020
- return mask
554
+ return out_mask
1021
555
 
1022
556
  def _sample_tree(
1023
557
  self,
1024
- internal_only=False,
1025
- tips_only=False,
1026
- skip_root=True,
1027
- weighted=False,
1028
- ):
1029
- """Function for randomly sampling clades from SimGenotypeData.tree.
1030
-
1031
- Args:
1032
- internal_only (bool): Only sample from NON-TIPS. Defaults to False.
1033
-
1034
- tips_only (bool): Only sample from tips. Defaults to False.
558
+ internal_only: bool = False,
559
+ tips_only: bool = False,
560
+ skip_root: bool = True,
561
+ weighted: bool = False,
562
+ rng: np.random.Generator | None = None,
563
+ ) -> list[str]:
564
+ """Sample a node and return descendant tip labels.
1035
565
 
1036
- skip_root (bool): Exclude sampling of root node. Defaults to True.
566
+ This method samples a node from the genotype tree and retrieves the tip labels of all descendant nodes. The sampling can be restricted to internal nodes, tip nodes, or can exclude the root node. Additionally, the sampling can be weighted by branch lengths.
1037
567
 
1038
- weighted (bool): Weight sampling by branch length. Defaults to False.
568
+ Args:
569
+ internal_only: Sample only internal nodes.
570
+ tips_only: Sample only tip nodes.
571
+ skip_root: Exclude the root from sampling.
572
+ weighted: Weight node sampling by branch length.
573
+ rng: Optional NumPy Generator for reproducibility.
1039
574
 
1040
575
  Returns:
1041
- List[str]: List of descendant tips from the sampled node.
576
+ List[str]: Tip labels under the sampled node.
1042
577
 
1043
578
  Raises:
1044
- ValueError: ``tips_only`` and ``internal_only`` cannot both be True.
579
+ ValueError: If no eligible nodes exist or both tips_only and internal_only are True.
1045
580
  """
1046
-
1047
581
  if tips_only and internal_only:
1048
- raise ValueError("internal_only and tips_only cannot both be true")
1049
-
1050
- # to only sample internal nodes add if not i.is_leaf()
1051
- node_dict = dict()
1052
-
1053
- for node in self.genotype_data.tree.treenode.traverse("preorder"):
1054
- ## node.idx is node indexes.
1055
- ## node.dist is branch lengths.
1056
- if skip_root:
1057
- # If root node.
1058
- if node.idx == self.genotype_data.tree.nnodes - 1:
1059
- continue
1060
-
1061
- if tips_only and internal_only:
1062
- raise ValueError(
1063
- "tips_only and internal_only cannot both be True"
582
+ msg = "tips_only and internal_only cannot both be True"
583
+ self.logger.error(msg)
584
+ raise ValueError(msg)
585
+
586
+ rng = np.random.default_rng() if rng is None else rng
587
+
588
+ node_dict: dict[int | object, float] = {}
589
+
590
+ if self.tree_parser is None or not hasattr(self.tree_parser, "tree"):
591
+ msg = "SimMissingTransformer.tree cannot be NoneType when strategy='nonrandom' or strategy='nonrandom_weighted'"
592
+ self.logger.error(msg)
593
+ raise TypeError(msg)
594
+
595
+ # Traverse using the tree backend you have; be tolerant of API differences.
596
+ for node in self.tree_parser.tree.treenode.traverse("preorder"):
597
+ # Robust root detection: prefer is_root(), then fall back to parent None, finally fall back to idx==nnodes-1 only if needed.
598
+ is_root = False
599
+ if hasattr(node, "is_root"):
600
+ is_root = bool(node.is_root())
601
+ elif getattr(node, "up", None) is None:
602
+ is_root = True
603
+ elif hasattr(self.tree_parser.tree, "nnodes") and hasattr(node, "idx"):
604
+ is_root = node.idx == self.tree_parser.tree.nnodes - 1
605
+
606
+ if skip_root and is_root:
607
+ continue
608
+
609
+ if tips_only and not node.is_leaf():
610
+ continue
611
+ if internal_only and node.is_leaf():
612
+ continue
613
+
614
+ # Branch length; coerce invalid to 0
615
+ dist = float(getattr(node, "dist", 0.0) or 0.0)
616
+ if not np.isfinite(dist):
617
+ dist = 0.0
618
+
619
+ # Use node.idx if stable, else the node object as key
620
+ key = getattr(node, "idx", node)
621
+ node_dict[key] = dist
622
+
623
+ if not node_dict:
624
+ msg = "No eligible nodes found to sample from the tree."
625
+ self.logger.error(msg)
626
+ raise ValueError(msg)
627
+
628
+ keys = np.array(list(node_dict.keys()), dtype=object)
629
+ weights = np.asarray(list(node_dict.values()), dtype=float)
630
+ weights[~np.isfinite(weights)] = 0.0
631
+ sample_set = set(self.genotype_data.samples)
632
+
633
+ def _choose_key() -> object:
634
+ if weighted and weights.sum() > 0.0:
635
+ p = weights / weights.sum()
636
+ return rng.choice(keys, p=p)
637
+ return rng.choice(keys)
638
+
639
+ tree = self.tree_parser.tree
640
+ last_error: Optional[Exception] = None
641
+ max_attempts = max(1, len(keys) * 3)
642
+
643
+ for _ in range(max_attempts):
644
+ chosen_key = _choose_key()
645
+
646
+ # 1. Resolve chosen_key to a Node object
647
+ try:
648
+ if isinstance(chosen_key, (int, np.integer)):
649
+ node = tree[int(chosen_key)]
650
+ else:
651
+ node = chosen_key
652
+ except Exception as e:
653
+ last_error = e
654
+ continue
655
+
656
+ # 2. Retrieve leaves for this specific node
657
+ if not hasattr(node, "get_leaves"):
658
+ last_error = TypeError(
659
+ f"Object {type(node)} does not have a get_leaves method."
1064
660
  )
661
+ continue
662
+
663
+ try:
664
+ tips = [leaf.name for leaf in node.get_leaves()] # type: ignore
665
+ except Exception as e:
666
+ last_error = e
667
+ continue
668
+
669
+ # Filter to sample IDs present in the matrix
670
+ tips = [t for t in tips if t in sample_set]
671
+ if tips:
672
+ return tips
673
+
674
+ msg = (
675
+ "No sampled clades contain tips present in genotype_data.samples. "
676
+ "Check that tree tip names match the genotype_data samples."
677
+ )
678
+ self.logger.error(msg)
679
+ if last_error:
680
+ raise ValueError(msg) from last_error
681
+ raise ValueError(msg)
1065
682
 
1066
- if tips_only:
1067
- if not node.is_leaf():
1068
- continue
1069
- elif internal_only:
1070
- if node.is_leaf():
1071
- continue
1072
- node_dict[node.idx] = node.dist
1073
- if weighted:
1074
- s = sum(list(node_dict.values()))
1075
- # Node index / sum of node distances.
1076
- p = [i / s for i in list(node_dict.values())]
1077
- node_idx = np.random.choice(list(node_dict.keys()), size=1, p=p)[0]
1078
- else:
1079
- # Get missing choice from random clade.
1080
- node_idx = np.random.choice(list(node_dict.keys()), size=1)[0]
1081
- return self.genotype_data.tree.get_tip_labels(idx=node_idx)
1082
-
1083
- def _validate_mask(self, mask=False):
1084
- """Make sure no entirely missing columns are simulated."""
1085
- if mask is None:
1086
- mask = self.mask_
1087
- for i, column in enumerate(self.mask_.T):
1088
- if mask:
1089
- miss_mask = self.original_missing_mask_[:, i]
1090
- col = column[~miss_mask]
1091
- obs_idx = np.where(~miss_mask)
1092
- idx = obs_idx[np.random.choice(np.arange(len(obs_idx)))]
683
+ def _validate_mask(self, use_non_original_only: bool = False) -> None:
684
+ """Ensure no column is entirely masked on observed entries.
685
+
686
+ Args:
687
+ use_non_original_only (bool): If True, only consider non-original-missing entries when validating. Defaults to False.
688
+ """
689
+ m = self.mask_
690
+ for j in range(m.shape[1]):
691
+ if use_non_original_only:
692
+ obs = ~self.original_missing_mask_[:, j]
1093
693
  else:
1094
- col = column
1095
- idx = np.random.choice(np.arange(col.shape[0]))
1096
- if np.sum(col) == col.size:
1097
- self.mask_[idx, i] = False
694
+ obs = np.ones(m.shape[0], dtype=bool)
695
+ if not np.any(obs):
696
+ continue
697
+ col = m[obs, j]
698
+ if col.size and col.all():
699
+ # clear one random observed index
700
+ idxs = np.flatnonzero(obs)
701
+ k = np.random.randint(0, idxs.size)
702
+ self.mask_[idxs[k], j] = False
1098
703
 
1099
704
  def _mask_snps(self, X):
1100
705
  """Mask positions in SimGenotypeData.snps and SimGenotypeData.onehot"""
@@ -1112,6 +717,51 @@ class SimGenotypeDataTransformer(BaseEstimator, TransformerMixin):
1112
717
  Xt[mask_boolean] = mask_val
1113
718
  return Xt
1114
719
 
720
+ def write_mask(self, filename_prefix: str):
721
+ """Write mask to file.
722
+
723
+ Args:
724
+ filename_prefix (str): Prefix for the filenames to write to.
725
+ """
726
+ np.save(filename_prefix + "_mask.npy", self.mask_)
727
+ np.save(
728
+ filename_prefix + "_original_missing_mask.npy",
729
+ self.original_missing_mask_,
730
+ )
731
+
732
+ def read_mask(
733
+ self, filename_prefix: str
734
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
735
+ """Read mask from file.
736
+
737
+ Args:
738
+ filename_prefix (str): Prefix for the filenames to read from.
739
+
740
+ Returns:
741
+ Tuple[np.ndarray, np.ndarray, np.ndarray]: The read masks. (mask, original_missing_mask, all_missing_mask).
742
+ """
743
+ # Check if files exist
744
+ if not Path(filename_prefix + "_mask.npy").is_file():
745
+ msg = filename_prefix + "_mask.npy" + " does not exist."
746
+ self.logger.error(msg)
747
+ raise FileNotFoundError(msg)
748
+
749
+ if not Path(filename_prefix + "_original_missing_mask.npy").is_file():
750
+ msg = filename_prefix + "_original_missing_mask.npy" + " does not exist."
751
+ self.logger.error(msg)
752
+ raise FileNotFoundError(msg)
753
+
754
+ # Load mask from file
755
+ self.mask_ = np.load(filename_prefix + "_mask.npy")
756
+ self.original_missing_mask_ = np.load(
757
+ filename_prefix + "_original_missing_mask.npy"
758
+ )
759
+
760
+ # Recalculate all_missing_mask_ from mask_ and original_missing_mask_
761
+ self.all_missing_mask_ = np.logical_or(self.mask_, self.original_missing_mask_)
762
+
763
+ return self.mask_, self.original_missing_mask_, self.all_missing_mask_
764
+
1115
765
  @property
1116
766
  def missing_count(self) -> int:
1117
767
  """Count of masked genotypes in SimGenotypeData.mask