smftools 0.2.1__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. smftools/__init__.py +2 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/__init__.py +0 -0
  4. smftools/cli/archived/cli_flows.py +94 -0
  5. smftools/cli/helpers.py +48 -0
  6. smftools/cli/hmm_adata.py +361 -0
  7. smftools/cli/load_adata.py +637 -0
  8. smftools/cli/preprocess_adata.py +455 -0
  9. smftools/cli/spatial_adata.py +697 -0
  10. smftools/cli_entry.py +434 -0
  11. smftools/config/conversion.yaml +18 -6
  12. smftools/config/deaminase.yaml +18 -11
  13. smftools/config/default.yaml +151 -36
  14. smftools/config/direct.yaml +28 -1
  15. smftools/config/discover_input_files.py +115 -0
  16. smftools/config/experiment_config.py +225 -27
  17. smftools/hmm/HMM.py +12 -1
  18. smftools/hmm/__init__.py +0 -6
  19. smftools/hmm/archived/call_hmm_peaks.py +106 -0
  20. smftools/hmm/call_hmm_peaks.py +318 -90
  21. smftools/informatics/__init__.py +13 -7
  22. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  23. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  24. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  25. smftools/informatics/{helpers → archived/helpers/archived}/aligned_BAM_to_bed.py +6 -4
  26. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  27. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  28. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  29. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  30. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  31. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +1 -1
  32. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  33. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  34. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  35. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  36. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  38. smftools/informatics/{helpers → archived/helpers/archived}/plot_bed_histograms.py +0 -19
  39. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +6 -5
  40. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +7 -7
  41. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  42. smftools/informatics/bam_functions.py +811 -0
  43. smftools/informatics/basecalling.py +67 -0
  44. smftools/informatics/bed_functions.py +366 -0
  45. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +42 -30
  46. smftools/informatics/fasta_functions.py +255 -0
  47. smftools/informatics/h5ad_functions.py +197 -0
  48. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +142 -59
  49. smftools/informatics/modkit_functions.py +129 -0
  50. smftools/informatics/ohe.py +160 -0
  51. smftools/informatics/pod5_functions.py +224 -0
  52. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  53. smftools/plotting/autocorrelation_plotting.py +1 -3
  54. smftools/plotting/general_plotting.py +1084 -363
  55. smftools/plotting/position_stats.py +3 -3
  56. smftools/preprocessing/__init__.py +4 -4
  57. smftools/preprocessing/append_base_context.py +35 -26
  58. smftools/preprocessing/append_binary_layer_by_base_context.py +6 -6
  59. smftools/preprocessing/binarize.py +17 -0
  60. smftools/preprocessing/binarize_on_Youden.py +11 -9
  61. smftools/preprocessing/calculate_complexity_II.py +1 -1
  62. smftools/preprocessing/calculate_coverage.py +16 -13
  63. smftools/preprocessing/calculate_position_Youden.py +42 -26
  64. smftools/preprocessing/calculate_read_modification_stats.py +2 -2
  65. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +1 -1
  66. smftools/preprocessing/filter_reads_on_modification_thresholds.py +20 -20
  67. smftools/preprocessing/flag_duplicate_reads.py +2 -2
  68. smftools/preprocessing/invert_adata.py +1 -1
  69. smftools/preprocessing/load_sample_sheet.py +1 -1
  70. smftools/preprocessing/reindex_references_adata.py +37 -0
  71. smftools/readwrite.py +360 -140
  72. {smftools-0.2.1.dist-info → smftools-0.2.4.dist-info}/METADATA +26 -19
  73. smftools-0.2.4.dist-info/RECORD +176 -0
  74. smftools-0.2.4.dist-info/entry_points.txt +2 -0
  75. smftools/cli.py +0 -184
  76. smftools/informatics/fast5_to_pod5.py +0 -24
  77. smftools/informatics/helpers/__init__.py +0 -73
  78. smftools/informatics/helpers/align_and_sort_BAM.py +0 -86
  79. smftools/informatics/helpers/bam_qc.py +0 -66
  80. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  81. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -378
  82. smftools/informatics/helpers/discover_input_files.py +0 -100
  83. smftools/informatics/helpers/index_fasta.py +0 -12
  84. smftools/informatics/helpers/make_dirs.py +0 -21
  85. smftools/informatics/readwrite.py +0 -106
  86. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  87. smftools/load_adata.py +0 -1346
  88. smftools-0.2.1.dist-info/RECORD +0 -161
  89. smftools-0.2.1.dist-info/entry_points.txt +0 -2
  90. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  91. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  92. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  93. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  94. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  95. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  96. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +0 -0
  97. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  98. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  99. /smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +0 -0
  100. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  101. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  102. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  103. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  104. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  105. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  106. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  107. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  108. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  109. /smftools/informatics/{helpers/binarize_converted_base_identities.py → binarize_converted_base_identities.py} +0 -0
  110. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  111. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archives/add_read_length_and_mapping_qc.py} +0 -0
  112. /smftools/preprocessing/{calculate_complexity.py → archives/calculate_complexity.py} +0 -0
  113. {smftools-0.2.1.dist-info → smftools-0.2.4.dist-info}/WHEEL +0 -0
  114. {smftools-0.2.1.dist-info → smftools-0.2.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,93 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import seaborn as sns
3
5
  import matplotlib.pyplot as plt
6
+ import scipy.cluster.hierarchy as sch
7
+ import matplotlib.gridspec as gridspec
8
+ import os
9
+ import math
10
+ import pandas as pd
11
+
12
+ from typing import Optional, Mapping, Sequence, Any, Dict, List, Tuple
13
+ from pathlib import Path
14
+
15
+ def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
16
+ """
17
+ Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
18
+ Always includes 0 and n_positions-1 when possible.
19
+ """
20
+ n_ticks = int(max(2, n_ticks))
21
+ if n_positions <= n_ticks:
22
+ return np.arange(n_positions)
23
+
24
+ # linspace gives fixed count
25
+ pos = np.linspace(0, n_positions - 1, n_ticks)
26
+ return np.unique(np.round(pos).astype(int))
27
+
28
+ def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix: str | None):
29
+ """
30
+ Select tick labels for the heatmap axis.
31
+
32
+ Parameters
33
+ ----------
34
+ subset : AnnData view
35
+ The per-bin subset of the AnnData.
36
+ sites : np.ndarray[int]
37
+ Indices of the subset.var positions to annotate.
38
+ reference : str
39
+ Reference name (e.g., '6B6_top').
40
+ index_col_suffix : None or str
41
+ If None → use subset.var_names
42
+ Else → use subset.var[f"{reference}_{index_col_suffix}"]
43
+
44
+ Returns
45
+ -------
46
+ np.ndarray[str]
47
+ The labels to use for tick positions.
48
+ """
49
+ if sites.size == 0:
50
+ return np.array([])
51
+
52
+ # Default behavior: use var_names
53
+ if index_col_suffix is None:
54
+ return subset.var_names[sites].astype(str)
55
+
56
+ # Otherwise: use a computed column adata.var[f"{reference}_{suffix}"]
57
+ colname = f"{reference}_{index_col_suffix}"
58
+
59
+ if colname not in subset.var:
60
+ raise KeyError(
61
+ f"index_col_suffix='{index_col_suffix}' requires var column '{colname}', "
62
+ f"but it is not present in adata.var."
63
+ )
64
+
65
+ labels = subset.var[colname].astype(str).values
66
+ return labels[sites]
67
+
68
+ def normalized_mean(matrix: np.ndarray) -> np.ndarray:
69
+ mean = np.nanmean(matrix, axis=0)
70
+ denom = (mean.max() - mean.min()) + 1e-9
71
+ return (mean - mean.min()) / denom
72
+
73
+ def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
74
+ """
75
+ Fraction methylated per column.
76
+ Methylated = 1
77
+ Valid = finite AND not 0
78
+ """
79
+ matrix = np.asarray(matrix)
80
+ valid_mask = np.isfinite(matrix) & (matrix != 0)
81
+ methyl_mask = (matrix == 1) & np.isfinite(matrix)
82
+
83
+ methylated = methyl_mask.sum(axis=0)
84
+ valid = valid_mask.sum(axis=0)
85
+
86
+ return np.divide(
87
+ methylated, valid,
88
+ out=np.zeros_like(methylated, dtype=float),
89
+ where=valid != 0
90
+ )
4
91
 
