sai-pg 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. sai/__init__.py +2 -0
  2. sai/__main__.py +6 -3
  3. sai/configs/__init__.py +24 -0
  4. sai/configs/global_config.py +83 -0
  5. sai/configs/ploidy_config.py +94 -0
  6. sai/configs/pop_config.py +82 -0
  7. sai/configs/stat_config.py +220 -0
  8. sai/{utils/generators → generators}/chunk_generator.py +2 -8
  9. sai/{utils/generators → generators}/window_generator.py +82 -37
  10. sai/{utils/multiprocessing → multiprocessing}/mp_manager.py +2 -2
  11. sai/{utils/multiprocessing → multiprocessing}/mp_pool.py +2 -2
  12. sai/parsers/outlier_parser.py +4 -3
  13. sai/parsers/score_parser.py +8 -119
  14. sai/{utils/preprocessors → preprocessors}/chunk_preprocessor.py +21 -15
  15. sai/preprocessors/feature_preprocessor.py +236 -0
  16. sai/registries/__init__.py +22 -0
  17. sai/registries/generic_registry.py +89 -0
  18. sai/registries/stat_registry.py +30 -0
  19. sai/sai.py +124 -220
  20. sai/stats/__init__.py +11 -0
  21. sai/stats/danc_statistic.py +83 -0
  22. sai/stats/dd_statistic.py +77 -0
  23. sai/stats/df_statistic.py +84 -0
  24. sai/stats/dplus_statistic.py +86 -0
  25. sai/stats/fd_statistic.py +92 -0
  26. sai/stats/generic_statistic.py +93 -0
  27. sai/stats/q_statistic.py +104 -0
  28. sai/stats/stat_utils.py +259 -0
  29. sai/stats/u_statistic.py +99 -0
  30. sai/utils/utils.py +220 -143
  31. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/METADATA +3 -14
  32. sai_pg-1.1.0.dist-info/RECORD +70 -0
  33. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/WHEEL +1 -1
  34. sai_pg-1.1.0.dist-info/top_level.txt +2 -0
  35. tests/configs/test_global_config.py +163 -0
  36. tests/configs/test_ploidy_config.py +93 -0
  37. tests/configs/test_pop_config.py +90 -0
  38. tests/configs/test_stat_config.py +171 -0
  39. tests/generators/test_chunk_generator.py +51 -0
  40. tests/generators/test_window_generator.py +164 -0
  41. tests/multiprocessing/test_mp_manager.py +92 -0
  42. tests/multiprocessing/test_mp_pool.py +79 -0
  43. tests/parsers/test_argument_validation.py +133 -0
  44. tests/parsers/test_outlier_parser.py +53 -0
  45. tests/parsers/test_score_parser.py +63 -0
  46. tests/preprocessors/test_chunk_preprocessor.py +79 -0
  47. tests/preprocessors/test_feature_preprocessor.py +223 -0
  48. tests/registries/test_registries.py +74 -0
  49. tests/stats/test_danc_statistic.py +51 -0
  50. tests/stats/test_dd_statistic.py +45 -0
  51. tests/stats/test_df_statistic.py +73 -0
  52. tests/stats/test_dplus_statistic.py +79 -0
  53. tests/stats/test_fd_statistic.py +68 -0
  54. tests/stats/test_q_statistic.py +268 -0
  55. tests/stats/test_stat_utils.py +354 -0
  56. tests/stats/test_u_statistic.py +233 -0
  57. tests/test___main__.py +51 -0
  58. tests/test_sai.py +102 -0
  59. tests/utils/test_utils.py +511 -0
  60. sai/parsers/plot_parser.py +0 -152
  61. sai/stats/features.py +0 -302
  62. sai/utils/preprocessors/feature_preprocessor.py +0 -211
  63. sai_pg-1.0.0.dist-info/RECORD +0 -30
  64. sai_pg-1.0.0.dist-info/top_level.txt +0 -1
  65. /sai/{utils/generators → generators}/__init__.py +0 -0
  66. /sai/{utils/generators → generators}/data_generator.py +0 -0
  67. /sai/{utils/multiprocessing → multiprocessing}/__init__.py +0 -0
  68. /sai/{utils/preprocessors → preprocessors}/__init__.py +0 -0
  69. /sai/{utils/preprocessors → preprocessors}/data_preprocessor.py +0 -0
  70. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/entry_points.txt +0 -0
  71. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,86 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from typing import Dict, Any
