sai-pg 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. sai/__init__.py +2 -0
  2. sai/__main__.py +6 -3
  3. sai/configs/__init__.py +24 -0
  4. sai/configs/global_config.py +83 -0
  5. sai/configs/ploidy_config.py +94 -0
  6. sai/configs/pop_config.py +82 -0
  7. sai/configs/stat_config.py +220 -0
  8. sai/{utils/generators → generators}/chunk_generator.py +2 -8
  9. sai/{utils/generators → generators}/window_generator.py +82 -37
  10. sai/{utils/multiprocessing → multiprocessing}/mp_manager.py +2 -2
  11. sai/{utils/multiprocessing → multiprocessing}/mp_pool.py +2 -2
  12. sai/parsers/outlier_parser.py +4 -3
  13. sai/parsers/score_parser.py +8 -119
  14. sai/{utils/preprocessors → preprocessors}/chunk_preprocessor.py +21 -15
  15. sai/preprocessors/feature_preprocessor.py +236 -0
  16. sai/registries/__init__.py +22 -0
  17. sai/registries/generic_registry.py +89 -0
  18. sai/registries/stat_registry.py +30 -0
  19. sai/sai.py +124 -220
  20. sai/stats/__init__.py +11 -0
  21. sai/stats/danc_statistic.py +83 -0
  22. sai/stats/dd_statistic.py +77 -0
  23. sai/stats/df_statistic.py +84 -0
  24. sai/stats/dplus_statistic.py +86 -0
  25. sai/stats/fd_statistic.py +92 -0
  26. sai/stats/generic_statistic.py +93 -0
  27. sai/stats/q_statistic.py +104 -0
  28. sai/stats/stat_utils.py +259 -0
  29. sai/stats/u_statistic.py +99 -0
  30. sai/utils/utils.py +220 -143
  31. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/METADATA +3 -14
  32. sai_pg-1.1.0.dist-info/RECORD +70 -0
  33. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/WHEEL +1 -1
  34. sai_pg-1.1.0.dist-info/top_level.txt +2 -0
  35. tests/configs/test_global_config.py +163 -0
  36. tests/configs/test_ploidy_config.py +93 -0
  37. tests/configs/test_pop_config.py +90 -0
  38. tests/configs/test_stat_config.py +171 -0
  39. tests/generators/test_chunk_generator.py +51 -0
  40. tests/generators/test_window_generator.py +164 -0
  41. tests/multiprocessing/test_mp_manager.py +92 -0
  42. tests/multiprocessing/test_mp_pool.py +79 -0
  43. tests/parsers/test_argument_validation.py +133 -0
  44. tests/parsers/test_outlier_parser.py +53 -0
  45. tests/parsers/test_score_parser.py +63 -0
  46. tests/preprocessors/test_chunk_preprocessor.py +79 -0
  47. tests/preprocessors/test_feature_preprocessor.py +223 -0
  48. tests/registries/test_registries.py +74 -0
  49. tests/stats/test_danc_statistic.py +51 -0
  50. tests/stats/test_dd_statistic.py +45 -0
  51. tests/stats/test_df_statistic.py +73 -0
  52. tests/stats/test_dplus_statistic.py +79 -0
  53. tests/stats/test_fd_statistic.py +68 -0
  54. tests/stats/test_q_statistic.py +268 -0
  55. tests/stats/test_stat_utils.py +354 -0
  56. tests/stats/test_u_statistic.py +233 -0
  57. tests/test___main__.py +51 -0
  58. tests/test_sai.py +102 -0
  59. tests/utils/test_utils.py +511 -0
  60. sai/parsers/plot_parser.py +0 -152
  61. sai/stats/features.py +0 -302
  62. sai/utils/preprocessors/feature_preprocessor.py +0 -211
  63. sai_pg-1.0.0.dist-info/RECORD +0 -30
  64. sai_pg-1.0.0.dist-info/top_level.txt +0 -1
  65. /sai/{utils/generators → generators}/__init__.py +0 -0
  66. /sai/{utils/generators → generators}/data_generator.py +0 -0
  67. /sai/{utils/multiprocessing → multiprocessing}/__init__.py +0 -0
  68. /sai/{utils/preprocessors → preprocessors}/__init__.py +0 -0
  69. /sai/{utils/preprocessors → preprocessors}/data_preprocessor.py +0 -0
  70. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/entry_points.txt +0 -0
  71. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,223 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import pytest