5
92
  def clean_barplot(ax, mean_values, title):
6
93
  x = np.arange(len(mean_values))
@@ -17,438 +104,1072 @@ def clean_barplot(ax, mean_values, title):
17
104
 
18
105
  ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
19
106
 
107
+ # def combined_hmm_raw_clustermap(
108
+ # adata,
109
+ # sample_col='Sample_Names',
110
+ # reference_col='Reference_strand',
111
+ # hmm_feature_layer="hmm_combined",
112
+ # layer_gpc="nan0_0minus1",
113
+ # layer_cpg="nan0_0minus1",
114
+ # layer_any_c="nan0_0minus1",
115
+ # cmap_hmm="tab10",
116
+ # cmap_gpc="coolwarm",
117
+ # cmap_cpg="viridis",
118
+ # cmap_any_c='coolwarm',
119
+ # min_quality=20,
120
+ # min_length=200,
121
+ # min_mapped_length_to_reference_length_ratio=0.8,
122
+ # min_position_valid_fraction=0.5,
123
+ # sample_mapping=None,
124
+ # save_path=None,
125
+ # normalize_hmm=False,
126
+ # sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
127
+ # bins=None,
128
+ # deaminase=False,
129
+ # min_signal=0
130
+ # ):
131
+
132
+ # results = []
133
+ # if deaminase:
134
+ # signal_type = 'deamination'
135
+ # else:
136
+ # signal_type = 'methylation'
137
+
138
+ # for ref in adata.obs[reference_col].cat.categories:
139
+ # for sample in adata.obs[sample_col].cat.categories:
140
+ # try:
141
+ # subset = adata[
142
+ # (adata.obs[reference_col] == ref) &
143
+ # (adata.obs[sample_col] == sample) &
144
+ # (adata.obs['read_quality'] >= min_quality) &
145
+ # (adata.obs['read_length'] >= min_length) &
146
+ # (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
147
+ # ]
148
+
149
+ # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
150
+ # subset = subset[:, mask]
151
+
152
+ # if subset.shape[0] == 0:
153
+ # print(f" No reads left after filtering for {sample} - {ref}")
154
+ # continue
155
+
156
+ # if bins:
157
+ # print(f"Using defined bins to subset clustermap for {sample} - {ref}")
158
+ # bins_temp = bins
159
+ # else:
160
+ # print(f"Using all reads for clustermap for {sample} - {ref}")
161
+ # bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
162
+
163
+ # # Get column positions (not var_names!) of site masks
164
+ # gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
165
+ # cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
166
+ # any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
167
+ # num_gpc = len(gpc_sites)
168
+ # num_cpg = len(cpg_sites)
169
+ # num_c = len(any_c_sites)
170
+ # print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites} for {sample} - {ref}")
171
+
172
+ # # Use var_names for x-axis tick labels
173
+ # gpc_labels = subset.var_names[gpc_sites].astype(int)
174
+ # cpg_labels = subset.var_names[cpg_sites].astype(int)
175
+ # any_c_labels = subset.var_names[any_c_sites].astype(int)
176
+
177
+ # stacked_hmm_feature, stacked_gpc, stacked_cpg, stacked_any_c = [], [], [], []
178
+ # row_labels, bin_labels = [], []
179
+ # bin_boundaries = []
180
+
181
+ # total_reads = subset.shape[0]
182
+ # percentages = {}
183
+ # last_idx = 0
184
+
185
+ # for bin_label, bin_filter in bins_temp.items():
186
+ # subset_bin = subset[bin_filter].copy()
187
+ # num_reads = subset_bin.shape[0]
188
+ # print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
189
+ # percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
190
+ # percentages[bin_label] = percent_reads
191
+
192
+ # if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
193
+ # # Determine sorting order
194
+ # if sort_by.startswith("obs:"):
195
+ # colname = sort_by.split("obs:")[1]
196
+ # order = np.argsort(subset_bin.obs[colname].values)
197
+ # elif sort_by == "gpc":
198
+ # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
199
+ # order = sch.leaves_list(linkage)
200
+ # elif sort_by == "cpg":
201
+ # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
202
+ # order = sch.leaves_list(linkage)
203
+ # elif sort_by == "gpc_cpg":
204
+ # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
205
+ # order = sch.leaves_list(linkage)
206
+ # elif sort_by == "none":
207
+ # order = np.arange(num_reads)
208
+ # elif sort_by == "any_c":
209
+ # linkage = sch.linkage(subset_bin.layers[layer_any_c], method="ward")
210
+ # order = sch.leaves_list(linkage)
211
+ # else:
212
+ # raise ValueError(f"Unsupported sort_by option: {sort_by}")
213
+
214
+ # stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
215
+ # stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
216
+ # stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
217
+ # stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
218
+
219
+ # row_labels.extend([bin_label] * num_reads)
220
+ # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
221
+ # last_idx += num_reads
222
+ # bin_boundaries.append(last_idx)
223
+
224
+ # if stacked_hmm_feature:
225
+ # hmm_matrix = np.vstack(stacked_hmm_feature)
226
+ # gpc_matrix = np.vstack(stacked_gpc)
227
+ # cpg_matrix = np.vstack(stacked_cpg)
228
+ # any_c_matrix = np.vstack(stacked_any_c)
229
+
230
+ # if hmm_matrix.size > 0:
231
+ # def normalized_mean(matrix):
232
+ # mean = np.nanmean(matrix, axis=0)
233
+ # normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
234
+ # return normalized
235
+
236
+ # def methylation_fraction(matrix):
237
+ # methylated = (matrix == 1).sum(axis=0)
238
+ # valid = (matrix != 0).sum(axis=0)
239
+ # return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
240
+
241
+ # if normalize_hmm:
242
+ # mean_hmm = normalized_mean(hmm_matrix)
243
+ # else:
244
+ # mean_hmm = np.nanmean(hmm_matrix, axis=0)
245
+ # mean_gpc = methylation_fraction(gpc_matrix)
246
+ # mean_cpg = methylation_fraction(cpg_matrix)
247
+ # mean_any_c = methylation_fraction(any_c_matrix)
248
+
249
+ # fig = plt.figure(figsize=(18, 12))
250
+ # gs = gridspec.GridSpec(2, 4, height_ratios=[1, 6], hspace=0.01)
251
+ # fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
252
+
253
+ # axes_heat = [fig.add_subplot(gs[1, i]) for i in range(4)]
254
+ # axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(4)]
255
+
256
+ # clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
257
+ # clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
258
+ # clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
259
+ # clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
260
+
261
+ # hmm_labels = subset.var_names.astype(int)
262
+ # hmm_label_spacing = 150
263
+ # sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
264
+ # axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
265
+ # axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
266
+ # for boundary in bin_boundaries[:-1]:
267
+ # axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
268
+
269
+ # sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
270
+ # axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
271
+ # axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
272
+ # for boundary in bin_boundaries[:-1]:
273
+ # axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
274
+
275
+ # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
276
+ # axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
277
+ # for boundary in bin_boundaries[:-1]:
278
+ # axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
279
+
280
+ # sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[3], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
281
+ # axes_heat[3].set_xticks(range(0, len(any_c_labels), 20))
282
+ # axes_heat[3].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
283
+ # for boundary in bin_boundaries[:-1]:
284
+ # axes_heat[3].axhline(y=boundary, color="black", linewidth=2)
285
+
286
+ # plt.tight_layout()
287
+
288
+ # if save_path:
289
+ # save_name = f"{ref} — {sample}"
290
+ # os.makedirs(save_path, exist_ok=True)
291
+ # safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
292
+ # out_file = os.path.join(save_path, f"{safe_name}.png")
293
+ # plt.savefig(out_file, dpi=300)
294
+ # print(f"Saved: {out_file}")
295
+ # plt.close()
296
+ # else:
297
+ # plt.show()
298
+
299
+ # print(f"Summary for {sample} - {ref}:")
300
+ # for bin_label, percent in percentages.items():
301
+ # print(f" - {bin_label}: {percent:.1f}%")
302
+
303
+ # results.append({
304
+ # "sample": sample,
305
+ # "ref": ref,
306
+ # "hmm_matrix": hmm_matrix,
307
+ # "gpc_matrix": gpc_matrix,
308
+ # "cpg_matrix": cpg_matrix,
309
+ # "row_labels": row_labels,
310
+ # "bin_labels": bin_labels,
311
+ # "bin_boundaries": bin_boundaries,
312
+ # "percentages": percentages
313
+ # })
314
+
315
+ # #adata.uns['clustermap_results'] = results
316
+
317
+ # except Exception as e:
318
+ # import traceback
319
+ # traceback.print_exc()
320
+ # continue
321
+
20
322
  def combined_hmm_raw_clustermap(
21
323
  adata,
22
- sample_col='Sample_Names',
23
- reference_col='Reference_strand',
24
- hmm_feature_layer="hmm_combined",
25
- layer_gpc="nan0_0minus1",
26
- layer_cpg="nan0_0minus1",
27
- layer_any_c="nan0_0minus1",
28
- cmap_hmm="tab10",
29
- cmap_gpc="coolwarm",
30
- cmap_cpg="viridis",
31
- cmap_any_c='coolwarm',
32
- min_quality=20,
33
- min_length=200,
34
- min_mapped_length_to_reference_length_ratio=0.8,
35
- min_position_valid_fraction=0.5,
36
- sample_mapping=None,
37
- save_path=None,
38
- normalize_hmm=False,
39
- sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
40
- bins=None,
41
- deaminase=False,
42
- min_signal=0
43
- ):
44
- import scipy.cluster.hierarchy as sch
45
- import pandas as pd
46
- import numpy as np
47
- import seaborn as sns
48
- import matplotlib.pyplot as plt
49
- import matplotlib.gridspec as gridspec
50
- import os
324
+ sample_col: str = "Sample_Names",
325
+ reference_col: str = "Reference_strand",
51
326
 
