sai-pg 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. sai/__init__.py +2 -0
  2. sai/__main__.py +6 -3
  3. sai/configs/__init__.py +24 -0
  4. sai/configs/global_config.py +83 -0
  5. sai/configs/ploidy_config.py +94 -0
  6. sai/configs/pop_config.py +82 -0
  7. sai/configs/stat_config.py +220 -0
  8. sai/{utils/generators → generators}/chunk_generator.py +2 -8
  9. sai/{utils/generators → generators}/window_generator.py +82 -37
  10. sai/{utils/multiprocessing → multiprocessing}/mp_manager.py +2 -2
  11. sai/{utils/multiprocessing → multiprocessing}/mp_pool.py +2 -2
  12. sai/parsers/outlier_parser.py +4 -3
  13. sai/parsers/score_parser.py +8 -119
  14. sai/{utils/preprocessors → preprocessors}/chunk_preprocessor.py +21 -15
  15. sai/preprocessors/feature_preprocessor.py +236 -0
  16. sai/registries/__init__.py +22 -0
  17. sai/registries/generic_registry.py +89 -0
  18. sai/registries/stat_registry.py +30 -0
  19. sai/sai.py +124 -220
  20. sai/stats/__init__.py +11 -0
  21. sai/stats/danc_statistic.py +83 -0
  22. sai/stats/dd_statistic.py +77 -0
  23. sai/stats/df_statistic.py +84 -0
  24. sai/stats/dplus_statistic.py +86 -0
  25. sai/stats/fd_statistic.py +92 -0
  26. sai/stats/generic_statistic.py +93 -0
  27. sai/stats/q_statistic.py +104 -0
  28. sai/stats/stat_utils.py +259 -0
  29. sai/stats/u_statistic.py +99 -0
  30. sai/utils/utils.py +220 -143
  31. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/METADATA +3 -14
  32. sai_pg-1.1.0.dist-info/RECORD +70 -0
  33. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/WHEEL +1 -1
  34. sai_pg-1.1.0.dist-info/top_level.txt +2 -0
  35. tests/configs/test_global_config.py +163 -0
  36. tests/configs/test_ploidy_config.py +93 -0
  37. tests/configs/test_pop_config.py +90 -0
  38. tests/configs/test_stat_config.py +171 -0
  39. tests/generators/test_chunk_generator.py +51 -0
  40. tests/generators/test_window_generator.py +164 -0
  41. tests/multiprocessing/test_mp_manager.py +92 -0
  42. tests/multiprocessing/test_mp_pool.py +79 -0
  43. tests/parsers/test_argument_validation.py +133 -0
  44. tests/parsers/test_outlier_parser.py +53 -0
  45. tests/parsers/test_score_parser.py +63 -0
  46. tests/preprocessors/test_chunk_preprocessor.py +79 -0
  47. tests/preprocessors/test_feature_preprocessor.py +223 -0
  48. tests/registries/test_registries.py +74 -0
  49. tests/stats/test_danc_statistic.py +51 -0
  50. tests/stats/test_dd_statistic.py +45 -0
  51. tests/stats/test_df_statistic.py +73 -0
  52. tests/stats/test_dplus_statistic.py +79 -0
  53. tests/stats/test_fd_statistic.py +68 -0
  54. tests/stats/test_q_statistic.py +268 -0
  55. tests/stats/test_stat_utils.py +354 -0
  56. tests/stats/test_u_statistic.py +233 -0
  57. tests/test___main__.py +51 -0
  58. tests/test_sai.py +102 -0
  59. tests/utils/test_utils.py +511 -0
  60. sai/parsers/plot_parser.py +0 -152
  61. sai/stats/features.py +0 -302
  62. sai/utils/preprocessors/feature_preprocessor.py +0 -211
  63. sai_pg-1.0.0.dist-info/RECORD +0 -30
  64. sai_pg-1.0.0.dist-info/top_level.txt +0 -1
  65. /sai/{utils/generators → generators}/__init__.py +0 -0
  66. /sai/{utils/generators → generators}/data_generator.py +0 -0
  67. /sai/{utils/multiprocessing → multiprocessing}/__init__.py +0 -0
  68. /sai/{utils/preprocessors → preprocessors}/__init__.py +0 -0
  69. /sai/{utils/preprocessors → preprocessors}/data_preprocessor.py +0 -0
  70. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/entry_points.txt +0 -0
  71. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,99 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from typing import Dict, Any
