smftools 0.1.7__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +9 -4
- smftools/_version.py +1 -1
- smftools/cli.py +184 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +33 -0
- smftools/config/deaminase.yaml +56 -0
- smftools/config/default.yaml +253 -0
- smftools/config/direct.yaml +17 -0
- smftools/config/experiment_config.py +1191 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/{tools → hmm}/display_hmm.py +3 -3
- smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
- smftools/{tools → hmm}/train_hmm.py +1 -1
- smftools/informatics/__init__.py +0 -2
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/fast5_to_pod5.py +4 -1
- smftools/informatics/helpers/__init__.py +3 -4
- smftools/informatics/helpers/align_and_sort_BAM.py +34 -7
- smftools/informatics/helpers/aligned_BAM_to_bed.py +35 -24
- smftools/informatics/helpers/binarize_converted_base_identities.py +116 -23
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +365 -42
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +165 -29
- smftools/informatics/helpers/discover_input_files.py +100 -0
- smftools/informatics/helpers/extract_base_identities.py +29 -3
- smftools/informatics/helpers/extract_read_features_from_bam.py +4 -2
- smftools/informatics/helpers/find_conversion_sites.py +5 -4
- smftools/informatics/helpers/modkit_extract_to_adata.py +6 -3
- smftools/informatics/helpers/plot_bed_histograms.py +269 -0
- smftools/informatics/helpers/separate_bam_by_bc.py +2 -2
- smftools/informatics/helpers/split_and_index_BAM.py +1 -5
- smftools/load_adata.py +1346 -0
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/{tools → machine_learning}/models/positional.py +3 -2
- smftools/{tools → machine_learning}/models/rnn.py +2 -1
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/plotting/__init__.py +4 -1
- smftools/plotting/autocorrelation_plotting.py +611 -0
- smftools/plotting/general_plotting.py +566 -89
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +13 -8
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_coverage.py +10 -1
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +17 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1326 -124
- smftools/preprocessing/invert_adata.py +12 -5
- smftools/preprocessing/load_sample_sheet.py +19 -4
- smftools/readwrite.py +849 -43
- smftools/tools/__init__.py +3 -32
- smftools/tools/calculate_umap.py +5 -5
- smftools/tools/general_tools.py +3 -3
- smftools/tools/position_stats.py +468 -106
- smftools/tools/read_stats.py +115 -1
- smftools/tools/spatial_autocorrelation.py +562 -0
- {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/METADATA +5 -1
- smftools-0.2.1.dist-info/RECORD +161 -0
- smftools-0.2.1.dist-info/entry_points.txt +2 -0
- smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
- smftools/informatics/load_adata.py +0 -182
- smftools/preprocessing/append_C_context.py +0 -82
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
- smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
- smftools/preprocessing/filter_reads_on_length.py +0 -51
- smftools/tools/call_hmm_peaks.py +0 -105
- smftools/tools/data/__init__.py +0 -2
- smftools/tools/data/anndata_data_module.py +0 -90
- smftools/tools/evaluation/__init__.py +0 -0
- smftools/tools/inference/__init__.py +0 -1
- smftools/tools/inference/lightning_inference.py +0 -41
- smftools/tools/models/base.py +0 -14
- smftools/tools/models/cnn.py +0 -34
- smftools/tools/models/lightning_base.py +0 -41
- smftools/tools/models/mlp.py +0 -17
- smftools/tools/models/sklearn_models.py +0 -40
- smftools/tools/models/transformer.py +0 -133
- smftools/tools/training/__init__.py +0 -1
- smftools/tools/training/train_lightning_model.py +0 -47
- smftools-0.1.7.dist-info/RECORD +0 -136
- /smftools/{tools → hmm}/calculate_distances.py +0 -0
- /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
- /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
- /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
- /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
- /smftools/{tools → machine_learning}/models/__init__.py +0 -0
- /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
- /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
- /smftools/{tools → machine_learning}/utils/device.py +0 -0
- /smftools/{tools → machine_learning}/utils/grl.py +0 -0
- /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
- /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -17,24 +17,30 @@ def clean_barplot(ax, mean_values, title):
|
|
|
17
17
|
|
|
18
18
|
ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
|
|
19
19
|
|
|
20
|
-
|
|
21
20
|
def combined_hmm_raw_clustermap(
|
|
22
21
|
adata,
|
|
23
22
|
sample_col='Sample_Names',
|
|
23
|
+
reference_col='Reference_strand',
|
|
24
24
|
hmm_feature_layer="hmm_combined",
|
|
25
25
|
layer_gpc="nan0_0minus1",
|
|
26
26
|
layer_cpg="nan0_0minus1",
|
|
27
|
+
layer_any_c="nan0_0minus1",
|
|
27
28
|
cmap_hmm="tab10",
|
|
28
29
|
cmap_gpc="coolwarm",
|
|
29
30
|
cmap_cpg="viridis",
|
|
31
|
+
cmap_any_c='coolwarm',
|
|
30
32
|
min_quality=20,
|
|
31
|
-
min_length=
|
|
33
|
+
min_length=200,
|
|
34
|
+
min_mapped_length_to_reference_length_ratio=0.8,
|
|
35
|
+
min_position_valid_fraction=0.5,
|
|
32
36
|
sample_mapping=None,
|
|
33
37
|
save_path=None,
|
|
34
38
|
normalize_hmm=False,
|
|
35
39
|
sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
|
|
36
|
-
bins=None
|
|
37
|
-
|
|
40
|
+
bins=None,
|
|
41
|
+
deaminase=False,
|
|
42
|
+
min_signal=0
|
|
43
|
+
):
|
|
38
44
|
import scipy.cluster.hierarchy as sch
|
|
39
45
|
import pandas as pd
|
|
40
46
|
import numpy as np
|
|
@@ -44,38 +50,51 @@ def combined_hmm_raw_clustermap(
|
|
|
44
50
|
import os
|
|
45
51
|
|
|
46
52
|
results = []
|
|
53
|
+
if deaminase:
|
|
54
|
+
signal_type = 'deamination'
|
|
55
|
+
else:
|
|
56
|
+
signal_type = 'methylation'
|
|
47
57
|
|
|
48
|
-
for ref in adata.obs[
|
|
58
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
49
59
|
for sample in adata.obs[sample_col].cat.categories:
|
|
50
60
|
try:
|
|
51
61
|
subset = adata[
|
|
52
|
-
(adata.obs[
|
|
62
|
+
(adata.obs[reference_col] == ref) &
|
|
53
63
|
(adata.obs[sample_col] == sample) &
|
|
54
|
-
(adata.obs['
|
|
64
|
+
(adata.obs['read_quality'] >= min_quality) &
|
|
55
65
|
(adata.obs['read_length'] >= min_length) &
|
|
56
|
-
(adata.obs['
|
|
66
|
+
(adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
|
|
57
67
|
]
|
|
58
68
|
|
|
59
|
-
|
|
69
|
+
mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
|
|
70
|
+
subset = subset[:, mask]
|
|
60
71
|
|
|
61
72
|
if subset.shape[0] == 0:
|
|
62
|
-
print(f"
|
|
73
|
+
print(f" No reads left after filtering for {sample} - {ref}")
|
|
63
74
|
continue
|
|
64
75
|
|
|
65
76
|
if bins:
|
|
66
|
-
|
|
77
|
+
print(f"Using defined bins to subset clustermap for {sample} - {ref}")
|
|
78
|
+
bins_temp = bins
|
|
67
79
|
else:
|
|
68
|
-
|
|
80
|
+
print(f"Using all reads for clustermap for {sample} - {ref}")
|
|
81
|
+
bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
|
|
69
82
|
|
|
70
83
|
# Get column positions (not var_names!) of site masks
|
|
71
84
|
gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
|
|
72
85
|
cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
|
|
86
|
+
any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
|
|
87
|
+
num_gpc = len(gpc_sites)
|
|
88
|
+
num_cpg = len(cpg_sites)
|
|
89
|
+
num_c = len(any_c_sites)
|
|
90
|
+
print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites} for {sample} - {ref}")
|
|
73
91
|
|
|
74
92
|
# Use var_names for x-axis tick labels
|
|
75
93
|
gpc_labels = subset.var_names[gpc_sites].astype(int)
|
|
76
94
|
cpg_labels = subset.var_names[cpg_sites].astype(int)
|
|
95
|
+
any_c_labels = subset.var_names[any_c_sites].astype(int)
|
|
77
96
|
|
|
78
|
-
stacked_hmm_feature, stacked_gpc, stacked_cpg = [], [], []
|
|
97
|
+
stacked_hmm_feature, stacked_gpc, stacked_cpg, stacked_any_c = [], [], [], []
|
|
79
98
|
row_labels, bin_labels = [], []
|
|
80
99
|
bin_boundaries = []
|
|
81
100
|
|
|
@@ -83,13 +102,14 @@ def combined_hmm_raw_clustermap(
|
|
|
83
102
|
percentages = {}
|
|
84
103
|
last_idx = 0
|
|
85
104
|
|
|
86
|
-
for bin_label, bin_filter in
|
|
105
|
+
for bin_label, bin_filter in bins_temp.items():
|
|
87
106
|
subset_bin = subset[bin_filter].copy()
|
|
88
107
|
num_reads = subset_bin.shape[0]
|
|
108
|
+
print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
|
|
89
109
|
percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
|
|
90
110
|
percentages[bin_label] = percent_reads
|
|
91
111
|
|
|
92
|
-
if num_reads > 0:
|
|
112
|
+
if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
|
|
93
113
|
# Determine sorting order
|
|
94
114
|
if sort_by.startswith("obs:"):
|
|
95
115
|
colname = sort_by.split("obs:")[1]
|
|
@@ -105,12 +125,16 @@ def combined_hmm_raw_clustermap(
|
|
|
105
125
|
order = sch.leaves_list(linkage)
|
|
106
126
|
elif sort_by == "none":
|
|
107
127
|
order = np.arange(num_reads)
|
|
128
|
+
elif sort_by == "any_c":
|
|
129
|
+
linkage = sch.linkage(subset_bin.layers[layer_any_c], method="ward")
|
|
130
|
+
order = sch.leaves_list(linkage)
|
|
108
131
|
else:
|
|
109
132
|
raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
110
133
|
|
|
111
134
|
stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
|
|
112
135
|
stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
|
|
113
136
|
stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
|
|
137
|
+
stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
|
|
114
138
|
|
|
115
139
|
row_labels.extend([bin_label] * num_reads)
|
|
116
140
|
bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
@@ -121,85 +145,538 @@ def combined_hmm_raw_clustermap(
|
|
|
121
145
|
hmm_matrix = np.vstack(stacked_hmm_feature)
|
|
122
146
|
gpc_matrix = np.vstack(stacked_gpc)
|
|
123
147
|
cpg_matrix = np.vstack(stacked_cpg)
|
|
148
|
+
any_c_matrix = np.vstack(stacked_any_c)
|
|
124
149
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
150
|
+
if hmm_matrix.size > 0:
|
|
151
|
+
def normalized_mean(matrix):
|
|
152
|
+
mean = np.nanmean(matrix, axis=0)
|
|
153
|
+
normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
|
|
154
|
+
return normalized
|
|
129
155
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
156
|
+
def methylation_fraction(matrix):
|
|
157
|
+
methylated = (matrix == 1).sum(axis=0)
|
|
158
|
+
valid = (matrix != 0).sum(axis=0)
|
|
159
|
+
return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
|
|
134
160
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
axes_heat[0].
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
axes_heat[1].
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
axes_heat[2].
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
"
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
161
|
+
if normalize_hmm:
|
|
162
|
+
mean_hmm = normalized_mean(hmm_matrix)
|
|
163
|
+
else:
|
|
164
|
+
mean_hmm = np.nanmean(hmm_matrix, axis=0)
|
|
165
|
+
mean_gpc = methylation_fraction(gpc_matrix)
|
|
166
|
+
mean_cpg = methylation_fraction(cpg_matrix)
|
|
167
|
+
mean_any_c = methylation_fraction(any_c_matrix)
|
|
168
|
+
|
|
169
|
+
fig = plt.figure(figsize=(18, 12))
|
|
170
|
+
gs = gridspec.GridSpec(2, 4, height_ratios=[1, 6], hspace=0.01)
|
|
171
|
+
fig.suptitle(f"{sample} - {ref}", fontsize=14, y=0.95)
|
|
172
|
+
|
|
173
|
+
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(4)]
|
|
174
|
+
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(4)]
|
|
175
|
+
|
|
176
|
+
clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
|
|
177
|
+
clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
|
|
178
|
+
clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
|
|
179
|
+
clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
|
|
180
|
+
|
|
181
|
+
hmm_labels = subset.var_names.astype(int)
|
|
182
|
+
hmm_label_spacing = 150
|
|
183
|
+
sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
|
|
184
|
+
axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
|
|
185
|
+
axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
|
|
186
|
+
for boundary in bin_boundaries[:-1]:
|
|
187
|
+
axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
|
|
188
|
+
|
|
189
|
+
sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
|
|
190
|
+
axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
|
|
191
|
+
axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
|
|
192
|
+
for boundary in bin_boundaries[:-1]:
|
|
193
|
+
axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
|
|
194
|
+
|
|
195
|
+
sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
|
|
196
|
+
axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
|
|
197
|
+
for boundary in bin_boundaries[:-1]:
|
|
198
|
+
axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
|
|
199
|
+
|
|
200
|
+
sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[3], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
|
|
201
|
+
axes_heat[3].set_xticks(range(0, len(any_c_labels), 20))
|
|
202
|
+
axes_heat[3].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
|
|
203
|
+
for boundary in bin_boundaries[:-1]:
|
|
204
|
+
axes_heat[3].axhline(y=boundary, color="black", linewidth=2)
|
|
205
|
+
|
|
206
|
+
plt.tight_layout()
|
|
207
|
+
|
|
208
|
+
if save_path:
|
|
209
|
+
save_name = f"{ref} — {sample}"
|
|
210
|
+
os.makedirs(save_path, exist_ok=True)
|
|
211
|
+
safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
212
|
+
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
213
|
+
plt.savefig(out_file, dpi=300)
|
|
214
|
+
print(f"Saved: {out_file}")
|
|
215
|
+
plt.close()
|
|
216
|
+
else:
|
|
217
|
+
plt.show()
|
|
218
|
+
|
|
219
|
+
print(f"Summary for {sample} - {ref}:")
|
|
220
|
+
for bin_label, percent in percentages.items():
|
|
221
|
+
print(f" - {bin_label}: {percent:.1f}%")
|
|
222
|
+
|
|
223
|
+
results.append({
|
|
224
|
+
"sample": sample,
|
|
225
|
+
"ref": ref,
|
|
226
|
+
"hmm_matrix": hmm_matrix,
|
|
227
|
+
"gpc_matrix": gpc_matrix,
|
|
228
|
+
"cpg_matrix": cpg_matrix,
|
|
229
|
+
"row_labels": row_labels,
|
|
230
|
+
"bin_labels": bin_labels,
|
|
231
|
+
"bin_boundaries": bin_boundaries,
|
|
232
|
+
"percentages": percentages
|
|
233
|
+
})
|
|
234
|
+
|
|
235
|
+
adata.uns['clustermap_results'] = results
|
|
236
|
+
|
|
237
|
+
except Exception as e:
|
|
238
|
+
import traceback
|
|
239
|
+
traceback.print_exc()
|
|
240
|
+
continue
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def combined_raw_clustermap(
|
|
244
|
+
adata,
|
|
245
|
+
sample_col='Sample_Names',
|
|
246
|
+
reference_col='Reference_strand',
|
|
247
|
+
layer_any_c="nan0_0minus1",
|
|
248
|
+
layer_gpc="nan0_0minus1",
|
|
249
|
+
layer_cpg="nan0_0minus1",
|
|
250
|
+
cmap_any_c="coolwarm",
|
|
251
|
+
cmap_gpc="coolwarm",
|
|
252
|
+
cmap_cpg="viridis",
|
|
253
|
+
min_quality=20,
|
|
254
|
+
min_length=200,
|
|
255
|
+
min_mapped_length_to_reference_length_ratio=0.8,
|
|
256
|
+
min_position_valid_fraction=0.5,
|
|
257
|
+
sample_mapping=None,
|
|
258
|
+
save_path=None,
|
|
259
|
+
sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
|
|
260
|
+
bins=None,
|
|
261
|
+
deaminase=False,
|
|
262
|
+
min_signal=0
|
|
263
|
+
):
|
|
264
|
+
import scipy.cluster.hierarchy as sch
|
|
265
|
+
import pandas as pd
|
|
266
|
+
import numpy as np
|
|
267
|
+
import seaborn as sns
|
|
268
|
+
import matplotlib.pyplot as plt
|
|
269
|
+
import matplotlib.gridspec as gridspec
|
|
270
|
+
import os
|
|
271
|
+
|
|
272
|
+
results = []
|
|
273
|
+
|
|
274
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
275
|
+
for sample in adata.obs[sample_col].cat.categories:
|
|
276
|
+
try:
|
|
277
|
+
subset = adata[
|
|
278
|
+
(adata.obs[reference_col] == ref) &
|
|
279
|
+
(adata.obs[sample_col] == sample) &
|
|
280
|
+
(adata.obs['read_quality'] >= min_quality) &
|
|
281
|
+
(adata.obs['mapped_length'] >= min_length) &
|
|
282
|
+
(adata.obs['mapped_length_to_reference_length_ratio'] >= min_mapped_length_to_reference_length_ratio)
|
|
283
|
+
]
|
|
284
|
+
|
|
285
|
+
mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
|
|
286
|
+
subset = subset[:, mask]
|
|
287
|
+
|
|
288
|
+
if subset.shape[0] == 0:
|
|
289
|
+
print(f" No reads left after filtering for {sample} - {ref}")
|
|
290
|
+
continue
|
|
291
|
+
|
|
292
|
+
if bins:
|
|
293
|
+
print(f"Using defined bins to subset clustermap for {sample} - {ref}")
|
|
294
|
+
bins_temp = bins
|
|
295
|
+
else:
|
|
296
|
+
print(f"Using all reads for clustermap for {sample} - {ref}")
|
|
297
|
+
bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
|
|
298
|
+
|
|
299
|
+
# Get column positions (not var_names!) of site masks
|
|
300
|
+
any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
|
|
301
|
+
gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
|
|
302
|
+
cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
|
|
303
|
+
num_any_c = len(any_c_sites)
|
|
304
|
+
num_gpc = len(gpc_sites)
|
|
305
|
+
num_cpg = len(cpg_sites)
|
|
306
|
+
print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites}\n and {num_any_c} any_C sites at {any_c_sites} for {sample} - {ref}")
|
|
307
|
+
|
|
308
|
+
# Use var_names for x-axis tick labels
|
|
309
|
+
gpc_labels = subset.var_names[gpc_sites].astype(int)
|
|
310
|
+
cpg_labels = subset.var_names[cpg_sites].astype(int)
|
|
311
|
+
any_c_labels = subset.var_names[any_c_sites].astype(int)
|
|
312
|
+
|
|
313
|
+
stacked_any_c, stacked_gpc, stacked_cpg = [], [], []
|
|
314
|
+
row_labels, bin_labels = [], []
|
|
315
|
+
bin_boundaries = []
|
|
316
|
+
|
|
317
|
+
total_reads = subset.shape[0]
|
|
318
|
+
percentages = {}
|
|
319
|
+
last_idx = 0
|
|
320
|
+
|
|
321
|
+
for bin_label, bin_filter in bins_temp.items():
|
|
322
|
+
subset_bin = subset[bin_filter].copy()
|
|
323
|
+
num_reads = subset_bin.shape[0]
|
|
324
|
+
print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
|
|
325
|
+
percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
|
|
326
|
+
percentages[bin_label] = percent_reads
|
|
327
|
+
|
|
328
|
+
if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
|
|
329
|
+
# Determine sorting order
|
|
330
|
+
if sort_by.startswith("obs:"):
|
|
331
|
+
colname = sort_by.split("obs:")[1]
|
|
332
|
+
order = np.argsort(subset_bin.obs[colname].values)
|
|
333
|
+
elif sort_by == "gpc":
|
|
334
|
+
linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
|
|
335
|
+
order = sch.leaves_list(linkage)
|
|
336
|
+
elif sort_by == "cpg":
|
|
337
|
+
linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
|
|
338
|
+
order = sch.leaves_list(linkage)
|
|
339
|
+
elif sort_by == "any_c":
|
|
340
|
+
linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
|
|
341
|
+
order = sch.leaves_list(linkage)
|
|
342
|
+
elif sort_by == "gpc_cpg":
|
|
343
|
+
linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
|
|
344
|
+
order = sch.leaves_list(linkage)
|
|
345
|
+
elif sort_by == "none":
|
|
346
|
+
order = np.arange(num_reads)
|
|
347
|
+
else:
|
|
348
|
+
raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
349
|
+
|
|
350
|
+
stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
|
|
351
|
+
stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
|
|
352
|
+
stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
|
|
353
|
+
|
|
354
|
+
row_labels.extend([bin_label] * num_reads)
|
|
355
|
+
bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
356
|
+
last_idx += num_reads
|
|
357
|
+
bin_boundaries.append(last_idx)
|
|
358
|
+
|
|
359
|
+
if stacked_any_c:
|
|
360
|
+
any_c_matrix = np.vstack(stacked_any_c)
|
|
361
|
+
gpc_matrix = np.vstack(stacked_gpc)
|
|
362
|
+
cpg_matrix = np.vstack(stacked_cpg)
|
|
363
|
+
|
|
364
|
+
if any_c_matrix.size > 0:
|
|
365
|
+
def normalized_mean(matrix):
|
|
366
|
+
mean = np.nanmean(matrix, axis=0)
|
|
367
|
+
normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
|
|
368
|
+
return normalized
|
|
369
|
+
|
|
370
|
+
def methylation_fraction(matrix):
|
|
371
|
+
methylated = (matrix == 1).sum(axis=0)
|
|
372
|
+
valid = (matrix != 0).sum(axis=0)
|
|
373
|
+
return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
|
|
374
|
+
|
|
375
|
+
mean_gpc = methylation_fraction(gpc_matrix)
|
|
376
|
+
mean_cpg = methylation_fraction(cpg_matrix)
|
|
377
|
+
mean_any_c = methylation_fraction(any_c_matrix)
|
|
378
|
+
|
|
379
|
+
fig = plt.figure(figsize=(18, 12))
|
|
380
|
+
gs = gridspec.GridSpec(2, 3, height_ratios=[1, 6], hspace=0.01)
|
|
381
|
+
fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
|
|
382
|
+
|
|
383
|
+
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(3)]
|
|
384
|
+
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(3)]
|
|
385
|
+
|
|
386
|
+
clean_barplot(axes_bar[0], mean_any_c, f"any C site Modification Signal")
|
|
387
|
+
clean_barplot(axes_bar[1], mean_gpc, f"GpC Modification Signal")
|
|
388
|
+
clean_barplot(axes_bar[2], mean_cpg, f"CpG Modification Signal")
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[0], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
|
|
392
|
+
axes_heat[0].set_xticks(range(0, len(any_c_labels), 20))
|
|
393
|
+
axes_heat[0].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
|
|
394
|
+
for boundary in bin_boundaries[:-1]:
|
|
395
|
+
axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
|
|
396
|
+
|
|
397
|
+
sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
|
|
398
|
+
axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
|
|
399
|
+
axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
|
|
400
|
+
for boundary in bin_boundaries[:-1]:
|
|
401
|
+
axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
|
|
402
|
+
|
|
403
|
+
sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
|
|
404
|
+
axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
|
|
405
|
+
for boundary in bin_boundaries[:-1]:
|
|
406
|
+
axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
|
|
407
|
+
|
|
408
|
+
plt.tight_layout()
|
|
409
|
+
|
|
410
|
+
if save_path:
|
|
411
|
+
save_name = f"{ref} — {sample}"
|
|
412
|
+
os.makedirs(save_path, exist_ok=True)
|
|
413
|
+
safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
414
|
+
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
415
|
+
plt.savefig(out_file, dpi=300)
|
|
416
|
+
print(f"Saved: {out_file}")
|
|
417
|
+
plt.close()
|
|
418
|
+
else:
|
|
419
|
+
plt.show()
|
|
420
|
+
|
|
421
|
+
print(f"Summary for {sample} - {ref}:")
|
|
422
|
+
for bin_label, percent in percentages.items():
|
|
423
|
+
print(f" - {bin_label}: {percent:.1f}%")
|
|
424
|
+
|
|
425
|
+
results.append({
|
|
426
|
+
"sample": sample,
|
|
427
|
+
"ref": ref,
|
|
428
|
+
"any_c_matrix": any_c_matrix,
|
|
429
|
+
"gpc_matrix": gpc_matrix,
|
|
430
|
+
"cpg_matrix": cpg_matrix,
|
|
431
|
+
"row_labels": row_labels,
|
|
432
|
+
"bin_labels": bin_labels,
|
|
433
|
+
"bin_boundaries": bin_boundaries,
|
|
434
|
+
"percentages": percentages
|
|
435
|
+
})
|
|
436
|
+
|
|
437
|
+
adata.uns['clustermap_results'] = results
|
|
201
438
|
|
|
202
439
|
except Exception as e:
|
|
203
440
|
import traceback
|
|
204
441
|
traceback.print_exc()
|
|
205
442
|
continue
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
import os
|
|
446
|
+
import math
|
|
447
|
+
from typing import List, Optional, Sequence, Tuple
|
|
448
|
+
|
|
449
|
+
import numpy as np
|
|
450
|
+
import pandas as pd
|
|
451
|
+
import matplotlib.pyplot as plt
|
|
452
|
+
|
|
453
|
+
def plot_hmm_layers_rolling_by_sample_ref(
|
|
454
|
+
adata,
|
|
455
|
+
layers: Optional[Sequence[str]] = None,
|
|
456
|
+
sample_col: str = "Barcode",
|
|
457
|
+
ref_col: str = "Reference_strand",
|
|
458
|
+
samples: Optional[Sequence[str]] = None,
|
|
459
|
+
references: Optional[Sequence[str]] = None,
|
|
460
|
+
window: int = 51,
|
|
461
|
+
min_periods: int = 1,
|
|
462
|
+
center: bool = True,
|
|
463
|
+
rows_per_page: int = 6,
|
|
464
|
+
figsize_per_cell: Tuple[float, float] = (4.0, 2.5),
|
|
465
|
+
dpi: int = 160,
|
|
466
|
+
output_dir: Optional[str] = None,
|
|
467
|
+
save: bool = True,
|
|
468
|
+
show_raw: bool = False,
|
|
469
|
+
cmap: str = "tab10",
|
|
470
|
+
use_var_coords: bool = True,
|
|
471
|
+
):
|
|
472
|
+
"""
|
|
473
|
+
For each sample (row) and reference (col) plot the rolling average of the
|
|
474
|
+
positional mean (mean across reads) for each layer listed.
|
|
475
|
+
|
|
476
|
+
Parameters
|
|
477
|
+
----------
|
|
478
|
+
adata : AnnData
|
|
479
|
+
Input annotated data (expects obs columns sample_col and ref_col).
|
|
480
|
+
layers : list[str] | None
|
|
481
|
+
Which adata.layers to plot. If None, attempts to autodetect layers whose
|
|
482
|
+
matrices look like "HMM" outputs (else will error). If None and layers
|
|
483
|
+
cannot be found, user must pass a list.
|
|
484
|
+
sample_col, ref_col : str
|
|
485
|
+
obs columns used to group rows.
|
|
486
|
+
samples, references : optional lists
|
|
487
|
+
explicit ordering of samples / references. If None, categories in adata.obs are used.
|
|
488
|
+
window : int
|
|
489
|
+
rolling window size (odd recommended). If window <= 1, no smoothing applied.
|
|
490
|
+
min_periods : int
|
|
491
|
+
min periods param for pd.Series.rolling.
|
|
492
|
+
center : bool
|
|
493
|
+
center the rolling window.
|
|
494
|
+
rows_per_page : int
|
|
495
|
+
paginate rows per page into multiple figures if needed.
|
|
496
|
+
figsize_per_cell : (w,h)
|
|
497
|
+
per-subplot size in inches.
|
|
498
|
+
dpi : int
|
|
499
|
+
figure dpi when saving.
|
|
500
|
+
output_dir : str | None
|
|
501
|
+
directory to save pages; created if necessary. If None and save=True, uses cwd.
|
|
502
|
+
save : bool
|
|
503
|
+
whether to save PNG files.
|
|
504
|
+
show_raw : bool
|
|
505
|
+
draw unsmoothed mean as faint line under smoothed curve.
|
|
506
|
+
cmap : str
|
|
507
|
+
matplotlib colormap for layer lines.
|
|
508
|
+
use_var_coords : bool
|
|
509
|
+
if True, tries to use adata.var_names (coerced to int) as x-axis coordinates; otherwise uses 0..n-1.
|
|
510
|
+
|
|
511
|
+
Returns
|
|
512
|
+
-------
|
|
513
|
+
saved_files : list[str]
|
|
514
|
+
list of saved filenames (may be empty if save=False).
|
|
515
|
+
"""
|
|
516
|
+
|
|
517
|
+
# --- basic checks / defaults ---
|
|
518
|
+
if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
|
|
519
|
+
raise ValueError(f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs")
|
|
520
|
+
|
|
521
|
+
# canonicalize samples / refs
|
|
522
|
+
if samples is None:
|
|
523
|
+
sseries = adata.obs[sample_col]
|
|
524
|
+
if not pd.api.types.is_categorical_dtype(sseries):
|
|
525
|
+
sseries = sseries.astype("category")
|
|
526
|
+
samples_all = list(sseries.cat.categories)
|
|
527
|
+
else:
|
|
528
|
+
samples_all = list(samples)
|
|
529
|
+
|
|
530
|
+
if references is None:
|
|
531
|
+
rseries = adata.obs[ref_col]
|
|
532
|
+
if not pd.api.types.is_categorical_dtype(rseries):
|
|
533
|
+
rseries = rseries.astype("category")
|
|
534
|
+
refs_all = list(rseries.cat.categories)
|
|
535
|
+
else:
|
|
536
|
+
refs_all = list(references)
|
|
537
|
+
|
|
538
|
+
# choose layers: if not provided, try a sensible default: all layers
|
|
539
|
+
if layers is None:
|
|
540
|
+
layers = list(adata.layers.keys())
|
|
541
|
+
if len(layers) == 0:
|
|
542
|
+
raise ValueError("No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot.")
|
|
543
|
+
layers = list(layers)
|
|
544
|
+
|
|
545
|
+
# x coordinates (positions)
|
|
546
|
+
try:
|
|
547
|
+
if use_var_coords:
|
|
548
|
+
x_coords = np.array([int(v) for v in adata.var_names])
|
|
549
|
+
else:
|
|
550
|
+
raise Exception("user disabled var coords")
|
|
551
|
+
except Exception:
|
|
552
|
+
# fallback to 0..n_vars-1
|
|
553
|
+
x_coords = np.arange(adata.shape[1], dtype=int)
|
|
554
|
+
|
|
555
|
+
# make output dir
|
|
556
|
+
if save:
|
|
557
|
+
outdir = output_dir or os.getcwd()
|
|
558
|
+
os.makedirs(outdir, exist_ok=True)
|
|
559
|
+
else:
|
|
560
|
+
outdir = None
|
|
561
|
+
|
|
562
|
+
n_samples = len(samples_all)
|
|
563
|
+
n_refs = len(refs_all)
|
|
564
|
+
total_pages = math.ceil(n_samples / rows_per_page)
|
|
565
|
+
saved_files = []
|
|
566
|
+
|
|
567
|
+
# color cycle for layers
|
|
568
|
+
cmap_obj = plt.get_cmap(cmap)
|
|
569
|
+
n_layers = max(1, len(layers))
|
|
570
|
+
colors = [cmap_obj(i / max(1, n_layers - 1)) for i in range(n_layers)]
|
|
571
|
+
|
|
572
|
+
for page in range(total_pages):
|
|
573
|
+
start = page * rows_per_page
|
|
574
|
+
end = min(start + rows_per_page, n_samples)
|
|
575
|
+
chunk = samples_all[start:end]
|
|
576
|
+
nrows = len(chunk)
|
|
577
|
+
ncols = n_refs
|
|
578
|
+
|
|
579
|
+
fig_w = figsize_per_cell[0] * ncols
|
|
580
|
+
fig_h = figsize_per_cell[1] * nrows
|
|
581
|
+
fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
|
|
582
|
+
figsize=(fig_w, fig_h), dpi=dpi,
|
|
583
|
+
squeeze=False)
|
|
584
|
+
|
|
585
|
+
for r_idx, sample_name in enumerate(chunk):
|
|
586
|
+
for c_idx, ref_name in enumerate(refs_all):
|
|
587
|
+
ax = axes[r_idx][c_idx]
|
|
588
|
+
|
|
589
|
+
# subset adata
|
|
590
|
+
mask = (adata.obs[sample_col].values == sample_name) & (adata.obs[ref_col].values == ref_name)
|
|
591
|
+
sub = adata[mask]
|
|
592
|
+
if sub.n_obs == 0:
|
|
593
|
+
ax.text(0.5, 0.5, "No reads", ha="center", va="center", transform=ax.transAxes, color="gray")
|
|
594
|
+
ax.set_xticks([])
|
|
595
|
+
ax.set_yticks([])
|
|
596
|
+
if r_idx == 0:
|
|
597
|
+
ax.set_title(str(ref_name), fontsize=9)
|
|
598
|
+
if c_idx == 0:
|
|
599
|
+
total_reads = int((adata.obs[sample_col] == sample_name).sum())
|
|
600
|
+
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
|
|
601
|
+
continue
|
|
602
|
+
|
|
603
|
+
# for each layer, compute positional mean across reads (ignore NaNs)
|
|
604
|
+
plotted_any = False
|
|
605
|
+
for li, layer in enumerate(layers):
|
|
606
|
+
if layer in sub.layers:
|
|
607
|
+
mat = sub.layers[layer]
|
|
608
|
+
else:
|
|
609
|
+
# fallback: try .X only for the first layer if layer not present
|
|
610
|
+
if layer == layers[0] and getattr(sub, "X", None) is not None:
|
|
611
|
+
mat = sub.X
|
|
612
|
+
else:
|
|
613
|
+
# layer not present for this subset
|
|
614
|
+
continue
|
|
615
|
+
|
|
616
|
+
# convert matrix to numpy 2D
|
|
617
|
+
if hasattr(mat, "toarray"):
|
|
618
|
+
try:
|
|
619
|
+
arr = mat.toarray()
|
|
620
|
+
except Exception:
|
|
621
|
+
arr = np.asarray(mat)
|
|
622
|
+
else:
|
|
623
|
+
arr = np.asarray(mat)
|
|
624
|
+
|
|
625
|
+
if arr.size == 0 or arr.shape[1] == 0:
|
|
626
|
+
continue
|
|
627
|
+
|
|
628
|
+
# compute column-wise mean ignoring NaNs
|
|
629
|
+
# if arr is boolean or int, convert to float to support NaN
|
|
630
|
+
arr = arr.astype(float)
|
|
631
|
+
with np.errstate(all="ignore"):
|
|
632
|
+
col_mean = np.nanmean(arr, axis=0)
|
|
633
|
+
|
|
634
|
+
# If all-NaN, skip
|
|
635
|
+
if np.all(np.isnan(col_mean)):
|
|
636
|
+
continue
|
|
637
|
+
|
|
638
|
+
# smooth via pandas rolling (centered)
|
|
639
|
+
if (window is None) or (window <= 1):
|
|
640
|
+
smoothed = col_mean
|
|
641
|
+
else:
|
|
642
|
+
ser = pd.Series(col_mean)
|
|
643
|
+
smoothed = ser.rolling(window=window, min_periods=min_periods, center=center).mean().to_numpy()
|
|
644
|
+
|
|
645
|
+
# x axis: x_coords (trim/pad to match length)
|
|
646
|
+
L = len(col_mean)
|
|
647
|
+
x = x_coords[:L]
|
|
648
|
+
|
|
649
|
+
# optionally plot raw faint line first
|
|
650
|
+
if show_raw:
|
|
651
|
+
ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
|
|
652
|
+
|
|
653
|
+
ax.plot(x, smoothed[:L], label=layer, color=colors[li], linewidth=1.2, alpha=0.95, zorder=2)
|
|
654
|
+
plotted_any = True
|
|
655
|
+
|
|
656
|
+
# labels / titles
|
|
657
|
+
if r_idx == 0:
|
|
658
|
+
ax.set_title(str(ref_name), fontsize=9)
|
|
659
|
+
if c_idx == 0:
|
|
660
|
+
total_reads = int((adata.obs[sample_col] == sample_name).sum())
|
|
661
|
+
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
|
|
662
|
+
if r_idx == nrows - 1:
|
|
663
|
+
ax.set_xlabel("position", fontsize=8)
|
|
664
|
+
|
|
665
|
+
# legend (only show in top-left plot to reduce clutter)
|
|
666
|
+
if (r_idx == 0 and c_idx == 0) and plotted_any:
|
|
667
|
+
ax.legend(fontsize=7, loc="upper right")
|
|
668
|
+
|
|
669
|
+
ax.grid(True, alpha=0.2)
|
|
670
|
+
|
|
671
|
+
fig.suptitle(f"Rolling mean of layer positional means (window={window}) — page {page+1}/{total_pages}", fontsize=11, y=0.995)
|
|
672
|
+
fig.tight_layout(rect=[0, 0, 1, 0.97])
|
|
673
|
+
|
|
674
|
+
if save:
|
|
675
|
+
fname = os.path.join(outdir, f"hmm_layers_rolling_page{page+1}.png")
|
|
676
|
+
plt.savefig(fname, bbox_inches="tight", dpi=dpi)
|
|
677
|
+
saved_files.append(fname)
|
|
678
|
+
else:
|
|
679
|
+
plt.show()
|
|
680
|
+
plt.close(fig)
|
|
681
|
+
|
|
682
|
+
return saved_files
|