327
+ hmm_feature_layer: str = "hmm_combined",
328
+
329
+ layer_gpc: str = "nan0_0minus1",
330
+ layer_cpg: str = "nan0_0minus1",
331
+ layer_c: str = "nan0_0minus1",
332
+ layer_a: str = "nan0_0minus1",
333
+
334
+ cmap_hmm: str = "tab10",
335
+ cmap_gpc: str = "coolwarm",
336
+ cmap_cpg: str = "viridis",
337
+ cmap_c: str = "coolwarm",
338
+ cmap_a: str = "coolwarm",
339
+
340
+ min_quality: int = 20,
341
+ min_length: int = 200,
342
+ min_mapped_length_to_reference_length_ratio: float = 0.8,
343
+ min_position_valid_fraction: float = 0.5,
344
+
345
+ save_path: str | Path | None = None,
346
+ normalize_hmm: bool = False,
347
+
348
+ sort_by: str = "gpc",
349
+ bins: Optional[Dict[str, Any]] = None,
350
+
351
+ deaminase: bool = False,
352
+ min_signal: float = 0.0,
353
+
354
+ # ---- fixed tick label controls (counts, not spacing)
355
+ n_xticks_hmm: int = 10,
356
+ n_xticks_any_c: int = 8,
357
+ n_xticks_gpc: int = 8,
358
+ n_xticks_cpg: int = 8,
359
+ n_xticks_a: int = 8,
360
+
361
+ index_col_suffix: str | None = None,
362
+ ):
363
+ """
364
+ Makes a multi-panel clustermap per (sample, reference):
365
+ HMM panel (always) + optional raw panels for C, GpC, CpG, and A sites.
366
+
367
+ Panels are added only if the corresponding site mask exists AND has >0 sites.
368
+
369
+ sort_by options:
370
+ 'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
371
+ """
372
+ def pick_xticks(labels: np.ndarray, n_ticks: int):
373
+ if labels.size == 0:
374
+ return [], []
375
+ idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
376
+ idx = np.unique(idx)
377
+ return idx.tolist(), labels[idx].tolist()
378
+
52
379
  results = []
53
- if deaminase:
54
- signal_type = 'deamination'
55
- else:
56
- signal_type = 'methylation'
380
+ signal_type = "deamination" if deaminase else "methylation"
57
381
 
58
382
  for ref in adata.obs[reference_col].cat.categories:
59
383
  for sample in adata.obs[sample_col].cat.categories:
384
+
60
385
  try:
386
+ # ---- subset reads ----
61
387
  subset = adata[
62
388
  (adata.obs[reference_col] == ref) &
63
389
  (adata.obs[sample_col] == sample) &
64
- (adata.obs['read_quality'] >= min_quality) &
65
- (adata.obs['read_length'] >= min_length) &
66
- (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
390
+ (adata.obs["read_quality"] >= min_quality) &
391
+ (adata.obs["read_length"] >= min_length) &
392
+ (
393
+ adata.obs["mapped_length_to_reference_length_ratio"]
394
+ > min_mapped_length_to_reference_length_ratio
395
+ )
67
396
  ]
68
-
69
- mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
70
- subset = subset[:, mask]
397
+
398
+ # ---- valid fraction filter ----
399
+ vf_key = f"{ref}_valid_fraction"
400
+ if vf_key in subset.var:
401
+ mask = subset.var[vf_key].astype(float) > float(min_position_valid_fraction)
402
+ subset = subset[:, mask]
71
403
 
72
404
  if subset.shape[0] == 0:
73
- print(f" No reads left after filtering for {sample} - {ref}")
74
405
  continue
75
406
 
76
- if bins:
77
- print(f"Using defined bins to subset clustermap for {sample} - {ref}")
78
- bins_temp = bins
407
+ # ---- bins ----
408
+ if bins is None:
409
+ bins_temp = {"All": np.ones(subset.n_obs, dtype=bool)}
79
410
  else:
80
- print(f"Using all reads for clustermap for {sample} - {ref}")
81
- bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
82
-
83
- # Get column positions (not var_names!) of site masks
84
- gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
85
- cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
86
- any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
87
- num_gpc = len(gpc_sites)
88
- num_cpg = len(cpg_sites)
89
- num_c = len(any_c_sites)
90
- print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites} for {sample} - {ref}")
91
-
92
- # Use var_names for x-axis tick labels
93
- gpc_labels = subset.var_names[gpc_sites].astype(int)
94
- cpg_labels = subset.var_names[cpg_sites].astype(int)
95
- any_c_labels = subset.var_names[any_c_sites].astype(int)
96
-
97
- stacked_hmm_feature, stacked_gpc, stacked_cpg, stacked_any_c = [], [], [], []
98
- row_labels, bin_labels = [], []
99
- bin_boundaries = []
411
+ bins_temp = bins
100
412
 
101
- total_reads = subset.shape[0]
413
+ # ---- site masks (robust) ----
414
+ def _sites(*keys):
415
+ for k in keys:
416
+ if k in subset.var:
417
+ return np.where(subset.var[k].values)[0]
418
+ return np.array([], dtype=int)
419
+
420
+ gpc_sites = _sites(f"{ref}_GpC_site")
421
+ cpg_sites = _sites(f"{ref}_CpG_site")
422
+ any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
423
+ any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
424
+
425
+ # ---- labels via _select_labels ----
426
+ # HMM uses *all* columns
427
+ hmm_sites = np.arange(subset.n_vars, dtype=int)
428
+ hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
429
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
430
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
431
+ any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
432
+ any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
433
+
434
+ # storage
435
+ stacked_hmm = []
436
+ stacked_any_c = []
437
+ stacked_gpc = []
438
+ stacked_cpg = []
439
+ stacked_any_a = []
440
+
441
+ row_labels, bin_labels, bin_boundaries = [], [], []
442
+ total_reads = subset.n_obs
102
443
  percentages = {}
103
444
  last_idx = 0
104
445
 
446
+ # ---------------- process bins ----------------
105
447
  for bin_label, bin_filter in bins_temp.items():
106
- subset_bin = subset[bin_filter].copy()
107
- num_reads = subset_bin.shape[0]
108
- print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
109
- percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
110
- percentages[bin_label] = percent_reads
448
+ sb = subset[bin_filter].copy()
449
+ n = sb.n_obs
450
+ if n == 0:
451
+ continue
111
452
 
