smftools 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. smftools/__init__.py +2 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/__init__.py +0 -0
  4. smftools/cli/cli_flows.py +94 -0
  5. smftools/cli/hmm_adata.py +338 -0
  6. smftools/cli/load_adata.py +577 -0
  7. smftools/cli/preprocess_adata.py +363 -0
  8. smftools/cli/spatial_adata.py +564 -0
  9. smftools/cli_entry.py +435 -0
  10. smftools/config/conversion.yaml +11 -6
  11. smftools/config/deaminase.yaml +12 -7
  12. smftools/config/default.yaml +36 -25
  13. smftools/config/direct.yaml +25 -1
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +109 -12
  16. smftools/informatics/__init__.py +13 -7
  17. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  18. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  19. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  20. smftools/informatics/{helpers → archived/helpers/archived}/aligned_BAM_to_bed.py +6 -4
  21. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  22. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  23. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  24. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  25. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  26. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +1 -1
  27. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  28. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  29. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  30. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  31. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  32. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/plot_bed_histograms.py +0 -19
  34. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +6 -5
  35. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +7 -7
  36. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  37. smftools/informatics/bam_functions.py +812 -0
  38. smftools/informatics/basecalling.py +67 -0
  39. smftools/informatics/bed_functions.py +366 -0
  40. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +42 -30
  41. smftools/informatics/fasta_functions.py +255 -0
  42. smftools/informatics/h5ad_functions.py +197 -0
  43. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +142 -59
  44. smftools/informatics/modkit_functions.py +129 -0
  45. smftools/informatics/ohe.py +160 -0
  46. smftools/informatics/pod5_functions.py +224 -0
  47. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  48. smftools/plotting/autocorrelation_plotting.py +1 -3
  49. smftools/plotting/general_plotting.py +1037 -362
  50. smftools/preprocessing/__init__.py +2 -0
  51. smftools/preprocessing/append_base_context.py +3 -3
  52. smftools/preprocessing/append_binary_layer_by_base_context.py +4 -4
  53. smftools/preprocessing/binarize.py +17 -0
  54. smftools/preprocessing/binarize_on_Youden.py +2 -2
  55. smftools/preprocessing/calculate_position_Youden.py +1 -1
  56. smftools/preprocessing/calculate_read_modification_stats.py +1 -1
  57. smftools/preprocessing/filter_reads_on_modification_thresholds.py +19 -19
  58. smftools/preprocessing/flag_duplicate_reads.py +1 -1
  59. smftools/readwrite.py +266 -140
  60. {smftools-0.2.1.dist-info → smftools-0.2.3.dist-info}/METADATA +10 -9
  61. {smftools-0.2.1.dist-info → smftools-0.2.3.dist-info}/RECORD +82 -70
  62. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  63. smftools/cli.py +0 -184
  64. smftools/informatics/fast5_to_pod5.py +0 -24
  65. smftools/informatics/helpers/__init__.py +0 -73
  66. smftools/informatics/helpers/align_and_sort_BAM.py +0 -86
  67. smftools/informatics/helpers/bam_qc.py +0 -66
  68. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  69. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -378
  70. smftools/informatics/helpers/discover_input_files.py +0 -100
  71. smftools/informatics/helpers/index_fasta.py +0 -12
  72. smftools/informatics/helpers/make_dirs.py +0 -21
  73. smftools/informatics/readwrite.py +0 -106
  74. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  75. smftools/load_adata.py +0 -1346
  76. smftools-0.2.1.dist-info/entry_points.txt +0 -2
  77. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  78. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  79. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  80. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +0 -0
  81. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  82. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  83. /smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +0 -0
  84. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  85. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  86. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  87. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  88. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  89. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  90. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  91. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  92. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  93. /smftools/informatics/{helpers/binarize_converted_base_identities.py → binarize_converted_base_identities.py} +0 -0
  94. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  95. {smftools-0.2.1.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  96. {smftools-0.2.1.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,40 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import seaborn as sns
3
5
  import matplotlib.pyplot as plt
6
+ import scipy.cluster.hierarchy as sch
7
+ import matplotlib.gridspec as gridspec
8
+ import os
9
+ import math
10
+ import pandas as pd
11
+
12
+ from typing import Optional, Mapping, Sequence, Any, Dict, List
13
+ from pathlib import Path
14
+
15
+ def normalized_mean(matrix: np.ndarray) -> np.ndarray:
16
+ mean = np.nanmean(matrix, axis=0)
17
+ denom = (mean.max() - mean.min()) + 1e-9
18
+ return (mean - mean.min()) / denom
19
+
20
+ def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
21
+ """
22
+ Fraction methylated per column.
23
+ Methylated = 1
24
+ Valid = finite AND not 0
25
+ """
26
+ matrix = np.asarray(matrix)
27
+ valid_mask = np.isfinite(matrix) & (matrix != 0)
28
+ methyl_mask = (matrix == 1) & np.isfinite(matrix)
29
+
30
+ methylated = methyl_mask.sum(axis=0)
31
+ valid = valid_mask.sum(axis=0)
32
+
33
+ return np.divide(
34
+ methylated, valid,
35
+ out=np.zeros_like(methylated, dtype=float),
36
+ where=valid != 0
37
+ )
4
38
 
5
39
  def clean_barplot(ax, mean_values, title):
6
40
  x = np.arange(len(mean_values))
@@ -17,438 +51,1079 @@ def clean_barplot(ax, mean_values, title):
17
51
 
18
52
  ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
19
53
 
54
+ # def combined_hmm_raw_clustermap(
55
+ # adata,
56
+ # sample_col='Sample_Names',
57
+ # reference_col='Reference_strand',
58
+ # hmm_feature_layer="hmm_combined",
59
+ # layer_gpc="nan0_0minus1",
60
+ # layer_cpg="nan0_0minus1",
61
+ # layer_any_c="nan0_0minus1",
62
+ # cmap_hmm="tab10",
63
+ # cmap_gpc="coolwarm",
64
+ # cmap_cpg="viridis",
65
+ # cmap_any_c='coolwarm',
66
+ # min_quality=20,
67
+ # min_length=200,
68
+ # min_mapped_length_to_reference_length_ratio=0.8,
69
+ # min_position_valid_fraction=0.5,
70
+ # sample_mapping=None,
71
+ # save_path=None,
72
+ # normalize_hmm=False,
73
+ # sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
74
+ # bins=None,
75
+ # deaminase=False,
76
+ # min_signal=0
77
+ # ):
78
+
79
+ # results = []
80
+ # if deaminase:
81
+ # signal_type = 'deamination'
82
+ # else:
83
+ # signal_type = 'methylation'
84
+
85
+ # for ref in adata.obs[reference_col].cat.categories:
86
+ # for sample in adata.obs[sample_col].cat.categories:
87
+ # try:
88
+ # subset = adata[
89
+ # (adata.obs[reference_col] == ref) &
90
+ # (adata.obs[sample_col] == sample) &
91
+ # (adata.obs['read_quality'] >= min_quality) &
92
+ # (adata.obs['read_length'] >= min_length) &
93
+ # (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
94
+ # ]
95
+
96
+ # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
97
+ # subset = subset[:, mask]
98
+
99
+ # if subset.shape[0] == 0:
100
+ # print(f" No reads left after filtering for {sample} - {ref}")
101
+ # continue
102
+
103
+ # if bins:
104
+ # print(f"Using defined bins to subset clustermap for {sample} - {ref}")
105
+ # bins_temp = bins
106
+ # else:
107
+ # print(f"Using all reads for clustermap for {sample} - {ref}")
108
+ # bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
109
+
110
+ # # Get column positions (not var_names!) of site masks
111
+ # gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
112
+ # cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
113
+ # any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
114
+ # num_gpc = len(gpc_sites)
115
+ # num_cpg = len(cpg_sites)
116
+ # num_c = len(any_c_sites)
117
+ # print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites} for {sample} - {ref}")
118
+
119
+ # # Use var_names for x-axis tick labels
120
+ # gpc_labels = subset.var_names[gpc_sites].astype(int)
121
+ # cpg_labels = subset.var_names[cpg_sites].astype(int)
122
+ # any_c_labels = subset.var_names[any_c_sites].astype(int)
123
+
124
+ # stacked_hmm_feature, stacked_gpc, stacked_cpg, stacked_any_c = [], [], [], []
125
+ # row_labels, bin_labels = [], []
126
+ # bin_boundaries = []
127
+
128
+ # total_reads = subset.shape[0]
129
+ # percentages = {}
130
+ # last_idx = 0
131
+
132
+ # for bin_label, bin_filter in bins_temp.items():
133
+ # subset_bin = subset[bin_filter].copy()
134
+ # num_reads = subset_bin.shape[0]
135
+ # print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
136
+ # percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
137
+ # percentages[bin_label] = percent_reads
138
+
139
+ # if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
140
+ # # Determine sorting order
141
+ # if sort_by.startswith("obs:"):
142
+ # colname = sort_by.split("obs:")[1]
143
+ # order = np.argsort(subset_bin.obs[colname].values)
144
+ # elif sort_by == "gpc":
145
+ # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
146
+ # order = sch.leaves_list(linkage)
147
+ # elif sort_by == "cpg":
148
+ # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
149
+ # order = sch.leaves_list(linkage)
150
+ # elif sort_by == "gpc_cpg":
151
+ # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
152
+ # order = sch.leaves_list(linkage)
153
+ # elif sort_by == "none":
154
+ # order = np.arange(num_reads)
155
+ # elif sort_by == "any_c":
156
+ # linkage = sch.linkage(subset_bin.layers[layer_any_c], method="ward")
157
+ # order = sch.leaves_list(linkage)
158
+ # else:
159
+ # raise ValueError(f"Unsupported sort_by option: {sort_by}")
160
+
161
+ # stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
162
+ # stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
163
+ # stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
164
+ # stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
165
+
166
+ # row_labels.extend([bin_label] * num_reads)
167
+ # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
168
+ # last_idx += num_reads
169
+ # bin_boundaries.append(last_idx)
170
+
171
+ # if stacked_hmm_feature:
172
+ # hmm_matrix = np.vstack(stacked_hmm_feature)
173
+ # gpc_matrix = np.vstack(stacked_gpc)
174
+ # cpg_matrix = np.vstack(stacked_cpg)
175
+ # any_c_matrix = np.vstack(stacked_any_c)
176
+
177
+ # if hmm_matrix.size > 0:
178
+ # def normalized_mean(matrix):
179
+ # mean = np.nanmean(matrix, axis=0)
180
+ # normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
181
+ # return normalized
182
+
183
+ # def methylation_fraction(matrix):
184
+ # methylated = (matrix == 1).sum(axis=0)
185
+ # valid = (matrix != 0).sum(axis=0)
186
+ # return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
187
+
188
+ # if normalize_hmm:
189
+ # mean_hmm = normalized_mean(hmm_matrix)
190
+ # else:
191
+ # mean_hmm = np.nanmean(hmm_matrix, axis=0)
192
+ # mean_gpc = methylation_fraction(gpc_matrix)
193
+ # mean_cpg = methylation_fraction(cpg_matrix)
194
+ # mean_any_c = methylation_fraction(any_c_matrix)
195
+
196
+ # fig = plt.figure(figsize=(18, 12))
197
+ # gs = gridspec.GridSpec(2, 4, height_ratios=[1, 6], hspace=0.01)
198
+ # fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
199
+
200
+ # axes_heat = [fig.add_subplot(gs[1, i]) for i in range(4)]
201
+ # axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(4)]
202
+
203
+ # clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
204
+ # clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
205
+ # clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
206
+ # clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
207
+
208
+ # hmm_labels = subset.var_names.astype(int)
209
+ # hmm_label_spacing = 150
210
+ # sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
211
+ # axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
212
+ # axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
213
+ # for boundary in bin_boundaries[:-1]:
214
+ # axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
215
+
216
+ # sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
217
+ # axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
218
+ # axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
219
+ # for boundary in bin_boundaries[:-1]:
220
+ # axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
221
+
222
+ # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
223
+ # axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
224
+ # for boundary in bin_boundaries[:-1]:
225
+ # axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
226
+
227
+ # sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[3], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
228
+ # axes_heat[3].set_xticks(range(0, len(any_c_labels), 20))
229
+ # axes_heat[3].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
230
+ # for boundary in bin_boundaries[:-1]:
231
+ # axes_heat[3].axhline(y=boundary, color="black", linewidth=2)
232
+
233
+ # plt.tight_layout()
234
+
235
+ # if save_path:
236
+ # save_name = f"{ref} — {sample}"
237
+ # os.makedirs(save_path, exist_ok=True)
238
+ # safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
239
+ # out_file = os.path.join(save_path, f"{safe_name}.png")
240
+ # plt.savefig(out_file, dpi=300)
241
+ # print(f"Saved: {out_file}")
242
+ # plt.close()
243
+ # else:
244
+ # plt.show()
245
+
246
+ # print(f"Summary for {sample} - {ref}:")
247
+ # for bin_label, percent in percentages.items():
248
+ # print(f" - {bin_label}: {percent:.1f}%")
249
+
250
+ # results.append({
251
+ # "sample": sample,
252
+ # "ref": ref,
253
+ # "hmm_matrix": hmm_matrix,
254
+ # "gpc_matrix": gpc_matrix,
255
+ # "cpg_matrix": cpg_matrix,
256
+ # "row_labels": row_labels,
257
+ # "bin_labels": bin_labels,
258
+ # "bin_boundaries": bin_boundaries,
259
+ # "percentages": percentages
260
+ # })
261
+
262
+ # #adata.uns['clustermap_results'] = results
263
+
264
+ # except Exception as e:
265
+ # import traceback
266
+ # traceback.print_exc()
267
+ # continue
268
+
269
+
20
270
  def combined_hmm_raw_clustermap(
21
271
  adata,
22
- sample_col='Sample_Names',
23
- reference_col='Reference_strand',
24
- hmm_feature_layer="hmm_combined",
25
- layer_gpc="nan0_0minus1",
26
- layer_cpg="nan0_0minus1",
27
- layer_any_c="nan0_0minus1",
28
- cmap_hmm="tab10",
29
- cmap_gpc="coolwarm",
30
- cmap_cpg="viridis",
31
- cmap_any_c='coolwarm',
32
- min_quality=20,
33
- min_length=200,
34
- min_mapped_length_to_reference_length_ratio=0.8,
35
- min_position_valid_fraction=0.5,
36
- sample_mapping=None,
37
- save_path=None,
38
- normalize_hmm=False,
39
- sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
40
- bins=None,
41
- deaminase=False,
42
- min_signal=0
43
- ):
44
- import scipy.cluster.hierarchy as sch
45
- import pandas as pd
46
- import numpy as np
47
- import seaborn as sns
48
- import matplotlib.pyplot as plt
49
- import matplotlib.gridspec as gridspec
50
- import os
272
+ sample_col: str = "Sample_Names",
273
+ reference_col: str = "Reference_strand",
274
+
275
+ hmm_feature_layer: str = "hmm_combined",
276
+
277
+ layer_gpc: str = "nan0_0minus1",
278
+ layer_cpg: str = "nan0_0minus1",
279
+ layer_any_c: str = "nan0_0minus1",
280
+ layer_a: str = "nan0_0minus1",
281
+
282
+ cmap_hmm: str = "tab10",
283
+ cmap_gpc: str = "coolwarm",
284
+ cmap_cpg: str = "viridis",
285
+ cmap_any_c: str = "coolwarm",
286
+ cmap_a: str = "coolwarm",
287
+
288
+ min_quality: int = 20,
289
+ min_length: int = 200,
290
+ min_mapped_length_to_reference_length_ratio: float = 0.8,
291
+ min_position_valid_fraction: float = 0.5,
292
+
293
+ save_path: str | Path | None = None,
294
+ normalize_hmm: bool = False,
295
+
296
+ sort_by: str = "gpc",
297
+ bins: Optional[Dict[str, Any]] = None,
298
+
299
+ deaminase: bool = False,
300
+ min_signal: float = 0.0,
301
+
302
+ # ---- fixed tick label controls (counts, not spacing)
303
+ n_xticks_hmm: int = 10,
304
+ n_xticks_any_c: int = 8,
305
+ n_xticks_gpc: int = 8,
306
+ n_xticks_cpg: int = 8,
307
+ n_xticks_a: int = 8,
308
+ ):
309
+ """
310
+ Makes a multi-panel clustermap per (sample, reference):
311
+ HMM panel (always) + optional raw panels for any_C, GpC, CpG, and A sites.
312
+
313
+ Panels are added only if the corresponding site mask exists AND has >0 sites.
51
314
 
315
+ sort_by options:
316
+ 'gpc', 'cpg', 'any_c', 'any_a', 'gpc_cpg', 'none', or 'obs:<col>'
317
+ """
318
+ def pick_xticks(labels: np.ndarray, n_ticks: int):
319
+ if labels.size == 0:
320
+ return [], []
321
+ idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
322
+ idx = np.unique(idx)
323
+ return idx.tolist(), labels[idx].tolist()
324
+
52
325
  results = []
53
- if deaminase:
54
- signal_type = 'deamination'
55
- else:
56
- signal_type = 'methylation'
326
+ signal_type = "deamination" if deaminase else "methylation"
57
327
 
58
328
  for ref in adata.obs[reference_col].cat.categories:
59
329
  for sample in adata.obs[sample_col].cat.categories:
330
+
60
331
  try:
332
+ # ---- subset reads ----
61
333
  subset = adata[
62
334
  (adata.obs[reference_col] == ref) &
63
335
  (adata.obs[sample_col] == sample) &
64
- (adata.obs['read_quality'] >= min_quality) &
65
- (adata.obs['read_length'] >= min_length) &
66
- (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
336
+ (adata.obs["read_quality"] >= min_quality) &
337
+ (adata.obs["read_length"] >= min_length) &
338
+ (
339
+ adata.obs["mapped_length_to_reference_length_ratio"]
340
+ > min_mapped_length_to_reference_length_ratio
341
+ )
67
342
  ]
68
-
69
- mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
70
- subset = subset[:, mask]
343
+
344
+ # ---- valid fraction filter ----
345
+ vf_key = f"{ref}_valid_fraction"
346
+ if vf_key in subset.var:
347
+ mask = subset.var[vf_key].astype(float) > float(min_position_valid_fraction)
348
+ subset = subset[:, mask]
71
349
 
72
350
  if subset.shape[0] == 0:
73
- print(f" No reads left after filtering for {sample} - {ref}")
74
351
  continue
75
352
 
76
- if bins:
77
- print(f"Using defined bins to subset clustermap for {sample} - {ref}")
78
- bins_temp = bins
353
+ # ---- bins ----
354
+ if bins is None:
355
+ bins_temp = {"All": np.ones(subset.n_obs, dtype=bool)}
79
356
  else:
80
- print(f"Using all reads for clustermap for {sample} - {ref}")
81
- bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
82
-
83
- # Get column positions (not var_names!) of site masks
84
- gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
85
- cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
86
- any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
87
- num_gpc = len(gpc_sites)
88
- num_cpg = len(cpg_sites)
89
- num_c = len(any_c_sites)
90
- print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites} for {sample} - {ref}")
91
-
92
- # Use var_names for x-axis tick labels
93
- gpc_labels = subset.var_names[gpc_sites].astype(int)
94
- cpg_labels = subset.var_names[cpg_sites].astype(int)
95
- any_c_labels = subset.var_names[any_c_sites].astype(int)
96
-
97
- stacked_hmm_feature, stacked_gpc, stacked_cpg, stacked_any_c = [], [], [], []
98
- row_labels, bin_labels = [], []
99
- bin_boundaries = []
357
+ bins_temp = bins
100
358
 
101
- total_reads = subset.shape[0]
359
+ # ---- site masks (robust) ----
360
+ def _sites(*keys):
361
+ for k in keys:
362
+ if k in subset.var:
363
+ return np.where(subset.var[k].values)[0]
364
+ return np.array([], dtype=int)
365
+
366
+ gpc_sites = _sites(f"{ref}_GpC_site")
367
+ cpg_sites = _sites(f"{ref}_CpG_site")
368
+ any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
369
+ any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
370
+
371
+ def _labels(sites):
372
+ return subset.var_names[sites].astype(int) if sites.size else np.array([])
373
+
374
+ gpc_labels = _labels(gpc_sites)
375
+ cpg_labels = _labels(cpg_sites)
376
+ any_c_labels = _labels(any_c_sites)
377
+ any_a_labels = _labels(any_a_sites)
378
+
379
+ # storage
380
+ stacked_hmm = []
381
+ stacked_any_c = []
382
+ stacked_gpc = []
383
+ stacked_cpg = []
384
+ stacked_any_a = []
385
+
386
+ row_labels, bin_labels, bin_boundaries = [], [], []
387
+ total_reads = subset.n_obs
102
388
  percentages = {}
103
389
  last_idx = 0
104
390
 
391
+ # ---------------- process bins ----------------
105
392
  for bin_label, bin_filter in bins_temp.items():
106
- subset_bin = subset[bin_filter].copy()
107
- num_reads = subset_bin.shape[0]
108
- print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
109
- percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
110
- percentages[bin_label] = percent_reads
393
+ sb = subset[bin_filter].copy()
394
+ n = sb.n_obs
395
+ if n == 0:
396
+ continue
111
397
 
112
- if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
113
- # Determine sorting order
114
- if sort_by.startswith("obs:"):
115
- colname = sort_by.split("obs:")[1]
116
- order = np.argsort(subset_bin.obs[colname].values)
117
- elif sort_by == "gpc":
118
- linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
119
- order = sch.leaves_list(linkage)
120
- elif sort_by == "cpg":
121
- linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
122
- order = sch.leaves_list(linkage)
123
- elif sort_by == "gpc_cpg":
124
- linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
125
- order = sch.leaves_list(linkage)
126
- elif sort_by == "none":
127
- order = np.arange(num_reads)
128
- elif sort_by == "any_c":
129
- linkage = sch.linkage(subset_bin.layers[layer_any_c], method="ward")
130
- order = sch.leaves_list(linkage)
131
- else:
132
- raise ValueError(f"Unsupported sort_by option: {sort_by}")
133
-
134
- stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
135
- stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
136
- stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
137
- stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
138
-
139
- row_labels.extend([bin_label] * num_reads)
140
- bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
141
- last_idx += num_reads
142
- bin_boundaries.append(last_idx)
143
-
144
- if stacked_hmm_feature:
145
- hmm_matrix = np.vstack(stacked_hmm_feature)
146
- gpc_matrix = np.vstack(stacked_gpc)
147
- cpg_matrix = np.vstack(stacked_cpg)
148
- any_c_matrix = np.vstack(stacked_any_c)
398
+ pct = (n / total_reads) * 100 if total_reads else 0
399
+ percentages[bin_label] = pct
149
400
 
150
- if hmm_matrix.size > 0:
151
- def normalized_mean(matrix):
152
- mean = np.nanmean(matrix, axis=0)
153
- normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
154
- return normalized
401
+ # ---- sorting ----
402
+ if sort_by.startswith("obs:"):
403
+ colname = sort_by.split("obs:")[1]
404
+ order = np.argsort(sb.obs[colname].values)
155
405
 
156
- def methylation_fraction(matrix):
157
- methylated = (matrix == 1).sum(axis=0)
158
- valid = (matrix != 0).sum(axis=0)
159
- return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
406
+ elif sort_by == "gpc" and gpc_sites.size:
407
+ linkage = sch.linkage(sb[:, gpc_sites].layers[layer_gpc], method="ward")
408
+ order = sch.leaves_list(linkage)
160
409
 
161
- if normalize_hmm:
162
- mean_hmm = normalized_mean(hmm_matrix)
163
- else:
164
- mean_hmm = np.nanmean(hmm_matrix, axis=0)
165
- mean_gpc = methylation_fraction(gpc_matrix)
166
- mean_cpg = methylation_fraction(cpg_matrix)
167
- mean_any_c = methylation_fraction(any_c_matrix)
168
-
169
- fig = plt.figure(figsize=(18, 12))
170
- gs = gridspec.GridSpec(2, 4, height_ratios=[1, 6], hspace=0.01)
171
- fig.suptitle(f"{sample} - {ref}", fontsize=14, y=0.95)
172
-
173
- axes_heat = [fig.add_subplot(gs[1, i]) for i in range(4)]
174
- axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(4)]
175
-
176
- clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
177
- clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
178
- clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
179
- clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
180
-
181
- hmm_labels = subset.var_names.astype(int)
182
- hmm_label_spacing = 150
183
- sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
184
- axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
185
- axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
186
- for boundary in bin_boundaries[:-1]:
187
- axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
188
-
189
- sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
190
- axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
191
- axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
192
- for boundary in bin_boundaries[:-1]:
193
- axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
194
-
195
- sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
196
- axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
197
- for boundary in bin_boundaries[:-1]:
198
- axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
199
-
200
- sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[3], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
201
- axes_heat[3].set_xticks(range(0, len(any_c_labels), 20))
202
- axes_heat[3].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
203
- for boundary in bin_boundaries[:-1]:
204
- axes_heat[3].axhline(y=boundary, color="black", linewidth=2)
205
-
206
- plt.tight_layout()
207
-
208
- if save_path:
209
- save_name = f"{ref} — {sample}"
210
- os.makedirs(save_path, exist_ok=True)
211
- safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
212
- out_file = os.path.join(save_path, f"{safe_name}.png")
213
- plt.savefig(out_file, dpi=300)
214
- print(f"Saved: {out_file}")
215
- plt.close()
216
- else:
217
- plt.show()
218
-
219
- print(f"Summary for {sample} - {ref}:")
220
- for bin_label, percent in percentages.items():
221
- print(f" - {bin_label}: {percent:.1f}%")
222
-
223
- results.append({
224
- "sample": sample,
225
- "ref": ref,
226
- "hmm_matrix": hmm_matrix,
227
- "gpc_matrix": gpc_matrix,
228
- "cpg_matrix": cpg_matrix,
229
- "row_labels": row_labels,
230
- "bin_labels": bin_labels,
231
- "bin_boundaries": bin_boundaries,
232
- "percentages": percentages
233
- })
234
-
235
- adata.uns['clustermap_results'] = results
410
+ elif sort_by == "cpg" and cpg_sites.size:
411
+ linkage = sch.linkage(sb[:, cpg_sites].layers[layer_cpg], method="ward")
412
+ order = sch.leaves_list(linkage)
236
413
 
