smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +7 -6
- smftools/_version.py +1 -1
- smftools/cli/cli_flows.py +94 -0
- smftools/cli/hmm_adata.py +338 -0
- smftools/cli/load_adata.py +577 -0
- smftools/cli/preprocess_adata.py +363 -0
- smftools/cli/spatial_adata.py +564 -0
- smftools/cli_entry.py +435 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +38 -0
- smftools/config/deaminase.yaml +61 -0
- smftools/config/default.yaml +264 -0
- smftools/config/direct.yaml +41 -0
- smftools/config/discover_input_files.py +115 -0
- smftools/config/experiment_config.py +1288 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/{tools → hmm}/display_hmm.py +3 -3
- smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
- smftools/{tools → hmm}/train_hmm.py +1 -1
- smftools/informatics/__init__.py +13 -9
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/fast5_to_pod5.py +43 -0
- smftools/informatics/archived/helpers/archived/__init__.py +71 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
- smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
- smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
- smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
- smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
- smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
- smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
- smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
- smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
- smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
- smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
- smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
- smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
- smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
- smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
- smftools/informatics/bam_functions.py +812 -0
- smftools/informatics/basecalling.py +67 -0
- smftools/informatics/bed_functions.py +366 -0
- smftools/informatics/binarize_converted_base_identities.py +172 -0
- smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
- smftools/informatics/fasta_functions.py +255 -0
- smftools/informatics/h5ad_functions.py +197 -0
- smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
- smftools/informatics/modkit_functions.py +129 -0
- smftools/informatics/ohe.py +160 -0
- smftools/informatics/pod5_functions.py +224 -0
- smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/{tools → machine_learning}/models/positional.py +3 -2
- smftools/{tools → machine_learning}/models/rnn.py +2 -1
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/plotting/__init__.py +4 -1
- smftools/plotting/autocorrelation_plotting.py +609 -0
- smftools/plotting/general_plotting.py +1292 -140
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +15 -8
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/binarize.py +17 -0
- smftools/preprocessing/binarize_on_Youden.py +2 -2
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_coverage.py +10 -1
- smftools/preprocessing/calculate_position_Youden.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +17 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1326 -124
- smftools/preprocessing/invert_adata.py +12 -5
- smftools/preprocessing/load_sample_sheet.py +19 -4
- smftools/readwrite.py +1021 -89
- smftools/tools/__init__.py +3 -32
- smftools/tools/calculate_umap.py +5 -5
- smftools/tools/general_tools.py +3 -3
- smftools/tools/position_stats.py +468 -106
- smftools/tools/read_stats.py +115 -1
- smftools/tools/spatial_autocorrelation.py +562 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
- smftools-0.2.3.dist-info/RECORD +173 -0
- smftools-0.2.3.dist-info/entry_points.txt +2 -0
- smftools/informatics/fast5_to_pod5.py +0 -21
- smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
- smftools/informatics/helpers/__init__.py +0 -74
- smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
- smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
- smftools/informatics/helpers/bam_qc.py +0 -66
- smftools/informatics/helpers/bed_to_bigwig.py +0 -39
- smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
- smftools/informatics/helpers/index_fasta.py +0 -12
- smftools/informatics/helpers/make_dirs.py +0 -21
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
- smftools/informatics/load_adata.py +0 -182
- smftools/informatics/readwrite.py +0 -106
- smftools/informatics/subsample_fasta_from_bed.py +0 -47
- smftools/preprocessing/append_C_context.py +0 -82
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
- smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
- smftools/preprocessing/filter_reads_on_length.py +0 -51
- smftools/tools/call_hmm_peaks.py +0 -105
- smftools/tools/data/__init__.py +0 -2
- smftools/tools/data/anndata_data_module.py +0 -90
- smftools/tools/inference/__init__.py +0 -1
- smftools/tools/inference/lightning_inference.py +0 -41
- smftools/tools/models/base.py +0 -14
- smftools/tools/models/cnn.py +0 -34
- smftools/tools/models/lightning_base.py +0 -41
- smftools/tools/models/mlp.py +0 -17
- smftools/tools/models/sklearn_models.py +0 -40
- smftools/tools/models/transformer.py +0 -133
- smftools/tools/training/__init__.py +0 -1
- smftools/tools/training/train_lightning_model.py +0 -47
- smftools-0.1.7.dist-info/RECORD +0 -136
- /smftools/{tools/evaluation → cli}/__init__.py +0 -0
- /smftools/{tools → hmm}/calculate_distances.py +0 -0
- /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
- /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
- /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
- /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
- /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
- /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
- /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
- /smftools/{tools → machine_learning}/models/__init__.py +0 -0
- /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
- /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
- /smftools/{tools → machine_learning}/utils/device.py +0 -0
- /smftools/{tools → machine_learning}/utils/grl.py +0 -0
- /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
- /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
smftools/tools/position_stats.py
CHANGED
|
@@ -117,123 +117,485 @@ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
|
|
|
117
117
|
|
|
118
118
|
return results_dict
|
|
119
119
|
|
|
120
|
-
|
|
120
|
+
import copy
|
|
121
|
+
import warnings
|
|
122
|
+
from typing import Dict, Any, List, Optional, Tuple, Union
|
|
123
|
+
|
|
124
|
+
import numpy as np
|
|
125
|
+
import pandas as pd
|
|
126
|
+
import matplotlib.pyplot as plt
|
|
127
|
+
|
|
128
|
+
# optional imports
|
|
129
|
+
try:
|
|
130
|
+
from joblib import Parallel, delayed
|
|
131
|
+
JOBLIB_AVAILABLE = True
|
|
132
|
+
except Exception:
|
|
133
|
+
JOBLIB_AVAILABLE = False
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
from scipy.stats import chi2_contingency
|
|
137
|
+
SCIPY_STATS_AVAILABLE = True
|
|
138
|
+
except Exception:
|
|
139
|
+
SCIPY_STATS_AVAILABLE = False
|
|
140
|
+
|
|
141
|
+
# -----------------------------
|
|
142
|
+
# Compute positionwise statistic (multi-method + simple site_types)
|
|
143
|
+
# -----------------------------
|
|
144
|
+
import numpy as np
|
|
145
|
+
import pandas as pd
|
|
146
|
+
from typing import List, Optional, Sequence, Dict, Any, Tuple
|
|
147
|
+
from contextlib import contextmanager
|
|
148
|
+
from joblib import Parallel, delayed, cpu_count
|
|
149
|
+
import joblib
|
|
150
|
+
from tqdm import tqdm
|
|
151
|
+
from scipy.stats import chi2_contingency
|
|
152
|
+
import warnings
|
|
153
|
+
import matplotlib.pyplot as plt
|
|
154
|
+
from itertools import cycle
|
|
155
|
+
import os
|
|
156
|
+
import warnings
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# ---------------------------
|
|
160
|
+
# joblib <-> tqdm integration
|
|
161
|
+
# ---------------------------
|
|
162
|
+
@contextmanager
|
|
163
|
+
def tqdm_joblib(tqdm_object: tqdm):
|
|
164
|
+
"""Context manager to patch joblib to update a tqdm progress bar."""
|
|
165
|
+
old = joblib.parallel.BatchCompletionCallBack
|
|
166
|
+
|
|
167
|
+
class TqdmBatchCompletionCallback(old): # type: ignore
|
|
168
|
+
def __call__(self, *args, **kwargs):
|
|
169
|
+
try:
|
|
170
|
+
tqdm_object.update(n=self.batch_size)
|
|
171
|
+
except Exception:
|
|
172
|
+
tqdm_object.update(1)
|
|
173
|
+
return super().__call__(*args, **kwargs)
|
|
174
|
+
|
|
175
|
+
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
|
|
176
|
+
try:
|
|
177
|
+
yield tqdm_object
|
|
178
|
+
finally:
|
|
179
|
+
joblib.parallel.BatchCompletionCallBack = old
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# ---------------------------
|
|
183
|
+
# row workers (upper-triangle only)
|
|
184
|
+
# ---------------------------
|
|
185
|
+
def _chi2_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
|
|
186
|
+
n_pos = X_bin.shape[1]
|
|
187
|
+
row = np.full((n_pos,), np.nan, dtype=float)
|
|
188
|
+
xi = X_bin[:, i]
|
|
189
|
+
for j in range(i, n_pos):
|
|
190
|
+
xj = X_bin[:, j]
|
|
191
|
+
mask = (~np.isnan(xi)) & (~np.isnan(xj))
|
|
192
|
+
if int(mask.sum()) < int(min_count_for_pairwise):
|
|
193
|
+
continue
|
|
194
|
+
try:
|
|
195
|
+
table = pd.crosstab(xi[mask], xj[mask])
|
|
196
|
+
if table.shape != (2, 2):
|
|
197
|
+
continue
|
|
198
|
+
chi2, _, _, _ = chi2_contingency(table, correction=False)
|
|
199
|
+
row[j] = float(chi2)
|
|
200
|
+
except Exception:
|
|
201
|
+
row[j] = np.nan
|
|
202
|
+
return (i, row)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _relative_risk_row_job(i: int, X_bin: np.ndarray, min_count_for_pairwise: int) -> Tuple[int, np.ndarray]:
|
|
206
|
+
n_pos = X_bin.shape[1]
|
|
207
|
+
row = np.full((n_pos,), np.nan, dtype=float)
|
|
208
|
+
xi = X_bin[:, i]
|
|
209
|
+
for j in range(i, n_pos):
|
|
210
|
+
xj = X_bin[:, j]
|
|
211
|
+
mask = (~np.isnan(xi)) & (~np.isnan(xj))
|
|
212
|
+
if int(mask.sum()) < int(min_count_for_pairwise):
|
|
213
|
+
continue
|
|
214
|
+
a = np.sum((xi[mask] == 1) & (xj[mask] == 1))
|
|
215
|
+
b = np.sum((xi[mask] == 1) & (xj[mask] == 0))
|
|
216
|
+
c = np.sum((xi[mask] == 0) & (xj[mask] == 1))
|
|
217
|
+
d = np.sum((xi[mask] == 0) & (xj[mask] == 0))
|
|
218
|
+
try:
|
|
219
|
+
if (a + b) > 0 and (c + d) > 0 and (c > 0):
|
|
220
|
+
p1 = a / float(a + b)
|
|
221
|
+
p2 = c / float(c + d)
|
|
222
|
+
row[j] = float(p1 / p2) if p2 > 0 else np.nan
|
|
223
|
+
else:
|
|
224
|
+
row[j] = np.nan
|
|
225
|
+
except Exception:
|
|
226
|
+
row[j] = np.nan
|
|
227
|
+
return (i, row)
|
|
228
|
+
|
|
229
|
+
def compute_positionwise_statistics(
|
|
121
230
|
adata,
|
|
122
|
-
layer,
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
encoding="signed",
|
|
128
|
-
|
|
231
|
+
layer: str,
|
|
232
|
+
methods: Sequence[str] = ("pearson",),
|
|
233
|
+
sample_col: str = "Barcode",
|
|
234
|
+
ref_col: str = "Reference_strand",
|
|
235
|
+
site_types: Optional[Sequence[str]] = None,
|
|
236
|
+
encoding: str = "signed",
|
|
237
|
+
output_key: str = "positionwise_result",
|
|
238
|
+
min_count_for_pairwise: int = 10,
|
|
239
|
+
max_threads: Optional[int] = None,
|
|
240
|
+
reverse_indices_on_store: bool = False,
|
|
129
241
|
):
|
|
130
242
|
"""
|
|
131
|
-
|
|
243
|
+
Compute per-(sample,ref) positionwise matrices for methods in `methods`.
|
|
132
244
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
method (str): 'pearson', 'binary_covariance', 'relative_risk', or 'chi_squared'.
|
|
137
|
-
groupby (str or list): Column(s) in adata.obs to group by.
|
|
138
|
-
output_key (str): Key in adata.uns to store results.
|
|
139
|
-
site_config (dict): Optional {ref: [site_types]} to restrict sites per reference.
|
|
140
|
-
encoding (str): 'signed' (1/-1/0) or 'binary' (1/0/NaN).
|
|
141
|
-
max_threads (int): Number of parallel threads to use (joblib).
|
|
245
|
+
Results stored at:
|
|
246
|
+
adata.uns[output_key][method][ (sample, ref) ] = DataFrame
|
|
247
|
+
adata.uns[output_key + "_n"][method][ (sample, ref) ] = int(n_reads)
|
|
142
248
|
"""
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
249
|
+
if isinstance(methods, str):
|
|
250
|
+
methods = [methods]
|
|
251
|
+
methods = [m.lower() for m in methods]
|
|
252
|
+
|
|
253
|
+
# prepare containers
|
|
254
|
+
adata.uns[output_key] = {m: {} for m in methods}
|
|
255
|
+
adata.uns[output_key + "_n"] = {m: {} for m in methods}
|
|
256
|
+
|
|
257
|
+
# workers
|
|
258
|
+
if max_threads is None or max_threads <= 0:
|
|
259
|
+
n_jobs = max(1, cpu_count() or 1)
|
|
260
|
+
else:
|
|
261
|
+
n_jobs = max(1, int(max_threads))
|
|
262
|
+
|
|
263
|
+
# samples / refs
|
|
264
|
+
sseries = adata.obs[sample_col]
|
|
265
|
+
if not pd.api.types.is_categorical_dtype(sseries):
|
|
266
|
+
sseries = sseries.astype("category")
|
|
267
|
+
samples = list(sseries.cat.categories)
|
|
268
|
+
|
|
269
|
+
rseries = adata.obs[ref_col]
|
|
270
|
+
if not pd.api.types.is_categorical_dtype(rseries):
|
|
271
|
+
rseries = rseries.astype("category")
|
|
272
|
+
references = list(rseries.cat.categories)
|
|
273
|
+
|
|
274
|
+
total_tasks = len(samples) * len(references)
|
|
275
|
+
pbar_outer = tqdm(total=total_tasks, desc="positionwise (sample x ref)", unit="cell")
|
|
276
|
+
|
|
277
|
+
for sample in samples:
|
|
278
|
+
for ref in references:
|
|
279
|
+
label = (sample, ref)
|
|
280
|
+
try:
|
|
281
|
+
mask = (adata.obs[sample_col] == sample) & (adata.obs[ref_col] == ref)
|
|
282
|
+
subset = adata[mask]
|
|
283
|
+
n_reads = subset.shape[0]
|
|
284
|
+
|
|
285
|
+
# nothing to do -> store empty placeholders
|
|
286
|
+
if n_reads == 0:
|
|
287
|
+
for m in methods:
|
|
288
|
+
adata.uns[output_key][m][label] = pd.DataFrame()
|
|
289
|
+
adata.uns[output_key + "_n"][m][label] = 0
|
|
290
|
+
pbar_outer.update(1)
|
|
291
|
+
continue
|
|
148
292
|
|
|
149
|
-
|
|
150
|
-
|
|
293
|
+
# select var columns based on site_types and reference
|
|
294
|
+
if site_types:
|
|
295
|
+
col_mask = np.zeros(subset.shape[1], dtype=bool)
|
|
296
|
+
for st in site_types:
|
|
297
|
+
colname = f"{ref}_{st}"
|
|
298
|
+
if colname in subset.var.columns:
|
|
299
|
+
col_mask |= np.asarray(subset.var[colname].values, dtype=bool)
|
|
300
|
+
else:
|
|
301
|
+
# if mask not present, warn once (but keep searching)
|
|
302
|
+
# user may pass generic site types
|
|
303
|
+
pass
|
|
304
|
+
if not col_mask.any():
|
|
305
|
+
selected_var_idx = np.arange(subset.shape[1])
|
|
306
|
+
else:
|
|
307
|
+
selected_var_idx = np.nonzero(col_mask)[0]
|
|
308
|
+
else:
|
|
309
|
+
selected_var_idx = np.arange(subset.shape[1])
|
|
310
|
+
|
|
311
|
+
if selected_var_idx.size == 0:
|
|
312
|
+
for m in methods:
|
|
313
|
+
adata.uns[output_key][m][label] = pd.DataFrame()
|
|
314
|
+
adata.uns[output_key + "_n"][m][label] = int(n_reads)
|
|
315
|
+
pbar_outer.update(1)
|
|
316
|
+
continue
|
|
151
317
|
|
|
152
|
-
|
|
153
|
-
|
|
318
|
+
# extract matrix
|
|
319
|
+
if (layer in subset.layers) and (subset.layers[layer] is not None):
|
|
320
|
+
X = subset.layers[layer]
|
|
321
|
+
else:
|
|
322
|
+
X = subset.X
|
|
323
|
+
X = np.asarray(X, dtype=float)
|
|
324
|
+
X = X[:, selected_var_idx] # (n_reads, n_pos)
|
|
325
|
+
|
|
326
|
+
# binary encoding
|
|
327
|
+
if encoding == "signed":
|
|
328
|
+
X_bin = np.where(X == 1, 1.0, np.where(X == -1, 0.0, np.nan))
|
|
329
|
+
else:
|
|
330
|
+
X_bin = np.where(X == 1, 1.0, np.where(X == 0, 0.0, np.nan))
|
|
331
|
+
|
|
332
|
+
n_pos = X_bin.shape[1]
|
|
333
|
+
if n_pos == 0:
|
|
334
|
+
for m in methods:
|
|
335
|
+
adata.uns[output_key][m][label] = pd.DataFrame()
|
|
336
|
+
adata.uns[output_key + "_n"][m][label] = int(n_reads)
|
|
337
|
+
pbar_outer.update(1)
|
|
338
|
+
continue
|
|
154
339
|
|
|
155
|
-
|
|
156
|
-
|
|
340
|
+
var_names = list(subset.var_names[selected_var_idx])
|
|
341
|
+
|
|
342
|
+
# compute per-method
|
|
343
|
+
for method in methods:
|
|
344
|
+
m = method.lower()
|
|
345
|
+
if m == "pearson":
|
|
346
|
+
# pairwise Pearson with column demean (nan-aware approximation)
|
|
347
|
+
with np.errstate(invalid="ignore"):
|
|
348
|
+
col_mean = np.nanmean(X_bin, axis=0)
|
|
349
|
+
Xc = X_bin - col_mean # nan preserved
|
|
350
|
+
Xc0 = np.nan_to_num(Xc, nan=0.0)
|
|
351
|
+
cov = Xc0.T @ Xc0
|
|
352
|
+
denom = (np.sqrt((Xc0**2).sum(axis=0))[:, None] * np.sqrt((Xc0**2).sum(axis=0))[None, :])
|
|
353
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
354
|
+
mat = np.where(denom != 0.0, cov / denom, np.nan)
|
|
355
|
+
elif m == "binary_covariance":
|
|
356
|
+
binary = (X_bin == 1).astype(float)
|
|
357
|
+
valid = (~np.isnan(X_bin)).astype(float)
|
|
358
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
359
|
+
numerator = binary.T @ binary
|
|
360
|
+
denominator = valid.T @ valid
|
|
361
|
+
mat = np.true_divide(numerator, denominator)
|
|
362
|
+
mat[~np.isfinite(mat)] = 0.0
|
|
363
|
+
elif m in ("chi_squared", "relative_risk"):
|
|
364
|
+
if m == "chi_squared":
|
|
365
|
+
worker = _chi2_row_job
|
|
366
|
+
else:
|
|
367
|
+
worker = _relative_risk_row_job
|
|
368
|
+
out = np.full((n_pos, n_pos), np.nan, dtype=float)
|
|
369
|
+
tasks = (delayed(worker)(i, X_bin, min_count_for_pairwise) for i in range(n_pos))
|
|
370
|
+
pbar_rows = tqdm(total=n_pos, desc=f"{m}: rows ({sample}__{ref})", leave=False)
|
|
371
|
+
with tqdm_joblib(pbar_rows):
|
|
372
|
+
results = Parallel(n_jobs=n_jobs, prefer="processes")(tasks)
|
|
373
|
+
pbar_rows.close()
|
|
374
|
+
for i, row in results:
|
|
375
|
+
out[int(i), :] = row
|
|
376
|
+
iu = np.triu_indices(n_pos, k=1)
|
|
377
|
+
out[iu[1], iu[0]] = out[iu]
|
|
378
|
+
mat = out
|
|
379
|
+
else:
|
|
380
|
+
raise ValueError(f"Unsupported method: {method}")
|
|
381
|
+
|
|
382
|
+
# optionally reverse order at store-time
|
|
383
|
+
if reverse_indices_on_store:
|
|
384
|
+
mat_store = np.flip(np.flip(mat, axis=0), axis=1)
|
|
385
|
+
idx_names = var_names[::-1]
|
|
386
|
+
else:
|
|
387
|
+
mat_store = mat
|
|
388
|
+
idx_names = var_names
|
|
389
|
+
|
|
390
|
+
# make dataframe with labels
|
|
391
|
+
df = pd.DataFrame(mat_store, index=idx_names, columns=idx_names)
|
|
392
|
+
|
|
393
|
+
adata.uns[output_key][m][label] = df
|
|
394
|
+
adata.uns[output_key + "_n"][m][label] = int(n_reads)
|
|
395
|
+
|
|
396
|
+
except Exception as exc:
|
|
397
|
+
warnings.warn(f"Failed computing positionwise for {sample}__{ref}: {exc}")
|
|
398
|
+
finally:
|
|
399
|
+
pbar_outer.update(1)
|
|
400
|
+
|
|
401
|
+
pbar_outer.close()
|
|
402
|
+
return None
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
# ---------------------------
|
|
406
|
+
# Plotting function
|
|
407
|
+
# ---------------------------
|
|
408
|
+
|
|
409
|
+
def plot_positionwise_matrices(
|
|
410
|
+
adata,
|
|
411
|
+
methods: List[str],
|
|
412
|
+
cmaps: Optional[List[str]] = None,
|
|
413
|
+
sample_col: str = "Barcode",
|
|
414
|
+
ref_col: str = "Reference_strand",
|
|
415
|
+
output_dir: Optional[str] = None,
|
|
416
|
+
vmin: Optional[Dict[str, float]] = None,
|
|
417
|
+
vmax: Optional[Dict[str, float]] = None,
|
|
418
|
+
figsize_per_cell: Tuple[float, float] = (3.5, 3.5),
|
|
419
|
+
dpi: int = 160,
|
|
420
|
+
cbar_shrink: float = 0.9,
|
|
421
|
+
output_key: str = "positionwise_result",
|
|
422
|
+
show_colorbar: bool = True,
|
|
423
|
+
flip_display_axes: bool = False,
|
|
424
|
+
rows_per_page: int = 6,
|
|
425
|
+
sample_label_rotation: float = 90.0,
|
|
426
|
+
):
|
|
427
|
+
"""
|
|
428
|
+
Plot grids of matrices for each method with pagination and rotated sample-row labels.
|
|
157
429
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
430
|
+
New args:
|
|
431
|
+
- rows_per_page: how many sample rows per page/figure (pagination)
|
|
432
|
+
- sample_label_rotation: rotation angle (deg) for the sample labels placed in the left margin.
|
|
433
|
+
Returns:
|
|
434
|
+
dict mapping method -> list of saved filenames (empty list if figures were shown).
|
|
435
|
+
"""
|
|
436
|
+
if isinstance(methods, str):
|
|
437
|
+
methods = [methods]
|
|
438
|
+
if cmaps is None:
|
|
439
|
+
cmaps = ["viridis"] * len(methods)
|
|
440
|
+
cmap_cycle = cycle(cmaps)
|
|
441
|
+
|
|
442
|
+
# canonicalize sample/ref order
|
|
443
|
+
sseries = adata.obs[sample_col]
|
|
444
|
+
if not pd.api.types.is_categorical_dtype(sseries):
|
|
445
|
+
sseries = sseries.astype("category")
|
|
446
|
+
samples = list(sseries.cat.categories)
|
|
447
|
+
|
|
448
|
+
rseries = adata.obs[ref_col]
|
|
449
|
+
if not pd.api.types.is_categorical_dtype(rseries):
|
|
450
|
+
rseries = rseries.astype("category")
|
|
451
|
+
references = list(rseries.cat.categories)
|
|
452
|
+
|
|
453
|
+
# ensure directories
|
|
454
|
+
if output_dir:
|
|
455
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
456
|
+
|
|
457
|
+
saved_files_by_method = {}
|
|
458
|
+
|
|
459
|
+
def _get_df_from_store(store, sample, ref):
|
|
460
|
+
"""
|
|
461
|
+
try multiple key formats: (sample, ref) tuple, 'sample__ref' string,
|
|
462
|
+
or str(sample)+'__'+str(ref). Return None if not found.
|
|
463
|
+
"""
|
|
464
|
+
if store is None:
|
|
465
|
+
return None
|
|
466
|
+
# try tuple key
|
|
467
|
+
key_t = (sample, ref)
|
|
468
|
+
if key_t in store:
|
|
469
|
+
return store[key_t]
|
|
470
|
+
# try string key
|
|
471
|
+
key_s = f"{sample}__{ref}"
|
|
472
|
+
if key_s in store:
|
|
473
|
+
return store[key_s]
|
|
474
|
+
# try stringified tuple keys (some callers store differently)
|
|
475
|
+
for k in store.keys():
|
|
476
|
+
try:
|
|
477
|
+
if isinstance(k, tuple) and len(k) == 2 and str(k[0]) == str(sample) and str(k[1]) == str(ref):
|
|
478
|
+
return store[k]
|
|
479
|
+
if isinstance(k, str) and key_s == k:
|
|
480
|
+
return store[k]
|
|
481
|
+
except Exception:
|
|
482
|
+
continue
|
|
483
|
+
return None
|
|
484
|
+
|
|
485
|
+
for method, cmap in zip(methods, cmap_cycle):
|
|
486
|
+
m = method.lower()
|
|
487
|
+
method_store = adata.uns.get(output_key, {}).get(m, {})
|
|
488
|
+
if not method_store:
|
|
489
|
+
warnings.warn(f"No results found for method '{method}' in adata.uns['{output_key}']. Skipping.", stacklevel=2)
|
|
490
|
+
saved_files_by_method[method] = []
|
|
161
491
|
continue
|
|
162
492
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
if
|
|
174
|
-
|
|
493
|
+
# gather numeric values to pick sensible vmin/vmax when not provided
|
|
494
|
+
vals = []
|
|
495
|
+
for s in samples:
|
|
496
|
+
for r in references:
|
|
497
|
+
df = _get_df_from_store(method_store, s, r)
|
|
498
|
+
if isinstance(df, pd.DataFrame) and df.size > 0:
|
|
499
|
+
a = df.values
|
|
500
|
+
a = a[np.isfinite(a)]
|
|
501
|
+
if a.size:
|
|
502
|
+
vals.append(a)
|
|
503
|
+
if vals:
|
|
504
|
+
allvals = np.concatenate(vals)
|
|
175
505
|
else:
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
numerator = np.dot(binary.T, binary)
|
|
191
|
-
denominator = np.dot(valid.T, valid)
|
|
192
|
-
|
|
193
|
-
with np.errstate(divide='ignore', invalid='ignore'):
|
|
194
|
-
mat = np.true_divide(numerator, denominator)
|
|
195
|
-
mat[~np.isfinite(mat)] = 0
|
|
196
|
-
|
|
197
|
-
elif method in ["relative_risk", "chi_squared"]:
|
|
198
|
-
def compute_row(i):
|
|
199
|
-
row = np.zeros(n_pos)
|
|
200
|
-
xi = X_bin[:, i]
|
|
201
|
-
for j in range(n_pos):
|
|
202
|
-
xj = X_bin[:, j]
|
|
203
|
-
mask = ~np.isnan(xi) & ~np.isnan(xj)
|
|
204
|
-
if np.sum(mask) < 10:
|
|
205
|
-
row[j] = np.nan
|
|
206
|
-
continue
|
|
207
|
-
if method == "relative_risk":
|
|
208
|
-
a = np.sum((xi[mask] == 1) & (xj[mask] == 1))
|
|
209
|
-
b = np.sum((xi[mask] == 1) & (xj[mask] == 0))
|
|
210
|
-
c = np.sum((xi[mask] == 0) & (xj[mask] == 1))
|
|
211
|
-
d = np.sum((xi[mask] == 0) & (xj[mask] == 0))
|
|
212
|
-
if (a + b) > 0 and (c + d) > 0 and c > 0:
|
|
213
|
-
p1 = a / (a + b)
|
|
214
|
-
p2 = c / (c + d)
|
|
215
|
-
row[j] = p1 / p2 if p2 > 0 else np.nan
|
|
216
|
-
else:
|
|
217
|
-
row[j] = np.nan
|
|
218
|
-
elif method == "chi_squared":
|
|
219
|
-
table = pd.crosstab(xi[mask], xj[mask])
|
|
220
|
-
if table.shape != (2, 2):
|
|
221
|
-
row[j] = np.nan
|
|
222
|
-
else:
|
|
223
|
-
chi2, _, _, _ = chi2_contingency(table, correction=False)
|
|
224
|
-
row[j] = chi2
|
|
225
|
-
return row
|
|
226
|
-
|
|
227
|
-
mat = np.array(
|
|
228
|
-
Parallel(n_jobs=max_threads)(
|
|
229
|
-
delayed(compute_row)(i) for i in tqdm(range(n_pos), desc=f"{method}: {group}")
|
|
230
|
-
)
|
|
231
|
-
)
|
|
232
|
-
|
|
506
|
+
allvals = np.array([])
|
|
507
|
+
|
|
508
|
+
# decide per-method defaults
|
|
509
|
+
if m == "pearson":
|
|
510
|
+
vmn = -1.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
|
|
511
|
+
vmx = 1.0 if (vmax is None or (isinstance(vmax, dict) and m not in vmax)) else (vmax.get(m) if isinstance(vmax, dict) else vmax)
|
|
512
|
+
vmn = -1.0 if vmn is None else vmn
|
|
513
|
+
vmx = 1.0 if vmx is None else vmx
|
|
514
|
+
elif m == "binary_covariance":
|
|
515
|
+
vmn = 0.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
|
|
516
|
+
vmx = 1.0 if (vmax is None or (isinstance(vmax, dict) and m not in vmax)) else (vmax.get(m) if isinstance(vmax, dict) else vmax)
|
|
517
|
+
vmn = 0.0 if vmn is None else vmn
|
|
518
|
+
vmx = 1.0 if vmx is None else vmx
|
|
233
519
|
else:
|
|
234
|
-
|
|
520
|
+
vmn = 0.0 if (vmin is None or (isinstance(vmin, dict) and m not in vmin)) else (vmin.get(m) if isinstance(vmin, dict) else vmin)
|
|
521
|
+
if (vmax is None) or (isinstance(vmax, dict) and m not in vmax):
|
|
522
|
+
vmx = float(np.nanpercentile(allvals, 99.0)) if allvals.size else 1.0
|
|
523
|
+
else:
|
|
524
|
+
vmx = (vmax.get(m) if isinstance(vmax, dict) else vmax)
|
|
525
|
+
vmn = 0.0 if vmn is None else vmn
|
|
526
|
+
if vmx is None:
|
|
527
|
+
vmx = 1.0
|
|
528
|
+
|
|
529
|
+
# prepare pagination over sample rows
|
|
530
|
+
saved_files = []
|
|
531
|
+
n_pages = max(1, int(np.ceil(len(samples) / float(max(1, rows_per_page)))))
|
|
532
|
+
for page_idx in range(n_pages):
|
|
533
|
+
start = page_idx * rows_per_page
|
|
534
|
+
chunk = samples[start : start + rows_per_page]
|
|
535
|
+
nrows = len(chunk)
|
|
536
|
+
ncols = max(1, len(references))
|
|
537
|
+
fig_w = ncols * figsize_per_cell[0]
|
|
538
|
+
fig_h = nrows * figsize_per_cell[1]
|
|
539
|
+
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
|
|
540
|
+
|
|
541
|
+
# leave margin for rotated sample labels
|
|
542
|
+
plt.subplots_adjust(left=0.12, right=0.88, top=0.95, bottom=0.05)
|
|
543
|
+
|
|
544
|
+
any_plotted = False
|
|
545
|
+
im = None
|
|
546
|
+
for r_idx, sample in enumerate(chunk):
|
|
547
|
+
for c_idx, ref in enumerate(references):
|
|
548
|
+
ax = axes[r_idx][c_idx]
|
|
549
|
+
df = _get_df_from_store(method_store, sample, ref)
|
|
550
|
+
if not isinstance(df, pd.DataFrame) or df.size == 0:
|
|
551
|
+
ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
|
|
552
|
+
ax.set_xticks([])
|
|
553
|
+
ax.set_yticks([])
|
|
554
|
+
else:
|
|
555
|
+
mat = df.values.astype(float)
|
|
556
|
+
origin = "upper" if flip_display_axes else "lower"
|
|
557
|
+
im = ax.imshow(mat, origin=origin, aspect="auto", vmin=vmn, vmax=vmx, cmap=cmap)
|
|
558
|
+
any_plotted = True
|
|
559
|
+
ax.set_xticks([])
|
|
560
|
+
ax.set_yticks([])
|
|
561
|
+
|
|
562
|
+
# top title is reference (only for top-row)
|
|
563
|
+
if r_idx == 0:
|
|
564
|
+
ax.set_title(str(ref), fontsize=9)
|
|
565
|
+
|
|
566
|
+
# draw rotated sample label into left margin centered on the row
|
|
567
|
+
# compute vertical center of this row's axis in figure coords
|
|
568
|
+
ax0 = axes[r_idx][0]
|
|
569
|
+
ax_y0, ax_y1 = ax0.get_position().y0, ax0.get_position().y1
|
|
570
|
+
y_center = 0.5 * (ax_y0 + ax_y1)
|
|
571
|
+
# place text at x=0.01 (just inside left margin); rotation controls orientation
|
|
572
|
+
fig.text(0.01, y_center, str(chunk[r_idx]), va="center", ha="left", rotation=sample_label_rotation, fontsize=9)
|
|
573
|
+
|
|
574
|
+
fig.suptitle(f"{method} — per-sample x per-reference matrices (page {page_idx+1}/{n_pages})", fontsize=12, y=0.99)
|
|
575
|
+
fig.tight_layout(rect=[0.05, 0.02, 0.9, 0.96])
|
|
576
|
+
|
|
577
|
+
# colorbar (shared)
|
|
578
|
+
if any_plotted and show_colorbar and (im is not None):
|
|
579
|
+
try:
|
|
580
|
+
cbar_ax = fig.add_axes([0.9, 0.15, 0.02, 0.7])
|
|
581
|
+
fig.colorbar(im, cax=cbar_ax, shrink=cbar_shrink)
|
|
582
|
+
except Exception:
|
|
583
|
+
try:
|
|
584
|
+
fig.colorbar(im, ax=axes.ravel().tolist(), fraction=0.02, pad=0.02)
|
|
585
|
+
except Exception:
|
|
586
|
+
pass
|
|
587
|
+
|
|
588
|
+
# save or show
|
|
589
|
+
if output_dir:
|
|
590
|
+
fname = f"positionwise_{method}_page{page_idx+1}.png"
|
|
591
|
+
outpath = os.path.join(output_dir, fname)
|
|
592
|
+
plt.savefig(outpath, bbox_inches="tight")
|
|
593
|
+
saved_files.append(outpath)
|
|
594
|
+
plt.close(fig)
|
|
595
|
+
else:
|
|
596
|
+
plt.show()
|
|
597
|
+
saved_files.append("") # placeholder to indicate a figure was shown
|
|
598
|
+
|
|
599
|
+
saved_files_by_method[method] = saved_files
|
|
235
600
|
|
|
236
|
-
|
|
237
|
-
mat_df = pd.DataFrame(mat, index=var_names, columns=var_names)
|
|
238
|
-
adata.uns[output_key][group] = mat_df
|
|
239
|
-
adata.uns[output_key + "_n"][group] = subset.shape[0]
|
|
601
|
+
return saved_files_by_method
|