sai-pg 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. sai/__init__.py +2 -0
  2. sai/__main__.py +6 -3
  3. sai/configs/__init__.py +24 -0
  4. sai/configs/global_config.py +83 -0
  5. sai/configs/ploidy_config.py +94 -0
  6. sai/configs/pop_config.py +82 -0
  7. sai/configs/stat_config.py +220 -0
  8. sai/{utils/generators → generators}/chunk_generator.py +1 -1
  9. sai/{utils/generators → generators}/window_generator.py +81 -37
  10. sai/{utils/multiprocessing → multiprocessing}/mp_manager.py +2 -2
  11. sai/{utils/multiprocessing → multiprocessing}/mp_pool.py +2 -2
  12. sai/parsers/outlier_parser.py +4 -3
  13. sai/parsers/score_parser.py +8 -119
  14. sai/{utils/preprocessors → preprocessors}/chunk_preprocessor.py +21 -15
  15. sai/preprocessors/feature_preprocessor.py +236 -0
  16. sai/registries/__init__.py +22 -0
  17. sai/registries/generic_registry.py +89 -0
  18. sai/registries/stat_registry.py +30 -0
  19. sai/sai.py +124 -220
  20. sai/stats/__init__.py +11 -0
  21. sai/stats/danc_statistic.py +83 -0
  22. sai/stats/dd_statistic.py +77 -0
  23. sai/stats/df_statistic.py +84 -0
  24. sai/stats/dplus_statistic.py +86 -0
  25. sai/stats/fd_statistic.py +92 -0
  26. sai/stats/generic_statistic.py +93 -0
  27. sai/stats/q_statistic.py +104 -0
  28. sai/stats/stat_utils.py +259 -0
  29. sai/stats/u_statistic.py +99 -0
  30. sai/utils/utils.py +213 -142
  31. {sai_pg-1.0.1.dist-info → sai_pg-1.1.0.dist-info}/METADATA +3 -14
  32. sai_pg-1.1.0.dist-info/RECORD +70 -0
  33. {sai_pg-1.0.1.dist-info → sai_pg-1.1.0.dist-info}/WHEEL +1 -1
  34. sai_pg-1.1.0.dist-info/top_level.txt +2 -0
  35. tests/configs/test_global_config.py +163 -0
  36. tests/configs/test_ploidy_config.py +93 -0
  37. tests/configs/test_pop_config.py +90 -0
  38. tests/configs/test_stat_config.py +171 -0
  39. tests/generators/test_chunk_generator.py +51 -0
  40. tests/generators/test_window_generator.py +164 -0
  41. tests/multiprocessing/test_mp_manager.py +92 -0
  42. tests/multiprocessing/test_mp_pool.py +79 -0
  43. tests/parsers/test_argument_validation.py +133 -0
  44. tests/parsers/test_outlier_parser.py +53 -0
  45. tests/parsers/test_score_parser.py +63 -0
  46. tests/preprocessors/test_chunk_preprocessor.py +79 -0
  47. tests/preprocessors/test_feature_preprocessor.py +223 -0
  48. tests/registries/test_registries.py +74 -0
  49. tests/stats/test_danc_statistic.py +51 -0
  50. tests/stats/test_dd_statistic.py +45 -0
  51. tests/stats/test_df_statistic.py +73 -0
  52. tests/stats/test_dplus_statistic.py +79 -0
  53. tests/stats/test_fd_statistic.py +68 -0
  54. tests/stats/test_q_statistic.py +268 -0
  55. tests/stats/test_stat_utils.py +354 -0
  56. tests/stats/test_u_statistic.py +233 -0
  57. tests/test___main__.py +51 -0
  58. tests/test_sai.py +102 -0
  59. tests/utils/test_utils.py +511 -0
  60. sai/parsers/plot_parser.py +0 -152
  61. sai/stats/features.py +0 -302
  62. sai/utils/preprocessors/feature_preprocessor.py +0 -211
  63. sai_pg-1.0.1.dist-info/RECORD +0 -30
  64. sai_pg-1.0.1.dist-info/top_level.txt +0 -1
  65. /sai/{utils/generators → generators}/__init__.py +0 -0
  66. /sai/{utils/generators → generators}/data_generator.py +0 -0
  67. /sai/{utils/multiprocessing → multiprocessing}/__init__.py +0 -0
  68. /sai/{utils/preprocessors → preprocessors}/__init__.py +0 -0
  69. /sai/{utils/preprocessors → preprocessors}/data_preprocessor.py +0 -0
  70. {sai_pg-1.0.1.dist-info → sai_pg-1.1.0.dist-info}/entry_points.txt +0 -0
  71. {sai_pg-1.0.1.dist-info → sai_pg-1.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,268 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import pytest
22
+ import numpy as np
23
+ from sai.stats import QStatistic
24
+
25
+
26
+ def test_QStatistic_compute_basic():
27
+ # Test data
28
+ ref_gts = np.array([[0, 0, 1], [0, 0, 0], [1, 1, 1]])
29
+ tgt_gts = np.array([[0, 1, 1], [0, 0, 1], [1, 1, 1]])
30
+ src_gts = np.array([[1, 1, 1], [0, 1, 1], [1, 1, 1]])
31
+ pos = np.array([0, 1, 2])
32
+ w, y, quantile = 0.5, ("=", 1.0), 0.95
33
+
34
+ # Expected output
35
+ expected_result = 0.66667 # Only the first site meets the criteria
36
+ expected_positions = np.array([0])
37
+
38
+ # Run test
39
+ q_stat = QStatistic(
40
+ ref_gts=ref_gts,
41
+ tgt_gts=tgt_gts,
42
+ src_gts_list=[src_gts],
43
+ ref_ploidy=1,
44
+ tgt_ploidy=1,
45
+ src_ploidy_list=[1],
46
+ )
47
+ results = q_stat.compute(
48
+ pos=pos,
49
+ w=w,
50
+ y_list=[y],
51
+ quantile=quantile,
52
+ anc_allele_available=False,
53
+ )
54
+
55
+ assert results["name"] == "Q"
56
+ assert np.isclose(
57
+ results["value"], expected_result
58
+ ), f"Expected {expected_result}, got {results['value']}"
59
+ assert np.array_equal(results["cdd_pos"], expected_positions)
60
+
61
+ results = q_stat.compute(
62
+ pos=pos,
63
+ w=w,
64
+ y_list=[y],
65
+ quantile=quantile,
66
+ anc_allele_available=True,
67
+ )
68
+
69
+ assert np.isclose(
70
+ results["value"], expected_result
71
+ ), f"Expected {expected_result}, got {results['value']}"
72
+ assert np.array_equal(results["cdd_pos"], expected_positions)
73
+
74
+
75
+ def test_QStatistic_compute_no_match():
76
+ # Test data with no matching loci
77
+ ref_gts = np.array([[0, 0, 1], [0, 0, 0]])
78
+ tgt_gts = np.array([[0, 1, 1], [1, 1, 1]])
79
+ src_gts = np.array([[1, 1, 1], [1, 1, 1]])
80
+ pos = np.array([0, 1])
81
+ w, y, quantile = (
82
+ 0.3,
83
+ ("=", 0.0),
84
+ 0.95,
85
+ ) # No tgt_gts frequencies < w and no src_gts frequencies == y
86
+
87
+ # Expected output
88
+ expected_positions = np.array([])
89
+
90
+ # Run test
91
+ q_stat = QStatistic(
92
+ ref_gts=ref_gts,
93
+ tgt_gts=tgt_gts,
94
+ src_gts_list=[src_gts],
95
+ ref_ploidy=1,
96
+ tgt_ploidy=1,
97
+ src_ploidy_list=[1],
98
+ )
99
+ results = q_stat.compute(
100
+ pos=pos,
101
+ w=w,
102
+ y_list=[y],
103
+ quantile=quantile,
104
+ anc_allele_available=False,
105
+ )
106
+
107
+ assert np.isnan(results["value"]), f"Expected NaN, got {results['value']}"
108
+ assert np.array_equal(results["cdd_pos"], expected_positions)
109
+
110
+
111
+ def test_QStatistic_compute_different_quantile():
112
+ # Test data
113
+ ref_gts = np.array([[0, 0, 1], [1, 0, 0], [0, 0, 1]])
114
+ tgt_gts = np.array([[0, 1, 1], [1, 1, 1], [1, 1, 1]])
115
+ src_gts = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1]])
116
+ pos = np.array([0, 1, 2])
117
+ w, y, quantile = 0.5, ("=", 1.0), 0.5
118
+
119
+ # Expected output
120
+ expected_result = (
121
+ 1.0 # 50% quantile (median) of [1.0, 1.0, 1.0] in tgt_gts that meets conditions
122
+ )
123
+ expected_positions = np.array([1, 2])
124
+
125
+ # Run test
126
+ q_stat = QStatistic(
127
+ ref_gts=ref_gts,
128
+ tgt_gts=tgt_gts,
129
+ src_gts_list=[src_gts],
130
+ ref_ploidy=1,
131
+ tgt_ploidy=1,
132
+ src_ploidy_list=[1],
133
+ )
134
+ results = q_stat.compute(
135
+ pos=pos,
136
+ w=w,
137
+ y_list=[y],
138
+ quantile=quantile,
139
+ anc_allele_available=False,
140
+ )
141
+
142
+ assert np.isclose(
143
+ results["value"], expected_result
144
+ ), f"Expected {expected_result}, got {results['value']}"
145
+ assert np.array_equal(results["cdd_pos"], expected_positions)
146
+
147
+
148
+ def test_QStatistic_compute_edge_case():
149
+ # Edge case where only one site meets criteria
150
+ ref_gts = np.array([[0, 0, 1], [0, 0, 0], [1, 1, 1]])
151
+ tgt_gts = np.array([[0, 1, 1], [1, 1, 1], [0, 0, 0]])
152
+ src_gts = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1]])
153
+ pos = np.array([0, 1, 2])
154
+ w, y, quantile = 0.95, ("=", 1.0), 0.95
155
+
156
+ # Expected output
157
+ expected_result = 0.9666666666666667
158
+ expected_positions = np.array([1])
159
+
160
+ # Run test
161
+ q_stat = QStatistic(
162
+ ref_gts=ref_gts,
163
+ tgt_gts=tgt_gts,
164
+ src_gts_list=[src_gts],
165
+ ref_ploidy=1,
166
+ tgt_ploidy=1,
167
+ src_ploidy_list=[1],
168
+ )
169
+ results = q_stat.compute(
170
+ pos=pos,
171
+ w=w,
172
+ y_list=[y],
173
+ quantile=quantile,
174
+ anc_allele_available=False,
175
+ )
176
+
177
+ assert np.isclose(
178
+ results["value"], expected_result
179
+ ), f"Expected {expected_result}, got {results['value']}"
180
+ assert np.array_equal(results["cdd_pos"], expected_positions)
181
+
182
+
183
+ def test_QStatistic_compute_with_two_sources():
184
+ # Test data
185
+ ref_gts = np.array([[1, 1, 0], [0, 1, 1], [1, 1, 1], [0, 0, 1]])
186
+ tgt_gts = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1], [1, 1, 1]])
187
+ src_gts1 = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 1]])
188
+ src_gts2 = np.array([[1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])
189
+ pos = np.array([0, 1, 2, 3])
190
+ w, y_list, quantile = 0.5, [("=", 1), ("=", 1)], 0.95
191
+
192
+ # Expected result: 95% quantile of the filtered tgt_gts frequencies
193
+ expected_positions = np.array([])
194
+
195
+ # Run test
196
+ q_stat = QStatistic(
197
+ ref_gts=ref_gts,
198
+ tgt_gts=tgt_gts,
199
+ src_gts_list=[src_gts1, src_gts2],
200
+ ref_ploidy=1,
201
+ tgt_ploidy=1,
202
+ src_ploidy_list=[1, 1],
203
+ )
204
+ results = q_stat.compute(
205
+ pos=pos,
206
+ w=w,
207
+ y_list=y_list,
208
+ quantile=quantile,
209
+ anc_allele_available=False,
210
+ )
211
+
212
+ assert np.isnan(results["value"]), f"Expected NaN, got {results['value']}"
213
+ assert np.array_equal(results["cdd_pos"], expected_positions)
214
+
215
+
216
+ def test_QStatistic_compute_with_mixed_ploidy():
217
+ ref_gts = np.array([[1, 1, 0], [0, 1, 1], [1, 1, 1], [0, 0, 1]])
218
+ tgt_gts = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1], [1, 1, 1]])
219
+ src_gts1 = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 1]])
220
+ src_gts2 = np.array([[1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])
221
+ pos = np.array([0, 1, 2, 3])
222
+ w, y_list, quantile = 0.5, [("=", 1), ("=", 1)], 0.95
223
+
224
+ expected_positions = np.array([])
225
+
226
+ q_stat = QStatistic(
227
+ ref_gts=ref_gts,
228
+ tgt_gts=tgt_gts,
229
+ src_gts_list=[src_gts1, src_gts2],
230
+ ref_ploidy=2,
231
+ tgt_ploidy=2,
232
+ src_ploidy_list=[4, 4],
233
+ )
234
+ results = q_stat.compute(
235
+ pos=pos,
236
+ w=w,
237
+ y_list=y_list,
238
+ quantile=quantile,
239
+ anc_allele_available=False,
240
+ )
241
+
242
+ assert np.isnan(results["value"]), f"Expected NaN, got {results['value']}"
243
+ assert np.array_equal(results["cdd_pos"], expected_positions)
244
+
245
+
246
+ def test_QStatistic_compute_with_missing_keys():
247
+ ref_gts = np.array([[1, 1, 0], [0, 1, 1], [1, 1, 1], [0, 0, 1]])
248
+ tgt_gts = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1], [1, 1, 1]])
249
+ src_gts1 = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 1]])
250
+ src_gts2 = np.array([[1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])
251
+ pos = np.array([0, 1, 2, 3])
252
+ w, quantile = 0.5, 0.95
253
+
254
+ with pytest.raises(ValueError):
255
+ q_stat = QStatistic(
256
+ ref_gts=ref_gts,
257
+ tgt_gts=tgt_gts,
258
+ src_gts_list=[src_gts1, src_gts2],
259
+ ref_ploidy=2,
260
+ tgt_ploidy=2,
261
+ src_ploidy_list=[4, 4],
262
+ )
263
+ q_stat.compute(
264
+ pos=pos,
265
+ w=w,
266
+ quantile=quantile,
267
+ anc_allele_available=False,
268
+ )
@@ -0,0 +1,354 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import pytest
22
+ import numpy as np
23
+ from sai.stats import calc_freq
24
+ from sai.stats import compute_matching_loci
25
+ from sai.stats import calc_four_pops_freq
26
+ from sai.stats import calc_pattern_sum
27
+
28
+
29
+ def test_phased_data():
30
+ # Phased data, ploidy = 1
31
+ gts = np.array([[1, 0, 0, 1], [0, 0, 0, 0], [1, 1, 1, 1]])
32
+ expected_frequency = np.array([0.5, 0.0, 1.0])
33
+ result = calc_freq(gts, ploidy=1)
34
+ np.testing.assert_array_almost_equal(
35
+ result, expected_frequency, decimal=6, err_msg="Phased data test failed."
36
+ )
37
+
38
+
39
+ def test_unphased_diploid_data():
40
+ # Unphased data, ploidy = 2 (diploid)
41
+ gts = np.array([[1, 1], [0, 0], [2, 2]])
42
+ expected_frequency = np.array([0.5, 0.0, 1.0])
43
+ result = calc_freq(gts, ploidy=2)
44
+ np.testing.assert_array_almost_equal(
45
+ result,
46
+ expected_frequency,
47
+ decimal=6,
48
+ err_msg="Unphased diploid data test failed.",
49
+ )
50
+
51
+
52
+ def test_unphased_triploid_data():
53
+ # Unphased data, ploidy = 3 (triploid)
54
+ gts = np.array([[1, 2, 3], [0, 0, 0], [3, 3, 3]])
55
+ expected_frequency = np.array([0.6667, 0.0, 1.0])
56
+ result = calc_freq(gts, ploidy=3)
57
+ np.testing.assert_array_almost_equal(
58
+ result,
59
+ expected_frequency,
60
+ decimal=4,
61
+ err_msg="Unphased triploid data test failed.",
62
+ )
63
+
64
+
65
+ def test_unphased_tetraploid_data():
66
+ # Unphased data, ploidy = 4 (tetraploid)
67
+ gts = np.array([[2, 2, 2, 2], [1, 3, 0, 4], [0, 0, 0, 0]])
68
+ expected_frequency = np.array([0.5, 0.5, 0.0])
69
+ result = calc_freq(gts, ploidy=4)
70
+ np.testing.assert_array_almost_equal(
71
+ result,
72
+ expected_frequency,
73
+ decimal=6,
74
+ err_msg="Unphased tetraploid data test failed.",
75
+ )
76
+
77
+
78
+ def test_invalid_ploidy():
79
+ gts = np.array([[1, 2, 3], [0, 0, 0], [3, 3, 3]])
80
+
81
+ with pytest.raises(ValueError):
82
+ calc_freq(gts, ploidy=None)
83
+
84
+ with pytest.raises(ValueError):
85
+ calc_freq(gts, ploidy=9.9)
86
+
87
+ with pytest.raises(ValueError):
88
+ calc_freq(gts, ploidy=-100)
89
+
90
+
91
+ def test_compute_matching_loci():
92
+ # Sample genotype data
93
+ ref_gts = np.array([[0, 1, 0], [1, 1, 0], [0, 0, 1]])
94
+ tgt_gts = np.array([[1, 1, 0], [0, 1, 1], [1, 1, 1]])
95
+ src_gts_list = [
96
+ np.array([[0, 0, 1], [1, 1, 0], [0, 1, 1]]), # src1
97
+ np.array([[1, 1, 0], [1, 0, 0], [1, 1, 0]]), # src2
98
+ ]
99
+
100
+ # Define parameters with all possible conditions
101
+ conditions = [("=", 0.5), ("<", 0.4), (">", 0.3), ("<=", 0.6), (">=", 0.2)]
102
+ ploidy = [2, 2, 2]
103
+ anc_allele_available = False
104
+
105
+ for y_condition in conditions:
106
+ y_list = [y_condition, y_condition] # Apply the same condition to both sources
107
+
108
+ # Call the function
109
+ ref_freq, tgt_freq, condition = compute_matching_loci(
110
+ ref_gts,
111
+ tgt_gts,
112
+ src_gts_list,
113
+ 0.5,
114
+ y_list,
115
+ ploidy,
116
+ anc_allele_available,
117
+ )
118
+
119
+ # Assertions to verify the outputs
120
+ assert ref_freq.shape == (3,)
121
+ assert tgt_freq.shape == (3,)
122
+ assert condition.shape == (3,)
123
+ assert np.all((ref_freq >= 0) & (ref_freq <= 1))
124
+ assert np.all((tgt_freq >= 0) & (tgt_freq <= 1))
125
+ assert np.all(
126
+ np.logical_or(condition == True, condition == False)
127
+ ) # Ensure condition is boolean
128
+
129
+ # Test invalid w values
130
+ with pytest.raises(
131
+ ValueError, match=r"Parameters w must be within the range \[0, 1\]."
132
+ ):
133
+ compute_matching_loci(
134
+ ref_gts,
135
+ tgt_gts,
136
+ src_gts_list,
137
+ -0.1,
138
+ y_list,
139
+ ploidy,
140
+ anc_allele_available,
141
+ )
142
+ with pytest.raises(
143
+ ValueError, match=r"Parameters w must be within the range \[0, 1\]."
144
+ ):
145
+ compute_matching_loci(
146
+ ref_gts,
147
+ tgt_gts,
148
+ src_gts_list,
149
+ 1.1,
150
+ y_list,
151
+ ploidy,
152
+ anc_allele_available,
153
+ )
154
+
155
+ # Test invalid y values
156
+ with pytest.raises(ValueError, match="Invalid value in y_list"):
157
+ compute_matching_loci(
158
+ ref_gts,
159
+ tgt_gts,
160
+ src_gts_list,
161
+ 0.5,
162
+ [("=", -0.1)],
163
+ ploidy,
164
+ anc_allele_available,
165
+ )
166
+ with pytest.raises(ValueError, match="Invalid value in y_list"):
167
+ compute_matching_loci(
168
+ ref_gts,
169
+ tgt_gts,
170
+ src_gts_list,
171
+ 0.5,
172
+ [("=", 1.1)],
173
+ ploidy,
174
+ anc_allele_available,
175
+ )
176
+
177
+ # Test invalid operators
178
+ with pytest.raises(ValueError, match="Invalid operator in y_list"):
179
+ compute_matching_loci(
180
+ ref_gts,
181
+ tgt_gts,
182
+ src_gts_list,
183
+ 0.5,
184
+ [("invalid", 0.5)],
185
+ ploidy,
186
+ anc_allele_available,
187
+ )
188
+
189
+ # Test mismatched src_gts_list and y_list lengths
190
+ with pytest.raises(
191
+ ValueError, match="The length of src_gts_list and y_list must match"
192
+ ):
193
+ compute_matching_loci(
194
+ ref_gts,
195
+ tgt_gts,
196
+ src_gts_list,
197
+ 0.5,
198
+ [("=", 0.5)],
199
+ ploidy,
200
+ anc_allele_available,
201
+ )
202
+
203
+
204
+ def test_calc_four_pops_freq_basic():
205
+ ref_gts = np.array([[0, 1], [1, 1]]) # freq = [0.5, 1.0]
206
+ tgt_gts = np.array([[1, 0], [0, 0]]) # freq = [0.5, 0.0]
207
+ src_gts = np.array([[1, 1], [1, 0]]) # freq = [1.0, 0.5]
208
+ out_gts = np.array([[0, 0], [0, 1]]) # freq = [0.0, 0.5]
209
+
210
+ ref, tgt, src, out = calc_four_pops_freq(
211
+ ref_gts,
212
+ tgt_gts,
213
+ src_gts,
214
+ out_gts,
215
+ )
216
+
217
+ np.testing.assert_array_almost_equal(ref, np.array([0.5, 1.0]))
218
+ np.testing.assert_array_almost_equal(tgt, np.array([0.5, 0.0]))
219
+ np.testing.assert_array_almost_equal(src, np.array([1.0, 0.5]))
220
+ np.testing.assert_array_almost_equal(out, np.array([0.0, 0.5]))
221
+
222
+
223
+ def test_calc_four_pops_freq_no_outgroup():
224
+ ref_gts = np.array([[0, 1]])
225
+ tgt_gts = np.array([[1, 0]])
226
+ src_gts = np.array([[1, 1]])
227
+
228
+ ref, tgt, src, out = calc_four_pops_freq(ref_gts, tgt_gts, src_gts, out_gts=None)
229
+
230
+ np.testing.assert_array_equal(ref, np.array([0.5]))
231
+ np.testing.assert_array_equal(tgt, np.array([0.5]))
232
+ np.testing.assert_array_equal(src, np.array([1.0]))
233
+ np.testing.assert_array_equal(out, np.array([0.0])) # default to 0s
234
+
235
+
236
+ def test_calc_four_pops_freq_diploid():
237
+ ref_gts = np.array([[0, 2]])
238
+ tgt_gts = np.array([[1, 1]])
239
+ src_gts = np.array([[2, 0]])
240
+ out_gts = np.array([[1, 1]])
241
+
242
+ # ploidy=2 → total alleles = 2 * n_samples
243
+ # freq = sum / (2 * N)
244
+
245
+ ref, tgt, src, out = calc_four_pops_freq(
246
+ ref_gts=ref_gts,
247
+ tgt_gts=tgt_gts,
248
+ src_gts=src_gts,
249
+ out_gts=out_gts,
250
+ ref_ploidy=2,
251
+ tgt_ploidy=2,
252
+ src_ploidy=2,
253
+ out_ploidy=2,
254
+ )
255
+
256
+ np.testing.assert_array_equal(ref, np.array([0.5])) # (0+2)/4
257
+ np.testing.assert_array_equal(tgt, np.array([0.5])) # (1+1)/4
258
+ np.testing.assert_array_equal(src, np.array([0.5])) # (2+0)/4
259
+ np.testing.assert_array_equal(out, np.array([0.5])) # (1+1)/4
260
+
261
+
262
+ def test_calc_four_pops_freq_mixed_ploidy():
263
+ ref_gts = np.array([[0, 2]])
264
+ tgt_gts = np.array([[1, 1]])
265
+ src_gts = np.array([[2, 0]])
266
+ out_gts = np.array([[1, 1]])
267
+
268
+ ref, tgt, src, out = calc_four_pops_freq(
269
+ ref_gts=ref_gts,
270
+ tgt_gts=tgt_gts,
271
+ src_gts=src_gts,
272
+ out_gts=out_gts,
273
+ ref_ploidy=2,
274
+ tgt_ploidy=1,
275
+ src_ploidy=4,
276
+ out_ploidy=4,
277
+ )
278
+
279
+ np.testing.assert_array_equal(ref, np.array([0.5])) # (0+2)/4
280
+ np.testing.assert_array_equal(tgt, np.array([1])) # (1+1)/2
281
+ np.testing.assert_array_equal(src, np.array([0.25])) # (2+0)/8
282
+ np.testing.assert_array_equal(out, np.array([0.25])) # (1+1)/8
283
+
284
+
285
+ def test_calc_pattern_sum_abba():
286
+ ref = np.array([0.1, 0.8])
287
+ tgt = np.array([0.9, 0.2])
288
+ src = np.array([0.5, 0.5])
289
+ out = np.array([0.0, 1.0])
290
+
291
+ # pattern: 'abba'
292
+ # site 0: (1-0.1)*0.9*0.5*(1-0.0) = 0.9*0.9*0.5*1 = 0.405
293
+ # site 1: (1-0.8)*0.2*0.5*(1-1.0) = 0.2*0.2*0.5*0 = 0.0
294
+ # sum = 0.405 + 0.0 = 0.405
295
+
296
+ result = calc_pattern_sum(ref, tgt, src, out, "abba")
297
+ assert np.isclose(result, 0.405)
298
+
299
+
300
+ def test_calc_pattern_sum_baba():
301
+ ref = np.array([0.1, 0.8])
302
+ tgt = np.array([0.9, 0.2])
303
+ src = np.array([0.5, 0.5])
304
+ out = np.array([0.0, 1.0])
305
+
306
+ # pattern: 'baba'
307
+ # site 0: 0.1*(1-0.9)*0.5*(1-0.0) = 0.1*0.1*0.5*1 = 0.005
308
+ # site 1: 0.8*(1-0.2)*0.5*0 = 0.8*0.8*0.5*0 = 0
309
+ # sum = 0.005
310
+
311
+ result = calc_pattern_sum(ref, tgt, src, out, "baba")
312
+ assert np.isclose(result, 0.005)
313
+
314
+
315
+ def test_calc_pattern_sum_baaa():
316
+ ref = np.array([0.1, 0.8])
317
+ tgt = np.array([0.9, 0.2])
318
+ src = np.array([0.5, 0.5])
319
+ out = np.array([0.0, 1.0])
320
+
321
+ # pattern: 'baaa'
322
+ # site 0: 0.1*(1-0.9)*(1-0.5)*(1-0.0) = 0.1*0.1*0.5*1 = 0.005
323
+ # site 1: 0.8*(1-0.2)*(1-0.5)*0 = 0.8*0.8*0.5*0 = 0
324
+ # sum = 0.005
325
+
326
+ result = calc_pattern_sum(ref, tgt, src, out, "baaa")
327
+ assert np.isclose(result, 0.005)
328
+
329
+
330
+ def test_calc_pattern_sum_abaa():
331
+ ref = np.array([0.1, 0.8])
332
+ tgt = np.array([0.9, 0.2])
333
+ src = np.array([0.5, 0.5])
334
+ out = np.array([0.0, 1.0])
335
+
336
+ # pattern: 'abaa'
337
+ # site 0: (1-0.1)*0.9*(1-0.5)*(1-0.0) = 0.9*0.9*0.5*1 = 0.405
338
+ # site 1: (1-0.8)*0.2*(1-0.5)*0 = 0.2*0.2*0.5*0 = 0
339
+ # sum = 0.405
340
+
341
+ result = calc_pattern_sum(ref, tgt, src, out, "abaa")
342
+ assert np.isclose(result, 0.405)
343
+
344
+
345
+ def test_invalid_pattern_length():
346
+ ref = tgt = src = out = np.array([0.1, 0.2])
347
+ with pytest.raises(ValueError, match="four-character"):
348
+ _ = calc_pattern_sum(ref, tgt, src, out, "ab")
349
+
350
+
351
+ def test_invalid_pattern_char():
352
+ ref = tgt = src = out = np.array([0.1, 0.2])
353
+ with pytest.raises(ValueError, match="Invalid character"):
354
+ _ = calc_pattern_sum(ref, tgt, src, out, "abxa")