smftools 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +7 -1
- smftools/cli/hmm_adata.py +902 -244
- smftools/cli/load_adata.py +318 -198
- smftools/cli/preprocess_adata.py +285 -171
- smftools/cli/spatial_adata.py +137 -53
- smftools/cli_entry.py +94 -178
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +22 -17
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +505 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2125 -1426
- smftools/hmm/__init__.py +2 -3
- smftools/hmm/archived/call_hmm_peaks.py +16 -1
- smftools/hmm/call_hmm_peaks.py +173 -193
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +379 -156
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +195 -29
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +347 -168
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +145 -85
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +8 -8
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +103 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +70 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +688 -271
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/METADATA +15 -43
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.4.dist-info/RECORD +0 -176
- /smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
smftools/tools/position_stats.py
CHANGED
|
@@ -1,23 +1,76 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
import anndata as ad
|
|
8
|
+
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
# optional imports
|
|
14
|
+
try:
|
|
15
|
+
from joblib import Parallel, delayed
|
|
16
|
+
|
|
17
|
+
JOBLIB_AVAILABLE = True
|
|
18
|
+
except Exception:
|
|
19
|
+
JOBLIB_AVAILABLE = False
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from scipy.stats import chi2_contingency
|
|
23
|
+
|
|
24
|
+
SCIPY_STATS_AVAILABLE = True
|
|
25
|
+
except Exception:
|
|
26
|
+
SCIPY_STATS_AVAILABLE = False
|
|
27
|
+
|
|
28
|
+
# -----------------------------
|
|
29
|
+
# Compute positionwise statistic (multi-method + simple site_types)
|
|
30
|
+
# -----------------------------
|
|
31
|
+
import os
|
|
32
|
+
from contextlib import contextmanager
|
|
33
|
+
from itertools import cycle
|
|
34
|
+
|
|
35
|
+
import joblib
|
|
36
|
+
from joblib import Parallel, cpu_count, delayed
|
|
37
|
+
from scipy.stats import chi2_contingency
|
|
38
|
+
from tqdm import tqdm
|
|
39
|
+
|
|
40
|
+
|
|
1
41
|
# ------------------------- Utilities -------------------------
|
|
2
|
-
def random_fill_nans(X):
|
|
42
|
+
def random_fill_nans(X: np.ndarray) -> np.ndarray:
|
|
43
|
+
"""Fill NaNs with random values in-place.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
X: Input array with NaNs.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
numpy.ndarray: Array with NaNs replaced by random values.
|
|
50
|
+
"""
|
|
3
51
|
import numpy as np
|
|
52
|
+
|
|
4
53
|
nan_mask = np.isnan(X)
|
|
5
54
|
X[nan_mask] = np.random.rand(*X[nan_mask].shape)
|
|
6
55
|
return X
|
|
7
56
|
|
|
8
|
-
def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
9
|
-
"""
|
|
10
|
-
Perform Bayesian-style methylation vs activity analysis independently within each group.
|
|
11
57
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
58
|
+
def calculate_relative_risk_on_activity(
|
|
59
|
+
adata: "ad.AnnData",
|
|
60
|
+
sites: Sequence[str],
|
|
61
|
+
alpha: float = 0.05,
|
|
62
|
+
groupby: str | Sequence[str] | None = None,
|
|
63
|
+
) -> dict:
|
|
64
|
+
"""Perform methylation vs. activity analysis within each group.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
adata: Annotated data matrix.
|
|
68
|
+
sites: Site keys (e.g., ``["GpC_site", "CpG_site"]``).
|
|
69
|
+
alpha: FDR threshold for significance.
|
|
70
|
+
groupby: Obs column(s) to group by.
|
|
17
71
|
|
|
18
72
|
Returns:
|
|
19
|
-
|
|
20
|
-
results_dict[ref][group_label] = (results_df, sig_df)
|
|
73
|
+
dict: Mapping of reference -> group label -> ``(results_df, sig_df)``.
|
|
21
74
|
"""
|
|
22
75
|
import numpy as np
|
|
23
76
|
import pandas as pd
|
|
@@ -25,30 +78,44 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
|
25
78
|
from statsmodels.stats.multitest import multipletests
|
|
26
79
|
|
|
27
80
|
def compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values):
|
|
28
|
-
|
|
81
|
+
"""Build result and significant-data DataFrames for a reference.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
ref: Reference name.
|
|
85
|
+
site_subset: AnnData subset restricted to sites.
|
|
86
|
+
positions_list: Positions tested.
|
|
87
|
+
relative_risks: Relative risk values.
|
|
88
|
+
p_values: Raw p-values.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Tuple of (results_df, sig_df).
|
|
92
|
+
"""
|
|
93
|
+
p_adj = multipletests(p_values, method="fdr_bh")[1] if p_values else []
|
|
29
94
|
|
|
30
95
|
genomic_positions = np.array(site_subset.var_names)[positions_list]
|
|
31
96
|
is_gpc_site = site_subset.var[f"{ref}_GpC_site"].values[positions_list]
|
|
32
97
|
is_cpg_site = site_subset.var[f"{ref}_CpG_site"].values[positions_list]
|
|
33
98
|
|
|
34
|
-
results_df = pd.DataFrame(
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
99
|
+
results_df = pd.DataFrame(
|
|
100
|
+
{
|
|
101
|
+
"Feature_Index": positions_list,
|
|
102
|
+
"Genomic_Position": genomic_positions.astype(int),
|
|
103
|
+
"Relative_Risk": relative_risks,
|
|
104
|
+
"Adjusted_P_Value": p_adj,
|
|
105
|
+
"GpC_Site": is_gpc_site,
|
|
106
|
+
"CpG_Site": is_cpg_site,
|
|
107
|
+
}
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
results_df["log2_Relative_Risk"] = np.log2(results_df["Relative_Risk"].replace(0, 1e-300))
|
|
111
|
+
results_df["-log10_Adj_P"] = -np.log10(results_df["Adjusted_P_Value"].replace(0, 1e-300))
|
|
112
|
+
sig_df = results_df[results_df["Adjusted_P_Value"] < alpha]
|
|
46
113
|
return results_df, sig_df
|
|
47
114
|
|
|
48
115
|
results_dict = {}
|
|
49
116
|
|
|
50
|
-
for ref in adata.obs[
|
|
51
|
-
ref_subset = adata[adata.obs[
|
|
117
|
+
for ref in adata.obs["Reference_strand"].unique():
|
|
118
|
+
ref_subset = adata[adata.obs["Reference_strand"] == ref].copy()
|
|
52
119
|
if ref_subset.shape[0] == 0:
|
|
53
120
|
continue
|
|
54
121
|
|
|
@@ -56,20 +123,22 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
|
56
123
|
if groupby is not None:
|
|
57
124
|
if isinstance(groupby, str):
|
|
58
125
|
groupby = [groupby]
|
|
126
|
+
|
|
59
127
|
def format_group_label(row):
|
|
128
|
+
"""Format a group label string from obs row values."""
|
|
60
129
|
return ",".join([f"{col}={row[col]}" for col in groupby])
|
|
61
130
|
|
|
62
|
-
combined_label =
|
|
131
|
+
combined_label = "__".join(groupby)
|
|
63
132
|
ref_subset.obs[combined_label] = ref_subset.obs.apply(format_group_label, axis=1)
|
|
64
133
|
groups = ref_subset.obs[combined_label].unique()
|
|
65
134
|
else:
|
|
66
135
|
combined_label = None
|
|
67
|
-
groups = [
|
|
136
|
+
groups = ["all"]
|
|
68
137
|
|
|
69
138
|
results_dict[ref] = {}
|
|
70
139
|
|
|
71
140
|
for group in groups:
|
|
72
|
-
if group ==
|
|
141
|
+
if group == "all":
|
|
73
142
|
group_subset = ref_subset
|
|
74
143
|
else:
|
|
75
144
|
group_subset = ref_subset[ref_subset.obs[combined_label] == group]
|
|
@@ -85,7 +154,7 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
|
85
154
|
|
|
86
155
|
# Matrix and labels
|
|
87
156
|
X = random_fill_nans(site_subset.X.copy())
|
|
88
|
-
y = site_subset.obs[
|
|
157
|
+
y = site_subset.obs["activity_status"].map({"Active": 1, "Silent": 0}).values
|
|
89
158
|
P_active = np.mean(y)
|
|
90
159
|
|
|
91
160
|
# Analysis
|
|
@@ -104,7 +173,9 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
|
104
173
|
continue
|
|
105
174
|
|
|
106
175
|
P_active_given_methylated = (P_methylated_given_active * P_active) / P_methylated
|
|
107
|
-
P_active_given_unmethylated = ((1 - P_methylated_given_active) * P_active) / (
|
|
176
|
+
P_active_given_unmethylated = ((1 - P_methylated_given_active) * P_active) / (
|
|
177
|
+
1 - P_methylated
|
|
178
|
+
)
|
|
108
179
|
RR = P_active_given_methylated / P_active_given_unmethylated
|
|
109
180
|
|
|
110
181
|
_, p_value = fisher_exact(table)
|
|
@@ -112,49 +183,13 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
|
112
183
|
relative_risks.append(RR)
|
|
113
184
|
p_values.append(p_value)
|
|
114
185
|
|
|
115
|
-
results_df, sig_df = compute_risk_df(
|
|
186
|
+
results_df, sig_df = compute_risk_df(
|
|
187
|
+
ref, site_subset, positions_list, relative_risks, p_values
|
|
188
|
+
)
|
|
116
189
|
results_dict[ref][group] = (results_df, sig_df)
|
|
117
190
|
|
|
118
191
|
return results_dict
|
|
119
192
|
|
|
120
|
-
import copy
|
|
121
|
-
import warnings
|
|
122
|
-
from typing import Dict, Any, List, Optional, Tuple, Union
|
|
123
|
-
|
|
124
|
-
import numpy as np
|
|
125
|
-
import pandas as pd
|
|
126
|
-
import matplotlib.pyplot as plt
|
|
127
|
-
|
|
128
|
-
# optional imports
|
|
129
|
-
try:
|
|
130
|
-
from joblib import Parallel, delayed
|
|
131
|
-
JOBLIB_AVAILABLE = True
|
|
132
|
-
except Exception:
|
|
133
|
-
JOBLIB_AVAILABLE = False
|
|
134
|
-
|
|
135
|
-
try:
|
|
136
|
-
from scipy.stats import chi2_contingency
|
|
137
|
-
SCIPY_STATS_AVAILABLE = True
|
|
138
|
-
except Exception:
|
|
139
|
-
SCIPY_STATS_AVAILABLE = False
|
|
140
|
-
|
|
141
|
-
# -----------------------------
|
|
142
|
-
# Compute positionwise statistic (multi-method + simple site_types)
|
|
143
|
-
# -----------------------------
|
|
144
|
-
import numpy as np
|
|
145
|
-
import pandas as pd
|
|
146
|
-
from typing import List, Optional, Sequence, Dict, Any, Tuple
|
|
147
|
-
from contextlib import contextmanager
|
|
148
|
-
from joblib import Parallel, delayed, cpu_count
|
|
149
|
-
import joblib
|
|
150
|
-
from tqdm import tqdm
|
|
151
|
-
from scipy.stats import chi2_contingency
|
|
152
|
-
import warnings
|
|
153
|
-
import matplotlib.pyplot as plt
|
|
154
|
-
from itertools import cycle
|
|
155
|
-
import os
|
|
156
|
-
import warnings
|
|
157
|
-
|
|
158
193
|
|
|
159
194
|
# ---------------------------
|
|
160
195
|
# joblib <-> tqdm integration
|
|
@@ -165,7 +200,10 @@ def tqdm_joblib(tqdm_object: tqdm):
|
|
|
165
200
|
old = joblib.parallel.BatchCompletionCallBack
|
|
166
201
|
|
|
167
202
|
class TqdmBatchCompletionCallback(old): # type: ignore
|
|
203
|
+
"""Joblib callback that updates a tqdm progress bar."""
|
|
204
|
+
|
|
168
205
|
def __call__(self, *args, **kwargs):
|
|
206
|
+
"""Update the progress bar when a batch completes."""
|
|
169
207
|
try:
|
|
170
208
|
tqdm_object.update(n=self.batch_size)
|
|
171
209
|
except Exception:
|
|
@@ -183,6 +221,16 @@ def tqdm_joblib(tqdm_object: tqdm):
|
|
|
183
221
|
# row workers (upper-triangle only)
|
|
184
222
|
# ---------------------------
|
|
185
223
|
def _chi2_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
|
|
224
|
+
"""Compute chi-squared statistics for one row of a pairwise matrix.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
i: Row index.
|
|
228
|
+
X_bin: Binary matrix.
|
|
229
|
+
min_count_for_pairwise: Minimum count for valid comparison.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
Tuple of (row_index, row_values).
|
|
233
|
+
"""
|
|
186
234
|
n_pos = X_bin.shape[1]
|
|
187
235
|
row = np.full((n_pos,), np.nan, dtype=float)
|
|
188
236
|
xi = X_bin[:, i]
|
|
@@ -202,7 +250,19 @@ def _chi2_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tup
|
|
|
202
250
|
return (i, row)
|
|
203
251
|
|
|
204
252
|
|
|
205
|
-
def _relative_risk_row_job(
|
|
253
|
+
def _relative_risk_row_job(
|
|
254
|
+
i: int, X_bin: np.ndarray, min_count_for_pairwise: int
|
|
255
|
+
) -> Tuple[int, np.ndarray]:
|
|
256
|
+
"""Compute relative-risk values for one row of a pairwise matrix.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
i: Row index.
|
|
260
|
+
X_bin: Binary matrix.
|
|
261
|
+
min_count_for_pairwise: Minimum count for valid comparison.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
Tuple of (row_index, row_values).
|
|
265
|
+
"""
|
|
206
266
|
n_pos = X_bin.shape[1]
|
|
207
267
|
row = np.full((n_pos,), np.nan, dtype=float)
|
|
208
268
|
xi = X_bin[:, i]
|
|
@@ -226,8 +286,9 @@ def _relative_risk_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: in
|
|
|
226
286
|
row[j] = np.nan
|
|
227
287
|
return (i, row)
|
|
228
288
|
|
|
289
|
+
|
|
229
290
|
def compute_positionwise_statistics(
|
|
230
|
-
adata,
|
|
291
|
+
adata: "ad.AnnData",
|
|
231
292
|
layer: str,
|
|
232
293
|
methods: Sequence[str] = ("pearson",),
|
|
233
294
|
sample_col: str = "Barcode",
|
|
@@ -238,13 +299,21 @@ def compute_positionwise_statistics(
|
|
|
238
299
|
min_count_for_pairwise: int = 10,
|
|
239
300
|
max_threads: Optional[int] = None,
|
|
240
301
|
reverse_indices_on_store: bool = False,
|
|
241
|
-
):
|
|
242
|
-
"""
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
302
|
+
) -> None:
|
|
303
|
+
"""Compute per-(sample, ref) positionwise matrices for selected methods.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
adata: AnnData object to analyze.
|
|
307
|
+
layer: Layer name to use for statistics.
|
|
308
|
+
methods: Methods to compute (e.g., ``"pearson"``).
|
|
309
|
+
sample_col: Obs column containing sample identifiers.
|
|
310
|
+
ref_col: Obs column containing reference identifiers.
|
|
311
|
+
site_types: Optional site types to subset positions.
|
|
312
|
+
encoding: ``"signed"`` or ``"binary"`` encoding.
|
|
313
|
+
output_key: Key prefix for results stored in ``adata.uns``.
|
|
314
|
+
min_count_for_pairwise: Minimum counts for pairwise comparisons.
|
|
315
|
+
max_threads: Maximum number of threads.
|
|
316
|
+
reverse_indices_on_store: Whether to reverse indices on output storage.
|
|
248
317
|
"""
|
|
249
318
|
if isinstance(methods, str):
|
|
250
319
|
methods = [methods]
|
|
@@ -349,7 +418,10 @@ def compute_positionwise_statistics(
|
|
|
349
418
|
Xc = X_bin - col_mean # nan preserved
|
|
350
419
|
Xc0 = np.nan_to_num(Xc, nan=0.0)
|
|
351
420
|
cov = Xc0.T @ Xc0
|
|
352
|
-
denom = (
|
|
421
|
+
denom = (
|
|
422
|
+
np.sqrt((Xc0**2).sum(axis=0))[:, None]
|
|
423
|
+
* np.sqrt((Xc0**2).sum(axis=0))[None, :]
|
|
424
|
+
)
|
|
353
425
|
with np.errstate(divide="ignore", invalid="ignore"):
|
|
354
426
|
mat = np.where(denom != 0.0, cov / denom, np.nan)
|
|
355
427
|
elif m == "binary_covariance":
|
|
@@ -366,8 +438,12 @@ def compute_positionwise_statistics(
|
|
|
366
438
|
else:
|
|
367
439
|
worker = _relative_risk_row_job
|
|
368
440
|
out = np.full((n_pos, n_pos), np.nan, dtype=float)
|
|
369
|
-
tasks = (
|
|
370
|
-
|
|
441
|
+
tasks = (
|
|
442
|
+
delayed(worker)(i, X_bin, min_count_for_pairwise) for i in range(n_pos)
|
|
443
|
+
)
|
|
444
|
+
pbar_rows = tqdm(
|
|
445
|
+
total=n_pos, desc=f"{m}: rows ({sample}__{ref})", leave=False
|
|
446
|
+
)
|
|
371
447
|
with tqdm_joblib(pbar_rows):
|
|
372
448
|
results = Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
|
|
373
449
|
pbar_rows.close()
|
|
@@ -406,6 +482,7 @@ def compute_positionwise_statistics(
|
|
|
406
482
|
# Plotting function
|
|
407
483
|
# ---------------------------
|
|
408
484
|
|
|
485
|
+
|
|
409
486
|
def plot_positionwise_matrices(
|
|
410
487
|
adata,
|
|
411
488
|
methods: List[str],
|
|
@@ -427,9 +504,10 @@ def plot_positionwise_matrices(
|
|
|
427
504
|
"""
|
|
428
505
|
Plot grids of matrices for each method with pagination and rotated sample-row labels.
|
|
429
506
|
|
|
430
|
-
|
|
507
|
+
Args:
|
|
431
508
|
- rows_per_page: how many sample rows per page/figure (pagination)
|
|
432
509
|
- sample_label_rotation: rotation angle (deg) for the sample labels placed in the left margin.
|
|
510
|
+
|
|
433
511
|
Returns:
|
|
434
512
|
dict mapping method -> list of saved filenames (empty list if figures were shown).
|
|
435
513
|
"""
|
|
@@ -474,7 +552,12 @@ def plot_positionwise_matrices(
|
|
|
474
552
|
# try stringified tuple keys (some callers store differently)
|
|
475
553
|
for k in store.keys():
|
|
476
554
|
try:
|
|
477
|
-
if
|
|
555
|
+
if (
|
|
556
|
+
isinstance(k, tuple)
|
|
557
|
+
and len(k) == 2
|
|
558
|
+
and str(k[0]) == str(sample)
|
|
559
|
+
and str(k[1]) == str(ref)
|
|
560
|
+
):
|
|
478
561
|
return store[k]
|
|
479
562
|
if isinstance(k, str) and key_s == k:
|
|
480
563
|
return store[k]
|
|
@@ -486,7 +569,10 @@ def plot_positionwise_matrices(
|
|
|
486
569
|
m = method.lower()
|
|
487
570
|
method_store = adata.uns.get(output_key, {}).get(m, {})
|
|
488
571
|
if not method_store:
|
|
489
|
-
warnings.warn(
|
|
572
|
+
warnings.warn(
|
|
573
|
+
f"No results found for method '{method}' in adata.uns['{output_key}']. Skipping.",
|
|
574
|
+
stacklevel=2,
|
|
575
|
+
)
|
|
490
576
|
saved_files_by_method[method] = []
|
|
491
577
|
continue
|
|
492
578
|
|
|
@@ -507,21 +593,41 @@ def plot_positionwise_matrices(
|
|
|
507
593
|
|
|
508
594
|
# decide per-method defaults
|
|
509
595
|
if m == "pearson":
|
|
510
|
-
vmn =
|
|
511
|
-
|
|
596
|
+
vmn = (
|
|
597
|
+
-1.0
|
|
598
|
+
if (vmin is None or (isinstance(vmin, dict) and m not in vmin))
|
|
599
|
+
else (vmin.get(m) if isinstance(vmin, dict) else vmin)
|
|
600
|
+
)
|
|
601
|
+
vmx = (
|
|
602
|
+
1.0
|
|
603
|
+
if (vmax is None or (isinstance(vmax, dict) and m not in vmax))
|
|
604
|
+
else (vmax.get(m) if isinstance(vmax, dict) else vmax)
|
|
605
|
+
)
|
|
512
606
|
vmn = -1.0 if vmn is None else vmn
|
|
513
607
|
vmx = 1.0 if vmx is None else vmx
|
|
514
608
|
elif m == "binary_covariance":
|
|
515
|
-
vmn =
|
|
516
|
-
|
|
609
|
+
vmn = (
|
|
610
|
+
0.0
|
|
611
|
+
if (vmin is None or (isinstance(vmin, dict) and m not in vmin))
|
|
612
|
+
else (vmin.get(m) if isinstance(vmin, dict) else vmin)
|
|
613
|
+
)
|
|
614
|
+
vmx = (
|
|
615
|
+
1.0
|
|
616
|
+
if (vmax is None or (isinstance(vmax, dict) and m not in vmax))
|
|
617
|
+
else (vmax.get(m) if isinstance(vmax, dict) else vmax)
|
|
618
|
+
)
|
|
517
619
|
vmn = 0.0 if vmn is None else vmn
|
|
518
620
|
vmx = 1.0 if vmx is None else vmx
|
|
519
621
|
else:
|
|
520
|
-
vmn =
|
|
622
|
+
vmn = (
|
|
623
|
+
0.0
|
|
624
|
+
if (vmin is None or (isinstance(vmin, dict) and m not in vmin))
|
|
625
|
+
else (vmin.get(m) if isinstance(vmin, dict) else vmin)
|
|
626
|
+
)
|
|
521
627
|
if (vmax is None) or (isinstance(vmax, dict) and m not in vmax):
|
|
522
628
|
vmx = float(np.nanpercentile(allvals, 99.0)) if allvals.size else 1.0
|
|
523
629
|
else:
|
|
524
|
-
vmx =
|
|
630
|
+
vmx = vmax.get(m) if isinstance(vmax, dict) else vmax
|
|
525
631
|
vmn = 0.0 if vmn is None else vmn
|
|
526
632
|
if vmx is None:
|
|
527
633
|
vmx = 1.0
|
|
@@ -536,7 +642,9 @@ def plot_positionwise_matrices(
|
|
|
536
642
|
ncols = max(1, len(references))
|
|
537
643
|
fig_w = ncols * figsize_per_cell[0]
|
|
538
644
|
fig_h = nrows * figsize_per_cell[1]
|
|
539
|
-
fig, axes = plt.subplots(
|
|
645
|
+
fig, axes = plt.subplots(
|
|
646
|
+
nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
|
|
647
|
+
)
|
|
540
648
|
|
|
541
649
|
# leave margin for rotated sample labels
|
|
542
650
|
plt.subplots_adjust(left=0.12, right=0.88, top=0.95, bottom=0.05)
|
|
@@ -548,13 +656,24 @@ def plot_positionwise_matrices(
|
|
|
548
656
|
ax = axes[r_idx][c_idx]
|
|
549
657
|
df = _get_df_from_store(method_store, sample, ref)
|
|
550
658
|
if not isinstance(df, pd.DataFrame) or df.size == 0:
|
|
551
|
-
ax.text(
|
|
659
|
+
ax.text(
|
|
660
|
+
0.5,
|
|
661
|
+
0.5,
|
|
662
|
+
"No data",
|
|
663
|
+
ha="center",
|
|
664
|
+
va="center",
|
|
665
|
+
transform=ax.transAxes,
|
|
666
|
+
fontsize=10,
|
|
667
|
+
color="gray",
|
|
668
|
+
)
|
|
552
669
|
ax.set_xticks([])
|
|
553
670
|
ax.set_yticks([])
|
|
554
671
|
else:
|
|
555
672
|
mat = df.values.astype(float)
|
|
556
673
|
origin = "upper" if flip_display_axes else "lower"
|
|
557
|
-
im = ax.imshow(
|
|
674
|
+
im = ax.imshow(
|
|
675
|
+
mat, origin=origin, aspect="auto", vmin=vmn, vmax=vmx, cmap=cmap
|
|
676
|
+
)
|
|
558
677
|
any_plotted = True
|
|
559
678
|
ax.set_xticks([])
|
|
560
679
|
ax.set_yticks([])
|
|
@@ -569,9 +688,21 @@ def plot_positionwise_matrices(
|
|
|
569
688
|
ax_y0, ax_y1 = ax0.get_position().y0, ax0.get_position().y1
|
|
570
689
|
y_center = 0.5 * (ax_y0 + ax_y1)
|
|
571
690
|
# place text at x=0.01 (just inside left margin); rotation controls orientation
|
|
572
|
-
fig.text(
|
|
573
|
-
|
|
574
|
-
|
|
691
|
+
fig.text(
|
|
692
|
+
0.01,
|
|
693
|
+
y_center,
|
|
694
|
+
str(chunk[r_idx]),
|
|
695
|
+
va="center",
|
|
696
|
+
ha="left",
|
|
697
|
+
rotation=sample_label_rotation,
|
|
698
|
+
fontsize=9,
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
fig.suptitle(
|
|
702
|
+
f"{method} — per-sample x per-reference matrices (page {page_idx + 1}/{n_pages})",
|
|
703
|
+
fontsize=12,
|
|
704
|
+
y=0.99,
|
|
705
|
+
)
|
|
575
706
|
fig.tight_layout(rect=[0.05, 0.02, 0.9, 0.96])
|
|
576
707
|
|
|
577
708
|
# colorbar (shared)
|
|
@@ -587,7 +718,7 @@ def plot_positionwise_matrices(
|
|
|
587
718
|
|
|
588
719
|
# save or show
|
|
589
720
|
if output_dir:
|
|
590
|
-
fname = f"positionwise_{method}_page{page_idx+1}.png"
|
|
721
|
+
fname = f"positionwise_{method}_page{page_idx + 1}.png"
|
|
591
722
|
outpath = os.path.join(output_dir, fname)
|
|
592
723
|
plt.savefig(outpath, bbox_inches="tight")
|
|
593
724
|
saved_files.append(outpath)
|
smftools/tools/read_stats.py
CHANGED
|
@@ -1,36 +1,53 @@
|
|
|
1
1
|
# ------------------------- Utilities -------------------------
|
|
2
|
-
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING, Sequence
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
import anndata as ad
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def random_fill_nans(X: "np.ndarray") -> "np.ndarray":
|
|
12
|
+
"""Fill NaNs with random values in-place.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
X: Input array with NaNs.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
numpy.ndarray: Array with NaNs replaced by random values.
|
|
19
|
+
"""
|
|
3
20
|
import numpy as np
|
|
21
|
+
|
|
4
22
|
nan_mask = np.isnan(X)
|
|
5
23
|
X[nan_mask] = np.random.rand(*X[nan_mask].shape)
|
|
6
24
|
return X
|
|
7
25
|
|
|
26
|
+
|
|
8
27
|
def calculate_row_entropy(
|
|
9
|
-
adata,
|
|
10
|
-
layer,
|
|
11
|
-
output_key="entropy",
|
|
12
|
-
site_config=None,
|
|
13
|
-
ref_col="Reference_strand",
|
|
14
|
-
encoding="signed",
|
|
15
|
-
max_threads=None
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
encoding (str): 'signed' (1/-1/0) or 'binary' (1/0/NaN).
|
|
28
|
-
max_threads (int): Number of threads for parallel processing.
|
|
28
|
+
adata: "ad.AnnData",
|
|
29
|
+
layer: str,
|
|
30
|
+
output_key: str = "entropy",
|
|
31
|
+
site_config: dict[str, Sequence[str]] | None = None,
|
|
32
|
+
ref_col: str = "Reference_strand",
|
|
33
|
+
encoding: str = "signed",
|
|
34
|
+
max_threads: int | None = None,
|
|
35
|
+
) -> None:
|
|
36
|
+
"""Add per-read entropy values to ``adata.obs``.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
adata: Annotated data matrix.
|
|
40
|
+
layer: Layer name to use for entropy calculation.
|
|
41
|
+
output_key: Base name for the entropy column in ``adata.obs``.
|
|
42
|
+
site_config: Mapping of reference to site types for masking.
|
|
43
|
+
ref_col: Obs column containing reference strands.
|
|
44
|
+
encoding: ``"signed"`` (1/-1/0) or ``"binary"`` (1/0/NaN).
|
|
45
|
+
max_threads: Number of threads for parallel processing.
|
|
29
46
|
"""
|
|
30
47
|
import numpy as np
|
|
31
48
|
import pandas as pd
|
|
32
|
-
from scipy.stats import entropy
|
|
33
49
|
from joblib import Parallel, delayed
|
|
50
|
+
from scipy.stats import entropy
|
|
34
51
|
from tqdm import tqdm
|
|
35
52
|
|
|
36
53
|
entropy_values = []
|
|
@@ -55,12 +72,14 @@ def calculate_row_entropy(
|
|
|
55
72
|
X_bin = np.where(X == 1, 1, np.where(X == 0, 0, np.nan))
|
|
56
73
|
|
|
57
74
|
def compute_entropy(row):
|
|
75
|
+
"""Compute Shannon entropy for a row with NaNs ignored."""
|
|
58
76
|
counts = pd.Series(row).value_counts(dropna=True).sort_index()
|
|
59
77
|
probs = counts / counts.sum()
|
|
60
78
|
return entropy(probs, base=2)
|
|
61
79
|
|
|
62
80
|
entropies = Parallel(n_jobs=max_threads)(
|
|
63
|
-
delayed(compute_entropy)(X_bin[i, :])
|
|
81
|
+
delayed(compute_entropy)(X_bin[i, :])
|
|
82
|
+
for i in tqdm(range(X_bin.shape[0]), desc=f"Entropy: {ref}")
|
|
64
83
|
)
|
|
65
84
|
|
|
66
85
|
entropy_values.extend(entropies)
|
|
@@ -69,6 +88,7 @@ def calculate_row_entropy(
|
|
|
69
88
|
entropy_key = f"{output_key}_entropy"
|
|
70
89
|
adata.obs.loc[row_indices, entropy_key] = entropy_values
|
|
71
90
|
|
|
91
|
+
|
|
72
92
|
def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sorted=True):
|
|
73
93
|
"""
|
|
74
94
|
Fast autocorrelation over real genomic spacing.
|
|
@@ -125,13 +145,13 @@ def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sor
|
|
|
125
145
|
j += 1
|
|
126
146
|
# consider pairs (i, i+1...j-1)
|
|
127
147
|
if j - i > 1:
|
|
128
|
-
diffs = pos[i+1:j] - pos[i]
|
|
129
|
-
contrib = xc[i] * xc[i+1:j]
|
|
148
|
+
diffs = pos[i + 1 : j] - pos[i] # 1..max_lag
|
|
149
|
+
contrib = xc[i] * xc[i + 1 : j] # contributions for each pair
|
|
130
150
|
# accumulate weighted sums and counts per lag
|
|
131
|
-
lag_sums[:max_lag+1] += np.bincount(diffs, weights=contrib,
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
151
|
+
lag_sums[: max_lag + 1] += np.bincount(diffs, weights=contrib, minlength=max_lag + 1)[
|
|
152
|
+
: max_lag + 1
|
|
153
|
+
]
|
|
154
|
+
lag_counts[: max_lag + 1] += np.bincount(diffs, minlength=max_lag + 1)[: max_lag + 1]
|
|
135
155
|
|
|
136
156
|
autocorr = np.full(max_lag + 1, np.nan, dtype=np.float64)
|
|
137
157
|
nz = lag_counts > 0
|
|
@@ -140,6 +160,7 @@ def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sor
|
|
|
140
160
|
|
|
141
161
|
return autocorr.astype(np.float32, copy=False)
|
|
142
162
|
|
|
163
|
+
|
|
143
164
|
# def binary_autocorrelation_with_spacing(row, positions, max_lag=1000):
|
|
144
165
|
# """
|
|
145
166
|
# Compute autocorrelation within a read using real genomic spacing from `positions`.
|