23
+ from sai.registries.stat_registry import STAT_REGISTRY
24
+ from sai.stats import GenericStatistic
25
+ from sai.stats.stat_utils import compute_matching_loci
26
+
27
+
28
+ @STAT_REGISTRY.register("U")
29
+ class UStatistic(GenericStatistic):
30
+ """
31
+ Class for computing the number of uniquely shared sites between the target and source populations (Racimo et al. 2017. Mol Biol Evol),
32
+ conditional on allele frequency patterns in the reference and source populations.
33
+ """
34
+
35
+ STAT_NAME = "U"
36
+
37
+ def compute(self, **kwargs) -> Dict[str, Any]:
38
+ """
39
+ Computes the count of genetic loci that meet specified allele frequency conditions
40
+ across reference, target, and multiple source genotypes, with adjustments based on src_freq consistency.
41
+
42
+ Parameters
43
+ ----------
44
+ pos : np.ndarray
45
+ A 1D numpy array where each element represents the genomic position.
46
+ w : float
47
+ Threshold for the allele frequency in `ref_gts`. Only loci with frequencies less than `w` are counted.
48
+ Must be within the range [0, 1].
49
+ x : float
50
+ Threshold for the allele frequency in `tgt_gts`. Only loci with frequencies greater than `x` are counted.
51
+ Must be within the range [0, 1].
52
+ y_list : list[float]
53
+ List of exact allele frequency thresholds for each source population in `src_gts_list`.
54
+ Must be within the range [0, 1] and have the same length as `src_gts_list`.
55
+ anc_allele_available : bool
56
+ If True, checks only for matches with `y` (assuming `1` represents the derived allele).
57
+ If False, checks both matches with `y` and `1 - y`, taking the major allele in the source as the reference.
58
+
59
+ Returns
60
+ -------
61
+ dict
62
+ A dictionary containing:
63
+ - 'name' : str
64
+ The name of the statistic ("U").
65
+ - 'value' : int
66
+ The count of loci that meet all specified frequency conditions.
67
+ - 'ccd_pos' : np.ndarray
68
+ A 1D numpy array containing the genomic positions of the loci that meet the conditions.
69
+ """
70
+ required_keys = ["pos", "w", "x", "y_list", "anc_allele_available"]
71
+ if missing := [k for k in required_keys if k not in kwargs]:
72
+ raise ValueError(f"Missing required argument(s): {', '.join(missing)}")
73
+
74
+ pos = kwargs["pos"]
75
+ w = kwargs["w"]
76
+ x = kwargs["x"]
77
+ y_list = kwargs["y_list"]
78
+ anc_allele_available = kwargs["anc_allele_available"]
79
+ ploidy = [self.ref_ploidy, self.tgt_ploidy] + self.src_ploidy_list
80
+
81
+ ref_freq, tgt_freq, condition = compute_matching_loci(
82
+ self.ref_gts,
83
+ self.tgt_gts,
84
+ self.src_gts_list,
85
+ w,
86
+ y_list,
87
+ ploidy,
88
+ anc_allele_available,
89
+ )
90
+
91
+ # Apply final conditions
92
+ condition &= tgt_freq > x
93
+
94
+ loci_indices = np.where(condition)[0]
95
+ loci_positions = pos[loci_indices]
96
+ count = loci_indices.size
97
+
98
+ # Return count of matching loci
99
+ return {"name": self.STAT_NAME, "value": count, "cdd_pos": loci_positions}
sai/utils/utils.py CHANGED
@@ -19,11 +19,13 @@
19
19
 
20
20
 
21
21
  import allel
22
+ import warnings
22
23
  import numpy as np
23
24
  import pandas as pd
24
- from typing import Optional, Union
25
25
  from natsort import natsorted
26
+ from typing import Optional, Union
26
27
  from sai.utils.genomic_dataclasses import ChromosomeData
28
+ from sai.configs import PloidyConfig
27
29
 
28
30
 
29
31
  def parse_ind_file(filename: str) -> dict[str, list[str]]:
@@ -37,7 +39,7 @@ def parse_ind_file(filename: str) -> dict[str, list[str]]:
37
39
 
38
40
  Returns
39
41
  -------
40
- samples : dict of str to list of str
42
+ samples : dict[str, list[str]]
41
43
  A dictionary where the keys represent categories, and the values are lists of samples
42
44
  associated with those categories.
43
45
 