23
+ from sai.registries.stat_registry import STAT_REGISTRY
24
+ from sai.stats import GenericStatistic
25
+ from sai.stats.stat_utils import calc_four_pops_freq, calc_pattern_sum
26
+
27
+
28
+ @STAT_REGISTRY.register("Dplus")
29
+ class DplusStatistic(GenericStatistic):
30
+ """
31
+ Class for computing the D+ statistic (Fang et al. 2024. PLoS Genet).
32
+
33
+ The D+ statistic is a modified version of the ABBA-BABA test that incorporates
34
+ additional site patterns (BAAA and ABAA) to improve detection of introgression.
35
+ """
36
+
37
+ STAT_NAME = "Dplus"
38
+
39
+ def compute(self, **kwargs) -> Dict[str, Any]:
40
+ """
41
+ Computes the D+ statistic for each source population.
42
+
43
+ This method calculates the D+ statistic based on the four-population test using
44
+ ABBA, BABA, BAAA, and ABAA site patterns. It iterates over all source populations
45
+ and returns a D+ value for each.
46
+
47
+ Parameters
48
+ ----------
49
+ **kwargs : dict
50
+ Unused. Present to maintain compatibility with the base class interface.
51
+
52
+ Returns
53
+ -------
54
+ dict
55
+ A dictionary containing:
56
+ - 'name' : str
57
+ The name of the statistic ("Dplus").
58
+ - 'value' : list[float]
59
+ A list of D⁺ values, one for each source population.
60
+ """
61
+ dplus_results = []
62
+
63
+ for i in range(len(self.src_gts_list)):
64
+ ref_freq, tgt_freq, src_freq, out_freq = calc_four_pops_freq(
65
+ ref_gts=self.ref_gts,
66
+ tgt_gts=self.tgt_gts,
67
+ src_gts=self.src_gts_list[i],
68
+ out_gts=self.out_gts,
69
+ ref_ploidy=self.ref_ploidy,
70
+ tgt_ploidy=self.tgt_ploidy,
71
+ src_ploidy=self.src_ploidy_list[i],
72
+ out_ploidy=self.out_ploidy,
73
+ )
74
+
75
+ abba = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "abba")
76
+ baba = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "baba")
77
+ baaa = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "baaa")
78
+ abaa = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "abaa")
79
+
80
+ numerator = abba - baba + baaa - abaa
81
+ denominator = abba + baba + baaa + abaa
82
+
83
+ dplus = numerator / denominator if denominator != 0 else np.nan
84
+ dplus_results.append(dplus)
85
+
86
+ return {"name": self.STAT_NAME, "value": dplus_results}
@@ -0,0 +1,92 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from typing import Dict, Any
23
+ from sai.registries.stat_registry import STAT_REGISTRY
24
+ from sai.stats import GenericStatistic
25
+ from sai.stats.stat_utils import calc_four_pops_freq, calc_pattern_sum
26
+
27
+
28
+ @STAT_REGISTRY.register("fd")
29
+ class FdStatistic(GenericStatistic):
30
+ """
31
+ Class for computing the dynamic estimator of the proportion of introgression (Martin et al. 2015. Mol Biol Evol).
32
+
33
+ The fd statistic is a dynamic estimator of the proportion of introgression
34
+ from a source into the target population, based on ABBA-BABA pattern sums.
35
+ """
36
+
37
+ STAT_NAME = "fd"
38
+
39
+ def compute(self, **kwargs) -> Dict[str, Any]:
40
+ """
41
+ Computes the fd statistic for each source population.
42
+
43
+ This method iterates over each source population, computes allele
44
+ frequencies, and evaluates the fd value using the standard numerator and
45
+ dynamic denominator formulation.
46
+
47
+ Parameters
48
+ ----------
49
+ **kwargs : dict
50
+ Unused. Present to maintain compatibility with the base class interface.
51
+
52
+ Returns
53
+ -------
54
+ dict
55
+ A dictionary containing:
56
+ - 'name' : str
57
+ The name of the statistic ("fd").
58
+ - 'value' : list[float]
59
+ List of fd values, one for each source population.
60
+ """
61
+ fd_results = []
62
+
63
+ for i in range(len(self.src_gts_list)):
64
+ ref_freq, tgt_freq, src_freq, out_freq = calc_four_pops_freq(
65
+ ref_gts=self.ref_gts,
66
+ tgt_gts=self.tgt_gts,
67
+ src_gts=self.src_gts_list[i],
68
+ out_gts=self.out_gts,
69
+ ref_ploidy=self.ref_ploidy,
70
+ tgt_ploidy=self.tgt_ploidy,
71
+ src_ploidy=self.src_ploidy_list[i],
72
+ out_ploidy=self.out_ploidy,
73
+ )
74
+
75
+ abba_n = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "abba")
76
+ baba_n = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "baba")
77
+
78
+ dnr_freq = np.maximum(tgt_freq, src_freq)
79
+
80
+ abba_d = calc_pattern_sum(ref_freq, dnr_freq, dnr_freq, out_freq, "abba")
81
+ baba_d = calc_pattern_sum(ref_freq, dnr_freq, dnr_freq, out_freq, "baba")
82
+
83
+ numerator = abba_n - baba_n
84
+ denominator = abba_d - baba_d
85
+
86
+ fd = numerator / denominator if denominator != 0 else np.nan
87
+ fd_results.append(fd)
88
+
89
+ # for i in range(len(ref_freq)):
90
+ # print(f"{ref_freq[i]}\t{tgt_freq[i]}\t{src_freq[i]}\t{out_freq[i]}")
91
+
92
+ return {"name": self.STAT_NAME, "value": fd_results}
@@ -0,0 +1,93 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from abc import ABC, abstractmethod
23
+ from typing import Dict, Any, Optional
24
+
25
+
26
+ class GenericStatistic(ABC):
27
+ """
28
+ Generic class for all statistics.
29
+
30
+ This class provides a generic interface for implementing specific statistical measures
31
+ from genotype matrices, typically representing different populations or samples.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ ref_gts: np.ndarray,
37
+ tgt_gts: np.ndarray,
38
+ ref_ploidy: int,
39
+ tgt_ploidy: int,
40
+ src_gts_list: list[np.ndarray],
41
+ src_ploidy_list: list[int],
42
+ out_gts: Optional[np.ndarray] = None,
43
+ out_ploidy: Optional[int] = None,
44
+ ):
45
+ """
46
+ Initializes the statistic with reference and target genotypes and their ploidies.
47
+
48
+ Parameters
49
+ ----------
50
+ ref_gts : np.ndarray
51
+ A 2D numpy array where each row represents a locus and each column represents an individual in the reference group.
52
+ tgt_gts : np.ndarray
53
+ A 2D numpy array where each row represents a locus and each column represents an individual in the target group.
54
+ ref_ploidy : int
55
+ Ploidy level of the reference population.
56
+ tgt_ploidy : int
57
+ Ploidy level of the target population.
58
+ src_gts_list: list[np.ndarray]
59
+ A list of 2D numpy arrays for each source population, where each row represents a locus and each column
60
+ represents an individual in that source population.
61
+ src_ploidy_list: list[int]
62
+ A list of ploidy levels for the source populations. If provided, must match the number of source genotype arrays.
63
+ out_gts: Optional[np.ndarray]
64
+ A 2D numpy array where each row represents a locus and each column represents an individual in the outgroup.
65
+ Default: None.
66
+ out_ploidy: Optional[int]
67
+ Ploidy level of the outgroup. Default: None.
68
+ """
69
+ self.ref_gts = ref_gts
70
+ self.tgt_gts = tgt_gts
71
+ self.src_gts_list = src_gts_list
72
+ self.out_gts = out_gts
73
+ self.ref_ploidy = ref_ploidy
74
+ self.tgt_ploidy = tgt_ploidy
75
+ self.src_ploidy_list = src_ploidy_list
76
+ self.out_ploidy = out_ploidy
77
+
78
+ @abstractmethod
79
+ def compute(self, **kwargs) -> Dict[str, Any]:
80
+ """
81
+ Computes the statistic based on the input genotype data.
82
+
83
+ Parameters
84
+ ----------
85
+ **kwargs : dict
86
+ Additional keyword arguments specific to the statistic being implemented.
87
+
88
+ Returns
89
+ -------
90
+ dict
91
+ A dictionary containing the results of the statistic computation.
92
+ """
93
+ pass
@@ -0,0 +1,104 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from typing import Dict, Any
23
+ from sai.registries.stat_registry import STAT_REGISTRY
24
+ from sai.stats import GenericStatistic
25
+ from sai.stats.stat_utils import compute_matching_loci
26
+
27
+
28
+ @STAT_REGISTRY.register("Q")
29
+ class QStatistic(GenericStatistic):
30
+ """
31
+ Class for computing the quantile statistic in the target population (Racimo et al. 2017. Mol Biol Evol),
32
+ conditional on allele frequency patterns in the reference and source populations.
33
+ """
34
+
35
+ STAT_NAME = "Q"
36
+
37
+ def compute(self, **kwargs) -> Dict[str, Any]:
38
+ """
39
+ Calculates a specified quantile of derived allele frequencies in `tgt_gts` for loci that meet specific conditions
40
+ across reference and multiple source genotypes, with adjustments based on src_freq consistency.
41
+
42
+ Parameters
43
+ ----------
44
+ pos: np.ndarray
45
+ A 1D numpy array where each element represents the genomic position.
46
+ w : float
47
+ Frequency threshold for the derived allele in `ref_gts`. Only loci with frequencies lower than `w` are included.
48
+ Must be within the range [0, 1].
49
+ y_list : list[float]
50
+ List of exact frequency thresholds for each source population in `src_gts_list`.
51
+ Must be within the range [0, 1] and have the same length as `src_gts_list`.
52
+ quantile : float
53
+ The quantile to compute for the filtered `tgt_gts` frequencies. Must be within the range [0, 1].
54
+ anc_allele_available : bool
55
+ If True, checks only for matches with `y` (assuming `1` represents the derived allele).
56
+ If False, checks both matches with `y` and `1 - y`, taking the major allele in the source as the reference.
57
+
58
+ Returns
59
+ -------
60
+ dict
61
+ A dictionary containing:
62
+ - 'name' : str
63
+ The name of the statistic ("Q").
64
+ - 'value' : float
65
+ The specified quantile of the derived allele frequencies in `tgt_gts` for loci meeting the specified conditions,
66
+ or NaN if no loci meet the criteria.
67
+ - 'ccd_pos' : np.ndarray
68
+ A 1D numpy array containing the genomic positions of the loci that meet the conditions.
69
+ """
70
+ required_keys = ["pos", "w", "y_list", "anc_allele_available", "quantile"]
71
+ if missing := [k for k in required_keys if k not in kwargs]:
72
+ raise ValueError(f"Missing required argument(s): {', '.join(missing)}")
73
+
74
+ pos = kwargs["pos"]
75
+ w = kwargs["w"]
76
+ y_list = kwargs["y_list"]
77
+ anc_allele_available = kwargs["anc_allele_available"]
78
+ quantile = kwargs["quantile"]
79
+ ploidy = [self.ref_ploidy, self.tgt_ploidy] + self.src_ploidy_list
80
+
81
+ ref_freq, tgt_freq, condition = compute_matching_loci(
82
+ self.ref_gts,
83
+ self.tgt_gts,
84
+ self.src_gts_list,
85
+ w,
86
+ y_list,
87
+ ploidy,
88
+ anc_allele_available,
89
+ )
90
+
91
+ # Filter `tgt_gts` frequencies based on the combined condition
92
+ filtered_tgt_freq = tgt_freq[condition]
93
+ filtered_positions = pos[condition]
94
+
95
+ # Return NaN if no loci meet the criteria
96
+ if filtered_tgt_freq.size == 0:
97
+ threshold = np.nan
98
+ loci_positions = np.array([])
99
+ else:
100
+ threshold = np.nanquantile(filtered_tgt_freq, quantile)
101
+ loci_positions = filtered_positions[filtered_tgt_freq >= threshold]
102
+
103
+ # Calculate and return the specified quantile of the filtered `tgt_gts` frequencies
104
+ return {"name": self.STAT_NAME, "value": threshold, "cdd_pos": loci_positions}
@@ -0,0 +1,259 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from typing import Tuple, Optional, Union
23
+
24
+
25
+ def calc_freq(gts: np.ndarray, ploidy: int = 1) -> np.ndarray:
26
+ """
27
+ Calculates allele frequencies, supporting both phased and unphased data.
28
+
29
+ Parameters
30
+ ----------
31
+ gts : np.ndarray
32
+ A 2D numpy array where each row represents a locus and each column represents an individual.
33
+ ploidy : int, optional
34
+ Ploidy level of the organism. If ploidy=1, the function assumes phased data and calculates
35
+ frequency by taking the mean across individuals. For unphased data, it calculates frequency by
36
+ dividing the sum across individuals by the total number of alleles. Default is 1.
37
+
38
+ Returns
39
+ -------
40
+ np.ndarray
41
+ An array of allele frequencies for each locus.
42
+
43
+ Raises
44
+ ------
45
+ ValueError
46
+ If ploidy is not a positive integer.
47
+ """
48
+ if not isinstance(ploidy, int) or ploidy <= 0:
49
+ raise ValueError("ploidy must be a positive integer.")
50
+
51
+ return np.sum(gts, axis=1) / (gts.shape[1] * ploidy)
52
+
53
+
54
+ def compute_matching_loci(
55
+ ref_gts: np.ndarray,
56
+ tgt_gts: np.ndarray,
57
+ src_gts_list: list[np.ndarray],
58
+ w: float,
59
+ y_list: list[tuple[str, float]],
60
+ ploidy: list[int],
61
+ anc_allele_available: bool,
62
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
63
+ """
64
+ Computes loci that meet specified allele frequency conditions across reference, target, and source genotypes.
65
+
66
+ Parameters
67
+ ----------
68
+ ref_gts : np.ndarray
69
+ A 2D numpy array where each row represents a locus and each column represents an individual in the reference group.
70
+ tgt_gts : np.ndarray
71
+ A 2D numpy array where each row represents a locus and each column represents an individual in the target group.
72
+ src_gts_list : list of np.ndarray
73
+ A list of 2D numpy arrays for each source population, where each row represents a locus and each column
74
+ represents an individual in that source population.
75
+ w : float
76
+ Threshold for the allele frequency in `ref_gts`. Only loci with frequencies less than `w` are counted.
77
+ Must be within the range [0, 1].
78
+ y_list : list of tuple[str, float]
79
+ List of allele frequency conditions for each source population in `src_gts_list`.
80
+ Each entry is a tuple (operator, threshold), where:
81
+ - `operator` can be '=', '<', '>', '<=', '>='
82
+ - `threshold` is a float within [0, 1]
83
+ The length must match `src_gts_list`.
84
+ ploidy : list[int]
85
+ Ploidy values for reference, target, and one or more source populations (in that order).
86
+ anc_allele_available : bool
87
+ If True, checks only for matches with `y` (assuming `1` represents the derived allele).
88
+ If False, checks both matches with `y` and `1 - y`, taking the dominant allele in the source as the reference.
89
+
90
+ Returns
91
+ -------
92
+ tuple[np.ndarray, np.ndarray, np.ndarray]
93
+ - Adjusted reference allele frequencies (`ref_freq`).
94
+ - Adjusted target allele frequencies (`tgt_freq`).
95
+ - Boolean array indicating loci that meet the specified frequency conditions (`condition`).
96
+ """
97
+ # Validate input parameters
98
+ if not (0 <= w <= 1):
99
+ raise ValueError("Parameters w must be within the range [0, 1].")
100
+
101
+ for op, y in y_list:
102
+ if not (0 <= y <= 1):
103
+ raise ValueError(f"Invalid value in y_list: {y}. within the range [0, 1].")
104
+ if op not in ("=", "<", ">", "<=", ">="):
105
+ raise ValueError(
106
+ f"Invalid operator in y_list: {op}. Must be '=', '<', '>', '<=', or '>='."
107
+ )
108
+
109
+ if len(src_gts_list) != len(y_list):
110
+ raise ValueError("The length of src_gts_list and y_list must match.")
111
+
112
+ # Compute allele frequencies
113
+ ref_freq = calc_freq(ref_gts, ploidy[0])
114
+ tgt_freq = calc_freq(tgt_gts, ploidy[1])
115
+ src_freq_list = [
116
+ calc_freq(src_gts, ploidy_val)
117
+ for src_gts, ploidy_val in zip(src_gts_list, ploidy[2:])
118
+ ]
119
+
120
+ # Check match for each `y`
121
+ op_funcs = {
122
+ "=": lambda src_freq, y: src_freq == y,
123
+ "<": lambda src_freq, y: src_freq < y,
124
+ ">": lambda src_freq, y: src_freq > y,
125
+ "<=": lambda src_freq, y: src_freq <= y,
126
+ ">=": lambda src_freq, y: src_freq >= y,
127
+ }
128
+
129
+ match_conditions = [
130
+ op_funcs[op](src_freq, y) for src_freq, (op, y) in zip(src_freq_list, y_list)
131
+ ]
132
+ all_match_y = np.all(match_conditions, axis=0)
133
+
134
+ if not anc_allele_available:
135
+ # Check if all source populations match `1 - y`
136
+ match_conditions_1_minus_y = [
137
+ op_funcs[op](src_freq, 1 - y)
138
+ for src_freq, (op, y) in zip(src_freq_list, y_list)
139
+ ]
140
+ all_match_1_minus_y = np.all(match_conditions_1_minus_y, axis=0)
141
+ all_match = all_match_y | all_match_1_minus_y
142
+
143
+ # Identify loci where all sources match `1 - y` for frequency inversion
144
+ inverted = all_match_1_minus_y
145
+
146
+ # Invert frequencies for these loci
147
+ ref_freq[inverted] = 1 - ref_freq[inverted]
148
+ tgt_freq[inverted] = 1 - tgt_freq[inverted]
149
+ else:
150
+ all_match = all_match_y
151
+
152
+ # Final condition: locus must satisfy source matching and have `ref_freq < w`
153
+ condition = all_match & (ref_freq < w)
154
+
155
+ return ref_freq, tgt_freq, condition
156
+
157
+
158
+ def calc_four_pops_freq(
159
+ ref_gts: np.ndarray,
160
+ tgt_gts: np.ndarray,
161
+ src_gts: np.ndarray,
162
+ out_gts: Optional[np.ndarray] = None,
163
+ ref_ploidy: int = 1,
164
+ tgt_ploidy: int = 1,
165
+ src_ploidy: int = 1,
166
+ out_ploidy: int = 1,
167
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
168
+ """
169
+ Calculates allele frequencies for four populations given their genotype matrices.
170
+
171
+ Parameters
172
+ ----------
173
+ ref_gts : np.ndarray
174
+ Genotype matrix for the reference population.
175
+ tgt_gts : np.ndarray
176
+ Genotype matrix for the target population.
177
+ src_gts : np.ndarray
178
+ Genotype matrix for the source population.
179
+ out_gts : np.ndarray
180
+ Genotype matrix for the outgroup. If None, the outgroup frequency is assumed to be 0 at all loci.
181
+ Default: None.
182
+ ref_ploidy : int, optional
183
+ Ploidy level of the genomes from the reference population. Default: 1 (phased data).
184
+ tgt_ploidy : int, optional
185
+ Ploidy level of the genomes from the target population. Default: 1 (phased data).
186
+ src_ploidy : int, optional
187
+ Ploidy level of the genomes from the source population. Default: 1 (phased data).
188
+ out_ploidy : int, optional
189
+ Ploidy level of the genomes from the outgroup. Default: 1 (phased data).
190
+
191
+ Returns
192
+ -------
193
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
194
+ Allele frequencies for (ref, tgt, src, out) populations.
195
+ """
196
+ ref_freq = calc_freq(ref_gts, ref_ploidy)
197
+ tgt_freq = calc_freq(tgt_gts, tgt_ploidy)
198
+ src_freq = calc_freq(src_gts, src_ploidy)
199
+ if out_gts is None:
200
+ out_freq = np.zeros_like(ref_freq)
201
+ else:
202
+ out_freq = calc_freq(out_gts, out_ploidy)
203
+
204
+ return ref_freq, tgt_freq, src_freq, out_freq
205
+
206
+
207
+ def calc_pattern_sum(
208
+ ref_freq: np.ndarray,
209
+ tgt_freq: np.ndarray,
210
+ src_freq: np.ndarray,
211
+ out_freq: np.ndarray,
212
+ pattern: str,
213
+ ) -> float:
214
+ """
215
+ Applies an ABBA-like pattern and returns the sum over loci of the transformed frequency products.
216
+
217
+ Parameters
218
+ ----------
219
+ ref_freq:
220
+ Allele frequencies for the reference population (no introgression) across loci.
221
+ tgt_freq:
222
+ Allele frequencies for the target population (receive introgression) across loci.
223
+ src_freq:
224
+ Allele frequencies for the source population (provide introgression) across loci.
225
+ out_freq:
226
+ Allele frequencies for the outgroup across loci.
227
+ pattern : str
228
+ A 4-character pattern string (e.g., 'abba'), where:
229
+ - 'a': use 1 - freq
230
+ - 'b': use freq
231
+
232
+ Returns
233
+ -------
234
+ float
235
+ Sum over loci of the product defined by the pattern.
236
+
237
+ Raises
238
+ ------
239
+ ValueError
240
+ - If the pattern string is not exactly four characters long.
241
+ - If the pattern contains characters other than 'a' or 'b'.
242
+ """
243
+ if len(pattern) != 4:
244
+ raise ValueError("Pattern must be a four-character string.")
245
+
246
+ freqs = [ref_freq, tgt_freq, src_freq, out_freq]
247
+ product = np.ones_like(ref_freq)
248
+
249
+ for i, c in enumerate(pattern.lower()):
250
+ if c == "a":
251
+ product *= 1 - freqs[i]
252
+ elif c == "b":
253
+ product *= freqs[i]
254
+ else:
255
+ raise ValueError(
256
+ f"Invalid character '{c}' in pattern. Only 'a' and 'b' allowed."
257
+ )
258
+
259
+ return float(np.sum(product))