smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +34 -0
- smftools/_settings.py +20 -0
- smftools/_version.py +1 -0
- smftools/cli.py +184 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +33 -0
- smftools/config/deaminase.yaml +56 -0
- smftools/config/default.yaml +253 -0
- smftools/config/direct.yaml +17 -0
- smftools/config/experiment_config.py +1191 -0
- smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
- smftools/datasets/F1_sample_sheet.csv +5 -0
- smftools/datasets/__init__.py +9 -0
- smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
- smftools/datasets/datasets.py +28 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/hmm/apply_hmm_batched.py +242 -0
- smftools/hmm/calculate_distances.py +18 -0
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/hmm/display_hmm.py +18 -0
- smftools/hmm/hmm_readwrite.py +16 -0
- smftools/hmm/nucleosome_hmm_refinement.py +104 -0
- smftools/hmm/train_hmm.py +78 -0
- smftools/informatics/__init__.py +14 -0
- smftools/informatics/archived/bam_conversion.py +59 -0
- smftools/informatics/archived/bam_direct.py +63 -0
- smftools/informatics/archived/basecalls_to_adata.py +71 -0
- smftools/informatics/archived/conversion_smf.py +132 -0
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/direct_smf.py +137 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/fast5_to_pod5.py +24 -0
- smftools/informatics/helpers/__init__.py +73 -0
- smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
- smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
- smftools/informatics/helpers/archived/informatics.py +260 -0
- smftools/informatics/helpers/archived/load_adata.py +516 -0
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/bed_to_bigwig.py +39 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
- smftools/informatics/helpers/canoncall.py +34 -0
- smftools/informatics/helpers/complement_base_list.py +21 -0
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
- smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
- smftools/informatics/helpers/count_aligned_reads.py +43 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/discover_input_files.py +100 -0
- smftools/informatics/helpers/extract_base_identities.py +70 -0
- smftools/informatics/helpers/extract_mods.py +83 -0
- smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
- smftools/informatics/helpers/find_conversion_sites.py +51 -0
- smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
- smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
- smftools/informatics/helpers/get_native_references.py +28 -0
- smftools/informatics/helpers/index_fasta.py +12 -0
- smftools/informatics/helpers/make_dirs.py +21 -0
- smftools/informatics/helpers/make_modbed.py +27 -0
- smftools/informatics/helpers/modQC.py +27 -0
- smftools/informatics/helpers/modcall.py +36 -0
- smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
- smftools/informatics/helpers/ohe_batching.py +76 -0
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +57 -0
- smftools/informatics/helpers/plot_bed_histograms.py +269 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
- smftools/informatics/helpers/split_and_index_BAM.py +32 -0
- smftools/informatics/readwrite.py +106 -0
- smftools/informatics/subsample_fasta_from_bed.py +47 -0
- smftools/informatics/subsample_pod5.py +104 -0
- smftools/load_adata.py +1346 -0
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/data/preprocessing.py +6 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/__init__.py +9 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/machine_learning/models/positional.py +18 -0
- smftools/machine_learning/models/rnn.py +17 -0
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/models/wrappers.py +20 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/machine_learning/utils/__init__.py +2 -0
- smftools/machine_learning/utils/device.py +10 -0
- smftools/machine_learning/utils/grl.py +14 -0
- smftools/plotting/__init__.py +18 -0
- smftools/plotting/autocorrelation_plotting.py +611 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +682 -0
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +38 -0
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/archives/mark_duplicates.py +146 -0
- smftools/preprocessing/archives/preprocessing.py +614 -0
- smftools/preprocessing/archives/remove_duplicates.py +21 -0
- smftools/preprocessing/binarize_on_Youden.py +45 -0
- smftools/preprocessing/binary_layers_to_ohe.py +40 -0
- smftools/preprocessing/calculate_complexity.py +72 -0
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_consensus.py +47 -0
- smftools/preprocessing/calculate_coverage.py +51 -0
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
- smftools/preprocessing/calculate_position_Youden.py +115 -0
- smftools/preprocessing/calculate_read_length_stats.py +79 -0
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +62 -0
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1351 -0
- smftools/preprocessing/invert_adata.py +37 -0
- smftools/preprocessing/load_sample_sheet.py +53 -0
- smftools/preprocessing/make_dirs.py +21 -0
- smftools/preprocessing/min_non_diagonal.py +25 -0
- smftools/preprocessing/recipes.py +127 -0
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +1004 -0
- smftools/tools/__init__.py +20 -0
- smftools/tools/archived/apply_hmm.py +202 -0
- smftools/tools/archived/classifiers.py +787 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/position_stats.py +601 -0
- smftools/tools/read_stats.py +184 -0
- smftools/tools/spatial_autocorrelation.py +562 -0
- smftools/tools/subset_adata.py +28 -0
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
- smftools-0.2.1.dist-info/RECORD +161 -0
- smftools-0.2.1.dist-info/entry_points.txt +2 -0
- smftools-0.1.6.dist-info/RECORD +0 -4
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,601 @@
|
|
|
1
|
+
# ------------------------- Utilities -------------------------
|
|
2
|
+
def random_fill_nans(X):
|
|
3
|
+
import numpy as np
|
|
4
|
+
nan_mask = np.isnan(X)
|
|
5
|
+
X[nan_mask] = np.random.rand(*X[nan_mask].shape)
|
|
6
|
+
return X
|
|
7
|
+
|
|
8
|
+
def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
9
|
+
"""
|
|
10
|
+
Perform Bayesian-style methylation vs activity analysis independently within each group.
|
|
11
|
+
|
|
12
|
+
Parameters:
|
|
13
|
+
adata (AnnData): Annotated data matrix.
|
|
14
|
+
sites (list of str): List of site keys (e.g., ['GpC_site', 'CpG_site']).
|
|
15
|
+
alpha (float): FDR threshold for significance.
|
|
16
|
+
groupby (str or list of str): Column(s) in adata.obs to group by.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
results_dict (dict): Dictionary with structure:
|
|
20
|
+
results_dict[ref][group_label] = (results_df, sig_df)
|
|
21
|
+
"""
|
|
22
|
+
import numpy as np
|
|
23
|
+
import pandas as pd
|
|
24
|
+
from scipy.stats import fisher_exact
|
|
25
|
+
from statsmodels.stats.multitest import multipletests
|
|
26
|
+
|
|
27
|
+
def compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values):
|
|
28
|
+
p_adj = multipletests(p_values, method='fdr_bh')[1] if p_values else []
|
|
29
|
+
|
|
30
|
+
genomic_positions = np.array(site_subset.var_names)[positions_list]
|
|
31
|
+
is_gpc_site = site_subset.var[f"{ref}_GpC_site"].values[positions_list]
|
|
32
|
+
is_cpg_site = site_subset.var[f"{ref}_CpG_site"].values[positions_list]
|
|
33
|
+
|
|
34
|
+
results_df = pd.DataFrame({
|
|
35
|
+
'Feature_Index': positions_list,
|
|
36
|
+
'Genomic_Position': genomic_positions.astype(int),
|
|
37
|
+
'Relative_Risk': relative_risks,
|
|
38
|
+
'Adjusted_P_Value': p_adj,
|
|
39
|
+
'GpC_Site': is_gpc_site,
|
|
40
|
+
'CpG_Site': is_cpg_site
|
|
41
|
+
})
|
|
42
|
+
|
|
43
|
+
results_df['log2_Relative_Risk'] = np.log2(results_df['Relative_Risk'].replace(0, 1e-300))
|
|
44
|
+
results_df['-log10_Adj_P'] = -np.log10(results_df['Adjusted_P_Value'].replace(0, 1e-300))
|
|
45
|
+
sig_df = results_df[results_df['Adjusted_P_Value'] < alpha]
|
|
46
|
+
return results_df, sig_df
|
|
47
|
+
|
|
48
|
+
results_dict = {}
|
|
49
|
+
|
|
50
|
+
for ref in adata.obs['Reference_strand'].unique():
|
|
51
|
+
ref_subset = adata[adata.obs['Reference_strand'] == ref].copy()
|
|
52
|
+
if ref_subset.shape[0] == 0:
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
# Normalize groupby to list
|
|
56
|
+
if groupby is not None:
|
|
57
|
+
if isinstance(groupby, str):
|
|
58
|
+
groupby = [groupby]
|
|
59
|
+
def format_group_label(row):
|
|
60
|
+
return ",".join([f"{col}={row[col]}" for col in groupby])
|
|
61
|
+
|
|
62
|
+
combined_label = '__'.join(groupby)
|
|
63
|
+
ref_subset.obs[combined_label] = ref_subset.obs.apply(format_group_label, axis=1)
|
|
64
|
+
groups = ref_subset.obs[combined_label].unique()
|
|
65
|
+
else:
|
|
66
|
+
combined_label = None
|
|
67
|
+
groups = ['all']
|
|
68
|
+
|
|
69
|
+
results_dict[ref] = {}
|
|
70
|
+
|
|
71
|
+
for group in groups:
|
|
72
|
+
if group == 'all':
|
|
73
|
+
group_subset = ref_subset
|
|
74
|
+
else:
|
|
75
|
+
group_subset = ref_subset[ref_subset.obs[combined_label] == group]
|
|
76
|
+
|
|
77
|
+
if group_subset.shape[0] == 0:
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
# Build site mask
|
|
81
|
+
site_mask = np.zeros(group_subset.shape[1], dtype=bool)
|
|
82
|
+
for site in sites:
|
|
83
|
+
site_mask |= group_subset.var[f"{ref}_{site}"]
|
|
84
|
+
site_subset = group_subset[:, site_mask].copy()
|
|
85
|
+
|
|
86
|
+
# Matrix and labels
|
|
87
|
+
X = random_fill_nans(site_subset.X.copy())
|
|
88
|
+
y = site_subset.obs['activity_status'].map({'Active': 1, 'Silent': 0}).values
|
|
89
|
+
P_active = np.mean(y)
|
|
90
|
+
|
|
91
|
+
# Analysis
|
|
92
|
+
positions_list, relative_risks, p_values = [], [], []
|
|
93
|
+
for pos in range(X.shape[1]):
|
|
94
|
+
methylation_state = (X[:, pos] > 0).astype(int)
|
|
95
|
+
table = pd.crosstab(methylation_state, y)
|
|
96
|
+
if table.shape != (2, 2):
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
P_methylated = np.mean(methylation_state)
|
|
100
|
+
P_methylated_given_active = np.mean(methylation_state[y == 1])
|
|
101
|
+
P_methylated_given_inactive = np.mean(methylation_state[y == 0])
|
|
102
|
+
|
|
103
|
+
if P_methylated_given_inactive == 0 or P_methylated in [0, 1]:
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
P_active_given_methylated = (P_methylated_given_active * P_active) / P_methylated
|
|
107
|
+
P_active_given_unmethylated = ((1 - P_methylated_given_active) * P_active) / (1 - P_methylated)
|
|
108
|
+
RR = P_active_given_methylated / P_active_given_unmethylated
|
|
109
|
+
|
|
110
|
+
_, p_value = fisher_exact(table)
|
|
111
|
+
positions_list.append(pos)
|
|
112
|
+
relative_risks.append(RR)
|
|
113
|
+
p_values.append(p_value)
|
|
114
|
+
|
|
115
|
+
results_df, sig_df = compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values)
|
|
116
|
+
results_dict[ref][group] = (results_df, sig_df)
|
|
117
|
+
|
|
118
|
+
return results_dict
|
|
119
|
+
|
|
120
|
+
import copy
|
|
121
|
+
import warnings
|
|
122
|
+
from typing import Dict, Any, List, Optional, Tuple, Union
|
|
123
|
+
|
|
124
|
+
import numpy as np
|
|
125
|
+
import pandas as pd
|
|
126
|
+
import matplotlib.pyplot as plt
|
|
127
|
+
|
|
128
|
+
# optional imports
|
|
129
|
+
try:
|
|
130
|
+
from joblib import Parallel, delayed
|
|
131
|
+
JOBLIB_AVAILABLE = True
|
|
132
|
+
except Exception:
|
|
133
|
+
JOBLIB_AVAILABLE = False
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
from scipy.stats import chi2_contingency
|
|
137
|
+
SCIPY_STATS_AVAILABLE = True
|
|
138
|
+
except Exception:
|
|
139
|
+
SCIPY_STATS_AVAILABLE = False
|
|
140
|
+
|
|
141
|
+
# -----------------------------
|
|
142
|
+
# Compute positionwise statistic (multi-method + simple site_types)
|
|
143
|
+
# -----------------------------
|
|
144
|
+
import numpy as np
|
|
145
|
+
import pandas as pd
|
|
146
|
+
from typing import List, Optional, Sequence, Dict, Any, Tuple
|
|
147
|
+
from contextlib import contextmanager
|
|
148
|
+
from joblib import Parallel, delayed, cpu_count
|
|
149
|
+
import joblib
|
|
150
|
+
from tqdm import tqdm
|
|
151
|
+
from scipy.stats import chi2_contingency
|
|
152
|
+
import warnings
|
|
153
|
+
import matplotlib.pyplot as plt
|
|
154
|
+
from itertools import cycle
|
|
155
|
+
import os
|
|
156
|
+
import warnings
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# ---------------------------
|
|
160
|
+
# joblib <-> tqdm integration
|
|
161
|
+
# ---------------------------
|
|
162
|
+
@contextmanager
|
|
163
|
+
def tqdm_joblib(tqdm_object: tqdm):
|
|
164
|
+
"""Context manager to patch joblib to update a tqdm progress bar."""
|
|
165
|
+
old = joblib.parallel.BatchCompletionCallBack
|
|
166
|
+
|
|
167
|
+
class TqdmBatchCompletionCallback(old): # type: ignore
|
|
168
|
+
def __call__(self, *args, **kwargs):
|
|
169
|
+
try:
|
|
170
|
+
tqdm_object.update(n=self.batch_size)
|
|
171
|
+
except Exception:
|
|
172
|
+
tqdm_object.update(1)
|
|
173
|
+
return super().__call__(*args, **kwargs)
|
|
174
|
+
|
|
175
|
+
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
|
|
176
|
+
try:
|
|
177
|
+
yield tqdm_object
|
|
178
|
+
finally:
|
|
179
|
+
joblib.parallel.BatchCompletionCallBack = old
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# ---------------------------
|
|
183
|
+
# row workers (upper-triangle only)
|
|
184
|
+
# ---------------------------
|
|
185
|
+
def _chi2_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
|
|
186
|
+
n_pos = X_bin.shape[1]
|
|
187
|
+
row = np.full((n_pos,), np.nan, dtype=float)
|
|
188
|
+
xi = X_bin[:, i]
|
|
189
|
+
for j in range(i, n_pos):
|
|
190
|
+
xj = X_bin[:, j]
|
|
191
|
+
mask = (~np.isnan(xi)) & (~np.isnan(xj))
|
|
192
|
+
if int(mask.sum()) < int(min_count_for_pairwise):
|
|
193
|
+
continue
|
|
194
|
+
try:
|
|
195
|
+
table = pd.crosstab(xi[mask], xj[mask])
|
|
196
|
+
if table.shape != (2, 2):
|
|
197
|
+
continue
|
|
198
|
+
chi2, _, _, _ = chi2_contingency(table, correction=False)
|
|
199
|
+
row[j] = float(chi2)
|
|
200
|
+
except Exception:
|
|
201
|
+
row[j] = np.nan
|
|
202
|
+
return (i, row)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _relative_risk_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
|
|
206
|
+
n_pos = X_bin.shape[1]
|
|
207
|
+
row = np.full((n_pos,), np.nan, dtype=float)
|
|
208
|
+
xi = X_bin[:, i]
|
|
209
|
+
for j in range(i, n_pos):
|
|
210
|
+
xj = X_bin[:, j]
|
|
211
|
+
mask = (~np.isnan(xi)) & (~np.isnan(xj))
|
|
212
|
+
if int(mask.sum()) < int(min_count_for_pairwise):
|
|
213
|
+
continue
|
|
214
|
+
a = np.sum((xi[mask] == 1) & (xj[mask] == 1))
|
|
215
|
+
b = np.sum((xi[mask] == 1) & (xj[mask] == 0))
|
|
216
|
+
c = np.sum((xi[mask] == 0) & (xj[mask] == 1))
|
|
217
|
+
d = np.sum((xi[mask] == 0) & (xj[mask] == 0))
|
|
218
|
+
try:
|
|
219
|
+
if (a + b) > 0 and (c + d) > 0 and (c > 0):
|
|
220
|
+
p1 = a / float(a + b)
|
|
221
|
+
p2 = c / float(c + d)
|
|
222
|
+
row[j] = float(p1 / p2) if p2 > 0 else np.nan
|
|
223
|
+
else:
|
|
224
|
+
row[j] = np.nan
|
|
225
|
+
except Exception:
|
|
226
|
+
row[j] = np.nan
|
|
227
|
+
return (i, row)
|
|
228
|
+
|
|
229
|
+
def compute_positionwise_statistics(
|
|
230
|
+
adata,
|
|
231
|
+
layer: str,
|
|
232
|
+
methods: Sequence[str] = ("pearson",),
|
|
233
|
+
sample_col: str = "Barcode",
|
|
234
|
+
ref_col: str = "Reference_strand",
|
|
235
|
+
site_types: Optional[Sequence[str]] = None,
|
|
236
|
+
encoding: str = "signed",
|
|
237
|
+
output_key: str = "positionwise_result",
|
|
238
|
+
min_count_for_pairwise: int = 10,
|
|
239
|
+
max_threads: Optional[int] = None,
|
|
240
|
+
reverse_indices_on_store: bool = False,
|
|
241
|
+
):
|
|
242
|
+
"""
|
|
243
|
+
Compute per-(sample,ref) positionwise matrices for methods in `methods`.
|
|
244
|
+
|
|
245
|
+
Results stored at:
|
|
246
|
+
adata.uns[output_key][method][ (sample, ref) ] = DataFrame
|
|
247
|
+
adata.uns[output_key + "_n"][method][ (sample, ref) ] = int(n_reads)
|
|
248
|
+
"""
|
|
249
|
+
if isinstance(methods, str):
|
|
250
|
+
methods = [methods]
|
|
251
|
+
methods = [m.lower() for m in methods]
|
|
252
|
+
|
|
253
|
+
# prepare containers
|
|
254
|
+
adata.uns[output_key] = {m: {} for m in methods}
|
|
255
|
+
adata.uns[output_key + "_n"] = {m: {} for m in methods}
|
|
256
|
+
|
|
257
|
+
# workers
|
|
258
|
+
if max_threads is None or max_threads <= 0:
|
|
259
|
+
n_jobs = max(1, cpu_count() or 1)
|
|
260
|
+
else:
|
|
261
|
+
n_jobs = max(1, int(max_threads))
|
|
262
|
+
|
|
263
|
+
# samples / refs
|
|
264
|
+
sseries = adata.obs[sample_col]
|
|
265
|
+
if not pd.api.types.is_categorical_dtype(sseries):
|
|
266
|
+
sseries = sseries.astype("category")
|
|
267
|
+
samples = list(sseries.cat.categories)
|
|
268
|
+
|
|
269
|
+
rseries = adata.obs[ref_col]
|
|
270
|
+
if not pd.api.types.is_categorical_dtype(rseries):
|
|
271
|
+
rseries = rseries.astype("category")
|
|
272
|
+
references = list(rseries.cat.categories)
|
|
273
|
+
|
|
274
|
+
total_tasks = len(samples) * len(references)
|
|
275
|
+
pbar_outer = tqdm(total=total_tasks, desc="positionwise (sample x ref)", unit="cell")
|
|
276
|
+
|
|
277
|
+
for sample in samples:
|
|
278
|
+
for ref in references:
|
|
279
|
+
label = (sample, ref)
|
|
280
|
+
try:
|
|
281
|
+
mask = (adata.obs[sample_col] == sample) & (adata.obs[ref_col] == ref)
|
|
282
|
+
subset = adata[mask]
|
|
283
|
+
n_reads = subset.shape[0]
|
|
284
|
+
|
|
285
|
+
# nothing to do -> store empty placeholders
|
|
286
|
+
if n_reads == 0:
|
|
287
|
+
for m in methods:
|
|
288
|
+
adata.uns[output_key][m][label] = pd.DataFrame()
|
|
289
|
+
adata.uns[output_key + "_n"][m][label] = 0
|
|
290
|
+
pbar_outer.update(1)
|
|
291
|
+
continue
|
|
292
|
+
|
|
293
|
+
# select var columns based on site_types and reference
|
|
294
|
+
if site_types:
|
|
295
|
+
col_mask = np.zeros(subset.shape[1], dtype=bool)
|
|
296
|
+
for st in site_types:
|
|
297
|
+
colname = f"{ref}_{st}"
|
|
298
|
+
if colname in subset.var.columns:
|
|
299
|
+
col_mask |= np.asarray(subset.var[colname].values, dtype=bool)
|
|
300
|
+
else:
|
|
301
|
+
# if mask not present, warn once (but keep searching)
|
|
302
|
+
# user may pass generic site types
|
|
303
|
+
pass
|
|
304
|
+
if not col_mask.any():
|
|
305
|
+
selected_var_idx = np.arange(subset.shape[1])
|
|
306
|
+
else:
|
|
307
|
+
selected_var_idx = np.nonzero(col_mask)[0]
|
|
308
|
+
else:
|
|
309
|
+
selected_var_idx = np.arange(subset.shape[1])
|
|
310
|
+
|
|
311
|
+
if selected_var_idx.size == 0:
|
|
312
|
+
for m in methods:
|
|
313
|
+
adata.uns[output_key][m][label] = pd.DataFrame()
|
|
314
|
+
adata.uns[output_key + "_n"][m][label] = int(n_reads)
|
|
315
|
+
pbar_outer.update(1)
|
|
316
|
+
continue
|
|
317
|
+
|
|
318
|
+
# extract matrix
|
|
319
|
+
if (layer in subset.layers) and (subset.layers[layer] is not None):
|
|
320
|
+
X = subset.layers[layer]
|
|
321
|
+
else:
|
|
322
|
+
X = subset.X
|
|
323
|
+
X = np.asarray(X, dtype=float)
|
|
324
|
+
X = X[:, selected_var_idx] # (n_reads, n_pos)
|
|
325
|
+
|
|
326
|
+
# binary encoding
|
|
327
|
+
if encoding == "signed":
|
|
328
|
+
X_bin = np.where(X == 1, 1.0, np.where(X == -1, 0.0, np.nan))
|
|
329
|
+
else:
|
|
330
|
+
X_bin = np.where(X == 1, 1.0, np.where(X == 0, 0.0, np.nan))
|
|
331
|
+
|
|
332
|
+
n_pos = X_bin.shape[1]
|
|
333
|
+
if n_pos == 0:
|
|
334
|
+
for m in methods:
|
|
335
|
+
adata.uns[output_key][m][label] = pd.DataFrame()
|
|
336
|
+
adata.uns[output_key + "_n"][m][label] = int(n_reads)
|
|
337
|
+
pbar_outer.update(1)
|
|
338
|
+
continue
|
|
339
|
+
|
|
340
|
+
var_names = list(subset.var_names[selected_var_idx])
|
|
341
|
+
|
|
342
|
+
# compute per-method
|
|
343
|
+
for method in methods:
|
|
344
|
+
m = method.lower()
|
|
345
|
+
if m == "pearson":
|
|
346
|
+
# pairwise Pearson with column demean (nan-aware approximation)
|
|
347
|
+
with np.errstate(invalid="ignore"):
|
|
348
|
+
col_mean = np.nanmean(X_bin, axis=0)
|
|
349
|
+
Xc = X_bin - col_mean # nan preserved
|
|
350
|
+
Xc0 = np.nan_to_num(Xc, nan=0.0)
|
|
351
|
+
cov = Xc0.T @ Xc0
|
|
352
|
+
denom = (np.sqrt((Xc0**2).sum(axis=0))[:, None] * np.sqrt((Xc0**2).sum(axis=0))[None, :])
|
|
353
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
354
|
+
mat = np.where(denom != 0.0, cov / denom, np.nan)
|
|
355
|
+
elif m == "binary_covariance":
|
|
356
|
+
binary = (X_bin == 1).astype(float)
|
|
357
|
+
valid = (~np.isnan(X_bin)).astype(float)
|
|
358
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
359
|
+
numerator = binary.T @ binary
|
|
360
|
+
denominator = valid.T @ valid
|
|
361
|
+
mat = np.true_divide(numerator, denominator)
|
|
362
|
+
mat[~np.isfinite(mat)] = 0.0
|
|
363
|
+
elif m in ("chi_squared", "relative_risk"):
|
|
364
|
+
if m == "chi_squared":
|
|
365
|
+
worker = _chi2_row_job
|
|
366
|
+
else:
|
|
367
|
+
worker = _relative_risk_row_job
|
|
368
|
+
out = np.full((n_pos, n_pos), np.nan, dtype=float)
|
|
369
|
+
tasks = (delayed(worker)(i, X_bin, min_count_for_pairwise) for i in range(n_pos))
|
|
370
|
+
pbar_rows = tqdm(total=n_pos, desc=f"{m}: rows ({sample}__{ref})", leave=False)
|
|
371
|
+
with tqdm_joblib(pbar_rows):
|
|
372
|
+
results = Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
|
|
373
|
+
pbar_rows.close()
|
|
374
|
+
for i, row in results:
|
|
375
|
+
out[int(i), :] = row
|
|
376
|
+
iu = np.triu_indices(n_pos, k=1)
|
|
377
|
+
out[iu[1], iu[0]] = out[iu]
|
|
378
|
+
mat = out
|
|
379
|
+
else:
|
|
380
|
+
raise ValueError(f"Unsupported method: {method}")
|
|
381
|
+
|
|
382
|
+
# optionally reverse order at store-time
|
|
383
|
+
if reverse_indices_on_store:
|
|
384
|
+
mat_store = np.flip(np.flip(mat, axis=0), axis=1)
|
|
385
|
+
idx_names = var_names[::-1]
|
|
386
|
+
else:
|
|
387
|
+
mat_store = mat
|
|
388
|
+
idx_names = var_names
|
|
389
|
+
|
|
390
|
+
# make dataframe with labels
|
|
391
|
+
df = pd.DataFrame(mat_store, index=idx_names, columns=idx_names)
|
|
392
|
+
|
|
393
|
+
adata.uns[output_key][m][label] = df
|
|
394
|
+
adata.uns[output_key + "_n"][m][label] = int(n_reads)
|
|
395
|
+
|
|
396
|
+
except Exception as exc:
|
|
397
|
+
warnings.warn(f"Failed computing positionwise for {sample}__{ref}: {exc}")
|
|
398
|
+
finally:
|
|
399
|
+
pbar_outer.update(1)
|
|
400
|
+
|
|
401
|
+
pbar_outer.close()
|
|
402
|
+
return None
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
# ---------------------------
|
|
406
|
+
# Plotting function
|
|
407
|
+
# ---------------------------
|
|
408
|
+
|
|
409
|
+
def plot_positionwise_matrices(
|
|
410
|
+
adata,
|
|
411
|
+
methods: List[str],
|
|
412
|
+
cmaps: Optional[List[str]] = None,
|
|
413
|
+
sample_col: str = "Barcode",
|
|
414
|
+
ref_col: str = "Reference_strand",
|
|
415
|
+
output_dir: Optional[str] = None,
|
|
416
|
+
vmin: Optional[Dict[str, float]] = None,
|
|
417
|
+
vmax: Optional[Dict[str, float]] = None,
|
|
418
|
+
figsize_per_cell: Tuple[float, float] = (3.5, 3.5),
|
|
419
|
+
dpi: int = 160,
|
|
420
|
+
cbar_shrink: float = 0.9,
|
|
421
|
+
output_key: str = "positionwise_result",
|
|
422
|
+
show_colorbar: bool = True,
|
|
423
|
+
flip_display_axes: bool = False,
|
|
424
|
+
rows_per_page: int = 6,
|
|
425
|
+
sample_label_rotation: float = 90.0,
|
|
426
|
+
):
|
|
427
|
+
"""
|
|
428
|
+
Plot grids of matrices for each method with pagination and rotated sample-row labels.
|
|
429
|
+
|
|
430
|
+
New args:
|
|
431
|
+
- rows_per_page: how many sample rows per page/figure (pagination)
|
|
432
|
+
- sample_label_rotation: rotation angle (deg) for the sample labels placed in the left margin.
|
|
433
|
+
Returns:
|
|
434
|
+
dict mapping method -> list of saved filenames (empty list if figures were shown).
|
|
435
|
+
"""
|
|
436
|
+
if isinstance(methods, str):
|
|
437
|
+
methods = [methods]
|
|
438
|
+
if cmaps is None:
|
|
439
|
+
cmaps = ["viridis"] * len(methods)
|
|
440
|
+
cmap_cycle = cycle(cmaps)
|
|
441
|
+
|
|
442
|
+
# canonicalize sample/ref order
|
|
443
|
+
sseries = adata.obs[sample_col]
|
|
444
|
+
if not pd.api.types.is_categorical_dtype(sseries):
|
|
445
|
+
sseries = sseries.astype("category")
|
|
446
|
+
samples = list(sseries.cat.categories)
|
|
447
|
+
|
|
448
|
+
rseries = adata.obs[ref_col]
|
|
449
|
+
if not pd.api.types.is_categorical_dtype(rseries):
|
|
450
|
+
rseries = rseries.astype("category")
|
|
451
|
+
references = list(rseries.cat.categories)
|
|
452
|
+
|
|
453
|
+
# ensure directories
|
|
454
|
+
if output_dir:
|
|
455
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
456
|
+
|
|
457
|
+
saved_files_by_method = {}
|
|
458
|
+
|
|
459
|
+
def _get_df_from_store(store, sample, ref):
|
|
460
|
+
"""
|
|
461
|
+
try multiple key formats: (sample, ref) tuple, 'sample__ref' string,
|
|
462
|
+
or str(sample)+'__'+str(ref). Return None if not found.
|
|
463
|
+
"""
|
|
464
|
+
if store is None:
|
|
465
|
+
return None
|
|
466
|
+
# try tuple key
|
|
467
|
+
key_t = (sample, ref)
|
|
468
|
+
if key_t in store:
|
|
469
|
+
return store[key_t]
|
|
470
|
+
# try string key
|
|
471
|
+
key_s = f"{sample}__{ref}"
|
|
472
|
+
if key_s in store:
|
|
473
|
+
return store[key_s]
|
|
474
|
+
# try stringified tuple keys (some callers store differently)
|
|
475
|
+
for k in store.keys():
|
|
476
|
+
try:
|
|
477
|
+
if isinstance(k, tuple) and len(k) == 2 and str(k[0]) == str(sample) and str(k[1]) == str(ref):
|
|
478
|
+
return store[k]
|
|
479
|
+
if isinstance(k, str) and key_s == k:
|
|
480
|
+
return store[k]
|
|
481
|
+
except Exception:
|
|
482
|
+
continue
|
|
483
|
+
return None
|
|
484
|
+
|
|
485
|
+
for method, cmap in zip(methods, cmap_cycle):
|
|
486
|
+
m = method.lower()
|
|
487
|
+
method_store = adata.uns.get(output_key, {}).get(m, {})
|
|
488
|
+
if not method_store:
|
|
489
|
+
warnings.warn(f"No results found for method '{method}' in adata.uns['{output_key}']. Skipping.", stacklevel=2)
|
|
490
|
+
saved_files_by_method[method] = []
|
|
491
|
+
continue
|
|
492
|
+
|
|
493
|
+
# gather numeric values to pick sensible vmin/vmax when not provided
|
|
494
|
+
vals = []
|
|
495
|
+
for s in samples:
|
|
496
|
+
for r in references:
|
|
497
|
+
df = _get_df_from_store(method_store, s, r)
|
|
498
|
+
if isinstance(df, pd.DataFrame) and df.size > 0:
|
|
499
|
+
a = df.values
|
|
500
|
+
a = a[np.isfinite(a)]
|
|
501
|
+
if a.size:
|
|
502
|
+
vals.append(a)
|
|
503
|
+
if vals:
|
|
504
|
+
allvals = np.concatenate(vals)
|
|
505
|
+
else:
|
|
506
|
+
allvals = np.array([])
|
|
507
|
+
|
|
508
|
+
# decide per-method defaults
|
|
509
|
+
if m == "pearson":
|
|
510
|
+
vmn = -1.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
|
|
511
|
+
vmx = 1.0 if (vmax is None or (isinstance(vmax, dict) and m not in vmax)) else (vmax.get(m) if isinstance(vmax, dict) else vmax)
|
|
512
|
+
vmn = -1.0 if vmn is None else vmn
|
|
513
|
+
vmx = 1.0 if vmx is None else vmx
|
|
514
|
+
elif m == "binary_covariance":
|
|
515
|
+
vmn = 0.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
|
|
516
|
+
vmx = 1.0 if (vmax is None or (isinstance(vmax, dict) and m not in vmax)) else (vmax.get(m) if isinstance(vmax, dict) else vmax)
|
|
517
|
+
vmn = 0.0 if vmn is None else vmn
|
|
518
|
+
vmx = 1.0 if vmx is None else vmx
|
|
519
|
+
else:
|
|
520
|
+
vmn = 0.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
|
|
521
|
+
if (vmax is None) or (isinstance(vmax, dict) and m not in vmax):
|
|
522
|
+
vmx = float(np.nanpercentile(allvals, 99.0)) if allvals.size else 1.0
|
|
523
|
+
else:
|
|
524
|
+
vmx = (vmax.get(m) if isinstance(vmax, dict) else vmax)
|
|
525
|
+
vmn = 0.0 if vmn is None else vmn
|
|
526
|
+
if vmx is None:
|
|
527
|
+
vmx = 1.0
|
|
528
|
+
|
|
529
|
+
# prepare pagination over sample rows
|
|
530
|
+
saved_files = []
|
|
531
|
+
n_pages = max(1, int(np.ceil(len(samples) / float(max(1, rows_per_page)))))
|
|
532
|
+
for page_idx in range(n_pages):
|
|
533
|
+
start = page_idx * rows_per_page
|
|
534
|
+
chunk = samples[start : start + rows_per_page]
|
|
535
|
+
nrows = len(chunk)
|
|
536
|
+
ncols = max(1, len(references))
|
|
537
|
+
fig_w = ncols * figsize_per_cell[0]
|
|
538
|
+
fig_h = nrows * figsize_per_cell[1]
|
|
539
|
+
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
|
|
540
|
+
|
|
541
|
+
# leave margin for rotated sample labels
|
|
542
|
+
plt.subplots_adjust(left=0.12, right=0.88, top=0.95, bottom=0.05)
|
|
543
|
+
|
|
544
|
+
any_plotted = False
|
|
545
|
+
im = None
|
|
546
|
+
for r_idx, sample in enumerate(chunk):
|
|
547
|
+
for c_idx, ref in enumerate(references):
|
|
548
|
+
ax = axes[r_idx][c_idx]
|
|
549
|
+
df = _get_df_from_store(method_store, sample, ref)
|
|
550
|
+
if not isinstance(df, pd.DataFrame) or df.size == 0:
|
|
551
|
+
ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
|
|
552
|
+
ax.set_xticks([])
|
|
553
|
+
ax.set_yticks([])
|
|
554
|
+
else:
|
|
555
|
+
mat = df.values.astype(float)
|
|
556
|
+
origin = "upper" if flip_display_axes else "lower"
|
|
557
|
+
im = ax.imshow(mat, origin=origin, aspect="auto", vmin=vmn, vmax=vmx, cmap=cmap)
|
|
558
|
+
any_plotted = True
|
|
559
|
+
ax.set_xticks([])
|
|
560
|
+
ax.set_yticks([])
|
|
561
|
+
|
|
562
|
+
# top title is reference (only for top-row)
|
|
563
|
+
if r_idx == 0:
|
|
564
|
+
ax.set_title(str(ref), fontsize=9)
|
|
565
|
+
|
|
566
|
+
# draw rotated sample label into left margin centered on the row
|
|
567
|
+
# compute vertical center of this row's axis in figure coords
|
|
568
|
+
ax0 = axes[r_idx][0]
|
|
569
|
+
ax_y0, ax_y1 = ax0.get_position().y0, ax0.get_position().y1
|
|
570
|
+
y_center = 0.5 * (ax_y0 + ax_y1)
|
|
571
|
+
# place text at x=0.01 (just inside left margin); rotation controls orientation
|
|
572
|
+
fig.text(0.01, y_center, str(chunk[r_idx]), va="center", ha="left", rotation=sample_label_rotation, fontsize=9)
|
|
573
|
+
|
|
574
|
+
fig.suptitle(f"{method} — per-sample x per-reference matrices (page {page_idx+1}/{n_pages})", fontsize=12, y=0.99)
|
|
575
|
+
fig.tight_layout(rect=[0.05, 0.02, 0.9, 0.96])
|
|
576
|
+
|
|
577
|
+
# colorbar (shared)
|
|
578
|
+
if any_plotted and show_colorbar and (im is not None):
|
|
579
|
+
try:
|
|
580
|
+
cbar_ax = fig.add_axes([0.9, 0.15, 0.02, 0.7])
|
|
581
|
+
fig.colorbar(im, cax=cbar_ax, shrink=cbar_shrink)
|
|
582
|
+
except Exception:
|
|
583
|
+
try:
|
|
584
|
+
fig.colorbar(im, ax=axes.ravel().tolist(), fraction=0.02, pad=0.02)
|
|
585
|
+
except Exception:
|
|
586
|
+
pass
|
|
587
|
+
|
|
588
|
+
# save or show
|
|
589
|
+
if output_dir:
|
|
590
|
+
fname = f"positionwise_{method}_page{page_idx+1}.png"
|
|
591
|
+
outpath = os.path.join(output_dir, fname)
|
|
592
|
+
plt.savefig(outpath, bbox_inches="tight")
|
|
593
|
+
saved_files.append(outpath)
|
|
594
|
+
plt.close(fig)
|
|
595
|
+
else:
|
|
596
|
+
plt.show()
|
|
597
|
+
saved_files.append("") # placeholder to indicate a figure was shown
|
|
598
|
+
|
|
599
|
+
saved_files_by_method[method] = saved_files
|
|
600
|
+
|
|
601
|
+
return saved_files_by_method
|