112
- if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
113
- # Determine sorting order
114
- if sort_by.startswith("obs:"):
115
- colname = sort_by.split("obs:")[1]
116
- order = np.argsort(subset_bin.obs[colname].values)
117
- elif sort_by == "gpc":
118
- linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
119
- order = sch.leaves_list(linkage)
120
- elif sort_by == "cpg":
121
- linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
122
- order = sch.leaves_list(linkage)
123
- elif sort_by == "gpc_cpg":
124
- linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
125
- order = sch.leaves_list(linkage)
126
- elif sort_by == "none":
127
- order = np.arange(num_reads)
128
- elif sort_by == "any_c":
129
- linkage = sch.linkage(subset_bin.layers[layer_any_c], method="ward")
130
- order = sch.leaves_list(linkage)
131
- else:
132
- raise ValueError(f"Unsupported sort_by option: {sort_by}")
133
-
134
- stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
135
- stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
136
- stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
137
- stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
138
-
139
- row_labels.extend([bin_label] * num_reads)
140
- bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
141
- last_idx += num_reads
142
- bin_boundaries.append(last_idx)
143
-
144
- if stacked_hmm_feature:
145
- hmm_matrix = np.vstack(stacked_hmm_feature)
146
- gpc_matrix = np.vstack(stacked_gpc)
147
- cpg_matrix = np.vstack(stacked_cpg)
148
- any_c_matrix = np.vstack(stacked_any_c)
453
+ pct = (n / total_reads) * 100 if total_reads else 0
454
+ percentages[bin_label] = pct
149
455
 
150
- if hmm_matrix.size > 0:
151
- def normalized_mean(matrix):
152
- mean = np.nanmean(matrix, axis=0)
153
- normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
154
- return normalized
456
+ # ---- sorting ----
457
+ if sort_by.startswith("obs:"):
458
+ colname = sort_by.split("obs:")[1]
459
+ order = np.argsort(sb.obs[colname].values)
155
460
 
156
- def methylation_fraction(matrix):
157
- methylated = (matrix == 1).sum(axis=0)
158
- valid = (matrix != 0).sum(axis=0)
159
- return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
461
+ elif sort_by == "gpc" and gpc_sites.size:
462
+ linkage = sch.linkage(sb[:, gpc_sites].layers[layer_gpc], method="ward")
463
+ order = sch.leaves_list(linkage)
160
464
 
161
- if normalize_hmm:
162
- mean_hmm = normalized_mean(hmm_matrix)
163
- else:
164
- mean_hmm = np.nanmean(hmm_matrix, axis=0)
165
- mean_gpc = methylation_fraction(gpc_matrix)
166
- mean_cpg = methylation_fraction(cpg_matrix)
167
- mean_any_c = methylation_fraction(any_c_matrix)
168
-
169
- fig = plt.figure(figsize=(18, 12))
170
- gs = gridspec.GridSpec(2, 4, height_ratios=[1, 6], hspace=0.01)
171
- fig.suptitle(f"{sample} - {ref}", fontsize=14, y=0.95)
172
-
173
- axes_heat = [fig.add_subplot(gs[1, i]) for i in range(4)]
174
- axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(4)]
175
-
176
- clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
177
- clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
178
- clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
179
- clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
180
-
181
- hmm_labels = subset.var_names.astype(int)
182
- hmm_label_spacing = 150
183
- sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
184
- axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
185
- axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
186
- for boundary in bin_boundaries[:-1]:
187
- axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
188
-
189
- sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
190
- axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
191
- axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
192
- for boundary in bin_boundaries[:-1]:
193
- axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
194
-
195
- sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
196
- axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
197
- for boundary in bin_boundaries[:-1]:
198
- axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
199
-
200
- sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[3], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
201
- axes_heat[3].set_xticks(range(0, len(any_c_labels), 20))
202
- axes_heat[3].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
203
- for boundary in bin_boundaries[:-1]:
204
- axes_heat[3].axhline(y=boundary, color="black", linewidth=2)
205
-
206
- plt.tight_layout()
207
-
208
- if save_path:
209
- save_name = f"{ref} — {sample}"
210
- os.makedirs(save_path, exist_ok=True)
211
- safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
212
- out_file = os.path.join(save_path, f"{safe_name}.png")
213
- plt.savefig(out_file, dpi=300)
214
- print(f"Saved: {out_file}")
215
- plt.close()
216
- else:
217
- plt.show()
218
-
219
- print(f"Summary for {sample} - {ref}:")
220
- for bin_label, percent in percentages.items():
221
- print(f" - {bin_label}: {percent:.1f}%")
222
-
223
- results.append({
224
- "sample": sample,
225
- "ref": ref,
226
- "hmm_matrix": hmm_matrix,
227
- "gpc_matrix": gpc_matrix,
228
- "cpg_matrix": cpg_matrix,
229
- "row_labels": row_labels,
230
- "bin_labels": bin_labels,
231
- "bin_boundaries": bin_boundaries,
232
- "percentages": percentages
233
- })
234
-
235
- adata.uns['clustermap_results'] = results
465
+ elif sort_by == "cpg" and cpg_sites.size:
466
+ linkage = sch.linkage(sb[:, cpg_sites].layers[layer_cpg], method="ward")
467
+ order = sch.leaves_list(linkage)
236
468
 
