sai-pg 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sai/__init__.py +18 -0
- sai/__main__.py +73 -0
- sai/parsers/__init__.py +18 -0
- sai/parsers/argument_validation.py +169 -0
- sai/parsers/outlier_parser.py +76 -0
- sai/parsers/plot_parser.py +152 -0
- sai/parsers/score_parser.py +241 -0
- sai/sai.py +315 -0
- sai/stats/__init__.py +18 -0
- sai/stats/features.py +302 -0
- sai/utils/__init__.py +22 -0
- sai/utils/generators/__init__.py +23 -0
- sai/utils/generators/chunk_generator.py +148 -0
- sai/utils/generators/data_generator.py +49 -0
- sai/utils/generators/window_generator.py +250 -0
- sai/utils/genomic_dataclasses.py +46 -0
- sai/utils/multiprocessing/__init__.py +22 -0
- sai/utils/multiprocessing/mp_manager.py +251 -0
- sai/utils/multiprocessing/mp_pool.py +73 -0
- sai/utils/preprocessors/__init__.py +23 -0
- sai/utils/preprocessors/chunk_preprocessor.py +152 -0
- sai/utils/preprocessors/data_preprocessor.py +94 -0
- sai/utils/preprocessors/feature_preprocessor.py +211 -0
- sai/utils/utils.py +689 -0
- sai_pg-1.0.0.dist-info/METADATA +44 -0
- sai_pg-1.0.0.dist-info/RECORD +30 -0
- sai_pg-1.0.0.dist-info/WHEEL +5 -0
- sai_pg-1.0.0.dist-info/entry_points.txt +2 -0
- sai_pg-1.0.0.dist-info/licenses/LICENSE +674 -0
- sai_pg-1.0.0.dist-info/top_level.txt +1 -0
sai/utils/utils.py
ADDED
@@ -0,0 +1,689 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import allel
|
22
|
+
import numpy as np
|
23
|
+
import pandas as pd
|
24
|
+
from typing import Optional, Union
|
25
|
+
from natsort import natsorted
|
26
|
+
from sai.utils.genomic_dataclasses import ChromosomeData
|
27
|
+
|
28
|
+
|
29
|
+
def parse_ind_file(filename: str) -> dict[str, list[str]]:
|
30
|
+
"""
|
31
|
+
Read sample information from a file and organize it by categories.
|
32
|
+
|
33
|
+
Parameters
|
34
|
+
----------
|
35
|
+
filename : str
|
36
|
+
The name of the file containing sample information.
|
37
|
+
|
38
|
+
Returns
|
39
|
+
-------
|
40
|
+
samples : dict of str to list of str
|
41
|
+
A dictionary where the keys represent categories, and the values are lists of samples
|
42
|
+
associated with those categories.
|
43
|
+
|
44
|
+
Raises
|
45
|
+
------
|
46
|
+
FileNotFoundError
|
47
|
+
If the specified file does not exist.
|
48
|
+
ValueError
|
49
|
+
If no samples are found in the file.
|
50
|
+
"""
|
51
|
+
try:
|
52
|
+
samples = {}
|
53
|
+
|
54
|
+
with open(filename, "r") as f:
|
55
|
+
for line in f:
|
56
|
+
parts = line.strip().split()
|
57
|
+
if len(parts) != 2:
|
58
|
+
continue
|
59
|
+
|
60
|
+
category, sample = parts
|
61
|
+
if category not in samples:
|
62
|
+
samples[category] = []
|
63
|
+
samples[category].append(sample)
|
64
|
+
|
65
|
+
if not samples:
|
66
|
+
raise ValueError(f"No samples found in {filename}. Please check your data.")
|
67
|
+
|
68
|
+
except FileNotFoundError:
|
69
|
+
raise FileNotFoundError(
|
70
|
+
f"File '{filename}' not found. Please check the file path."
|
71
|
+
)
|
72
|
+
|
73
|
+
return samples
|
74
|
+
|
75
|
+
|
76
|
+
def read_geno_data(
|
77
|
+
vcf: str,
|
78
|
+
ind_samples: dict[str, list[str]],
|
79
|
+
chr_name: str,
|
80
|
+
start: int = None,
|
81
|
+
end: int = None,
|
82
|
+
anc_allele_file: Optional[str] = None,
|
83
|
+
filter_missing: bool = True,
|
84
|
+
) -> tuple[ChromosomeData, list[str], int]:
|
85
|
+
"""
|
86
|
+
Read genotype data from a VCF file efficiently for a specified chromosome.
|
87
|
+
|
88
|
+
Parameters
|
89
|
+
----------
|
90
|
+
vcf : str
|
91
|
+
The name of the VCF file containing genotype data.
|
92
|
+
ind_samples : dict of str to list of str
|
93
|
+
A dictionary where keys are categories (e.g., different sample groups), and values are lists of sample names.
|
94
|
+
chr_name : str
|
95
|
+
The name of the chromosome to read.
|
96
|
+
start: int, optional
|
97
|
+
The starting position (1-based, inclusive) on the chromosome. Default: None.
|
98
|
+
end: int, optional
|
99
|
+
The ending position (1-based, inclusive) on the chromosome. Default: None.
|
100
|
+
anc_allele_file : str, optional
|
101
|
+
The name of the BED file containing ancestral allele information, or None if not provided.
|
102
|
+
filter_missing : bool, optional
|
103
|
+
Indicates whether to filter out variants that are missing across all samples. Default: True.
|
104
|
+
|
105
|
+
Returns
|
106
|
+
-------
|
107
|
+
chrom_data: ChromosomeData
|
108
|
+
A ChromosomeData instance for the specified chromosome in the VCF.
|
109
|
+
samples: list
|
110
|
+
A list of samples in the data.
|
111
|
+
ploidy: int
|
112
|
+
Ploidy level of the organism.
|
113
|
+
"""
|
114
|
+
try:
|
115
|
+
# Load all samples from the VCF file
|
116
|
+
all_samples = [sample for samples in ind_samples.values() for sample in samples]
|
117
|
+
|
118
|
+
# Use region parameter to restrict to the specified chromosome
|
119
|
+
if (start is None) and (end is None):
|
120
|
+
region = f"{chr_name}"
|
121
|
+
else:
|
122
|
+
region = f"{chr_name}:{start}-{end}"
|
123
|
+
vcf_data = allel.read_vcf(
|
124
|
+
vcf,
|
125
|
+
fields=[
|
126
|
+
"calldata/GT",
|
127
|
+
"variants/CHROM",
|
128
|
+
"variants/POS",
|
129
|
+
"variants/REF",
|
130
|
+
"variants/ALT",
|
131
|
+
"samples",
|
132
|
+
],
|
133
|
+
alt_number=1,
|
134
|
+
samples=all_samples,
|
135
|
+
region=region, # Specify the chromosome region
|
136
|
+
tabix=None,
|
137
|
+
)
|
138
|
+
except Exception as e:
|
139
|
+
raise ValueError(f"Failed to read VCF file {vcf} from {region}: {e}") from e
|
140
|
+
|
141
|
+
# Convert genotype data to a more efficient GenotypeArray
|
142
|
+
if vcf_data is None:
|
143
|
+
return None, all_samples, None
|
144
|
+
|
145
|
+
gt = allel.GenotypeArray(vcf_data.get("calldata/GT"))
|
146
|
+
pos = vcf_data.get("variants/POS")
|
147
|
+
ref = vcf_data.get("variants/REF")
|
148
|
+
alt = vcf_data.get("variants/ALT")
|
149
|
+
ploidy = gt.shape[2]
|
150
|
+
|
151
|
+
if gt is None or pos is None or ref is None or alt is None:
|
152
|
+
raise ValueError("Invalid VCF file: Missing essential genotype data fields.")
|
153
|
+
|
154
|
+
# Load ancestral allele data if provided
|
155
|
+
if anc_allele_file:
|
156
|
+
anc_alleles = read_anc_allele(
|
157
|
+
anc_allele_file=anc_allele_file,
|
158
|
+
chr_name=chr_name,
|
159
|
+
start=start,
|
160
|
+
end=end,
|
161
|
+
)
|
162
|
+
else:
|
163
|
+
anc_alleles = None
|
164
|
+
|
165
|
+
sample_indices = [all_samples.index(s) for s in all_samples]
|
166
|
+
|
167
|
+
chrom_data = ChromosomeData(
|
168
|
+
POS=pos, REF=ref, ALT=alt, GT=gt.take(sample_indices, axis=1)
|
169
|
+
)
|
170
|
+
|
171
|
+
# Remove missing data if specified
|
172
|
+
if filter_missing:
|
173
|
+
non_missing_index = chrom_data.GT.count_missing(axis=1) == 0
|
174
|
+
num_missing = len(non_missing_index) - np.sum(non_missing_index)
|
175
|
+
if num_missing != 0:
|
176
|
+
print(
|
177
|
+
f"Found {num_missing} variants with missing genotypes, removing them ..."
|
178
|
+
)
|
179
|
+
chrom_data = filter_geno_data(chrom_data, non_missing_index)
|
180
|
+
|
181
|
+
# Check and incorporate ancestral alleles if the file is provided
|
182
|
+
if anc_alleles:
|
183
|
+
chrom_data = check_anc_allele(chrom_data, anc_alleles, chr_name)
|
184
|
+
|
185
|
+
return chrom_data, vcf_data.get("samples"), ploidy
|
186
|
+
|
187
|
+
|
188
|
+
def filter_geno_data(
|
189
|
+
data: ChromosomeData, index: Union[np.ndarray, list[bool]]
|
190
|
+
) -> ChromosomeData:
|
191
|
+
"""
|
192
|
+
Filter the genotype data based on the provided index.
|
193
|
+
|
194
|
+
Parameters
|
195
|
+
----------
|
196
|
+
data : ChromosomeData
|
197
|
+
An instance of ChromosomeData containing genotype data, where each attribute corresponds to an array (e.g., POS, REF, ALT, GT).
|
198
|
+
index : np.ndarray or list of bool
|
199
|
+
A boolean or integer array indicating which rows to keep.
|
200
|
+
|
201
|
+
Returns
|
202
|
+
-------
|
203
|
+
ChromosomeData
|
204
|
+
A new ChromosomeData instance with filtered data, containing only the rows specified by the index.
|
205
|
+
"""
|
206
|
+
return ChromosomeData(
|
207
|
+
POS=data.POS[index],
|
208
|
+
REF=data.REF[index],
|
209
|
+
ALT=data.ALT[index],
|
210
|
+
GT=data.GT.compress(index, axis=0),
|
211
|
+
)
|
212
|
+
|
213
|
+
|
214
|
+
def read_data(
|
215
|
+
vcf_file: str,
|
216
|
+
chr_name: str,
|
217
|
+
ref_ind_file: Optional[str],
|
218
|
+
tgt_ind_file: Optional[str],
|
219
|
+
src_ind_file: Optional[str],
|
220
|
+
anc_allele_file: Optional[str],
|
221
|
+
start: int = None,
|
222
|
+
end: int = None,
|
223
|
+
is_phased: bool = True,
|
224
|
+
filter_ref: bool = True,
|
225
|
+
filter_tgt: bool = True,
|
226
|
+
filter_src: bool = False,
|
227
|
+
filter_missing: bool = True,
|
228
|
+
) -> tuple[
|
229
|
+
Optional[dict[str, dict[str, ChromosomeData]]],
|
230
|
+
Optional[dict[str, list[str]]],
|
231
|
+
Optional[dict[str, dict[str, ChromosomeData]]],
|
232
|
+
Optional[dict[str, list[str]]],
|
233
|
+
Optional[dict[str, dict[str, ChromosomeData]]],
|
234
|
+
Optional[dict[str, list[str]]],
|
235
|
+
Optional[int],
|
236
|
+
]:
|
237
|
+
"""
|
238
|
+
Helper function for reading data from reference, target, and source populations.
|
239
|
+
|
240
|
+
Parameters
|
241
|
+
----------
|
242
|
+
vcf_file : str
|
243
|
+
Name of the VCF file containing genotype data.
|
244
|
+
chr_name: str
|
245
|
+
Name of the chromosome to read.
|
246
|
+
ref_ind_file : str or None
|
247
|
+
File with reference population sample information. None if not provided.
|
248
|
+
tgt_ind_file : str or None
|
249
|
+
File with target population sample information. None if not provided.
|
250
|
+
src_ind_file : str or None
|
251
|
+
File with source population sample information. None if not provided.
|
252
|
+
anc_allele_file : str or None
|
253
|
+
File with ancestral allele information. None if not provided.
|
254
|
+
start: int, optional
|
255
|
+
The starting position (1-based, inclusive) on the chromosome. Default: None.
|
256
|
+
end: int, optional
|
257
|
+
The ending position (1-based, inclusive) on the chromosome. Default: None.
|
258
|
+
is_phased : bool, optional
|
259
|
+
Whether to use phased genotypes. Default: True.
|
260
|
+
filter_ref : bool, optional
|
261
|
+
Whether to filter fixed variants for reference data. Default: True.
|
262
|
+
filter_tgt : bool, optional
|
263
|
+
Whether to filter fixed variants for target data. Default: True.
|
264
|
+
filter_src : bool, optional
|
265
|
+
Whether to filter fixed variants for source data. Default: False.
|
266
|
+
filter_missing : bool, optional
|
267
|
+
Whether to filter out missing data. Default: True.
|
268
|
+
|
269
|
+
Returns
|
270
|
+
-------
|
271
|
+
ref_data : dict or None
|
272
|
+
Genotype data from reference populations, organized by category and chromosome.
|
273
|
+
ref_samples : dict or None
|
274
|
+
Sample information from reference populations.
|
275
|
+
tgt_data : dict or None
|
276
|
+
Genotype data from target populations, organized by category and chromosome.
|
277
|
+
tgt_samples : dict or None
|
278
|
+
Sample information from target populations.
|
279
|
+
src_data : dict or None
|
280
|
+
Genotype data from source populations, organized by category and chromosome.
|
281
|
+
src_samples : dict or None
|
282
|
+
Sample information from source populations.
|
283
|
+
ploidy: int or None
|
284
|
+
Ploidy level of the organism.
|
285
|
+
|
286
|
+
Notes
|
287
|
+
-----
|
288
|
+
The `ref_data`, `tgt_data`, and `src_data` are organized as nested dictionaries where:
|
289
|
+
|
290
|
+
- The outermost keys represent different populations or sample categories.
|
291
|
+
- The second-level keys represent different chromosomes.
|
292
|
+
- The innermost value is a ChromosomeData instance containing:
|
293
|
+
- "POS": numpy array of variant positions.
|
294
|
+
- "REF": numpy array of reference alleles.
|
295
|
+
- "ALT": numpy array of alternative alleles.
|
296
|
+
- "GT": allel.GenotypeArray containing genotype data.
|
297
|
+
|
298
|
+
This organization allows easy access and manipulation of genotype data by category and chromosome,
|
299
|
+
enabling flexible processing across different populations and chromosomal regions.
|
300
|
+
"""
|
301
|
+
ref_data = ref_samples = tgt_data = tgt_samples = src_data = src_samples = None
|
302
|
+
|
303
|
+
# Parse sample information
|
304
|
+
if ref_ind_file:
|
305
|
+
ref_samples = parse_ind_file(ref_ind_file)
|
306
|
+
|
307
|
+
if tgt_ind_file:
|
308
|
+
tgt_samples = parse_ind_file(tgt_ind_file)
|
309
|
+
|
310
|
+
if src_ind_file:
|
311
|
+
src_samples = parse_ind_file(src_ind_file)
|
312
|
+
|
313
|
+
# Combine all samples for a single VCF read
|
314
|
+
all_samples = {}
|
315
|
+
if ref_samples:
|
316
|
+
all_samples.update(ref_samples)
|
317
|
+
if tgt_samples:
|
318
|
+
all_samples.update(tgt_samples)
|
319
|
+
if src_samples:
|
320
|
+
all_samples.update(src_samples)
|
321
|
+
|
322
|
+
try:
|
323
|
+
# Read VCF data
|
324
|
+
geno_data, all_samples, ploidy = read_geno_data(
|
325
|
+
vcf=vcf_file,
|
326
|
+
ind_samples=all_samples,
|
327
|
+
chr_name=chr_name,
|
328
|
+
start=start,
|
329
|
+
end=end,
|
330
|
+
anc_allele_file=anc_allele_file,
|
331
|
+
filter_missing=filter_missing,
|
332
|
+
)
|
333
|
+
except Exception as e:
|
334
|
+
raise ValueError(f"Failed to read VCF data: {e}")
|
335
|
+
|
336
|
+
if geno_data is None:
|
337
|
+
return None, ref_samples, None, tgt_samples, None, src_samples, None
|
338
|
+
|
339
|
+
# Separate reference, target, and source data
|
340
|
+
ref_data = extract_group_data(geno_data, all_samples, ref_samples)
|
341
|
+
tgt_data = extract_group_data(geno_data, all_samples, tgt_samples)
|
342
|
+
src_data = extract_group_data(geno_data, all_samples, src_samples)
|
343
|
+
|
344
|
+
# Apply fixed variant filtering conditionally
|
345
|
+
if filter_ref and ref_data and ref_samples:
|
346
|
+
ref_data = filter_fixed_variants(ref_data, ref_samples)
|
347
|
+
if filter_tgt and tgt_data and tgt_samples:
|
348
|
+
tgt_data = filter_fixed_variants(tgt_data, tgt_samples)
|
349
|
+
if filter_src and src_data and src_samples:
|
350
|
+
src_data = filter_fixed_variants(src_data, src_samples)
|
351
|
+
|
352
|
+
# Adjust genotypes based on phased/unphased requirement
|
353
|
+
reshape_genotypes(ref_data, is_phased)
|
354
|
+
reshape_genotypes(tgt_data, is_phased)
|
355
|
+
reshape_genotypes(src_data, is_phased)
|
356
|
+
|
357
|
+
return ref_data, ref_samples, tgt_data, tgt_samples, src_data, src_samples, ploidy
|
358
|
+
|
359
|
+
|
360
|
+
def extract_group_data(
|
361
|
+
geno_data: dict[str, ChromosomeData],
|
362
|
+
all_samples: list[str],
|
363
|
+
sample_groups: Optional[dict[str, list[str]]] = None,
|
364
|
+
) -> Optional[dict[str, ChromosomeData]]:
|
365
|
+
"""
|
366
|
+
Extract genotype data from geno_data based on the sample groups.
|
367
|
+
|
368
|
+
Parameters
|
369
|
+
----------
|
370
|
+
geno_data : dict of str to ChromosomeData
|
371
|
+
Contains genotype data, where each value is a ChromosomeData instance.
|
372
|
+
all_samples: list
|
373
|
+
A list of all sample names in the dataset.
|
374
|
+
sample_groups : dict of str to list of str, optional
|
375
|
+
Contains sample group information, where each key is a group name and the value is a list of samples.
|
376
|
+
If None, the function returns None.
|
377
|
+
|
378
|
+
Returns
|
379
|
+
-------
|
380
|
+
extracted_data : dict or None
|
381
|
+
Genotype data organized by sample group, or None if no sample groups are provided.
|
382
|
+
The structure is as follows:
|
383
|
+
|
384
|
+
- Keys represent sample group names.
|
385
|
+
- Values are ChromosomeData instances, filtered to include only the samples in the specified group.
|
386
|
+
"""
|
387
|
+
if sample_groups is None:
|
388
|
+
return None
|
389
|
+
|
390
|
+
sample_indices = {sample: idx for idx, sample in enumerate(all_samples)}
|
391
|
+
|
392
|
+
extracted_data = {}
|
393
|
+
|
394
|
+
for group, samples in sample_groups.items():
|
395
|
+
indices = [sample_indices[s] for s in samples if s in sample_indices]
|
396
|
+
|
397
|
+
# Extract ChromosomeData for the selected samples in each group
|
398
|
+
extracted_data[group] = ChromosomeData(
|
399
|
+
GT=geno_data.GT[:, indices, :],
|
400
|
+
POS=geno_data.POS,
|
401
|
+
REF=geno_data.REF,
|
402
|
+
ALT=geno_data.ALT,
|
403
|
+
)
|
404
|
+
|
405
|
+
return extracted_data
|
406
|
+
|
407
|
+
|
408
|
+
def filter_fixed_variants(
|
409
|
+
data: dict[str, ChromosomeData], samples: dict[str, list[str]]
|
410
|
+
) -> dict[str, ChromosomeData]:
|
411
|
+
"""
|
412
|
+
Filter out fixed variants for each population in the given data.
|
413
|
+
|
414
|
+
Parameters
|
415
|
+
----------
|
416
|
+
data : dict of str to ChromosomeData
|
417
|
+
Genotype data organized by category, where each category is represented by a ChromosomeData instance.
|
418
|
+
samples : dict of str to list of str
|
419
|
+
Sample information corresponding to each category, with each list containing
|
420
|
+
sample names for a specific population category.
|
421
|
+
|
422
|
+
Returns
|
423
|
+
-------
|
424
|
+
filtered_data : dict of str to ChromosomeData
|
425
|
+
Genotype data with fixed variants filtered out for each category.
|
426
|
+
"""
|
427
|
+
filtered_data = {}
|
428
|
+
for cat, geno in data.items():
|
429
|
+
ref_fixed_variants = np.sum(geno.GT.is_hom_ref(), axis=1) == len(samples[cat])
|
430
|
+
alt_fixed_variants = np.sum(geno.GT.is_hom_alt(), axis=1) == len(samples[cat])
|
431
|
+
fixed_variants = np.logical_or(ref_fixed_variants, alt_fixed_variants)
|
432
|
+
index = np.logical_not(fixed_variants)
|
433
|
+
filtered_data[cat] = filter_geno_data(geno, index)
|
434
|
+
|
435
|
+
return filtered_data
|
436
|
+
|
437
|
+
|
438
|
+
def reshape_genotypes(
|
439
|
+
data: Optional[dict[str, ChromosomeData]], is_phased: bool
|
440
|
+
) -> None:
|
441
|
+
"""
|
442
|
+
Reshape genotypes based on whether they are phased or unphased.
|
443
|
+
|
444
|
+
Parameters
|
445
|
+
----------
|
446
|
+
data : dict of str to ChromosomeData or None
|
447
|
+
Genotype data organized by sample group. If None, the function does nothing.
|
448
|
+
is_phased : bool
|
449
|
+
If True, reshape phased genotypes. Otherwise, sum over ploidy for unphased genotypes.
|
450
|
+
"""
|
451
|
+
if data is None:
|
452
|
+
return
|
453
|
+
|
454
|
+
for category, chrom_data in data.items():
|
455
|
+
mut_num, ind_num, ploidy = chrom_data.GT.shape
|
456
|
+
if is_phased:
|
457
|
+
chrom_data.GT = np.reshape(chrom_data.GT, (mut_num, ind_num * ploidy))
|
458
|
+
else:
|
459
|
+
chrom_data.GT = np.sum(chrom_data.GT, axis=2)
|
460
|
+
|
461
|
+
|
462
|
+
def get_ref_alt_allele(
|
463
|
+
ref: list[str], alt: list[str], pos: list[int]
|
464
|
+
) -> tuple[dict[int, str], dict[int, str]]:
|
465
|
+
"""
|
466
|
+
Indexes REF and ALT alleles with genomic positions.
|
467
|
+
|
468
|
+
Parameters
|
469
|
+
----------
|
470
|
+
ref : list of str
|
471
|
+
REF alleles.
|
472
|
+
alt : list of str
|
473
|
+
ALT alleles.
|
474
|
+
pos : list of int
|
475
|
+
Genomic positions.
|
476
|
+
|
477
|
+
Returns
|
478
|
+
-------
|
479
|
+
Dictionaries mapping genomic positions to REF and ALT alleles, respectively.
|
480
|
+
"""
|
481
|
+
return {p: r for p, r in zip(pos, ref)}, {p: a for p, a in zip(pos, alt)}
|
482
|
+
|
483
|
+
|
484
|
+
def read_anc_allele(
|
485
|
+
anc_allele_file: str, chr_name: str, start: int = None, end: int = None
|
486
|
+
) -> dict[str, dict[int, str]]:
|
487
|
+
"""
|
488
|
+
Reads ancestral allele information from a BED file for a specified chromosome,
|
489
|
+
optionally within a specified position range.
|
490
|
+
|
491
|
+
Parameters
|
492
|
+
----------
|
493
|
+
anc_allele_file : str
|
494
|
+
Path to the BED file containing ancestral allele information.
|
495
|
+
chr_name : str
|
496
|
+
Name of the chromosome to read.
|
497
|
+
start : int, optional
|
498
|
+
Start position (1-based, inclusive) of the region to filter. If None, no lower bound.
|
499
|
+
end : int, optional
|
500
|
+
End position (1-based, inclusive) of the region to filter. If None, no upper bound.
|
501
|
+
|
502
|
+
Returns
|
503
|
+
-------
|
504
|
+
dict of {str: dict of {int: str}}
|
505
|
+
Chromosome-level dictionary mapping genomic positions to ancestral alleles.
|
506
|
+
|
507
|
+
Raises
|
508
|
+
------
|
509
|
+
FileNotFoundError
|
510
|
+
If the ancestral allele file is not found.
|
511
|
+
ValueError
|
512
|
+
If no ancestral allele information is found for the specified chromosome (and region if specified).
|
513
|
+
"""
|
514
|
+
anc_alleles = {}
|
515
|
+
try:
|
516
|
+
with open(anc_allele_file, "r") as f:
|
517
|
+
for line in f:
|
518
|
+
e = line.rstrip().split()
|
519
|
+
chrom, pos, allele = e[0], int(e[2]), e[3]
|
520
|
+
if chrom != chr_name:
|
521
|
+
continue
|
522
|
+
if (start is not None and pos < start) or (
|
523
|
+
end is not None and pos > end
|
524
|
+
):
|
525
|
+
continue
|
526
|
+
anc_alleles.setdefault(chrom, {})[pos] = allele
|
527
|
+
except FileNotFoundError as exc:
|
528
|
+
raise FileNotFoundError(f"File {anc_allele_file} not found.") from exc
|
529
|
+
|
530
|
+
if not anc_alleles:
|
531
|
+
if start is not None or end is not None:
|
532
|
+
raise ValueError(
|
533
|
+
f"No ancestral allele is found for chromosome {chr_name} in the region {start}-{end}."
|
534
|
+
)
|
535
|
+
else:
|
536
|
+
raise ValueError(f"No ancestral allele is found for chromosome {chr_name}.")
|
537
|
+
|
538
|
+
return anc_alleles
|
539
|
+
|
540
|
+
|
541
|
+
def check_anc_allele(
|
542
|
+
data: dict[str, ChromosomeData], anc_allele: dict[str, dict[int, str]], c: str
|
543
|
+
) -> dict[str, ChromosomeData]:
|
544
|
+
"""
|
545
|
+
Checks whether the REF or ALT allele is the ancestral allele and updates genotype data.
|
546
|
+
|
547
|
+
Parameters
|
548
|
+
----------
|
549
|
+
data : dict
|
550
|
+
Genotype data for checking ancestral allele information.
|
551
|
+
anc_allele : dict of {str: dict of {int: str}}
|
552
|
+
Dictionary with ancestral allele information.
|
553
|
+
c : str
|
554
|
+
Chromosome name.
|
555
|
+
|
556
|
+
Returns
|
557
|
+
-------
|
558
|
+
dict
|
559
|
+
Genotype data with updated alleles after ancestral allele checking.
|
560
|
+
"""
|
561
|
+
ref_allele, alt_allele = get_ref_alt_allele(data.REF, data.ALT, data.POS)
|
562
|
+
|
563
|
+
# Determine variants to remove or flip
|
564
|
+
intersect_snps = np.intersect1d(
|
565
|
+
list(ref_allele.keys()), list(anc_allele.get(c, {}).keys())
|
566
|
+
)
|
567
|
+
removed_snps, flipped_snps = [], []
|
568
|
+
|
569
|
+
for v in intersect_snps:
|
570
|
+
if anc_allele[c][v] not in {ref_allele[v], alt_allele[v]}:
|
571
|
+
removed_snps.append(v)
|
572
|
+
elif anc_allele[c][v] == alt_allele[v]:
|
573
|
+
flipped_snps.append(v)
|
574
|
+
|
575
|
+
# Filter data by intersecting SNPs and remove any that should be removed
|
576
|
+
intersect_filter = np.in1d(data.POS, intersect_snps)
|
577
|
+
data = filter_geno_data(data, intersect_filter)
|
578
|
+
|
579
|
+
if removed_snps:
|
580
|
+
remain_filter = np.logical_not(np.in1d(data.POS, removed_snps))
|
581
|
+
data = filter_geno_data(data, remain_filter)
|
582
|
+
|
583
|
+
# Flip alleles in SNPs where ALT allele is ancestral
|
584
|
+
flip_snps(data, flipped_snps)
|
585
|
+
|
586
|
+
return data
|
587
|
+
|
588
|
+
|
589
|
+
def flip_snps(data: dict[str, ChromosomeData], flipped_snps: list[int]) -> None:
|
590
|
+
"""
|
591
|
+
Flips the genotypes for SNPs where the ALT allele is the ancestral allele.
|
592
|
+
|
593
|
+
Parameters
|
594
|
+
----------
|
595
|
+
data : dict
|
596
|
+
Genotype data.
|
597
|
+
flipped_snps : list of int
|
598
|
+
List of positions where the ALT allele is ancestral.
|
599
|
+
"""
|
600
|
+
# Create a boolean mask for positions that need to be flipped
|
601
|
+
is_flipped = np.isin(data.POS, flipped_snps)
|
602
|
+
|
603
|
+
# Flip all genotypes at once where the mask is True
|
604
|
+
data.GT[is_flipped] = allel.GenotypeArray(abs(data.GT[is_flipped] - 1))
|
605
|
+
|
606
|
+
|
607
|
+
def split_genome(
|
608
|
+
pos: np.ndarray,
|
609
|
+
window_size: int,
|
610
|
+
step_size: int,
|
611
|
+
) -> list[tuple]:
|
612
|
+
"""
|
613
|
+
Creates sliding windows along the genome based on variant positions.
|
614
|
+
|
615
|
+
Parameters
|
616
|
+
----------
|
617
|
+
pos : np.ndarray
|
618
|
+
Array of positions for the variants.
|
619
|
+
window_size : int
|
620
|
+
Length of each sliding window.
|
621
|
+
step_size : int
|
622
|
+
Step size of the sliding windows.
|
623
|
+
|
624
|
+
Returns
|
625
|
+
-------
|
626
|
+
list of tuple
|
627
|
+
List of sliding windows, where each entry is a tuple (start_position, end_position)
|
628
|
+
representing the start and end positions of each window.
|
629
|
+
|
630
|
+
Raises
|
631
|
+
------
|
632
|
+
ValueError
|
633
|
+
- If `step_size` or `window_size` are non-positive
|
634
|
+
- If `step_size` is greater than `window_size`
|
635
|
+
- If the `pos` array is empty
|
636
|
+
"""
|
637
|
+
# Validate inputs
|
638
|
+
if step_size <= 0 or window_size <= 0:
|
639
|
+
raise ValueError("`step_size` and `window_size` must be positive integers.")
|
640
|
+
if step_size > window_size:
|
641
|
+
raise ValueError("`step_size` cannot be greater than `window_size`.")
|
642
|
+
if len(pos) == 0:
|
643
|
+
raise ValueError("`pos` array must not be empty.")
|
644
|
+
|
645
|
+
window_positions = []
|
646
|
+
win_start = (pos[0] + step_size) // step_size * step_size - window_size + 1
|
647
|
+
win_start = max(win_start, 1)
|
648
|
+
|
649
|
+
# Create windows based on step size and window size
|
650
|
+
while win_start <= pos[-1]:
|
651
|
+
win_end = win_start + window_size - 1
|
652
|
+
window_positions.append((win_start, win_end))
|
653
|
+
win_start += step_size
|
654
|
+
|
655
|
+
return window_positions
|
656
|
+
|
657
|
+
|
658
|
+
def natsorted_df(df: pd.DataFrame) -> pd.DataFrame:
|
659
|
+
"""
|
660
|
+
Sorts a DataFrame naturally by "Chrom", "Start", and "End" columns.
|
661
|
+
|
662
|
+
Parameters
|
663
|
+
----------
|
664
|
+
df : pd.DataFrame
|
665
|
+
The DataFrame to be sorted.
|
666
|
+
|
667
|
+
Returns
|
668
|
+
-------
|
669
|
+
pd.DataFrame
|
670
|
+
The naturally sorted DataFrame.
|
671
|
+
|
672
|
+
Raises
|
673
|
+
------
|
674
|
+
ValueError
|
675
|
+
If the required columns "Chrom", "Start", or "End" are missing.
|
676
|
+
"""
|
677
|
+
required_columns = {"Chrom", "Start", "End"}
|
678
|
+
|
679
|
+
if missing_columns := required_columns - set(df.columns):
|
680
|
+
raise ValueError(f"Missing required columns: {', '.join(missing_columns)}")
|
681
|
+
|
682
|
+
df["Start"] = df["Start"].astype(int)
|
683
|
+
df["End"] = df["End"].astype(int)
|
684
|
+
|
685
|
+
sorted_indices = natsorted(
|
686
|
+
df.index, key=lambda i: (df.at[i, "Chrom"], df.at[i, "Start"], df.at[i, "End"])
|
687
|
+
)
|
688
|
+
|
689
|
+
return df.loc[sorted_indices].reset_index(drop=True)
|