smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,40 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import seaborn as sns
3
5
  import matplotlib.pyplot as plt
6
+ import scipy.cluster.hierarchy as sch
7
+ import matplotlib.gridspec as gridspec
8
+ import os
9
+ import math
10
+ import pandas as pd
11
+
12
+ from typing import Optional, Mapping, Sequence, Any, Dict, List
13
+ from pathlib import Path
14
+
15
+ def normalized_mean(matrix: np.ndarray) -> np.ndarray:
16
+ mean = np.nanmean(matrix, axis=0)
17
+ denom = (mean.max() - mean.min()) + 1e-9
18
+ return (mean - mean.min()) / denom
19
+
20
+ def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
21
+ """
22
+ Fraction methylated per column.
23
+ Methylated = 1
24
+ Valid = finite AND not 0
25
+ """
26
+ matrix = np.asarray(matrix)
27
+ valid_mask = np.isfinite(matrix) & (matrix != 0)
28
+ methyl_mask = (matrix == 1) & np.isfinite(matrix)
29
+
30
+ methylated = methyl_mask.sum(axis=0)
31
+ valid = valid_mask.sum(axis=0)
32
+
33
+ return np.divide(
34
+ methylated, valid,
35
+ out=np.zeros_like(methylated, dtype=float),
36
+ where=valid != 0
37
+ )
4
38
 
5
39
  def clean_barplot(ax, mean_values, title):
6
40
  x = np.arange(len(mean_values))
@@ -17,189 +51,1307 @@ def clean_barplot(ax, mean_values, title):
17
51
 
18
52
  ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
19
53
 
54
+ # def combined_hmm_raw_clustermap(
55
+ # adata,
56
+ # sample_col='Sample_Names',
57
+ # reference_col='Reference_strand',
58
+ # hmm_feature_layer="hmm_combined",
59
+ # layer_gpc="nan0_0minus1",
60
+ # layer_cpg="nan0_0minus1",
61
+ # layer_any_c="nan0_0minus1",
62
+ # cmap_hmm="tab10",
63
+ # cmap_gpc="coolwarm",
64
+ # cmap_cpg="viridis",
65
+ # cmap_any_c='coolwarm',
66
+ # min_quality=20,
67
+ # min_length=200,
68
+ # min_mapped_length_to_reference_length_ratio=0.8,
69
+ # min_position_valid_fraction=0.5,
70
+ # sample_mapping=None,
71
+ # save_path=None,
72
+ # normalize_hmm=False,
73
+ # sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
74
+ # bins=None,
75
+ # deaminase=False,
76
+ # min_signal=0
77
+ # ):
78
+
79
+ # results = []
80
+ # if deaminase:
81
+ # signal_type = 'deamination'
82
+ # else:
83
+ # signal_type = 'methylation'
84
+
85
+ # for ref in adata.obs[reference_col].cat.categories:
86
+ # for sample in adata.obs[sample_col].cat.categories:
87
+ # try:
88
+ # subset = adata[
89
+ # (adata.obs[reference_col] == ref) &
90
+ # (adata.obs[sample_col] == sample) &
91
+ # (adata.obs['read_quality'] >= min_quality) &
92
+ # (adata.obs['read_length'] >= min_length) &
93
+ # (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
94
+ # ]
95
+
96
+ # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
97
+ # subset = subset[:, mask]
98
+
99
+ # if subset.shape[0] == 0:
100
+ # print(f" No reads left after filtering for {sample} - {ref}")
101
+ # continue
102
+
103
+ # if bins:
104
+ # print(f"Using defined bins to subset clustermap for {sample} - {ref}")
105
+ # bins_temp = bins
106
+ # else:
107
+ # print(f"Using all reads for clustermap for {sample} - {ref}")
108
+ # bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
109
+
110
+ # # Get column positions (not var_names!) of site masks
111
+ # gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
112
+ # cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
113
+ # any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
114
+ # num_gpc = len(gpc_sites)
115
+ # num_cpg = len(cpg_sites)
116
+ # num_c = len(any_c_sites)
117
+ # print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites} for {sample} - {ref}")
118
+
119
+ # # Use var_names for x-axis tick labels
120
+ # gpc_labels = subset.var_names[gpc_sites].astype(int)
121
+ # cpg_labels = subset.var_names[cpg_sites].astype(int)
122
+ # any_c_labels = subset.var_names[any_c_sites].astype(int)
123
+
124
+ # stacked_hmm_feature, stacked_gpc, stacked_cpg, stacked_any_c = [], [], [], []
125
+ # row_labels, bin_labels = [], []
126
+ # bin_boundaries = []
127
+
128
+ # total_reads = subset.shape[0]
129
+ # percentages = {}
130
+ # last_idx = 0
131
+
132
+ # for bin_label, bin_filter in bins_temp.items():
133
+ # subset_bin = subset[bin_filter].copy()
134
+ # num_reads = subset_bin.shape[0]
135
+ # print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
136
+ # percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
137
+ # percentages[bin_label] = percent_reads
138
+
139
+ # if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
140
+ # # Determine sorting order
141
+ # if sort_by.startswith("obs:"):
142
+ # colname = sort_by.split("obs:")[1]
143
+ # order = np.argsort(subset_bin.obs[colname].values)
144
+ # elif sort_by == "gpc":
145
+ # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
146
+ # order = sch.leaves_list(linkage)
147
+ # elif sort_by == "cpg":
148
+ # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
149
+ # order = sch.leaves_list(linkage)
150
+ # elif sort_by == "gpc_cpg":
151
+ # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
152
+ # order = sch.leaves_list(linkage)
153
+ # elif sort_by == "none":
154
+ # order = np.arange(num_reads)
155
+ # elif sort_by == "any_c":
156
+ # linkage = sch.linkage(subset_bin.layers[layer_any_c], method="ward")
157
+ # order = sch.leaves_list(linkage)
158
+ # else:
159
+ # raise ValueError(f"Unsupported sort_by option: {sort_by}")
160
+
161
+ # stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
162
+ # stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
163
+ # stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
164
+ # stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
165
+
166
+ # row_labels.extend([bin_label] * num_reads)
167
+ # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
168
+ # last_idx += num_reads
169
+ # bin_boundaries.append(last_idx)
170
+
171
+ # if stacked_hmm_feature:
172
+ # hmm_matrix = np.vstack(stacked_hmm_feature)
173
+ # gpc_matrix = np.vstack(stacked_gpc)
174
+ # cpg_matrix = np.vstack(stacked_cpg)
175
+ # any_c_matrix = np.vstack(stacked_any_c)
176
+
177
+ # if hmm_matrix.size > 0:
178
+ # def normalized_mean(matrix):
179
+ # mean = np.nanmean(matrix, axis=0)
180
+ # normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
181
+ # return normalized
182
+
183
+ # def methylation_fraction(matrix):
184
+ # methylated = (matrix == 1).sum(axis=0)
185
+ # valid = (matrix != 0).sum(axis=0)
186
+ # return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
187
+
188
+ # if normalize_hmm:
189
+ # mean_hmm = normalized_mean(hmm_matrix)
190
+ # else:
191
+ # mean_hmm = np.nanmean(hmm_matrix, axis=0)
192
+ # mean_gpc = methylation_fraction(gpc_matrix)
193
+ # mean_cpg = methylation_fraction(cpg_matrix)
194
+ # mean_any_c = methylation_fraction(any_c_matrix)
195
+
196
+ # fig = plt.figure(figsize=(18, 12))
197
+ # gs = gridspec.GridSpec(2, 4, height_ratios=[1, 6], hspace=0.01)
198
+ # fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
199
+
200
+ # axes_heat = [fig.add_subplot(gs[1, i]) for i in range(4)]
201
+ # axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(4)]
202
+
203
+ # clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
204
+ # clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
205
+ # clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
206
+ # clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
207
+
208
+ # hmm_labels = subset.var_names.astype(int)
209
+ # hmm_label_spacing = 150
210
+ # sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
211
+ # axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
212
+ # axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
213
+ # for boundary in bin_boundaries[:-1]:
214
+ # axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
215
+
216
+ # sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
217
+ # axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
218
+ # axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
219
+ # for boundary in bin_boundaries[:-1]:
220
+ # axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
221
+
222
+ # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
223
+ # axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
224
+ # for boundary in bin_boundaries[:-1]:
225
+ # axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
226
+
227
+ # sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[3], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
228
+ # axes_heat[3].set_xticks(range(0, len(any_c_labels), 20))
229
+ # axes_heat[3].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
230
+ # for boundary in bin_boundaries[:-1]:
231
+ # axes_heat[3].axhline(y=boundary, color="black", linewidth=2)
232
+
233
+ # plt.tight_layout()
234
+
235
+ # if save_path:
236
+ # save_name = f"{ref} — {sample}"
237
+ # os.makedirs(save_path, exist_ok=True)
238
+ # safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
239
+ # out_file = os.path.join(save_path, f"{safe_name}.png")
240
+ # plt.savefig(out_file, dpi=300)
241
+ # print(f"Saved: {out_file}")
242
+ # plt.close()
243
+ # else:
244
+ # plt.show()
245
+
246
+ # print(f"Summary for {sample} - {ref}:")
247
+ # for bin_label, percent in percentages.items():
248
+ # print(f" - {bin_label}: {percent:.1f}%")
249
+
250
+ # results.append({
251
+ # "sample": sample,
252
+ # "ref": ref,
253
+ # "hmm_matrix": hmm_matrix,
254
+ # "gpc_matrix": gpc_matrix,
255
+ # "cpg_matrix": cpg_matrix,
256
+ # "row_labels": row_labels,
257
+ # "bin_labels": bin_labels,
258
+ # "bin_boundaries": bin_boundaries,
259
+ # "percentages": percentages
260
+ # })
261
+
262
+ # #adata.uns['clustermap_results'] = results
263
+
264
+ # except Exception as e:
265
+ # import traceback
266
+ # traceback.print_exc()
267
+ # continue
268
+
20
269
 