237
- except Exception as e:
414
+ elif sort_by == "any_c" and any_c_sites.size:
415
+ linkage = sch.linkage(sb[:, any_c_sites].layers[layer_any_c], method="ward")
416
+ order = sch.leaves_list(linkage)
417
+
418
+ elif sort_by == "any_a" and any_a_sites.size:
419
+ linkage = sch.linkage(sb[:, any_a_sites].layers[layer_a], method="ward")
420
+ order = sch.leaves_list(linkage)
421
+
422
+ elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
423
+ linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
424
+ order = sch.leaves_list(linkage)
425
+
426
+ else:
427
+ order = np.arange(n)
428
+
429
+ sb = sb[order]
430
+
431
+ # ---- collect matrices ----
432
+ stacked_hmm.append(sb.layers[hmm_feature_layer])
433
+ if any_c_sites.size:
434
+ stacked_any_c.append(sb[:, any_c_sites].layers[layer_any_c])
435
+ if gpc_sites.size:
436
+ stacked_gpc.append(sb[:, gpc_sites].layers[layer_gpc])
437
+ if cpg_sites.size:
438
+ stacked_cpg.append(sb[:, cpg_sites].layers[layer_cpg])
439
+ if any_a_sites.size:
440
+ stacked_any_a.append(sb[:, any_a_sites].layers[layer_a])
441
+
442
+ row_labels.extend([bin_label] * n)
443
+ bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
444
+ last_idx += n
445
+ bin_boundaries.append(last_idx)
446
+
447
+ # ---------------- stack ----------------
448
+ hmm_matrix = np.vstack(stacked_hmm)
449
+ mean_hmm = normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
450
+
451
+ panels = [
452
+ ("HMM", hmm_matrix, subset.var_names.astype(int), cmap_hmm, mean_hmm, n_xticks_hmm),
453
+ ]
454
+
455
+ if stacked_any_c:
456
+ m = np.vstack(stacked_any_c)
457
+ panels.append(("any_C", m, any_c_labels, cmap_any_c, methylation_fraction(m), n_xticks_any_c))
458
+
459
+ if stacked_gpc:
460
+ m = np.vstack(stacked_gpc)
461
+ panels.append(("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc))
462
+
463
+ if stacked_cpg:
464
+ m = np.vstack(stacked_cpg)
465
+ panels.append(("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg))
466
+
467
+ if stacked_any_a:
468
+ m = np.vstack(stacked_any_a)
469
+ panels.append(("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a))
470
+
471
+ # ---------------- plotting ----------------
472
+ n_panels = len(panels)
473
+ fig = plt.figure(figsize=(4.5 * n_panels, 10))
474
+ gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
475
+ fig.suptitle(f"{sample} — {ref} — {total_reads} reads ({signal_type})",
476
+ fontsize=14, y=0.98)
477
+
478
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
479
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
480
+
481
+ for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
482
+
483
+ # ---- your clean barplot ----
484
+ clean_barplot(axes_bar[i], mean_vec, name)
485
+
486
+ # ---- heatmap ----
487
+ sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i],
488
+ yticklabels=False, cbar=False)
489
+
490
+ # ---- xticks ----
491
+ xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
492
+ axes_heat[i].set_xticks(xtick_pos)
493
+ axes_heat[i].set_xticklabels(xtick_labels, rotation=90, fontsize=8)
494
+
495
+ for boundary in bin_boundaries[:-1]:
496
+ axes_heat[i].axhline(y=boundary, color="black", linewidth=1.2)
497
+
498
+ plt.tight_layout()
499
+
500
+ if save_path:
501
+ save_path = Path(save_path)
502
+ save_path.mkdir(parents=True, exist_ok=True)
503
+ safe_name = f"{ref}__{sample}".replace("/", "_")
504
+ out_file = save_path / f"{safe_name}.png"
505
+ plt.savefig(out_file, dpi=300)
506
+ plt.close(fig)
507
+ else:
508
+ plt.show()
509
+
510
+ except Exception:
238
511
  import traceback
239
512
  traceback.print_exc()
240
513
  continue
241
514
 
242
515
 
516
+ # def combined_raw_clustermap(
517
+ # adata,
518
+ # sample_col='Sample_Names',
519
+ # reference_col='Reference_strand',
520
+ # mod_target_bases=['GpC', 'CpG'],
521
+ # layer_any_c="nan0_0minus1",
522
+ # layer_gpc="nan0_0minus1",
523
+ # layer_cpg="nan0_0minus1",
524
+ # layer_a="nan0_0minus1",
525
+ # cmap_any_c="coolwarm",
526
+ # cmap_gpc="coolwarm",
527
+ # cmap_cpg="viridis",
528
+ # cmap_a="coolwarm",
529
+ # min_quality=20,
530
+ # min_length=200,
531
+ # min_mapped_length_to_reference_length_ratio=0.8,
532
+ # min_position_valid_fraction=0.5,
533
+ # sample_mapping=None,
534
+ # save_path=None,
535
+ # sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', 'any_a', or 'obs:<column>'
536
+ # bins=None,
537
+ # deaminase=False,
538
+ # min_signal=0
539
+ # ):
540
+
541
+ # results = []
542
+
543
+ # for ref in adata.obs[reference_col].cat.categories:
544
+ # for sample in adata.obs[sample_col].cat.categories:
545
+ # try:
546
+ # subset = adata[
547
+ # (adata.obs[reference_col] == ref) &
548
+ # (adata.obs[sample_col] == sample) &
549
+ # (adata.obs['read_quality'] >= min_quality) &
550
+ # (adata.obs['mapped_length'] >= min_length) &
551
+ # (adata.obs['mapped_length_to_reference_length_ratio'] >= min_mapped_length_to_reference_length_ratio)
552
+ # ]
553
+
554
+ # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
555
+ # subset = subset[:, mask]
556
+
557
+ # if subset.shape[0] == 0:
558
+ # print(f" No reads left after filtering for {sample} - {ref}")
559
+ # continue
560
+
561
+ # if bins:
562
+ # print(f"Using defined bins to subset clustermap for {sample} - {ref}")
563
+ # bins_temp = bins
564
+ # else:
565
+ # print(f"Using all reads for clustermap for {sample} - {ref}")
566
+ # bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
567
+
568
+ # num_any_c = 0
569
+ # num_gpc = 0
570
+ # num_cpg = 0
571
+ # num_any_a = 0
572
+
573
+ # # Get column positions (not var_names!) of site masks
574
+ # if any(base in ["C", "CpG", "GpC"] for base in mod_target_bases):
575
+ # any_c_sites = np.where(subset.var[f"{ref}_C_site"].values)[0]
576
+ # gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
577
+ # cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
578
+ # num_any_c = len(any_c_sites)
579
+ # num_gpc = len(gpc_sites)
580
+ # num_cpg = len(cpg_sites)
581
+ # print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites}\n and {num_any_c} any_C sites at {any_c_sites} for {sample} - {ref}")
582
+
583
+ # # Use var_names for x-axis tick labels
584
+ # gpc_labels = subset.var_names[gpc_sites].astype(int)
585
+ # cpg_labels = subset.var_names[cpg_sites].astype(int)
586
+ # any_c_labels = subset.var_names[any_c_sites].astype(int)
587
+ # stacked_any_c, stacked_gpc, stacked_cpg = [], [], []
588
+
589
+ # if "A" in mod_target_bases:
590
+ # any_a_sites = np.where(subset.var[f"{ref}_A_site"].values)[0]
591
+ # num_any_a = len(any_a_sites)
592
+ # print(f"Found {num_any_a} any_A sites at {any_a_sites} for {sample} - {ref}")
593
+ # any_a_labels = subset.var_names[any_a_sites].astype(int)
594
+ # stacked_any_a = []
595
+
596
+ # row_labels, bin_labels = [], []
597
+ # bin_boundaries = []
598
+
599
+ # total_reads = subset.shape[0]
600
+ # percentages = {}
601
+ # last_idx = 0
602
+
603
+ # for bin_label, bin_filter in bins_temp.items():
604
+ # subset_bin = subset[bin_filter].copy()
605
+ # num_reads = subset_bin.shape[0]
606
+ # print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
607
+ # percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
608
+ # percentages[bin_label] = percent_reads
609
+
610
+ # if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
611
+ # # Determine sorting order
612
+ # if sort_by.startswith("obs:"):
613
+ # colname = sort_by.split("obs:")[1]
614
+ # order = np.argsort(subset_bin.obs[colname].values)
615
+ # elif sort_by == "gpc":
616
+ # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
617
+ # order = sch.leaves_list(linkage)
618
+ # elif sort_by == "cpg":
619
+ # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
620
+ # order = sch.leaves_list(linkage)
621
+ # elif sort_by == "any_c":
622
+ # linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
623
+ # order = sch.leaves_list(linkage)
624
+ # elif sort_by == "gpc_cpg":
625
+ # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
626
+ # order = sch.leaves_list(linkage)
627
+ # elif sort_by == "none":
628
+ # order = np.arange(num_reads)
629
+ # elif sort_by == "any_a":
630
+ # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
631
+ # order = sch.leaves_list(linkage)
632
+ # else:
633
+ # raise ValueError(f"Unsupported sort_by option: {sort_by}")
634
+
635
+ # stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
636
+ # stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
637
+ # stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
638
+
639
+ # if num_reads > 0 and num_any_a > 0:
640
+ # # Determine sorting order
641
+ # if sort_by.startswith("obs:"):
642
+ # colname = sort_by.split("obs:")[1]
643
+ # order = np.argsort(subset_bin.obs[colname].values)
644
+ # elif sort_by == "gpc":
645
+ # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
646
+ # order = sch.leaves_list(linkage)
647
+ # elif sort_by == "cpg":
648
+ # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
649
+ # order = sch.leaves_list(linkage)
650
+ # elif sort_by == "any_c":
651
+ # linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
652
+ # order = sch.leaves_list(linkage)
653
+ # elif sort_by == "gpc_cpg":
654
+ # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
655
+ # order = sch.leaves_list(linkage)
656
+ # elif sort_by == "none":
657
+ # order = np.arange(num_reads)
658
+ # elif sort_by == "any_a":
659
+ # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
660
+ # order = sch.leaves_list(linkage)
661
+ # else:
662
+ # raise ValueError(f"Unsupported sort_by option: {sort_by}")
663
+
664
+ # stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
665
+
666
+
667
+ # row_labels.extend([bin_label] * num_reads)
668
+ # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
669
+ # last_idx += num_reads
670
+ # bin_boundaries.append(last_idx)
671
+
672
+ # gs_dim = 0
673
+
674
+ # if stacked_any_c:
675
+ # any_c_matrix = np.vstack(stacked_any_c)
676
+ # gpc_matrix = np.vstack(stacked_gpc)
677
+ # cpg_matrix = np.vstack(stacked_cpg)
678
+ # if any_c_matrix.size > 0:
679
+ # mean_gpc = methylation_fraction(gpc_matrix)
680
+ # mean_cpg = methylation_fraction(cpg_matrix)
681
+ # mean_any_c = methylation_fraction(any_c_matrix)
682
+ # gs_dim += 3
683
+
684
+ # if stacked_any_a:
685
+ # any_a_matrix = np.vstack(stacked_any_a)
686
+ # if any_a_matrix.size > 0:
687
+ # mean_any_a = methylation_fraction(any_a_matrix)
688
+ # gs_dim += 1
689
+
690
+
691
+ # fig = plt.figure(figsize=(18, 12))
692
+ # gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
693
+ # fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
694
+ # axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
695
+ # axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
696
+
697
+ # current_ax = 0
698
+
699
+ # if stacked_any_c:
700
+ # if any_c_matrix.size > 0:
701
+ # clean_barplot(axes_bar[current_ax], mean_any_c, f"any C site Modification Signal")
702
+ # sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[current_ax], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
703
+ # axes_heat[current_ax].set_xticks(range(0, len(any_c_labels), 20))
704
+ # axes_heat[current_ax].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
705
+ # for boundary in bin_boundaries[:-1]:
706
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
707
+ # current_ax +=1
708
+
709
+ # clean_barplot(axes_bar[current_ax], mean_gpc, f"GpC Modification Signal")
710
+ # sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[current_ax], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
711
+ # axes_heat[current_ax].set_xticks(range(0, len(gpc_labels), 5))
712
+ # axes_heat[current_ax].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
713
+ # for boundary in bin_boundaries[:-1]:
714
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
715
+ # current_ax +=1
716
+
717
+ # clean_barplot(axes_bar[current_ax], mean_cpg, f"CpG Modification Signal")
718
+ # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
719
+ # axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
720
+ # for boundary in bin_boundaries[:-1]:
721
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
722
+ # current_ax +=1
723
+
724
+ # results.append({
725
+ # "sample": sample,
726
+ # "ref": ref,
727
+ # "any_c_matrix": any_c_matrix,
728
+ # "gpc_matrix": gpc_matrix,
729
+ # "cpg_matrix": cpg_matrix,
730
+ # "row_labels": row_labels,
731
+ # "bin_labels": bin_labels,
732
+ # "bin_boundaries": bin_boundaries,
733
+ # "percentages": percentages
734
+ # })
735
+
736
+ # if stacked_any_a:
737
+ # if any_a_matrix.size > 0:
738
+ # clean_barplot(axes_bar[current_ax], mean_any_a, f"any A site Modification Signal")
739
+ # sns.heatmap(any_a_matrix, cmap=cmap_a, ax=axes_heat[current_ax], xticklabels=any_a_labels[::20], yticklabels=False, cbar=False)
740
+ # axes_heat[current_ax].set_xticks(range(0, len(any_a_labels), 20))
741
+ # axes_heat[current_ax].set_xticklabels(any_a_labels[::20], rotation=90, fontsize=10)
742
+ # for boundary in bin_boundaries[:-1]:
743
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
744
+ # current_ax +=1
745
+
746
+ # results.append({
747
+ # "sample": sample,
748
+ # "ref": ref,
749
+ # "any_a_matrix": any_a_matrix,
750
+ # "row_labels": row_labels,
751
+ # "bin_labels": bin_labels,
752
+ # "bin_boundaries": bin_boundaries,
753
+ # "percentages": percentages
754
+ # })
755
+
756
+ # plt.tight_layout()
757
+
758
+ # if save_path:
759
+ # save_name = f"{ref} — {sample}"
760
+ # os.makedirs(save_path, exist_ok=True)
761
+ # safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
762
+ # out_file = os.path.join(save_path, f"{safe_name}.png")
763
+ # plt.savefig(out_file, dpi=300)
764
+ # print(f"Saved: {out_file}")
765
+ # plt.close()
766
+ # else:
767
+ # plt.show()
768
+
769
+ # print(f"Summary for {sample} - {ref}:")
770
+ # for bin_label, percent in percentages.items():
771
+ # print(f" - {bin_label}: {percent:.1f}%")
772
+
773
+ # adata.uns['clustermap_results'] = results
774
+
775
+ # except Exception as e:
776
+ # import traceback
777
+ # traceback.print_exc()
778
+ # continue
779
+
780
+ def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
781
+ """
782
+ Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
783
+ Always includes 0 and n_positions-1 when possible.
784
+ """
785
+ n_ticks = int(max(2, n_ticks))
786
+ if n_positions <= n_ticks:
787
+ return np.arange(n_positions)
788
+
789
+ # linspace gives fixed count
790
+ pos = np.linspace(0, n_positions - 1, n_ticks)
791
+ return np.unique(np.round(pos).astype(int))
792
+
243
793
  def combined_raw_clustermap(
244
794
  adata,
245
- sample_col='Sample_Names',
246
- reference_col='Reference_strand',
247
- layer_any_c="nan0_0minus1",
248
- layer_gpc="nan0_0minus1",
249
- layer_cpg="nan0_0minus1",
250
- cmap_any_c="coolwarm",
251
- cmap_gpc="coolwarm",
252
- cmap_cpg="viridis",
253
- min_quality=20,
254
- min_length=200,
255
- min_mapped_length_to_reference_length_ratio=0.8,
256
- min_position_valid_fraction=0.5,
257
- sample_mapping=None,
258
- save_path=None,
259
- sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
260
- bins=None,
261
- deaminase=False,
262
- min_signal=0
263
- ):
264
- import scipy.cluster.hierarchy as sch
265
- import pandas as pd
266
- import numpy as np
267
- import seaborn as sns
268
- import matplotlib.pyplot as plt
269
- import matplotlib.gridspec as gridspec
270
- import os
795
+ sample_col: str = "Sample_Names",
796
+ reference_col: str = "Reference_strand",
797
+ mod_target_bases: Sequence[str] = ("GpC", "CpG"),
798
+ layer_any_c: str = "nan0_0minus1",
799
+ layer_gpc: str = "nan0_0minus1",
800
+ layer_cpg: str = "nan0_0minus1",
801
+ layer_a: str = "nan0_0minus1",
802
+ cmap_any_c: str = "coolwarm",
803
+ cmap_gpc: str = "coolwarm",
804
+ cmap_cpg: str = "viridis",
805
+ cmap_a: str = "coolwarm",
806
+ min_quality: float = 20,
807
+ min_length: int = 200,
808
+ min_mapped_length_to_reference_length_ratio: float = 0.8,
809
+ min_position_valid_fraction: float = 0.5,
810
+ sample_mapping: Optional[Mapping[str, str]] = None,
811
+ save_path: str | Path | None = None,
812
+ sort_by: str = "gpc", # 'gpc','cpg','any_c','gpc_cpg','any_a','none','obs:<col>'
813
+ bins: Optional[Dict[str, Any]] = None,
814
+ deaminase: bool = False,
815
+ min_signal: float = 0,
816
+ # NEW tick controls
817
+ n_xticks_any_c: int = 10,
818
+ n_xticks_gpc: int = 10,
819
+ n_xticks_cpg: int = 10,
820
+ n_xticks_any_a: int = 10,
821
+ xtick_rotation: int = 90,
822
+ xtick_fontsize: int = 9,
823
+ ):
824
+ """
825
+ Plot stacked heatmaps + per-position mean barplots for any_C, GpC, CpG, and optional A.
271
826
 
272
- results = []
827
+ Key fixes vs old version:
828
+ - order computed ONCE per bin, applied to all matrices
829
+ - no hard-coded axes indices
830
+ - NaNs excluded from methylation denominators
831
+ - var_names not forced to int
832
+ - fixed count of x tick labels per block (controllable)
833
+ - adata.uns updated once at end
834
+
835
+ Returns
836
+ -------
837
+ results : list[dict]
838
+ One entry per (sample, ref) plot with matrices + bin metadata.
839
+ """
840
+
841
+ results: List[Dict[str, Any]] = []
842
+ save_path = Path(save_path) if save_path is not None else None
843
+ if save_path is not None:
844
+ save_path.mkdir(parents=True, exist_ok=True)
845
+
846
+ # Ensure categorical
847
+ for col in (sample_col, reference_col):
848
+ if col not in adata.obs:
849
+ raise KeyError(f"{col} not in adata.obs")
850
+ if not pd.api.types.is_categorical_dtype(adata.obs[col]):
851
+ adata.obs[col] = adata.obs[col].astype("category")
852
+
853
+ base_set = set(mod_target_bases)
854
+ include_any_c = any(b in {"C", "CpG", "GpC"} for b in base_set)
855
+ include_any_a = "A" in base_set
273
856
 
274
857
  for ref in adata.obs[reference_col].cat.categories:
275
858
  for sample in adata.obs[sample_col].cat.categories:
859
+
860
+ # Optionally remap sample label for display
861
+ display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
862
+
276
863
  try:
277
864
  subset = adata[
278
865
  (adata.obs[reference_col] == ref) &
279
866
  (adata.obs[sample_col] == sample) &
280
- (adata.obs['read_quality'] >= min_quality) &
281
- (adata.obs['mapped_length'] >= min_length) &
282
- (adata.obs['mapped_length_to_reference_length_ratio'] >= min_mapped_length_to_reference_length_ratio)
867
+ (adata.obs["read_quality"] >= min_quality) &
868
+ (adata.obs["mapped_length"] >= min_length) &
869
+ (adata.obs["mapped_length_to_reference_length_ratio"] >= min_mapped_length_to_reference_length_ratio)
283
870
  ]
284
871
 
285
- mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
286
- subset = subset[:, mask]
872
+ # position-level mask
873
+ valid_key = f"{ref}_valid_fraction"
874
+ if valid_key in subset.var:
875
+ mask = subset.var[valid_key].astype(float).values > float(min_position_valid_fraction)
876
+ subset = subset[:, mask]
287
877
 
288
878
  if subset.shape[0] == 0:
289
- print(f" No reads left after filtering for {sample} - {ref}")
879
+ print(f"No reads left after filtering for {display_sample} - {ref}")
290
880
  continue
291
881
 
292
- if bins:
293
- print(f"Using defined bins to subset clustermap for {sample} - {ref}")
294
- bins_temp = bins
882
+ # bins mode
883
+ if bins is None:
884
+ bins_temp = {"All": (subset.obs[reference_col] == ref)}
295
885
  else:
296
- print(f"Using all reads for clustermap for {sample} - {ref}")
297
- bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
298
-
299
- # Get column positions (not var_names!) of site masks
300
- any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
301
- gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
302
- cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
303
- num_any_c = len(any_c_sites)
304
- num_gpc = len(gpc_sites)
305
- num_cpg = len(cpg_sites)
306
- print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites}\n and {num_any_c} any_C sites at {any_c_sites} for {sample} - {ref}")
307
-
308
- # Use var_names for x-axis tick labels
309
- gpc_labels = subset.var_names[gpc_sites].astype(int)
310
- cpg_labels = subset.var_names[cpg_sites].astype(int)
311
- any_c_labels = subset.var_names[any_c_sites].astype(int)
312
-
313
- stacked_any_c, stacked_gpc, stacked_cpg = [], [], []
314
- row_labels, bin_labels = [], []
315
- bin_boundaries = []
886
+ bins_temp = bins
316
887
 
317
- total_reads = subset.shape[0]
888
+ # find sites (positions)
889
+ any_c_sites = gpc_sites = cpg_sites = np.array([], dtype=int)
890
+ any_a_sites = np.array([], dtype=int)
891
+
892
+ num_any_c = num_gpc = num_cpg = num_any_a = 0
893
+
894
+ if include_any_c:
895
+ any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
896
+ gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
897
+ cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
898
+
899
+ num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
900
+
901
+ any_c_labels = subset.var_names[any_c_sites].astype(str)
902
+ gpc_labels = subset.var_names[gpc_sites].astype(str)
903
+ cpg_labels = subset.var_names[cpg_sites].astype(str)
904
+
905
+ if include_any_a:
906
+ any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
907
+ num_any_a = len(any_a_sites)
908
+ any_a_labels = subset.var_names[any_a_sites].astype(str)
909
+
910
+ stacked_any_c, stacked_gpc, stacked_cpg, stacked_any_a = [], [], [], []
911
+ row_labels, bin_labels, bin_boundaries = [], [], []
318
912
  percentages = {}
319
913
  last_idx = 0
914
+ total_reads = subset.shape[0]
320
915
 
916
+ # ----------------------------
917
+ # per-bin stacking
918
+ # ----------------------------
321
919
  for bin_label, bin_filter in bins_temp.items():
322
920
  subset_bin = subset[bin_filter].copy()
323
921
  num_reads = subset_bin.shape[0]
324
- print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
325
- percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
922
+ if num_reads == 0:
923
+ percentages[bin_label] = 0.0
924
+ continue
925
+
926
+ percent_reads = (num_reads / total_reads) * 100
326
927
  percentages[bin_label] = percent_reads
327
928
 
328
- if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
329
- # Determine sorting order
330
- if sort_by.startswith("obs:"):
331
- colname = sort_by.split("obs:")[1]
332
- order = np.argsort(subset_bin.obs[colname].values)
333
- elif sort_by == "gpc":
334
- linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
335
- order = sch.leaves_list(linkage)
336
- elif sort_by == "cpg":
337
- linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
338
- order = sch.leaves_list(linkage)
339
- elif sort_by == "any_c":
340
- linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
341
- order = sch.leaves_list(linkage)
342
- elif sort_by == "gpc_cpg":
343
- linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
344
- order = sch.leaves_list(linkage)
345
- elif sort_by == "none":
346
- order = np.arange(num_reads)
347
- else:
348
- raise ValueError(f"Unsupported sort_by option: {sort_by}")
929
+ # compute order ONCE
930
+ if sort_by.startswith("obs:"):
931
+ colname = sort_by.split("obs:")[1]
932
+ order = np.argsort(subset_bin.obs[colname].values)
349
933
 
350
- stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
351
- stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
352
- stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
934
+ elif sort_by == "gpc" and num_gpc > 0:
935
+ linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
936
+ order = sch.leaves_list(linkage)
353
937
 
354
- row_labels.extend([bin_label] * num_reads)
355
- bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
356
- last_idx += num_reads
357
- bin_boundaries.append(last_idx)
938
+ elif sort_by == "cpg" and num_cpg > 0:
939
+ linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
940
+ order = sch.leaves_list(linkage)
358
941
 
359
- if stacked_any_c:
942
+ elif sort_by == "any_c" and num_any_c > 0:
943
+ linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
944
+ order = sch.leaves_list(linkage)
945
+
946
+ elif sort_by == "gpc_cpg":
947
+ linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
948
+ order = sch.leaves_list(linkage)
949
+
950
+ elif sort_by == "any_a" and num_any_a > 0:
951
+ linkage = sch.linkage(subset_bin[:, any_a_sites].layers[layer_a], method="ward")
952
+ order = sch.leaves_list(linkage)
953
+
954
+ elif sort_by == "none":
955
+ order = np.arange(num_reads)
956
+
957
+ else:
958
+ order = np.arange(num_reads)
959
+
960
+ subset_bin = subset_bin[order]
961
+
962
+ # stack consistently
963
+ if include_any_c and num_any_c > 0:
964
+ stacked_any_c.append(subset_bin[:, any_c_sites].layers[layer_any_c])
965
+ if include_any_c and num_gpc > 0:
966
+ stacked_gpc.append(subset_bin[:, gpc_sites].layers[layer_gpc])
967
+ if include_any_c and num_cpg > 0:
968
+ stacked_cpg.append(subset_bin[:, cpg_sites].layers[layer_cpg])
969
+ if include_any_a and num_any_a > 0:
970
+ stacked_any_a.append(subset_bin[:, any_a_sites].layers[layer_a])
971
+
972
+ row_labels.extend([bin_label] * num_reads)
973
+ bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
974
+ last_idx += num_reads
975
+ bin_boundaries.append(last_idx)
976
+
977
+ # ----------------------------
978
+ # build matrices + means
979
+ # ----------------------------
980
+ blocks = [] # list of dicts describing what to plot in order
981
+
982
+ if include_any_c and stacked_any_c:
360
983
  any_c_matrix = np.vstack(stacked_any_c)
361
- gpc_matrix = np.vstack(stacked_gpc)
362
- cpg_matrix = np.vstack(stacked_cpg)
363
-
364
- if any_c_matrix.size > 0:
365
- def normalized_mean(matrix):
366
- mean = np.nanmean(matrix, axis=0)
367
- normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
368
- return normalized
369
-
370
- def methylation_fraction(matrix):
371
- methylated = (matrix == 1).sum(axis=0)
372
- valid = (matrix != 0).sum(axis=0)
373
- return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
374
-
375
- mean_gpc = methylation_fraction(gpc_matrix)
376
- mean_cpg = methylation_fraction(cpg_matrix)
377
- mean_any_c = methylation_fraction(any_c_matrix)
378
-
379
- fig = plt.figure(figsize=(18, 12))
380
- gs = gridspec.GridSpec(2, 3, height_ratios=[1, 6], hspace=0.01)
381
- fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
382
-
383
- axes_heat = [fig.add_subplot(gs[1, i]) for i in range(3)]
384
- axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(3)]
385
-
386
- clean_barplot(axes_bar[0], mean_any_c, f"any C site Modification Signal")
387
- clean_barplot(axes_bar[1], mean_gpc, f"GpC Modification Signal")
388
- clean_barplot(axes_bar[2], mean_cpg, f"CpG Modification Signal")
389
-
984
+ gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
985
+ cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
986
+
987
+ mean_any_c = methylation_fraction(any_c_matrix) if any_c_matrix.size else None
988
+ mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
989
+ mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
990
+
991
+ if any_c_matrix.size:
992
+ blocks.append(dict(
993
+ name="any_c",
994
+ matrix=any_c_matrix,
995
+ mean=mean_any_c,
996
+ labels=any_c_labels,
997
+ cmap=cmap_any_c,
998
+ n_xticks=n_xticks_any_c,
999
+ title="any C site Modification Signal"
1000
+ ))
1001
+ if gpc_matrix.size:
1002
+ blocks.append(dict(
1003
+ name="gpc",
1004
+ matrix=gpc_matrix,
1005
+ mean=mean_gpc,
1006
+ labels=gpc_labels,
1007
+ cmap=cmap_gpc,
1008
+ n_xticks=n_xticks_gpc,
1009
+ title="GpC Modification Signal"
1010
+ ))
1011
+ if cpg_matrix.size:
1012
+ blocks.append(dict(
1013
+ name="cpg",
1014
+ matrix=cpg_matrix,
1015
+ mean=mean_cpg,
1016
+ labels=cpg_labels,
1017
+ cmap=cmap_cpg,
1018
+ n_xticks=n_xticks_cpg,
1019
+ title="CpG Modification Signal"
1020
+ ))
1021
+
1022
+ if include_any_a and stacked_any_a:
1023
+ any_a_matrix = np.vstack(stacked_any_a)
1024
+ mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
1025
+ if any_a_matrix.size:
1026
+ blocks.append(dict(
1027
+ name="any_a",
1028
+ matrix=any_a_matrix,
1029
+ mean=mean_any_a,
1030
+ labels=any_a_labels,
1031
+ cmap=cmap_a,
1032
+ n_xticks=n_xticks_any_a,
1033
+ title="any A site Modification Signal"
1034
+ ))
1035
+
1036
+ if not blocks:
1037
+ print(f"No matrices to plot for {display_sample} - {ref}")
1038
+ continue
390
1039
 
391
- sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[0], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
392
- axes_heat[0].set_xticks(range(0, len(any_c_labels), 20))
393
- axes_heat[0].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
394
- for boundary in bin_boundaries[:-1]:
395
- axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
396
-
397
- sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
398
- axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
399
- axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
400
- for boundary in bin_boundaries[:-1]:
401
- axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
402
-
403
- sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
404
- axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
405
- for boundary in bin_boundaries[:-1]:
406
- axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
407
-
408
- plt.tight_layout()
409
-
410
- if save_path:
411
- save_name = f"{ref} — {sample}"
412
- os.makedirs(save_path, exist_ok=True)
413
- safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
414
- out_file = os.path.join(save_path, f"{safe_name}.png")
415
- plt.savefig(out_file, dpi=300)
416
- print(f"Saved: {out_file}")
417
- plt.close()
418
- else:
419
- plt.show()
420
-
421
- print(f"Summary for {sample} - {ref}:")
422
- for bin_label, percent in percentages.items():
423
- print(f" - {bin_label}: {percent:.1f}%")
424
-
425
- results.append({
426
- "sample": sample,
427
- "ref": ref,
428
- "any_c_matrix": any_c_matrix,
429
- "gpc_matrix": gpc_matrix,
430
- "cpg_matrix": cpg_matrix,
431
- "row_labels": row_labels,
432
- "bin_labels": bin_labels,
433
- "bin_boundaries": bin_boundaries,
434
- "percentages": percentages
435
- })
436
-
437
- adata.uns['clustermap_results'] = results
1040
+ gs_dim = len(blocks)
1041
+ fig = plt.figure(figsize=(5.5 * gs_dim, 11))
1042
+ gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.02)
1043
+ fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
1044
+
1045
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
1046
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
1047
+
1048
+ # ----------------------------
1049
+ # plot blocks
1050
+ # ----------------------------
1051
+ for i, blk in enumerate(blocks):
1052
+ mat = blk["matrix"]
1053
+ mean = blk["mean"]
1054
+ labels = np.asarray(blk["labels"], dtype=str)
1055
+ n_xticks = blk["n_xticks"]
1056
+
1057
+ # barplot
1058
+ clean_barplot(axes_bar[i], mean, blk["title"])
1059
+
1060
+ # heatmap
1061
+ sns.heatmap(
1062
+ mat,
1063
+ cmap=blk["cmap"],
1064
+ ax=axes_heat[i],
1065
+ yticklabels=False,
1066
+ cbar=False
1067
+ )
1068
+
1069
+ # fixed tick labels
1070
+ tick_pos = _fixed_tick_positions(len(labels), n_xticks)
1071
+ axes_heat[i].set_xticks(tick_pos)
1072
+ axes_heat[i].set_xticklabels(
1073
+ labels[tick_pos],
1074
+ rotation=xtick_rotation,
1075
+ fontsize=xtick_fontsize
1076
+ )
1077
+
1078
+ # bin separators
1079
+ for boundary in bin_boundaries[:-1]:
1080
+ axes_heat[i].axhline(y=boundary, color="black", linewidth=2)
1081
+
1082
+ axes_heat[i].set_xlabel("Position", fontsize=9)
1083
+
1084
+ plt.tight_layout()
1085
+
1086
+ # save or show
1087
+ if save_path is not None:
1088
+ safe_name = f"{ref}__{display_sample}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
1089
+ out_file = save_path / f"{safe_name}.png"
1090
+ fig.savefig(out_file, dpi=300)
1091
+ plt.close(fig)
1092
+ print(f"Saved: {out_file}")
1093
+ else:
1094
+ plt.show()
1095
+
1096
+ # record results
1097
+ rec = {
1098
+ "sample": str(sample),
1099
+ "ref": str(ref),
1100
+ "row_labels": row_labels,
1101
+ "bin_labels": bin_labels,
1102
+ "bin_boundaries": bin_boundaries,
1103
+ "percentages": percentages,
1104
+ }
1105
+ for blk in blocks:
1106
+ rec[f"{blk['name']}_matrix"] = blk["matrix"]
1107
+ rec[f"{blk['name']}_labels"] = list(map(str, blk["labels"]))
1108
+ results.append(rec)
1109
+
1110
+ print(f"Summary for {display_sample} - {ref}:")
1111
+ for bin_label, percent in percentages.items():
1112
+ print(f" - {bin_label}: {percent:.1f}%")
438
1113
 
439
1114
  except Exception as e:
440
1115
  import traceback
441
1116
  traceback.print_exc()
442
1117
  continue
443
-
444
1118
 
445
- import os
446
- import math
447
- from typing import List, Optional, Sequence, Tuple
1119
+ # store once at the end (HDF5 safe)
1120
+ # matrices won't be HDF5-safe; store only metadata + maybe hit counts
1121
+ # adata.uns["clustermap_results"] = [
1122
+ # {k: v for k, v in r.items() if not k.endswith("_matrix")}
1123
+ # for r in results
1124
+ # ]
448
1125
 
449
- import numpy as np
450
- import pandas as pd
451
- import matplotlib.pyplot as plt
1126
+ return results
452
1127
 
453
1128
  def plot_hmm_layers_rolling_by_sample_ref(
454
1129
  adata,