smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
smftools/tools/position_stats.py
CHANGED
|
@@ -1,23 +1,61 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import warnings
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from itertools import cycle
|
|
7
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from scipy.stats import chi2_contingency
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
from smftools.optional_imports import require
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
import anndata as ad
|
|
18
|
+
|
|
19
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
|
|
20
|
+
|
|
21
|
+
# -----------------------------
|
|
22
|
+
# Compute positionwise statistic (multi-method + simple site_types)
|
|
23
|
+
# -----------------------------
|
|
24
|
+
|
|
25
|
+
|
|
1
26
|
# ------------------------- Utilities -------------------------
|
|
2
|
-
def random_fill_nans(X):
|
|
27
|
+
def random_fill_nans(X: np.ndarray) -> np.ndarray:
|
|
28
|
+
"""Fill NaNs with random values in-place.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
X: Input array with NaNs.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
numpy.ndarray: Array with NaNs replaced by random values.
|
|
35
|
+
"""
|
|
3
36
|
import numpy as np
|
|
37
|
+
|
|
4
38
|
nan_mask = np.isnan(X)
|
|
5
39
|
X[nan_mask] = np.random.rand(*X[nan_mask].shape)
|
|
6
40
|
return X
|
|
7
41
|
|
|
8
|
-
def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
9
|
-
"""
|
|
10
|
-
Perform Bayesian-style methylation vs activity analysis independently within each group.
|
|
11
42
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
43
|
+
def calculate_relative_risk_on_activity(
|
|
44
|
+
adata: "ad.AnnData",
|
|
45
|
+
sites: Sequence[str],
|
|
46
|
+
alpha: float = 0.05,
|
|
47
|
+
groupby: str | Sequence[str] | None = None,
|
|
48
|
+
) -> dict:
|
|
49
|
+
"""Perform methylation vs. activity analysis within each group.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
adata: Annotated data matrix.
|
|
53
|
+
sites: Site keys (e.g., ``["GpC_site", "CpG_site"]``).
|
|
54
|
+
alpha: FDR threshold for significance.
|
|
55
|
+
groupby: Obs column(s) to group by.
|
|
17
56
|
|
|
18
57
|
Returns:
|
|
19
|
-
|
|
20
|
-
results_dict[ref][group_label] = (results_df, sig_df)
|
|
58
|
+
dict: Mapping of reference -> group label -> ``(results_df, sig_df)``.
|
|
21
59
|
"""
|
|
22
60
|
import numpy as np
|
|
23
61
|
import pandas as pd
|
|
@@ -25,30 +63,44 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
|
25
63
|
from statsmodels.stats.multitest import multipletests
|
|
26
64
|
|
|
27
65
|
def compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values):
|
|
28
|
-
|
|
66
|
+
"""Build result and significant-data DataFrames for a reference.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
ref: Reference name.
|
|
70
|
+
site_subset: AnnData subset restricted to sites.
|
|
71
|
+
positions_list: Positions tested.
|
|
72
|
+
relative_risks: Relative risk values.
|
|
73
|
+
p_values: Raw p-values.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Tuple of (results_df, sig_df).
|
|
77
|
+
"""
|
|
78
|
+
p_adj = multipletests(p_values, method="fdr_bh")[1] if p_values else []
|
|
29
79
|
|
|
30
80
|
genomic_positions = np.array(site_subset.var_names)[positions_list]
|
|
31
81
|
is_gpc_site = site_subset.var[f"{ref}_GpC_site"].values[positions_list]
|
|
32
82
|
is_cpg_site = site_subset.var[f"{ref}_CpG_site"].values[positions_list]
|
|
33
83
|
|
|
34
|
-
results_df = pd.DataFrame(
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
84
|
+
results_df = pd.DataFrame(
|
|
85
|
+
{
|
|
86
|
+
"Feature_Index": positions_list,
|
|
87
|
+
"Genomic_Position": genomic_positions.astype(int),
|
|
88
|
+
"Relative_Risk": relative_risks,
|
|
89
|
+
"Adjusted_P_Value": p_adj,
|
|
90
|
+
"GpC_Site": is_gpc_site,
|
|
91
|
+
"CpG_Site": is_cpg_site,
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
results_df["log2_Relative_Risk"] = np.log2(results_df["Relative_Risk"].replace(0, 1e-300))
|
|
96
|
+
results_df["-log10_Adj_P"] = -np.log10(results_df["Adjusted_P_Value"].replace(0, 1e-300))
|
|
97
|
+
sig_df = results_df[results_df["Adjusted_P_Value"] < alpha]
|
|
46
98
|
return results_df, sig_df
|
|
47
99
|
|
|
48
100
|
results_dict = {}
|
|
49
101
|
|
|
50
|
-
for ref in adata.obs[
|
|
51
|
-
ref_subset = adata[adata.obs[
|
|
102
|
+
for ref in adata.obs["Reference_strand"].unique():
|
|
103
|
+
ref_subset = adata[adata.obs["Reference_strand"] == ref].copy()
|
|
52
104
|
if ref_subset.shape[0] == 0:
|
|
53
105
|
continue
|
|
54
106
|
|
|
@@ -56,20 +108,22 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
|
56
108
|
if groupby is not None:
|
|
57
109
|
if isinstance(groupby, str):
|
|
58
110
|
groupby = [groupby]
|
|
111
|
+
|
|
59
112
|
def format_group_label(row):
|
|
113
|
+
"""Format a group label string from obs row values."""
|
|
60
114
|
return ",".join([f"{col}={row[col]}" for col in groupby])
|
|
61
115
|
|
|
62
|
-
combined_label =
|
|
116
|
+
combined_label = "__".join(groupby)
|
|
63
117
|
ref_subset.obs[combined_label] = ref_subset.obs.apply(format_group_label, axis=1)
|
|
64
118
|
groups = ref_subset.obs[combined_label].unique()
|
|
65
119
|
else:
|
|
66
120
|
combined_label = None
|
|
67
|
-
groups = [
|
|
121
|
+
groups = ["all"]
|
|
68
122
|
|
|
69
123
|
results_dict[ref] = {}
|
|
70
124
|
|
|
71
125
|
for group in groups:
|
|
72
|
-
if group ==
|
|
126
|
+
if group == "all":
|
|
73
127
|
group_subset = ref_subset
|
|
74
128
|
else:
|
|
75
129
|
group_subset = ref_subset[ref_subset.obs[combined_label] == group]
|
|
@@ -85,7 +139,7 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
|
85
139
|
|
|
86
140
|
# Matrix and labels
|
|
87
141
|
X = random_fill_nans(site_subset.X.copy())
|
|
88
|
-
y = site_subset.obs[
|
|
142
|
+
y = site_subset.obs["activity_status"].map({"Active": 1, "Silent": 0}).values
|
|
89
143
|
P_active = np.mean(y)
|
|
90
144
|
|
|
91
145
|
# Analysis
|
|
@@ -104,7 +158,9 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
|
104
158
|
continue
|
|
105
159
|
|
|
106
160
|
P_active_given_methylated = (P_methylated_given_active * P_active) / P_methylated
|
|
107
|
-
P_active_given_unmethylated = ((1 - P_methylated_given_active) * P_active) / (
|
|
161
|
+
P_active_given_unmethylated = ((1 - P_methylated_given_active) * P_active) / (
|
|
162
|
+
1 - P_methylated
|
|
163
|
+
)
|
|
108
164
|
RR = P_active_given_methylated / P_active_given_unmethylated
|
|
109
165
|
|
|
110
166
|
_, p_value = fisher_exact(table)
|
|
@@ -112,49 +168,13 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
|
112
168
|
relative_risks.append(RR)
|
|
113
169
|
p_values.append(p_value)
|
|
114
170
|
|
|
115
|
-
results_df, sig_df = compute_risk_df(
|
|
171
|
+
results_df, sig_df = compute_risk_df(
|
|
172
|
+
ref, site_subset, positions_list, relative_risks, p_values
|
|
173
|
+
)
|
|
116
174
|
results_dict[ref][group] = (results_df, sig_df)
|
|
117
175
|
|
|
118
176
|
return results_dict
|
|
119
177
|
|
|
120
|
-
import copy
|
|
121
|
-
import warnings
|
|
122
|
-
from typing import Dict, Any, List, Optional, Tuple, Union
|
|
123
|
-
|
|
124
|
-
import numpy as np
|
|
125
|
-
import pandas as pd
|
|
126
|
-
import matplotlib.pyplot as plt
|
|
127
|
-
|
|
128
|
-
# optional imports
|
|
129
|
-
try:
|
|
130
|
-
from joblib import Parallel, delayed
|
|
131
|
-
JOBLIB_AVAILABLE = True
|
|
132
|
-
except Exception:
|
|
133
|
-
JOBLIB_AVAILABLE = False
|
|
134
|
-
|
|
135
|
-
try:
|
|
136
|
-
from scipy.stats import chi2_contingency
|
|
137
|
-
SCIPY_STATS_AVAILABLE = True
|
|
138
|
-
except Exception:
|
|
139
|
-
SCIPY_STATS_AVAILABLE = False
|
|
140
|
-
|
|
141
|
-
# -----------------------------
|
|
142
|
-
# Compute positionwise statistic (multi-method + simple site_types)
|
|
143
|
-
# -----------------------------
|
|
144
|
-
import numpy as np
|
|
145
|
-
import pandas as pd
|
|
146
|
-
from typing import List, Optional, Sequence, Dict, Any, Tuple
|
|
147
|
-
from contextlib import contextmanager
|
|
148
|
-
from joblib import Parallel, delayed, cpu_count
|
|
149
|
-
import joblib
|
|
150
|
-
from tqdm import tqdm
|
|
151
|
-
from scipy.stats import chi2_contingency
|
|
152
|
-
import warnings
|
|
153
|
-
import matplotlib.pyplot as plt
|
|
154
|
-
from itertools import cycle
|
|
155
|
-
import os
|
|
156
|
-
import warnings
|
|
157
|
-
|
|
158
178
|
|
|
159
179
|
# ---------------------------
|
|
160
180
|
# joblib <-> tqdm integration
|
|
@@ -162,10 +182,15 @@ import warnings
|
|
|
162
182
|
@contextmanager
|
|
163
183
|
def tqdm_joblib(tqdm_object: tqdm):
|
|
164
184
|
"""Context manager to patch joblib to update a tqdm progress bar."""
|
|
185
|
+
joblib = require("joblib", extra="ml-base", purpose="parallel position statistics")
|
|
186
|
+
|
|
165
187
|
old = joblib.parallel.BatchCompletionCallBack
|
|
166
188
|
|
|
167
189
|
class TqdmBatchCompletionCallback(old): # type: ignore
|
|
190
|
+
"""Joblib callback that updates a tqdm progress bar."""
|
|
191
|
+
|
|
168
192
|
def __call__(self, *args, **kwargs):
|
|
193
|
+
"""Update the progress bar when a batch completes."""
|
|
169
194
|
try:
|
|
170
195
|
tqdm_object.update(n=self.batch_size)
|
|
171
196
|
except Exception:
|
|
@@ -183,6 +208,16 @@ def tqdm_joblib(tqdm_object: tqdm):
|
|
|
183
208
|
# row workers (upper-triangle only)
|
|
184
209
|
# ---------------------------
|
|
185
210
|
def _chi2_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
|
|
211
|
+
"""Compute chi-squared statistics for one row of a pairwise matrix.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
i: Row index.
|
|
215
|
+
X_bin: Binary matrix.
|
|
216
|
+
min_count_for_pairwise: Minimum count for valid comparison.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Tuple of (row_index, row_values).
|
|
220
|
+
"""
|
|
186
221
|
n_pos = X_bin.shape[1]
|
|
187
222
|
row = np.full((n_pos,), np.nan, dtype=float)
|
|
188
223
|
xi = X_bin[:, i]
|
|
@@ -202,7 +237,19 @@ def _chi2_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tup
|
|
|
202
237
|
return (i, row)
|
|
203
238
|
|
|
204
239
|
|
|
205
|
-
def _relative_risk_row_job(
|
|
240
|
+
def _relative_risk_row_job(
|
|
241
|
+
i: int, X_bin: np.ndarray, min_count_for_pairwise: int
|
|
242
|
+
) -> Tuple[int, np.ndarray]:
|
|
243
|
+
"""Compute relative-risk values for one row of a pairwise matrix.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
i: Row index.
|
|
247
|
+
X_bin: Binary matrix.
|
|
248
|
+
min_count_for_pairwise: Minimum count for valid comparison.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Tuple of (row_index, row_values).
|
|
252
|
+
"""
|
|
206
253
|
n_pos = X_bin.shape[1]
|
|
207
254
|
row = np.full((n_pos,), np.nan, dtype=float)
|
|
208
255
|
xi = X_bin[:, i]
|
|
@@ -226,8 +273,9 @@ def _relative_risk_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: in
|
|
|
226
273
|
row[j] = np.nan
|
|
227
274
|
return (i, row)
|
|
228
275
|
|
|
276
|
+
|
|
229
277
|
def compute_positionwise_statistics(
|
|
230
|
-
adata,
|
|
278
|
+
adata: "ad.AnnData",
|
|
231
279
|
layer: str,
|
|
232
280
|
methods: Sequence[str] = ("pearson",),
|
|
233
281
|
sample_col: str = "Barcode",
|
|
@@ -238,14 +286,24 @@ def compute_positionwise_statistics(
|
|
|
238
286
|
min_count_for_pairwise: int = 10,
|
|
239
287
|
max_threads: Optional[int] = None,
|
|
240
288
|
reverse_indices_on_store: bool = False,
|
|
241
|
-
):
|
|
289
|
+
) -> None:
|
|
290
|
+
"""Compute per-(sample, ref) positionwise matrices for selected methods.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
adata: AnnData object to analyze.
|
|
294
|
+
layer: Layer name to use for statistics.
|
|
295
|
+
methods: Methods to compute (e.g., ``"pearson"``).
|
|
296
|
+
sample_col: Obs column containing sample identifiers.
|
|
297
|
+
ref_col: Obs column containing reference identifiers.
|
|
298
|
+
site_types: Optional site types to subset positions.
|
|
299
|
+
encoding: ``"signed"`` or ``"binary"`` encoding.
|
|
300
|
+
output_key: Key prefix for results stored in ``adata.uns``.
|
|
301
|
+
min_count_for_pairwise: Minimum counts for pairwise comparisons.
|
|
302
|
+
max_threads: Maximum number of threads.
|
|
303
|
+
reverse_indices_on_store: Whether to reverse indices on output storage.
|
|
242
304
|
"""
|
|
243
|
-
|
|
305
|
+
joblib = require("joblib", extra="ml-base", purpose="parallel position statistics")
|
|
244
306
|
|
|
245
|
-
Results stored at:
|
|
246
|
-
adata.uns[output_key][method][ (sample, ref) ] = DataFrame
|
|
247
|
-
adata.uns[output_key + "_n"][method][ (sample, ref) ] = int(n_reads)
|
|
248
|
-
"""
|
|
249
307
|
if isinstance(methods, str):
|
|
250
308
|
methods = [methods]
|
|
251
309
|
methods = [m.lower() for m in methods]
|
|
@@ -256,7 +314,7 @@ def compute_positionwise_statistics(
|
|
|
256
314
|
|
|
257
315
|
# workers
|
|
258
316
|
if max_threads is None or max_threads <= 0:
|
|
259
|
-
n_jobs = max(1, cpu_count() or 1)
|
|
317
|
+
n_jobs = max(1, joblib.cpu_count() or 1)
|
|
260
318
|
else:
|
|
261
319
|
n_jobs = max(1, int(max_threads))
|
|
262
320
|
|
|
@@ -349,7 +407,10 @@ def compute_positionwise_statistics(
|
|
|
349
407
|
Xc = X_bin - col_mean # nan preserved
|
|
350
408
|
Xc0 = np.nan_to_num(Xc, nan=0.0)
|
|
351
409
|
cov = Xc0.T @ Xc0
|
|
352
|
-
denom = (
|
|
410
|
+
denom = (
|
|
411
|
+
np.sqrt((Xc0**2).sum(axis=0))[:, None]
|
|
412
|
+
* np.sqrt((Xc0**2).sum(axis=0))[None, :]
|
|
413
|
+
)
|
|
353
414
|
with np.errstate(divide="ignore", invalid="ignore"):
|
|
354
415
|
mat = np.where(denom != 0.0, cov / denom, np.nan)
|
|
355
416
|
elif m == "binary_covariance":
|
|
@@ -366,10 +427,15 @@ def compute_positionwise_statistics(
|
|
|
366
427
|
else:
|
|
367
428
|
worker = _relative_risk_row_job
|
|
368
429
|
out = np.full((n_pos, n_pos), np.nan, dtype=float)
|
|
369
|
-
tasks = (
|
|
370
|
-
|
|
430
|
+
tasks = (
|
|
431
|
+
joblib.delayed(worker)(i, X_bin, min_count_for_pairwise)
|
|
432
|
+
for i in range(n_pos)
|
|
433
|
+
)
|
|
434
|
+
pbar_rows = tqdm(
|
|
435
|
+
total=n_pos, desc=f"{m}: rows ({sample}__{ref})", leave=False
|
|
436
|
+
)
|
|
371
437
|
with tqdm_joblib(pbar_rows):
|
|
372
|
-
results = Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
|
|
438
|
+
results = joblib.Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
|
|
373
439
|
pbar_rows.close()
|
|
374
440
|
for i, row in results:
|
|
375
441
|
out[int(i), :] = row
|
|
@@ -406,6 +472,7 @@ def compute_positionwise_statistics(
|
|
|
406
472
|
# Plotting function
|
|
407
473
|
# ---------------------------
|
|
408
474
|
|
|
475
|
+
|
|
409
476
|
def plot_positionwise_matrices(
|
|
410
477
|
adata,
|
|
411
478
|
methods: List[str],
|
|
@@ -427,9 +494,10 @@ def plot_positionwise_matrices(
|
|
|
427
494
|
"""
|
|
428
495
|
Plot grids of matrices for each method with pagination and rotated sample-row labels.
|
|
429
496
|
|
|
430
|
-
|
|
497
|
+
Args:
|
|
431
498
|
- rows_per_page: how many sample rows per page/figure (pagination)
|
|
432
499
|
- sample_label_rotation: rotation angle (deg) for the sample labels placed in the left margin.
|
|
500
|
+
|
|
433
501
|
Returns:
|
|
434
502
|
dict mapping method -> list of saved filenames (empty list if figures were shown).
|
|
435
503
|
"""
|
|
@@ -474,7 +542,12 @@ def plot_positionwise_matrices(
|
|
|
474
542
|
# try stringified tuple keys (some callers store differently)
|
|
475
543
|
for k in store.keys():
|
|
476
544
|
try:
|
|
477
|
-
if
|
|
545
|
+
if (
|
|
546
|
+
isinstance(k, tuple)
|
|
547
|
+
and len(k) == 2
|
|
548
|
+
and str(k[0]) == str(sample)
|
|
549
|
+
and str(k[1]) == str(ref)
|
|
550
|
+
):
|
|
478
551
|
return store[k]
|
|
479
552
|
if isinstance(k, str) and key_s == k:
|
|
480
553
|
return store[k]
|
|
@@ -486,7 +559,10 @@ def plot_positionwise_matrices(
|
|
|
486
559
|
m = method.lower()
|
|
487
560
|
method_store = adata.uns.get(output_key, {}).get(m, {})
|
|
488
561
|
if not method_store:
|
|
489
|
-
warnings.warn(
|
|
562
|
+
warnings.warn(
|
|
563
|
+
f"No results found for method '{method}' in adata.uns['{output_key}']. Skipping.",
|
|
564
|
+
stacklevel=2,
|
|
565
|
+
)
|
|
490
566
|
saved_files_by_method[method] = []
|
|
491
567
|
continue
|
|
492
568
|
|
|
@@ -507,21 +583,41 @@ def plot_positionwise_matrices(
|
|
|
507
583
|
|
|
508
584
|
# decide per-method defaults
|
|
509
585
|
if m == "pearson":
|
|
510
|
-
vmn =
|
|
511
|
-
|
|
586
|
+
vmn = (
|
|
587
|
+
-1.0
|
|
588
|
+
if (vmin is None or (isinstance(vmin, dict) and m not in vmin))
|
|
589
|
+
else (vmin.get(m) if isinstance(vmin, dict) else vmin)
|
|
590
|
+
)
|
|
591
|
+
vmx = (
|
|
592
|
+
1.0
|
|
593
|
+
if (vmax is None or (isinstance(vmax, dict) and m not in vmax))
|
|
594
|
+
else (vmax.get(m) if isinstance(vmax, dict) else vmax)
|
|
595
|
+
)
|
|
512
596
|
vmn = -1.0 if vmn is None else vmn
|
|
513
597
|
vmx = 1.0 if vmx is None else vmx
|
|
514
598
|
elif m == "binary_covariance":
|
|
515
|
-
vmn =
|
|
516
|
-
|
|
599
|
+
vmn = (
|
|
600
|
+
0.0
|
|
601
|
+
if (vmin is None or (isinstance(vmin, dict) and m not in vmin))
|
|
602
|
+
else (vmin.get(m) if isinstance(vmin, dict) else vmin)
|
|
603
|
+
)
|
|
604
|
+
vmx = (
|
|
605
|
+
1.0
|
|
606
|
+
if (vmax is None or (isinstance(vmax, dict) and m not in vmax))
|
|
607
|
+
else (vmax.get(m) if isinstance(vmax, dict) else vmax)
|
|
608
|
+
)
|
|
517
609
|
vmn = 0.0 if vmn is None else vmn
|
|
518
610
|
vmx = 1.0 if vmx is None else vmx
|
|
519
611
|
else:
|
|
520
|
-
vmn =
|
|
612
|
+
vmn = (
|
|
613
|
+
0.0
|
|
614
|
+
if (vmin is None or (isinstance(vmin, dict) and m not in vmin))
|
|
615
|
+
else (vmin.get(m) if isinstance(vmin, dict) else vmin)
|
|
616
|
+
)
|
|
521
617
|
if (vmax is None) or (isinstance(vmax, dict) and m not in vmax):
|
|
522
618
|
vmx = float(np.nanpercentile(allvals, 99.0)) if allvals.size else 1.0
|
|
523
619
|
else:
|
|
524
|
-
vmx =
|
|
620
|
+
vmx = vmax.get(m) if isinstance(vmax, dict) else vmax
|
|
525
621
|
vmn = 0.0 if vmn is None else vmn
|
|
526
622
|
if vmx is None:
|
|
527
623
|
vmx = 1.0
|
|
@@ -536,7 +632,9 @@ def plot_positionwise_matrices(
|
|
|
536
632
|
ncols = max(1, len(references))
|
|
537
633
|
fig_w = ncols * figsize_per_cell[0]
|
|
538
634
|
fig_h = nrows * figsize_per_cell[1]
|
|
539
|
-
fig, axes = plt.subplots(
|
|
635
|
+
fig, axes = plt.subplots(
|
|
636
|
+
nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
|
|
637
|
+
)
|
|
540
638
|
|
|
541
639
|
# leave margin for rotated sample labels
|
|
542
640
|
plt.subplots_adjust(left=0.12, right=0.88, top=0.95, bottom=0.05)
|
|
@@ -548,13 +646,24 @@ def plot_positionwise_matrices(
|
|
|
548
646
|
ax = axes[r_idx][c_idx]
|
|
549
647
|
df = _get_df_from_store(method_store, sample, ref)
|
|
550
648
|
if not isinstance(df, pd.DataFrame) or df.size == 0:
|
|
551
|
-
ax.text(
|
|
649
|
+
ax.text(
|
|
650
|
+
0.5,
|
|
651
|
+
0.5,
|
|
652
|
+
"No data",
|
|
653
|
+
ha="center",
|
|
654
|
+
va="center",
|
|
655
|
+
transform=ax.transAxes,
|
|
656
|
+
fontsize=10,
|
|
657
|
+
color="gray",
|
|
658
|
+
)
|
|
552
659
|
ax.set_xticks([])
|
|
553
660
|
ax.set_yticks([])
|
|
554
661
|
else:
|
|
555
662
|
mat = df.values.astype(float)
|
|
556
663
|
origin = "upper" if flip_display_axes else "lower"
|
|
557
|
-
im = ax.imshow(
|
|
664
|
+
im = ax.imshow(
|
|
665
|
+
mat, origin=origin, aspect="auto", vmin=vmn, vmax=vmx, cmap=cmap
|
|
666
|
+
)
|
|
558
667
|
any_plotted = True
|
|
559
668
|
ax.set_xticks([])
|
|
560
669
|
ax.set_yticks([])
|
|
@@ -569,9 +678,21 @@ def plot_positionwise_matrices(
|
|
|
569
678
|
ax_y0, ax_y1 = ax0.get_position().y0, ax0.get_position().y1
|
|
570
679
|
y_center = 0.5 * (ax_y0 + ax_y1)
|
|
571
680
|
# place text at x=0.01 (just inside left margin); rotation controls orientation
|
|
572
|
-
fig.text(
|
|
573
|
-
|
|
574
|
-
|
|
681
|
+
fig.text(
|
|
682
|
+
0.01,
|
|
683
|
+
y_center,
|
|
684
|
+
str(chunk[r_idx]),
|
|
685
|
+
va="center",
|
|
686
|
+
ha="left",
|
|
687
|
+
rotation=sample_label_rotation,
|
|
688
|
+
fontsize=9,
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
fig.suptitle(
|
|
692
|
+
f"{method} — per-sample x per-reference matrices (page {page_idx + 1}/{n_pages})",
|
|
693
|
+
fontsize=12,
|
|
694
|
+
y=0.99,
|
|
695
|
+
)
|
|
575
696
|
fig.tight_layout(rect=[0.05, 0.02, 0.9, 0.96])
|
|
576
697
|
|
|
577
698
|
# colorbar (shared)
|
|
@@ -587,7 +708,7 @@ def plot_positionwise_matrices(
|
|
|
587
708
|
|
|
588
709
|
# save or show
|
|
589
710
|
if output_dir:
|
|
590
|
-
fname = f"positionwise_{method}_page{page_idx+1}.png"
|
|
711
|
+
fname = f"positionwise_{method}_page{page_idx + 1}.png"
|
|
591
712
|
outpath = os.path.join(output_dir, fname)
|
|
592
713
|
plt.savefig(outpath, bbox_inches="tight")
|
|
593
714
|
saved_files.append(outpath)
|
smftools/tools/read_stats.py
CHANGED
|
@@ -1,36 +1,53 @@
|
|
|
1
1
|
# ------------------------- Utilities -------------------------
|
|
2
|
-
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING, Sequence
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
import anndata as ad
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def random_fill_nans(X: "np.ndarray") -> "np.ndarray":
|
|
12
|
+
"""Fill NaNs with random values in-place.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
X: Input array with NaNs.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
numpy.ndarray: Array with NaNs replaced by random values.
|
|
19
|
+
"""
|
|
3
20
|
import numpy as np
|
|
21
|
+
|
|
4
22
|
nan_mask = np.isnan(X)
|
|
5
23
|
X[nan_mask] = np.random.rand(*X[nan_mask].shape)
|
|
6
24
|
return X
|
|
7
25
|
|
|
26
|
+
|
|
8
27
|
def calculate_row_entropy(
|
|
9
|
-
adata,
|
|
10
|
-
layer,
|
|
11
|
-
output_key="entropy",
|
|
12
|
-
site_config=None,
|
|
13
|
-
ref_col="Reference_strand",
|
|
14
|
-
encoding="signed",
|
|
15
|
-
max_threads=None
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
encoding (str): 'signed' (1/-1/0) or 'binary' (1/0/NaN).
|
|
28
|
-
max_threads (int): Number of threads for parallel processing.
|
|
28
|
+
adata: "ad.AnnData",
|
|
29
|
+
layer: str,
|
|
30
|
+
output_key: str = "entropy",
|
|
31
|
+
site_config: dict[str, Sequence[str]] | None = None,
|
|
32
|
+
ref_col: str = "Reference_strand",
|
|
33
|
+
encoding: str = "signed",
|
|
34
|
+
max_threads: int | None = None,
|
|
35
|
+
) -> None:
|
|
36
|
+
"""Add per-read entropy values to ``adata.obs``.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
adata: Annotated data matrix.
|
|
40
|
+
layer: Layer name to use for entropy calculation.
|
|
41
|
+
output_key: Base name for the entropy column in ``adata.obs``.
|
|
42
|
+
site_config: Mapping of reference to site types for masking.
|
|
43
|
+
ref_col: Obs column containing reference strands.
|
|
44
|
+
encoding: ``"signed"`` (1/-1/0) or ``"binary"`` (1/0/NaN).
|
|
45
|
+
max_threads: Number of threads for parallel processing.
|
|
29
46
|
"""
|
|
30
47
|
import numpy as np
|
|
31
48
|
import pandas as pd
|
|
32
|
-
from scipy.stats import entropy
|
|
33
49
|
from joblib import Parallel, delayed
|
|
50
|
+
from scipy.stats import entropy
|
|
34
51
|
from tqdm import tqdm
|
|
35
52
|
|
|
36
53
|
entropy_values = []
|
|
@@ -55,12 +72,14 @@ def calculate_row_entropy(
|
|
|
55
72
|
X_bin = np.where(X == 1, 1, np.where(X == 0, 0, np.nan))
|
|
56
73
|
|
|
57
74
|
def compute_entropy(row):
|
|
75
|
+
"""Compute Shannon entropy for a row with NaNs ignored."""
|
|
58
76
|
counts = pd.Series(row).value_counts(dropna=True).sort_index()
|
|
59
77
|
probs = counts / counts.sum()
|
|
60
78
|
return entropy(probs, base=2)
|
|
61
79
|
|
|
62
80
|
entropies = Parallel(n_jobs=max_threads)(
|
|
63
|
-
delayed(compute_entropy)(X_bin[i, :])
|
|
81
|
+
delayed(compute_entropy)(X_bin[i, :])
|
|
82
|
+
for i in tqdm(range(X_bin.shape[0]), desc=f"Entropy: {ref}")
|
|
64
83
|
)
|
|
65
84
|
|
|
66
85
|
entropy_values.extend(entropies)
|
|
@@ -69,6 +88,7 @@ def calculate_row_entropy(
|
|
|
69
88
|
entropy_key = f"{output_key}_entropy"
|
|
70
89
|
adata.obs.loc[row_indices, entropy_key] = entropy_values
|
|
71
90
|
|
|
91
|
+
|
|
72
92
|
def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sorted=True):
|
|
73
93
|
"""
|
|
74
94
|
Fast autocorrelation over real genomic spacing.
|
|
@@ -125,13 +145,13 @@ def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sor
|
|
|
125
145
|
j += 1
|
|
126
146
|
# consider pairs (i, i+1...j-1)
|
|
127
147
|
if j - i > 1:
|
|
128
|
-
diffs = pos[i+1:j] - pos[i]
|
|
129
|
-
contrib = xc[i] * xc[i+1:j]
|
|
148
|
+
diffs = pos[i + 1 : j] - pos[i] # 1..max_lag
|
|
149
|
+
contrib = xc[i] * xc[i + 1 : j] # contributions for each pair
|
|
130
150
|
# accumulate weighted sums and counts per lag
|
|
131
|
-
lag_sums[:max_lag+1] += np.bincount(diffs, weights=contrib,
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
151
|
+
lag_sums[: max_lag + 1] += np.bincount(diffs, weights=contrib, minlength=max_lag + 1)[
|
|
152
|
+
: max_lag + 1
|
|
153
|
+
]
|
|
154
|
+
lag_counts[: max_lag + 1] += np.bincount(diffs, minlength=max_lag + 1)[: max_lag + 1]
|
|
135
155
|
|
|
136
156
|
autocorr = np.full(max_lag + 1, np.nan, dtype=np.float64)
|
|
137
157
|
nz = lag_counts > 0
|
|
@@ -140,6 +160,7 @@ def binary_autocorrelation_with_spacing(row, positions, max_lag=1000, assume_sor
|
|
|
140
160
|
|
|
141
161
|
return autocorr.astype(np.float32, copy=False)
|
|
142
162
|
|
|
163
|
+
|
|
143
164
|
# def binary_autocorrelation_with_spacing(row, positions, max_lag=1000):
|
|
144
165
|
# """
|
|
145
166
|
# Compute autocorrelation within a read using real genomic spacing from `positions`.
|