21
270
  def combined_hmm_raw_clustermap(
22
271
  adata,
23
- sample_col='Sample_Names',
24
- hmm_feature_layer="hmm_combined",
25
- layer_gpc="nan0_0minus1",
26
- layer_cpg="nan0_0minus1",
27
- cmap_hmm="tab10",
28
- cmap_gpc="coolwarm",
29
- cmap_cpg="viridis",
30
- min_quality=20,
31
- min_length=2700,
32
- sample_mapping=None,
33
- save_path=None,
34
- normalize_hmm=False,
35
- sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
36
- bins=None
272
+ sample_col: str = "Sample_Names",
273
+ reference_col: str = "Reference_strand",
274
+
275
+ hmm_feature_layer: str = "hmm_combined",
276
+
277
+ layer_gpc: str = "nan0_0minus1",
278
+ layer_cpg: str = "nan0_0minus1",
279
+ layer_any_c: str = "nan0_0minus1",
280
+ layer_a: str = "nan0_0minus1",
281
+
282
+ cmap_hmm: str = "tab10",
283
+ cmap_gpc: str = "coolwarm",
284
+ cmap_cpg: str = "viridis",
285
+ cmap_any_c: str = "coolwarm",
286
+ cmap_a: str = "coolwarm",
287
+
288
+ min_quality: int = 20,
289
+ min_length: int = 200,
290
+ min_mapped_length_to_reference_length_ratio: float = 0.8,
291
+ min_position_valid_fraction: float = 0.5,
292
+
293
+ save_path: str | Path | None = None,
294
+ normalize_hmm: bool = False,
295
+
296
+ sort_by: str = "gpc",
297
+ bins: Optional[Dict[str, Any]] = None,
298
+
299
+ deaminase: bool = False,
300
+ min_signal: float = 0.0,
301
+
302
+ # ---- fixed tick label controls (counts, not spacing)
303
+ n_xticks_hmm: int = 10,
304
+ n_xticks_any_c: int = 8,
305
+ n_xticks_gpc: int = 8,
306
+ n_xticks_cpg: int = 8,
307
+ n_xticks_a: int = 8,
37
308
  ):
38
- import scipy.cluster.hierarchy as sch
39
- import pandas as pd
40
- import numpy as np
41
- import seaborn as sns
42
- import matplotlib.pyplot as plt
43
- import matplotlib.gridspec as gridspec
44
- import os
309
+ """
310
+ Makes a multi-panel clustermap per (sample, reference):
311
+ HMM panel (always) + optional raw panels for any_C, GpC, CpG, and A sites.
312
+
313
+ Panels are added only if the corresponding site mask exists AND has >0 sites.
45
314
 
315
+ sort_by options:
316
+ 'gpc', 'cpg', 'any_c', 'any_a', 'gpc_cpg', 'none', or 'obs:<col>'
317
+ """
318
+ def pick_xticks(labels: np.ndarray, n_ticks: int):
319
+ if labels.size == 0:
320
+ return [], []
321
+ idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
322
+ idx = np.unique(idx)
323
+ return idx.tolist(), labels[idx].tolist()
324
+
46
325
  results = []
326
+ signal_type = "deamination" if deaminase else "methylation"
47
327
 
48
- for ref in adata.obs["Reference_strand"].cat.categories:
328
+ for ref in adata.obs[reference_col].cat.categories:
49
329
  for sample in adata.obs[sample_col].cat.categories:
330
+
50
331
  try:
332
+ # ---- subset reads ----
51
333
  subset = adata[
52
- (adata.obs['Reference_strand'] == ref) &
334
+ (adata.obs[reference_col] == ref) &
53
335
  (adata.obs[sample_col] == sample) &
54
- (adata.obs['query_read_quality'] >= min_quality) &
55
- (adata.obs['read_length'] >= min_length) &
56
- (adata.obs['Raw_methylation_signal'] >= 20)
336
+ (adata.obs["read_quality"] >= min_quality) &
337
+ (adata.obs["read_length"] >= min_length) &
338
+ (
339
+ adata.obs["mapped_length_to_reference_length_ratio"]
340
+ > min_mapped_length_to_reference_length_ratio
341
+ )
57
342
  ]