@@ -77,11 +79,12 @@ def read_geno_data(
77
79
  vcf: str,
78
80
  ind_samples: dict[str, list[str]],
79
81
  chr_name: str,
82
+ ploidy: int = 2,
80
83
  start: int = None,
81
84
  end: int = None,
82
85
  anc_allele_file: Optional[str] = None,
83
86
  filter_missing: bool = True,
84
- ) -> tuple[ChromosomeData, list[str], int]:
87
+ ) -> dict[str, ChromosomeData]:
85
88
  """
86
89
  Read genotype data from a VCF file efficiently for a specified chromosome.
87
90
 
@@ -93,9 +96,11 @@ def read_geno_data(
93
96
  A dictionary where keys are categories (e.g., different sample groups), and values are lists of sample names.
94
97
  chr_name : str
95
98
  The name of the chromosome to read.
96
- start: int, optional
99
+ ploidy : int, optional
100
+ Ploidy level of the genome.
101
+ start : int, optional
97
102
  The starting position (1-based, inclusive) on the chromosome. Default: None.
98
- end: int, optional
103
+ end : int, optional
99
104
  The ending position (1-based, inclusive) on the chromosome. Default: None.
100
105
  anc_allele_file : str, optional
101
106
  The name of the BED file containing ancestral allele information, or None if not provided.
@@ -104,12 +109,7 @@ def read_geno_data(
104
109
 
105
110
  Returns
106
111
  -------
107
- chrom_data: ChromosomeData
108
- A ChromosomeData instance for the specified chromosome in the VCF.
109
- samples: list
110
- A list of samples in the data.
111
- ploidy: int
112
- Ploidy level of the organism.
112
+ A dictionary mapping each population name to its ChromosomeData.
113
113
  """
114
114
  try:
115
115
  # Load all samples from the VCF file
@@ -132,21 +132,21 @@ def read_geno_data(
132
132
  ],
133
133
  alt_number=1,
134
134
  samples=all_samples,
135
+ numbers={"GT": ploidy},
135
136
  region=region, # Specify the chromosome region
136
137
  tabix=None,
137
138
  )
138
139
  except Exception as e:
139
140
  raise ValueError(f"Failed to read VCF file {vcf} from {region}: {e}") from e
140
141
 
141
- # Convert genotype data to a more efficient GenotypeArray
142
142
  if vcf_data is None:
143
- return None, all_samples, None
143
+ return None
144
144
 
145
145
  gt = allel.GenotypeArray(vcf_data.get("calldata/GT"))
146
146
  pos = vcf_data.get("variants/POS")
147
147
  ref = vcf_data.get("variants/REF")
148
148
  alt = vcf_data.get("variants/ALT")
149
- ploidy = gt.shape[2]
149
+ sample_names = list(vcf_data.get("samples"))
150
150
 
151
151
  if gt is None or pos is None or ref is None or alt is None:
152
152
  raise ValueError("Invalid VCF file: Missing essential genotype data fields.")
