smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +7 -6
- smftools/_version.py +1 -1
- smftools/cli/cli_flows.py +94 -0
- smftools/cli/hmm_adata.py +338 -0
- smftools/cli/load_adata.py +577 -0
- smftools/cli/preprocess_adata.py +363 -0
- smftools/cli/spatial_adata.py +564 -0
- smftools/cli_entry.py +435 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +38 -0
- smftools/config/deaminase.yaml +61 -0
- smftools/config/default.yaml +264 -0
- smftools/config/direct.yaml +41 -0
- smftools/config/discover_input_files.py +115 -0
- smftools/config/experiment_config.py +1288 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/{tools → hmm}/display_hmm.py +3 -3
- smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
- smftools/{tools → hmm}/train_hmm.py +1 -1
- smftools/informatics/__init__.py +13 -9
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/fast5_to_pod5.py +43 -0
- smftools/informatics/archived/helpers/archived/__init__.py +71 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
- smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
- smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
- smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
- smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
- smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
- smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
- smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
- smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
- smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
- smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
- smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
- smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
- smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
- smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
- smftools/informatics/bam_functions.py +812 -0
- smftools/informatics/basecalling.py +67 -0
- smftools/informatics/bed_functions.py +366 -0
- smftools/informatics/binarize_converted_base_identities.py +172 -0
- smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
- smftools/informatics/fasta_functions.py +255 -0
- smftools/informatics/h5ad_functions.py +197 -0
- smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
- smftools/informatics/modkit_functions.py +129 -0
- smftools/informatics/ohe.py +160 -0
- smftools/informatics/pod5_functions.py +224 -0
- smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/{tools → machine_learning}/models/positional.py +3 -2
- smftools/{tools → machine_learning}/models/rnn.py +2 -1
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/plotting/__init__.py +4 -1
- smftools/plotting/autocorrelation_plotting.py +609 -0
- smftools/plotting/general_plotting.py +1292 -140
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +15 -8
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/binarize.py +17 -0
- smftools/preprocessing/binarize_on_Youden.py +2 -2
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_coverage.py +10 -1
- smftools/preprocessing/calculate_position_Youden.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +17 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1326 -124
- smftools/preprocessing/invert_adata.py +12 -5
- smftools/preprocessing/load_sample_sheet.py +19 -4
- smftools/readwrite.py +1021 -89
- smftools/tools/__init__.py +3 -32
- smftools/tools/calculate_umap.py +5 -5
- smftools/tools/general_tools.py +3 -3
- smftools/tools/position_stats.py +468 -106
- smftools/tools/read_stats.py +115 -1
- smftools/tools/spatial_autocorrelation.py +562 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
- smftools-0.2.3.dist-info/RECORD +173 -0
- smftools-0.2.3.dist-info/entry_points.txt +2 -0
- smftools/informatics/fast5_to_pod5.py +0 -21
- smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
- smftools/informatics/helpers/__init__.py +0 -74
- smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
- smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
- smftools/informatics/helpers/bam_qc.py +0 -66
- smftools/informatics/helpers/bed_to_bigwig.py +0 -39
- smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
- smftools/informatics/helpers/index_fasta.py +0 -12
- smftools/informatics/helpers/make_dirs.py +0 -21
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
- smftools/informatics/load_adata.py +0 -182
- smftools/informatics/readwrite.py +0 -106
- smftools/informatics/subsample_fasta_from_bed.py +0 -47
- smftools/preprocessing/append_C_context.py +0 -82
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
- smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
- smftools/preprocessing/filter_reads_on_length.py +0 -51
- smftools/tools/call_hmm_peaks.py +0 -105
- smftools/tools/data/__init__.py +0 -2
- smftools/tools/data/anndata_data_module.py +0 -90
- smftools/tools/inference/__init__.py +0 -1
- smftools/tools/inference/lightning_inference.py +0 -41
- smftools/tools/models/base.py +0 -14
- smftools/tools/models/cnn.py +0 -34
- smftools/tools/models/lightning_base.py +0 -41
- smftools/tools/models/mlp.py +0 -17
- smftools/tools/models/sklearn_models.py +0 -40
- smftools/tools/models/transformer.py +0 -133
- smftools/tools/training/__init__.py +0 -1
- smftools/tools/training/train_lightning_model.py +0 -47
- smftools-0.1.7.dist-info/RECORD +0 -136
- /smftools/{tools/evaluation → cli}/__init__.py +0 -0
- /smftools/{tools → hmm}/calculate_distances.py +0 -0
- /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
- /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
- /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
- /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
- /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
- /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
- /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
- /smftools/{tools → machine_learning}/models/__init__.py +0 -0
- /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
- /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
- /smftools/{tools → machine_learning}/utils/device.py +0 -0
- /smftools/{tools → machine_learning}/utils/grl.py +0 -0
- /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
- /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
import seaborn as sns
|
|
3
5
|
import matplotlib.pyplot as plt
|
|
6
|
+
import scipy.cluster.hierarchy as sch
|
|
7
|
+
import matplotlib.gridspec as gridspec
|
|
8
|
+
import os
|
|
9
|
+
import math
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
from typing import Optional, Mapping, Sequence, Any, Dict, List
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
def normalized_mean(matrix: np.ndarray) -> np.ndarray:
|
|
16
|
+
mean = np.nanmean(matrix, axis=0)
|
|
17
|
+
denom = (mean.max() - mean.min()) + 1e-9
|
|
18
|
+
return (mean - mean.min()) / denom
|
|
19
|
+
|
|
20
|
+
def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
|
|
21
|
+
"""
|
|
22
|
+
Fraction methylated per column.
|
|
23
|
+
Methylated = 1
|
|
24
|
+
Valid = finite AND not 0
|
|
25
|
+
"""
|
|
26
|
+
matrix = np.asarray(matrix)
|
|
27
|
+
valid_mask = np.isfinite(matrix) & (matrix != 0)
|
|
28
|
+
methyl_mask = (matrix == 1) & np.isfinite(matrix)
|
|
29
|
+
|
|
30
|
+
methylated = methyl_mask.sum(axis=0)
|
|
31
|
+
valid = valid_mask.sum(axis=0)
|
|
32
|
+
|
|
33
|
+
return np.divide(
|
|
34
|
+
methylated, valid,
|
|
35
|
+
out=np.zeros_like(methylated, dtype=float),
|
|
36
|
+
where=valid != 0
|
|
37
|
+
)
|
|
4
38
|
|
|
5
39
|
def clean_barplot(ax, mean_values, title):
|
|
6
40
|
x = np.arange(len(mean_values))
|
|
@@ -17,189 +51,1307 @@ def clean_barplot(ax, mean_values, title):
|
|
|
17
51
|
|
|
18
52
|
ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
|
|
19
53
|
|
|
54
|
+
# def combined_hmm_raw_clustermap(
|
|
55
|
+
# adata,
|
|
56
|
+
# sample_col='Sample_Names',
|
|
57
|
+
# reference_col='Reference_strand',
|
|
58
|
+
# hmm_feature_layer="hmm_combined",
|
|
59
|
+
# layer_gpc="nan0_0minus1",
|
|
60
|
+
# layer_cpg="nan0_0minus1",
|
|
61
|
+
# layer_any_c="nan0_0minus1",
|
|
62
|
+
# cmap_hmm="tab10",
|
|
63
|
+
# cmap_gpc="coolwarm",
|
|
64
|
+
# cmap_cpg="viridis",
|
|
65
|
+
# cmap_any_c='coolwarm',
|
|
66
|
+
# min_quality=20,
|
|
67
|
+
# min_length=200,
|
|
68
|
+
# min_mapped_length_to_reference_length_ratio=0.8,
|
|
69
|
+
# min_position_valid_fraction=0.5,
|
|
70
|
+
# sample_mapping=None,
|
|
71
|
+
# save_path=None,
|
|
72
|
+
# normalize_hmm=False,
|
|
73
|
+
# sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
|
|
74
|
+
# bins=None,
|
|
75
|
+
# deaminase=False,
|
|
76
|
+
# min_signal=0
|
|
77
|
+
# ):
|
|
78
|
+
|
|
79
|
+
# results = []
|
|
80
|
+
# if deaminase:
|
|
81
|
+
# signal_type = 'deamination'
|
|
82
|
+
# else:
|
|
83
|
+
# signal_type = 'methylation'
|
|
84
|
+
|
|
85
|
+
# for ref in adata.obs[reference_col].cat.categories:
|
|
86
|
+
# for sample in adata.obs[sample_col].cat.categories:
|
|
87
|
+
# try:
|
|
88
|
+
# subset = adata[
|
|
89
|
+
# (adata.obs[reference_col] == ref) &
|
|
90
|
+
# (adata.obs[sample_col] == sample) &
|
|
91
|
+
# (adata.obs['read_quality'] >= min_quality) &
|
|
92
|
+
# (adata.obs['read_length'] >= min_length) &
|
|
93
|
+
# (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
|
|
94
|
+
# ]
|
|
95
|
+
|
|
96
|
+
# mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
|
|
97
|
+
# subset = subset[:, mask]
|
|
98
|
+
|
|
99
|
+
# if subset.shape[0] == 0:
|
|
100
|
+
# print(f" No reads left after filtering for {sample} - {ref}")
|
|
101
|
+
# continue
|
|
102
|
+
|
|
103
|
+
# if bins:
|
|
104
|
+
# print(f"Using defined bins to subset clustermap for {sample} - {ref}")
|
|
105
|
+
# bins_temp = bins
|
|
106
|
+
# else:
|
|
107
|
+
# print(f"Using all reads for clustermap for {sample} - {ref}")
|
|
108
|
+
# bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
|
|
109
|
+
|
|
110
|
+
# # Get column positions (not var_names!) of site masks
|
|
111
|
+
# gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
|
|
112
|
+
# cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
|
|
113
|
+
# any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
|
|
114
|
+
# num_gpc = len(gpc_sites)
|
|
115
|
+
# num_cpg = len(cpg_sites)
|
|
116
|
+
# num_c = len(any_c_sites)
|
|
117
|
+
# print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites} for {sample} - {ref}")
|
|
118
|
+
|
|
119
|
+
# # Use var_names for x-axis tick labels
|
|
120
|
+
# gpc_labels = subset.var_names[gpc_sites].astype(int)
|
|
121
|
+
# cpg_labels = subset.var_names[cpg_sites].astype(int)
|
|
122
|
+
# any_c_labels = subset.var_names[any_c_sites].astype(int)
|
|
123
|
+
|
|
124
|
+
# stacked_hmm_feature, stacked_gpc, stacked_cpg, stacked_any_c = [], [], [], []
|
|
125
|
+
# row_labels, bin_labels = [], []
|
|
126
|
+
# bin_boundaries = []
|
|
127
|
+
|
|
128
|
+
# total_reads = subset.shape[0]
|
|
129
|
+
# percentages = {}
|
|
130
|
+
# last_idx = 0
|
|
131
|
+
|
|
132
|
+
# for bin_label, bin_filter in bins_temp.items():
|
|
133
|
+
# subset_bin = subset[bin_filter].copy()
|
|
134
|
+
# num_reads = subset_bin.shape[0]
|
|
135
|
+
# print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
|
|
136
|
+
# percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
|
|
137
|
+
# percentages[bin_label] = percent_reads
|
|
138
|
+
|
|
139
|
+
# if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
|
|
140
|
+
# # Determine sorting order
|
|
141
|
+
# if sort_by.startswith("obs:"):
|
|
142
|
+
# colname = sort_by.split("obs:")[1]
|
|
143
|
+
# order = np.argsort(subset_bin.obs[colname].values)
|
|
144
|
+
# elif sort_by == "gpc":
|
|
145
|
+
# linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
|
|
146
|
+
# order = sch.leaves_list(linkage)
|
|
147
|
+
# elif sort_by == "cpg":
|
|
148
|
+
# linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
|
|
149
|
+
# order = sch.leaves_list(linkage)
|
|
150
|
+
# elif sort_by == "gpc_cpg":
|
|
151
|
+
# linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
|
|
152
|
+
# order = sch.leaves_list(linkage)
|
|
153
|
+
# elif sort_by == "none":
|
|
154
|
+
# order = np.arange(num_reads)
|
|
155
|
+
# elif sort_by == "any_c":
|
|
156
|
+
# linkage = sch.linkage(subset_bin.layers[layer_any_c], method="ward")
|
|
157
|
+
# order = sch.leaves_list(linkage)
|
|
158
|
+
# else:
|
|
159
|
+
# raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
160
|
+
|
|
161
|
+
# stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
|
|
162
|
+
# stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
|
|
163
|
+
# stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
|
|
164
|
+
# stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
|
|
165
|
+
|
|
166
|
+
# row_labels.extend([bin_label] * num_reads)
|
|
167
|
+
# bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
168
|
+
# last_idx += num_reads
|
|
169
|
+
# bin_boundaries.append(last_idx)
|
|
170
|
+
|
|
171
|
+
# if stacked_hmm_feature:
|
|
172
|
+
# hmm_matrix = np.vstack(stacked_hmm_feature)
|
|
173
|
+
# gpc_matrix = np.vstack(stacked_gpc)
|
|
174
|
+
# cpg_matrix = np.vstack(stacked_cpg)
|
|
175
|
+
# any_c_matrix = np.vstack(stacked_any_c)
|
|
176
|
+
|
|
177
|
+
# if hmm_matrix.size > 0:
|
|
178
|
+
# def normalized_mean(matrix):
|
|
179
|
+
# mean = np.nanmean(matrix, axis=0)
|
|
180
|
+
# normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
|
|
181
|
+
# return normalized
|
|
182
|
+
|
|
183
|
+
# def methylation_fraction(matrix):
|
|
184
|
+
# methylated = (matrix == 1).sum(axis=0)
|
|
185
|
+
# valid = (matrix != 0).sum(axis=0)
|
|
186
|
+
# return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
|
|
187
|
+
|
|
188
|
+
# if normalize_hmm:
|
|
189
|
+
# mean_hmm = normalized_mean(hmm_matrix)
|
|
190
|
+
# else:
|
|
191
|
+
# mean_hmm = np.nanmean(hmm_matrix, axis=0)
|
|
192
|
+
# mean_gpc = methylation_fraction(gpc_matrix)
|
|
193
|
+
# mean_cpg = methylation_fraction(cpg_matrix)
|
|
194
|
+
# mean_any_c = methylation_fraction(any_c_matrix)
|
|
195
|
+
|
|
196
|
+
# fig = plt.figure(figsize=(18, 12))
|
|
197
|
+
# gs = gridspec.GridSpec(2, 4, height_ratios=[1, 6], hspace=0.01)
|
|
198
|
+
# fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
|
|
199
|
+
|
|
200
|
+
# axes_heat = [fig.add_subplot(gs[1, i]) for i in range(4)]
|
|
201
|
+
# axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(4)]
|
|
202
|
+
|
|
203
|
+
# clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
|
|
204
|
+
# clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
|
|
205
|
+
# clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
|
|
206
|
+
# clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
|
|
207
|
+
|
|
208
|
+
# hmm_labels = subset.var_names.astype(int)
|
|
209
|
+
# hmm_label_spacing = 150
|
|
210
|
+
# sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
|
|
211
|
+
# axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
|
|
212
|
+
# axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
|
|
213
|
+
# for boundary in bin_boundaries[:-1]:
|
|
214
|
+
# axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
|
|
215
|
+
|
|
216
|
+
# sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
|
|
217
|
+
# axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
|
|
218
|
+
# axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
|
|
219
|
+
# for boundary in bin_boundaries[:-1]:
|
|
220
|
+
# axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
|
|
221
|
+
|
|
222
|
+
# sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
|
|
223
|
+
# axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
|
|
224
|
+
# for boundary in bin_boundaries[:-1]:
|
|
225
|
+
# axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
|
|
226
|
+
|
|
227
|
+
# sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[3], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
|
|
228
|
+
# axes_heat[3].set_xticks(range(0, len(any_c_labels), 20))
|
|
229
|
+
# axes_heat[3].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
|
|
230
|
+
# for boundary in bin_boundaries[:-1]:
|
|
231
|
+
# axes_heat[3].axhline(y=boundary, color="black", linewidth=2)
|
|
232
|
+
|
|
233
|
+
# plt.tight_layout()
|
|
234
|
+
|
|
235
|
+
# if save_path:
|
|
236
|
+
# save_name = f"{ref} — {sample}"
|
|
237
|
+
# os.makedirs(save_path, exist_ok=True)
|
|
238
|
+
# safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
239
|
+
# out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
240
|
+
# plt.savefig(out_file, dpi=300)
|
|
241
|
+
# print(f"Saved: {out_file}")
|
|
242
|
+
# plt.close()
|
|
243
|
+
# else:
|
|
244
|
+
# plt.show()
|
|
245
|
+
|
|
246
|
+
# print(f"Summary for {sample} - {ref}:")
|
|
247
|
+
# for bin_label, percent in percentages.items():
|
|
248
|
+
# print(f" - {bin_label}: {percent:.1f}%")
|
|
249
|
+
|
|
250
|
+
# results.append({
|
|
251
|
+
# "sample": sample,
|
|
252
|
+
# "ref": ref,
|
|
253
|
+
# "hmm_matrix": hmm_matrix,
|
|
254
|
+
# "gpc_matrix": gpc_matrix,
|
|
255
|
+
# "cpg_matrix": cpg_matrix,
|
|
256
|
+
# "row_labels": row_labels,
|
|
257
|
+
# "bin_labels": bin_labels,
|
|
258
|
+
# "bin_boundaries": bin_boundaries,
|
|
259
|
+
# "percentages": percentages
|
|
260
|
+
# })
|
|
261
|
+
|
|
262
|
+
# #adata.uns['clustermap_results'] = results
|
|
263
|
+
|
|
264
|
+
# except Exception as e:
|
|
265
|
+
# import traceback
|
|
266
|
+
# traceback.print_exc()
|
|
267
|
+
# continue
|
|
268
|
+
|
|
20
269
|
|
|
21
270
|
def combined_hmm_raw_clustermap(
|
|
22
271
|
adata,
|
|
23
|
-
sample_col=
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
272
|
+
sample_col: str = "Sample_Names",
|
|
273
|
+
reference_col: str = "Reference_strand",
|
|
274
|
+
|
|
275
|
+
hmm_feature_layer: str = "hmm_combined",
|
|
276
|
+
|
|
277
|
+
layer_gpc: str = "nan0_0minus1",
|
|
278
|
+
layer_cpg: str = "nan0_0minus1",
|
|
279
|
+
layer_any_c: str = "nan0_0minus1",
|
|
280
|
+
layer_a: str = "nan0_0minus1",
|
|
281
|
+
|
|
282
|
+
cmap_hmm: str = "tab10",
|
|
283
|
+
cmap_gpc: str = "coolwarm",
|
|
284
|
+
cmap_cpg: str = "viridis",
|
|
285
|
+
cmap_any_c: str = "coolwarm",
|
|
286
|
+
cmap_a: str = "coolwarm",
|
|
287
|
+
|
|
288
|
+
min_quality: int = 20,
|
|
289
|
+
min_length: int = 200,
|
|
290
|
+
min_mapped_length_to_reference_length_ratio: float = 0.8,
|
|
291
|
+
min_position_valid_fraction: float = 0.5,
|
|
292
|
+
|
|
293
|
+
save_path: str | Path | None = None,
|
|
294
|
+
normalize_hmm: bool = False,
|
|
295
|
+
|
|
296
|
+
sort_by: str = "gpc",
|
|
297
|
+
bins: Optional[Dict[str, Any]] = None,
|
|
298
|
+
|
|
299
|
+
deaminase: bool = False,
|
|
300
|
+
min_signal: float = 0.0,
|
|
301
|
+
|
|
302
|
+
# ---- fixed tick label controls (counts, not spacing)
|
|
303
|
+
n_xticks_hmm: int = 10,
|
|
304
|
+
n_xticks_any_c: int = 8,
|
|
305
|
+
n_xticks_gpc: int = 8,
|
|
306
|
+
n_xticks_cpg: int = 8,
|
|
307
|
+
n_xticks_a: int = 8,
|
|
37
308
|
):
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
import matplotlib.gridspec as gridspec
|
|
44
|
-
import os
|
|
309
|
+
"""
|
|
310
|
+
Makes a multi-panel clustermap per (sample, reference):
|
|
311
|
+
HMM panel (always) + optional raw panels for any_C, GpC, CpG, and A sites.
|
|
312
|
+
|
|
313
|
+
Panels are added only if the corresponding site mask exists AND has >0 sites.
|
|
45
314
|
|
|
315
|
+
sort_by options:
|
|
316
|
+
'gpc', 'cpg', 'any_c', 'any_a', 'gpc_cpg', 'none', or 'obs:<col>'
|
|
317
|
+
"""
|
|
318
|
+
def pick_xticks(labels: np.ndarray, n_ticks: int):
|
|
319
|
+
if labels.size == 0:
|
|
320
|
+
return [], []
|
|
321
|
+
idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
|
|
322
|
+
idx = np.unique(idx)
|
|
323
|
+
return idx.tolist(), labels[idx].tolist()
|
|
324
|
+
|
|
46
325
|
results = []
|
|
326
|
+
signal_type = "deamination" if deaminase else "methylation"
|
|
47
327
|
|
|
48
|
-
for ref in adata.obs[
|
|
328
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
49
329
|
for sample in adata.obs[sample_col].cat.categories:
|
|
330
|
+
|
|
50
331
|
try:
|
|
332
|
+
# ---- subset reads ----
|
|
51
333
|
subset = adata[
|
|
52
|
-
(adata.obs[
|
|
334
|
+
(adata.obs[reference_col] == ref) &
|
|
53
335
|
(adata.obs[sample_col] == sample) &
|
|
54
|
-
(adata.obs[
|
|
55
|
-
(adata.obs[
|
|
56
|
-
(
|
|
336
|
+
(adata.obs["read_quality"] >= min_quality) &
|
|
337
|
+
(adata.obs["read_length"] >= min_length) &
|
|
338
|
+
(
|
|
339
|
+
adata.obs["mapped_length_to_reference_length_ratio"]
|
|
340
|
+
> min_mapped_length_to_reference_length_ratio
|
|
341
|
+
)
|
|
57
342
|
]
|
|
343
|
+
|
|
344
|
+
# ---- valid fraction filter ----
|
|
345
|
+
vf_key = f"{ref}_valid_fraction"
|
|
346
|
+
if vf_key in subset.var:
|
|
347
|
+
mask = subset.var[vf_key].astype(float) > float(min_position_valid_fraction)
|
|
348
|
+
subset = subset[:, mask]
|
|
349
|
+
|
|
350
|
+
if subset.shape[0] == 0:
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
# ---- bins ----
|
|
354
|
+
if bins is None:
|
|
355
|
+
bins_temp = {"All": np.ones(subset.n_obs, dtype=bool)}
|
|
356
|
+
else:
|
|
357
|
+
bins_temp = bins
|
|
358
|
+
|
|
359
|
+
# ---- site masks (robust) ----
|
|
360
|
+
def _sites(*keys):
|
|
361
|
+
for k in keys:
|
|
362
|
+
if k in subset.var:
|
|
363
|
+
return np.where(subset.var[k].values)[0]
|
|
364
|
+
return np.array([], dtype=int)
|
|
365
|
+
|
|
366
|
+
gpc_sites = _sites(f"{ref}_GpC_site")
|
|
367
|
+
cpg_sites = _sites(f"{ref}_CpG_site")
|
|
368
|
+
any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
|
|
369
|
+
any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
|
|
370
|
+
|
|
371
|
+
def _labels(sites):
|
|
372
|
+
return subset.var_names[sites].astype(int) if sites.size else np.array([])
|
|
373
|
+
|
|
374
|
+
gpc_labels = _labels(gpc_sites)
|
|
375
|
+
cpg_labels = _labels(cpg_sites)
|
|
376
|
+
any_c_labels = _labels(any_c_sites)
|
|
377
|
+
any_a_labels = _labels(any_a_sites)
|
|
378
|
+
|
|
379
|
+
# storage
|
|
380
|
+
stacked_hmm = []
|
|
381
|
+
stacked_any_c = []
|
|
382
|
+
stacked_gpc = []
|
|
383
|
+
stacked_cpg = []
|
|
384
|
+
stacked_any_a = []
|
|
385
|
+
|
|
386
|
+
row_labels, bin_labels, bin_boundaries = [], [], []
|
|
387
|
+
total_reads = subset.n_obs
|
|
388
|
+
percentages = {}
|
|
389
|
+
last_idx = 0
|
|
390
|
+
|
|
391
|
+
# ---------------- process bins ----------------
|
|
392
|
+
for bin_label, bin_filter in bins_temp.items():
|
|
393
|
+
sb = subset[bin_filter].copy()
|
|
394
|
+
n = sb.n_obs
|
|
395
|
+
if n == 0:
|
|
396
|
+
continue
|
|
397
|
+
|
|
398
|
+
pct = (n / total_reads) * 100 if total_reads else 0
|
|
399
|
+
percentages[bin_label] = pct
|
|
400
|
+
|
|
401
|
+
# ---- sorting ----
|
|
402
|
+
if sort_by.startswith("obs:"):
|
|
403
|
+
colname = sort_by.split("obs:")[1]
|
|
404
|
+
order = np.argsort(sb.obs[colname].values)
|
|
405
|
+
|
|
406
|
+
elif sort_by == "gpc" and gpc_sites.size:
|
|
407
|
+
linkage = sch.linkage(sb[:, gpc_sites].layers[layer_gpc], method="ward")
|
|
408
|
+
order = sch.leaves_list(linkage)
|
|
409
|
+
|
|
410
|
+
elif sort_by == "cpg" and cpg_sites.size:
|
|
411
|
+
linkage = sch.linkage(sb[:, cpg_sites].layers[layer_cpg], method="ward")
|
|
412
|
+
order = sch.leaves_list(linkage)
|
|
413
|
+
|
|
414
|
+
elif sort_by == "any_c" and any_c_sites.size:
|
|
415
|
+
linkage = sch.linkage(sb[:, any_c_sites].layers[layer_any_c], method="ward")
|
|
416
|
+
order = sch.leaves_list(linkage)
|
|
417
|
+
|
|
418
|
+
elif sort_by == "any_a" and any_a_sites.size:
|
|
419
|
+
linkage = sch.linkage(sb[:, any_a_sites].layers[layer_a], method="ward")
|
|
420
|
+
order = sch.leaves_list(linkage)
|
|
421
|
+
|
|
422
|
+
elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
|
|
423
|
+
linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
|
|
424
|
+
order = sch.leaves_list(linkage)
|
|
425
|
+
|
|
426
|
+
else:
|
|
427
|
+
order = np.arange(n)
|
|
428
|
+
|
|
429
|
+
sb = sb[order]
|
|
430
|
+
|
|
431
|
+
# ---- collect matrices ----
|
|
432
|
+
stacked_hmm.append(sb.layers[hmm_feature_layer])
|
|
433
|
+
if any_c_sites.size:
|
|
434
|
+
stacked_any_c.append(sb[:, any_c_sites].layers[layer_any_c])
|
|
435
|
+
if gpc_sites.size:
|
|
436
|
+
stacked_gpc.append(sb[:, gpc_sites].layers[layer_gpc])
|
|
437
|
+
if cpg_sites.size:
|
|
438
|
+
stacked_cpg.append(sb[:, cpg_sites].layers[layer_cpg])
|
|
439
|
+
if any_a_sites.size:
|
|
440
|
+
stacked_any_a.append(sb[:, any_a_sites].layers[layer_a])
|
|
441
|
+
|
|
442
|
+
row_labels.extend([bin_label] * n)
|
|
443
|
+
bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
|
|
444
|
+
last_idx += n
|
|
445
|
+
bin_boundaries.append(last_idx)
|
|
446
|
+
|
|
447
|
+
# ---------------- stack ----------------
|
|
448
|
+
hmm_matrix = np.vstack(stacked_hmm)
|
|
449
|
+
mean_hmm = normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
|
|
450
|
+
|
|
451
|
+
panels = [
|
|
452
|
+
("HMM", hmm_matrix, subset.var_names.astype(int), cmap_hmm, mean_hmm, n_xticks_hmm),
|
|
453
|
+
]
|
|
454
|
+
|
|
455
|
+
if stacked_any_c:
|
|
456
|
+
m = np.vstack(stacked_any_c)
|
|
457
|
+
panels.append(("any_C", m, any_c_labels, cmap_any_c, methylation_fraction(m), n_xticks_any_c))
|
|
458
|
+
|
|
459
|
+
if stacked_gpc:
|
|
460
|
+
m = np.vstack(stacked_gpc)
|
|
461
|
+
panels.append(("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc))
|
|
462
|
+
|
|
463
|
+
if stacked_cpg:
|
|
464
|
+
m = np.vstack(stacked_cpg)
|
|
465
|
+
panels.append(("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg))
|
|
466
|
+
|
|
467
|
+
if stacked_any_a:
|
|
468
|
+
m = np.vstack(stacked_any_a)
|
|
469
|
+
panels.append(("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a))
|
|
470
|
+
|
|
471
|
+
# ---------------- plotting ----------------
|
|
472
|
+
n_panels = len(panels)
|
|
473
|
+
fig = plt.figure(figsize=(4.5 * n_panels, 10))
|
|
474
|
+
gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
|
|
475
|
+
fig.suptitle(f"{sample} — {ref} — {total_reads} reads ({signal_type})",
|
|
476
|
+
fontsize=14, y=0.98)
|
|
477
|
+
|
|
478
|
+
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
|
|
479
|
+
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
|
|
480
|
+
|
|
481
|
+
for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
|
|
482
|
+
|
|
483
|
+
# ---- your clean barplot ----
|
|
484
|
+
clean_barplot(axes_bar[i], mean_vec, name)
|
|
485
|
+
|
|
486
|
+
# ---- heatmap ----
|
|
487
|
+
sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i],
|
|
488
|
+
yticklabels=False, cbar=False)
|
|
489
|
+
|
|
490
|
+
# ---- xticks ----
|
|
491
|
+
xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
|
|
492
|
+
axes_heat[i].set_xticks(xtick_pos)
|
|
493
|
+
axes_heat[i].set_xticklabels(xtick_labels, rotation=90, fontsize=8)
|
|
494
|
+
|
|
495
|
+
for boundary in bin_boundaries[:-1]:
|
|
496
|
+
axes_heat[i].axhline(y=boundary, color="black", linewidth=1.2)
|
|
497
|
+
|
|
498
|
+
plt.tight_layout()
|
|
499
|
+
|
|
500
|
+
if save_path:
|
|
501
|
+
save_path = Path(save_path)
|
|
502
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
503
|
+
safe_name = f"{ref}__{sample}".replace("/", "_")
|
|
504
|
+
out_file = save_path / f"{safe_name}.png"
|
|
505
|
+
plt.savefig(out_file, dpi=300)
|
|
506
|
+
plt.close(fig)
|
|
507
|
+
else:
|
|
508
|
+
plt.show()
|
|
509
|
+
|
|
510
|
+
except Exception:
|
|
511
|
+
import traceback
|
|
512
|
+
traceback.print_exc()
|
|
513
|
+
continue
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
# def combined_raw_clustermap(
|
|
517
|
+
# adata,
|
|
518
|
+
# sample_col='Sample_Names',
|
|
519
|
+
# reference_col='Reference_strand',
|
|
520
|
+
# mod_target_bases=['GpC', 'CpG'],
|
|
521
|
+
# layer_any_c="nan0_0minus1",
|
|
522
|
+
# layer_gpc="nan0_0minus1",
|
|
523
|
+
# layer_cpg="nan0_0minus1",
|
|
524
|
+
# layer_a="nan0_0minus1",
|
|
525
|
+
# cmap_any_c="coolwarm",
|
|
526
|
+
# cmap_gpc="coolwarm",
|
|
527
|
+
# cmap_cpg="viridis",
|
|
528
|
+
# cmap_a="coolwarm",
|
|
529
|
+
# min_quality=20,
|
|
530
|
+
# min_length=200,
|
|
531
|
+
# min_mapped_length_to_reference_length_ratio=0.8,
|
|
532
|
+
# min_position_valid_fraction=0.5,
|
|
533
|
+
# sample_mapping=None,
|
|
534
|
+
# save_path=None,
|
|
535
|
+
# sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', 'any_a', or 'obs:<column>'
|
|
536
|
+
# bins=None,
|
|
537
|
+
# deaminase=False,
|
|
538
|
+
# min_signal=0
|
|
539
|
+
# ):
|
|
540
|
+
|
|
541
|
+
# results = []
|
|
542
|
+
|
|
543
|
+
# for ref in adata.obs[reference_col].cat.categories:
|
|
544
|
+
# for sample in adata.obs[sample_col].cat.categories:
|
|
545
|
+
# try:
|
|
546
|
+
# subset = adata[
|
|
547
|
+
# (adata.obs[reference_col] == ref) &
|
|
548
|
+
# (adata.obs[sample_col] == sample) &
|
|
549
|
+
# (adata.obs['read_quality'] >= min_quality) &
|
|
550
|
+
# (adata.obs['mapped_length'] >= min_length) &
|
|
551
|
+
# (adata.obs['mapped_length_to_reference_length_ratio'] >= min_mapped_length_to_reference_length_ratio)
|
|
552
|
+
# ]
|
|
553
|
+
|
|
554
|
+
# mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
|
|
555
|
+
# subset = subset[:, mask]
|
|
556
|
+
|
|
557
|
+
# if subset.shape[0] == 0:
|
|
558
|
+
# print(f" No reads left after filtering for {sample} - {ref}")
|
|
559
|
+
# continue
|
|
560
|
+
|
|
561
|
+
# if bins:
|
|
562
|
+
# print(f"Using defined bins to subset clustermap for {sample} - {ref}")
|
|
563
|
+
# bins_temp = bins
|
|
564
|
+
# else:
|
|
565
|
+
# print(f"Using all reads for clustermap for {sample} - {ref}")
|
|
566
|
+
# bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
|
|
567
|
+
|
|
568
|
+
# num_any_c = 0
|
|
569
|
+
# num_gpc = 0
|
|
570
|
+
# num_cpg = 0
|
|
571
|
+
# num_any_a = 0
|
|
572
|
+
|
|
573
|
+
# # Get column positions (not var_names!) of site masks
|
|
574
|
+
# if any(base in ["C", "CpG", "GpC"] for base in mod_target_bases):
|
|
575
|
+
# any_c_sites = np.where(subset.var[f"{ref}_C_site"].values)[0]
|
|
576
|
+
# gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
|
|
577
|
+
# cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
|
|
578
|
+
# num_any_c = len(any_c_sites)
|
|
579
|
+
# num_gpc = len(gpc_sites)
|
|
580
|
+
# num_cpg = len(cpg_sites)
|
|
581
|
+
# print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites}\n and {num_any_c} any_C sites at {any_c_sites} for {sample} - {ref}")
|
|
582
|
+
|
|
583
|
+
# # Use var_names for x-axis tick labels
|
|
584
|
+
# gpc_labels = subset.var_names[gpc_sites].astype(int)
|
|
585
|
+
# cpg_labels = subset.var_names[cpg_sites].astype(int)
|
|
586
|
+
# any_c_labels = subset.var_names[any_c_sites].astype(int)
|
|
587
|
+
# stacked_any_c, stacked_gpc, stacked_cpg = [], [], []
|
|
588
|
+
|
|
589
|
+
# if "A" in mod_target_bases:
|
|
590
|
+
# any_a_sites = np.where(subset.var[f"{ref}_A_site"].values)[0]
|
|
591
|
+
# num_any_a = len(any_a_sites)
|
|
592
|
+
# print(f"Found {num_any_a} any_A sites at {any_a_sites} for {sample} - {ref}")
|
|
593
|
+
# any_a_labels = subset.var_names[any_a_sites].astype(int)
|
|
594
|
+
# stacked_any_a = []
|
|
595
|
+
|
|
596
|
+
# row_labels, bin_labels = [], []
|
|
597
|
+
# bin_boundaries = []
|
|
598
|
+
|
|
599
|
+
# total_reads = subset.shape[0]
|
|
600
|
+
# percentages = {}
|
|
601
|
+
# last_idx = 0
|
|
602
|
+
|
|
603
|
+
# for bin_label, bin_filter in bins_temp.items():
|
|
604
|
+
# subset_bin = subset[bin_filter].copy()
|
|
605
|
+
# num_reads = subset_bin.shape[0]
|
|
606
|
+
# print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
|
|
607
|
+
# percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
|
|
608
|
+
# percentages[bin_label] = percent_reads
|
|
609
|
+
|
|
610
|
+
# if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
|
|
611
|
+
# # Determine sorting order
|
|
612
|
+
# if sort_by.startswith("obs:"):
|
|
613
|
+
# colname = sort_by.split("obs:")[1]
|
|
614
|
+
# order = np.argsort(subset_bin.obs[colname].values)
|
|
615
|
+
# elif sort_by == "gpc":
|
|
616
|
+
# linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
|
|
617
|
+
# order = sch.leaves_list(linkage)
|
|
618
|
+
# elif sort_by == "cpg":
|
|
619
|
+
# linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
|
|
620
|
+
# order = sch.leaves_list(linkage)
|
|
621
|
+
# elif sort_by == "any_c":
|
|
622
|
+
# linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
|
|
623
|
+
# order = sch.leaves_list(linkage)
|
|
624
|
+
# elif sort_by == "gpc_cpg":
|
|
625
|
+
# linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
|
|
626
|
+
# order = sch.leaves_list(linkage)
|
|
627
|
+
# elif sort_by == "none":
|
|
628
|
+
# order = np.arange(num_reads)
|
|
629
|
+
# elif sort_by == "any_a":
|
|
630
|
+
# linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
|
|
631
|
+
# order = sch.leaves_list(linkage)
|
|
632
|
+
# else:
|
|
633
|
+
# raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
634
|
+
|
|
635
|
+
# stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
|
|
636
|
+
# stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
|
|
637
|
+
# stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
|
|
638
|
+
|
|
639
|
+
# if num_reads > 0 and num_any_a > 0:
|
|
640
|
+
# # Determine sorting order
|
|
641
|
+
# if sort_by.startswith("obs:"):
|
|
642
|
+
# colname = sort_by.split("obs:")[1]
|
|
643
|
+
# order = np.argsort(subset_bin.obs[colname].values)
|
|
644
|
+
# elif sort_by == "gpc":
|
|
645
|
+
# linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
|
|
646
|
+
# order = sch.leaves_list(linkage)
|
|
647
|
+
# elif sort_by == "cpg":
|
|
648
|
+
# linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
|
|
649
|
+
# order = sch.leaves_list(linkage)
|
|
650
|
+
# elif sort_by == "any_c":
|
|
651
|
+
# linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
|
|
652
|
+
# order = sch.leaves_list(linkage)
|
|
653
|
+
# elif sort_by == "gpc_cpg":
|
|
654
|
+
# linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
|
|
655
|
+
# order = sch.leaves_list(linkage)
|
|
656
|
+
# elif sort_by == "none":
|
|
657
|
+
# order = np.arange(num_reads)
|
|
658
|
+
# elif sort_by == "any_a":
|
|
659
|
+
# linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
|
|
660
|
+
# order = sch.leaves_list(linkage)
|
|
661
|
+
# else:
|
|
662
|
+
# raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
663
|
+
|
|
664
|
+
# stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
# row_labels.extend([bin_label] * num_reads)
|
|
668
|
+
# bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
669
|
+
# last_idx += num_reads
|
|
670
|
+
# bin_boundaries.append(last_idx)
|
|
671
|
+
|
|
672
|
+
# gs_dim = 0
|
|
673
|
+
|
|
674
|
+
# if stacked_any_c:
|
|
675
|
+
# any_c_matrix = np.vstack(stacked_any_c)
|
|
676
|
+
# gpc_matrix = np.vstack(stacked_gpc)
|
|
677
|
+
# cpg_matrix = np.vstack(stacked_cpg)
|
|
678
|
+
# if any_c_matrix.size > 0:
|
|
679
|
+
# mean_gpc = methylation_fraction(gpc_matrix)
|
|
680
|
+
# mean_cpg = methylation_fraction(cpg_matrix)
|
|
681
|
+
# mean_any_c = methylation_fraction(any_c_matrix)
|
|
682
|
+
# gs_dim += 3
|
|
683
|
+
|
|
684
|
+
# if stacked_any_a:
|
|
685
|
+
# any_a_matrix = np.vstack(stacked_any_a)
|
|
686
|
+
# if any_a_matrix.size > 0:
|
|
687
|
+
# mean_any_a = methylation_fraction(any_a_matrix)
|
|
688
|
+
# gs_dim += 1
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
# fig = plt.figure(figsize=(18, 12))
|
|
692
|
+
# gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
|
|
693
|
+
# fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
|
|
694
|
+
# axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
|
|
695
|
+
# axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
|
|
696
|
+
|
|
697
|
+
# current_ax = 0
|
|
698
|
+
|
|
699
|
+
# if stacked_any_c:
|
|
700
|
+
# if any_c_matrix.size > 0:
|
|
701
|
+
# clean_barplot(axes_bar[current_ax], mean_any_c, f"any C site Modification Signal")
|
|
702
|
+
# sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[current_ax], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
|
|
703
|
+
# axes_heat[current_ax].set_xticks(range(0, len(any_c_labels), 20))
|
|
704
|
+
# axes_heat[current_ax].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
|
|
705
|
+
# for boundary in bin_boundaries[:-1]:
|
|
706
|
+
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
707
|
+
# current_ax +=1
|
|
708
|
+
|
|
709
|
+
# clean_barplot(axes_bar[current_ax], mean_gpc, f"GpC Modification Signal")
|
|
710
|
+
# sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[current_ax], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
|
|
711
|
+
# axes_heat[current_ax].set_xticks(range(0, len(gpc_labels), 5))
|
|
712
|
+
# axes_heat[current_ax].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
|
|
713
|
+
# for boundary in bin_boundaries[:-1]:
|
|
714
|
+
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
715
|
+
# current_ax +=1
|
|
716
|
+
|
|
717
|
+
# clean_barplot(axes_bar[current_ax], mean_cpg, f"CpG Modification Signal")
|
|
718
|
+
# sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
|
|
719
|
+
# axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
|
|
720
|
+
# for boundary in bin_boundaries[:-1]:
|
|
721
|
+
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
722
|
+
# current_ax +=1
|
|
723
|
+
|
|
724
|
+
# results.append({
|
|
725
|
+
# "sample": sample,
|
|
726
|
+
# "ref": ref,
|
|
727
|
+
# "any_c_matrix": any_c_matrix,
|
|
728
|
+
# "gpc_matrix": gpc_matrix,
|
|
729
|
+
# "cpg_matrix": cpg_matrix,
|
|
730
|
+
# "row_labels": row_labels,
|
|
731
|
+
# "bin_labels": bin_labels,
|
|
732
|
+
# "bin_boundaries": bin_boundaries,
|
|
733
|
+
# "percentages": percentages
|
|
734
|
+
# })
|
|
735
|
+
|
|
736
|
+
# if stacked_any_a:
|
|
737
|
+
# if any_a_matrix.size > 0:
|
|
738
|
+
# clean_barplot(axes_bar[current_ax], mean_any_a, f"any A site Modification Signal")
|
|
739
|
+
# sns.heatmap(any_a_matrix, cmap=cmap_a, ax=axes_heat[current_ax], xticklabels=any_a_labels[::20], yticklabels=False, cbar=False)
|
|
740
|
+
# axes_heat[current_ax].set_xticks(range(0, len(any_a_labels), 20))
|
|
741
|
+
# axes_heat[current_ax].set_xticklabels(any_a_labels[::20], rotation=90, fontsize=10)
|
|
742
|
+
# for boundary in bin_boundaries[:-1]:
|
|
743
|
+
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
744
|
+
# current_ax +=1
|
|
745
|
+
|
|
746
|
+
# results.append({
|
|
747
|
+
# "sample": sample,
|
|
748
|
+
# "ref": ref,
|
|
749
|
+
# "any_a_matrix": any_a_matrix,
|
|
750
|
+
# "row_labels": row_labels,
|
|
751
|
+
# "bin_labels": bin_labels,
|
|
752
|
+
# "bin_boundaries": bin_boundaries,
|
|
753
|
+
# "percentages": percentages
|
|
754
|
+
# })
|
|
755
|
+
|
|
756
|
+
# plt.tight_layout()
|
|
757
|
+
|
|
758
|
+
# if save_path:
|
|
759
|
+
# save_name = f"{ref} — {sample}"
|
|
760
|
+
# os.makedirs(save_path, exist_ok=True)
|
|
761
|
+
# safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
762
|
+
# out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
763
|
+
# plt.savefig(out_file, dpi=300)
|
|
764
|
+
# print(f"Saved: {out_file}")
|
|
765
|
+
# plt.close()
|
|
766
|
+
# else:
|
|
767
|
+
# plt.show()
|
|
768
|
+
|
|
769
|
+
# print(f"Summary for {sample} - {ref}:")
|
|
770
|
+
# for bin_label, percent in percentages.items():
|
|
771
|
+
# print(f" - {bin_label}: {percent:.1f}%")
|
|
58
772
|
|
|
59
|
-
|
|
773
|
+
# adata.uns['clustermap_results'] = results
|
|
774
|
+
|
|
775
|
+
# except Exception as e:
|
|
776
|
+
# import traceback
|
|
777
|
+
# traceback.print_exc()
|
|
778
|
+
# continue
|
|
779
|
+
|
|
780
|
+
def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
|
|
781
|
+
"""
|
|
782
|
+
Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
|
|
783
|
+
Always includes 0 and n_positions-1 when possible.
|
|
784
|
+
"""
|
|
785
|
+
n_ticks = int(max(2, n_ticks))
|
|
786
|
+
if n_positions <= n_ticks:
|
|
787
|
+
return np.arange(n_positions)
|
|
788
|
+
|
|
789
|
+
# linspace gives fixed count
|
|
790
|
+
pos = np.linspace(0, n_positions - 1, n_ticks)
|
|
791
|
+
return np.unique(np.round(pos).astype(int))
|
|
792
|
+
|
|
793
|
+
def combined_raw_clustermap(
|
|
794
|
+
adata,
|
|
795
|
+
sample_col: str = "Sample_Names",
|
|
796
|
+
reference_col: str = "Reference_strand",
|
|
797
|
+
mod_target_bases: Sequence[str] = ("GpC", "CpG"),
|
|
798
|
+
layer_any_c: str = "nan0_0minus1",
|
|
799
|
+
layer_gpc: str = "nan0_0minus1",
|
|
800
|
+
layer_cpg: str = "nan0_0minus1",
|
|
801
|
+
layer_a: str = "nan0_0minus1",
|
|
802
|
+
cmap_any_c: str = "coolwarm",
|
|
803
|
+
cmap_gpc: str = "coolwarm",
|
|
804
|
+
cmap_cpg: str = "viridis",
|
|
805
|
+
cmap_a: str = "coolwarm",
|
|
806
|
+
min_quality: float = 20,
|
|
807
|
+
min_length: int = 200,
|
|
808
|
+
min_mapped_length_to_reference_length_ratio: float = 0.8,
|
|
809
|
+
min_position_valid_fraction: float = 0.5,
|
|
810
|
+
sample_mapping: Optional[Mapping[str, str]] = None,
|
|
811
|
+
save_path: str | Path | None = None,
|
|
812
|
+
sort_by: str = "gpc", # 'gpc','cpg','any_c','gpc_cpg','any_a','none','obs:<col>'
|
|
813
|
+
bins: Optional[Dict[str, Any]] = None,
|
|
814
|
+
deaminase: bool = False,
|
|
815
|
+
min_signal: float = 0,
|
|
816
|
+
# NEW tick controls
|
|
817
|
+
n_xticks_any_c: int = 10,
|
|
818
|
+
n_xticks_gpc: int = 10,
|
|
819
|
+
n_xticks_cpg: int = 10,
|
|
820
|
+
n_xticks_any_a: int = 10,
|
|
821
|
+
xtick_rotation: int = 90,
|
|
822
|
+
xtick_fontsize: int = 9,
|
|
823
|
+
):
|
|
824
|
+
"""
|
|
825
|
+
Plot stacked heatmaps + per-position mean barplots for any_C, GpC, CpG, and optional A.
|
|
826
|
+
|
|
827
|
+
Key fixes vs old version:
|
|
828
|
+
- order computed ONCE per bin, applied to all matrices
|
|
829
|
+
- no hard-coded axes indices
|
|
830
|
+
- NaNs excluded from methylation denominators
|
|
831
|
+
- var_names not forced to int
|
|
832
|
+
- fixed count of x tick labels per block (controllable)
|
|
833
|
+
- adata.uns updated once at end
|
|
834
|
+
|
|
835
|
+
Returns
|
|
836
|
+
-------
|
|
837
|
+
results : list[dict]
|
|
838
|
+
One entry per (sample, ref) plot with matrices + bin metadata.
|
|
839
|
+
"""
|
|
840
|
+
|
|
841
|
+
results: List[Dict[str, Any]] = []
|
|
842
|
+
save_path = Path(save_path) if save_path is not None else None
|
|
843
|
+
if save_path is not None:
|
|
844
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
845
|
+
|
|
846
|
+
# Ensure categorical
|
|
847
|
+
for col in (sample_col, reference_col):
|
|
848
|
+
if col not in adata.obs:
|
|
849
|
+
raise KeyError(f"{col} not in adata.obs")
|
|
850
|
+
if not pd.api.types.is_categorical_dtype(adata.obs[col]):
|
|
851
|
+
adata.obs[col] = adata.obs[col].astype("category")
|
|
852
|
+
|
|
853
|
+
base_set = set(mod_target_bases)
|
|
854
|
+
include_any_c = any(b in {"C", "CpG", "GpC"} for b in base_set)
|
|
855
|
+
include_any_a = "A" in base_set
|
|
856
|
+
|
|
857
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
858
|
+
for sample in adata.obs[sample_col].cat.categories:
|
|
859
|
+
|
|
860
|
+
# Optionally remap sample label for display
|
|
861
|
+
display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
|
|
862
|
+
|
|
863
|
+
try:
|
|
864
|
+
subset = adata[
|
|
865
|
+
(adata.obs[reference_col] == ref) &
|
|
866
|
+
(adata.obs[sample_col] == sample) &
|
|
867
|
+
(adata.obs["read_quality"] >= min_quality) &
|
|
868
|
+
(adata.obs["mapped_length"] >= min_length) &
|
|
869
|
+
(adata.obs["mapped_length_to_reference_length_ratio"] >= min_mapped_length_to_reference_length_ratio)
|
|
870
|
+
]
|
|
871
|
+
|
|
872
|
+
# position-level mask
|
|
873
|
+
valid_key = f"{ref}_valid_fraction"
|
|
874
|
+
if valid_key in subset.var:
|
|
875
|
+
mask = subset.var[valid_key].astype(float).values > float(min_position_valid_fraction)
|
|
876
|
+
subset = subset[:, mask]
|
|
60
877
|
|
|
61
878
|
if subset.shape[0] == 0:
|
|
62
|
-
print(f"
|
|
879
|
+
print(f"No reads left after filtering for {display_sample} - {ref}")
|
|
63
880
|
continue
|
|
64
881
|
|
|
65
|
-
|
|
66
|
-
|
|
882
|
+
# bins mode
|
|
883
|
+
if bins is None:
|
|
884
|
+
bins_temp = {"All": (subset.obs[reference_col] == ref)}
|
|
67
885
|
else:
|
|
68
|
-
|
|
886
|
+
bins_temp = bins
|
|
69
887
|
|
|
70
|
-
#
|
|
71
|
-
gpc_sites = np.
|
|
72
|
-
|
|
888
|
+
# find sites (positions)
|
|
889
|
+
any_c_sites = gpc_sites = cpg_sites = np.array([], dtype=int)
|
|
890
|
+
any_a_sites = np.array([], dtype=int)
|
|
73
891
|
|
|
74
|
-
|
|
75
|
-
gpc_labels = subset.var_names[gpc_sites].astype(int)
|
|
76
|
-
cpg_labels = subset.var_names[cpg_sites].astype(int)
|
|
892
|
+
num_any_c = num_gpc = num_cpg = num_any_a = 0
|
|
77
893
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
894
|
+
if include_any_c:
|
|
895
|
+
any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
|
|
896
|
+
gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
|
|
897
|
+
cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
|
|
81
898
|
|
|
82
|
-
|
|
899
|
+
num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
|
|
900
|
+
|
|
901
|
+
any_c_labels = subset.var_names[any_c_sites].astype(str)
|
|
902
|
+
gpc_labels = subset.var_names[gpc_sites].astype(str)
|
|
903
|
+
cpg_labels = subset.var_names[cpg_sites].astype(str)
|
|
904
|
+
|
|
905
|
+
if include_any_a:
|
|
906
|
+
any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
|
|
907
|
+
num_any_a = len(any_a_sites)
|
|
908
|
+
any_a_labels = subset.var_names[any_a_sites].astype(str)
|
|
909
|
+
|
|
910
|
+
stacked_any_c, stacked_gpc, stacked_cpg, stacked_any_a = [], [], [], []
|
|
911
|
+
row_labels, bin_labels, bin_boundaries = [], [], []
|
|
83
912
|
percentages = {}
|
|
84
913
|
last_idx = 0
|
|
914
|
+
total_reads = subset.shape[0]
|
|
85
915
|
|
|
86
|
-
|
|
916
|
+
# ----------------------------
|
|
917
|
+
# per-bin stacking
|
|
918
|
+
# ----------------------------
|
|
919
|
+
for bin_label, bin_filter in bins_temp.items():
|
|
87
920
|
subset_bin = subset[bin_filter].copy()
|
|
88
921
|
num_reads = subset_bin.shape[0]
|
|
89
|
-
|
|
922
|
+
if num_reads == 0:
|
|
923
|
+
percentages[bin_label] = 0.0
|
|
924
|
+
continue
|
|
925
|
+
|
|
926
|
+
percent_reads = (num_reads / total_reads) * 100
|
|
90
927
|
percentages[bin_label] = percent_reads
|
|
91
928
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
if stacked_hmm_feature:
|
|
121
|
-
hmm_matrix = np.vstack(stacked_hmm_feature)
|
|
122
|
-
gpc_matrix = np.vstack(stacked_gpc)
|
|
123
|
-
cpg_matrix = np.vstack(stacked_cpg)
|
|
124
|
-
|
|
125
|
-
def normalized_mean(matrix):
|
|
126
|
-
mean = np.nanmean(matrix, axis=0)
|
|
127
|
-
normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
|
|
128
|
-
return normalized
|
|
129
|
-
|
|
130
|
-
def methylation_fraction(matrix):
|
|
131
|
-
methylated = (matrix == 1).sum(axis=0)
|
|
132
|
-
valid = (matrix != 0).sum(axis=0)
|
|
133
|
-
return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
|
|
134
|
-
|
|
135
|
-
if normalize_hmm:
|
|
136
|
-
mean_hmm = normalized_mean(hmm_matrix)
|
|
929
|
+
# compute order ONCE
|
|
930
|
+
if sort_by.startswith("obs:"):
|
|
931
|
+
colname = sort_by.split("obs:")[1]
|
|
932
|
+
order = np.argsort(subset_bin.obs[colname].values)
|
|
933
|
+
|
|
934
|
+
elif sort_by == "gpc" and num_gpc > 0:
|
|
935
|
+
linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
|
|
936
|
+
order = sch.leaves_list(linkage)
|
|
937
|
+
|
|
938
|
+
elif sort_by == "cpg" and num_cpg > 0:
|
|
939
|
+
linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
|
|
940
|
+
order = sch.leaves_list(linkage)
|
|
941
|
+
|
|
942
|
+
elif sort_by == "any_c" and num_any_c > 0:
|
|
943
|
+
linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
|
|
944
|
+
order = sch.leaves_list(linkage)
|
|
945
|
+
|
|
946
|
+
elif sort_by == "gpc_cpg":
|
|
947
|
+
linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
|
|
948
|
+
order = sch.leaves_list(linkage)
|
|
949
|
+
|
|
950
|
+
elif sort_by == "any_a" and num_any_a > 0:
|
|
951
|
+
linkage = sch.linkage(subset_bin[:, any_a_sites].layers[layer_a], method="ward")
|
|
952
|
+
order = sch.leaves_list(linkage)
|
|
953
|
+
|
|
954
|
+
elif sort_by == "none":
|
|
955
|
+
order = np.arange(num_reads)
|
|
956
|
+
|
|
137
957
|
else:
|
|
138
|
-
|
|
139
|
-
mean_gpc = methylation_fraction(gpc_matrix)
|
|
140
|
-
mean_cpg = methylation_fraction(cpg_matrix)
|
|
958
|
+
order = np.arange(num_reads)
|
|
141
959
|
|
|
142
|
-
|
|
143
|
-
gs = gridspec.GridSpec(2, 3, height_ratios=[1, 6], hspace=0.01)
|
|
144
|
-
fig.suptitle(f"{sample} - {ref}", fontsize=14, y=0.95)
|
|
960
|
+
subset_bin = subset_bin[order]
|
|
145
961
|
|
|
146
|
-
|
|
147
|
-
|
|
962
|
+
# stack consistently
|
|
963
|
+
if include_any_c and num_any_c > 0:
|
|
964
|
+
stacked_any_c.append(subset_bin[:, any_c_sites].layers[layer_any_c])
|
|
965
|
+
if include_any_c and num_gpc > 0:
|
|
966
|
+
stacked_gpc.append(subset_bin[:, gpc_sites].layers[layer_gpc])
|
|
967
|
+
if include_any_c and num_cpg > 0:
|
|
968
|
+
stacked_cpg.append(subset_bin[:, cpg_sites].layers[layer_cpg])
|
|
969
|
+
if include_any_a and num_any_a > 0:
|
|
970
|
+
stacked_any_a.append(subset_bin[:, any_a_sites].layers[layer_a])
|
|
148
971
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
hmm_labels = subset.var_names.astype(int)
|
|
154
|
-
hmm_label_spacing = 150
|
|
155
|
-
sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
|
|
156
|
-
axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
|
|
157
|
-
axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
|
|
158
|
-
for boundary in bin_boundaries[:-1]:
|
|
159
|
-
axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
|
|
972
|
+
row_labels.extend([bin_label] * num_reads)
|
|
973
|
+
bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
974
|
+
last_idx += num_reads
|
|
975
|
+
bin_boundaries.append(last_idx)
|
|
160
976
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
|
|
977
|
+
# ----------------------------
|
|
978
|
+
# build matrices + means
|
|
979
|
+
# ----------------------------
|
|
980
|
+
blocks = [] # list of dicts describing what to plot in order
|
|
166
981
|
|
|
167
|
-
|
|
168
|
-
|
|
982
|
+
if include_any_c and stacked_any_c:
|
|
983
|
+
any_c_matrix = np.vstack(stacked_any_c)
|
|
984
|
+
gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
|
|
985
|
+
cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
|
|
986
|
+
|
|
987
|
+
mean_any_c = methylation_fraction(any_c_matrix) if any_c_matrix.size else None
|
|
988
|
+
mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
|
|
989
|
+
mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
|
|
990
|
+
|
|
991
|
+
if any_c_matrix.size:
|
|
992
|
+
blocks.append(dict(
|
|
993
|
+
name="any_c",
|
|
994
|
+
matrix=any_c_matrix,
|
|
995
|
+
mean=mean_any_c,
|
|
996
|
+
labels=any_c_labels,
|
|
997
|
+
cmap=cmap_any_c,
|
|
998
|
+
n_xticks=n_xticks_any_c,
|
|
999
|
+
title="any C site Modification Signal"
|
|
1000
|
+
))
|
|
1001
|
+
if gpc_matrix.size:
|
|
1002
|
+
blocks.append(dict(
|
|
1003
|
+
name="gpc",
|
|
1004
|
+
matrix=gpc_matrix,
|
|
1005
|
+
mean=mean_gpc,
|
|
1006
|
+
labels=gpc_labels,
|
|
1007
|
+
cmap=cmap_gpc,
|
|
1008
|
+
n_xticks=n_xticks_gpc,
|
|
1009
|
+
title="GpC Modification Signal"
|
|
1010
|
+
))
|
|
1011
|
+
if cpg_matrix.size:
|
|
1012
|
+
blocks.append(dict(
|
|
1013
|
+
name="cpg",
|
|
1014
|
+
matrix=cpg_matrix,
|
|
1015
|
+
mean=mean_cpg,
|
|
1016
|
+
labels=cpg_labels,
|
|
1017
|
+
cmap=cmap_cpg,
|
|
1018
|
+
n_xticks=n_xticks_cpg,
|
|
1019
|
+
title="CpG Modification Signal"
|
|
1020
|
+
))
|
|
1021
|
+
|
|
1022
|
+
if include_any_a and stacked_any_a:
|
|
1023
|
+
any_a_matrix = np.vstack(stacked_any_a)
|
|
1024
|
+
mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
|
|
1025
|
+
if any_a_matrix.size:
|
|
1026
|
+
blocks.append(dict(
|
|
1027
|
+
name="any_a",
|
|
1028
|
+
matrix=any_a_matrix,
|
|
1029
|
+
mean=mean_any_a,
|
|
1030
|
+
labels=any_a_labels,
|
|
1031
|
+
cmap=cmap_a,
|
|
1032
|
+
n_xticks=n_xticks_any_a,
|
|
1033
|
+
title="any A site Modification Signal"
|
|
1034
|
+
))
|
|
1035
|
+
|
|
1036
|
+
if not blocks:
|
|
1037
|
+
print(f"No matrices to plot for {display_sample} - {ref}")
|
|
1038
|
+
continue
|
|
1039
|
+
|
|
1040
|
+
gs_dim = len(blocks)
|
|
1041
|
+
fig = plt.figure(figsize=(5.5 * gs_dim, 11))
|
|
1042
|
+
gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.02)
|
|
1043
|
+
fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
|
|
1044
|
+
|
|
1045
|
+
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
|
|
1046
|
+
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
|
|
1047
|
+
|
|
1048
|
+
# ----------------------------
|
|
1049
|
+
# plot blocks
|
|
1050
|
+
# ----------------------------
|
|
1051
|
+
for i, blk in enumerate(blocks):
|
|
1052
|
+
mat = blk["matrix"]
|
|
1053
|
+
mean = blk["mean"]
|
|
1054
|
+
labels = np.asarray(blk["labels"], dtype=str)
|
|
1055
|
+
n_xticks = blk["n_xticks"]
|
|
1056
|
+
|
|
1057
|
+
# barplot
|
|
1058
|
+
clean_barplot(axes_bar[i], mean, blk["title"])
|
|
1059
|
+
|
|
1060
|
+
# heatmap
|
|
1061
|
+
sns.heatmap(
|
|
1062
|
+
mat,
|
|
1063
|
+
cmap=blk["cmap"],
|
|
1064
|
+
ax=axes_heat[i],
|
|
1065
|
+
yticklabels=False,
|
|
1066
|
+
cbar=False
|
|
1067
|
+
)
|
|
1068
|
+
|
|
1069
|
+
# fixed tick labels
|
|
1070
|
+
tick_pos = _fixed_tick_positions(len(labels), n_xticks)
|
|
1071
|
+
axes_heat[i].set_xticks(tick_pos)
|
|
1072
|
+
axes_heat[i].set_xticklabels(
|
|
1073
|
+
labels[tick_pos],
|
|
1074
|
+
rotation=xtick_rotation,
|
|
1075
|
+
fontsize=xtick_fontsize
|
|
1076
|
+
)
|
|
1077
|
+
|
|
1078
|
+
# bin separators
|
|
169
1079
|
for boundary in bin_boundaries[:-1]:
|
|
170
|
-
axes_heat[
|
|
1080
|
+
axes_heat[i].axhline(y=boundary, color="black", linewidth=2)
|
|
171
1081
|
|
|
172
|
-
|
|
1082
|
+
axes_heat[i].set_xlabel("Position", fontsize=9)
|
|
173
1083
|
|
|
174
|
-
|
|
175
|
-
save_name = f"{ref} — {sample}"
|
|
176
|
-
os.makedirs(save_path, exist_ok=True)
|
|
177
|
-
safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
178
|
-
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
179
|
-
plt.savefig(out_file, dpi=300)
|
|
180
|
-
print(f"📁 Saved: {out_file}")
|
|
1084
|
+
plt.tight_layout()
|
|
181
1085
|
|
|
1086
|
+
# save or show
|
|
1087
|
+
if save_path is not None:
|
|
1088
|
+
safe_name = f"{ref}__{display_sample}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
|
|
1089
|
+
out_file = save_path / f"{safe_name}.png"
|
|
1090
|
+
fig.savefig(out_file, dpi=300)
|
|
1091
|
+
plt.close(fig)
|
|
1092
|
+
print(f"Saved: {out_file}")
|
|
1093
|
+
else:
|
|
182
1094
|
plt.show()
|
|
183
1095
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
1096
|
+
# record results
|
|
1097
|
+
rec = {
|
|
1098
|
+
"sample": str(sample),
|
|
1099
|
+
"ref": str(ref),
|
|
1100
|
+
"row_labels": row_labels,
|
|
1101
|
+
"bin_labels": bin_labels,
|
|
1102
|
+
"bin_boundaries": bin_boundaries,
|
|
1103
|
+
"percentages": percentages,
|
|
1104
|
+
}
|
|
1105
|
+
for blk in blocks:
|
|
1106
|
+
rec[f"{blk['name']}_matrix"] = blk["matrix"]
|
|
1107
|
+
rec[f"{blk['name']}_labels"] = list(map(str, blk["labels"]))
|
|
1108
|
+
results.append(rec)
|
|
1109
|
+
|
|
1110
|
+
print(f"Summary for {display_sample} - {ref}:")
|
|
1111
|
+
for bin_label, percent in percentages.items():
|
|
1112
|
+
print(f" - {bin_label}: {percent:.1f}%")
|
|
201
1113
|
|
|
202
1114
|
except Exception as e:
|
|
203
1115
|
import traceback
|
|
204
1116
|
traceback.print_exc()
|
|
205
1117
|
continue
|
|
1118
|
+
|
|
1119
|
+
# store once at the end (HDF5 safe)
|
|
1120
|
+
# matrices won't be HDF5-safe; store only metadata + maybe hit counts
|
|
1121
|
+
# adata.uns["clustermap_results"] = [
|
|
1122
|
+
# {k: v for k, v in r.items() if not k.endswith("_matrix")}
|
|
1123
|
+
# for r in results
|
|
1124
|
+
# ]
|
|
1125
|
+
|
|
1126
|
+
return results
|
|
1127
|
+
|
|
1128
|
+
def plot_hmm_layers_rolling_by_sample_ref(
|
|
1129
|
+
adata,
|
|
1130
|
+
layers: Optional[Sequence[str]] = None,
|
|
1131
|
+
sample_col: str = "Barcode",
|
|
1132
|
+
ref_col: str = "Reference_strand",
|
|
1133
|
+
samples: Optional[Sequence[str]] = None,
|
|
1134
|
+
references: Optional[Sequence[str]] = None,
|
|
1135
|
+
window: int = 51,
|
|
1136
|
+
min_periods: int = 1,
|
|
1137
|
+
center: bool = True,
|
|
1138
|
+
rows_per_page: int = 6,
|
|
1139
|
+
figsize_per_cell: Tuple[float, float] = (4.0, 2.5),
|
|
1140
|
+
dpi: int = 160,
|
|
1141
|
+
output_dir: Optional[str] = None,
|
|
1142
|
+
save: bool = True,
|
|
1143
|
+
show_raw: bool = False,
|
|
1144
|
+
cmap: str = "tab10",
|
|
1145
|
+
use_var_coords: bool = True,
|
|
1146
|
+
):
|
|
1147
|
+
"""
|
|
1148
|
+
For each sample (row) and reference (col) plot the rolling average of the
|
|
1149
|
+
positional mean (mean across reads) for each layer listed.
|
|
1150
|
+
|
|
1151
|
+
Parameters
|
|
1152
|
+
----------
|
|
1153
|
+
adata : AnnData
|
|
1154
|
+
Input annotated data (expects obs columns sample_col and ref_col).
|
|
1155
|
+
layers : list[str] | None
|
|
1156
|
+
Which adata.layers to plot. If None, attempts to autodetect layers whose
|
|
1157
|
+
matrices look like "HMM" outputs (else will error). If None and layers
|
|
1158
|
+
cannot be found, user must pass a list.
|
|
1159
|
+
sample_col, ref_col : str
|
|
1160
|
+
obs columns used to group rows.
|
|
1161
|
+
samples, references : optional lists
|
|
1162
|
+
explicit ordering of samples / references. If None, categories in adata.obs are used.
|
|
1163
|
+
window : int
|
|
1164
|
+
rolling window size (odd recommended). If window <= 1, no smoothing applied.
|
|
1165
|
+
min_periods : int
|
|
1166
|
+
min periods param for pd.Series.rolling.
|
|
1167
|
+
center : bool
|
|
1168
|
+
center the rolling window.
|
|
1169
|
+
rows_per_page : int
|
|
1170
|
+
paginate rows per page into multiple figures if needed.
|
|
1171
|
+
figsize_per_cell : (w,h)
|
|
1172
|
+
per-subplot size in inches.
|
|
1173
|
+
dpi : int
|
|
1174
|
+
figure dpi when saving.
|
|
1175
|
+
output_dir : str | None
|
|
1176
|
+
directory to save pages; created if necessary. If None and save=True, uses cwd.
|
|
1177
|
+
save : bool
|
|
1178
|
+
whether to save PNG files.
|
|
1179
|
+
show_raw : bool
|
|
1180
|
+
draw unsmoothed mean as faint line under smoothed curve.
|
|
1181
|
+
cmap : str
|
|
1182
|
+
matplotlib colormap for layer lines.
|
|
1183
|
+
use_var_coords : bool
|
|
1184
|
+
if True, tries to use adata.var_names (coerced to int) as x-axis coordinates; otherwise uses 0..n-1.
|
|
1185
|
+
|
|
1186
|
+
Returns
|
|
1187
|
+
-------
|
|
1188
|
+
saved_files : list[str]
|
|
1189
|
+
list of saved filenames (may be empty if save=False).
|
|
1190
|
+
"""
|
|
1191
|
+
|
|
1192
|
+
# --- basic checks / defaults ---
|
|
1193
|
+
if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
|
|
1194
|
+
raise ValueError(f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs")
|
|
1195
|
+
|
|
1196
|
+
# canonicalize samples / refs
|
|
1197
|
+
if samples is None:
|
|
1198
|
+
sseries = adata.obs[sample_col]
|
|
1199
|
+
if not pd.api.types.is_categorical_dtype(sseries):
|
|
1200
|
+
sseries = sseries.astype("category")
|
|
1201
|
+
samples_all = list(sseries.cat.categories)
|
|
1202
|
+
else:
|
|
1203
|
+
samples_all = list(samples)
|
|
1204
|
+
|
|
1205
|
+
if references is None:
|
|
1206
|
+
rseries = adata.obs[ref_col]
|
|
1207
|
+
if not pd.api.types.is_categorical_dtype(rseries):
|
|
1208
|
+
rseries = rseries.astype("category")
|
|
1209
|
+
refs_all = list(rseries.cat.categories)
|
|
1210
|
+
else:
|
|
1211
|
+
refs_all = list(references)
|
|
1212
|
+
|
|
1213
|
+
# choose layers: if not provided, try a sensible default: all layers
|
|
1214
|
+
if layers is None:
|
|
1215
|
+
layers = list(adata.layers.keys())
|
|
1216
|
+
if len(layers) == 0:
|
|
1217
|
+
raise ValueError("No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot.")
|
|
1218
|
+
layers = list(layers)
|
|
1219
|
+
|
|
1220
|
+
# x coordinates (positions)
|
|
1221
|
+
try:
|
|
1222
|
+
if use_var_coords:
|
|
1223
|
+
x_coords = np.array([int(v) for v in adata.var_names])
|
|
1224
|
+
else:
|
|
1225
|
+
raise Exception("user disabled var coords")
|
|
1226
|
+
except Exception:
|
|
1227
|
+
# fallback to 0..n_vars-1
|
|
1228
|
+
x_coords = np.arange(adata.shape[1], dtype=int)
|
|
1229
|
+
|
|
1230
|
+
# make output dir
|
|
1231
|
+
if save:
|
|
1232
|
+
outdir = output_dir or os.getcwd()
|
|
1233
|
+
os.makedirs(outdir, exist_ok=True)
|
|
1234
|
+
else:
|
|
1235
|
+
outdir = None
|
|
1236
|
+
|
|
1237
|
+
n_samples = len(samples_all)
|
|
1238
|
+
n_refs = len(refs_all)
|
|
1239
|
+
total_pages = math.ceil(n_samples / rows_per_page)
|
|
1240
|
+
saved_files = []
|
|
1241
|
+
|
|
1242
|
+
# color cycle for layers
|
|
1243
|
+
cmap_obj = plt.get_cmap(cmap)
|
|
1244
|
+
n_layers = max(1, len(layers))
|
|
1245
|
+
colors = [cmap_obj(i / max(1, n_layers - 1)) for i in range(n_layers)]
|
|
1246
|
+
|
|
1247
|
+
for page in range(total_pages):
|
|
1248
|
+
start = page * rows_per_page
|
|
1249
|
+
end = min(start + rows_per_page, n_samples)
|
|
1250
|
+
chunk = samples_all[start:end]
|
|
1251
|
+
nrows = len(chunk)
|
|
1252
|
+
ncols = n_refs
|
|
1253
|
+
|
|
1254
|
+
fig_w = figsize_per_cell[0] * ncols
|
|
1255
|
+
fig_h = figsize_per_cell[1] * nrows
|
|
1256
|
+
fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
|
|
1257
|
+
figsize=(fig_w, fig_h), dpi=dpi,
|
|
1258
|
+
squeeze=False)
|
|
1259
|
+
|
|
1260
|
+
for r_idx, sample_name in enumerate(chunk):
|
|
1261
|
+
for c_idx, ref_name in enumerate(refs_all):
|
|
1262
|
+
ax = axes[r_idx][c_idx]
|
|
1263
|
+
|
|
1264
|
+
# subset adata
|
|
1265
|
+
mask = (adata.obs[sample_col].values == sample_name) & (adata.obs[ref_col].values == ref_name)
|
|
1266
|
+
sub = adata[mask]
|
|
1267
|
+
if sub.n_obs == 0:
|
|
1268
|
+
ax.text(0.5, 0.5, "No reads", ha="center", va="center", transform=ax.transAxes, color="gray")
|
|
1269
|
+
ax.set_xticks([])
|
|
1270
|
+
ax.set_yticks([])
|
|
1271
|
+
if r_idx == 0:
|
|
1272
|
+
ax.set_title(str(ref_name), fontsize=9)
|
|
1273
|
+
if c_idx == 0:
|
|
1274
|
+
total_reads = int((adata.obs[sample_col] == sample_name).sum())
|
|
1275
|
+
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
|
|
1276
|
+
continue
|
|
1277
|
+
|
|
1278
|
+
# for each layer, compute positional mean across reads (ignore NaNs)
|
|
1279
|
+
plotted_any = False
|
|
1280
|
+
for li, layer in enumerate(layers):
|
|
1281
|
+
if layer in sub.layers:
|
|
1282
|
+
mat = sub.layers[layer]
|
|
1283
|
+
else:
|
|
1284
|
+
# fallback: try .X only for the first layer if layer not present
|
|
1285
|
+
if layer == layers[0] and getattr(sub, "X", None) is not None:
|
|
1286
|
+
mat = sub.X
|
|
1287
|
+
else:
|
|
1288
|
+
# layer not present for this subset
|
|
1289
|
+
continue
|
|
1290
|
+
|
|
1291
|
+
# convert matrix to numpy 2D
|
|
1292
|
+
if hasattr(mat, "toarray"):
|
|
1293
|
+
try:
|
|
1294
|
+
arr = mat.toarray()
|
|
1295
|
+
except Exception:
|
|
1296
|
+
arr = np.asarray(mat)
|
|
1297
|
+
else:
|
|
1298
|
+
arr = np.asarray(mat)
|
|
1299
|
+
|
|
1300
|
+
if arr.size == 0 or arr.shape[1] == 0:
|
|
1301
|
+
continue
|
|
1302
|
+
|
|
1303
|
+
# compute column-wise mean ignoring NaNs
|
|
1304
|
+
# if arr is boolean or int, convert to float to support NaN
|
|
1305
|
+
arr = arr.astype(float)
|
|
1306
|
+
with np.errstate(all="ignore"):
|
|
1307
|
+
col_mean = np.nanmean(arr, axis=0)
|
|
1308
|
+
|
|
1309
|
+
# If all-NaN, skip
|
|
1310
|
+
if np.all(np.isnan(col_mean)):
|
|
1311
|
+
continue
|
|
1312
|
+
|
|
1313
|
+
# smooth via pandas rolling (centered)
|
|
1314
|
+
if (window is None) or (window <= 1):
|
|
1315
|
+
smoothed = col_mean
|
|
1316
|
+
else:
|
|
1317
|
+
ser = pd.Series(col_mean)
|
|
1318
|
+
smoothed = ser.rolling(window=window, min_periods=min_periods, center=center).mean().to_numpy()
|
|
1319
|
+
|
|
1320
|
+
# x axis: x_coords (trim/pad to match length)
|
|
1321
|
+
L = len(col_mean)
|
|
1322
|
+
x = x_coords[:L]
|
|
1323
|
+
|
|
1324
|
+
# optionally plot raw faint line first
|
|
1325
|
+
if show_raw:
|
|
1326
|
+
ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
|
|
1327
|
+
|
|
1328
|
+
ax.plot(x, smoothed[:L], label=layer, color=colors[li], linewidth=1.2, alpha=0.95, zorder=2)
|
|
1329
|
+
plotted_any = True
|
|
1330
|
+
|
|
1331
|
+
# labels / titles
|
|
1332
|
+
if r_idx == 0:
|
|
1333
|
+
ax.set_title(str(ref_name), fontsize=9)
|
|
1334
|
+
if c_idx == 0:
|
|
1335
|
+
total_reads = int((adata.obs[sample_col] == sample_name).sum())
|
|
1336
|
+
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
|
|
1337
|
+
if r_idx == nrows - 1:
|
|
1338
|
+
ax.set_xlabel("position", fontsize=8)
|
|
1339
|
+
|
|
1340
|
+
# legend (only show in top-left plot to reduce clutter)
|
|
1341
|
+
if (r_idx == 0 and c_idx == 0) and plotted_any:
|
|
1342
|
+
ax.legend(fontsize=7, loc="upper right")
|
|
1343
|
+
|
|
1344
|
+
ax.grid(True, alpha=0.2)
|
|
1345
|
+
|
|
1346
|
+
fig.suptitle(f"Rolling mean of layer positional means (window={window}) — page {page+1}/{total_pages}", fontsize=11, y=0.995)
|
|
1347
|
+
fig.tight_layout(rect=[0, 0, 1, 0.97])
|
|
1348
|
+
|
|
1349
|
+
if save:
|
|
1350
|
+
fname = os.path.join(outdir, f"hmm_layers_rolling_page{page+1}.png")
|
|
1351
|
+
plt.savefig(fname, bbox_inches="tight", dpi=dpi)
|
|
1352
|
+
saved_files.append(fname)
|
|
1353
|
+
else:
|
|
1354
|
+
plt.show()
|
|
1355
|
+
plt.close(fig)
|
|
1356
|
+
|
|
1357
|
+
return saved_files
|