sai-pg 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sai/stats/features.py ADDED
@@ -0,0 +1,302 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+
23
+
24
+ def calc_freq(gts: np.ndarray, ploidy: int = 1) -> np.ndarray:
25
+ """
26
+ Calculates allele frequencies, supporting both phased and unphased data.
27
+
28
+ Parameters
29
+ ----------
30
+ gts : np.ndarray
31
+ A 2D numpy array where each row represents a locus and each column represents an individual.
32
+ ploidy : int, optional
33
+ Ploidy level of the organism. If ploidy=1, the function assumes phased data and calculates
34
+ frequency by taking the mean across individuals. For unphased data, it calculates frequency by
35
+ dividing the sum across individuals by the total number of alleles. Default is 1.
36
+
37
+ Returns
38
+ -------
39
+ np.ndarray
40
+ An array of allele frequencies for each locus.
41
+ """
42
+ return np.sum(gts, axis=1) / (gts.shape[1] * ploidy)
43
+
44
+
45
+ def compute_matching_loci(
46
+ ref_gts: np.ndarray,
47
+ tgt_gts: np.ndarray,
48
+ src_gts_list: list[np.ndarray],
49
+ w: float,
50
+ y_list: list[tuple[str, float]],
51
+ ploidy: int,
52
+ anc_allele_available: bool,
53
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
54
+ """
55
+ Computes loci that meet specified allele frequency conditions across reference, target, and source genotypes.
56
+
57
+ Parameters
58
+ ----------
59
+ ref_gts : np.ndarray
60
+ A 2D numpy array where each row represents a locus and each column represents an individual in the reference group.
61
+ tgt_gts : np.ndarray
62
+ A 2D numpy array where each row represents a locus and each column represents an individual in the target group.
63
+ src_gts_list : list of np.ndarray
64
+ A list of 2D numpy arrays for each source population, where each row represents a locus and each column
65
+ represents an individual in that source population.
66
+ w : float
67
+ Threshold for the allele frequency in `ref_gts`. Only loci with frequencies less than `w` are counted.
68
+ Must be within the range [0, 1].
69
+ y_list : list of tuple[str, float]
70
+ List of allele frequency conditions for each source population in `src_gts_list`.
71
+ Each entry is a tuple (operator, threshold), where:
72
+ - `operator` can be '=', '<', '>', '<=', '>='
73
+ - `threshold` is a float within [0, 1]
74
+ The length must match `src_gts_list`.
75
+ ploidy : int
76
+ The ploidy level of the organism.
77
+ anc_allele_available : bool
78
+ If True, checks only for matches with `y` (assuming `1` represents the derived allele).
79
+ If False, checks both matches with `y` and `1 - y`, taking the dominant allele in the source as the reference.
80
+
81
+ Returns
82
+ -------
83
+ tuple[np.ndarray, np.ndarray, np.ndarray]
84
+ - Adjusted reference allele frequencies (`ref_freq`).
85
+ - Adjusted target allele frequencies (`tgt_freq`).
86
+ - Boolean array indicating loci that meet the specified frequency conditions (`condition`).
87
+ """
88
+ # Validate input parameters
89
+ if not (0 <= w <= 1):
90
+ raise ValueError("Parameters w must be within the range [0, 1].")
91
+
92
+ for op, y in y_list:
93
+ if not (0 <= y <= 1):
94
+ raise ValueError(f"Invalid value in y_list: {y}. within the range [0, 1].")
95
+ if op not in ("=", "<", ">", "<=", ">="):
96
+ raise ValueError(
97
+ f"Invalid operator in y_list: {op}. Must be '=', '<', '>', '<=', or '>='."
98
+ )
99
+
100
+ if len(src_gts_list) != len(y_list):
101
+ raise ValueError("The length of src_gts_list and y_list must match.")
102
+
103
+ # Compute allele frequencies
104
+ ref_freq = calc_freq(ref_gts, ploidy)
105
+ tgt_freq = calc_freq(tgt_gts, ploidy)
106
+ src_freq_list = [calc_freq(src_gts, ploidy) for src_gts in src_gts_list]
107
+
108
+ # Check match for each `y`
109
+ op_funcs = {
110
+ "=": lambda src_freq, y: src_freq == y,
111
+ "<": lambda src_freq, y: src_freq < y,
112
+ ">": lambda src_freq, y: src_freq > y,
113
+ "<=": lambda src_freq, y: src_freq <= y,
114
+ ">=": lambda src_freq, y: src_freq >= y,
115
+ }
116
+
117
+ match_conditions = [
118
+ op_funcs[op](src_freq, y) for src_freq, (op, y) in zip(src_freq_list, y_list)
119
+ ]
120
+ all_match_y = np.all(match_conditions, axis=0)
121
+
122
+ if not anc_allele_available:
123
+ # Check if all source populations match `1 - y`
124
+ match_conditions_1_minus_y = [
125
+ op_funcs[op](src_freq, 1 - y)
126
+ for src_freq, (op, y) in zip(src_freq_list, y_list)
127
+ ]
128
+ all_match_1_minus_y = np.all(match_conditions_1_minus_y, axis=0)
129
+ all_match = all_match_y | all_match_1_minus_y
130
+
131
+ # Identify loci where all sources match `1 - y` for frequency inversion
132
+ inverted = all_match_1_minus_y
133
+
134
+ # Invert frequencies for these loci
135
+ ref_freq[inverted] = 1 - ref_freq[inverted]
136
+ tgt_freq[inverted] = 1 - tgt_freq[inverted]
137
+ else:
138
+ all_match = all_match_y
139
+
140
+ # Final condition: locus must satisfy source matching and have `ref_freq < w`
141
+ condition = all_match & (ref_freq < w)
142
+
143
+ return ref_freq, tgt_freq, condition
144
+
145
+
146
+ def calc_u(
147
+ ref_gts: np.ndarray,
148
+ tgt_gts: np.ndarray,
149
+ src_gts_list: list[np.ndarray],
150
+ pos: np.ndarray,
151
+ w: float,
152
+ x: float,
153
+ y_list: list[float],
154
+ ploidy: int = 1,
155
+ anc_allele_available: bool = False,
156
+ ) -> tuple[int, np.ndarray]:
157
+ """
158
+ Calculates the count of genetic loci that meet specified allele frequency conditions
159
+ across reference, target, and multiple source genotypes, with adjustments based on src_freq consistency.
160
+
161
+ Parameters
162
+ ----------
163
+ ref_gts : np.ndarray
164
+ A 2D numpy array where each row represents a locus and each column represents an individual in the reference group.
165
+ tgt_gts : np.ndarray
166
+ A 2D numpy array where each row represents a locus and each column represents an individual in the target group.
167
+ src_gts_list : list of np.ndarray
168
+ A list of 2D numpy arrays for each source population, where each row represents a locus and each column
169
+ represents an individual in that source population.
170
+ pos : np.ndarray
171
+ A 1D numpy array where each element represents the genomic position.
172
+ w : float
173
+ Threshold for the allele frequency in `ref_gts`. Only loci with frequencies less than `w` are counted.
174
+ Must be within the range [0, 1].
175
+ x : float
176
+ Threshold for the allele frequency in `tgt_gts`. Only loci with frequencies greater than `x` are counted.
177
+ Must be within the range [0, 1].
178
+ y_list : list of float
179
+ List of exact allele frequency thresholds for each source population in `src_gts_list`.
180
+ Must be within the range [0, 1] and have the same length as `src_gts_list`.
181
+ ploidy : int, optional
182
+ The ploidy level of the organism. Default is 1, which assumes phased data.
183
+ anc_allele_available : bool
184
+ If True, checks only for matches with `y` (assuming `1` represents the derived allele).
185
+ If False, checks both matches with `y` and `1 - y`, taking the major allele in the source as the reference.
186
+
187
+ Returns
188
+ -------
189
+ tuple[int, np.ndarray]
190
+ - The count of loci that meet all specified frequency conditions.
191
+ - A 1D numpy array containing the genomic positions of the loci that meet the conditions.
192
+
193
+ Raises
194
+ ------
195
+ ValueError
196
+ If `x` is outside the range [0, 1].
197
+ """
198
+ # Validate input parameters
199
+ if not (0 <= x <= 1):
200
+ raise ValueError("Parameter x must be within the range [0, 1].")
201
+
202
+ ref_freq, tgt_freq, condition = compute_matching_loci(
203
+ ref_gts,
204
+ tgt_gts,
205
+ src_gts_list,
206
+ w,
207
+ y_list,
208
+ ploidy,
209
+ anc_allele_available,
210
+ )
211
+
212
+ # Apply final conditions
213
+ condition &= tgt_freq > x
214
+
215
+ loci_indices = np.where(condition)[0]
216
+ loci_positions = pos[loci_indices]
217
+ count = loci_indices.size
218
+
219
+ # Return count of matching loci
220
+ return count, loci_positions
221
+
222
+
223
+ def calc_q(
224
+ ref_gts: np.ndarray,
225
+ tgt_gts: np.ndarray,
226
+ src_gts_list: list[np.ndarray],
227
+ pos: np.ndarray,
228
+ w: float,
229
+ y_list: list[float],
230
+ quantile: float = 0.95,
231
+ ploidy: int = 1,
232
+ anc_allele_available: bool = False,
233
+ ) -> float:
234
+ """
235
+ Calculates a specified quantile of derived allele frequencies in `tgt_gts` for loci that meet specific conditions
236
+ across reference and multiple source genotypes, with adjustments based on src_freq consistency.
237
+
238
+ Parameters
239
+ ----------
240
+ ref_gts : np.ndarray
241
+ A 2D numpy array where each row represents a locus and each column represents an individual in the reference group.
242
+ tgt_gts : np.ndarray
243
+ A 2D numpy array where each row represents a locus and each column represents an individual in the target group.
244
+ src_gts_list : list of np.ndarray
245
+ A list of 2D numpy arrays for each source population, where each row represents a locus and each column
246
+ represents an individual in that source population.
247
+ pos: np.ndarray
248
+ A 1D numpy array where each element represents the genomic position.
249
+ w : float
250
+ Frequency threshold for the derived allele in `ref_gts`. Only loci with frequencies lower than `w` are included.
251
+ Must be within the range [0, 1].
252
+ y_list : list of float
253
+ List of exact frequency thresholds for each source population in `src_gts_list`.
254
+ Must be within the range [0, 1] and have the same length as `src_gts_list`.
255
+ quantile : float, optional
256
+ The quantile to compute for the filtered `tgt_gts` frequencies. Must be within the range [0, 1].
257
+ Default is 0.95 (95% quantile).
258
+ ploidy : int, optional
259
+ The ploidy level of the organism. Default is 1, which assumes phased data.
260
+ anc_allele_available : bool
261
+ If True, checks only for matches with `y` (assuming `1` represents the derived allele).
262
+ If False, checks both matches with `y` and `1 - y`, taking the major allele in the source as the reference.
263
+
264
+ Returns
265
+ -------
266
+ tuple[float, np.ndarray]
267
+ - The specified quantile of the derived allele frequencies in `tgt_gts` for loci meeting the specified conditions,
268
+ or NaN if no loci meet the criteria.
269
+ - A 1D numpy array containing the genomic positions of the loci that meet the conditions.
270
+
271
+ Raises
272
+ ------
273
+ ValueError
274
+ If `quantile` is outside the range [0, 1].
275
+ """
276
+ # Validate input parameters
277
+ if not (0 <= quantile <= 1):
278
+ raise ValueError("Parameter quantile must be within the range [0, 1].")
279
+
280
+ ref_freq, tgt_freq, condition = compute_matching_loci(
281
+ ref_gts,
282
+ tgt_gts,
283
+ src_gts_list,
284
+ w,
285
+ y_list,
286
+ ploidy,
287
+ anc_allele_available,
288
+ )
289
+
290
+ # Filter `tgt_gts` frequencies based on the combined condition
291
+ filtered_tgt_freq = tgt_freq[condition]
292
+ filtered_positions = pos[condition]
293
+
294
+ # Return NaN if no loci meet the criteria
295
+ if filtered_tgt_freq.size == 0:
296
+ return np.nan, np.array([])
297
+
298
+ threshold = np.nanquantile(filtered_tgt_freq, quantile)
299
+ loci_positions = filtered_positions[filtered_tgt_freq >= threshold]
300
+
301
+ # Calculate and return the specified quantile of the filtered `tgt_gts` frequencies
302
+ return threshold, loci_positions
sai/utils/__init__.py ADDED
@@ -0,0 +1,22 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ from .genomic_dataclasses import *
22
+ from .utils import *
@@ -0,0 +1,23 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ from .data_generator import DataGenerator
22
+ from .chunk_generator import ChunkGenerator
23
+ from .window_generator import WindowGenerator
@@ -0,0 +1,148 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import pysam
22
+ from typing import Iterator
23
+ from sai.utils import split_genome
24
+ from sai.utils.generators import DataGenerator
25
+
26
+
27
+ class ChunkGenerator(DataGenerator):
28
+ """
29
+ Generates genome chunks from VCF windows for parallel processing.
30
+
31
+ This class splits genomic windows into non-overlapping chunks assigned to workers,
32
+ based on the VCF file length and a user-defined window and step size.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ vcf_file: str,
38
+ chr_name: str,
39
+ step_size: int,
40
+ window_size: int,
41
+ num_chunks: int,
42
+ ):
43
+ """
44
+ Initializes a new instance of ChunkGenerator.
45
+
46
+ Parameters
47
+ ----------
48
+ vcf_file : str
49
+ Path to the VCF file to process.
50
+ chr_name: str
51
+ Name of the chromosome to process.
52
+ step_size : int
53
+ Step size for generating windows.
54
+ window_size : int
55
+ Window size for generating windows.
56
+ num_chunks : int
57
+ Number of chunks to split the windows among.
58
+
59
+ Raises
60
+ ------
61
+ ValueError
62
+ If the specified chromosome is not found in the VCF file.
63
+ """
64
+ with pysam.VariantFile(vcf_file) as vcf:
65
+ first_pos = last_pos = None
66
+ for rec in vcf:
67
+ if rec.chrom != chr_name:
68
+ if first_pos is not None:
69
+ break
70
+ continue
71
+ if first_pos is None:
72
+ first_pos = rec.pos
73
+ last_pos = rec.pos
74
+
75
+ if first_pos is None:
76
+ raise ValueError(f"Chromosome {chr_name} not found in VCF.")
77
+
78
+ windows = split_genome([first_pos, last_pos], window_size, step_size)
79
+
80
+ self.chunks = self._split_windows_ranges(windows, num_chunks)
81
+ self.num_chunks = len(self.chunks)
82
+ self.chr_name = chr_name
83
+
84
+ def get(self) -> Iterator[tuple[str, int, int]]:
85
+ """
86
+ Yields a tuple representing the chunk assigned to each worker.
87
+
88
+ Yields
89
+ ------
90
+ tuple of int
91
+ A tuple representing the range (chr_name, start, end) assigned to each worker.
92
+ """
93
+ for chunk in self.chunks:
94
+ yield {
95
+ "chr_name": self.chr_name,
96
+ "start": chunk[0],
97
+ "end": chunk[1],
98
+ }
99
+
100
+ def __len__(self) -> int:
101
+ """
102
+ Returns the number of chunks.
103
+
104
+ Returns
105
+ -------
106
+ int
107
+ Number of chunks.
108
+ """
109
+ return self.num_chunks
110
+
111
+ def _split_windows_ranges(self, windows: list, num_chunks: int) -> list:
112
+ """
113
+ Splits the list of windows into ranges assigned to each chunk.
114
+
115
+ Each range is defined by the first window's start and the last window's end
116
+ within that split.
117
+
118
+ Parameters
119
+ ----------
120
+ windows : list of tuple
121
+ List of (start, end) tuples representing windows.
122
+ num_chunks : int
123
+ Number of chunks to divide the windows among.
124
+
125
+ Returns
126
+ -------
127
+ list of tuple
128
+ List of (start, end) tuples representing the ranges for each chunk.
129
+ """
130
+ avg = len(windows) // num_chunks
131
+ remainder = len(windows) % num_chunks
132
+ result = []
133
+ start_idx = 0
134
+ prev_end = None
135
+
136
+ for i in range(num_chunks):
137
+ end_idx = start_idx + avg + (1 if i < remainder else 0)
138
+ sub = windows[start_idx:end_idx]
139
+ if sub:
140
+ min_start = sub[0][0]
141
+ max_end = sub[-1][1]
142
+ if (prev_end is not None) and (min_start < prev_end):
143
+ min_start = prev_end + 1
144
+ result.append((min_start, max_end))
145
+ prev_end = max_end
146
+ start_idx = end_idx
147
+
148
+ return result
@@ -0,0 +1,49 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ from abc import ABC, abstractmethod
22
+
23
+
24
+ class DataGenerator(ABC):
25
+ """
26
+ Abstract base class for generating data.
27
+
28
+ This class defines a common interface for data generation. Subclasses
29
+ must implement the get method to generate data according to specific
30
+ requirements or configurations provided via keyword arguments.
31
+ """
32
+
33
+ @abstractmethod
34
+ def get(self, **kwargs):
35
+ """
36
+ Generates data based on the provided keyword arguments.
37
+
38
+ Subclasses should implement this method to generate and return data
39
+ according to the requirements described by the keyword arguments.
40
+
41
+ Parameters:
42
+ **kwargs: Arbitrary keyword arguments specific to the data generation
43
+ implementation in subclasses.
44
+
45
+ Returns:
46
+ The generated data, the format and type of which are determined by the
47
+ subclass implementation.
48
+ """
49
+ pass