237
- except Exception as e:
469
+ elif sort_by == "c" and any_c_sites.size:
470
+ linkage = sch.linkage(sb[:, any_c_sites].layers[layer_c], method="ward")
471
+ order = sch.leaves_list(linkage)
472
+
473
+ elif sort_by == "a" and any_a_sites.size:
474
+ linkage = sch.linkage(sb[:, any_a_sites].layers[layer_a], method="ward")
475
+ order = sch.leaves_list(linkage)
476
+
477
+ elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
478
+ linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
479
+ order = sch.leaves_list(linkage)
480
+
481
+ elif sort_by == "hmm" and hmm_sites.size:
482
+ linkage = sch.linkage(sb[:, hmm_sites].layers[hmm_feature_layer], method="ward")
483
+ order = sch.leaves_list(linkage)
484
+
485
+ else:
486
+ order = np.arange(n)
487
+
488
+ sb = sb[order]
489
+
490
+ # ---- collect matrices ----
491
+ stacked_hmm.append(sb.layers[hmm_feature_layer])
492
+ if any_c_sites.size:
493
+ stacked_any_c.append(sb[:, any_c_sites].layers[layer_c])
494
+ if gpc_sites.size:
495
+ stacked_gpc.append(sb[:, gpc_sites].layers[layer_gpc])
496
+ if cpg_sites.size:
497
+ stacked_cpg.append(sb[:, cpg_sites].layers[layer_cpg])
498
+ if any_a_sites.size:
499
+ stacked_any_a.append(sb[:, any_a_sites].layers[layer_a])
500
+
501
+ row_labels.extend([bin_label] * n)
502
+ bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
503
+ last_idx += n
504
+ bin_boundaries.append(last_idx)
505
+
506
+ # ---------------- stack ----------------
507
+ hmm_matrix = np.vstack(stacked_hmm)
508
+ mean_hmm = normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
509
+
510
+ panels = [
511
+ (f"HMM - {hmm_feature_layer}", hmm_matrix, hmm_labels, cmap_hmm, mean_hmm, n_xticks_hmm),
512
+ ]
513
+
514
+ if stacked_any_c:
515
+ m = np.vstack(stacked_any_c)
516
+ panels.append(("C", m, any_c_labels, cmap_c, methylation_fraction(m), n_xticks_any_c))
517
+
518
+ if stacked_gpc:
519
+ m = np.vstack(stacked_gpc)
520
+ panels.append(("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc))
521
+
522
+ if stacked_cpg:
523
+ m = np.vstack(stacked_cpg)
524
+ panels.append(("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg))
525
+
526
+ if stacked_any_a:
527
+ m = np.vstack(stacked_any_a)
528
+ panels.append(("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a))
529
+
530
+ # ---------------- plotting ----------------
531
+ n_panels = len(panels)
532
+ fig = plt.figure(figsize=(4.5 * n_panels, 10))
533
+ gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
534
+ fig.suptitle(f"{sample} — {ref} — {total_reads} reads ({signal_type})",
535
+ fontsize=14, y=0.98)
536
+
537
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
538
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
539
+
540
+ for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
541
+
542
+ # ---- your clean barplot ----
543
+ clean_barplot(axes_bar[i], mean_vec, name)
544
+
545
+ # ---- heatmap ----
546
+ sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i],
547
+ yticklabels=False, cbar=False)
548
+
549
+ # ---- xticks ----
550
+ xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
551
+ axes_heat[i].set_xticks(xtick_pos)
552
+ axes_heat[i].set_xticklabels(xtick_labels, rotation=90, fontsize=8)
553
+
554
+ for boundary in bin_boundaries[:-1]:
555
+ axes_heat[i].axhline(y=boundary, color="black", linewidth=1.2)
556
+
557
+ plt.tight_layout()
558
+
559
+ if save_path:
560
+ save_path = Path(save_path)
561
+ save_path.mkdir(parents=True, exist_ok=True)
562
+ safe_name = f"{ref}__{sample}".replace("/", "_")
563
+ out_file = save_path / f"{safe_name}.png"
564
+ plt.savefig(out_file, dpi=300)
565
+ plt.close(fig)
566
+ else:
567
+ plt.show()
568
+
569
+ except Exception:
238
570
  import traceback
239
571
  traceback.print_exc()
240
572
  continue
241
573
 
242
574
 
575
+ # def combined_raw_clustermap(
576
+ # adata,
577
+ # sample_col='Sample_Names',
578
+ # reference_col='Reference_strand',
579
+ # mod_target_bases=['GpC', 'CpG'],
580
+ # layer_any_c="nan0_0minus1",
581
+ # layer_gpc="nan0_0minus1",
582
+ # layer_cpg="nan0_0minus1",
583
+ # layer_a="nan0_0minus1",
584
+ # cmap_any_c="coolwarm",
585
+ # cmap_gpc="coolwarm",
586
+ # cmap_cpg="viridis",
587
+ # cmap_a="coolwarm",
588
+ # min_quality=20,
589
+ # min_length=200,
590
+ # min_mapped_length_to_reference_length_ratio=0.8,
591
+ # min_position_valid_fraction=0.5,
592
+ # sample_mapping=None,
593
+ # save_path=None,
594
+ # sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', 'any_a', or 'obs:<column>'
595
+ # bins=None,
596
+ # deaminase=False,
597
+ # min_signal=0
598
+ # ):
599
+
600
+ # results = []
601
+
602
+ # for ref in adata.obs[reference_col].cat.categories:
603
+ # for sample in adata.obs[sample_col].cat.categories:
604
+ # try:
605
+ # subset = adata[
606
+ # (adata.obs[reference_col] == ref) &
607
+ # (adata.obs[sample_col] == sample) &
608
+ # (adata.obs['read_quality'] >= min_quality) &
609
+ # (adata.obs['mapped_length'] >= min_length) &
610
+ # (adata.obs['mapped_length_to_reference_length_ratio'] >= min_mapped_length_to_reference_length_ratio)
611
+ # ]
612
+
613
+ # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
614
+ # subset = subset[:, mask]
615
+
616
+ # if subset.shape[0] == 0:
617
+ # print(f" No reads left after filtering for {sample} - {ref}")
618
+ # continue
619
+
620
+ # if bins:
621
+ # print(f"Using defined bins to subset clustermap for {sample} - {ref}")
622
+ # bins_temp = bins
623
+ # else:
624
+ # print(f"Using all reads for clustermap for {sample} - {ref}")
625
+ # bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
626
+
627
+ # num_any_c = 0
628
+ # num_gpc = 0
629
+ # num_cpg = 0
630
+ # num_any_a = 0
631
+
632
+ # # Get column positions (not var_names!) of site masks
633
+ # if any(base in ["C", "CpG", "GpC"] for base in mod_target_bases):
634
+ # any_c_sites = np.where(subset.var[f"{ref}_C_site"].values)[0]
635
+ # gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
636
+ # cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
637
+ # num_any_c = len(any_c_sites)
638
+ # num_gpc = len(gpc_sites)
639
+ # num_cpg = len(cpg_sites)
640
+ # print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites}\n and {num_any_c} any_C sites at {any_c_sites} for {sample} - {ref}")
641
+
642
+ # # Use var_names for x-axis tick labels
643
+ # gpc_labels = subset.var_names[gpc_sites].astype(int)
644
+ # cpg_labels = subset.var_names[cpg_sites].astype(int)
645
+ # any_c_labels = subset.var_names[any_c_sites].astype(int)
646
+ # stacked_any_c, stacked_gpc, stacked_cpg = [], [], []
647
+
648
+ # if "A" in mod_target_bases:
649
+ # any_a_sites = np.where(subset.var[f"{ref}_A_site"].values)[0]
650
+ # num_any_a = len(any_a_sites)
651
+ # print(f"Found {num_any_a} any_A sites at {any_a_sites} for {sample} - {ref}")
652
+ # any_a_labels = subset.var_names[any_a_sites].astype(int)
653
+ # stacked_any_a = []
654
+
655
+ # row_labels, bin_labels = [], []
656
+ # bin_boundaries = []
657
+
658
+ # total_reads = subset.shape[0]
659
+ # percentages = {}
660
+ # last_idx = 0
661
+
662
+ # for bin_label, bin_filter in bins_temp.items():
663
+ # subset_bin = subset[bin_filter].copy()
664
+ # num_reads = subset_bin.shape[0]
665
+ # print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
666
+ # percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
667
+ # percentages[bin_label] = percent_reads
668
+
669
+ # if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
670
+ # # Determine sorting order
671
+ # if sort_by.startswith("obs:"):
672
+ # colname = sort_by.split("obs:")[1]
673
+ # order = np.argsort(subset_bin.obs[colname].values)
674
+ # elif sort_by == "gpc":
675
+ # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
676
+ # order = sch.leaves_list(linkage)
677
+ # elif sort_by == "cpg":
678
+ # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
679
+ # order = sch.leaves_list(linkage)
680
+ # elif sort_by == "any_c":
681
+ # linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
682
+ # order = sch.leaves_list(linkage)
683
+ # elif sort_by == "gpc_cpg":
684
+ # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
685
+ # order = sch.leaves_list(linkage)
686
+ # elif sort_by == "none":
687
+ # order = np.arange(num_reads)
688
+ # elif sort_by == "any_a":
689
+ # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
690
+ # order = sch.leaves_list(linkage)
691
+ # else:
692
+ # raise ValueError(f"Unsupported sort_by option: {sort_by}")
693
+
694
+ # stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
695
+ # stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
696
+ # stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
697
+
698
+ # if num_reads > 0 and num_any_a > 0:
699
+ # # Determine sorting order
700
+ # if sort_by.startswith("obs:"):
701
+ # colname = sort_by.split("obs:")[1]
702
+ # order = np.argsort(subset_bin.obs[colname].values)
703
+ # elif sort_by == "gpc":
704
+ # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
705
+ # order = sch.leaves_list(linkage)
706
+ # elif sort_by == "cpg":
707
+ # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
708
+ # order = sch.leaves_list(linkage)
709
+ # elif sort_by == "any_c":
710
+ # linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
711
+ # order = sch.leaves_list(linkage)
712
+ # elif sort_by == "gpc_cpg":
713
+ # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
714
+ # order = sch.leaves_list(linkage)
715
+ # elif sort_by == "none":
716
+ # order = np.arange(num_reads)
717
+ # elif sort_by == "any_a":
718
+ # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
719
+ # order = sch.leaves_list(linkage)
720
+ # else:
721
+ # raise ValueError(f"Unsupported sort_by option: {sort_by}")
722
+
723
+ # stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
724
+
725
+
726
+ # row_labels.extend([bin_label] * num_reads)
727
+ # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
728
+ # last_idx += num_reads
729
+ # bin_boundaries.append(last_idx)
730
+
731
+ # gs_dim = 0
732
+
733
+ # if stacked_any_c:
734
+ # any_c_matrix = np.vstack(stacked_any_c)
735
+ # gpc_matrix = np.vstack(stacked_gpc)
736
+ # cpg_matrix = np.vstack(stacked_cpg)
737
+ # if any_c_matrix.size > 0:
738
+ # mean_gpc = methylation_fraction(gpc_matrix)
739
+ # mean_cpg = methylation_fraction(cpg_matrix)
740
+ # mean_any_c = methylation_fraction(any_c_matrix)
741
+ # gs_dim += 3
742
+
743
+ # if stacked_any_a:
744
+ # any_a_matrix = np.vstack(stacked_any_a)
745
+ # if any_a_matrix.size > 0:
746
+ # mean_any_a = methylation_fraction(any_a_matrix)
747
+ # gs_dim += 1
748
+
749
+
750
+ # fig = plt.figure(figsize=(18, 12))
751
+ # gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
752
+ # fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
753
+ # axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
754
+ # axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
755
+
756
+ # current_ax = 0
757
+
758
+ # if stacked_any_c:
759
+ # if any_c_matrix.size > 0:
760
+ # clean_barplot(axes_bar[current_ax], mean_any_c, f"any C site Modification Signal")
761
+ # sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[current_ax], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
762
+ # axes_heat[current_ax].set_xticks(range(0, len(any_c_labels), 20))
763
+ # axes_heat[current_ax].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
764
+ # for boundary in bin_boundaries[:-1]:
765
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
766
+ # current_ax +=1
767
+
768
+ # clean_barplot(axes_bar[current_ax], mean_gpc, f"GpC Modification Signal")
769
+ # sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[current_ax], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
770
+ # axes_heat[current_ax].set_xticks(range(0, len(gpc_labels), 5))
771
+ # axes_heat[current_ax].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
772
+ # for boundary in bin_boundaries[:-1]:
773
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
774
+ # current_ax +=1
775
+
776
+ # clean_barplot(axes_bar[current_ax], mean_cpg, f"CpG Modification Signal")
777
+ # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
778
+ # axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
779
+ # for boundary in bin_boundaries[:-1]:
780
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
781
+ # current_ax +=1
782
+
783
+ # results.append({
784
+ # "sample": sample,
785
+ # "ref": ref,
786
+ # "any_c_matrix": any_c_matrix,
787
+ # "gpc_matrix": gpc_matrix,
788
+ # "cpg_matrix": cpg_matrix,
789
+ # "row_labels": row_labels,
790
+ # "bin_labels": bin_labels,
791
+ # "bin_boundaries": bin_boundaries,
792
+ # "percentages": percentages
793
+ # })
794
+
795
+ # if stacked_any_a:
796
+ # if any_a_matrix.size > 0:
797
+ # clean_barplot(axes_bar[current_ax], mean_any_a, f"any A site Modification Signal")
798
+ # sns.heatmap(any_a_matrix, cmap=cmap_a, ax=axes_heat[current_ax], xticklabels=any_a_labels[::20], yticklabels=False, cbar=False)
799
+ # axes_heat[current_ax].set_xticks(range(0, len(any_a_labels), 20))
800
+ # axes_heat[current_ax].set_xticklabels(any_a_labels[::20], rotation=90, fontsize=10)
801
+ # for boundary in bin_boundaries[:-1]:
802
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
803
+ # current_ax +=1
804
+
805
+ # results.append({
806
+ # "sample": sample,
807
+ # "ref": ref,
808
+ # "any_a_matrix": any_a_matrix,
809
+ # "row_labels": row_labels,
810
+ # "bin_labels": bin_labels,
811
+ # "bin_boundaries": bin_boundaries,
812
+ # "percentages": percentages
813
+ # })
814
+
815
+ # plt.tight_layout()
816
+
817
+ # if save_path:
818
+ # save_name = f"{ref} — {sample}"
819
+ # os.makedirs(save_path, exist_ok=True)
820
+ # safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
821
+ # out_file = os.path.join(save_path, f"{safe_name}.png")
822
+ # plt.savefig(out_file, dpi=300)
823
+ # print(f"Saved: {out_file}")
824
+ # plt.close()
825
+ # else:
826
+ # plt.show()
827
+
828
+ # print(f"Summary for {sample} - {ref}:")
829
+ # for bin_label, percent in percentages.items():
830
+ # print(f" - {bin_label}: {percent:.1f}%")
831
+
832
+ # adata.uns['clustermap_results'] = results
833
+
834
+ # except Exception as e:
835
+ # import traceback
836
+ # traceback.print_exc()
837
+ # continue
838
+
243
839
  def combined_raw_clustermap(
244
840
  adata,
245
- sample_col='Sample_Names',
246
- reference_col='Reference_strand',
247
- layer_any_c="nan0_0minus1",
248
- layer_gpc="nan0_0minus1",
249
- layer_cpg="nan0_0minus1",
250
- cmap_any_c="coolwarm",
251
- cmap_gpc="coolwarm",
252
- cmap_cpg="viridis",
253
- min_quality=20,
254
- min_length=200,
255
- min_mapped_length_to_reference_length_ratio=0.8,
256
- min_position_valid_fraction=0.5,
257
- sample_mapping=None,
258
- save_path=None,
259
- sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
260
- bins=None,
261
- deaminase=False,
262
- min_signal=0
263
- ):
264
- import scipy.cluster.hierarchy as sch
265
- import pandas as pd
266
- import numpy as np
267
- import seaborn as sns
268
- import matplotlib.pyplot as plt
269
- import matplotlib.gridspec as gridspec
270
- import os
841
+ sample_col: str = "Sample_Names",
842
+ reference_col: str = "Reference_strand",
843
+ mod_target_bases: Sequence[str] = ("GpC", "CpG"),
844
+ layer_c: str = "nan0_0minus1",
845
+ layer_gpc: str = "nan0_0minus1",
846
+ layer_cpg: str = "nan0_0minus1",
847
+ layer_a: str = "nan0_0minus1",
848
+ cmap_c: str = "coolwarm",
849
+ cmap_gpc: str = "coolwarm",
850
+ cmap_cpg: str = "viridis",
851
+ cmap_a: str = "coolwarm",
852
+ min_quality: float = 20,
853
+ min_length: int = 200,
854
+ min_mapped_length_to_reference_length_ratio: float = 0.8,
855
+ min_position_valid_fraction: float = 0.5,
856
+ sample_mapping: Optional[Mapping[str, str]] = None,
857
+ save_path: str | Path | None = None,
858
+ sort_by: str = "gpc", # 'gpc','cpg','c','gpc_cpg','a','none','obs:<col>'
859
+ bins: Optional[Dict[str, Any]] = None,
860
+ deaminase: bool = False,
861
+ min_signal: float = 0,
862
+ n_xticks_any_c: int = 10,
863
+ n_xticks_gpc: int = 10,
864
+ n_xticks_cpg: int = 10,
865
+ n_xticks_any_a: int = 10,
866
+ xtick_rotation: int = 90,
867
+ xtick_fontsize: int = 9,
868
+ index_col_suffix: str | None = None,
869
+ ):
870
+ """
871
+ Plot stacked heatmaps + per-position mean barplots for C, GpC, CpG, and optional A.
271
872
 
272
- results = []
873
+ Key fixes vs old version:
874
+ - order computed ONCE per bin, applied to all matrices
875
+ - no hard-coded axes indices
876
+ - NaNs excluded from methylation denominators
877
+ - var_names not forced to int
878
+ - fixed count of x tick labels per block (controllable)
879
+ - adata.uns updated once at end
880
+
881
+ Returns
882
+ -------
883
+ results : list[dict]
884
+ One entry per (sample, ref) plot with matrices + bin metadata.
885
+ """
886
+
887
+ results: List[Dict[str, Any]] = []
888
+ save_path = Path(save_path) if save_path is not None else None
889
+ if save_path is not None:
890
+ save_path.mkdir(parents=True, exist_ok=True)
891
+
892
+ # Ensure categorical
893
+ for col in (sample_col, reference_col):
894
+ if col not in adata.obs:
895
+ raise KeyError(f"{col} not in adata.obs")
896
+ if not pd.api.types.is_categorical_dtype(adata.obs[col]):
897
+ adata.obs[col] = adata.obs[col].astype("category")
898
+
899
+ base_set = set(mod_target_bases)
900
+ include_any_c = any(b in {"C", "CpG", "GpC"} for b in base_set)
901
+ include_any_a = "A" in base_set
273
902
 
274
903
  for ref in adata.obs[reference_col].cat.categories:
275
904
  for sample in adata.obs[sample_col].cat.categories:
905
+
906
+ # Optionally remap sample label for display
907
+ display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
908
+
276
909
  try:
277
910
  subset = adata[
278
911
  (adata.obs[reference_col] == ref) &
279
912
  (adata.obs[sample_col] == sample) &
280
- (adata.obs['read_quality'] >= min_quality) &
281
- (adata.obs['mapped_length'] >= min_length) &
282
- (adata.obs['mapped_length_to_reference_length_ratio'] >= min_mapped_length_to_reference_length_ratio)
913
+ (adata.obs["read_quality"] >= min_quality) &
914
+ (adata.obs["mapped_length"] >= min_length) &
915
+ (adata.obs["mapped_length_to_reference_length_ratio"] >= min_mapped_length_to_reference_length_ratio)
283
916
  ]
284
917
 
285
- mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
286
- subset = subset[:, mask]
918
+ # position-level mask
919
+ valid_key = f"{ref}_valid_fraction"
920
+ if valid_key in subset.var:
921
+ mask = subset.var[valid_key].astype(float).values > float(min_position_valid_fraction)
922
+ subset = subset[:, mask]
287
923
 
288
924
  if subset.shape[0] == 0:
289
- print(f" No reads left after filtering for {sample} - {ref}")
925
+ print(f"No reads left after filtering for {display_sample} - {ref}")
290
926
  continue
291
927
 
292
- if bins:
293
- print(f"Using defined bins to subset clustermap for {sample} - {ref}")
294
- bins_temp = bins
928
+ # bins mode
929
+ if bins is None:
930
+ bins_temp = {"All": (subset.obs[reference_col] == ref)}
295
931
  else:
296
- print(f"Using all reads for clustermap for {sample} - {ref}")
297
- bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
298
-
299
- # Get column positions (not var_names!) of site masks
300
- any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
301
- gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
302
- cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
303
- num_any_c = len(any_c_sites)
304
- num_gpc = len(gpc_sites)
305
- num_cpg = len(cpg_sites)
306
- print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites}\n and {num_any_c} any_C sites at {any_c_sites} for {sample} - {ref}")
307
-
308
- # Use var_names for x-axis tick labels
309
- gpc_labels = subset.var_names[gpc_sites].astype(int)
310
- cpg_labels = subset.var_names[cpg_sites].astype(int)
311
- any_c_labels = subset.var_names[any_c_sites].astype(int)
312
-
313
- stacked_any_c, stacked_gpc, stacked_cpg = [], [], []
314
- row_labels, bin_labels = [], []
315
- bin_boundaries = []
932
+ bins_temp = bins
316
933
 
317
- total_reads = subset.shape[0]
934
+ # find sites (positions)
935
+ any_c_sites = gpc_sites = cpg_sites = np.array([], dtype=int)
936
+ any_a_sites = np.array([], dtype=int)
937
+
938
+ num_any_c = num_gpc = num_cpg = num_any_a = 0
939
+
940
+ if include_any_c:
941
+ any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
942
+ gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
943
+ cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
944
+
945
+ num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
946
+
947
+ any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
948
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
949
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
950
+
951
+ if include_any_a:
952
+ any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
953
+ num_any_a = len(any_a_sites)
954
+ any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
955
+
956
+ stacked_any_c, stacked_gpc, stacked_cpg, stacked_any_a = [], [], [], []
957
+ row_labels, bin_labels, bin_boundaries = [], [], []
318
958
  percentages = {}
319
959
  last_idx = 0
960
+ total_reads = subset.shape[0]
320
961
 
962
+ # ----------------------------
963
+ # per-bin stacking
964
+ # ----------------------------
321
965
  for bin_label, bin_filter in bins_temp.items():
322
966
  subset_bin = subset[bin_filter].copy()
323
967
  num_reads = subset_bin.shape[0]
324
- print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
325
- percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
968
+ if num_reads == 0:
969
+ percentages[bin_label] = 0.0
970
+ continue
971
+
972
+ percent_reads = (num_reads / total_reads) * 100
326
973
  percentages[bin_label] = percent_reads
327
974
 
328
- if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
329
- # Determine sorting order
330
- if sort_by.startswith("obs:"):
331
- colname = sort_by.split("obs:")[1]
332
- order = np.argsort(subset_bin.obs[colname].values)
333
- elif sort_by == "gpc":
334
- linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
335
- order = sch.leaves_list(linkage)
336
- elif sort_by == "cpg":
337
- linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
338
- order = sch.leaves_list(linkage)
339
- elif sort_by == "any_c":
340
- linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
341
- order = sch.leaves_list(linkage)
342
- elif sort_by == "gpc_cpg":
343
- linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
344
- order = sch.leaves_list(linkage)
345
- elif sort_by == "none":
346
- order = np.arange(num_reads)
347
- else:
348
- raise ValueError(f"Unsupported sort_by option: {sort_by}")
975
+ # compute order ONCE
976
+ if sort_by.startswith("obs:"):
977
+ colname = sort_by.split("obs:")[1]
978
+ order = np.argsort(subset_bin.obs[colname].values)
349
979
 
350
- stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
351
- stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
352
- stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
980
+ elif sort_by == "gpc" and num_gpc > 0:
981
+ linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
982
+ order = sch.leaves_list(linkage)
353
983
 
354
- row_labels.extend([bin_label] * num_reads)
355
- bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
356
- last_idx += num_reads
357
- bin_boundaries.append(last_idx)
984
+ elif sort_by == "cpg" and num_cpg > 0:
985
+ linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
986
+ order = sch.leaves_list(linkage)
358
987
 
359
- if stacked_any_c:
988
+ elif sort_by == "c" and num_any_c > 0:
989
+ linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_c], method="ward")
990
+ order = sch.leaves_list(linkage)
991
+
992
+ elif sort_by == "gpc_cpg":
993
+ linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
994
+ order = sch.leaves_list(linkage)
995
+
996
+ elif sort_by == "a" and num_any_a > 0:
997
+ linkage = sch.linkage(subset_bin[:, any_a_sites].layers[layer_a], method="ward")
998
+ order = sch.leaves_list(linkage)
999
+
1000
+ elif sort_by == "none":
1001
+ order = np.arange(num_reads)
1002
+
1003
+ else:
1004
+ order = np.arange(num_reads)
1005
+
1006
+ subset_bin = subset_bin[order]
1007
+
1008
+ # stack consistently
1009
+ if include_any_c and num_any_c > 0:
1010
+ stacked_any_c.append(subset_bin[:, any_c_sites].layers[layer_c])
1011
+ if include_any_c and num_gpc > 0:
1012
+ stacked_gpc.append(subset_bin[:, gpc_sites].layers[layer_gpc])
1013
+ if include_any_c and num_cpg > 0:
1014
+ stacked_cpg.append(subset_bin[:, cpg_sites].layers[layer_cpg])
1015
+ if include_any_a and num_any_a > 0:
1016
+ stacked_any_a.append(subset_bin[:, any_a_sites].layers[layer_a])
1017
+
1018
+ row_labels.extend([bin_label] * num_reads)
1019
+ bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
1020
+ last_idx += num_reads
1021
+ bin_boundaries.append(last_idx)
1022
+
1023
+ # ----------------------------
1024
+ # build matrices + means
1025
+ # ----------------------------
1026
+ blocks = [] # list of dicts describing what to plot in order
1027
+
1028
+ if include_any_c and stacked_any_c:
360
1029
  any_c_matrix = np.vstack(stacked_any_c)