343
+
344
+ # ---- valid fraction filter ----
345
+ vf_key = f"{ref}_valid_fraction"
346
+ if vf_key in subset.var:
347
+ mask = subset.var[vf_key].astype(float) > float(min_position_valid_fraction)
348
+ subset = subset[:, mask]
349
+
350
+ if subset.shape[0] == 0:
351
+ continue
352
+
353
+ # ---- bins ----
354
+ if bins is None:
355
+ bins_temp = {"All": np.ones(subset.n_obs, dtype=bool)}
356
+ else:
357
+ bins_temp = bins
358
+
359
+ # ---- site masks (robust) ----
360
+ def _sites(*keys):
361
+ for k in keys:
362
+ if k in subset.var:
363
+ return np.where(subset.var[k].values)[0]
364
+ return np.array([], dtype=int)
365
+
366
+ gpc_sites = _sites(f"{ref}_GpC_site")
367
+ cpg_sites = _sites(f"{ref}_CpG_site")
368
+ any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
369
+ any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
370
+
371
+ def _labels(sites):
372
+ return subset.var_names[sites].astype(int) if sites.size else np.array([])
373
+
374
+ gpc_labels = _labels(gpc_sites)
375
+ cpg_labels = _labels(cpg_sites)
376
+ any_c_labels = _labels(any_c_sites)
377
+ any_a_labels = _labels(any_a_sites)
378
+
379
+ # storage
380
+ stacked_hmm = []
381
+ stacked_any_c = []
382
+ stacked_gpc = []
383
+ stacked_cpg = []
384
+ stacked_any_a = []
385
+
386
+ row_labels, bin_labels, bin_boundaries = [], [], []
387
+ total_reads = subset.n_obs
388
+ percentages = {}
389
+ last_idx = 0
390
+
391
+ # ---------------- process bins ----------------
392
+ for bin_label, bin_filter in bins_temp.items():
393
+ sb = subset[bin_filter].copy()
394
+ n = sb.n_obs
395
+ if n == 0:
396
+ continue
397
+
398
+ pct = (n / total_reads) * 100 if total_reads else 0
399
+ percentages[bin_label] = pct
400
+
401
+ # ---- sorting ----
402
+ if sort_by.startswith("obs:"):
403
+ colname = sort_by.split("obs:")[1]
404
+ order = np.argsort(sb.obs[colname].values)
405
+
406
+ elif sort_by == "gpc" and gpc_sites.size:
407
+ linkage = sch.linkage(sb[:, gpc_sites].layers[layer_gpc], method="ward")
408
+ order = sch.leaves_list(linkage)
409
+
410
+ elif sort_by == "cpg" and cpg_sites.size:
411
+ linkage = sch.linkage(sb[:, cpg_sites].layers[layer_cpg], method="ward")
412
+ order = sch.leaves_list(linkage)
413
+
414
+ elif sort_by == "any_c" and any_c_sites.size:
415
+ linkage = sch.linkage(sb[:, any_c_sites].layers[layer_any_c], method="ward")
416
+ order = sch.leaves_list(linkage)
417
+
418
+ elif sort_by == "any_a" and any_a_sites.size:
419
+ linkage = sch.linkage(sb[:, any_a_sites].layers[layer_a], method="ward")
420
+ order = sch.leaves_list(linkage)
421
+
422
+ elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
423
+ linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
424
+ order = sch.leaves_list(linkage)
425
+
426
+ else:
427
+ order = np.arange(n)
428
+
429
+ sb = sb[order]
430
+
431
+ # ---- collect matrices ----
432
+ stacked_hmm.append(sb.layers[hmm_feature_layer])
433
+ if any_c_sites.size:
434
+ stacked_any_c.append(sb[:, any_c_sites].layers[layer_any_c])
435
+ if gpc_sites.size:
436
+ stacked_gpc.append(sb[:, gpc_sites].layers[layer_gpc])
437
+ if cpg_sites.size:
438
+ stacked_cpg.append(sb[:, cpg_sites].layers[layer_cpg])
439
+ if any_a_sites.size:
440
+ stacked_any_a.append(sb[:, any_a_sites].layers[layer_a])
441
+
442
+ row_labels.extend([bin_label] * n)
443
+ bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
444
+ last_idx += n
445
+ bin_boundaries.append(last_idx)
446
+
447
+ # ---------------- stack ----------------
448
+ hmm_matrix = np.vstack(stacked_hmm)
449
+ mean_hmm = normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
450
+
451
+ panels = [
452
+ ("HMM", hmm_matrix, subset.var_names.astype(int), cmap_hmm, mean_hmm, n_xticks_hmm),
453
+ ]
454
+
455
+ if stacked_any_c:
456
+ m = np.vstack(stacked_any_c)
457
+ panels.append(("any_C", m, any_c_labels, cmap_any_c, methylation_fraction(m), n_xticks_any_c))
458
+
459
+ if stacked_gpc:
460
+ m = np.vstack(stacked_gpc)
461
+ panels.append(("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc))
462
+
463
+ if stacked_cpg:
464
+ m = np.vstack(stacked_cpg)
465
+ panels.append(("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg))
466
+
467
+ if stacked_any_a:
468
+ m = np.vstack(stacked_any_a)
469
+ panels.append(("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a))
470
+
471
+ # ---------------- plotting ----------------
472
+ n_panels = len(panels)
473
+ fig = plt.figure(figsize=(4.5 * n_panels, 10))
474
+ gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
475
+ fig.suptitle(f"{sample} — {ref} — {total_reads} reads ({signal_type})",
476
+ fontsize=14, y=0.98)
477
+
478
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
479
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
480
+
481
+ for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
482
+
483
+ # ---- your clean barplot ----
484
+ clean_barplot(axes_bar[i], mean_vec, name)
485
+
486
+ # ---- heatmap ----
487
+ sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i],
488
+ yticklabels=False, cbar=False)
489
+
490
+ # ---- xticks ----
491
+ xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
492
+ axes_heat[i].set_xticks(xtick_pos)
493
+ axes_heat[i].set_xticklabels(xtick_labels, rotation=90, fontsize=8)
494
+
495
+ for boundary in bin_boundaries[:-1]:
496
+ axes_heat[i].axhline(y=boundary, color="black", linewidth=1.2)
497
+
498
+ plt.tight_layout()
499
+
500
+ if save_path:
501
+ save_path = Path(save_path)
502
+ save_path.mkdir(parents=True, exist_ok=True)
503
+ safe_name = f"{ref}__{sample}".replace("/", "_")
504
+ out_file = save_path / f"{safe_name}.png"
505
+ plt.savefig(out_file, dpi=300)
506
+ plt.close(fig)
507
+ else:
508
+ plt.show()
509
+
510
+ except Exception:
511
+ import traceback
512
+ traceback.print_exc()
513
+ continue
514
+
515
+
516
+ # def combined_raw_clustermap(
517
+ # adata,
518
+ # sample_col='Sample_Names',
519
+ # reference_col='Reference_strand',
520
+ # mod_target_bases=['GpC', 'CpG'],
521
+ # layer_any_c="nan0_0minus1",
522
+ # layer_gpc="nan0_0minus1",
523
+ # layer_cpg="nan0_0minus1",
524
+ # layer_a="nan0_0minus1",
525
+ # cmap_any_c="coolwarm",
526
+ # cmap_gpc="coolwarm",
527
+ # cmap_cpg="viridis",
528
+ # cmap_a="coolwarm",
529
+ # min_quality=20,
530
+ # min_length=200,
531
+ # min_mapped_length_to_reference_length_ratio=0.8,
532
+ # min_position_valid_fraction=0.5,
533
+ # sample_mapping=None,
534
+ # save_path=None,
535
+ # sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', 'any_a', or 'obs:<column>'
536
+ # bins=None,
537
+ # deaminase=False,
538
+ # min_signal=0
539
+ # ):
540
+
541
+ # results = []
542
+
543
+ # for ref in adata.obs[reference_col].cat.categories:
544
+ # for sample in adata.obs[sample_col].cat.categories:
545
+ # try:
546
+ # subset = adata[
547
+ # (adata.obs[reference_col] == ref) &
548
+ # (adata.obs[sample_col] == sample) &
549
+ # (adata.obs['read_quality'] >= min_quality) &
550
+ # (adata.obs['mapped_length'] >= min_length) &
551
+ # (adata.obs['mapped_length_to_reference_length_ratio'] >= min_mapped_length_to_reference_length_ratio)
552
+ # ]
553
+
554
+ # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
555
+ # subset = subset[:, mask]
556
+
557
+ # if subset.shape[0] == 0:
558
+ # print(f" No reads left after filtering for {sample} - {ref}")
559
+ # continue
560
+
561
+ # if bins:
562
+ # print(f"Using defined bins to subset clustermap for {sample} - {ref}")
563
+ # bins_temp = bins
564
+ # else:
565
+ # print(f"Using all reads for clustermap for {sample} - {ref}")
566
+ # bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
567
+
568
+ # num_any_c = 0
569
+ # num_gpc = 0
570
+ # num_cpg = 0
571
+ # num_any_a = 0
572
+
573
+ # # Get column positions (not var_names!) of site masks
574
+ # if any(base in ["C", "CpG", "GpC"] for base in mod_target_bases):
575
+ # any_c_sites = np.where(subset.var[f"{ref}_C_site"].values)[0]
576
+ # gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
577
+ # cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
578
+ # num_any_c = len(any_c_sites)
579
+ # num_gpc = len(gpc_sites)
580
+ # num_cpg = len(cpg_sites)
581
+ # print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites}\n and {num_any_c} any_C sites at {any_c_sites} for {sample} - {ref}")
582
+
583
+ # # Use var_names for x-axis tick labels
584
+ # gpc_labels = subset.var_names[gpc_sites].astype(int)
585
+ # cpg_labels = subset.var_names[cpg_sites].astype(int)
586
+ # any_c_labels = subset.var_names[any_c_sites].astype(int)
587
+ # stacked_any_c, stacked_gpc, stacked_cpg = [], [], []
588
+
589
+ # if "A" in mod_target_bases:
590
+ # any_a_sites = np.where(subset.var[f"{ref}_A_site"].values)[0]
591
+ # num_any_a = len(any_a_sites)
592
+ # print(f"Found {num_any_a} any_A sites at {any_a_sites} for {sample} - {ref}")
593
+ # any_a_labels = subset.var_names[any_a_sites].astype(int)
594
+ # stacked_any_a = []
595
+
596
+ # row_labels, bin_labels = [], []
597
+ # bin_boundaries = []
598
+
599
+ # total_reads = subset.shape[0]
600
+ # percentages = {}
601
+ # last_idx = 0
602
+
603
+ # for bin_label, bin_filter in bins_temp.items():
604
+ # subset_bin = subset[bin_filter].copy()
605
+ # num_reads = subset_bin.shape[0]
606
+ # print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
607
+ # percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
608
+ # percentages[bin_label] = percent_reads
609
+
610
+ # if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
611
+ # # Determine sorting order
612
+ # if sort_by.startswith("obs:"):
613
+ # colname = sort_by.split("obs:")[1]
614
+ # order = np.argsort(subset_bin.obs[colname].values)
615
+ # elif sort_by == "gpc":
616
+ # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
617
+ # order = sch.leaves_list(linkage)
618
+ # elif sort_by == "cpg":
619
+ # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
620
+ # order = sch.leaves_list(linkage)
621
+ # elif sort_by == "any_c":
622
+ # linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
623
+ # order = sch.leaves_list(linkage)
624
+ # elif sort_by == "gpc_cpg":
625
+ # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
626
+ # order = sch.leaves_list(linkage)
627
+ # elif sort_by == "none":
628
+ # order = np.arange(num_reads)
629
+ # elif sort_by == "any_a":
630
+ # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
631
+ # order = sch.leaves_list(linkage)
632
+ # else:
633
+ # raise ValueError(f"Unsupported sort_by option: {sort_by}")
634
+
635
+ # stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
636
+ # stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
637
+ # stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
638
+
639
+ # if num_reads > 0 and num_any_a > 0:
640
+ # # Determine sorting order
641
+ # if sort_by.startswith("obs:"):
642
+ # colname = sort_by.split("obs:")[1]
643
+ # order = np.argsort(subset_bin.obs[colname].values)
644
+ # elif sort_by == "gpc":
645
+ # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
646
+ # order = sch.leaves_list(linkage)
647
+ # elif sort_by == "cpg":
648
+ # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
649
+ # order = sch.leaves_list(linkage)
650
+ # elif sort_by == "any_c":
651
+ # linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
652
+ # order = sch.leaves_list(linkage)
653
+ # elif sort_by == "gpc_cpg":
654
+ # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
655
+ # order = sch.leaves_list(linkage)
656
+ # elif sort_by == "none":
657
+ # order = np.arange(num_reads)
658
+ # elif sort_by == "any_a":
659
+ # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
660
+ # order = sch.leaves_list(linkage)
661
+ # else:
662
+ # raise ValueError(f"Unsupported sort_by option: {sort_by}")
663
+
664
+ # stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
665
+
666
+
667
+ # row_labels.extend([bin_label] * num_reads)
668
+ # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
669
+ # last_idx += num_reads
670
+ # bin_boundaries.append(last_idx)
671
+
672
+ # gs_dim = 0
673
+
674
+ # if stacked_any_c:
675
+ # any_c_matrix = np.vstack(stacked_any_c)
676
+ # gpc_matrix = np.vstack(stacked_gpc)
677
+ # cpg_matrix = np.vstack(stacked_cpg)
678
+ # if any_c_matrix.size > 0:
679
+ # mean_gpc = methylation_fraction(gpc_matrix)
680
+ # mean_cpg = methylation_fraction(cpg_matrix)
681
+ # mean_any_c = methylation_fraction(any_c_matrix)
682
+ # gs_dim += 3
683
+
684
+ # if stacked_any_a:
685
+ # any_a_matrix = np.vstack(stacked_any_a)
686
+ # if any_a_matrix.size > 0:
687
+ # mean_any_a = methylation_fraction(any_a_matrix)
688
+ # gs_dim += 1
689
+
690
+
691
+ # fig = plt.figure(figsize=(18, 12))
692
+ # gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
693
+ # fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
694
+ # axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
695
+ # axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
696
+
697
+ # current_ax = 0
698
+
699
+ # if stacked_any_c:
700
+ # if any_c_matrix.size > 0:
701
+ # clean_barplot(axes_bar[current_ax], mean_any_c, f"any C site Modification Signal")
702
+ # sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[current_ax], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
703
+ # axes_heat[current_ax].set_xticks(range(0, len(any_c_labels), 20))
704
+ # axes_heat[current_ax].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
705
+ # for boundary in bin_boundaries[:-1]:
706
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
707
+ # current_ax +=1
708
+
709
+ # clean_barplot(axes_bar[current_ax], mean_gpc, f"GpC Modification Signal")
710
+ # sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[current_ax], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
711
+ # axes_heat[current_ax].set_xticks(range(0, len(gpc_labels), 5))
712
+ # axes_heat[current_ax].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
713
+ # for boundary in bin_boundaries[:-1]:
714
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
715
+ # current_ax +=1
716
+
717
+ # clean_barplot(axes_bar[current_ax], mean_cpg, f"CpG Modification Signal")
718
+ # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
719
+ # axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
720
+ # for boundary in bin_boundaries[:-1]:
721
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
722
+ # current_ax +=1
723
+
724
+ # results.append({
725
+ # "sample": sample,
726
+ # "ref": ref,
727
+ # "any_c_matrix": any_c_matrix,
728
+ # "gpc_matrix": gpc_matrix,
729
+ # "cpg_matrix": cpg_matrix,
730
+ # "row_labels": row_labels,
731
+ # "bin_labels": bin_labels,
732
+ # "bin_boundaries": bin_boundaries,
733
+ # "percentages": percentages
734
+ # })
735
+
736
+ # if stacked_any_a:
737
+ # if any_a_matrix.size > 0:
738
+ # clean_barplot(axes_bar[current_ax], mean_any_a, f"any A site Modification Signal")
739
+ # sns.heatmap(any_a_matrix, cmap=cmap_a, ax=axes_heat[current_ax], xticklabels=any_a_labels[::20], yticklabels=False, cbar=False)
740
+ # axes_heat[current_ax].set_xticks(range(0, len(any_a_labels), 20))
741
+ # axes_heat[current_ax].set_xticklabels(any_a_labels[::20], rotation=90, fontsize=10)
742
+ # for boundary in bin_boundaries[:-1]:
743
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
744
+ # current_ax +=1
745
+
746
+ # results.append({
747
+ # "sample": sample,
748
+ # "ref": ref,
749
+ # "any_a_matrix": any_a_matrix,
750
+ # "row_labels": row_labels,
751
+ # "bin_labels": bin_labels,
752
+ # "bin_boundaries": bin_boundaries,
753
+ # "percentages": percentages
754
+ # })
755
+
756
+ # plt.tight_layout()
757
+
758
+ # if save_path:
759
+ # save_name = f"{ref} — {sample}"
760
+ # os.makedirs(save_path, exist_ok=True)
761
+ # safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
762
+ # out_file = os.path.join(save_path, f"{safe_name}.png")
763
+ # plt.savefig(out_file, dpi=300)
764
+ # print(f"Saved: {out_file}")
765
+ # plt.close()
766
+ # else:
767
+ # plt.show()
768
+
769
+ # print(f"Summary for {sample} - {ref}:")
770
+ # for bin_label, percent in percentages.items():
771
+ # print(f" - {bin_label}: {percent:.1f}%")
58
772
 
59
- subset = subset[:, subset.var[f'position_in_{ref}'] == True]
773
+ # adata.uns['clustermap_results'] = results
774
+
775
+ # except Exception as e:
776
+ # import traceback
777
+ # traceback.print_exc()
778
+ # continue
779
+
780
+ def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
781
+ """
782
+ Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
783
+ Always includes 0 and n_positions-1 when possible.
784
+ """
785
+ n_ticks = int(max(2, n_ticks))
786
+ if n_positions <= n_ticks:
787
+ return np.arange(n_positions)
788
+
789
+ # linspace gives fixed count
790
+ pos = np.linspace(0, n_positions - 1, n_ticks)
791
+ return np.unique(np.round(pos).astype(int))
792
+
793
+ def combined_raw_clustermap(
794
+ adata,
795
+ sample_col: str = "Sample_Names",
796
+ reference_col: str = "Reference_strand",
797
+ mod_target_bases: Sequence[str] = ("GpC", "CpG"),
798
+ layer_any_c: str = "nan0_0minus1",
799
+ layer_gpc: str = "nan0_0minus1",
800
+ layer_cpg: str = "nan0_0minus1",
801
+ layer_a: str = "nan0_0minus1",
802
+ cmap_any_c: str = "coolwarm",
803
+ cmap_gpc: str = "coolwarm",
804
+ cmap_cpg: str = "viridis",
805
+ cmap_a: str = "coolwarm",
806
+ min_quality: float = 20,
807
+ min_length: int = 200,
808
+ min_mapped_length_to_reference_length_ratio: float = 0.8,
809
+ min_position_valid_fraction: float = 0.5,
810
+ sample_mapping: Optional[Mapping[str, str]] = None,
811
+ save_path: str | Path | None = None,
812
+ sort_by: str = "gpc", # 'gpc','cpg','any_c','gpc_cpg','any_a','none','obs:<col>'
813
+ bins: Optional[Dict[str, Any]] = None,
814
+ deaminase: bool = False,
815
+ min_signal: float = 0,
816
+ # NEW tick controls
817
+ n_xticks_any_c: int = 10,
818
+ n_xticks_gpc: int = 10,
819
+ n_xticks_cpg: int = 10,
820
+ n_xticks_any_a: int = 10,
821
+ xtick_rotation: int = 90,
822
+ xtick_fontsize: int = 9,
823
+ ):
824
+ """
825
+ Plot stacked heatmaps + per-position mean barplots for any_C, GpC, CpG, and optional A.
826
+
827
+ Key fixes vs old version:
828
+ - order computed ONCE per bin, applied to all matrices
829
+ - no hard-coded axes indices
830
+ - NaNs excluded from methylation denominators
831
+ - var_names not forced to int
832
+ - fixed count of x tick labels per block (controllable)
833
+ - adata.uns updated once at end
834
+
835
+ Returns
836
+ -------
837
+ results : list[dict]
838
+ One entry per (sample, ref) plot with matrices + bin metadata.
839
+ """
840
+
841
+ results: List[Dict[str, Any]] = []
842
+ save_path = Path(save_path) if save_path is not None else None
843
+ if save_path is not None:
844
+ save_path.mkdir(parents=True, exist_ok=True)
845
+
846
+ # Ensure categorical
847
+ for col in (sample_col, reference_col):
848
+ if col not in adata.obs:
849
+ raise KeyError(f"{col} not in adata.obs")
850
+ if not pd.api.types.is_categorical_dtype(adata.obs[col]):
851
+ adata.obs[col] = adata.obs[col].astype("category")
852
+
853
+ base_set = set(mod_target_bases)
854
+ include_any_c = any(b in {"C", "CpG", "GpC"} for b in base_set)
855
+ include_any_a = "A" in base_set
856
+
857
+ for ref in adata.obs[reference_col].cat.categories:
858
+ for sample in adata.obs[sample_col].cat.categories:
859
+
860
+ # Optionally remap sample label for display
861
+ display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
862
+
863
+ try:
864
+ subset = adata[
865
+ (adata.obs[reference_col] == ref) &
866
+ (adata.obs[sample_col] == sample) &
867
+ (adata.obs["read_quality"] >= min_quality) &
868
+ (adata.obs["mapped_length"] >= min_length) &
869
+ (adata.obs["mapped_length_to_reference_length_ratio"] >= min_mapped_length_to_reference_length_ratio)
870
+ ]
871
+
872
+ # position-level mask
873
+ valid_key = f"{ref}_valid_fraction"
874
+ if valid_key in subset.var:
875
+ mask = subset.var[valid_key].astype(float).values > float(min_position_valid_fraction)
876
+ subset = subset[:, mask]
60
877
 
61
878
  if subset.shape[0] == 0:
62
- print(f"No reads left after filtering for {sample} - {ref}")
879
+ print(f"No reads left after filtering for {display_sample} - {ref}")
63
880
  continue
64
881
 
65
- if bins:
66
- pass
882
+ # bins mode
883
+ if bins is None:
884
+ bins_temp = {"All": (subset.obs[reference_col] == ref)}
67
885
  else:
68
- bins = {"All": (subset.obs['Reference_strand'] != None)}
886
+ bins_temp = bins
69
887
 
70
- # Get column positions (not var_names!) of site masks
71
- gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
72
- cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
888
+ # find sites (positions)
889
+ any_c_sites = gpc_sites = cpg_sites = np.array([], dtype=int)
890
+ any_a_sites = np.array([], dtype=int)
73
891
 
74
- # Use var_names for x-axis tick labels
75
- gpc_labels = subset.var_names[gpc_sites].astype(int)
76
- cpg_labels = subset.var_names[cpg_sites].astype(int)
892
+ num_any_c = num_gpc = num_cpg = num_any_a = 0
77
893
 
78
- stacked_hmm_feature, stacked_gpc, stacked_cpg = [], [], []
79
- row_labels, bin_labels = [], []
80
- bin_boundaries = []
894
+ if include_any_c:
895
+ any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
896
+ gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
897
+ cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
81
898
 
82
- total_reads = subset.shape[0]
899
+ num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
900
+
901
+ any_c_labels = subset.var_names[any_c_sites].astype(str)
902
+ gpc_labels = subset.var_names[gpc_sites].astype(str)
903
+ cpg_labels = subset.var_names[cpg_sites].astype(str)
904
+
905
+ if include_any_a:
906
+ any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
907
+ num_any_a = len(any_a_sites)
908
+ any_a_labels = subset.var_names[any_a_sites].astype(str)
909
+
910
+ stacked_any_c, stacked_gpc, stacked_cpg, stacked_any_a = [], [], [], []
911
+ row_labels, bin_labels, bin_boundaries = [], [], []
83
912
  percentages = {}
84
913
  last_idx = 0
914
+ total_reads = subset.shape[0]
85
915
 
86
- for bin_label, bin_filter in bins.items():
916
+ # ----------------------------
917
+ # per-bin stacking
918
+ # ----------------------------
919
+ for bin_label, bin_filter in bins_temp.items():
87
920
  subset_bin = subset[bin_filter].copy()
88
921
  num_reads = subset_bin.shape[0]
89
- percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
922
+ if num_reads == 0:
923
+ percentages[bin_label] = 0.0
924
+ continue
925
+
926
+ percent_reads = (num_reads / total_reads) * 100
90
927
  percentages[bin_label] = percent_reads
91
928
 
92
- if num_reads > 0:
93
- # Determine sorting order
94
- if sort_by.startswith("obs:"):
95
- colname = sort_by.split("obs:")[1]
96
- order = np.argsort(subset_bin.obs[colname].values)
97
- elif sort_by == "gpc":
98
- linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
99
- order = sch.leaves_list(linkage)
100
- elif sort_by == "cpg":
101
- linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
102
- order = sch.leaves_list(linkage)
103
- elif sort_by == "gpc_cpg":
104
- linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
105
- order = sch.leaves_list(linkage)
106
- elif sort_by == "none":
107
- order = np.arange(num_reads)
108
- else:
109
- raise ValueError(f"Unsupported sort_by option: {sort_by}")
110
-
111
- stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
112
- stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
113
- stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
114
-
115
- row_labels.extend([bin_label] * num_reads)
116
- bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
117
- last_idx += num_reads
118
- bin_boundaries.append(last_idx)
119
-
120
- if stacked_hmm_feature:
121
- hmm_matrix = np.vstack(stacked_hmm_feature)
122
- gpc_matrix = np.vstack(stacked_gpc)
123
- cpg_matrix = np.vstack(stacked_cpg)
124
-
125
- def normalized_mean(matrix):
126
- mean = np.nanmean(matrix, axis=0)
127
- normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
128
- return normalized
129
-
130
- def methylation_fraction(matrix):
131
- methylated = (matrix == 1).sum(axis=0)
132
- valid = (matrix != 0).sum(axis=0)
133
- return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
134
-
135
- if normalize_hmm:
136
- mean_hmm = normalized_mean(hmm_matrix)
929
+ # compute order ONCE
930
+ if sort_by.startswith("obs:"):
931
+ colname = sort_by.split("obs:")[1]
932
+ order = np.argsort(subset_bin.obs[colname].values)
933
+
934
+ elif sort_by == "gpc" and num_gpc > 0:
935
+ linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
936
+ order = sch.leaves_list(linkage)
937
+
938
+ elif sort_by == "cpg" and num_cpg > 0:
939
+ linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
940
+ order = sch.leaves_list(linkage)
941
+
942
+ elif sort_by == "any_c" and num_any_c > 0:
943
+ linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
944
+ order = sch.leaves_list(linkage)
945
+
946
+ elif sort_by == "gpc_cpg":
947
+ linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
948
+ order = sch.leaves_list(linkage)
949
+
950
+ elif sort_by == "any_a" and num_any_a > 0:
951
+ linkage = sch.linkage(subset_bin[:, any_a_sites].layers[layer_a], method="ward")
952
+ order = sch.leaves_list(linkage)
953
+
954
+ elif sort_by == "none":
955
+ order = np.arange(num_reads)
956
+
137
957
  else:
138
- mean_hmm = np.nanmean(hmm_matrix, axis=0)
139
- mean_gpc = methylation_fraction(gpc_matrix)
140
- mean_cpg = methylation_fraction(cpg_matrix)
958
+ order = np.arange(num_reads)
141
959
 
142
- fig = plt.figure(figsize=(18, 12))
143
- gs = gridspec.GridSpec(2, 3, height_ratios=[1, 6], hspace=0.01)
144
- fig.suptitle(f"{sample} - {ref}", fontsize=14, y=0.95)
960
+ subset_bin = subset_bin[order]
145
961
 
146
- axes_heat = [fig.add_subplot(gs[1, i]) for i in range(3)]
147
- axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(3)]
962
+ # stack consistently
963
+ if include_any_c and num_any_c > 0:
964
+ stacked_any_c.append(subset_bin[:, any_c_sites].layers[layer_any_c])
965
+ if include_any_c and num_gpc > 0:
966
+ stacked_gpc.append(subset_bin[:, gpc_sites].layers[layer_gpc])
967
+ if include_any_c and num_cpg > 0:
968
+ stacked_cpg.append(subset_bin[:, cpg_sites].layers[layer_cpg])
969
+ if include_any_a and num_any_a > 0:
970
+ stacked_any_a.append(subset_bin[:, any_a_sites].layers[layer_a])
148
971
 
149
- clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
150
- clean_barplot(axes_bar[1], mean_gpc, f"GpC Methylation")
151
- clean_barplot(axes_bar[2], mean_cpg, f"CpG Methylation")
152
-
153
- hmm_labels = subset.var_names.astype(int)
154
- hmm_label_spacing = 150
155
- sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
156
- axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
157
- axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
158
- for boundary in bin_boundaries[:-1]:
159
- axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
972
+ row_labels.extend([bin_label] * num_reads)
973
+ bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
974
+ last_idx += num_reads
975
+ bin_boundaries.append(last_idx)
160
976
 
161
- sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
162
- axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
163
- axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
164
- for boundary in bin_boundaries[:-1]:
165
- axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
977
+ # ----------------------------
978
+ # build matrices + means
979
+ # ----------------------------
980
+ blocks = [] # list of dicts describing what to plot in order
166
981
 
167
- sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
168
- axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
982
+ if include_any_c and stacked_any_c:
983
+ any_c_matrix = np.vstack(stacked_any_c)
984
+ gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
985
+ cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
986
+
987
+ mean_any_c = methylation_fraction(any_c_matrix) if any_c_matrix.size else None
988
+ mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
989
+ mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
990
+
991
+ if any_c_matrix.size:
992
+ blocks.append(dict(
993
+ name="any_c",
994
+ matrix=any_c_matrix,
995
+ mean=mean_any_c,
996
+ labels=any_c_labels,
997
+ cmap=cmap_any_c,
998
+ n_xticks=n_xticks_any_c,
999
+ title="any C site Modification Signal"
1000
+ ))
1001
+ if gpc_matrix.size:
1002
+ blocks.append(dict(
1003
+ name="gpc",
1004
+ matrix=gpc_matrix,
1005
+ mean=mean_gpc,
1006
+ labels=gpc_labels,
1007
+ cmap=cmap_gpc,
1008
+ n_xticks=n_xticks_gpc,
1009
+ title="GpC Modification Signal"
1010
+ ))
1011
+ if cpg_matrix.size:
1012
+ blocks.append(dict(
1013
+ name="cpg",
1014
+ matrix=cpg_matrix,
1015
+ mean=mean_cpg,
1016
+ labels=cpg_labels,
1017
+ cmap=cmap_cpg,
1018
+ n_xticks=n_xticks_cpg,
1019
+ title="CpG Modification Signal"
1020
+ ))
1021
+
1022
+ if include_any_a and stacked_any_a:
1023
+ any_a_matrix = np.vstack(stacked_any_a)
1024
+ mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
1025
+ if any_a_matrix.size:
1026
+ blocks.append(dict(
1027
+ name="any_a",
1028
+ matrix=any_a_matrix,
1029
+ mean=mean_any_a,
1030
+ labels=any_a_labels,
1031
+ cmap=cmap_a,
1032
+ n_xticks=n_xticks_any_a,
1033
+ title="any A site Modification Signal"
1034
+ ))
1035
+
1036
+ if not blocks:
1037
+ print(f"No matrices to plot for {display_sample} - {ref}")
1038
+ continue
1039
+
1040
+ gs_dim = len(blocks)
1041
+ fig = plt.figure(figsize=(5.5 * gs_dim, 11))
1042
+ gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.02)
1043
+ fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
1044
+
1045
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
1046
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
1047
+
1048
+ # ----------------------------
1049
+ # plot blocks
1050
+ # ----------------------------
1051
+ for i, blk in enumerate(blocks):
1052
+ mat = blk["matrix"]
1053
+ mean = blk["mean"]
1054
+ labels = np.asarray(blk["labels"], dtype=str)
1055
+ n_xticks = blk["n_xticks"]
1056
+
1057
+ # barplot
1058
+ clean_barplot(axes_bar[i], mean, blk["title"])
1059
+
1060
+ # heatmap
1061
+ sns.heatmap(
1062
+ mat,
1063
+ cmap=blk["cmap"],
1064
+ ax=axes_heat[i],
1065
+ yticklabels=False,
1066
+ cbar=False
1067
+ )
1068
+
1069
+ # fixed tick labels
1070
+ tick_pos = _fixed_tick_positions(len(labels), n_xticks)
1071
+ axes_heat[i].set_xticks(tick_pos)
1072
+ axes_heat[i].set_xticklabels(
1073
+ labels[tick_pos],
1074
+ rotation=xtick_rotation,
1075
+ fontsize=xtick_fontsize
1076
+ )
1077
+
1078
+ # bin separators
169
1079
  for boundary in bin_boundaries[:-1]:
170
- axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
1080
+ axes_heat[i].axhline(y=boundary, color="black", linewidth=2)
171
1081
 
172
- plt.tight_layout()
1082
+ axes_heat[i].set_xlabel("Position", fontsize=9)
173
1083
 
174
- if save_path:
175
- save_name = f"{ref} — {sample}"
176
- os.makedirs(save_path, exist_ok=True)
177
- safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
178
- out_file = os.path.join(save_path, f"{safe_name}.png")
179
- plt.savefig(out_file, dpi=300)
180
- print(f"📁 Saved: {out_file}")
1084
+ plt.tight_layout()
181
1085
 
1086
+ # save or show
1087
+ if save_path is not None:
1088
+ safe_name = f"{ref}__{display_sample}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
1089
+ out_file = save_path / f"{safe_name}.png"
1090
+ fig.savefig(out_file, dpi=300)
1091
+ plt.close(fig)
1092
+ print(f"Saved: {out_file}")
1093
+ else:
182
1094
  plt.show()
183
1095
 
184
- print(f"📊 Summary for {sample} - {ref}:")
185
- for bin_label, percent in percentages.items():
186
- print(f" - {bin_label}: {percent:.1f}%")
187
-
188
- results.append({
189
- "sample": sample,
190
- "ref": ref,
191
- "hmm_matrix": hmm_matrix,
192
- "gpc_matrix": gpc_matrix,
193
- "cpg_matrix": cpg_matrix,
194
- "row_labels": row_labels,
195
- "bin_labels": bin_labels,
196
- "bin_boundaries": bin_boundaries,
197
- "percentages": percentages
198
- })
199
-
200
- adata.uns['clustermap_results'] = results
1096
+ # record results
1097
+ rec = {
1098
+ "sample": str(sample),
1099
+ "ref": str(ref),
1100
+ "row_labels": row_labels,
1101
+ "bin_labels": bin_labels,
1102
+ "bin_boundaries": bin_boundaries,
1103
+ "percentages": percentages,
1104
+ }
1105
+ for blk in blocks:
1106
+ rec[f"{blk['name']}_matrix"] = blk["matrix"]
1107
+ rec[f"{blk['name']}_labels"] = list(map(str, blk["labels"]))
1108
+ results.append(rec)
1109
+
1110
+ print(f"Summary for {display_sample} - {ref}:")
1111
+ for bin_label, percent in percentages.items():
1112
+ print(f" - {bin_label}: {percent:.1f}%")
201
1113
 
202
1114
  except Exception as e:
203
1115
  import traceback
204
1116
  traceback.print_exc()
205
1117
  continue
1118
+
1119
+ # store once at the end (HDF5 safe)
1120
+ # matrices won't be HDF5-safe; store only metadata + maybe hit counts
1121
+ # adata.uns["clustermap_results"] = [
1122
+ # {k: v for k, v in r.items() if not k.endswith("_matrix")}
1123
+ # for r in results
1124
+ # ]
1125
+
1126
+ return results
1127
+
1128
+ def plot_hmm_layers_rolling_by_sample_ref(
1129
+ adata,
1130
+ layers: Optional[Sequence[str]] = None,
1131
+ sample_col: str = "Barcode",
1132
+ ref_col: str = "Reference_strand",
1133
+ samples: Optional[Sequence[str]] = None,
1134
+ references: Optional[Sequence[str]] = None,
1135
+ window: int = 51,
1136
+ min_periods: int = 1,
1137
+ center: bool = True,
1138
+ rows_per_page: int = 6,
1139
+ figsize_per_cell: Tuple[float, float] = (4.0, 2.5),
1140
+ dpi: int = 160,
1141
+ output_dir: Optional[str] = None,
1142
+ save: bool = True,
1143
+ show_raw: bool = False,
1144
+ cmap: str = "tab10",
1145
+ use_var_coords: bool = True,
1146
+ ):
1147
+ """
1148
+ For each sample (row) and reference (col) plot the rolling average of the
1149
+ positional mean (mean across reads) for each layer listed.
1150
+
1151
+ Parameters
1152
+ ----------
1153
+ adata : AnnData
1154
+ Input annotated data (expects obs columns sample_col and ref_col).
1155
+ layers : list[str] | None
1156
+ Which adata.layers to plot. If None, attempts to autodetect layers whose
1157
+ matrices look like "HMM" outputs (else will error). If None and layers
1158
+ cannot be found, user must pass a list.
1159
+ sample_col, ref_col : str
1160
+ obs columns used to group rows.
1161
+ samples, references : optional lists
1162
+ explicit ordering of samples / references. If None, categories in adata.obs are used.
1163
+ window : int
1164
+ rolling window size (odd recommended). If window <= 1, no smoothing applied.
1165
+ min_periods : int
1166
+ min periods param for pd.Series.rolling.
1167
+ center : bool
1168
+ center the rolling window.
1169
+ rows_per_page : int
1170
+ paginate rows per page into multiple figures if needed.
1171
+ figsize_per_cell : (w,h)
1172
+ per-subplot size in inches.
1173
+ dpi : int
1174
+ figure dpi when saving.
1175
+ output_dir : str | None
1176
+ directory to save pages; created if necessary. If None and save=True, uses cwd.
1177
+ save : bool
1178
+ whether to save PNG files.
1179
+ show_raw : bool
1180
+ draw unsmoothed mean as faint line under smoothed curve.
1181
+ cmap : str
1182
+ matplotlib colormap for layer lines.
1183
+ use_var_coords : bool
1184
+ if True, tries to use adata.var_names (coerced to int) as x-axis coordinates; otherwise uses 0..n-1.
1185
+
1186
+ Returns
1187
+ -------
1188
+ saved_files : list[str]
1189
+ list of saved filenames (may be empty if save=False).
1190
+ """
1191
+
1192
+ # --- basic checks / defaults ---
1193
+ if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
1194
+ raise ValueError(f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs")
1195
+
1196
+ # canonicalize samples / refs
1197
+ if samples is None:
1198
+ sseries = adata.obs[sample_col]
1199
+ if not pd.api.types.is_categorical_dtype(sseries):
1200
+ sseries = sseries.astype("category")
1201
+ samples_all = list(sseries.cat.categories)
1202
+ else:
1203
+ samples_all = list(samples)
1204
+
1205
+ if references is None:
1206
+ rseries = adata.obs[ref_col]
1207
+ if not pd.api.types.is_categorical_dtype(rseries):
1208
+ rseries = rseries.astype("category")
1209
+ refs_all = list(rseries.cat.categories)
1210
+ else:
1211
+ refs_all = list(references)
1212
+
1213
+ # choose layers: if not provided, try a sensible default: all layers
1214
+ if layers is None:
1215
+ layers = list(adata.layers.keys())
1216
+ if len(layers) == 0:
1217
+ raise ValueError("No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot.")
1218
+ layers = list(layers)
1219
+
1220
+ # x coordinates (positions)
1221
+ try:
1222
+ if use_var_coords:
1223
+ x_coords = np.array([int(v) for v in adata.var_names])
1224
+ else:
1225
+ raise Exception("user disabled var coords")
1226
+ except Exception:
1227
+ # fallback to 0..n_vars-1
1228
+ x_coords = np.arange(adata.shape[1], dtype=int)
1229
+
1230
+ # make output dir
1231
+ if save:
1232
+ outdir = output_dir or os.getcwd()
1233
+ os.makedirs(outdir, exist_ok=True)
1234
+ else:
1235
+ outdir = None
1236
+
1237
+ n_samples = len(samples_all)
1238
+ n_refs = len(refs_all)
1239
+ total_pages = math.ceil(n_samples / rows_per_page)
1240
+ saved_files = []
1241
+
1242
+ # color cycle for layers
1243
+ cmap_obj = plt.get_cmap(cmap)
1244
+ n_layers = max(1, len(layers))
1245
+ colors = [cmap_obj(i / max(1, n_layers - 1)) for i in range(n_layers)]
1246
+
1247
+ for page in range(total_pages):
1248
+ start = page * rows_per_page
1249
+ end = min(start + rows_per_page, n_samples)
1250
+ chunk = samples_all[start:end]
1251
+ nrows = len(chunk)
1252
+ ncols = n_refs
1253
+
1254
+ fig_w = figsize_per_cell[0] * ncols
1255
+ fig_h = figsize_per_cell[1] * nrows
1256
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
1257
+ figsize=(fig_w, fig_h), dpi=dpi,
1258
+ squeeze=False)
1259
+
1260
+ for r_idx, sample_name in enumerate(chunk):
1261
+ for c_idx, ref_name in enumerate(refs_all):
1262
+ ax = axes[r_idx][c_idx]
1263
+
1264
+ # subset adata
1265
+ mask = (adata.obs[sample_col].values == sample_name) & (adata.obs[ref_col].values == ref_name)
1266
+ sub = adata[mask]
1267
+ if sub.n_obs == 0:
1268
+ ax.text(0.5, 0.5, "No reads", ha="center", va="center", transform=ax.transAxes, color="gray")
1269
+ ax.set_xticks([])
1270
+ ax.set_yticks([])
1271
+ if r_idx == 0:
1272
+ ax.set_title(str(ref_name), fontsize=9)
1273
+ if c_idx == 0:
1274
+ total_reads = int((adata.obs[sample_col] == sample_name).sum())
1275
+ ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
1276
+ continue
1277
+
1278
+ # for each layer, compute positional mean across reads (ignore NaNs)
1279
+ plotted_any = False
1280
+ for li, layer in enumerate(layers):
1281
+ if layer in sub.layers:
1282
+ mat = sub.layers[layer]
1283
+ else:
1284
+ # fallback: try .X only for the first layer if layer not present
1285
+ if layer == layers[0] and getattr(sub, "X", None) is not None:
1286
+ mat = sub.X
1287
+ else:
1288
+ # layer not present for this subset
1289
+ continue
1290
+
1291
+ # convert matrix to numpy 2D
1292
+ if hasattr(mat, "toarray"):
1293
+ try:
1294
+ arr = mat.toarray()
1295
+ except Exception:
1296
+ arr = np.asarray(mat)
1297
+ else:
1298
+ arr = np.asarray(mat)
1299
+
1300
+ if arr.size == 0 or arr.shape[1] == 0:
1301
+ continue
1302
+
1303
+ # compute column-wise mean ignoring NaNs
1304
+ # if arr is boolean or int, convert to float to support NaN
1305
+ arr = arr.astype(float)
1306
+ with np.errstate(all="ignore"):
1307
+ col_mean = np.nanmean(arr, axis=0)
1308
+
1309
+ # If all-NaN, skip
1310
+ if np.all(np.isnan(col_mean)):
1311
+ continue
1312
+
1313
+ # smooth via pandas rolling (centered)
1314
+ if (window is None) or (window <= 1):
1315
+ smoothed = col_mean
1316
+ else:
1317
+ ser = pd.Series(col_mean)
1318
+ smoothed = ser.rolling(window=window, min_periods=min_periods, center=center).mean().to_numpy()
1319
+
1320
+ # x axis: x_coords (trim/pad to match length)
1321
+ L = len(col_mean)
1322
+ x = x_coords[:L]
1323
+
1324
+ # optionally plot raw faint line first
1325
+ if show_raw:
1326
+ ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
1327
+
1328
+ ax.plot(x, smoothed[:L], label=layer, color=colors[li], linewidth=1.2, alpha=0.95, zorder=2)
1329
+ plotted_any = True
1330
+
1331
+ # labels / titles
1332
+ if r_idx == 0:
1333
+ ax.set_title(str(ref_name), fontsize=9)
1334
+ if c_idx == 0:
1335
+ total_reads = int((adata.obs[sample_col] == sample_name).sum())
1336
+ ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
1337
+ if r_idx == nrows - 1:
1338
+ ax.set_xlabel("position", fontsize=8)
1339
+
1340
+ # legend (only show in top-left plot to reduce clutter)
1341
+ if (r_idx == 0 and c_idx == 0) and plotted_any:
1342
+ ax.legend(fontsize=7, loc="upper right")
1343
+
1344
+ ax.grid(True, alpha=0.2)
1345
+
1346
+ fig.suptitle(f"Rolling mean of layer positional means (window={window}) — page {page+1}/{total_pages}", fontsize=11, y=0.995)
1347
+ fig.tight_layout(rect=[0, 0, 1, 0.97])
1348
+
1349
+ if save:
1350
+ fname = os.path.join(outdir, f"hmm_layers_rolling_page{page+1}.png")
1351
+ plt.savefig(fname, bbox_inches="tight", dpi=dpi)
1352
+ saved_files.append(fname)
1353
+ else:
1354
+ plt.show()
1355
+ plt.close(fig)
1356
+
1357
+ return saved_files