22
+ import numpy as np
23
+ import sai.stats
24
+ from sai.generators import WindowGenerator
25
+ from sai.preprocessors import FeaturePreprocessor
26
+ from sai.configs import PloidyConfig, StatConfig
27
+
28
+
29
+ @pytest.fixture
30
+ def feature_preprocessor():
31
+ # Create an instance of FeaturePreprocessor with thresholds and temporary output file
32
+ stat_config = StatConfig(
33
+ {
34
+ "U": {
35
+ "ref": {"ref1": 0.3},
36
+ "tgt": {"tgt1": 0.5},
37
+ "src": {"src1": "=1", "src2": "=1"},
38
+ },
39
+ "Q": {
40
+ "ref": {"ref1": 0.3},
41
+ "tgt": {"tgt1": 0.95},
42
+ "src": {"src1": "=0.2", "src2": "=0.4"},
43
+ },
44
+ }
45
+ )
46
+
47
+ return FeaturePreprocessor(
48
+ output_file="test_output.tsv",
49
+ stat_config=stat_config,
50
+ )
51
+
52
+
53
+ def test_run(feature_preprocessor):
54
+ # Define mock input data to test the run method
55
+ chr_name = "21"
56
+ ref_pop = "ref1"
57
+ tgt_pop = "tgt1"
58
+ src_pop_list = ["src1", "src2"]
59
+ start, end = 1000, 2000
60
+
61
+ # Create mock genotype data
62
+ ref_gts = np.array([[0, 0, 1], [1, 1, 0], [0, 1, 1]])
63
+ tgt_gts = np.array([[0, 1, 1], [1, 1, 1], [0, 0, 1]])
64
+ src_gts_list = [
65
+ np.array([[0, 0, 0], [1, 0, 0], [1, 1, 1]]),
66
+ np.array([[1, 1, 1], [0, 1, 1], [0, 0, 1]]),
67
+ ]
68
+
69
+ pos = np.array([100, 200, 300])
70
+
71
+ ploidy_config = PloidyConfig(
72
+ {
73
+ "ref": {"ref1": 1},
74
+ "tgt": {"tgt1": 1},
75
+ "src": {"src1": 1},
76
+ }
77
+ )
78
+
79
+ # Run the method
80
+ result = feature_preprocessor.run(
81
+ chr_name=chr_name,
82
+ ref_pop=ref_pop,
83
+ tgt_pop=tgt_pop,
84
+ src_pop_list=src_pop_list,
85
+ start=start,
86
+ end=end,
87
+ pos=pos,
88
+ ref_gts=ref_gts,
89
+ tgt_gts=tgt_gts,
90
+ src_gts_list=src_gts_list,
91
+ ploidy_config=ploidy_config,
92
+ )
93
+
94
+ # Check that the result contains the expected keys
95
+ assert result[0]["chr_name"] == chr_name
96
+ assert result[0]["start"] == start
97
+ assert result[0]["end"] == end
98
+ assert result[0]["ref_pop"] == ref_pop
99
+ assert result[0]["tgt_pop"] == tgt_pop
100
+ assert result[0]["src_pop_list"] == src_pop_list
101
+ assert "U" in result[0]
102
+ assert "Q" in result[0]
103
+
104
+ result = feature_preprocessor.run(
105
+ chr_name=chr_name,
106
+ ref_pop=ref_pop,
107
+ tgt_pop=tgt_pop,
108
+ src_pop_list=src_pop_list,
109
+ start=start,
110
+ end=end,
111
+ pos=pos,
112
+ ref_gts=None,
113
+ tgt_gts=None,
114
+ src_gts_list=None,
115
+ ploidy_config=None,
116
+ )
117
+
118
+ assert np.isnan(result[0]["U"])
119
+ assert np.isnan(result[0]["Q"])
120
+
121
+
122
+ def test_process_items(feature_preprocessor, tmp_path):
123
+ # Generate a temporary output file path using tmp_path
124
+ temp_output = tmp_path / "test_output.tsv"
125
+ feature_preprocessor.output_file = str(temp_output)
126
+
127
+ # Define a mock items dictionary to test process_items method
128
+ items = {
129
+ "chr_name": "21",
130
+ "start": 1000,
131
+ "end": 2000,
132
+ "ref_pop": "ref1",
133
+ "tgt_pop": "tgt1",
134
+ "src_pop_list": ["src1", "src2"],
135
+ "nsnps": 10,
136
+ "U": 5,
137
+ "Q": 0.8,
138
+ "cdd_pos": {
139
+ "U": np.array([]),
140
+ "Q": np.array([]),
141
+ },
142
+ }
143
+
144
+ # Run the process_items method
145
+ feature_preprocessor.process_items([items])
146
+
147
+ # Check the output file content
148
+ with open(temp_output, "r") as f:
149
+ lines = f.readlines()
150
+ assert len(lines) == 1 # Ensure only one line is written
151
+ expected_output = "21\t1000\t2000\tref1\ttgt1\tsrc1,src2\t10\t5\t0.8\n"
152
+ assert lines[0] == expected_output
153
+
154
+
155
+ @pytest.fixture
156
+ def example_data():
157
+ pytest.example_vcf = "./tests/data/example.vcf"
158
+ pytest.example_ref_ind_list = "./tests/data/example.ref.ind.list"
159
+ pytest.example_tgt_ind_list = "./tests/data/example.tgt.ind.list"
160
+ pytest.example_src_ind_list = "./tests/data/example.src.ind.list"
161
+
162
+
163
+ def test_run_from_file(example_data, tmp_path):
164
+ # Set up the WindowGenerator
165
+ ploidy_config = PloidyConfig(
166
+ {
167
+ "ref": {"AFR": 2},
168
+ "tgt": {"CHB": 2},
169
+ "src": {"Nean": 2},
170
+ }
171
+ )
172
+
173
+ generator = WindowGenerator(
174
+ vcf_file=pytest.example_vcf,
175
+ chr_name=21,
176
+ ref_ind_file=pytest.example_ref_ind_list,
177
+ tgt_ind_file=pytest.example_tgt_ind_list,
178
+ src_ind_file=pytest.example_src_ind_list,
179
+ out_ind_file=None,
180
+ win_len=6666,
181
+ win_step=6666,
182
+ ploidy_config=ploidy_config,
183
+ )
184
+
185
+ # Create a temporary output file path using tmp_path
186
+ temp_output_file = tmp_path / "output.tsv"
187
+
188
+ # Initialize the FeaturePreprocessor with the temporary output file path
189
+ stat_config = StatConfig(
190
+ {"U": {"ref": {"AFR": 0.3}, "tgt": {"CHB": 0.5}, "src": {"Nean": "=1"}}}
191
+ )
192
+
193
+ preprocessor = FeaturePreprocessor(
194
+ output_file=str(temp_output_file),
195
+ stat_config=stat_config,
196
+ )
197
+
198
+ # Run the generator and preprocessor
199
+ for window_data in generator.get():
200
+ items = preprocessor.run(
201
+ chr_name=window_data["chr_name"],
202
+ ref_pop=window_data["ref_pop"],
203
+ tgt_pop=window_data["tgt_pop"],
204
+ src_pop_list=window_data["src_pop_list"],
205
+ start=window_data["start"],
206
+ end=window_data["end"],
207
+ pos=window_data["pos"],
208
+ ref_gts=window_data["ref_gts"],
209
+ tgt_gts=window_data["tgt_gts"],
210
+ src_gts_list=window_data["src_gts_list"],
211
+ ploidy_config=window_data["ploidy_config"],
212
+ )
213
+ preprocessor.process_items(items)
214
+ assert items[0]["U"] == 3
215
+
216
+ # Check that the temporary file was created and is not empty
217
+ assert temp_output_file.exists()
218
+ assert temp_output_file.stat().st_size > 0
219
+
220
+ # Optionally, open and inspect contents of temp_output_file here if necessary
221
+ with open(temp_output_file, "r") as f:
222
+ lines = f.readlines()
223
+ assert len(lines) > 0 # Ensure some content was written
@@ -0,0 +1,74 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import pytest
22
+ from sai.registries import (
23
+ StatRegistry,
24
+ )
25
+
26
+
27
+ @pytest.mark.parametrize(
28
+ "RegistryClass",
29
+ [
30
+ StatRegistry,
31
+ ],
32
+ )
33
+ def test_register_and_get(RegistryClass):
34
+ registry = RegistryClass()
35
+
36
+ class DummyClass:
37
+ pass
38
+
39
+ registry.register("dummy")(DummyClass)
40
+ assert "dummy" in registry.list_registered()
41
+
42
+ retrieved = registry.get("dummy")
43
+ assert retrieved is DummyClass
44
+
45
+
46
+ @pytest.mark.parametrize(
47
+ "RegistryClass",
48
+ [
49
+ StatRegistry,
50
+ ],
51
+ )
52
+ def test_duplicate_registration(RegistryClass):
53
+ registry = RegistryClass()
54
+
55
+ class DummyClass:
56
+ pass
57
+
58
+ registry.register("dummy")(DummyClass)
59
+
60
+ with pytest.raises(ValueError):
61
+ registry.register("dummy")(DummyClass)
62
+
63
+
64
+ @pytest.mark.parametrize(
65
+ "RegistryClass",
66
+ [
67
+ StatRegistry,
68
+ ],
69
+ )
70
+ def test_get_unregistered(RegistryClass):
71
+ registry = RegistryClass()
72
+
73
+ with pytest.raises(KeyError):
74
+ registry.get("nonexistent")
@@ -0,0 +1,51 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from sai.stats import DancStatistic
23
+
24
+
25
+ def test_DancStatistic_compute():
26
+ ref_gts = np.array([[0, 0], [0, 0], [1, 1]]) # Referece population
27
+ tgt_gts = np.array([[1, 0], [0, 1], [0, 1]]) # Taget population
28
+ src_gts = np.array([[0, 1], [1, 0], [1, 0]]) # Source population
29
+ # out_gts = None # No outgroup provided
30
+
31
+ # D ancestral
32
+ # baaa - abaa = 0.25 - 0.5 = -0.25
33
+ # baaa + abaa = 0.25 + 0.5 = 0.75
34
+ # (baaa - abaa) / (baaa + abaa) = -1/3
35
+
36
+ danc_stat = DancStatistic(
37
+ ref_gts=ref_gts,
38
+ tgt_gts=tgt_gts,
39
+ src_gts_list=[src_gts],
40
+ ref_ploidy=1,
41
+ tgt_ploidy=1,
42
+ src_ploidy_list=[1],
43
+ )
44
+ results = danc_stat.compute()
45
+
46
+ expected_result = -1 / 3
47
+
48
+ assert results["name"] == "Danc"
49
+ assert np.isclose(
50
+ results["value"][0], expected_result
51
+ ), f"Expected {expected_result}, but got {results['value'][0]}"
@@ -0,0 +1,45 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from sai.stats import DdStatistic
23
+
24
+
25
+ def test_DdStatistic_compute():
26
+ ref_gts = np.array([[1, 1], [0, 0]])
27
+ tgt_gts = np.array([[1, 0], [0, 1]])
28
+ src_gts = np.array([[0, 1], [1, 1]])
29
+
30
+ dd_stat = DdStatistic(
31
+ ref_gts=ref_gts,
32
+ tgt_gts=tgt_gts,
33
+ src_gts_list=[src_gts],
34
+ ref_ploidy=1,
35
+ tgt_ploidy=1,
36
+ src_ploidy_list=[1],
37
+ )
38
+ results = dd_stat.compute()
39
+
40
+ expected_result = 0.5
41
+
42
+ assert results["name"] == "DD"
43
+ assert np.isclose(
44
+ results["value"][0], expected_result
45
+ ), f"Expected {expected_result}, but got {results['value'][0]}"
@@ -0,0 +1,73 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from sai.stats import DfStatistic
23
+
24
+
25
+ def test_DfStatistic_compute():
26
+ ref_gts = np.array([[0, 0], [0, 0], [1, 1]]) # Referece population
27
+ tgt_gts = np.array([[1, 0], [0, 1], [0, 1]]) # Taget population
28
+ src_gts = np.array([[0, 1], [1, 0], [1, 0]]) # Source population
29
+ # out_gts = None # No outgroup provided
30
+
31
+ # ref_freq = [0, 0, 1]
32
+ # tgt_freq = [0.5, 0.5, 0.5]
33
+ # src_freq = [0.5, 0.5, 0.5]
34
+ # out_freq = [0, 0, 0]
35
+
36
+ # pattern: 'abba'
37
+ # site 0: (1-0)*0.5*0.5*(1-0) = 0.25
38
+ # site 1: (1-0)*0.5*0.5*(1-0) = 0.25
39
+ # site 2: (1-1)*0.5*0.5*(1-0) = 0
40
+ # sum = 0.5
41
+
42
+ # pattern: 'baba'
43
+ # site 0: 0*(1-0.5)*0.5*(1-0) = 0
44
+ # site 1: 0*(1-0.5)*0.5*(1-0) = 0
45
+ # site 2: 1*(1-0.5)*0.5*(1-0) = 0.25
46
+ # sum = 0.25
47
+
48
+ # pattern: 'bbaa'
49
+ # site 0: 0*0.5*(1-0.5)*(1-0) = 0
50
+ # site 1: 0*0.5*(1-0.5)*(1-0) = 0
51
+ # site 2: 1*0.5*(1-0.5)*(1-0) = 0.25
52
+ # sum = 0.25
53
+
54
+ # abba - baba = 0.5 - 0.25 = 0.25
55
+ # abba + baba + 2 * bbaa = 0.5 + 0.25 + 2*0.25 = 1.25
56
+
57
+ # Call the function with the test input
58
+ df_stat = DfStatistic(
59
+ ref_gts=ref_gts,
60
+ tgt_gts=tgt_gts,
61
+ src_gts_list=[src_gts],
62
+ ref_ploidy=1,
63
+ tgt_ploidy=1,
64
+ src_ploidy_list=[1],
65
+ )
66
+ results = df_stat.compute()
67
+
68
+ # Check the result
69
+ expected_result = 0.2
70
+ assert results["name"] == "df"
71
+ assert np.isclose(
72
+ results["value"][0], expected_result
73
+ ), f"Expected {expected_result}, but got {results['value'][0]}"
@@ -0,0 +1,79 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from sai.stats import DplusStatistic
23
+
24
+
25
+ def test_DplusStatistic_compute():
26
+ ref_gts = np.array([[0, 0], [0, 0], [1, 1]]) # Referece population
27
+ tgt_gts = np.array([[1, 0], [0, 1], [0, 1]]) # Taget population
28
+ src_gts = np.array([[0, 1], [1, 0], [1, 0]]) # Source population
29
+ # out_gts = None # No outgroup provided
30
+
31
+ # ref_freq = [0, 0, 1]
32
+ # tgt_freq = [0.5, 0.5, 0.5]
33
+ # src_freq = [0.5, 0.5, 0.5]
34
+ # out_freq = [0, 0, 0]
35
+
36
+ # pattern: 'abba'
37
+ # site 0: (1-0)*0.5*0.5*(1-0) = 0.25
38
+ # site 1: (1-0)*0.5*0.5*(1-0) = 0.25
39
+ # site 2: (1-1)*0.5*0.5*(1-0) = 0
40
+ # sum = 0.5
41
+
42
+ # pattern: 'baba'
43
+ # site 0: 0*(1-0.5)*0.5*(1-0) = 0
44
+ # site 1: 0*(1-0.5)*0.5*(1-0) = 0
45
+ # site 2: 1*(1-0.5)*0.5*(1-0) = 0.25
46
+ # sum = 0.25
47
+
48
+ # pattern: 'baaa'
49
+ # site 0: 0*(1-0.5)*(1-0.5)*(1-0) = 0
50
+ # site 1: 0*(1-0.5)*(1-0.5)*(1-0) = 0
51
+ # site 2: 1*(1-0.5)*(1-0.5)*(1-0) = 0.25
52
+ # sum = 0.25
53
+
54
+ # pattern: 'abaa'
55
+ # site 0: (1-0)*0.5*(1-0.5)*(1-0) = 0.25
56
+ # site 1: (1-0)*0.5*(1-0.5)*(1-0) = 0.25
57
+ # site 2: (1-1)*0.5*(1-0.5)*(1-0) = 0
58
+ # sume = 0.5
59
+
60
+ # abba - baba + baaa - abaa = 0.5 - 0.25 + 0.25 - 0.5 = 0
61
+ # abba + baba + baaa + abaa = 0.5 + 0.25 + 0.25 + 0.5 = 1.5
62
+
63
+ # Call the function with the test input
64
+ dplus_stat = DplusStatistic(
65
+ ref_gts=ref_gts,
66
+ tgt_gts=tgt_gts,
67
+ src_gts_list=[src_gts],
68
+ ref_ploidy=1,
69
+ tgt_ploidy=1,
70
+ src_ploidy_list=[1],
71
+ )
72
+ results = dplus_stat.compute()
73
+
74
+ # Check the result
75
+ expected_result = 0
76
+ assert results["name"] == "Dplus"
77
+ assert np.isclose(
78
+ results["value"][0], expected_result
79
+ ), f"Expected {expected_result}, but got {results['value'][0]}"
@@ -0,0 +1,68 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from sai.stats import FdStatistic
23
+
24
+
25
+ def test_FdStatistic_compute():
26
+ ref_gts = np.array([[0, 1], [1, 0], [0, 1]]) # Reference population
27
+ tgt_gts = np.array([[1, 0], [0, 1], [1, 0]]) # Target population
28
+ src_gts = np.array([[1, 1], [1, 1], [1, 1]]) # Source population
29
+ # out_gts = None # No outgroup provided
30
+
31
+ # ref_freq = [0.5, 0.5, 0.5]
32
+ # tgt_freq = [0.5, 0.5, 0.5]
33
+ # src_freq = [1, 1, 1]
34
+ # out_freq = [0, 0, 0]
35
+
36
+ # dnr_freq = src_freq = [1, 1, 1]
37
+ # pattern: 'abba'
38
+ # site 0: (1-0.5)*1*1*(1-0) = 0.5
39
+ # site 1: (1-0.5)*1*1*(1-0) = 0.5
40
+ # site 2: (1-0.5)*1*1*(1-0) = 0.5
41
+ # sum = 1.5
42
+
43
+ # pattern: 'baba'
44
+ # site 0: 0.5*(1-1)*1*(1-0) = 0
45
+ # site 1: 0.5*(1-1)*1*(1-0) = 0
46
+ # site 2: 0.5*(1-1)*1*(1-0) = 0
47
+ # sum = 0
48
+
49
+ # abba_d - baba_d = 1.5
50
+
51
+ # Call the function with the test input
52
+ fd_stat = FdStatistic(
53
+ ref_gts=ref_gts,
54
+ tgt_gts=tgt_gts,
55
+ src_gts_list=[src_gts],
56
+ ref_ploidy=1,
57
+ tgt_ploidy=1,
58
+ src_ploidy_list=[1],
59
+ )
60
+
61
+ results = fd_stat.compute()
62
+
63
+ # Check the result
64
+ expected_result = 0
65
+ assert results["name"] == "fd"
66
+ assert np.isclose(
67
+ results["value"][0], expected_result
68
+ ), f"Expected {expected_result}, but got {results['value'][0]}"