sai-pg 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. sai/__init__.py +2 -0
  2. sai/__main__.py +6 -3
  3. sai/configs/__init__.py +24 -0
  4. sai/configs/global_config.py +83 -0
  5. sai/configs/ploidy_config.py +94 -0
  6. sai/configs/pop_config.py +82 -0
  7. sai/configs/stat_config.py +220 -0
  8. sai/{utils/generators → generators}/chunk_generator.py +2 -8
  9. sai/{utils/generators → generators}/window_generator.py +82 -37
  10. sai/{utils/multiprocessing → multiprocessing}/mp_manager.py +2 -2
  11. sai/{utils/multiprocessing → multiprocessing}/mp_pool.py +2 -2
  12. sai/parsers/outlier_parser.py +4 -3
  13. sai/parsers/score_parser.py +8 -119
  14. sai/{utils/preprocessors → preprocessors}/chunk_preprocessor.py +21 -15
  15. sai/preprocessors/feature_preprocessor.py +236 -0
  16. sai/registries/__init__.py +22 -0
  17. sai/registries/generic_registry.py +89 -0
  18. sai/registries/stat_registry.py +30 -0
  19. sai/sai.py +124 -220
  20. sai/stats/__init__.py +11 -0
  21. sai/stats/danc_statistic.py +83 -0
  22. sai/stats/dd_statistic.py +77 -0
  23. sai/stats/df_statistic.py +84 -0
  24. sai/stats/dplus_statistic.py +86 -0
  25. sai/stats/fd_statistic.py +92 -0
  26. sai/stats/generic_statistic.py +93 -0
  27. sai/stats/q_statistic.py +104 -0
  28. sai/stats/stat_utils.py +259 -0
  29. sai/stats/u_statistic.py +99 -0
  30. sai/utils/utils.py +220 -143
  31. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/METADATA +3 -14
  32. sai_pg-1.1.0.dist-info/RECORD +70 -0
  33. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/WHEEL +1 -1
  34. sai_pg-1.1.0.dist-info/top_level.txt +2 -0
  35. tests/configs/test_global_config.py +163 -0
  36. tests/configs/test_ploidy_config.py +93 -0
  37. tests/configs/test_pop_config.py +90 -0
  38. tests/configs/test_stat_config.py +171 -0
  39. tests/generators/test_chunk_generator.py +51 -0
  40. tests/generators/test_window_generator.py +164 -0
  41. tests/multiprocessing/test_mp_manager.py +92 -0
  42. tests/multiprocessing/test_mp_pool.py +79 -0
  43. tests/parsers/test_argument_validation.py +133 -0
  44. tests/parsers/test_outlier_parser.py +53 -0
  45. tests/parsers/test_score_parser.py +63 -0
  46. tests/preprocessors/test_chunk_preprocessor.py +79 -0
  47. tests/preprocessors/test_feature_preprocessor.py +223 -0
  48. tests/registries/test_registries.py +74 -0
  49. tests/stats/test_danc_statistic.py +51 -0
  50. tests/stats/test_dd_statistic.py +45 -0
  51. tests/stats/test_df_statistic.py +73 -0
  52. tests/stats/test_dplus_statistic.py +79 -0
  53. tests/stats/test_fd_statistic.py +68 -0
  54. tests/stats/test_q_statistic.py +268 -0
  55. tests/stats/test_stat_utils.py +354 -0
  56. tests/stats/test_u_statistic.py +233 -0
  57. tests/test___main__.py +51 -0
  58. tests/test_sai.py +102 -0
  59. tests/utils/test_utils.py +511 -0
  60. sai/parsers/plot_parser.py +0 -152
  61. sai/stats/features.py +0 -302
  62. sai/utils/preprocessors/feature_preprocessor.py +0 -211
  63. sai_pg-1.0.0.dist-info/RECORD +0 -30
  64. sai_pg-1.0.0.dist-info/top_level.txt +0 -1
  65. /sai/{utils/generators → generators}/__init__.py +0 -0
  66. /sai/{utils/generators → generators}/data_generator.py +0 -0
  67. /sai/{utils/multiprocessing → multiprocessing}/__init__.py +0 -0
  68. /sai/{utils/preprocessors → preprocessors}/__init__.py +0 -0
  69. /sai/{utils/preprocessors → preprocessors}/data_preprocessor.py +0 -0
  70. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/entry_points.txt +0 -0
  71. {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/licenses/LICENSE +0 -0
sai/sai.py CHANGED
@@ -20,28 +20,23 @@
20
20
 
21
21
  import os
22
22
  import warnings
23
+ import yaml
23
24
  import pandas as pd
24
- import matplotlib.pyplot as plt
25
- from matplotlib.ticker import MaxNLocator
26
- from sai.utils.generators import ChunkGenerator
27
- from sai.utils.preprocessors import ChunkPreprocessor
25
+ from pathlib import Path
26
+ from sai.generators import ChunkGenerator
27
+ from sai.preprocessors import ChunkPreprocessor
28
+ from sai.configs import GlobalConfig
28
29
  from sai.utils.utils import natsorted_df
29
30
 
30
31
 
31
32
  def score(
32
33
  vcf_file: str,
33
34
  chr_name: str,
34
- ref_ind_file: str,
35
- tgt_ind_file: str,
36
- src_ind_file: str,
37
35
  win_len: int,
38
36
  win_step: int,
39
- num_src: int,
40
37
  anc_allele_file: str,
41
- w: float,
42
- y: list[float],
43
38
  output_file: str,
44
- stat_type: str,
39
+ config: str,
45
40
  num_workers: int,
46
41
  ) -> None:
47
42
  """
@@ -53,55 +48,76 @@ def score(
53
48
  Path to the VCF file containing variant data.
54
49
  chr_name : str
55
50
  The chromosome name to be analyzed from the VCF file.
56
- ref_ind_file : str
57
- Path to the file containing reference population identifiers.
58
- tgt_ind_file : str
59
- Path to the file containing target population identifiers.
60
- src_ind_file : str
61
- Path to the file containing source population identifiers.
62
51
  win_len : int
63
52
  Length of each genomic window in base pairs.
64
53
  win_step : int
65
54
  Step size in base pairs between consecutive windows.
66
- num_src : int
67
- Number of source populations to include in each windowed analysis.
68
55
  anc_allele_file : str
69
56
  Path to the file containing ancestral allele information.
70
- w : float
71
- Frequency threshold for calculating feature vectors.
72
- y : list[float]
73
- List of frequency thresholds used for various calculations in feature vector processing.
74
57
  output_file : str
75
58
  File path to save the output of processed feature vectors.
76
- stat_type: str
77
- Specifies the type of statistic to compute.
59
+ config: str
60
+ Path to the YAML configuration file specifying the statistics and ploidies to compute.
78
61
  num_workers : int
79
62
  Number of parallel processes for multiprocessing.
80
63
  """
64
+ try:
65
+ with open(config, "r") as f:
66
+ config_dict = yaml.safe_load(f)
67
+ except FileNotFoundError:
68
+ raise FileNotFoundError(f"Configuration file '{config}' not found.")
69
+ except yaml.YAMLError as e:
70
+ raise ValueError(f"Error parsing YAML configuration file '{config}': {e}")
71
+
72
+ required_fields = ["statistics", "ploidies", "populations"]
73
+ missing_fields = [field for field in required_fields if field not in config_dict]
74
+
75
+ if missing_fields:
76
+ raise ValueError(
77
+ f"Missing required fields in configuration file '{config}': {', '.join(missing_fields)}"
78
+ )
79
+
80
+ global_config = GlobalConfig(**config_dict)
81
+
82
+ stat_config = global_config.statistics
83
+ ploidy_config = global_config.ploidies
84
+ pop_config = global_config.populations
85
+
81
86
  generator = ChunkGenerator(
82
87
  vcf_file=vcf_file,
83
88
  chr_name=chr_name,
84
89
  window_size=win_len,
85
90
  step_size=win_step,
86
- num_chunks=num_workers * 8,
91
+ # num_chunks=num_workers * 8,
92
+ num_chunks=1,
87
93
  )
88
94
 
89
95
  preprocessor = ChunkPreprocessor(
90
96
  vcf_file=vcf_file,
91
- ref_ind_file=ref_ind_file,
92
- tgt_ind_file=tgt_ind_file,
93
- src_ind_file=src_ind_file,
97
+ ref_ind_file=pop_config.get_population("ref"),
98
+ tgt_ind_file=pop_config.get_population("tgt"),
99
+ src_ind_file=pop_config.get_population("src"),
100
+ out_ind_file=pop_config.get_population("outgroup"),
94
101
  win_len=win_len,
95
102
  win_step=win_step,
96
- w=w,
97
- y=y,
98
103
  output_file=output_file,
99
- stat_type=stat_type,
104
+ ploidy_config=ploidy_config,
105
+ stat_config=stat_config,
100
106
  anc_allele_file=anc_allele_file,
101
- num_src=num_src,
102
107
  )
103
108
 
104
- header = f"Chrom\tStart\tEnd\tRef\tTgt\tSrc\tN(Variants)\t{stat_type}(w<{w},y=({','.join(f'{op}{val}' for op, val in y)}))\tCandidate\n"
109
+ src_pops = list(ploidy_config.root["src"].keys())
110
+
111
+ header_parts = ["Chrom", "Start", "End", "Ref", "Tgt", "Src", "N(Variants)"]
112
+
113
+ for stat_name in stat_config.root.keys():
114
+ if stat_name in ("U", "Q") or len(src_pops) <= 1:
115
+ header_parts.append(stat_name)
116
+ else:
117
+ for sp in src_pops:
118
+ header_parts.append(f"{stat_name}.{sp}")
119
+
120
+ header = "\t".join(header_parts) + "\n"
105
121
 
106
122
  directory = os.path.dirname(output_file)
107
123
  if directory:
@@ -109,6 +125,13 @@ def score(
109
125
  with open(output_file, "w") as f:
110
126
  f.write(header)
111
127
 
128
+ for key in ("U", "Q"):
129
+ if key in stat_config.root:
130
+ path = Path(output_file)
131
+ log_file = path.with_suffix(f".{key}.log")
132
+ with open(log_file, "w") as f:
133
+ f.write(f"Chrom\tStart\tEnd\t{key}_SNP\n")
134
+
112
135
  items = []
113
136
 
114
137
  for params in generator.get():
@@ -117,199 +140,80 @@ def score(
117
140
  preprocessor.process_items(items)
118
141
 
119
142
 
120
- def outlier(score_file: str, output: str, quantile: float) -> None:
143
+ def outlier(score_file: str, output_prefix: str, quantile: float) -> None:
121
144
  """
122
- Outputs rows exceeding the specified quantile for the chosen column ('U' or 'Q'),
123
- sorted by Start and then End columns.
145
+ Identifies outlier windows for each statistic column in a score file and
146
+ write them to separate output files.
147
+
148
+ This function reads a tab-delimited score file, determines which columns
149
+ contain statistics (e.g., U, Q, D+, etc.), computes the specified quantile
150
+ threshold for each statistic, and outputs rows exceeding that threshold.
151
+ Results for each statistic are written to a separate TSV file, sorted by
152
+ Chrom, Start, and End when available.
124
153
 
125
154
  Parameters
126
155
  ----------
127
156
  score_file : str
128
- Path to the input file, in CSV format.
129
- output : str
130
- Path to the output file.
157
+ Path to the input score file (tab-delimited).
158
+ output_prefix : str
159
+ Prefix for the output files. Each output file is named
160
+ "{output_prefix}.{stat}.tsv".
131
161
  quantile : float
132
- Quantile threshold to filter rows.
162
+ Quantile threshold (between 0 and 1) used to define outliers.
133
163
  """
134
- # Read the input data file
135
- data = pd.read_csv(
136
- score_file,
137
- sep="\t",
138
- na_values=["nan"],
139
- dtype={"Candidate": str},
140
- index_col=False,
141
- )
164
+ df = pd.read_csv(score_file, sep="\t", na_values=["nan"], index_col=False)
142
165
 
143
- column = data.columns[-2]
144
-
145
- # Convert column to numeric for computation
146
- data[column] = pd.to_numeric(data[column], errors="coerce")
147
-
148
- # Calculate quantile threshold for the chosen column
149
- threshold = data[column].quantile(quantile)
150
-
151
- if data[column].nunique() == 1:
152
- warnings.warn(
153
- f"Column '{column}' contains only one unique value ({threshold}), making quantile filtering meaningless.",
154
- UserWarning,
155
- )
156
- outliers = pd.DataFrame(columns=data.columns)
157
- elif (threshold == 1) and (column.startswith("Q")):
158
- outliers = data[data[column] >= threshold]
159
- else:
160
- outliers = data[data[column] > threshold]
161
-
162
- # Sort the filtered data by 'Chrom', 'Start', 'End' columns
163
- if not outliers.empty:
164
- outliers = outliers.reset_index(drop=True)
165
- outliers_sorted = natsorted_df(outliers)
166
+ cols = list(df.columns)
167
+ if "N(Variants)" in cols:
168
+ start_idx = cols.index("N(Variants)") + 1
169
+ metric_cols = cols[start_idx:]
166
170
  else:
167
- outliers_sorted = outliers
168
-
169
- # Convert all columns to string before saving
170
- outliers_sorted = outliers_sorted.astype(str)
171
-
172
- # Save the sorted filtered data to the output file
173
- outliers_sorted.to_csv(output, index=False, sep="\t")
174
-
175
-
176
- def plot(
177
- u_file: str,
178
- q_file: str,
179
- output: str,
180
- xlabel: str,
181
- ylabel: str,
182
- title: str,
183
- figsize_x: float = 6,
184
- figsize_y: float = 6,
185
- dpi: int = 300,
186
- alpha: float = 0.6,
187
- marker_size: float = 20,
188
- marker_color: str = "blue",
189
- marker_style: str = "o",
190
- ) -> None:
191
- """
192
- Reads two score/outlier files (U and Q), finds common candidate positions, and plots U vs. Q.
193
-
194
- Parameters
195
- ----------
196
- u_file : str
197
- Path to the input file containing U score/outlier data.
198
- q_file : str
199
- Path to the input file containing Q score/outlier data.
200
- output : str
201
- Path to save the output plot.
202
- xlabel : str
203
- Label for the X-axis.
204
- ylabel : str
205
- Label for the Y-axis.
206
- title : str
207
- Title of the plot.
208
- figsize_x : float, optional
209
- Width of the figure (default: 6).
210
- figsize_y : float, optional
211
- Height of the figure (default: 6).
212
- dpi : int, optional
213
- Resolution of the saved plot (default: 300).
214
- alpha : float, optional
215
- Transparency level of scatter points (default: 0.6).
216
- marker_size : float, optional
217
- Size of the scatter plot markers (default: 20).
218
- marker_color : str, optional
219
- Color of the markers (default: "blue").
220
- marker_style : str, optional
221
- Shape of the marker (default: "o").
222
- """
223
- u_data = pd.read_csv(u_file, sep="\t")
224
- q_data = pd.read_csv(q_file, sep="\t")
225
-
226
- u_column = u_data.columns[-2]
227
- q_column = q_data.columns[-2]
228
-
229
- u_data["interval"] = (
230
- u_data["Chrom"].astype(str)
231
- + ":"
232
- + u_data["Start"].astype(str)
233
- + "-"
234
- + u_data["End"].astype(str)
235
- )
236
- q_data["interval"] = (
237
- q_data["Chrom"].astype(str)
238
- + ":"
239
- + q_data["Start"].astype(str)
240
- + "-"
241
- + q_data["End"].astype(str)
242
- )
243
-
244
- u_data[u_column] = pd.to_numeric(u_data[u_column], errors="coerce")
245
- q_data[q_column] = pd.to_numeric(q_data[q_column], errors="coerce")
246
- u_data = u_data.dropna(subset=[u_column])
247
- q_data = q_data.dropna(subset=[q_column])
248
-
249
- u_interval_dict = {row["interval"]: row[u_column] for _, row in u_data.iterrows()}
250
- q_interval_dict = {row["interval"]: row[q_column] for _, row in q_data.iterrows()}
251
- u_candidate_dict = {
252
- row["interval"]: set(str(row["Candidate"]).split(","))
253
- for _, row in u_data.iterrows()
254
- }
255
- q_candidate_dict = {
256
- row["interval"]: set(str(row["Candidate"]).split(","))
257
- for _, row in q_data.iterrows()
258
- }
259
-
260
- common_intervals = set(u_interval_dict.keys()) & set(q_interval_dict.keys())
261
- if not common_intervals:
262
- raise ValueError(
263
- "No common genomic intervals found between U and Q score/outlier files."
171
+ # fallback: exclude common non-metric columns, keep numeric ones
172
+ non_metrics = {"Chrom", "Start", "End", "Ref", "Tgt", "Src"}
173
+ candidate = [c for c in cols if c not in non_metrics]
174
+ metric_cols = [
175
+ c for c in candidate if pd.to_numeric(df[c], errors="coerce").notna().any()
176
+ ]
177
+
178
+ if not metric_cols:
179
+ raise ValueError("No metric columns found.")
180
+
181
+ for col in metric_cols:
182
+ s_num = pd.to_numeric(df[col], errors="coerce").dropna()
183
+
184
+ if s_num.empty:
185
+ warnings.warn(
186
+ f"Column '{col}' has no numeric values; writing empty result.",
187
+ UserWarning,
188
+ )
189
+ out_sorted = pd.DataFrame(columns=df.columns)
190
+ elif s_num.nunique() == 1:
191
+ thr = s_num.iloc[0]
192
+ warnings.warn(
193
+ f"Column '{col}' has only one unique value ({thr}); writing empty result.",
194
+ UserWarning,
195
+ )
196
+ out_sorted = pd.DataFrame(columns=df.columns)
197
+ else:
198
+ thr = s_num.quantile(quantile)
199
+ col_num = pd.to_numeric(df[col], errors="coerce")
200
+ if not col.startswith("U"):
201
+ out = df[col_num >= thr]
202
+ else:
203
+ out = df[col_num > thr]
204
+
205
+ if not out.empty:
206
+ out = out.reset_index(drop=True)
207
+ try:
208
+ out_sorted = natsorted_df(out) # your existing natural sort
209
+ except NameError:
210
+ keys = [k for k in ("Chrom", "Start", "End") if k in out.columns]
211
+ out_sorted = (
212
+ out.sort_values(by=keys, kind="mergesort") if keys else out
213
+ )
214
+ else:
215
+ out_sorted = out
216
+
217
+ out_sorted.astype(str).to_csv(
218
+ f"{output_prefix}.{col}.{quantile}.outliers.tsv", index=False, sep="\t"
264
219
  )
265
-
266
- # Helper: get candidate overlap or "."
267
- def get_candidate_overlap(interval):
268
- u_set = u_candidate_dict.get(interval, set())
269
- q_set = q_candidate_dict.get(interval, set())
270
- overlap = sorted(u_set & q_set)
271
- return ",".join(overlap) if overlap else "NA"
272
-
273
- overlap_df = pd.DataFrame(
274
- {
275
- "Chrom": [interval.split(":")[0] for interval in common_intervals],
276
- "Start": [
277
- int(interval.split(":")[1].split("-")[0])
278
- for interval in common_intervals
279
- ],
280
- "End": [
281
- int(interval.split(":")[1].split("-")[1])
282
- for interval in common_intervals
283
- ],
284
- u_column: [u_interval_dict[c] for c in common_intervals],
285
- q_column: [q_interval_dict[c] for c in common_intervals],
286
- "Overlapping Candidate": [
287
- get_candidate_overlap(c) for c in common_intervals
288
- ],
289
- }
290
- )
291
-
292
- overlap_df_sorted = natsorted_df(overlap_df)
293
- overlap_output = os.path.splitext(output)[0] + ".overlap.tsv"
294
- pd.DataFrame(overlap_df_sorted).to_csv(overlap_output, sep="\t", index=False)
295
-
296
- plt.figure(figsize=(figsize_x, figsize_y))
297
- plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))
298
- plt.scatter(
299
- x=overlap_df[q_column],
300
- y=overlap_df[u_column],
301
- alpha=alpha,
302
- s=marker_size,
303
- c=marker_color,
304
- marker=marker_style,
305
- )
306
- xmin, xmax = plt.gca().get_xlim()
307
- ymin, ymax = plt.gca().get_ylim()
308
- plt.xlim(left=max(0, xmin))
309
- plt.ylim(bottom=max(0, ymin))
310
- plt.xlabel(xlabel)
311
- plt.ylabel(ylabel)
312
- plt.title(title)
313
- plt.grid(alpha=0.5, linestyle="--")
314
- plt.savefig(output, bbox_inches="tight", dpi=dpi)
315
- plt.close()
sai/stats/__init__.py CHANGED
@@ -16,3 +16,14 @@
16
16
  # along with this program. If not, please see
17
17
  #
18
18
  # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ from .generic_statistic import GenericStatistic
22
+ from .danc_statistic import DancStatistic
23
+ from .dd_statistic import DdStatistic
24
+ from .df_statistic import DfStatistic
25
+ from .dplus_statistic import DplusStatistic
26
+ from .fd_statistic import FdStatistic
27
+ from .q_statistic import QStatistic
28
+ from .u_statistic import UStatistic
29
+ from .stat_utils import *
@@ -0,0 +1,83 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from typing import Dict, Any
23
+ from sai.registries.stat_registry import STAT_REGISTRY
24
+ from sai.stats import GenericStatistic
25
+ from sai.stats.stat_utils import calc_four_pops_freq, calc_pattern_sum
26
+
27
+
28
+ @STAT_REGISTRY.register("Danc")
29
+ class DancStatistic(GenericStatistic):
30
+ """
31
+ Class for computing the Danc statistic (Fang et al. 2024. PLoS Genet)
32
+
33
+ The Danc statistic detects asymmetric ancestry contribution by comparing
34
+ excess BAAA and ABAA site patterns in a four-population framework.
35
+ """
36
+
37
+ STAT_NAME = "Danc"
38
+
39
+ def compute(self, **kwargs) -> Dict[str, Any]:
40
+ """
41
+ Computes the Danc statistic for each source population.
42
+
43
+ This method computes the statistic per source population using four-population
44
+ site pattern counts.
45
+
46
+ Parameters
47
+ ----------
48
+ **kwargs : dict
49
+ Unused. Present to maintain compatibility with the base class interface.
50
+
51
+ Returns
52
+ -------
53
+ dict
54
+ A dictionary containing:
55
+ - 'name' : str
56
+ The name of the statistic ("Danc").
57
+ - 'value' : list[float]
58
+ A list of Danc values, one for each source population.
59
+ """
60
+ danc_results = []
61
+
62
+ for i in range(len(self.src_gts_list)):
63
+ ref_freq, tgt_freq, src_freq, out_freq = calc_four_pops_freq(
64
+ ref_gts=self.ref_gts,
65
+ tgt_gts=self.tgt_gts,
66
+ src_gts=self.src_gts_list[i],
67
+ out_gts=self.out_gts,
68
+ ref_ploidy=self.ref_ploidy,
69
+ tgt_ploidy=self.tgt_ploidy,
70
+ src_ploidy=self.src_ploidy_list[i],
71
+ out_ploidy=self.out_ploidy,
72
+ )
73
+
74
+ baaa = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "baaa")
75
+ abaa = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "abaa")
76
+
77
+ numerator = baaa - abaa
78
+ denominator = baaa + abaa
79
+
80
+ danc = numerator / denominator if denominator != 0 else np.nan
81
+ danc_results.append(danc)
82
+
83
+ return {"name": self.STAT_NAME, "value": danc_results}
@@ -0,0 +1,77 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from scipy.spatial.distance import cdist
23
+ from typing import Dict, Any
24
+ from sai.registries.stat_registry import STAT_REGISTRY
25
+ from sai.stats import GenericStatistic
26
+
27
+
28
+ @STAT_REGISTRY.register("DD")
29
+ class DdStatistic(GenericStatistic):
30
+ """
31
+ Class for computing the average difference of the sequence divergence.
32
+
33
+ The DD statistic quantifies the difference in average pairwise sequence
34
+ divergence between a source population and two target populations (reference
35
+ and target), using Manhattan (cityblock) distance.
36
+ """
37
+
38
+ STAT_NAME = "DD"
39
+
40
+ def compute(self, **kwargs) -> Dict[str, Any]:
41
+ """
42
+ Computes the DD statistic for each source population.
43
+
44
+ For each source population, the method calculates pairwise Manhattan distances
45
+ between the source and both the target and reference populations, averages the
46
+ distances per genome, and computes the difference in mean divergence.
47
+
48
+ Parameters
49
+ ----------
50
+ **kwargs : dict
51
+ Unused. Present to maintain compatibility with the base class interface.
52
+
53
+ Returns
54
+ -------
55
+ dict
56
+ A dictionary containing:
57
+ - 'name' : str
58
+ The name of the statistic ("DD").
59
+ - 'value' : list[float]
60
+ A list of DD values, one for each source population.
61
+ """
62
+ dd_results = []
63
+
64
+ for i in range(len(self.src_gts_list)):
65
+ # pairwise distances
66
+ src_gts = self.src_gts_list[i]
67
+ seq_divs_src_tgt = cdist(src_gts.T, self.tgt_gts.T, metric="cityblock")
68
+ seq_divs_src_ref = cdist(src_gts.T, self.ref_gts.T, metric="cityblock")
69
+
70
+ # mean of each row
71
+ mean_src_tgt = np.mean(seq_divs_src_tgt, axis=1)
72
+ mean_src_ref = np.mean(seq_divs_src_ref, axis=1)
73
+
74
+ dd = np.mean(mean_src_ref - mean_src_tgt)
75
+ dd_results.append(dd)
76
+
77
+ return {"name": self.STAT_NAME, "value": dd_results}
@@ -0,0 +1,84 @@
1
+ # Copyright 2025 Xin Huang
2
+ #
3
+ # GNU General Public License v3.0
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, please see
17
+ #
18
+ # https://www.gnu.org/licenses/gpl-3.0.en.html
19
+
20
+
21
+ import numpy as np
22
+ from typing import Dict, Any
23
+ from sai.registries.stat_registry import STAT_REGISTRY
24
+ from sai.stats import GenericStatistic
25
+ from sai.stats.stat_utils import calc_four_pops_freq, calc_pattern_sum
26
+
27
+
28
+ @STAT_REGISTRY.register("df")
29
+ class DfStatistic(GenericStatistic):
30
+ """
31
+ Class for computing the distance fraction (df) statistic (Pfeifer and Kapan. 2019. BMC Bioinformatics).
32
+
33
+ The df statistic quantifies the relative excess of shared derived alleles
34
+ using ABBA, BABA, and BBAA site patterns across a four-population test.
35
+ """
36
+
37
+ STAT_NAME = "df"
38
+
39
+ def compute(self, **kwargs) -> Dict[str, Any]:
40
+ """
41
+ Computes the df statistic for each source population.
42
+
43
+ This method computes df for each source population using site pattern
44
+ counts based on allele frequency input.
45
+
46
+ Parameters
47
+ ----------
48
+ **kwargs : dict
49
+ Unused. Present to maintain compatibility with the base class interface.
50
+
51
+ Returns
52
+ -------
53
+ dict
54
+ A dictionary containing:
55
+ - 'name' : str
56
+ The name of the statistic ("df").
57
+ - 'value' : list[float]
58
+ A list of df values, one per source population.
59
+ """
60
+ df_results = []
61
+
62
+ for i in range(len(self.src_gts_list)):
63
+ ref_freq, tgt_freq, src_freq, out_freq = calc_four_pops_freq(
64
+ ref_gts=self.ref_gts,
65
+ tgt_gts=self.tgt_gts,
66
+ src_gts=self.src_gts_list[i],
67
+ out_gts=self.out_gts,
68
+ ref_ploidy=self.ref_ploidy,
69
+ tgt_ploidy=self.tgt_ploidy,
70
+ src_ploidy=self.src_ploidy_list[i],
71
+ out_ploidy=self.out_ploidy,
72
+ )
73
+
74
+ abba = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "abba")
75
+ baba = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "baba")
76
+ bbaa = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "bbaa")
77
+
78
+ numerator = abba - baba
79
+ denominator = abba + baba + 2 * bbaa
80
+
81
+ df = numerator / denominator if denominator != 0 else np.nan
82
+ df_results.append(df)
83
+
84
+ return {"name": self.STAT_NAME, "value": df_results}