sai-pg 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sai/__init__.py +2 -0
- sai/__main__.py +6 -3
- sai/configs/__init__.py +24 -0
- sai/configs/global_config.py +83 -0
- sai/configs/ploidy_config.py +94 -0
- sai/configs/pop_config.py +82 -0
- sai/configs/stat_config.py +220 -0
- sai/{utils/generators → generators}/chunk_generator.py +1 -1
- sai/{utils/generators → generators}/window_generator.py +81 -37
- sai/{utils/multiprocessing → multiprocessing}/mp_manager.py +2 -2
- sai/{utils/multiprocessing → multiprocessing}/mp_pool.py +2 -2
- sai/parsers/outlier_parser.py +4 -3
- sai/parsers/score_parser.py +8 -119
- sai/{utils/preprocessors → preprocessors}/chunk_preprocessor.py +21 -15
- sai/preprocessors/feature_preprocessor.py +236 -0
- sai/registries/__init__.py +22 -0
- sai/registries/generic_registry.py +89 -0
- sai/registries/stat_registry.py +30 -0
- sai/sai.py +124 -220
- sai/stats/__init__.py +11 -0
- sai/stats/danc_statistic.py +83 -0
- sai/stats/dd_statistic.py +77 -0
- sai/stats/df_statistic.py +84 -0
- sai/stats/dplus_statistic.py +86 -0
- sai/stats/fd_statistic.py +92 -0
- sai/stats/generic_statistic.py +93 -0
- sai/stats/q_statistic.py +104 -0
- sai/stats/stat_utils.py +259 -0
- sai/stats/u_statistic.py +99 -0
- sai/utils/utils.py +213 -142
- {sai_pg-1.0.1.dist-info → sai_pg-1.1.0.dist-info}/METADATA +3 -14
- sai_pg-1.1.0.dist-info/RECORD +70 -0
- {sai_pg-1.0.1.dist-info → sai_pg-1.1.0.dist-info}/WHEEL +1 -1
- sai_pg-1.1.0.dist-info/top_level.txt +2 -0
- tests/configs/test_global_config.py +163 -0
- tests/configs/test_ploidy_config.py +93 -0
- tests/configs/test_pop_config.py +90 -0
- tests/configs/test_stat_config.py +171 -0
- tests/generators/test_chunk_generator.py +51 -0
- tests/generators/test_window_generator.py +164 -0
- tests/multiprocessing/test_mp_manager.py +92 -0
- tests/multiprocessing/test_mp_pool.py +79 -0
- tests/parsers/test_argument_validation.py +133 -0
- tests/parsers/test_outlier_parser.py +53 -0
- tests/parsers/test_score_parser.py +63 -0
- tests/preprocessors/test_chunk_preprocessor.py +79 -0
- tests/preprocessors/test_feature_preprocessor.py +223 -0
- tests/registries/test_registries.py +74 -0
- tests/stats/test_danc_statistic.py +51 -0
- tests/stats/test_dd_statistic.py +45 -0
- tests/stats/test_df_statistic.py +73 -0
- tests/stats/test_dplus_statistic.py +79 -0
- tests/stats/test_fd_statistic.py +68 -0
- tests/stats/test_q_statistic.py +268 -0
- tests/stats/test_stat_utils.py +354 -0
- tests/stats/test_u_statistic.py +233 -0
- tests/test___main__.py +51 -0
- tests/test_sai.py +102 -0
- tests/utils/test_utils.py +511 -0
- sai/parsers/plot_parser.py +0 -152
- sai/stats/features.py +0 -302
- sai/utils/preprocessors/feature_preprocessor.py +0 -211
- sai_pg-1.0.1.dist-info/RECORD +0 -30
- sai_pg-1.0.1.dist-info/top_level.txt +0 -1
- /sai/{utils/generators → generators}/__init__.py +0 -0
- /sai/{utils/generators → generators}/data_generator.py +0 -0
- /sai/{utils/multiprocessing → multiprocessing}/__init__.py +0 -0
- /sai/{utils/preprocessors → preprocessors}/__init__.py +0 -0
- /sai/{utils/preprocessors → preprocessors}/data_preprocessor.py +0 -0
- {sai_pg-1.0.1.dist-info → sai_pg-1.1.0.dist-info}/entry_points.txt +0 -0
- {sai_pg-1.0.1.dist-info → sai_pg-1.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,86 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
from typing import Dict, Any
|
23
|
+
from sai.registries.stat_registry import STAT_REGISTRY
|
24
|
+
from sai.stats import GenericStatistic
|
25
|
+
from sai.stats.stat_utils import calc_four_pops_freq, calc_pattern_sum
|
26
|
+
|
27
|
+
|
28
|
+
@STAT_REGISTRY.register("Dplus")
|
29
|
+
class DplusStatistic(GenericStatistic):
|
30
|
+
"""
|
31
|
+
Class for computing the D+ statistic (Fang et al. 2024. PLoS Genet).
|
32
|
+
|
33
|
+
The D+ statistic is a modified version of the ABBA-BABA test that incorporates
|
34
|
+
additional site patterns (BAAA and ABAA) to improve detection of introgression.
|
35
|
+
"""
|
36
|
+
|
37
|
+
STAT_NAME = "Dplus"
|
38
|
+
|
39
|
+
def compute(self, **kwargs) -> Dict[str, Any]:
|
40
|
+
"""
|
41
|
+
Computes the D+ statistic for each source population.
|
42
|
+
|
43
|
+
This method calculates the D+ statistic based on the four-population test using
|
44
|
+
ABBA, BABA, BAAA, and ABAA site patterns. It iterates over all source populations
|
45
|
+
and returns a D+ value for each.
|
46
|
+
|
47
|
+
Parameters
|
48
|
+
----------
|
49
|
+
**kwargs : dict
|
50
|
+
Unused. Present to maintain compatibility with the base class interface.
|
51
|
+
|
52
|
+
Returns
|
53
|
+
-------
|
54
|
+
dict
|
55
|
+
A dictionary containing:
|
56
|
+
- 'name' : str
|
57
|
+
The name of the statistic ("Dplus").
|
58
|
+
- 'value' : list[float]
|
59
|
+
A list of D⁺ values, one for each source population.
|
60
|
+
"""
|
61
|
+
dplus_results = []
|
62
|
+
|
63
|
+
for i in range(len(self.src_gts_list)):
|
64
|
+
ref_freq, tgt_freq, src_freq, out_freq = calc_four_pops_freq(
|
65
|
+
ref_gts=self.ref_gts,
|
66
|
+
tgt_gts=self.tgt_gts,
|
67
|
+
src_gts=self.src_gts_list[i],
|
68
|
+
out_gts=self.out_gts,
|
69
|
+
ref_ploidy=self.ref_ploidy,
|
70
|
+
tgt_ploidy=self.tgt_ploidy,
|
71
|
+
src_ploidy=self.src_ploidy_list[i],
|
72
|
+
out_ploidy=self.out_ploidy,
|
73
|
+
)
|
74
|
+
|
75
|
+
abba = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "abba")
|
76
|
+
baba = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "baba")
|
77
|
+
baaa = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "baaa")
|
78
|
+
abaa = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "abaa")
|
79
|
+
|
80
|
+
numerator = abba - baba + baaa - abaa
|
81
|
+
denominator = abba + baba + baaa + abaa
|
82
|
+
|
83
|
+
dplus = numerator / denominator if denominator != 0 else np.nan
|
84
|
+
dplus_results.append(dplus)
|
85
|
+
|
86
|
+
return {"name": self.STAT_NAME, "value": dplus_results}
|
@@ -0,0 +1,92 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
from typing import Dict, Any
|
23
|
+
from sai.registries.stat_registry import STAT_REGISTRY
|
24
|
+
from sai.stats import GenericStatistic
|
25
|
+
from sai.stats.stat_utils import calc_four_pops_freq, calc_pattern_sum
|
26
|
+
|
27
|
+
|
28
|
+
@STAT_REGISTRY.register("fd")
|
29
|
+
class FdStatistic(GenericStatistic):
|
30
|
+
"""
|
31
|
+
Class for computing the dynamic estimator of the proportion of introgression (Martin et al. 2015. Mol Biol Evol).
|
32
|
+
|
33
|
+
The fd statistic is a dynamic estimator of the proportion of introgression
|
34
|
+
from a source into the target population, based on ABBA-BABA pattern sums.
|
35
|
+
"""
|
36
|
+
|
37
|
+
STAT_NAME = "fd"
|
38
|
+
|
39
|
+
def compute(self, **kwargs) -> Dict[str, Any]:
|
40
|
+
"""
|
41
|
+
Computes the fd statistic for each source population.
|
42
|
+
|
43
|
+
This method iterates over each source population, computes allele
|
44
|
+
frequencies, and evaluates the fd value using the standard numerator and
|
45
|
+
dynamic denominator formulation.
|
46
|
+
|
47
|
+
Parameters
|
48
|
+
----------
|
49
|
+
**kwargs : dict
|
50
|
+
Unused. Present to maintain compatibility with the base class interface.
|
51
|
+
|
52
|
+
Returns
|
53
|
+
-------
|
54
|
+
dict
|
55
|
+
A dictionary containing:
|
56
|
+
- 'name' : str
|
57
|
+
The name of the statistic ("fd").
|
58
|
+
- 'value' : list[float]
|
59
|
+
List of fd values, one for each source population.
|
60
|
+
"""
|
61
|
+
fd_results = []
|
62
|
+
|
63
|
+
for i in range(len(self.src_gts_list)):
|
64
|
+
ref_freq, tgt_freq, src_freq, out_freq = calc_four_pops_freq(
|
65
|
+
ref_gts=self.ref_gts,
|
66
|
+
tgt_gts=self.tgt_gts,
|
67
|
+
src_gts=self.src_gts_list[i],
|
68
|
+
out_gts=self.out_gts,
|
69
|
+
ref_ploidy=self.ref_ploidy,
|
70
|
+
tgt_ploidy=self.tgt_ploidy,
|
71
|
+
src_ploidy=self.src_ploidy_list[i],
|
72
|
+
out_ploidy=self.out_ploidy,
|
73
|
+
)
|
74
|
+
|
75
|
+
abba_n = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "abba")
|
76
|
+
baba_n = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "baba")
|
77
|
+
|
78
|
+
dnr_freq = np.maximum(tgt_freq, src_freq)
|
79
|
+
|
80
|
+
abba_d = calc_pattern_sum(ref_freq, dnr_freq, dnr_freq, out_freq, "abba")
|
81
|
+
baba_d = calc_pattern_sum(ref_freq, dnr_freq, dnr_freq, out_freq, "baba")
|
82
|
+
|
83
|
+
numerator = abba_n - baba_n
|
84
|
+
denominator = abba_d - baba_d
|
85
|
+
|
86
|
+
fd = numerator / denominator if denominator != 0 else np.nan
|
87
|
+
fd_results.append(fd)
|
88
|
+
|
89
|
+
# for i in range(len(ref_freq)):
|
90
|
+
# print(f"{ref_freq[i]}\t{tgt_freq[i]}\t{src_freq[i]}\t{out_freq[i]}")
|
91
|
+
|
92
|
+
return {"name": self.STAT_NAME, "value": fd_results}
|
@@ -0,0 +1,93 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
from abc import ABC, abstractmethod
|
23
|
+
from typing import Dict, Any, Optional
|
24
|
+
|
25
|
+
|
26
|
+
class GenericStatistic(ABC):
|
27
|
+
"""
|
28
|
+
Generic class for all statistics.
|
29
|
+
|
30
|
+
This class provides a generic interface for implementing specific statistical measures
|
31
|
+
from genotype matrices, typically representing different populations or samples.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
ref_gts: np.ndarray,
|
37
|
+
tgt_gts: np.ndarray,
|
38
|
+
ref_ploidy: int,
|
39
|
+
tgt_ploidy: int,
|
40
|
+
src_gts_list: list[np.ndarray],
|
41
|
+
src_ploidy_list: list[int],
|
42
|
+
out_gts: Optional[np.ndarray] = None,
|
43
|
+
out_ploidy: Optional[int] = None,
|
44
|
+
):
|
45
|
+
"""
|
46
|
+
Initializes the statistic with reference and target genotypes and their ploidies.
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
ref_gts : np.ndarray
|
51
|
+
A 2D numpy array where each row represents a locus and each column represents an individual in the reference group.
|
52
|
+
tgt_gts : np.ndarray
|
53
|
+
A 2D numpy array where each row represents a locus and each column represents an individual in the target group.
|
54
|
+
ref_ploidy : int
|
55
|
+
Ploidy level of the reference population.
|
56
|
+
tgt_ploidy : int
|
57
|
+
Ploidy level of the target population.
|
58
|
+
src_gts_list: list[np.ndarray]
|
59
|
+
A list of 2D numpy arrays for each source population, where each row represents a locus and each column
|
60
|
+
represents an individual in that source population.
|
61
|
+
src_ploidy_list: list[int]
|
62
|
+
A list of ploidy levels for the source populations. If provided, must match the number of source genotype arrays.
|
63
|
+
out_gts: Optional[np.ndarray]
|
64
|
+
A 2D numpy array where each row represents a locus and each column represents an individual in the outgroup.
|
65
|
+
Default: None.
|
66
|
+
out_ploidy: Optional[int]
|
67
|
+
Ploidy level of the outgroup. Default: None.
|
68
|
+
"""
|
69
|
+
self.ref_gts = ref_gts
|
70
|
+
self.tgt_gts = tgt_gts
|
71
|
+
self.src_gts_list = src_gts_list
|
72
|
+
self.out_gts = out_gts
|
73
|
+
self.ref_ploidy = ref_ploidy
|
74
|
+
self.tgt_ploidy = tgt_ploidy
|
75
|
+
self.src_ploidy_list = src_ploidy_list
|
76
|
+
self.out_ploidy = out_ploidy
|
77
|
+
|
78
|
+
@abstractmethod
|
79
|
+
def compute(self, **kwargs) -> Dict[str, Any]:
|
80
|
+
"""
|
81
|
+
Computes the statistic based on the input genotype data.
|
82
|
+
|
83
|
+
Parameters
|
84
|
+
----------
|
85
|
+
**kwargs : dict
|
86
|
+
Additional keyword arguments specific to the statistic being implemented.
|
87
|
+
|
88
|
+
Returns
|
89
|
+
-------
|
90
|
+
dict
|
91
|
+
A dictionary containing the results of the statistic computation.
|
92
|
+
"""
|
93
|
+
pass
|
sai/stats/q_statistic.py
ADDED
@@ -0,0 +1,104 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
from typing import Dict, Any
|
23
|
+
from sai.registries.stat_registry import STAT_REGISTRY
|
24
|
+
from sai.stats import GenericStatistic
|
25
|
+
from sai.stats.stat_utils import compute_matching_loci
|
26
|
+
|
27
|
+
|
28
|
+
@STAT_REGISTRY.register("Q")
|
29
|
+
class QStatistic(GenericStatistic):
|
30
|
+
"""
|
31
|
+
Class for computing the quantile statistic in the target population (Racimo et al. 2017. Mol Biol Evol),
|
32
|
+
conditional on allele frequency patterns in the reference and source populations.
|
33
|
+
"""
|
34
|
+
|
35
|
+
STAT_NAME = "Q"
|
36
|
+
|
37
|
+
def compute(self, **kwargs) -> Dict[str, Any]:
|
38
|
+
"""
|
39
|
+
Calculates a specified quantile of derived allele frequencies in `tgt_gts` for loci that meet specific conditions
|
40
|
+
across reference and multiple source genotypes, with adjustments based on src_freq consistency.
|
41
|
+
|
42
|
+
Parameters
|
43
|
+
----------
|
44
|
+
pos: np.ndarray
|
45
|
+
A 1D numpy array where each element represents the genomic position.
|
46
|
+
w : float
|
47
|
+
Frequency threshold for the derived allele in `ref_gts`. Only loci with frequencies lower than `w` are included.
|
48
|
+
Must be within the range [0, 1].
|
49
|
+
y_list : list[float]
|
50
|
+
List of exact frequency thresholds for each source population in `src_gts_list`.
|
51
|
+
Must be within the range [0, 1] and have the same length as `src_gts_list`.
|
52
|
+
quantile : float
|
53
|
+
The quantile to compute for the filtered `tgt_gts` frequencies. Must be within the range [0, 1].
|
54
|
+
anc_allele_available : bool
|
55
|
+
If True, checks only for matches with `y` (assuming `1` represents the derived allele).
|
56
|
+
If False, checks both matches with `y` and `1 - y`, taking the major allele in the source as the reference.
|
57
|
+
|
58
|
+
Returns
|
59
|
+
-------
|
60
|
+
dict
|
61
|
+
A dictionary containing:
|
62
|
+
- 'name' : str
|
63
|
+
The name of the statistic ("Q").
|
64
|
+
- 'value' : float
|
65
|
+
The specified quantile of the derived allele frequencies in `tgt_gts` for loci meeting the specified conditions,
|
66
|
+
or NaN if no loci meet the criteria.
|
67
|
+
- 'ccd_pos' : np.ndarray
|
68
|
+
A 1D numpy array containing the genomic positions of the loci that meet the conditions.
|
69
|
+
"""
|
70
|
+
required_keys = ["pos", "w", "y_list", "anc_allele_available", "quantile"]
|
71
|
+
if missing := [k for k in required_keys if k not in kwargs]:
|
72
|
+
raise ValueError(f"Missing required argument(s): {', '.join(missing)}")
|
73
|
+
|
74
|
+
pos = kwargs["pos"]
|
75
|
+
w = kwargs["w"]
|
76
|
+
y_list = kwargs["y_list"]
|
77
|
+
anc_allele_available = kwargs["anc_allele_available"]
|
78
|
+
quantile = kwargs["quantile"]
|
79
|
+
ploidy = [self.ref_ploidy, self.tgt_ploidy] + self.src_ploidy_list
|
80
|
+
|
81
|
+
ref_freq, tgt_freq, condition = compute_matching_loci(
|
82
|
+
self.ref_gts,
|
83
|
+
self.tgt_gts,
|
84
|
+
self.src_gts_list,
|
85
|
+
w,
|
86
|
+
y_list,
|
87
|
+
ploidy,
|
88
|
+
anc_allele_available,
|
89
|
+
)
|
90
|
+
|
91
|
+
# Filter `tgt_gts` frequencies based on the combined condition
|
92
|
+
filtered_tgt_freq = tgt_freq[condition]
|
93
|
+
filtered_positions = pos[condition]
|
94
|
+
|
95
|
+
# Return NaN if no loci meet the criteria
|
96
|
+
if filtered_tgt_freq.size == 0:
|
97
|
+
threshold = np.nan
|
98
|
+
loci_positions = np.array([])
|
99
|
+
else:
|
100
|
+
threshold = np.nanquantile(filtered_tgt_freq, quantile)
|
101
|
+
loci_positions = filtered_positions[filtered_tgt_freq >= threshold]
|
102
|
+
|
103
|
+
# Calculate and return the specified quantile of the filtered `tgt_gts` frequencies
|
104
|
+
return {"name": self.STAT_NAME, "value": threshold, "cdd_pos": loci_positions}
|
sai/stats/stat_utils.py
ADDED
@@ -0,0 +1,259 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
from typing import Tuple, Optional, Union
|
23
|
+
|
24
|
+
|
25
|
+
def calc_freq(gts: np.ndarray, ploidy: int = 1) -> np.ndarray:
|
26
|
+
"""
|
27
|
+
Calculates allele frequencies, supporting both phased and unphased data.
|
28
|
+
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
gts : np.ndarray
|
32
|
+
A 2D numpy array where each row represents a locus and each column represents an individual.
|
33
|
+
ploidy : int, optional
|
34
|
+
Ploidy level of the organism. If ploidy=1, the function assumes phased data and calculates
|
35
|
+
frequency by taking the mean across individuals. For unphased data, it calculates frequency by
|
36
|
+
dividing the sum across individuals by the total number of alleles. Default is 1.
|
37
|
+
|
38
|
+
Returns
|
39
|
+
-------
|
40
|
+
np.ndarray
|
41
|
+
An array of allele frequencies for each locus.
|
42
|
+
|
43
|
+
Raises
|
44
|
+
------
|
45
|
+
ValueError
|
46
|
+
If ploidy is not a positive integer.
|
47
|
+
"""
|
48
|
+
if not isinstance(ploidy, int) or ploidy <= 0:
|
49
|
+
raise ValueError("ploidy must be a positive integer.")
|
50
|
+
|
51
|
+
return np.sum(gts, axis=1) / (gts.shape[1] * ploidy)
|
52
|
+
|
53
|
+
|
54
|
+
def compute_matching_loci(
|
55
|
+
ref_gts: np.ndarray,
|
56
|
+
tgt_gts: np.ndarray,
|
57
|
+
src_gts_list: list[np.ndarray],
|
58
|
+
w: float,
|
59
|
+
y_list: list[tuple[str, float]],
|
60
|
+
ploidy: list[int],
|
61
|
+
anc_allele_available: bool,
|
62
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
63
|
+
"""
|
64
|
+
Computes loci that meet specified allele frequency conditions across reference, target, and source genotypes.
|
65
|
+
|
66
|
+
Parameters
|
67
|
+
----------
|
68
|
+
ref_gts : np.ndarray
|
69
|
+
A 2D numpy array where each row represents a locus and each column represents an individual in the reference group.
|
70
|
+
tgt_gts : np.ndarray
|
71
|
+
A 2D numpy array where each row represents a locus and each column represents an individual in the target group.
|
72
|
+
src_gts_list : list of np.ndarray
|
73
|
+
A list of 2D numpy arrays for each source population, where each row represents a locus and each column
|
74
|
+
represents an individual in that source population.
|
75
|
+
w : float
|
76
|
+
Threshold for the allele frequency in `ref_gts`. Only loci with frequencies less than `w` are counted.
|
77
|
+
Must be within the range [0, 1].
|
78
|
+
y_list : list of tuple[str, float]
|
79
|
+
List of allele frequency conditions for each source population in `src_gts_list`.
|
80
|
+
Each entry is a tuple (operator, threshold), where:
|
81
|
+
- `operator` can be '=', '<', '>', '<=', '>='
|
82
|
+
- `threshold` is a float within [0, 1]
|
83
|
+
The length must match `src_gts_list`.
|
84
|
+
ploidy : list[int]
|
85
|
+
Ploidy values for reference, target, and one or more source populations (in that order).
|
86
|
+
anc_allele_available : bool
|
87
|
+
If True, checks only for matches with `y` (assuming `1` represents the derived allele).
|
88
|
+
If False, checks both matches with `y` and `1 - y`, taking the dominant allele in the source as the reference.
|
89
|
+
|
90
|
+
Returns
|
91
|
+
-------
|
92
|
+
tuple[np.ndarray, np.ndarray, np.ndarray]
|
93
|
+
- Adjusted reference allele frequencies (`ref_freq`).
|
94
|
+
- Adjusted target allele frequencies (`tgt_freq`).
|
95
|
+
- Boolean array indicating loci that meet the specified frequency conditions (`condition`).
|
96
|
+
"""
|
97
|
+
# Validate input parameters
|
98
|
+
if not (0 <= w <= 1):
|
99
|
+
raise ValueError("Parameters w must be within the range [0, 1].")
|
100
|
+
|
101
|
+
for op, y in y_list:
|
102
|
+
if not (0 <= y <= 1):
|
103
|
+
raise ValueError(f"Invalid value in y_list: {y}. within the range [0, 1].")
|
104
|
+
if op not in ("=", "<", ">", "<=", ">="):
|
105
|
+
raise ValueError(
|
106
|
+
f"Invalid operator in y_list: {op}. Must be '=', '<', '>', '<=', or '>='."
|
107
|
+
)
|
108
|
+
|
109
|
+
if len(src_gts_list) != len(y_list):
|
110
|
+
raise ValueError("The length of src_gts_list and y_list must match.")
|
111
|
+
|
112
|
+
# Compute allele frequencies
|
113
|
+
ref_freq = calc_freq(ref_gts, ploidy[0])
|
114
|
+
tgt_freq = calc_freq(tgt_gts, ploidy[1])
|
115
|
+
src_freq_list = [
|
116
|
+
calc_freq(src_gts, ploidy_val)
|
117
|
+
for src_gts, ploidy_val in zip(src_gts_list, ploidy[2:])
|
118
|
+
]
|
119
|
+
|
120
|
+
# Check match for each `y`
|
121
|
+
op_funcs = {
|
122
|
+
"=": lambda src_freq, y: src_freq == y,
|
123
|
+
"<": lambda src_freq, y: src_freq < y,
|
124
|
+
">": lambda src_freq, y: src_freq > y,
|
125
|
+
"<=": lambda src_freq, y: src_freq <= y,
|
126
|
+
">=": lambda src_freq, y: src_freq >= y,
|
127
|
+
}
|
128
|
+
|
129
|
+
match_conditions = [
|
130
|
+
op_funcs[op](src_freq, y) for src_freq, (op, y) in zip(src_freq_list, y_list)
|
131
|
+
]
|
132
|
+
all_match_y = np.all(match_conditions, axis=0)
|
133
|
+
|
134
|
+
if not anc_allele_available:
|
135
|
+
# Check if all source populations match `1 - y`
|
136
|
+
match_conditions_1_minus_y = [
|
137
|
+
op_funcs[op](src_freq, 1 - y)
|
138
|
+
for src_freq, (op, y) in zip(src_freq_list, y_list)
|
139
|
+
]
|
140
|
+
all_match_1_minus_y = np.all(match_conditions_1_minus_y, axis=0)
|
141
|
+
all_match = all_match_y | all_match_1_minus_y
|
142
|
+
|
143
|
+
# Identify loci where all sources match `1 - y` for frequency inversion
|
144
|
+
inverted = all_match_1_minus_y
|
145
|
+
|
146
|
+
# Invert frequencies for these loci
|
147
|
+
ref_freq[inverted] = 1 - ref_freq[inverted]
|
148
|
+
tgt_freq[inverted] = 1 - tgt_freq[inverted]
|
149
|
+
else:
|
150
|
+
all_match = all_match_y
|
151
|
+
|
152
|
+
# Final condition: locus must satisfy source matching and have `ref_freq < w`
|
153
|
+
condition = all_match & (ref_freq < w)
|
154
|
+
|
155
|
+
return ref_freq, tgt_freq, condition
|
156
|
+
|
157
|
+
|
158
|
+
def calc_four_pops_freq(
|
159
|
+
ref_gts: np.ndarray,
|
160
|
+
tgt_gts: np.ndarray,
|
161
|
+
src_gts: np.ndarray,
|
162
|
+
out_gts: Optional[np.ndarray] = None,
|
163
|
+
ref_ploidy: int = 1,
|
164
|
+
tgt_ploidy: int = 1,
|
165
|
+
src_ploidy: int = 1,
|
166
|
+
out_ploidy: int = 1,
|
167
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
168
|
+
"""
|
169
|
+
Calculates allele frequencies for four populations given their genotype matrices.
|
170
|
+
|
171
|
+
Parameters
|
172
|
+
----------
|
173
|
+
ref_gts : np.ndarray
|
174
|
+
Genotype matrix for the reference population.
|
175
|
+
tgt_gts : np.ndarray
|
176
|
+
Genotype matrix for the target population.
|
177
|
+
src_gts : np.ndarray
|
178
|
+
Genotype matrix for the source population.
|
179
|
+
out_gts : np.ndarray
|
180
|
+
Genotype matrix for the outgroup. If None, the outgroup frequency is assumed to be 0 at all loci.
|
181
|
+
Default: None.
|
182
|
+
ref_ploidy : int, optional
|
183
|
+
Ploidy level of the genomes from the reference population. Default: 1 (phased data).
|
184
|
+
tgt_ploidy : int, optional
|
185
|
+
Ploidy level of the genomes from the target population. Default: 1 (phased data).
|
186
|
+
src_ploidy : int, optional
|
187
|
+
Ploidy level of the genomes from the source population. Default: 1 (phased data).
|
188
|
+
out_ploidy : int, optional
|
189
|
+
Ploidy level of the genomes from the outgroup. Default: 1 (phased data).
|
190
|
+
|
191
|
+
Returns
|
192
|
+
-------
|
193
|
+
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
194
|
+
Allele frequencies for (ref, tgt, src, out) populations.
|
195
|
+
"""
|
196
|
+
ref_freq = calc_freq(ref_gts, ref_ploidy)
|
197
|
+
tgt_freq = calc_freq(tgt_gts, tgt_ploidy)
|
198
|
+
src_freq = calc_freq(src_gts, src_ploidy)
|
199
|
+
if out_gts is None:
|
200
|
+
out_freq = np.zeros_like(ref_freq)
|
201
|
+
else:
|
202
|
+
out_freq = calc_freq(out_gts, out_ploidy)
|
203
|
+
|
204
|
+
return ref_freq, tgt_freq, src_freq, out_freq
|
205
|
+
|
206
|
+
|
207
|
+
def calc_pattern_sum(
|
208
|
+
ref_freq: np.ndarray,
|
209
|
+
tgt_freq: np.ndarray,
|
210
|
+
src_freq: np.ndarray,
|
211
|
+
out_freq: np.ndarray,
|
212
|
+
pattern: str,
|
213
|
+
) -> float:
|
214
|
+
"""
|
215
|
+
Applies an ABBA-like pattern and returns the sum over loci of the transformed frequency products.
|
216
|
+
|
217
|
+
Parameters
|
218
|
+
----------
|
219
|
+
ref_freq:
|
220
|
+
Allele frequencies for the reference population (no introgression) across loci.
|
221
|
+
tgt_freq:
|
222
|
+
Allele frequencies for the target population (receive introgression) across loci.
|
223
|
+
src_freq:
|
224
|
+
Allele frequencies for the source population (provide introgression) across loci.
|
225
|
+
out_freq:
|
226
|
+
Allele frequencies for the outgroup across loci.
|
227
|
+
pattern : str
|
228
|
+
A 4-character pattern string (e.g., 'abba'), where:
|
229
|
+
- 'a': use 1 - freq
|
230
|
+
- 'b': use freq
|
231
|
+
|
232
|
+
Returns
|
233
|
+
-------
|
234
|
+
float
|
235
|
+
Sum over loci of the product defined by the pattern.
|
236
|
+
|
237
|
+
Raises
|
238
|
+
------
|
239
|
+
ValueError
|
240
|
+
- If the pattern string is not exactly four characters long.
|
241
|
+
- If the pattern contains characters other than 'a' or 'b'.
|
242
|
+
"""
|
243
|
+
if len(pattern) != 4:
|
244
|
+
raise ValueError("Pattern must be a four-character string.")
|
245
|
+
|
246
|
+
freqs = [ref_freq, tgt_freq, src_freq, out_freq]
|
247
|
+
product = np.ones_like(ref_freq)
|
248
|
+
|
249
|
+
for i, c in enumerate(pattern.lower()):
|
250
|
+
if c == "a":
|
251
|
+
product *= 1 - freqs[i]
|
252
|
+
elif c == "b":
|
253
|
+
product *= freqs[i]
|
254
|
+
else:
|
255
|
+
raise ValueError(
|
256
|
+
f"Invalid character '{c}' in pattern. Only 'a' and 'b' allowed."
|
257
|
+
)
|
258
|
+
|
259
|
+
return float(np.sum(product))
|