361
- gpc_matrix = np.vstack(stacked_gpc)
362
- cpg_matrix = np.vstack(stacked_cpg)
363
-
364
- if any_c_matrix.size > 0:
365
- def normalized_mean(matrix):
366
- mean = np.nanmean(matrix, axis=0)
367
- normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
368
- return normalized
369
-
370
- def methylation_fraction(matrix):
371
- methylated = (matrix == 1).sum(axis=0)
372
- valid = (matrix != 0).sum(axis=0)
373
- return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
374
-
375
- mean_gpc = methylation_fraction(gpc_matrix)
376
- mean_cpg = methylation_fraction(cpg_matrix)
377
- mean_any_c = methylation_fraction(any_c_matrix)
378
-
379
- fig = plt.figure(figsize=(18, 12))
380
- gs = gridspec.GridSpec(2, 3, height_ratios=[1, 6], hspace=0.01)
381
- fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
382
-
383
- axes_heat = [fig.add_subplot(gs[1, i]) for i in range(3)]
384
- axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(3)]
385
-
386
- clean_barplot(axes_bar[0], mean_any_c, f"any C site Modification Signal")
387
- clean_barplot(axes_bar[1], mean_gpc, f"GpC Modification Signal")
388
- clean_barplot(axes_bar[2], mean_cpg, f"CpG Modification Signal")
389
-
1030
+ gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
1031
+ cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
1032
+
1033
+ mean_any_c = methylation_fraction(any_c_matrix) if any_c_matrix.size else None
1034
+ mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
1035
+ mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
1036
+
1037
+ if any_c_matrix.size:
1038
+ blocks.append(dict(
1039
+ name="c",
1040
+ matrix=any_c_matrix,
1041
+ mean=mean_any_c,
1042
+ labels=any_c_labels,
1043
+ cmap=cmap_c,
1044
+ n_xticks=n_xticks_any_c,
1045
+ title="any C site Modification Signal"
1046
+ ))
1047
+ if gpc_matrix.size:
1048
+ blocks.append(dict(
1049
+ name="gpc",
1050
+ matrix=gpc_matrix,
1051
+ mean=mean_gpc,
1052
+ labels=gpc_labels,
1053
+ cmap=cmap_gpc,
1054
+ n_xticks=n_xticks_gpc,
1055
+ title="GpC Modification Signal"
1056
+ ))
1057
+ if cpg_matrix.size:
1058
+ blocks.append(dict(
1059
+ name="cpg",
1060
+ matrix=cpg_matrix,
1061
+ mean=mean_cpg,
1062
+ labels=cpg_labels,
1063
+ cmap=cmap_cpg,
1064
+ n_xticks=n_xticks_cpg,
1065
+ title="CpG Modification Signal"
1066
+ ))
1067
+
1068
+ if include_any_a and stacked_any_a:
1069
+ any_a_matrix = np.vstack(stacked_any_a)
1070
+ mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
1071
+ if any_a_matrix.size:
1072
+ blocks.append(dict(
1073
+ name="a",
1074
+ matrix=any_a_matrix,
1075
+ mean=mean_any_a,
1076
+ labels=any_a_labels,
1077
+ cmap=cmap_a,
1078
+ n_xticks=n_xticks_any_a,
1079
+ title="any A site Modification Signal"
1080
+ ))
1081
+
1082
+ if not blocks:
1083
+ print(f"No matrices to plot for {display_sample} - {ref}")
1084
+ continue
390
1085
 