@@ -162,27 +162,33 @@ def read_geno_data(
162
162
  else:
163
163
  anc_alleles = None
164
164
 
165
- sample_indices = [all_samples.index(s) for s in all_samples]
165
+ chrom_data_dict = {}
166
166
 
167
- chrom_data = ChromosomeData(
168
- POS=pos, REF=ref, ALT=alt, GT=gt.take(sample_indices, axis=1)
169
- )
167
+ for pop, pop_samples in ind_samples.items():
168
+ indices = [sample_names.index(s) for s in pop_samples]
169
+ pop_gt = gt.take(indices, axis=1)
170
170
 
171
- # Remove missing data if specified
172
- if filter_missing:
173
- non_missing_index = chrom_data.GT.count_missing(axis=1) == 0
174
- num_missing = len(non_missing_index) - np.sum(non_missing_index)
175
- if num_missing != 0:
176
- print(
177
- f"Found {num_missing} variants with missing genotypes, removing them ..."
178
- )
179
- chrom_data = filter_geno_data(chrom_data, non_missing_index)
171
+ chrom_data = ChromosomeData(
172
+ POS=pos.copy(), REF=ref.copy(), ALT=alt.copy(), GT=pop_gt
173
+ )
180
174
 
181
- # Check and incorporate ancestral alleles if the file is provided
182
- if anc_alleles:
183
- chrom_data = check_anc_allele(chrom_data, anc_alleles, chr_name)
175
+ missing_mask = chrom_data.GT.count_missing(axis=1) != 0
184
176
 
185
- return chrom_data, vcf_data.get("samples"), ploidy
177
+ if filter_missing:
178
+ if np.any(missing_mask):
179
+ chrom_data = filter_geno_data(chrom_data, ~missing_mask)
180
+ else:
181
+ if np.any(missing_mask):
182
+ raise ValueError(
183
+ "Missing data is found. Please remove variants with missing data or enable filtering."
184
+ )
185
+
186
+ if anc_alleles:
187
+ chrom_data = check_anc_allele(chrom_data, anc_alleles, chr_name)
188
+
189
+ chrom_data_dict[pop] = chrom_data
190
+
191
+ return chrom_data_dict
186
192
 
187
193
 
188
194
  def filter_geno_data(
@@ -214,9 +220,11 @@ def filter_geno_data(
214
220
  def read_data(
215
221
  vcf_file: str,
216
222
  chr_name: str,
223
+ ploidy_config: PloidyConfig,
217
224
  ref_ind_file: Optional[str],
218
225
  tgt_ind_file: Optional[str],
219
226
  src_ind_file: Optional[str],
227
+ out_ind_file: Optional[str],
220
228
  anc_allele_file: Optional[str],
221
229
  start: int = None,
222
230
  end: int = None,
@@ -224,15 +232,18 @@ def read_data(
224
232
  filter_ref: bool = True,
225
233
  filter_tgt: bool = True,
226
234
  filter_src: bool = False,
235
+ filter_out: bool = False,
227
236
  filter_missing: bool = True,
228
- ) -> tuple[
229
- Optional[dict[str, dict[str, ChromosomeData]]],
230
- Optional[dict[str, list[str]]],
231
- Optional[dict[str, dict[str, ChromosomeData]]],
232
- Optional[dict[str, list[str]]],
233
- Optional[dict[str, dict[str, ChromosomeData]]],
234
- Optional[dict[str, list[str]]],
235
- Optional[int],
237
+ ) -> dict[
238
+ str,
239
+ tuple[
240
+ Optional[dict[str, dict[str, ChromosomeData]]],
241
+ Optional[dict[str, list[str]]],
242
+ Optional[dict[str, dict[str, ChromosomeData]]],
243
+ Optional[dict[str, list[str]]],
244
+ Optional[dict[str, dict[str, ChromosomeData]]],
245
+ Optional[dict[str, list[str]]],
246
+ ],
236
247
  ]:
237
248
  """
238
249
  Helper function for reading data from reference, target, and source populations.
@@ -241,14 +252,18 @@ def read_data(
241
252
  ----------
242
253
  vcf_file : str
243
254
  Name of the VCF file containing genotype data.
244
- chr_name: str
255
+ chr_name : str
245
256
  Name of the chromosome to read.
257
+ ploidy_config : PloidyConfig
258
+ Configuration specifying ploidy levels for each population involved in the analysis.
246
259
  ref_ind_file : str or None
247
260
  File with reference population sample information. None if not provided.
248
261
  tgt_ind_file : str or None
249
262
  File with target population sample information. None if not provided.
250
263
  src_ind_file : str or None
251
264
  File with source population sample information. None if not provided.
265
+ out_ind_file : str or None
266
+ File with outgroup population sample information. None if not provided.
252
267
  anc_allele_file : str or None
253
268
  File with ancestral allele information. None if not provided.
254
269
  start: int, optional
@@ -263,11 +278,21 @@ def read_data(
263
278
  Whether to filter fixed variants for target data. Default: True.
264
279
  filter_src : bool, optional
265
280
  Whether to filter fixed variants for source data. Default: False.
281
+ filter_out : bool, optional
282
+ Whether to filter fixed variants for outgroup data. Default: False.
266
283
  filter_missing : bool, optional
267
284
  Whether to filter out missing data. Default: True.
268
285
 
269
286
  Returns
270
287
  -------
288
+ result : dict
289
+ {
290
+ "ref": (ref_data, ref_samples),
291
+ "tgt": (tgt_data, tgt_samples),
292
+ "src": (src_data, src_samples),
293
+ "outgroup": (out_data, out_samples) # optional
294
+ }
295
+
271
296
  ref_data : dict or None
272
297
  Genotype data from reference populations, organized by category and chromosome.
273
298
  ref_samples : dict or None
@@ -280,12 +305,14 @@ def read_data(
280
305
  Genotype data from source populations, organized by category and chromosome.
281
306
  src_samples : dict or None
282
307
  Sample information from source populations.
283
- ploidy: int or None
284
- Ploidy level of the organism.
308
+ out_data : dict or None
309
+ Genotype data from outgroup populations, organized by category and chromosome.
310
+ out_samples : dict or None
311
+ Sample information from outgroup populations.
285
312
 
286
313
  Notes
287
314
  -----
288
- The `ref_data`, `tgt_data`, and `src_data` are organized as nested dictionaries where:
315
+ The `ref_data`, `tgt_data`, `src_data`, `out_data` are organized as nested dictionaries where:
289
316
 
290
317
  - The outermost keys represent different populations or sample categories.
291
318
  - The second-level keys represent different chromosomes.
@@ -298,111 +325,40 @@ def read_data(
298
325
  This organization allows easy access and manipulation of genotype data by category and chromosome,
299
326
  enabling flexible processing across different populations and chromosomal regions.
300
327
  """
301
- ref_data = ref_samples = tgt_data = tgt_samples = src_data = src_samples = None
302
-
303
- # Parse sample information
304
- if ref_ind_file:
305
- ref_samples = parse_ind_file(ref_ind_file)
306
-
307
- if tgt_ind_file:
308
- tgt_samples = parse_ind_file(tgt_ind_file)
309
-
310
- if src_ind_file:
311
- src_samples = parse_ind_file(src_ind_file)
312
-
313
- # Combine all samples for a single VCF read
314
- all_samples = {}
315
- if ref_samples:
316
- all_samples.update(ref_samples)
317
- if tgt_samples:
318
- all_samples.update(tgt_samples)
319
- if src_samples:
320
- all_samples.update(src_samples)
321
-
322
- try:
323
- # Read VCF data
324
- geno_data, all_samples, ploidy = read_geno_data(
325
- vcf=vcf_file,
326
- ind_samples=all_samples,
328
+ group_params = [
329
+ ("ref", ref_ind_file, filter_ref),
330
+ ("tgt", tgt_ind_file, filter_tgt),
331
+ ("src", src_ind_file, filter_src),
332
+ ("outgroup", out_ind_file, filter_out),
333
+ ]
334
+
335
+ results = {}
336
+
337
+ for group, ind_file, filter_flag in group_params:
338
+ if ind_file is None:
339
+ results[group] = (None, None)
340
+ continue
341
+
342
+ if group == "outgroup" and group not in ploidy_config.root:
343
+ results[group] = (None, None)
344
+ continue
345
+
346
+ data, samples = _load_population_data(
347
+ vcf_file=vcf_file,
327
348
  chr_name=chr_name,
349
+ sample_file=ind_file,
350
+ anc_allele_file=anc_allele_file,
328
351
  start=start,
329
352
  end=end,
330
- anc_allele_file=anc_allele_file,
353
+ is_phased=is_phased,
354
+ filter_flag=filter_flag,
331
355
  filter_missing=filter_missing,
356
+ ploidy_config=ploidy_config,
357
+ group=group,
332
358
  )
333
- except Exception as e:
334
- raise ValueError(f"Failed to read VCF data: {e}")
335
-
336
- if geno_data is None:
337
- return None, ref_samples, None, tgt_samples, None, src_samples, None
338
-
339
- # Separate reference, target, and source data
340
- ref_data = extract_group_data(geno_data, all_samples, ref_samples)
341
- tgt_data = extract_group_data(geno_data, all_samples, tgt_samples)
342
- src_data = extract_group_data(geno_data, all_samples, src_samples)
343
-
344
- # Apply fixed variant filtering conditionally
345
- if filter_ref and ref_data and ref_samples:
346
- ref_data = filter_fixed_variants(ref_data, ref_samples)
347
- if filter_tgt and tgt_data and tgt_samples:
348
- tgt_data = filter_fixed_variants(tgt_data, tgt_samples)
349
- if filter_src and src_data and src_samples:
350
- src_data = filter_fixed_variants(src_data, src_samples)
351
-
352
- # Adjust genotypes based on phased/unphased requirement
353
- reshape_genotypes(ref_data, is_phased)
354
- reshape_genotypes(tgt_data, is_phased)
355
- reshape_genotypes(src_data, is_phased)
356
-
357
- return ref_data, ref_samples, tgt_data, tgt_samples, src_data, src_samples, ploidy
358
-
359
-
360
- def extract_group_data(
361
- geno_data: dict[str, ChromosomeData],
362
- all_samples: list[str],
363
- sample_groups: Optional[dict[str, list[str]]] = None,
364
- ) -> Optional[dict[str, ChromosomeData]]:
365
- """
366
- Extract genotype data from geno_data based on the sample groups.
367
-
368
- Parameters
369
- ----------
370
- geno_data : dict of str to ChromosomeData
371
- Contains genotype data, where each value is a ChromosomeData instance.
372
- all_samples: list
373
- A list of all sample names in the dataset.
374
- sample_groups : dict of str to list of str, optional
375
- Contains sample group information, where each key is a group name and the value is a list of samples.
376
- If None, the function returns None.
377
-
378
- Returns
379
- -------
380
- extracted_data : dict or None
381
- Genotype data organized by sample group, or None if no sample groups are provided.
382
- The structure is as follows:
383
-
384
- - Keys represent sample group names.
385
- - Values are ChromosomeData instances, filtered to include only the samples in the specified group.
386
- """
387
- if sample_groups is None:
388
- return None
359
+ results[group] = (data, samples)
389
360
 
390
- sample_indices = {sample: idx for idx, sample in enumerate(all_samples)}
391
-
392
- extracted_data = {}
393
-
394
- for group, samples in sample_groups.items():
395
- indices = [sample_indices[s] for s in samples if s in sample_indices]
396
-
397
- # Extract ChromosomeData for the selected samples in each group
398
- extracted_data[group] = ChromosomeData(
399
- GT=geno_data.GT[:, indices, :],
400
- POS=geno_data.POS,
401
- REF=geno_data.REF,
402
- ALT=geno_data.ALT,
403
- )
404
-
405
- return extracted_data
361
+ return results
406
362
 
407
363
 
408
364
  def filter_fixed_variants(
@@ -608,6 +564,7 @@ def split_genome(
608
564
  pos: np.ndarray,
609
565
  window_size: int,
610
566
  step_size: int,
567
+ start: int = None,
611
568
  ) -> list[tuple]:
612
569
  """
613
570
  Creates sliding windows along the genome based on variant positions.
@@ -620,6 +577,9 @@ def split_genome(
620
577
  Length of each sliding window.
621
578
  step_size : int
622
579
  Step size of the sliding windows.
580
+ start: int, optional
581
+ Minimum starting coordinate for the first window. The first window will start
582
+ no smaller than this value. Default is None.
623
583
 
624
584
  Returns
625
585
  -------
@@ -644,7 +604,9 @@ def split_genome(
644
604
 
645
605
  window_positions = []
646
606
  win_start = (pos[0] + step_size) // step_size * step_size - window_size + 1
647
- win_start = max(win_start, 1)
607
+ if start is None:
608
+ start = 1
609
+ win_start = max(win_start, start)
648
610
 
649
611
  # Create windows based on step size and window size
650
612
  while win_start <= pos[-1]:
@@ -687,3 +649,118 @@ def natsorted_df(df: pd.DataFrame) -> pd.DataFrame:
687
649
  )
688
650
 
689
651
  return df.loc[sorted_indices].reset_index(drop=True)
652
+
653
+
654
+ def _load_population_data(
655
+ vcf_file: str,
656
+ chr_name: str,
657
+ sample_file: Optional[str],
658
+ anc_allele_file: Optional[str],
659
+ start: Optional[int],
660
+ end: Optional[int],
661
+ is_phased: bool,
662
+ filter_flag: bool,
663
+ filter_missing: bool,
664
+ ploidy_config: PloidyConfig,
665
+ group: str, # e.g., "ref", "tgt", "src"
666
+ ) -> tuple[
667
+ Optional[dict[str, dict[str, ChromosomeData]]], Optional[dict[str, list[str]]]
668
+ ]:
669
+ """
670
+ Loads genotype data and sample information for a population group (e.g., reference) from a VCF file,
671
+ handling multiple populations with potentially different ploidy.
672
+
673
+ Parameters
674
+ ----------
675
+ vcf_file : str
676
+ Path to the VCF file containing variant data.
677
+ chr_name : str
678
+ Chromosome name to extract from the VCF.
679
+ sample_file : str or None
680
+ Path to the file containing sample IDs grouped by population.
681
+ If None, no data is loaded.
682
+ anc_allele_file : str or None
683
+ Path to the BED file with ancestral allele annotations.
684
+ start : int or None
685
+ Start position on the chromosome (1-based, inclusive). If None, starts at the beginning.
686
+ end : int or None
687
+ End position on the chromosome (1-based, inclusive). If None, reads to the end.
688
+ is_phased : bool
689
+ Whether the genotypes are phased.
690
+ filter_flag : bool
691
+ Whether to remove variants fixed in all samples of each population.
692
+ filter_missing : bool
693
+ Whether to filter out variants with missing genotypes across all samples.
694
+ ploidy_config : PloidyConfig
695
+ Configuration containing ploidy for all populations in all groups.
696
+ group : str
697
+ The group label (e.g., "ref", "tgt", "src") used to extract populations from ploidy_config.
698
+
699
+ Returns
700
+ -------
701
+ data : dict[str, dict[str, ChromosomeData]] or None
702
+ Dictionary mapping population -> chromosome -> ChromosomeData.
703
+ samples : dict[str, list[str]] or None
704
+ Dictionary mapping population -> list of sample IDs.
705
+ """
706
+ if sample_file is None:
707
+ return None, None
708
+
709
+ samples = parse_ind_file(sample_file)
710
+
711
+ if group not in ploidy_config.root:
712
+ raise ValueError(f"Ploidy configuration missing group '{group}'.")
713
+
714
+ group_ploidies = ploidy_config.root[group]
715
+
716
+ # Ensure all populations in ploidy_config[group] are in sample_file
717
+ for population in group_ploidies:
718
+ if population not in samples:
719
+ raise ValueError(
720
+ f"Population '{population}' in ploidy_config[{group}] not found in sample file: {sample_file}"
721
+ )
722
+
723
+ data: dict[str, ChromosomeData] = {}
724
+
725
+ for population, sample_list in samples.items():
726
+ if population not in group_ploidies:
727
+ warnings.warn(
728
+ f"Population '{population}' found in sample file but not in ploidy_config[{group}]; skipping.",
729
+ RuntimeWarning,
730
+ )
731
+ continue
732
+
733
+ ploidy = group_ploidies[population]
734
+
735
+ try:
736
+ geno_data = read_geno_data(
737
+ vcf=vcf_file,
738
+ ind_samples={population: sample_list},
739
+ chr_name=chr_name,
740
+ start=start,
741
+ end=end,
742
+ anc_allele_file=anc_allele_file,
743
+ filter_missing=filter_missing,
744
+ ploidy=ploidy,
745
+ )
746
+ except Exception as e:
747
+ raise ValueError(
748
+ f"Failed to read VCF data for {sample_file}, population '{population}': {e}"
749
+ )
750
+
751
+ if geno_data is None:
752
+ continue
753
+
754
+ if filter_flag:
755
+ geno_data = filter_fixed_variants(geno_data, {population: sample_list})
756
+
757
+ reshape_genotypes(geno_data, is_phased)
758
+
759
+ data[population] = geno_data[
760
+ population
761
+ ] # geno_data: dict[population -> ChromosomeData]
762
+
763
+ if not data:
764
+ return None, samples
765
+
766
+ return data, samples
@@ -1,11 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sai-pg
3
- Version: 1.0.0
3
+ Version: 1.1.0
4
4
  Summary: A Python Package for Statistics for Adaptive Introgression
5
- Home-page: https://github.com/xin-huang/sai
6
- Author: Xin Huang
7
- Author-email: xinhuang.res@gmail.com
5
+ Author-email: Xin Huang <xinhuang.res@gmail.com>
8
6
  License: GPLv3
7
+ Project-URL: Homepage, https://github.com/xin-huang/sai
9
8
  Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
10
9
  Classifier: Programming Language :: Python
11
10
  Classifier: Programming Language :: Python :: 3.9
@@ -19,17 +18,7 @@ Requires-Dist: pandas==2.2.1
19
18
  Requires-Dist: pysam==0.23.0
20
19
  Requires-Dist: scikit-allel==1.3.7
21
20
  Requires-Dist: scipy==1.12.0
22
- Dynamic: author
23
- Dynamic: author-email
24
- Dynamic: classifier
25
- Dynamic: description
26
- Dynamic: description-content-type
27
- Dynamic: home-page
28
- Dynamic: license
29
21
  Dynamic: license-file
30
- Dynamic: requires-dist
31
- Dynamic: requires-python
32
- Dynamic: summary
33
22
 
34
23
  # SAI
35
24