sai-pg 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sai/__init__.py +2 -0
- sai/__main__.py +6 -3
- sai/configs/__init__.py +24 -0
- sai/configs/global_config.py +83 -0
- sai/configs/ploidy_config.py +94 -0
- sai/configs/pop_config.py +82 -0
- sai/configs/stat_config.py +220 -0
- sai/{utils/generators → generators}/chunk_generator.py +1 -1
- sai/{utils/generators → generators}/window_generator.py +81 -37
- sai/{utils/multiprocessing → multiprocessing}/mp_manager.py +2 -2
- sai/{utils/multiprocessing → multiprocessing}/mp_pool.py +2 -2
- sai/parsers/outlier_parser.py +4 -3
- sai/parsers/score_parser.py +8 -119
- sai/{utils/preprocessors → preprocessors}/chunk_preprocessor.py +21 -15
- sai/preprocessors/feature_preprocessor.py +236 -0
- sai/registries/__init__.py +22 -0
- sai/registries/generic_registry.py +89 -0
- sai/registries/stat_registry.py +30 -0
- sai/sai.py +124 -220
- sai/stats/__init__.py +11 -0
- sai/stats/danc_statistic.py +83 -0
- sai/stats/dd_statistic.py +77 -0
- sai/stats/df_statistic.py +84 -0
- sai/stats/dplus_statistic.py +86 -0
- sai/stats/fd_statistic.py +92 -0
- sai/stats/generic_statistic.py +93 -0
- sai/stats/q_statistic.py +104 -0
- sai/stats/stat_utils.py +259 -0
- sai/stats/u_statistic.py +99 -0
- sai/utils/utils.py +213 -142
- {sai_pg-1.0.1.dist-info → sai_pg-1.1.0.dist-info}/METADATA +3 -14
- sai_pg-1.1.0.dist-info/RECORD +70 -0
- {sai_pg-1.0.1.dist-info → sai_pg-1.1.0.dist-info}/WHEEL +1 -1
- sai_pg-1.1.0.dist-info/top_level.txt +2 -0
- tests/configs/test_global_config.py +163 -0
- tests/configs/test_ploidy_config.py +93 -0
- tests/configs/test_pop_config.py +90 -0
- tests/configs/test_stat_config.py +171 -0
- tests/generators/test_chunk_generator.py +51 -0
- tests/generators/test_window_generator.py +164 -0
- tests/multiprocessing/test_mp_manager.py +92 -0
- tests/multiprocessing/test_mp_pool.py +79 -0
- tests/parsers/test_argument_validation.py +133 -0
- tests/parsers/test_outlier_parser.py +53 -0
- tests/parsers/test_score_parser.py +63 -0
- tests/preprocessors/test_chunk_preprocessor.py +79 -0
- tests/preprocessors/test_feature_preprocessor.py +223 -0
- tests/registries/test_registries.py +74 -0
- tests/stats/test_danc_statistic.py +51 -0
- tests/stats/test_dd_statistic.py +45 -0
- tests/stats/test_df_statistic.py +73 -0
- tests/stats/test_dplus_statistic.py +79 -0
- tests/stats/test_fd_statistic.py +68 -0
- tests/stats/test_q_statistic.py +268 -0
- tests/stats/test_stat_utils.py +354 -0
- tests/stats/test_u_statistic.py +233 -0
- tests/test___main__.py +51 -0
- tests/test_sai.py +102 -0
- tests/utils/test_utils.py +511 -0
- sai/parsers/plot_parser.py +0 -152
- sai/stats/features.py +0 -302
- sai/utils/preprocessors/feature_preprocessor.py +0 -211
- sai_pg-1.0.1.dist-info/RECORD +0 -30
- sai_pg-1.0.1.dist-info/top_level.txt +0 -1
- /sai/{utils/generators → generators}/__init__.py +0 -0
- /sai/{utils/generators → generators}/data_generator.py +0 -0
- /sai/{utils/multiprocessing → multiprocessing}/__init__.py +0 -0
- /sai/{utils/preprocessors → preprocessors}/__init__.py +0 -0
- /sai/{utils/preprocessors → preprocessors}/data_preprocessor.py +0 -0
- {sai_pg-1.0.1.dist-info → sai_pg-1.1.0.dist-info}/entry_points.txt +0 -0
- {sai_pg-1.0.1.dist-info → sai_pg-1.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,268 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import pytest
|
22
|
+
import numpy as np
|
23
|
+
from sai.stats import QStatistic
|
24
|
+
|
25
|
+
|
26
|
+
def test_QStatistic_compute_basic():
|
27
|
+
# Test data
|
28
|
+
ref_gts = np.array([[0, 0, 1], [0, 0, 0], [1, 1, 1]])
|
29
|
+
tgt_gts = np.array([[0, 1, 1], [0, 0, 1], [1, 1, 1]])
|
30
|
+
src_gts = np.array([[1, 1, 1], [0, 1, 1], [1, 1, 1]])
|
31
|
+
pos = np.array([0, 1, 2])
|
32
|
+
w, y, quantile = 0.5, ("=", 1.0), 0.95
|
33
|
+
|
34
|
+
# Expected output
|
35
|
+
expected_result = 0.66667 # Only the first site meets the criteria
|
36
|
+
expected_positions = np.array([0])
|
37
|
+
|
38
|
+
# Run test
|
39
|
+
q_stat = QStatistic(
|
40
|
+
ref_gts=ref_gts,
|
41
|
+
tgt_gts=tgt_gts,
|
42
|
+
src_gts_list=[src_gts],
|
43
|
+
ref_ploidy=1,
|
44
|
+
tgt_ploidy=1,
|
45
|
+
src_ploidy_list=[1],
|
46
|
+
)
|
47
|
+
results = q_stat.compute(
|
48
|
+
pos=pos,
|
49
|
+
w=w,
|
50
|
+
y_list=[y],
|
51
|
+
quantile=quantile,
|
52
|
+
anc_allele_available=False,
|
53
|
+
)
|
54
|
+
|
55
|
+
assert results["name"] == "Q"
|
56
|
+
assert np.isclose(
|
57
|
+
results["value"], expected_result
|
58
|
+
), f"Expected {expected_result}, got {results['value']}"
|
59
|
+
assert np.array_equal(results["cdd_pos"], expected_positions)
|
60
|
+
|
61
|
+
results = q_stat.compute(
|
62
|
+
pos=pos,
|
63
|
+
w=w,
|
64
|
+
y_list=[y],
|
65
|
+
quantile=quantile,
|
66
|
+
anc_allele_available=True,
|
67
|
+
)
|
68
|
+
|
69
|
+
assert np.isclose(
|
70
|
+
results["value"], expected_result
|
71
|
+
), f"Expected {expected_result}, got {results['value']}"
|
72
|
+
assert np.array_equal(results["cdd_pos"], expected_positions)
|
73
|
+
|
74
|
+
|
75
|
+
def test_QStatistic_compute_no_match():
|
76
|
+
# Test data with no matching loci
|
77
|
+
ref_gts = np.array([[0, 0, 1], [0, 0, 0]])
|
78
|
+
tgt_gts = np.array([[0, 1, 1], [1, 1, 1]])
|
79
|
+
src_gts = np.array([[1, 1, 1], [1, 1, 1]])
|
80
|
+
pos = np.array([0, 1])
|
81
|
+
w, y, quantile = (
|
82
|
+
0.3,
|
83
|
+
("=", 0.0),
|
84
|
+
0.95,
|
85
|
+
) # No tgt_gts frequencies < w and no src_gts frequencies == y
|
86
|
+
|
87
|
+
# Expected output
|
88
|
+
expected_positions = np.array([])
|
89
|
+
|
90
|
+
# Run test
|
91
|
+
q_stat = QStatistic(
|
92
|
+
ref_gts=ref_gts,
|
93
|
+
tgt_gts=tgt_gts,
|
94
|
+
src_gts_list=[src_gts],
|
95
|
+
ref_ploidy=1,
|
96
|
+
tgt_ploidy=1,
|
97
|
+
src_ploidy_list=[1],
|
98
|
+
)
|
99
|
+
results = q_stat.compute(
|
100
|
+
pos=pos,
|
101
|
+
w=w,
|
102
|
+
y_list=[y],
|
103
|
+
quantile=quantile,
|
104
|
+
anc_allele_available=False,
|
105
|
+
)
|
106
|
+
|
107
|
+
assert np.isnan(results["value"]), f"Expected NaN, got {results['value']}"
|
108
|
+
assert np.array_equal(results["cdd_pos"], expected_positions)
|
109
|
+
|
110
|
+
|
111
|
+
def test_QStatistic_compute_different_quantile():
|
112
|
+
# Test data
|
113
|
+
ref_gts = np.array([[0, 0, 1], [1, 0, 0], [0, 0, 1]])
|
114
|
+
tgt_gts = np.array([[0, 1, 1], [1, 1, 1], [1, 1, 1]])
|
115
|
+
src_gts = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1]])
|
116
|
+
pos = np.array([0, 1, 2])
|
117
|
+
w, y, quantile = 0.5, ("=", 1.0), 0.5
|
118
|
+
|
119
|
+
# Expected output
|
120
|
+
expected_result = (
|
121
|
+
1.0 # 50% quantile (median) of [1.0, 1.0, 1.0] in tgt_gts that meets conditions
|
122
|
+
)
|
123
|
+
expected_positions = np.array([1, 2])
|
124
|
+
|
125
|
+
# Run test
|
126
|
+
q_stat = QStatistic(
|
127
|
+
ref_gts=ref_gts,
|
128
|
+
tgt_gts=tgt_gts,
|
129
|
+
src_gts_list=[src_gts],
|
130
|
+
ref_ploidy=1,
|
131
|
+
tgt_ploidy=1,
|
132
|
+
src_ploidy_list=[1],
|
133
|
+
)
|
134
|
+
results = q_stat.compute(
|
135
|
+
pos=pos,
|
136
|
+
w=w,
|
137
|
+
y_list=[y],
|
138
|
+
quantile=quantile,
|
139
|
+
anc_allele_available=False,
|
140
|
+
)
|
141
|
+
|
142
|
+
assert np.isclose(
|
143
|
+
results["value"], expected_result
|
144
|
+
), f"Expected {expected_result}, got {results['value']}"
|
145
|
+
assert np.array_equal(results["cdd_pos"], expected_positions)
|
146
|
+
|
147
|
+
|
148
|
+
def test_QStatistic_compute_edge_case():
|
149
|
+
# Edge case where only one site meets criteria
|
150
|
+
ref_gts = np.array([[0, 0, 1], [0, 0, 0], [1, 1, 1]])
|
151
|
+
tgt_gts = np.array([[0, 1, 1], [1, 1, 1], [0, 0, 0]])
|
152
|
+
src_gts = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1]])
|
153
|
+
pos = np.array([0, 1, 2])
|
154
|
+
w, y, quantile = 0.95, ("=", 1.0), 0.95
|
155
|
+
|
156
|
+
# Expected output
|
157
|
+
expected_result = 0.9666666666666667
|
158
|
+
expected_positions = np.array([1])
|
159
|
+
|
160
|
+
# Run test
|
161
|
+
q_stat = QStatistic(
|
162
|
+
ref_gts=ref_gts,
|
163
|
+
tgt_gts=tgt_gts,
|
164
|
+
src_gts_list=[src_gts],
|
165
|
+
ref_ploidy=1,
|
166
|
+
tgt_ploidy=1,
|
167
|
+
src_ploidy_list=[1],
|
168
|
+
)
|
169
|
+
results = q_stat.compute(
|
170
|
+
pos=pos,
|
171
|
+
w=w,
|
172
|
+
y_list=[y],
|
173
|
+
quantile=quantile,
|
174
|
+
anc_allele_available=False,
|
175
|
+
)
|
176
|
+
|
177
|
+
assert np.isclose(
|
178
|
+
results["value"], expected_result
|
179
|
+
), f"Expected {expected_result}, got {results['value']}"
|
180
|
+
assert np.array_equal(results["cdd_pos"], expected_positions)
|
181
|
+
|
182
|
+
|
183
|
+
def test_QStatistic_compute_with_two_sources():
|
184
|
+
# Test data
|
185
|
+
ref_gts = np.array([[1, 1, 0], [0, 1, 1], [1, 1, 1], [0, 0, 1]])
|
186
|
+
tgt_gts = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1], [1, 1, 1]])
|
187
|
+
src_gts1 = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 1]])
|
188
|
+
src_gts2 = np.array([[1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])
|
189
|
+
pos = np.array([0, 1, 2, 3])
|
190
|
+
w, y_list, quantile = 0.5, [("=", 1), ("=", 1)], 0.95
|
191
|
+
|
192
|
+
# Expected result: 95% quantile of the filtered tgt_gts frequencies
|
193
|
+
expected_positions = np.array([])
|
194
|
+
|
195
|
+
# Run test
|
196
|
+
q_stat = QStatistic(
|
197
|
+
ref_gts=ref_gts,
|
198
|
+
tgt_gts=tgt_gts,
|
199
|
+
src_gts_list=[src_gts1, src_gts2],
|
200
|
+
ref_ploidy=1,
|
201
|
+
tgt_ploidy=1,
|
202
|
+
src_ploidy_list=[1, 1],
|
203
|
+
)
|
204
|
+
results = q_stat.compute(
|
205
|
+
pos=pos,
|
206
|
+
w=w,
|
207
|
+
y_list=y_list,
|
208
|
+
quantile=quantile,
|
209
|
+
anc_allele_available=False,
|
210
|
+
)
|
211
|
+
|
212
|
+
assert np.isnan(results["value"]), f"Expected NaN, got {results['value']}"
|
213
|
+
assert np.array_equal(results["cdd_pos"], expected_positions)
|
214
|
+
|
215
|
+
|
216
|
+
def test_QStatistic_compute_with_mixed_ploidy():
|
217
|
+
ref_gts = np.array([[1, 1, 0], [0, 1, 1], [1, 1, 1], [0, 0, 1]])
|
218
|
+
tgt_gts = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1], [1, 1, 1]])
|
219
|
+
src_gts1 = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 1]])
|
220
|
+
src_gts2 = np.array([[1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])
|
221
|
+
pos = np.array([0, 1, 2, 3])
|
222
|
+
w, y_list, quantile = 0.5, [("=", 1), ("=", 1)], 0.95
|
223
|
+
|
224
|
+
expected_positions = np.array([])
|
225
|
+
|
226
|
+
q_stat = QStatistic(
|
227
|
+
ref_gts=ref_gts,
|
228
|
+
tgt_gts=tgt_gts,
|
229
|
+
src_gts_list=[src_gts1, src_gts2],
|
230
|
+
ref_ploidy=2,
|
231
|
+
tgt_ploidy=2,
|
232
|
+
src_ploidy_list=[4, 4],
|
233
|
+
)
|
234
|
+
results = q_stat.compute(
|
235
|
+
pos=pos,
|
236
|
+
w=w,
|
237
|
+
y_list=y_list,
|
238
|
+
quantile=quantile,
|
239
|
+
anc_allele_available=False,
|
240
|
+
)
|
241
|
+
|
242
|
+
assert np.isnan(results["value"]), f"Expected NaN, got {results['value']}"
|
243
|
+
assert np.array_equal(results["cdd_pos"], expected_positions)
|
244
|
+
|
245
|
+
|
246
|
+
def test_QStatistic_compute_with_missing_keys():
|
247
|
+
ref_gts = np.array([[1, 1, 0], [0, 1, 1], [1, 1, 1], [0, 0, 1]])
|
248
|
+
tgt_gts = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1], [1, 1, 1]])
|
249
|
+
src_gts1 = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 1]])
|
250
|
+
src_gts2 = np.array([[1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])
|
251
|
+
pos = np.array([0, 1, 2, 3])
|
252
|
+
w, quantile = 0.5, 0.95
|
253
|
+
|
254
|
+
with pytest.raises(ValueError):
|
255
|
+
q_stat = QStatistic(
|
256
|
+
ref_gts=ref_gts,
|
257
|
+
tgt_gts=tgt_gts,
|
258
|
+
src_gts_list=[src_gts1, src_gts2],
|
259
|
+
ref_ploidy=2,
|
260
|
+
tgt_ploidy=2,
|
261
|
+
src_ploidy_list=[4, 4],
|
262
|
+
)
|
263
|
+
q_stat.compute(
|
264
|
+
pos=pos,
|
265
|
+
w=w,
|
266
|
+
quantile=quantile,
|
267
|
+
anc_allele_available=False,
|
268
|
+
)
|
@@ -0,0 +1,354 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import pytest
|
22
|
+
import numpy as np
|
23
|
+
from sai.stats import calc_freq
|
24
|
+
from sai.stats import compute_matching_loci
|
25
|
+
from sai.stats import calc_four_pops_freq
|
26
|
+
from sai.stats import calc_pattern_sum
|
27
|
+
|
28
|
+
|
29
|
+
def test_phased_data():
|
30
|
+
# Phased data, ploidy = 1
|
31
|
+
gts = np.array([[1, 0, 0, 1], [0, 0, 0, 0], [1, 1, 1, 1]])
|
32
|
+
expected_frequency = np.array([0.5, 0.0, 1.0])
|
33
|
+
result = calc_freq(gts, ploidy=1)
|
34
|
+
np.testing.assert_array_almost_equal(
|
35
|
+
result, expected_frequency, decimal=6, err_msg="Phased data test failed."
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
def test_unphased_diploid_data():
|
40
|
+
# Unphased data, ploidy = 2 (diploid)
|
41
|
+
gts = np.array([[1, 1], [0, 0], [2, 2]])
|
42
|
+
expected_frequency = np.array([0.5, 0.0, 1.0])
|
43
|
+
result = calc_freq(gts, ploidy=2)
|
44
|
+
np.testing.assert_array_almost_equal(
|
45
|
+
result,
|
46
|
+
expected_frequency,
|
47
|
+
decimal=6,
|
48
|
+
err_msg="Unphased diploid data test failed.",
|
49
|
+
)
|
50
|
+
|
51
|
+
|
52
|
+
def test_unphased_triploid_data():
|
53
|
+
# Unphased data, ploidy = 3 (triploid)
|
54
|
+
gts = np.array([[1, 2, 3], [0, 0, 0], [3, 3, 3]])
|
55
|
+
expected_frequency = np.array([0.6667, 0.0, 1.0])
|
56
|
+
result = calc_freq(gts, ploidy=3)
|
57
|
+
np.testing.assert_array_almost_equal(
|
58
|
+
result,
|
59
|
+
expected_frequency,
|
60
|
+
decimal=4,
|
61
|
+
err_msg="Unphased triploid data test failed.",
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
def test_unphased_tetraploid_data():
|
66
|
+
# Unphased data, ploidy = 4 (tetraploid)
|
67
|
+
gts = np.array([[2, 2, 2, 2], [1, 3, 0, 4], [0, 0, 0, 0]])
|
68
|
+
expected_frequency = np.array([0.5, 0.5, 0.0])
|
69
|
+
result = calc_freq(gts, ploidy=4)
|
70
|
+
np.testing.assert_array_almost_equal(
|
71
|
+
result,
|
72
|
+
expected_frequency,
|
73
|
+
decimal=6,
|
74
|
+
err_msg="Unphased tetraploid data test failed.",
|
75
|
+
)
|
76
|
+
|
77
|
+
|
78
|
+
def test_invalid_ploidy():
|
79
|
+
gts = np.array([[1, 2, 3], [0, 0, 0], [3, 3, 3]])
|
80
|
+
|
81
|
+
with pytest.raises(ValueError):
|
82
|
+
calc_freq(gts, ploidy=None)
|
83
|
+
|
84
|
+
with pytest.raises(ValueError):
|
85
|
+
calc_freq(gts, ploidy=9.9)
|
86
|
+
|
87
|
+
with pytest.raises(ValueError):
|
88
|
+
calc_freq(gts, ploidy=-100)
|
89
|
+
|
90
|
+
|
91
|
+
def test_compute_matching_loci():
|
92
|
+
# Sample genotype data
|
93
|
+
ref_gts = np.array([[0, 1, 0], [1, 1, 0], [0, 0, 1]])
|
94
|
+
tgt_gts = np.array([[1, 1, 0], [0, 1, 1], [1, 1, 1]])
|
95
|
+
src_gts_list = [
|
96
|
+
np.array([[0, 0, 1], [1, 1, 0], [0, 1, 1]]), # src1
|
97
|
+
np.array([[1, 1, 0], [1, 0, 0], [1, 1, 0]]), # src2
|
98
|
+
]
|
99
|
+
|
100
|
+
# Define parameters with all possible conditions
|
101
|
+
conditions = [("=", 0.5), ("<", 0.4), (">", 0.3), ("<=", 0.6), (">=", 0.2)]
|
102
|
+
ploidy = [2, 2, 2]
|
103
|
+
anc_allele_available = False
|
104
|
+
|
105
|
+
for y_condition in conditions:
|
106
|
+
y_list = [y_condition, y_condition] # Apply the same condition to both sources
|
107
|
+
|
108
|
+
# Call the function
|
109
|
+
ref_freq, tgt_freq, condition = compute_matching_loci(
|
110
|
+
ref_gts,
|
111
|
+
tgt_gts,
|
112
|
+
src_gts_list,
|
113
|
+
0.5,
|
114
|
+
y_list,
|
115
|
+
ploidy,
|
116
|
+
anc_allele_available,
|
117
|
+
)
|
118
|
+
|
119
|
+
# Assertions to verify the outputs
|
120
|
+
assert ref_freq.shape == (3,)
|
121
|
+
assert tgt_freq.shape == (3,)
|
122
|
+
assert condition.shape == (3,)
|
123
|
+
assert np.all((ref_freq >= 0) & (ref_freq <= 1))
|
124
|
+
assert np.all((tgt_freq >= 0) & (tgt_freq <= 1))
|
125
|
+
assert np.all(
|
126
|
+
np.logical_or(condition == True, condition == False)
|
127
|
+
) # Ensure condition is boolean
|
128
|
+
|
129
|
+
# Test invalid w values
|
130
|
+
with pytest.raises(
|
131
|
+
ValueError, match=r"Parameters w must be within the range \[0, 1\]."
|
132
|
+
):
|
133
|
+
compute_matching_loci(
|
134
|
+
ref_gts,
|
135
|
+
tgt_gts,
|
136
|
+
src_gts_list,
|
137
|
+
-0.1,
|
138
|
+
y_list,
|
139
|
+
ploidy,
|
140
|
+
anc_allele_available,
|
141
|
+
)
|
142
|
+
with pytest.raises(
|
143
|
+
ValueError, match=r"Parameters w must be within the range \[0, 1\]."
|
144
|
+
):
|
145
|
+
compute_matching_loci(
|
146
|
+
ref_gts,
|
147
|
+
tgt_gts,
|
148
|
+
src_gts_list,
|
149
|
+
1.1,
|
150
|
+
y_list,
|
151
|
+
ploidy,
|
152
|
+
anc_allele_available,
|
153
|
+
)
|
154
|
+
|
155
|
+
# Test invalid y values
|
156
|
+
with pytest.raises(ValueError, match="Invalid value in y_list"):
|
157
|
+
compute_matching_loci(
|
158
|
+
ref_gts,
|
159
|
+
tgt_gts,
|
160
|
+
src_gts_list,
|
161
|
+
0.5,
|
162
|
+
[("=", -0.1)],
|
163
|
+
ploidy,
|
164
|
+
anc_allele_available,
|
165
|
+
)
|
166
|
+
with pytest.raises(ValueError, match="Invalid value in y_list"):
|
167
|
+
compute_matching_loci(
|
168
|
+
ref_gts,
|
169
|
+
tgt_gts,
|
170
|
+
src_gts_list,
|
171
|
+
0.5,
|
172
|
+
[("=", 1.1)],
|
173
|
+
ploidy,
|
174
|
+
anc_allele_available,
|
175
|
+
)
|
176
|
+
|
177
|
+
# Test invalid operators
|
178
|
+
with pytest.raises(ValueError, match="Invalid operator in y_list"):
|
179
|
+
compute_matching_loci(
|
180
|
+
ref_gts,
|
181
|
+
tgt_gts,
|
182
|
+
src_gts_list,
|
183
|
+
0.5,
|
184
|
+
[("invalid", 0.5)],
|
185
|
+
ploidy,
|
186
|
+
anc_allele_available,
|
187
|
+
)
|
188
|
+
|
189
|
+
# Test mismatched src_gts_list and y_list lengths
|
190
|
+
with pytest.raises(
|
191
|
+
ValueError, match="The length of src_gts_list and y_list must match"
|
192
|
+
):
|
193
|
+
compute_matching_loci(
|
194
|
+
ref_gts,
|
195
|
+
tgt_gts,
|
196
|
+
src_gts_list,
|
197
|
+
0.5,
|
198
|
+
[("=", 0.5)],
|
199
|
+
ploidy,
|
200
|
+
anc_allele_available,
|
201
|
+
)
|
202
|
+
|
203
|
+
|
204
|
+
def test_calc_four_pops_freq_basic():
|
205
|
+
ref_gts = np.array([[0, 1], [1, 1]]) # freq = [0.5, 1.0]
|
206
|
+
tgt_gts = np.array([[1, 0], [0, 0]]) # freq = [0.5, 0.0]
|
207
|
+
src_gts = np.array([[1, 1], [1, 0]]) # freq = [1.0, 0.5]
|
208
|
+
out_gts = np.array([[0, 0], [0, 1]]) # freq = [0.0, 0.5]
|
209
|
+
|
210
|
+
ref, tgt, src, out = calc_four_pops_freq(
|
211
|
+
ref_gts,
|
212
|
+
tgt_gts,
|
213
|
+
src_gts,
|
214
|
+
out_gts,
|
215
|
+
)
|
216
|
+
|
217
|
+
np.testing.assert_array_almost_equal(ref, np.array([0.5, 1.0]))
|
218
|
+
np.testing.assert_array_almost_equal(tgt, np.array([0.5, 0.0]))
|
219
|
+
np.testing.assert_array_almost_equal(src, np.array([1.0, 0.5]))
|
220
|
+
np.testing.assert_array_almost_equal(out, np.array([0.0, 0.5]))
|
221
|
+
|
222
|
+
|
223
|
+
def test_calc_four_pops_freq_no_outgroup():
|
224
|
+
ref_gts = np.array([[0, 1]])
|
225
|
+
tgt_gts = np.array([[1, 0]])
|
226
|
+
src_gts = np.array([[1, 1]])
|
227
|
+
|
228
|
+
ref, tgt, src, out = calc_four_pops_freq(ref_gts, tgt_gts, src_gts, out_gts=None)
|
229
|
+
|
230
|
+
np.testing.assert_array_equal(ref, np.array([0.5]))
|
231
|
+
np.testing.assert_array_equal(tgt, np.array([0.5]))
|
232
|
+
np.testing.assert_array_equal(src, np.array([1.0]))
|
233
|
+
np.testing.assert_array_equal(out, np.array([0.0])) # default to 0s
|
234
|
+
|
235
|
+
|
236
|
+
def test_calc_four_pops_freq_diploid():
|
237
|
+
ref_gts = np.array([[0, 2]])
|
238
|
+
tgt_gts = np.array([[1, 1]])
|
239
|
+
src_gts = np.array([[2, 0]])
|
240
|
+
out_gts = np.array([[1, 1]])
|
241
|
+
|
242
|
+
# ploidy=2 → total alleles = 2 * n_samples
|
243
|
+
# freq = sum / (2 * N)
|
244
|
+
|
245
|
+
ref, tgt, src, out = calc_four_pops_freq(
|
246
|
+
ref_gts=ref_gts,
|
247
|
+
tgt_gts=tgt_gts,
|
248
|
+
src_gts=src_gts,
|
249
|
+
out_gts=out_gts,
|
250
|
+
ref_ploidy=2,
|
251
|
+
tgt_ploidy=2,
|
252
|
+
src_ploidy=2,
|
253
|
+
out_ploidy=2,
|
254
|
+
)
|
255
|
+
|
256
|
+
np.testing.assert_array_equal(ref, np.array([0.5])) # (0+2)/4
|
257
|
+
np.testing.assert_array_equal(tgt, np.array([0.5])) # (1+1)/4
|
258
|
+
np.testing.assert_array_equal(src, np.array([0.5])) # (2+0)/4
|
259
|
+
np.testing.assert_array_equal(out, np.array([0.5])) # (1+1)/4
|
260
|
+
|
261
|
+
|
262
|
+
def test_calc_four_pops_freq_mixed_ploidy():
|
263
|
+
ref_gts = np.array([[0, 2]])
|
264
|
+
tgt_gts = np.array([[1, 1]])
|
265
|
+
src_gts = np.array([[2, 0]])
|
266
|
+
out_gts = np.array([[1, 1]])
|
267
|
+
|
268
|
+
ref, tgt, src, out = calc_four_pops_freq(
|
269
|
+
ref_gts=ref_gts,
|
270
|
+
tgt_gts=tgt_gts,
|
271
|
+
src_gts=src_gts,
|
272
|
+
out_gts=out_gts,
|
273
|
+
ref_ploidy=2,
|
274
|
+
tgt_ploidy=1,
|
275
|
+
src_ploidy=4,
|
276
|
+
out_ploidy=4,
|
277
|
+
)
|
278
|
+
|
279
|
+
np.testing.assert_array_equal(ref, np.array([0.5])) # (0+2)/4
|
280
|
+
np.testing.assert_array_equal(tgt, np.array([1])) # (1+1)/2
|
281
|
+
np.testing.assert_array_equal(src, np.array([0.25])) # (2+0)/8
|
282
|
+
np.testing.assert_array_equal(out, np.array([0.25])) # (1+1)/8
|
283
|
+
|
284
|
+
|
285
|
+
def test_calc_pattern_sum_abba():
|
286
|
+
ref = np.array([0.1, 0.8])
|
287
|
+
tgt = np.array([0.9, 0.2])
|
288
|
+
src = np.array([0.5, 0.5])
|
289
|
+
out = np.array([0.0, 1.0])
|
290
|
+
|
291
|
+
# pattern: 'abba'
|
292
|
+
# site 0: (1-0.1)*0.9*0.5*(1-0.0) = 0.9*0.9*0.5*1 = 0.405
|
293
|
+
# site 1: (1-0.8)*0.2*0.5*(1-1.0) = 0.2*0.2*0.5*0 = 0.0
|
294
|
+
# sum = 0.405 + 0.0 = 0.405
|
295
|
+
|
296
|
+
result = calc_pattern_sum(ref, tgt, src, out, "abba")
|
297
|
+
assert np.isclose(result, 0.405)
|
298
|
+
|
299
|
+
|
300
|
+
def test_calc_pattern_sum_baba():
|
301
|
+
ref = np.array([0.1, 0.8])
|
302
|
+
tgt = np.array([0.9, 0.2])
|
303
|
+
src = np.array([0.5, 0.5])
|
304
|
+
out = np.array([0.0, 1.0])
|
305
|
+
|
306
|
+
# pattern: 'baba'
|
307
|
+
# site 0: 0.1*(1-0.9)*0.5*(1-0.0) = 0.1*0.1*0.5*1 = 0.005
|
308
|
+
# site 1: 0.8*(1-0.2)*0.5*0 = 0.8*0.8*0.5*0 = 0
|
309
|
+
# sum = 0.005
|
310
|
+
|
311
|
+
result = calc_pattern_sum(ref, tgt, src, out, "baba")
|
312
|
+
assert np.isclose(result, 0.005)
|
313
|
+
|
314
|
+
|
315
|
+
def test_calc_pattern_sum_baaa():
|
316
|
+
ref = np.array([0.1, 0.8])
|
317
|
+
tgt = np.array([0.9, 0.2])
|
318
|
+
src = np.array([0.5, 0.5])
|
319
|
+
out = np.array([0.0, 1.0])
|
320
|
+
|
321
|
+
# pattern: 'baaa'
|
322
|
+
# site 0: 0.1*(1-0.9)*(1-0.5)*(1-0.0) = 0.1*0.1*0.5*1 = 0.005
|
323
|
+
# site 1: 0.8*(1-0.2)*(1-0.5)*0 = 0.8*0.8*0.5*0 = 0
|
324
|
+
# sum = 0.005
|
325
|
+
|
326
|
+
result = calc_pattern_sum(ref, tgt, src, out, "baaa")
|
327
|
+
assert np.isclose(result, 0.005)
|
328
|
+
|
329
|
+
|
330
|
+
def test_calc_pattern_sum_abaa():
|
331
|
+
ref = np.array([0.1, 0.8])
|
332
|
+
tgt = np.array([0.9, 0.2])
|
333
|
+
src = np.array([0.5, 0.5])
|
334
|
+
out = np.array([0.0, 1.0])
|
335
|
+
|
336
|
+
# pattern: 'abaa'
|
337
|
+
# site 0: (1-0.1)*0.9*(1-0.5)*(1-0.0) = 0.9*0.9*0.5*1 = 0.405
|
338
|
+
# site 1: (1-0.8)*0.2*(1-0.5)*0 = 0.2*0.2*0.5*0 = 0
|
339
|
+
# sum = 0.405
|
340
|
+
|
341
|
+
result = calc_pattern_sum(ref, tgt, src, out, "abaa")
|
342
|
+
assert np.isclose(result, 0.405)
|
343
|
+
|
344
|
+
|
345
|
+
def test_invalid_pattern_length():
|
346
|
+
ref = tgt = src = out = np.array([0.1, 0.2])
|
347
|
+
with pytest.raises(ValueError, match="four-character"):
|
348
|
+
_ = calc_pattern_sum(ref, tgt, src, out, "ab")
|
349
|
+
|
350
|
+
|
351
|
+
def test_invalid_pattern_char():
|
352
|
+
ref = tgt = src = out = np.array([0.1, 0.2])
|
353
|
+
with pytest.raises(ValueError, match="Invalid character"):
|
354
|
+
_ = calc_pattern_sum(ref, tgt, src, out, "abxa")
|