391
- sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[0], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
392
- axes_heat[0].set_xticks(range(0, len(any_c_labels), 20))
393
- axes_heat[0].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
394
- for boundary in bin_boundaries[:-1]:
395
- axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
396
-
397
- sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
398
- axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
399
- axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
400
- for boundary in bin_boundaries[:-1]:
401
- axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
402
-
403
- sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
404
- axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
405
- for boundary in bin_boundaries[:-1]:
406
- axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
407
-
408
- plt.tight_layout()
409
-
410
- if save_path:
411
- save_name = f"{ref} — {sample}"
412
- os.makedirs(save_path, exist_ok=True)
413
- safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
414
- out_file = os.path.join(save_path, f"{safe_name}.png")
415
- plt.savefig(out_file, dpi=300)
416
- print(f"Saved: {out_file}")
417
- plt.close()
418
- else:
419
- plt.show()
420
-
421
- print(f"Summary for {sample} - {ref}:")
422
- for bin_label, percent in percentages.items():
423
- print(f" - {bin_label}: {percent:.1f}%")
424
-
425
- results.append({
426
- "sample": sample,
427
- "ref": ref,
428
- "any_c_matrix": any_c_matrix,
429
- "gpc_matrix": gpc_matrix,
430
- "cpg_matrix": cpg_matrix,
431
- "row_labels": row_labels,
432
- "bin_labels": bin_labels,
433
- "bin_boundaries": bin_boundaries,
434
- "percentages": percentages
435
- })
436
-
437
- adata.uns['clustermap_results'] = results
1086
+ gs_dim = len(blocks)
1087
+ fig = plt.figure(figsize=(5.5 * gs_dim, 11))
1088
+ gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.02)
1089
+ fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
1090
+
1091
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
1092
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
1093
+
1094
+ # ----------------------------
1095
+ # plot blocks
1096
+ # ----------------------------
1097
+ for i, blk in enumerate(blocks):
1098
+ mat = blk["matrix"]
1099
+ mean = blk["mean"]
1100
+ labels = np.asarray(blk["labels"], dtype=str)
1101
+ n_xticks = blk["n_xticks"]
1102
+
1103
+ # barplot
1104
+ clean_barplot(axes_bar[i], mean, blk["title"])
1105
+
1106
+ # heatmap
1107
+ sns.heatmap(
1108
+ mat,
1109
+ cmap=blk["cmap"],
1110
+ ax=axes_heat[i],
1111
+ yticklabels=False,
1112
+ cbar=False
1113
+ )
1114
+
1115
+ # fixed tick labels
1116
+ tick_pos = _fixed_tick_positions(len(labels), n_xticks)
1117
+ axes_heat[i].set_xticks(tick_pos)
1118
+ axes_heat[i].set_xticklabels(
1119
+ labels[tick_pos],
1120
+ rotation=xtick_rotation,
1121
+ fontsize=xtick_fontsize
1122
+ )
1123
+
1124
+ # bin separators
1125
+ for boundary in bin_boundaries[:-1]:
1126
+ axes_heat[i].axhline(y=boundary, color="black", linewidth=2)
1127
+
1128
+ axes_heat[i].set_xlabel("Position", fontsize=9)
1129
+
1130
+ plt.tight_layout()
1131
+
1132
+ # save or show
1133
+ if save_path is not None:
1134
+ safe_name = f"{ref}__{display_sample}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
1135
+ out_file = save_path / f"{safe_name}.png"
1136
+ fig.savefig(out_file, dpi=300)
1137
+ plt.close(fig)
1138
+ print(f"Saved: {out_file}")
1139
+ else:
1140
+ plt.show()
1141
+
1142
+ # record results
1143
+ rec = {
1144
+ "sample": str(sample),
1145
+ "ref": str(ref),
1146
+ "row_labels": row_labels,
1147
+ "bin_labels": bin_labels,
1148
+ "bin_boundaries": bin_boundaries,
1149
+ "percentages": percentages,
1150
+ }
1151
+ for blk in blocks:
1152
+ rec[f"{blk['name']}_matrix"] = blk["matrix"]
1153
+ rec[f"{blk['name']}_labels"] = list(map(str, blk["labels"]))
1154
+ results.append(rec)
1155
+
1156
+ print(f"Summary for {display_sample} - {ref}:")
1157
+ for bin_label, percent in percentages.items():
1158
+ print(f" - {bin_label}: {percent:.1f}%")
438
1159
 
439
1160
  except Exception as e:
440
1161
  import traceback
441
1162
  traceback.print_exc()
442
1163
  continue
443
-
444
1164
 
445
- import os
446
- import math
447
- from typing import List, Optional, Sequence, Tuple
1165
+ # store once at the end (HDF5 safe)
1166
+ # matrices won't be HDF5-safe; store only metadata + maybe hit counts
1167
+ # adata.uns["clustermap_results"] = [
1168
+ # {k: v for k, v in r.items() if not k.endswith("_matrix")}
1169
+ # for r in results
1170
+ # ]
448
1171
 
449
- import numpy as np
450
- import pandas as pd
451
- import matplotlib.pyplot as plt
1172
+ return results
452
1173
 
453
1174
  def plot_hmm_layers_rolling_by_sample_ref(
454
1175
  adata,
@@ -466,7 +1187,7 @@ def plot_hmm_layers_rolling_by_sample_ref(
466
1187
  output_dir: Optional[str] = None,
467
1188
  save: bool = True,
468
1189
  show_raw: bool = False,
469
- cmap: str = "tab10",
1190
+ cmap: str = "tab20",
470
1191
  use_var_coords: bool = True,
471
1192
  ):
472
1193
  """