sai-pg 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sai/__init__.py +2 -0
- sai/__main__.py +6 -3
- sai/configs/__init__.py +24 -0
- sai/configs/global_config.py +83 -0
- sai/configs/ploidy_config.py +94 -0
- sai/configs/pop_config.py +82 -0
- sai/configs/stat_config.py +220 -0
- sai/{utils/generators → generators}/chunk_generator.py +2 -8
- sai/{utils/generators → generators}/window_generator.py +82 -37
- sai/{utils/multiprocessing → multiprocessing}/mp_manager.py +2 -2
- sai/{utils/multiprocessing → multiprocessing}/mp_pool.py +2 -2
- sai/parsers/outlier_parser.py +4 -3
- sai/parsers/score_parser.py +8 -119
- sai/{utils/preprocessors → preprocessors}/chunk_preprocessor.py +21 -15
- sai/preprocessors/feature_preprocessor.py +236 -0
- sai/registries/__init__.py +22 -0
- sai/registries/generic_registry.py +89 -0
- sai/registries/stat_registry.py +30 -0
- sai/sai.py +124 -220
- sai/stats/__init__.py +11 -0
- sai/stats/danc_statistic.py +83 -0
- sai/stats/dd_statistic.py +77 -0
- sai/stats/df_statistic.py +84 -0
- sai/stats/dplus_statistic.py +86 -0
- sai/stats/fd_statistic.py +92 -0
- sai/stats/generic_statistic.py +93 -0
- sai/stats/q_statistic.py +104 -0
- sai/stats/stat_utils.py +259 -0
- sai/stats/u_statistic.py +99 -0
- sai/utils/utils.py +220 -143
- {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/METADATA +3 -14
- sai_pg-1.1.0.dist-info/RECORD +70 -0
- {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/WHEEL +1 -1
- sai_pg-1.1.0.dist-info/top_level.txt +2 -0
- tests/configs/test_global_config.py +163 -0
- tests/configs/test_ploidy_config.py +93 -0
- tests/configs/test_pop_config.py +90 -0
- tests/configs/test_stat_config.py +171 -0
- tests/generators/test_chunk_generator.py +51 -0
- tests/generators/test_window_generator.py +164 -0
- tests/multiprocessing/test_mp_manager.py +92 -0
- tests/multiprocessing/test_mp_pool.py +79 -0
- tests/parsers/test_argument_validation.py +133 -0
- tests/parsers/test_outlier_parser.py +53 -0
- tests/parsers/test_score_parser.py +63 -0
- tests/preprocessors/test_chunk_preprocessor.py +79 -0
- tests/preprocessors/test_feature_preprocessor.py +223 -0
- tests/registries/test_registries.py +74 -0
- tests/stats/test_danc_statistic.py +51 -0
- tests/stats/test_dd_statistic.py +45 -0
- tests/stats/test_df_statistic.py +73 -0
- tests/stats/test_dplus_statistic.py +79 -0
- tests/stats/test_fd_statistic.py +68 -0
- tests/stats/test_q_statistic.py +268 -0
- tests/stats/test_stat_utils.py +354 -0
- tests/stats/test_u_statistic.py +233 -0
- tests/test___main__.py +51 -0
- tests/test_sai.py +102 -0
- tests/utils/test_utils.py +511 -0
- sai/parsers/plot_parser.py +0 -152
- sai/stats/features.py +0 -302
- sai/utils/preprocessors/feature_preprocessor.py +0 -211
- sai_pg-1.0.0.dist-info/RECORD +0 -30
- sai_pg-1.0.0.dist-info/top_level.txt +0 -1
- /sai/{utils/generators → generators}/__init__.py +0 -0
- /sai/{utils/generators → generators}/data_generator.py +0 -0
- /sai/{utils/multiprocessing → multiprocessing}/__init__.py +0 -0
- /sai/{utils/preprocessors → preprocessors}/__init__.py +0 -0
- /sai/{utils/preprocessors → preprocessors}/data_preprocessor.py +0 -0
- {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/entry_points.txt +0 -0
- {sai_pg-1.0.0.dist-info → sai_pg-1.1.0.dist-info}/licenses/LICENSE +0 -0
sai/sai.py
CHANGED
@@ -20,28 +20,23 @@
|
|
20
20
|
|
21
21
|
import os
|
22
22
|
import warnings
|
23
|
+
import yaml
|
23
24
|
import pandas as pd
|
24
|
-
|
25
|
-
from
|
26
|
-
from sai.
|
27
|
-
from sai.
|
25
|
+
from pathlib import Path
|
26
|
+
from sai.generators import ChunkGenerator
|
27
|
+
from sai.preprocessors import ChunkPreprocessor
|
28
|
+
from sai.configs import GlobalConfig
|
28
29
|
from sai.utils.utils import natsorted_df
|
29
30
|
|
30
31
|
|
31
32
|
def score(
|
32
33
|
vcf_file: str,
|
33
34
|
chr_name: str,
|
34
|
-
ref_ind_file: str,
|
35
|
-
tgt_ind_file: str,
|
36
|
-
src_ind_file: str,
|
37
35
|
win_len: int,
|
38
36
|
win_step: int,
|
39
|
-
num_src: int,
|
40
37
|
anc_allele_file: str,
|
41
|
-
w: float,
|
42
|
-
y: list[float],
|
43
38
|
output_file: str,
|
44
|
-
|
39
|
+
config: str,
|
45
40
|
num_workers: int,
|
46
41
|
) -> None:
|
47
42
|
"""
|
@@ -53,55 +48,76 @@ def score(
|
|
53
48
|
Path to the VCF file containing variant data.
|
54
49
|
chr_name : str
|
55
50
|
The chromosome name to be analyzed from the VCF file.
|
56
|
-
ref_ind_file : str
|
57
|
-
Path to the file containing reference population identifiers.
|
58
|
-
tgt_ind_file : str
|
59
|
-
Path to the file containing target population identifiers.
|
60
|
-
src_ind_file : str
|
61
|
-
Path to the file containing source population identifiers.
|
62
51
|
win_len : int
|
63
52
|
Length of each genomic window in base pairs.
|
64
53
|
win_step : int
|
65
54
|
Step size in base pairs between consecutive windows.
|
66
|
-
num_src : int
|
67
|
-
Number of source populations to include in each windowed analysis.
|
68
55
|
anc_allele_file : str
|
69
56
|
Path to the file containing ancestral allele information.
|
70
|
-
w : float
|
71
|
-
Frequency threshold for calculating feature vectors.
|
72
|
-
y : list[float]
|
73
|
-
List of frequency thresholds used for various calculations in feature vector processing.
|
74
57
|
output_file : str
|
75
58
|
File path to save the output of processed feature vectors.
|
76
|
-
|
77
|
-
|
59
|
+
config: str
|
60
|
+
Path to the YAML configuration file specifying the statistics and ploidies to compute.
|
78
61
|
num_workers : int
|
79
62
|
Number of parallel processes for multiprocessing.
|
80
63
|
"""
|
64
|
+
try:
|
65
|
+
with open(config, "r") as f:
|
66
|
+
config_dict = yaml.safe_load(f)
|
67
|
+
except FileNotFoundError:
|
68
|
+
raise FileNotFoundError(f"Configuration file '{config}' not found.")
|
69
|
+
except yaml.YAMLError as e:
|
70
|
+
raise ValueError(f"Error parsing YAML configuration file '{config}': {e}")
|
71
|
+
|
72
|
+
required_fields = ["statistics", "ploidies", "populations"]
|
73
|
+
missing_fields = [field for field in required_fields if field not in config_dict]
|
74
|
+
|
75
|
+
if missing_fields:
|
76
|
+
raise ValueError(
|
77
|
+
f"Missing required fields in configuration file '{config}': {', '.join(missing_fields)}"
|
78
|
+
)
|
79
|
+
|
80
|
+
global_config = GlobalConfig(**config_dict)
|
81
|
+
|
82
|
+
stat_config = global_config.statistics
|
83
|
+
ploidy_config = global_config.ploidies
|
84
|
+
pop_config = global_config.populations
|
85
|
+
|
81
86
|
generator = ChunkGenerator(
|
82
87
|
vcf_file=vcf_file,
|
83
88
|
chr_name=chr_name,
|
84
89
|
window_size=win_len,
|
85
90
|
step_size=win_step,
|
86
|
-
num_chunks=num_workers * 8,
|
91
|
+
# num_chunks=num_workers * 8,
|
92
|
+
num_chunks=1,
|
87
93
|
)
|
88
94
|
|
89
95
|
preprocessor = ChunkPreprocessor(
|
90
96
|
vcf_file=vcf_file,
|
91
|
-
ref_ind_file=
|
92
|
-
tgt_ind_file=
|
93
|
-
src_ind_file=
|
97
|
+
ref_ind_file=pop_config.get_population("ref"),
|
98
|
+
tgt_ind_file=pop_config.get_population("tgt"),
|
99
|
+
src_ind_file=pop_config.get_population("src"),
|
100
|
+
out_ind_file=pop_config.get_population("outgroup"),
|
94
101
|
win_len=win_len,
|
95
102
|
win_step=win_step,
|
96
|
-
w=w,
|
97
|
-
y=y,
|
98
103
|
output_file=output_file,
|
99
|
-
|
104
|
+
ploidy_config=ploidy_config,
|
105
|
+
stat_config=stat_config,
|
100
106
|
anc_allele_file=anc_allele_file,
|
101
|
-
num_src=num_src,
|
102
107
|
)
|
103
108
|
|
104
|
-
|
109
|
+
src_pops = list(ploidy_config.root["src"].keys())
|
110
|
+
|
111
|
+
header_parts = ["Chrom", "Start", "End", "Ref", "Tgt", "Src", "N(Variants)"]
|
112
|
+
|
113
|
+
for stat_name in stat_config.root.keys():
|
114
|
+
if stat_name in ("U", "Q") or len(src_pops) <= 1:
|
115
|
+
header_parts.append(stat_name)
|
116
|
+
else:
|
117
|
+
for sp in src_pops:
|
118
|
+
header_parts.append(f"{stat_name}.{sp}")
|
119
|
+
|
120
|
+
header = "\t".join(header_parts) + "\n"
|
105
121
|
|
106
122
|
directory = os.path.dirname(output_file)
|
107
123
|
if directory:
|
@@ -109,6 +125,13 @@ def score(
|
|
109
125
|
with open(output_file, "w") as f:
|
110
126
|
f.write(header)
|
111
127
|
|
128
|
+
for key in ("U", "Q"):
|
129
|
+
if key in stat_config.root:
|
130
|
+
path = Path(output_file)
|
131
|
+
log_file = path.with_suffix(f".{key}.log")
|
132
|
+
with open(log_file, "w") as f:
|
133
|
+
f.write(f"Chrom\tStart\tEnd\t{key}_SNP\n")
|
134
|
+
|
112
135
|
items = []
|
113
136
|
|
114
137
|
for params in generator.get():
|
@@ -117,199 +140,80 @@ def score(
|
|
117
140
|
preprocessor.process_items(items)
|
118
141
|
|
119
142
|
|
120
|
-
def outlier(score_file: str,
|
143
|
+
def outlier(score_file: str, output_prefix: str, quantile: float) -> None:
|
121
144
|
"""
|
122
|
-
|
123
|
-
|
145
|
+
Identifies outlier windows for each statistic column in a score file and
|
146
|
+
write them to separate output files.
|
147
|
+
|
148
|
+
This function reads a tab-delimited score file, determines which columns
|
149
|
+
contain statistics (e.g., U, Q, D+, etc.), computes the specified quantile
|
150
|
+
threshold for each statistic, and outputs rows exceeding that threshold.
|
151
|
+
Results for each statistic are written to a separate TSV file, sorted by
|
152
|
+
Chrom, Start, and End when available.
|
124
153
|
|
125
154
|
Parameters
|
126
155
|
----------
|
127
156
|
score_file : str
|
128
|
-
Path to the input file
|
129
|
-
|
130
|
-
|
157
|
+
Path to the input score file (tab-delimited).
|
158
|
+
output_prefix : str
|
159
|
+
Prefix for the output files. Each output file is named
|
160
|
+
"{output_prefix}.{stat}.tsv".
|
131
161
|
quantile : float
|
132
|
-
Quantile threshold to
|
162
|
+
Quantile threshold (between 0 and 1) used to define outliers.
|
133
163
|
"""
|
134
|
-
|
135
|
-
data = pd.read_csv(
|
136
|
-
score_file,
|
137
|
-
sep="\t",
|
138
|
-
na_values=["nan"],
|
139
|
-
dtype={"Candidate": str},
|
140
|
-
index_col=False,
|
141
|
-
)
|
164
|
+
df = pd.read_csv(score_file, sep="\t", na_values=["nan"], index_col=False)
|
142
165
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
# Calculate quantile threshold for the chosen column
|
149
|
-
threshold = data[column].quantile(quantile)
|
150
|
-
|
151
|
-
if data[column].nunique() == 1:
|
152
|
-
warnings.warn(
|
153
|
-
f"Column '{column}' contains only one unique value ({threshold}), making quantile filtering meaningless.",
|
154
|
-
UserWarning,
|
155
|
-
)
|
156
|
-
outliers = pd.DataFrame(columns=data.columns)
|
157
|
-
elif (threshold == 1) and (column.startswith("Q")):
|
158
|
-
outliers = data[data[column] >= threshold]
|
159
|
-
else:
|
160
|
-
outliers = data[data[column] > threshold]
|
161
|
-
|
162
|
-
# Sort the filtered data by 'Chrom', 'Start', 'End' columns
|
163
|
-
if not outliers.empty:
|
164
|
-
outliers = outliers.reset_index(drop=True)
|
165
|
-
outliers_sorted = natsorted_df(outliers)
|
166
|
+
cols = list(df.columns)
|
167
|
+
if "N(Variants)" in cols:
|
168
|
+
start_idx = cols.index("N(Variants)") + 1
|
169
|
+
metric_cols = cols[start_idx:]
|
166
170
|
else:
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
Transparency level of scatter points (default: 0.6).
|
216
|
-
marker_size : float, optional
|
217
|
-
Size of the scatter plot markers (default: 20).
|
218
|
-
marker_color : str, optional
|
219
|
-
Color of the markers (default: "blue").
|
220
|
-
marker_style : str, optional
|
221
|
-
Shape of the marker (default: "o").
|
222
|
-
"""
|
223
|
-
u_data = pd.read_csv(u_file, sep="\t")
|
224
|
-
q_data = pd.read_csv(q_file, sep="\t")
|
225
|
-
|
226
|
-
u_column = u_data.columns[-2]
|
227
|
-
q_column = q_data.columns[-2]
|
228
|
-
|
229
|
-
u_data["interval"] = (
|
230
|
-
u_data["Chrom"].astype(str)
|
231
|
-
+ ":"
|
232
|
-
+ u_data["Start"].astype(str)
|
233
|
-
+ "-"
|
234
|
-
+ u_data["End"].astype(str)
|
235
|
-
)
|
236
|
-
q_data["interval"] = (
|
237
|
-
q_data["Chrom"].astype(str)
|
238
|
-
+ ":"
|
239
|
-
+ q_data["Start"].astype(str)
|
240
|
-
+ "-"
|
241
|
-
+ q_data["End"].astype(str)
|
242
|
-
)
|
243
|
-
|
244
|
-
u_data[u_column] = pd.to_numeric(u_data[u_column], errors="coerce")
|
245
|
-
q_data[q_column] = pd.to_numeric(q_data[q_column], errors="coerce")
|
246
|
-
u_data = u_data.dropna(subset=[u_column])
|
247
|
-
q_data = q_data.dropna(subset=[q_column])
|
248
|
-
|
249
|
-
u_interval_dict = {row["interval"]: row[u_column] for _, row in u_data.iterrows()}
|
250
|
-
q_interval_dict = {row["interval"]: row[q_column] for _, row in q_data.iterrows()}
|
251
|
-
u_candidate_dict = {
|
252
|
-
row["interval"]: set(str(row["Candidate"]).split(","))
|
253
|
-
for _, row in u_data.iterrows()
|
254
|
-
}
|
255
|
-
q_candidate_dict = {
|
256
|
-
row["interval"]: set(str(row["Candidate"]).split(","))
|
257
|
-
for _, row in q_data.iterrows()
|
258
|
-
}
|
259
|
-
|
260
|
-
common_intervals = set(u_interval_dict.keys()) & set(q_interval_dict.keys())
|
261
|
-
if not common_intervals:
|
262
|
-
raise ValueError(
|
263
|
-
"No common genomic intervals found between U and Q score/outlier files."
|
171
|
+
# fallback: exclude common non-metric columns, keep numeric ones
|
172
|
+
non_metrics = {"Chrom", "Start", "End", "Ref", "Tgt", "Src"}
|
173
|
+
candidate = [c for c in cols if c not in non_metrics]
|
174
|
+
metric_cols = [
|
175
|
+
c for c in candidate if pd.to_numeric(df[c], errors="coerce").notna().any()
|
176
|
+
]
|
177
|
+
|
178
|
+
if not metric_cols:
|
179
|
+
raise ValueError("No metric columns found.")
|
180
|
+
|
181
|
+
for col in metric_cols:
|
182
|
+
s_num = pd.to_numeric(df[col], errors="coerce").dropna()
|
183
|
+
|
184
|
+
if s_num.empty:
|
185
|
+
warnings.warn(
|
186
|
+
f"Column '{col}' has no numeric values; writing empty result.",
|
187
|
+
UserWarning,
|
188
|
+
)
|
189
|
+
out_sorted = pd.DataFrame(columns=df.columns)
|
190
|
+
elif s_num.nunique() == 1:
|
191
|
+
thr = s_num.iloc[0]
|
192
|
+
warnings.warn(
|
193
|
+
f"Column '{col}' has only one unique value ({thr}); writing empty result.",
|
194
|
+
UserWarning,
|
195
|
+
)
|
196
|
+
out_sorted = pd.DataFrame(columns=df.columns)
|
197
|
+
else:
|
198
|
+
thr = s_num.quantile(quantile)
|
199
|
+
col_num = pd.to_numeric(df[col], errors="coerce")
|
200
|
+
if not col.startswith("U"):
|
201
|
+
out = df[col_num >= thr]
|
202
|
+
else:
|
203
|
+
out = df[col_num > thr]
|
204
|
+
|
205
|
+
if not out.empty:
|
206
|
+
out = out.reset_index(drop=True)
|
207
|
+
try:
|
208
|
+
out_sorted = natsorted_df(out) # your existing natural sort
|
209
|
+
except NameError:
|
210
|
+
keys = [k for k in ("Chrom", "Start", "End") if k in out.columns]
|
211
|
+
out_sorted = (
|
212
|
+
out.sort_values(by=keys, kind="mergesort") if keys else out
|
213
|
+
)
|
214
|
+
else:
|
215
|
+
out_sorted = out
|
216
|
+
|
217
|
+
out_sorted.astype(str).to_csv(
|
218
|
+
f"{output_prefix}.{col}.{quantile}.outliers.tsv", index=False, sep="\t"
|
264
219
|
)
|
265
|
-
|
266
|
-
# Helper: get candidate overlap or "."
|
267
|
-
def get_candidate_overlap(interval):
|
268
|
-
u_set = u_candidate_dict.get(interval, set())
|
269
|
-
q_set = q_candidate_dict.get(interval, set())
|
270
|
-
overlap = sorted(u_set & q_set)
|
271
|
-
return ",".join(overlap) if overlap else "NA"
|
272
|
-
|
273
|
-
overlap_df = pd.DataFrame(
|
274
|
-
{
|
275
|
-
"Chrom": [interval.split(":")[0] for interval in common_intervals],
|
276
|
-
"Start": [
|
277
|
-
int(interval.split(":")[1].split("-")[0])
|
278
|
-
for interval in common_intervals
|
279
|
-
],
|
280
|
-
"End": [
|
281
|
-
int(interval.split(":")[1].split("-")[1])
|
282
|
-
for interval in common_intervals
|
283
|
-
],
|
284
|
-
u_column: [u_interval_dict[c] for c in common_intervals],
|
285
|
-
q_column: [q_interval_dict[c] for c in common_intervals],
|
286
|
-
"Overlapping Candidate": [
|
287
|
-
get_candidate_overlap(c) for c in common_intervals
|
288
|
-
],
|
289
|
-
}
|
290
|
-
)
|
291
|
-
|
292
|
-
overlap_df_sorted = natsorted_df(overlap_df)
|
293
|
-
overlap_output = os.path.splitext(output)[0] + ".overlap.tsv"
|
294
|
-
pd.DataFrame(overlap_df_sorted).to_csv(overlap_output, sep="\t", index=False)
|
295
|
-
|
296
|
-
plt.figure(figsize=(figsize_x, figsize_y))
|
297
|
-
plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))
|
298
|
-
plt.scatter(
|
299
|
-
x=overlap_df[q_column],
|
300
|
-
y=overlap_df[u_column],
|
301
|
-
alpha=alpha,
|
302
|
-
s=marker_size,
|
303
|
-
c=marker_color,
|
304
|
-
marker=marker_style,
|
305
|
-
)
|
306
|
-
xmin, xmax = plt.gca().get_xlim()
|
307
|
-
ymin, ymax = plt.gca().get_ylim()
|
308
|
-
plt.xlim(left=max(0, xmin))
|
309
|
-
plt.ylim(bottom=max(0, ymin))
|
310
|
-
plt.xlabel(xlabel)
|
311
|
-
plt.ylabel(ylabel)
|
312
|
-
plt.title(title)
|
313
|
-
plt.grid(alpha=0.5, linestyle="--")
|
314
|
-
plt.savefig(output, bbox_inches="tight", dpi=dpi)
|
315
|
-
plt.close()
|
sai/stats/__init__.py
CHANGED
@@ -16,3 +16,14 @@
|
|
16
16
|
# along with this program. If not, please see
|
17
17
|
#
|
18
18
|
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
from .generic_statistic import GenericStatistic
|
22
|
+
from .danc_statistic import DancStatistic
|
23
|
+
from .dd_statistic import DdStatistic
|
24
|
+
from .df_statistic import DfStatistic
|
25
|
+
from .dplus_statistic import DplusStatistic
|
26
|
+
from .fd_statistic import FdStatistic
|
27
|
+
from .q_statistic import QStatistic
|
28
|
+
from .u_statistic import UStatistic
|
29
|
+
from .stat_utils import *
|
@@ -0,0 +1,83 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
from typing import Dict, Any
|
23
|
+
from sai.registries.stat_registry import STAT_REGISTRY
|
24
|
+
from sai.stats import GenericStatistic
|
25
|
+
from sai.stats.stat_utils import calc_four_pops_freq, calc_pattern_sum
|
26
|
+
|
27
|
+
|
28
|
+
@STAT_REGISTRY.register("Danc")
|
29
|
+
class DancStatistic(GenericStatistic):
|
30
|
+
"""
|
31
|
+
Class for computing the Danc statistic (Fang et al. 2024. PLoS Genet)
|
32
|
+
|
33
|
+
The Danc statistic detects asymmetric ancestry contribution by comparing
|
34
|
+
excess BAAA and ABAA site patterns in a four-population framework.
|
35
|
+
"""
|
36
|
+
|
37
|
+
STAT_NAME = "Danc"
|
38
|
+
|
39
|
+
def compute(self, **kwargs) -> Dict[str, Any]:
|
40
|
+
"""
|
41
|
+
Computes the Danc statistic for each source population.
|
42
|
+
|
43
|
+
This method computes the statistic per source population using four-population
|
44
|
+
site pattern counts.
|
45
|
+
|
46
|
+
Parameters
|
47
|
+
----------
|
48
|
+
**kwargs : dict
|
49
|
+
Unused. Present to maintain compatibility with the base class interface.
|
50
|
+
|
51
|
+
Returns
|
52
|
+
-------
|
53
|
+
dict
|
54
|
+
A dictionary containing:
|
55
|
+
- 'name' : str
|
56
|
+
The name of the statistic ("Danc").
|
57
|
+
- 'value' : list[float]
|
58
|
+
A list of Danc values, one for each source population.
|
59
|
+
"""
|
60
|
+
danc_results = []
|
61
|
+
|
62
|
+
for i in range(len(self.src_gts_list)):
|
63
|
+
ref_freq, tgt_freq, src_freq, out_freq = calc_four_pops_freq(
|
64
|
+
ref_gts=self.ref_gts,
|
65
|
+
tgt_gts=self.tgt_gts,
|
66
|
+
src_gts=self.src_gts_list[i],
|
67
|
+
out_gts=self.out_gts,
|
68
|
+
ref_ploidy=self.ref_ploidy,
|
69
|
+
tgt_ploidy=self.tgt_ploidy,
|
70
|
+
src_ploidy=self.src_ploidy_list[i],
|
71
|
+
out_ploidy=self.out_ploidy,
|
72
|
+
)
|
73
|
+
|
74
|
+
baaa = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "baaa")
|
75
|
+
abaa = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "abaa")
|
76
|
+
|
77
|
+
numerator = baaa - abaa
|
78
|
+
denominator = baaa + abaa
|
79
|
+
|
80
|
+
danc = numerator / denominator if denominator != 0 else np.nan
|
81
|
+
danc_results.append(danc)
|
82
|
+
|
83
|
+
return {"name": self.STAT_NAME, "value": danc_results}
|
@@ -0,0 +1,77 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
from scipy.spatial.distance import cdist
|
23
|
+
from typing import Dict, Any
|
24
|
+
from sai.registries.stat_registry import STAT_REGISTRY
|
25
|
+
from sai.stats import GenericStatistic
|
26
|
+
|
27
|
+
|
28
|
+
@STAT_REGISTRY.register("DD")
|
29
|
+
class DdStatistic(GenericStatistic):
|
30
|
+
"""
|
31
|
+
Class for computing the average difference of the sequence divergence.
|
32
|
+
|
33
|
+
The DD statistic quantifies the difference in average pairwise sequence
|
34
|
+
divergence between a source population and two target populations (reference
|
35
|
+
and target), using Manhattan (cityblock) distance.
|
36
|
+
"""
|
37
|
+
|
38
|
+
STAT_NAME = "DD"
|
39
|
+
|
40
|
+
def compute(self, **kwargs) -> Dict[str, Any]:
|
41
|
+
"""
|
42
|
+
Computes the DD statistic for each source population.
|
43
|
+
|
44
|
+
For each source population, the method calculates pairwise Manhattan distances
|
45
|
+
between the source and both the target and reference populations, averages the
|
46
|
+
distances per genome, and computes the difference in mean divergence.
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
**kwargs : dict
|
51
|
+
Unused. Present to maintain compatibility with the base class interface.
|
52
|
+
|
53
|
+
Returns
|
54
|
+
-------
|
55
|
+
dict
|
56
|
+
A dictionary containing:
|
57
|
+
- 'name' : str
|
58
|
+
The name of the statistic ("DD").
|
59
|
+
- 'value' : list[float]
|
60
|
+
A list of DD values, one for each source population.
|
61
|
+
"""
|
62
|
+
dd_results = []
|
63
|
+
|
64
|
+
for i in range(len(self.src_gts_list)):
|
65
|
+
# pairwise distances
|
66
|
+
src_gts = self.src_gts_list[i]
|
67
|
+
seq_divs_src_tgt = cdist(src_gts.T, self.tgt_gts.T, metric="cityblock")
|
68
|
+
seq_divs_src_ref = cdist(src_gts.T, self.ref_gts.T, metric="cityblock")
|
69
|
+
|
70
|
+
# mean of each row
|
71
|
+
mean_src_tgt = np.mean(seq_divs_src_tgt, axis=1)
|
72
|
+
mean_src_ref = np.mean(seq_divs_src_ref, axis=1)
|
73
|
+
|
74
|
+
dd = np.mean(mean_src_ref - mean_src_tgt)
|
75
|
+
dd_results.append(dd)
|
76
|
+
|
77
|
+
return {"name": self.STAT_NAME, "value": dd_results}
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright 2025 Xin Huang
|
2
|
+
#
|
3
|
+
# GNU General Public License v3.0
|
4
|
+
#
|
5
|
+
# This program is free software: you can redistribute it and/or modify
|
6
|
+
# it under the terms of the GNU General Public License as published by
|
7
|
+
# the Free Software Foundation, either version 3 of the License, or
|
8
|
+
# (at your option) any later version.
|
9
|
+
#
|
10
|
+
# This program is distributed in the hope that it will be useful,
|
11
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13
|
+
# GNU General Public License for more details.
|
14
|
+
#
|
15
|
+
# You should have received a copy of the GNU General Public License
|
16
|
+
# along with this program. If not, please see
|
17
|
+
#
|
18
|
+
# https://www.gnu.org/licenses/gpl-3.0.en.html
|
19
|
+
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
from typing import Dict, Any
|
23
|
+
from sai.registries.stat_registry import STAT_REGISTRY
|
24
|
+
from sai.stats import GenericStatistic
|
25
|
+
from sai.stats.stat_utils import calc_four_pops_freq, calc_pattern_sum
|
26
|
+
|
27
|
+
|
28
|
+
@STAT_REGISTRY.register("df")
|
29
|
+
class DfStatistic(GenericStatistic):
|
30
|
+
"""
|
31
|
+
Class for computing the distance fraction (df) statistic (Pfeifer and Kapan. 2019. BMC Bioinformatics).
|
32
|
+
|
33
|
+
The df statistic quantifies the relative excess of shared derived alleles
|
34
|
+
using ABBA, BABA, and BBAA site patterns across a four-population test.
|
35
|
+
"""
|
36
|
+
|
37
|
+
STAT_NAME = "df"
|
38
|
+
|
39
|
+
def compute(self, **kwargs) -> Dict[str, Any]:
|
40
|
+
"""
|
41
|
+
Computes the df statistic for each source population.
|
42
|
+
|
43
|
+
This method computes df for each source population using site pattern
|
44
|
+
counts based on allele frequency input.
|
45
|
+
|
46
|
+
Parameters
|
47
|
+
----------
|
48
|
+
**kwargs : dict
|
49
|
+
Unused. Present to maintain compatibility with the base class interface.
|
50
|
+
|
51
|
+
Returns
|
52
|
+
-------
|
53
|
+
dict
|
54
|
+
A dictionary containing:
|
55
|
+
- 'name' : str
|
56
|
+
The name of the statistic ("df").
|
57
|
+
- 'value' : list[float]
|
58
|
+
A list of df values, one per source population.
|
59
|
+
"""
|
60
|
+
df_results = []
|
61
|
+
|
62
|
+
for i in range(len(self.src_gts_list)):
|
63
|
+
ref_freq, tgt_freq, src_freq, out_freq = calc_four_pops_freq(
|
64
|
+
ref_gts=self.ref_gts,
|
65
|
+
tgt_gts=self.tgt_gts,
|
66
|
+
src_gts=self.src_gts_list[i],
|
67
|
+
out_gts=self.out_gts,
|
68
|
+
ref_ploidy=self.ref_ploidy,
|
69
|
+
tgt_ploidy=self.tgt_ploidy,
|
70
|
+
src_ploidy=self.src_ploidy_list[i],
|
71
|
+
out_ploidy=self.out_ploidy,
|
72
|
+
)
|
73
|
+
|
74
|
+
abba = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "abba")
|
75
|
+
baba = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "baba")
|
76
|
+
bbaa = calc_pattern_sum(ref_freq, tgt_freq, src_freq, out_freq, "bbaa")
|
77
|
+
|
78
|
+
numerator = abba - baba
|
79
|
+
denominator = abba + baba + 2 * bbaa
|
80
|
+
|
81
|
+
df = numerator / denominator if denominator != 0 else np.nan
|
82
|
+
df_results.append(df)
|
83
|
+
|
84
|
+
return {"name": self.STAT_NAME, "value": df_results}
|